set.seed(1) ## to get same randomization for OSA residuals for zero obs
source("src/common.R")
source("src/forecast.R")
doOSA <- TRUE

## PARAMS
##
JUMPMULT = 100
JUMPLIMIT = 2.9
MMULT = 1.0
LSDLC3 = -0.3

library(TMB)
library(xtable)
library(ellipse)

compile("src/sesam.cpp") ##,tracesweep=TRUE)
dyn.load(dynlib("src/sesam"))

obs<-read.table("data/obs.dat")
aux<-read.table("data/aux.dat") 
xsa=read.table("data/xsa.dat")

obs=obs[order(obs$t1,obs$fleet,obs$ageFrom),]

x<-xtabs(obs$obs>0~obs$fleet+obs$ageFrom)

maxFleet=max(obs$fleet)

## remove commercial CPUE series
obs=subset(obs,fleet!=2)
obs$fleet[obs$fleet>2] = obs$fleet[obs$fleet>2] - 1

## remove scottish age 0
##obs = subset(obs,!(fleet==4 & ageFrom==0))
## split scottish age 0
obs$fleet[obs$fleet==4 & obs$ageFrom==0 & obs$t1>2013] = 6


## set the fraction into the quarter each survey is conducted 
quarterFrac = c(NA,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5) 

if(exists("is.LO") && is.LO && is.integer(LOnumber)){
    cat("Doing LO run. Removing fleet ",LOnumber,"\n");
    quarterFrac = quarterFrac[-LOnumber]
    obs=subset(obs,fleet!=LOnumber)
    obs$fleet[obs$fleet>LOnumber] = obs$fleet[obs$fleet>LOnumber] - 1
}

## if all ages in a timestep is zero for a fleet => not observed (except fleet 1)
for(f in setdiff(unique(obs$fleet),1)){
    for(yy in unique(obs$t1)){
        
        sel = which(obs$fleet==f & obs$t1==yy)
        mysum = sum(obs$obs[sel])
        if(mysum==0 & length(sel)>0) {
            obs=obs[-sel,]; 
        }
    }
}

## if all years are zero for a given age group and fleet => not observed
for(f in unique(obs$fleet)){
    for(aa in unique(obs$ageFrom)){
        sel = which(obs$fleet==f & obs$ageFrom==aa)
        mysum = sum(obs$obs[sel],na.rm=TRUE)
        if(mysum==0 & length(sel)>0) obs=obs[-sel,]; ##obs$obs[sel]=NA    
    }
}

## Reduce from 4+ to 3+
for(f in unique(obs$fleet)){
    for(yy in unique(obs$t1)){
        sel = which(obs$fleet==f & obs$t1==yy & obs$ageFrom%in% c(3,4))
        mysum = sum(obs$obs[sel])
        if(length(sel)==2) {
            obs$obs[sel[1]] = mysum;
            obs=obs[-sel[2],]
        }
    }
}

aux=aux[aux$age<=3,]

## Start in 1984
obs=obs[obs$t1>=1984,]
aux=aux[aux$t1>=1984,]

if(exists("is.RETRO") && is.RETRO && is.integer(RETROnumber)){
    cat("Doing RETRO run. Remoing the last ",RETROnumber," years\n");
    maxY = max(obs$t1)-RETROnumber
    maxYs = aggregate(t1~fleet,data=obs,FUN=max)
    
    for(f in 1:max(obs$fleet)){
        obs=subset(obs,!(t1>(maxYs[f,2]-RETROnumber) & fleet==f))
    }
    
    aux=aux[aux$t1<=maxY,]
}

names(aux)<-paste("aux",names(aux), sep="")
aux$auxM<-4*aux$auxM  ## OBS, SXSA multiplies with DT, so M in sesam must by (1/DT) larger
aux$auxM<-MMULT*aux$auxM  ## OBS, SXSA multiplies with DT, so M in sesam must by (1/DT) larger

obsc <- subset(obs,fleet==1)

cn = xtabs(obs~t1+ageFrom,data=obsc)
cw = xtabs(auxCW ~ auxt1 + auxage,data=aux)
tc = rowSums(cn * cw)

bigjump = rep(0,length(tc)+1)
jump = rep(0,length(tc)+1)
for( i in 5:length(tc)){
    jump[i] = abs( log(tc[i]+1) - log(tc[i-4]+1)) 
    if( jump[i] > JUMPLIMIT ){
        bigjump[i] = 1
    }
}

data<-c(as.list(obs),as.list(aux))

data$auxbigjump = bigjump
data$times<-sort(unique(c(data$t1,data$t2,data$auxt1,data$auxt2)))

data$ages<-sort(unique(c(data$ageFrom,data$ageTo)))
data$bday<-0.0
data$jumpmult <- c(JUMPMULT,JUMPMULT) ## only second value is used
getidx<-function(a,t, eps=0.0001){
    which(((t>(data$auxt1-eps))&
           (t<(data$auxt2-eps)))&
          (a==data$auxage))-1
}

data$mapaux<-outer(data$ages, data$times, Vectorize(getidx, c("a","t")))
data$mapaux[,ncol(data$mapaux)]<-data$mapaux[,ncol(data$mapaux)-4] ## quarter 4 aux missing, copy from last year
data$mapaux<-matrix(unlist(data$mapaux), nrow=nrow(data$mapaux), ncol=ncol(data$mapaux))  

x<-xtabs(data$obs>0~data$fleet+data$ageFrom)
data$keyVarObs<-row(x)
data$keyVarObs[x==0]<-NA

data$keyVarObs = matrix(0,6,4)

data$keyVarObs[1,]=c(0,1,1,2)
data$keyVarObs[2,]=c(NA,3,4,5) ## 
data$keyVarObs[3,]=c(6,7,8,9)
data$keyVarObs[4,]=c(10,11,11,12)
data$keyVarObs[5,]=c(13,14,15,NA)
data$keyVarObs[6,]=c(10,NA,NA,NA)


data$keyVarObs[] = as.numeric(as.factor(data$keyVarObs))-1
## Setup variance parameter couplings
if(exists("is.LO") && is.LO && is.integer(LOnumber)){
    data$keyVarObs = data$keyVarObs[-LOnumber,]
    
    data$keyVarObs[] = as.numeric(as.factor(data$keyVarObs))-1
}

data$keyVarLogN<-rep(0,length(data$ages))
data$keyVarLogN<-c(0,0,0,1) 
data$keyLogFsta<-c(0,1,2,3) ##c(0:(length(data$ages)-1)); 

##data$keyVarLogF<-rep(0,max(data$keyLogFsta)+1)
data$keyVarLogF<-c(0,0,1,1)
x<-data$keyVarObs
x[1,]<-NA
data$keyLogQ<-t(sapply(1:nrow(x),function(i){xx<-x[i,];n<-sum(!is.na(xx)); xx[!is.na(xx)]<-1:n+sum(!is.na(x[1:(i-1),]));xx}))-1

data$keyLogQ = matrix(NA,6,4) 
data$keyLogQ[2,] = c(NA,0,1,2)
data$keyLogQ[3,] = c(3,4,5,5)
data$keyLogQ[4,] = c(6,7,8,8)
data$keyLogQ[5,] = c(9,10,11,NA)
data$keyLogQ[6,] = c(12,NA,NA,NA)

data$keyLogQ[] = as.numeric(as.factor(data$keyLogQ))-1

if(exists("is.LO") && is.LO && is.integer(LOnumber)){
    data$keyLogQ = data$keyLogQ[-LOnumber,]
    data$keyLogQ[] = as.numeric(as.factor(data$keyLogQ))-1
}

data$idx1<-match(data$t1,data$times)-1
data$idx2<-match(data$t2,data$times)-1
data$fbarrange<-c(1,2)

getlagback<-function(t, lag=1.0, eps=0.01){
    ret<-which((t>(data$times+lag-eps))&
               (t<(data$times+lag+eps)))
    ifelse(length(ret)==0,NA,ret)
}
data$maplagR<-sapply(data$times, Vectorize(getlagback, c("t")))-1

data$maplagF<-rbind(c(NA,1:(length(data$times)-1)-1),data$maplagR)
data$recruitTimes<-rep(c(0,0,1,0),length.out=length(data$times))

## Set the detection limit by fleet to half of the smallest positive observation 
data$eps = unlist( lapply(split(data$obs,data$fleet),function(x)min(x[x>0])) )/2 
data$quarterFrac = quarterFrac

data$quarterObs = (data$t1-floor(data$t1))*4+1
data$quarter = (data$times-floor(data$times))*4+1

param<-list()

param$logSdLogFsta<-rep(0,max(data$keyVarLogF)+1)
param$logSdLogN<-rep(1,max(data$keyVarLogN)+1) 
param$logSdLogObs<-rep(0,max(data$keyVarObs, na.rm=TRUE)+1)  
param$logQ=rep(0,max(data$keyLogQ,na.rm=TRUE)+1)  
param$trans_rho=.7 
param$logSdLogR<-1
param$logF<-matrix(0,ncol=length(data$times), nrow=max(data$keyLogFsta)+1)
param$logN<-matrix(0,ncol=length(data$times), nrow=length(data$ages))
param$logR<-rep(5,length(data$times))

data$logobs = log(data$obs);
for(i in 1:length(data$logobs)){
    if(data$logobs[i]==-Inf) data$logobs[i]=log(data$eps[data$fleet[i]]-1e-4)
}

data$noYears = floor(data$times[length(data$times)]-data$times[1])

##mymap = list(logSdLogN=factor(c(NA,1:(max(data$keyVarLogN)))),logSdLogObs=factor(c(1,NA,3:length(param$logSdLogObs)))) ##,logSdLogFsta=factor(NA))

##mymap = list(logSdLogObs=factor(c(1,2,NA,4:length(param$logSdLogObs))),trans_rho=factor(NA)) ##,logSdLogFsta=factor(NA))
mymap = list(logSdLogObs=factor(c(1,2,NA,4:length(param$logSdLogObs)))) ##,logSdLogFsta=factor(NA))
##mymap = list()


param$logSdLogObs[3] = LSDLC3
##param$trans_rho = 5

##param$logSdLogObs[2] = -0.5
##param$logSdLogN[1] = -1
##param$logSdLogFsta[1] = 0.5

obj <- MakeADFun(data,param, random=c("logF","logN","logR"), DLL="sesam", inner.control=list(maxit=1000),map=mymap)

lower<-rep(-Inf,length(obj$par))
upper<-rep(Inf,length(obj$par))
lower[grep("^logSd",names(obj$par))]=-4;
upper[grep("^logSd",names(obj$par))]=4;
upper[grep("^trans_rho",names(obj$par))]=5;

opt <- nlminb(obj$par, obj$fn, obj$gr, control=list(trace=1,eval.max=1200,iter.max=900,rel.tol=1e-10),lower=lower,upper=upper) 

if(!opt$convergence==0) warning("Error: model did not converge.");

pl <- obj$env$parList()
rep<-obj$report()
jointrep<-sdreport(obj, getJointPrecision=T,bias.correct=TRUE)
sdrep<-sdreport(obj,bias.correct=TRUE)
allsd<-sqrt(diag(solve(jointrep$jointPrecision))) 
plsd2 <- obj$env$parList(par=allsd)

plsd=pl
sel=which(names(sdrep$par.random)=="logN")
plsd$logN[,] <- sqrt(sdrep$diag.cov.random[sel])
sel=which(names(sdrep$par.random)=="logF")
plsd$logF[,] <-  sqrt(sdrep$diag.cov.random[sel])
sel=which(names(sdrep$par.random)=="logR")
plsd$logR <-  sqrt(sdrep$diag.cov.random[sel])

estN <- pl$logN
sel2<-pl$logN[1,]==0
estN[1, sel2 ]<-pl$logR[sel2]
pl$logN<-estN
sdN <-plsd$logN
sdN[1, sel2 ]<-plsd$logR[sel2]
plsd$logN<-sdN   

sum<-summary(sdrep)
ssb<-sum[rownames(sum)=="ssb",]
tsb<-sum[rownames(sum)=="tsb",]
logfbar<-sum[rownames(sum)=="logfbar",]
lssb<-sum[rownames(sum)=="logssb",]
ltsb<-sum[rownames(sum)=="logtsb",]



SSB<-ssb[,1]
SSB.lo<-exp(lssb[,1] -2*lssb[,2])
SSB.hi<-exp(lssb[,1] +2*lssb[,2])

TSB<-tsb[,1]
TSB.lo<-exp(ltsb[,1]-2*ltsb[,2])
TSB.hi<-exp(ltsb[,1]+2*ltsb[,2])

FBAR<-exp(logfbar[,1])
FBAR.lo<-exp(logfbar[,1]-2*logfbar[,2])
FBAR.hi<-exp(logfbar[,1]+2*logfbar[,2]);

logfbar.y = sum[rownames(sum)=="logfbarY",]
fbar.y = exp(logfbar.y[,1])
fbar.y.lo = exp(logfbar.y[,1] - 2*logfbar.y[,2])
fbar.y.hi = exp(logfbar.y[,1] + 2*logfbar.y[,2])

if(exists("is.RETRO") && is.RETRO && is.integer(RETROnumber)){
    listname = paste0("retrorun",RETROnumber)
    assign(listname,list(sum,SSB,FBAR, pl, plsd, sdrep,opt=opt))
    eval(substitute(save(x,file=paste0("RETRO/",listname,".RData")),list(x=as.name(listname))))
} else if(exists("is.LO") && is.LO && is.integer(LOnumber) ){
    listname = paste0("LOrun",LOnumber)
    assign(listname,list(sum,SSB,FBAR, pl, plsd, sdrep,opt=opt))
    eval(substitute(save(x,file=paste0("LO/",listname,".RData")),list(x=as.name(listname))))
    
} else if(exists("MPROF") && MPROF){
    ## do nothing
} else {
    source("src/plotscript.R")
}
