## Forecasting functions for SESAM
forecast.sesam<-function(obj,jointrep,data,stepsback=2,lags.needed=4,lags.forward=4,noSim=1000,seed=1,scale=1, logrecruitVec=NULL,oldstyle=TRUE,Myearrange=NULL){

    if(is.character(logrecruitVec) && !logrecruitVec%in%c("RW","historic")) stop("logrecruitVec must be a vector or on of the a character strings 'RW' or 'historic'");  
    
    set.seed(seed)
    ##require(MASS) Dont use MASS -- mvrnorm is platform dependent
    nF = max(data$keyLogFsta)+1
    nN = length(data$ages)
        
    ## Extract state + covariances
    pn = names(jointrep$value)
    Fidx = which(pn=="logF")
    lf = length(Fidx)
    Nidx = which(pn=="logN")
    ln = length(Nidx)
    Ridx = which(pn=="logR" & jointrep$value>-1) ## OBS hack to leave out essentially zero recruitments (timesteps with no recruitment). Will fail if N is scaled to be very low!
    lr = length(Ridx)
    
    Fidx = Fidx[ (lf-nF*(stepsback+lags.needed)+1):(lf-(nF*stepsback)) ]
    Nidx = Nidx[ (ln-nN*(stepsback+lags.needed)+1):(ln-(nN*stepsback)) ] 
        
    last.years.states = jointrep$value[ c(Nidx,Fidx,Ridx) ]

    nfixed = length(jointrep$par.fixed)
    last.years.state.cov = jointrep$cov[ c(Nidx,Fidx,Ridx), c(Nidx,Fidx,Ridx)] + diag(1e-12,length(last.years.states)) ## add small value to diagonal to enforce positive definite
    
    last.years.state.sim = stockassessment::rmvnorm(noSim, mu=last.years.states, Sigma=last.years.state.cov) ##mvrnorm(noSim, mu=last.years.states, Sigma=last.years.state.cov)

    nTimes = length(data$times)
    timesIdx = (nTimes-stepsback-lags.needed+1):(nTimes-stepsback)
    auxidx = data$mapaux[,timesIdx] + 1
    times=data$times[timesIdx]

    Nidx2 = which(names(last.years.states)=="logN")
    Fidx2 = which(names(last.years.states)=="logF")
    Ridx2 = which(names(last.years.states)=="logR")
    
    
    
    M = matrix( data$auxM[auxidx],nN)
    if(!is.null(Myearrange)){
        stopifnot(length(Myearrange)==2)
        avgM = aggregate(auxM ~ auxage+quarter,FUN=mean,data=subset(aux,auxt1 >= Myearrange[1] & auxt1 <= Myearrange[2]))
        M = matrix( avgM$auxM,nN)
    }
    
    
    M = cbind(M,M,M,M,M,M,M,M,M)
    DT=1/4
    sdlogNproc = exp(jointrep$par.fixed["logSdLogN"]) ## OBS assumes only one
    sdlogRproc = exp(jointrep$par.fixed["logSdLogR"])

    y2q<-function(x) (x-floor(x))*4 + 1

    newtimes = times[length(times)] + 1:lags.forward*DT ##orig
    newtimes2 = times[length(times)-1]+1:(lags.forward+1)*DT

    recruitTimes = as.numeric(y2q(newtimes)==3)
    
    if(length(scale)==1) scale=rep(scale,length(newtimes))
    ##if(length(scale)!=(length(newtimes)-1)) stop("Scale must have length 1 or length equal to lags.forward");
    
    ## Step forward
    step<-function(x,scale2){
        
        if(length(logrecruitVec)==1 && logrecruitVec=="historic")
            logrecruitVec = x[Ridx2];
        
        logN = cbind( matrix(x[Nidx2],nrow=nN), matrix(NA,nN,lags.forward))
        logF = cbind( matrix(x[Fidx2],nrow=nF), matrix(NA,nF,lags.forward))
                
        count = 0;
        for(i in (lags.needed+1):(lags.needed+lags.forward)){
            count=count+1
            oldlogN = logN[,i-1]
                        
            ## scale F's one lag back - the last F state reported is not influenced by data
            logF[,i] <- logF[,i-4]; 
            
            logF[,i-1] <- logF[,i-1] + log(scale2[count]);
            
            logN[,i] <- oldlogN - (exp(logF[,i-1] ) + M[,i-1])*DT  + rnorm(nN,0,sdlogNproc);
            
            if( y2q(newtimes[count])==1){## new year
                pg = logN[nN,i]
                logN[2:nN,i] = logN[1:(nN-1),i]
                logN[nN,i] = log( exp(logN[nN,i]) + exp(pg) );
                logN[1,i] = 0
            }

            if(recruitTimes[i-lags.needed]==1){ ## Random walk recruitment
                if(length(logrecruitVec)==1 && logrecruitVec=="RW"){
                    rr = logN[1,i-4] + rnorm(1,0,sdlogRproc)
                    logN[1,i]= rr
                } else { ## sample from recruitment vector (historic or given)
                    rr = logrecruitVec[ sample.int(length(logrecruitVec),1) ]
                    logN[1,i]=rr
                }
                
            }
            
        }
        sel = (ncol(logN)-lags.forward):ncol(logN)
        if(oldstyle) sel = (ncol(logN)-lags.forward+1):ncol(logN) ## orig
        
        list(logN=logN[,sel,drop=FALSE],logF=logF[,sel,drop=FALSE])
        
     }

    sim = apply(last.years.state.sim,1, step,scale2=scale)

    attr(sim,"times")<-newtimes2
    if(oldstyle) attr(sim,"times")<-newtimes  
    sim
}


get.fbar.sim<-function(sim,fbarrange=2:3){
    lapply(sim, function(x) colMeans(exp(x$logF[fbarrange,])))
}

get.ssb.sim<-function(sim,SW,PM){
    repIt <- rep(1:4,length=ncol(sim[[1]]$logN))
    lapply(sim, function(x) colSums( exp(x$logN)*SW[,repIt]*PM[,repIt] ))
}

get.tsb.sim<-function(sim,SW){
    repIt <- rep(1:4,length=ncol(sim[[1]]$logN))
    lapply(sim, function(x) colSums( exp(x$logN)*SW[,repIt]))
}


catchfun<-function(x,CW,NM,DT=1/4){
    repIt <- rep(1:4,length=ncol(x$logN))
    F <- exp(x$logF)
    Z <- F+NM[,repIt]
    N <- exp(x$logN)
    C<- F/Z*(1-exp(-Z*DT))*N
    colSums( C*CW[,repIt] )
}

get.catch.sim<-function(sim,CW,NM){
    lapply(sim,catchfun,CW=CW,NM=NM) 
}

get.recruits.sim<-function(sim, quarter){
   qtimes = y2q(attr(sim,"times"))
   sel <- which( qtimes == quarter )
   lapply( sim, function(x) { tmp=numeric( ncol(sim[[1]]$logN) ); tmp[sel] = exp(x$logN[1,sel]); tmp } )
}


get.sim.scaled.to.catch<-function(obj,jointrep,data,stepsback=0,lags.forward=6,noSim=1000,seed=5,
                                  catchgrouping=c(1,1,2,2,2,2),catchtarget=c(70000,100000),CW,NM,trace=FALSE,Myearrange=Myearrange){

    res = rep(NA,max(catchgrouping))

    
    targetDev<-function(scalings,match.subset=rep(TRUE,length(scalings))){ 

        myscale = exp(scalings[catchgrouping]) ## scaling are log-scalings!
        
        sim=forecast.sesam(obj,jointrep,data,stepsback=stepsback,lags.forward=lags.forward,noSim=noSim,seed=seed,scale=myscale,scaleFonelagback=scaleFonelagback,Myearrange=Myearrange)
        catch.sim = get.catch.sim(sim,CW,NM)

        median.catch.sim = apply( do.call("rbind",catch.sim),2,median)

        targets = aggregate(median.catch.sim,by=list(catchgrouping),FUN=sum)$x

        ss = sum(( (catchtarget-targets)/1000 )^2)

        if(trace) cat(myscale, " : ", ss," target: ", targets, "\n");

        ss
        
        }

    opt = stepwise.optimize(rep(0,max(catchgrouping)),targetDev,interval=c(-4,4))
    exp(opt)

}

stepwise.optimize<-function(par,fn,interval,...){
    
    res = par
    for(i in 1:length(par)){

        tmpfun<-function(x,par2,fn2,i){
            tp = par2
            tp[i] = x
            fn2(tp,i)
        }

        oo = optimize(tmpfun,interval=interval,par2=res,fn2=fn,i=i,tol=1e-6,...)
        
        res[i] = oo$minimum
        cat("optimized step ",i," out of ",length(par),"\n")
    }
    res
}

get.sim.scaled.to.blim<-function(obj,jointrep,data,stepsback=0,lags.forward=6,noSim=1000,seed=5,
                                  blimgrouping=rep(1,lags.forward+1),blimtarget=c(70000,100000),SW,PM,logrecruitVec=NULL,trace=FALSE,prob=0.05){

    res = rep(NA,max(blimgrouping))

    
    targetDev<-function(scalings,match.subset=rep(TRUE,length(scalings))){

        myscale = exp(scalings[blimgrouping])
        
        sim=forecast.sesam(obj,jointrep,data,stepsback=stepsback,lags.forward=lags.forward,noSim=noSim,seed=seed,scale=myscale,logrecruitVec=logrecruitVec,oldstyle=FALSE)
        ssb.sim = get.ssb.sim(sim,SW,PM)

        lo.ssb.sim = apply( do.call("rbind",ssb.sim),2,quantile,probs=prob)

        ## get ssb from last time point in each grouping
        tmp = c(blimgrouping,max(blimgrouping)+1)
        sel = as.logical( diff( tmp ))
        
        targets = lo.ssb.sim[ sel ]

        ss = sum(( (blimtarget-targets)[match.subset]/1000 )^2)
        
        if(trace) cat(myscale, " : ", ss," target: ", targets, "\n");
        
        ss
        
        }

    opt = stepwise.optimize(rep(0,max(blimgrouping)),targetDev,interval=c(-5,5))

    exp(opt)

}

CWpredict<-function(aux){
    require(mgcv)
    aux$cohort=floor(aux$auxt1)-aux$auxage
    aux$year=factor(floor(aux$auxt1))
    aux$timefac=factor(aux$auxt1)

    aux$dum=1
    m=gam(sqrt(auxCW)~factor(auxage)*factor(quarter)+te(cohort,auxage,k=c(10,4))+s(timefac,bs="re",by=dum),data=aux[aux$auxCW>0,],method="ML" )

    aux$dum=0 ## disable random year*quarter effects in last years for predictions
    sel = aux$auxCW>0 | aux$auxt1>(max(aux$auxt1)-1)
    p=predict(m,newdata=aux[sel,])
    aux$predCW=0
    aux$predCW[sel]  = p^2
    
    CWpred = xtabs(predCW ~ auxage + quarter,data=subset(aux,auxt1>max(aux$auxt1)-1))
          
    list(model=m,aux=aux,CW=CWpred)
    
}

forecast.plotFbar<-function(data, sim, sdrep, fbarrange, quantiles=c(0.025,0.975)){

    time = data$times
    tsel=which(time<=max(data$t1))

    fbar.sim = get.fbar.sim(sim,fbarrange=fbarrange)
    fbar.simmat = do.call("rbind",fbar.sim)

    median.fbar.sim = apply( fbar.simmat,2,median)
    lower.fbar.sim = apply( fbar.simmat,2,quantile,probs=quantiles[1])
    upper.fbar.sim = apply( fbar.simmat,2,quantile,probs=quantiles[2])

    ftimes = attr(sim,"times")
    last = length(ftimes)
    
    sum<-summary(sdrep)
    logfbar<-sum[rownames(sum)=="logfbar",]

    FBAR<-exp(logfbar[,1])
    FBAR.lo<-exp(logfbar[,1]-2*logfbar[,2])
    FBAR.hi<-exp(logfbar[,1]+2*logfbar[,2]);

    uncertaintyPlot(c(time[tsel],ftimes[-last]), c(FBAR[tsel],median.fbar.sim[-last]), c(FBAR.lo[tsel],lower.fbar.sim[-last]), c(FBAR.hi[tsel],upper.fbar.sim[-last]), xlab="Time", ylab="FBAR", type="b", cex=.5, las=1)
    points(ftimes[-last],median.fbar.sim[-last],col=2,pch=16)

}

forecast.plotSSB<-function(data, sim, sdrep, quantiles=c(0.05,0.95),SW,PM,blim=NULL){
    time = data$times
    tsel=which(time<=max(data$t1))
    ftimes = attr(sim,"times")
    qftimes=y2q(ftimes)[1:length(ftimes)]
    
    ssb.sim<-get.ssb.sim(sim,SW[,qftimes],PM[,qftimes])
        
    ssb.simmat = do.call("rbind",ssb.sim)
    
    median.ssb.sim = apply( ssb.simmat,2,median)
    lower.ssb.sim = apply( ssb.simmat,2,quantile,probs=quantiles[1])
    upper.ssb.sim = apply( ssb.simmat,2,quantile,probs=quantiles[2])

    sum<-summary(sdrep)
    ssb<-sum[rownames(sum)=="ssb",]
    SSB<-ssb[,1]
    SSB.lo<-SSB-2*ssb[,2]
    SSB.hi<-SSB+2*ssb[,2]

    uncertaintyPlot(c(time,ftimes[-1]), c(SSB,median.ssb.sim[-1]), c(SSB.lo,lower.ssb.sim[-1]), c(SSB.hi,upper.ssb.sim[-1]), xlab="Time", ylab="SSB", type="b", cex=.5, las=1)
    
    points(ftimes[-1],median.ssb.sim[-1],col=2,pch=16)
    
    if(!is.null(blim)) abline(h=blim,lwd=2,col=2)
}

forecast.plotcatch<-function(data, sim, sdrep,aux,pl,quantiles=c(0.025,0.975),CW,NM){
    time = data$times
    lasttime = length(time)
    tsel=which(time<=max(data$t1))
    ftimes = attr(sim,"times")
    last = length(ftimes)
    qftimes=y2q(ftimes)[1:length(ftimes)]
    
    catch.sim<-get.catch.sim(sim,CW[,qftimes],NM[,qftimes])

    catch.simmat = do.call("rbind",catch.sim)
    median.catch.sim = apply( catch.simmat,2,median)
    lower.catch.sim = apply( catch.simmat,2,quantile,probs=quantiles[1])
    upper.catch.sim = apply( catch.simmat,2,quantile,probs=quantiles[2])

    CW.all = xtabs(auxCW ~ auxage + auxt1 ,data=aux)
    totcatch = catchfun(pl,CW.all,NM)
    
    uncertaintyPlot(c(time[-lasttime],ftimes[-c(last)]), c(totcatch[-lasttime],median.catch.sim[-c(last)]), c(totcatch[-lasttime],lower.catch.sim[-c(last)]), c(totcatch[-lasttime],upper.catch.sim[-c(last)]), xlab="Time", ylab="total catch", type="b", cex=.5, las=1)
    
    points(ftimes[-c(last)],median.catch.sim[-c(last)],col=2,pch=16)
}

forecast.plotRecruits<-function(data, sim, sdrep, quantiles=c(0.025,0.975)){

    time = data$times
    tsel=which(time<=max(data$t1))

       
    rec.sim = get.recruits.sim(sim,3)
    rec.simmat = do.call("rbind",rec.sim)

    median.rec.sim = apply( rec.simmat,2,median)
    lower.rec.sim = apply( rec.simmat,2,quantile,probs=quantiles[1])
    upper.rec.sim = apply( rec.simmat,2,quantile,probs=quantiles[2])

    ftimes = attr(sim,"times")
    last = length(ftimes)
    
    sum<-summary(sdrep)
    logrec<-sum[rownames(sum)=="logR",]

    
    REC<-exp(logrec[,1])
    REC.lo<-exp(logrec[,1]-2*logrec[,2])
    REC.hi<-exp(logrec[,1]+2*logrec[,2]);

    uncertaintyPlot(c(time[tsel],ftimes[-last]), c(REC[tsel],median.rec.sim[-last]), c(REC.lo[tsel],lower.rec.sim[-last]), c(REC.hi[tsel],upper.rec.sim[-last]), xlab="Time", ylab="REC", type="b", cex=.5, las=1)
    points(ftimes[-last],median.rec.sim[-last],col=2,pch=16)

}

uncertaintyPlot<-function(x, y, lower, upper, r=0/256,g=0/256,b=128/256, ...) {
    # create the main y vs. x plot
    plot(x,y,lwd=3, col=rgb(r,g,b,1), ylim=c(min(lower,na.rm=TRUE), max(upper,na.rm=TRUE)),...)
    # add a grid
    grid(col="darkgray")
    fin<-which(is.finite(lower))
    polygon(c(x[fin], rev(x[fin])), c(lower[fin], rev(upper[fin])),  col=rgb(r,g,b,0.4), border=NA)
    # re-plot the y vs. x line so it's on top of the uncertainty range
    lines(x,y,lwd=3, lty=1, col=rgb(r,g,b,1))
}

y2q<-function(x) (x-floor(x))*4 + 1
