set.seed(1) ## to get same randomization for OSA residuals for zero obs
source("src/common.R")
source("src/forecast.R")
doOSA <- TRUE

library(TMB)
library(xtable)
library(ellipse)

compile("src/sesam.cpp")
dyn.load(dynlib("src/sesam"))

obs<-read.table("data/obs.dat")
aux<-read.table("data/aux.dat")
xsa=read.table("data/xsa.dat")

obs=obs[order(obs$t1,obs$fleet,obs$ageFrom),]

x<-xtabs(obs$obs>0~obs$fleet+obs$ageFrom)

## set the fraction into the year each survey is conducted 
quarterFrac = c(NA,0.5,0.5,0.6,0.75,0.75,0.5,0.5)

maxFleet=max(obs$fleet)

## remove commercial CPUE series
obs=subset(obs,fleet!=2)
obs$fleet[obs$fleet>2] = obs$fleet[obs$fleet>2] - 1
quarterFrac = c(NA,0.5,0.6,0.75,0.75,0.5,0.5)

## if all ages in a year is zero for a fleet => not observed
for(f in unique(obs$fleet)){
    for(yy in unique(obs$t1)){
        
        sel = which(obs$fleet==f & obs$t1==yy)
        mysum = sum(obs$obs[sel])
        if(mysum==0 & length(sel)>0) {
            obs=obs[-sel,]; 
        }
    }
}

## if all years are zero for a given age group and fleet => not observed
for(f in unique(obs$fleet)){
    for(aa in unique(obs$ageFrom)){
        sel = which(obs$fleet==f & obs$ageFrom==aa)
        mysum = sum(obs$obs[sel],na.rm=TRUE)
        if(mysum==0 & length(sel)>0) obs=obs[-sel,]; ##obs$obs[sel]=NA    
    }
}

## Reduce from 4+ to 3+
for(f in unique(obs$fleet)){
    for(yy in unique(obs$t1)){
        sel = which(obs$fleet==f & obs$t1==yy & obs$ageFrom%in% c(3,4))
        mysum = sum(obs$obs[sel])
        if(length(sel)==2) {
            obs$obs[sel[1]] = mysum;
            obs=obs[-sel[2],]
        }
    }
}

aux=aux[aux$age<=3,]

## Start in 1984
obs=obs[obs$t1>=1984,]
aux=aux[aux$t1>=1984,]

if(exists("is.RETRO") && is.RETRO && is.integer(RETROnumber)){
    cat("Doing RETRO run. Remoing the last ",RETROnumber," years\n");
    maxY = max(obs$t1)-RETROnumber
    obs=obs[obs$t1<=maxY,]
    aux=aux[aux$t1<=maxY,]
}

names(aux)<-paste("aux",names(aux), sep="")
aux$auxM<-4*aux$auxM  ## OBS, SXSA multiplies with DT, so M in sesam must by (1/DT) larger

data<-c(as.list(obs),as.list(aux))

data$times<-sort(unique(c(data$t1,data$t2,data$auxt1,data$auxt2)))
data$ages<-sort(unique(c(data$ageFrom,data$ageTo)))
data$bday<-0.0

getidx<-function(a,t, eps=0.0001){
    which(((t>(data$auxt1-eps))&
           (t<(data$auxt2-eps)))&
          (a==data$auxage))-1
}

data$mapaux<-outer(data$ages, data$times, Vectorize(getidx, c("a","t")))
data$mapaux[,ncol(data$mapaux)]<-data$mapaux[,ncol(data$mapaux)-1]
data$mapaux<-matrix(unlist(data$mapaux), nrow=nrow(data$mapaux), ncol=ncol(data$mapaux))  


x<-xtabs(data$obs>0~data$fleet+data$ageFrom)
data$keyVarObs<-row(x)
data$keyVarObs[x==0]<-NA

## Setup variance parameter couplings
data$keyVarObs[1,]=c(0,1,1,0)
data$keyVarObs[2,]=c(NA,2,2,2)
data$keyVarObs[3,]=c(3,3,NA,NA)
data$keyVarObs[4,]=c(4,4,NA,NA)
data$keyVarObs[5,]=c(NA,NA,5,5)

data$keyVarLogN<-rep(0,length(data$ages))
data$keyLogFsta<-c(0:(length(data$ages)-1)); 

data$keyVarLogF<-rep(0,max(data$keyLogFsta)+1)
x<-data$keyVarObs
x[1,]<-NA
data$keyLogQ<-t(sapply(1:nrow(x),function(i){xx<-x[i,];n<-sum(!is.na(xx)); xx[!is.na(xx)]<-1:n+sum(!is.na(x[1:(i-1),]));xx}))-1

data$idx1<-match(data$t1,data$times)-1
data$idx2<-match(data$t2,data$times)-1
data$fbarrange<-c(1,2)

getlagback<-function(t, lag=1.0, eps=0.01){
    ret<-which((t>(data$times+lag-eps))&
               (t<(data$times+lag+eps)))
    ifelse(length(ret)==0,NA,ret)
}
data$maplagR<-sapply(data$times, Vectorize(getlagback, c("t")))-1

data$maplagF<-rbind(c(NA,1:(length(data$times)-1)-1),data$maplagR)
data$recruitTimes<-rep(c(0,0,1,0),length.out=length(data$times))

## Set the detection limit by fleet to half of the smallest positive observation 
data$eps = unlist( lapply(split(data$obs,data$fleet),function(x)min(x[x>0])) )/2 
data$quarterFrac = quarterFrac


param<-list()

param$logSdLogFsta<-rep(0,max(data$keyVarLogF)+1)
param$logSdLogN<-rep(1,max(data$keyVarLogN)+1) 
param$logSdLogObs<-rep(0,max(data$keyVarObs, na.rm=TRUE)+1)  
param$logQ=rep(0,max(data$keyLogQ,na.rm=TRUE)+1)  
param$trans_rho=.7 
param$logSdLogR<-1
param$logF<-matrix(0,ncol=length(data$times), nrow=max(data$keyLogFsta)+1)
param$logN<-matrix(0,ncol=length(data$times), nrow=length(data$ages))
param$logR<-rep(5,length(data$times))

data$logobs = log(data$obs);
for(i in 1:length(data$logobs)){
    if(data$logobs[i]==-Inf) data$logobs[i]=log(data$eps[data$fleet[i]]-1e-4)
}

data$noYears = floor(data$times[length(data$times)]-data$times[1])


obj <- MakeADFun(data,param, random=c("logF","logN","logR"), DLL="sesam", inner.control=list(maxit=1000))


lower<-rep(-Inf,length(obj$par))
upper<-rep(Inf,length(obj$par))
lower[grep("^logSd",names(obj$par))]=-4;
upper[grep("^logSd",names(obj$par))]=4;


opt <- nlminb(obj$par, obj$fn, obj$gr, control=list(trace=1,eval.max=1200,iter.max=900),lower=lower,upper=upper) 

if(!opt$convergence==0) warning("Error: model did not converge.");

pl <- obj$env$parList()
rep<-obj$report()
jointrep<-sdreport(obj, getJointPrecision=T,bias.correct=TRUE)
sdrep<-sdreport(obj,bias.correct=TRUE)
allsd<-sqrt(diag(solve(jointrep$jointPrecision))) 
plsd2 <- obj$env$parList(par=allsd)

plsd=pl
sel=which(names(sdrep$par.random)=="logN")
plsd$logN[,] <- sqrt(sdrep$diag.cov.random[sel])
sel=which(names(sdrep$par.random)=="logF")
plsd$logF[,] <-  sqrt(sdrep$diag.cov.random[sel])
sel=which(names(sdrep$par.random)=="logR")
plsd$logR <-  sqrt(sdrep$diag.cov.random[sel])

estN <- pl$logN
sel2<-pl$logN[1,]==0
estN[1, sel2 ]<-pl$logR[sel2]
pl$logN<-estN
sdN <-plsd$logN
sdN[1, sel2 ]<-plsd$logR[sel2]
plsd$logN<-sdN   

sum<-summary(sdrep)
ssb<-sum[rownames(sum)=="ssb",]
tsb<-sum[rownames(sum)=="tsb",]
logfbar<-sum[rownames(sum)=="logfbar",]
lssb<-sum[rownames(sum)=="logssb",]
ltsb<-sum[rownames(sum)=="logtsb",]



SSB<-ssb[,1]
SSB.lo<-exp(lssb[,1] -2*lssb[,2])
SSB.hi<-exp(lssb[,1] +2*lssb[,2])

TSB<-tsb[,1]
TSB.lo<-exp(ltsb[,1]-2*ltsb[,2])
TSB.hi<-exp(ltsb[,1]+2*ltsb[,2])

FBAR<-exp(logfbar[,1])
FBAR.lo<-exp(logfbar[,1]-2*logfbar[,2])
FBAR.hi<-exp(logfbar[,1]+2*logfbar[,2]);

logfbar.y = sum[rownames(sum)=="logfbarY",]
fbar.y = exp(logfbar.y[,1])
fbar.y.lo = exp(logfbar.y[,1] - 2*logfbar.y[,2])
fbar.y.hi = exp(logfbar.y[,1] + 2*logfbar.y[,2])

if(exists("is.RETRO") && is.RETRO && is.integer(RETROnumber)){
    listname = paste0("retrorun",RETROnumber)
    assign(listname,list(sum,SSB,FBAR, pl, plsd, sdrep))
    eval(substitute(save(x,file=paste0("RETRO/",listname,".RData")),list(x=as.name(listname))))
} else {
    source("src/plotscript.R")
}
