# Stochastic surplus Production model in Continuous-Time (SPiCT)
#    Copyright (C) 2015  Martin Waever Pedersen, mawp@dtu.dk or wpsgodd@gmail.com
#
#    This program is free software: you can redistribute it and/or modify
#    it under the terms of the GNU General Public License as published by
#    the Free Software Foundation, either version 3 of the License, or
#    (at your option) any later version.
#
#    This program is distributed in the hope that it will be useful,
#    but WITHOUT ANY WARRANTY; without even the implied warranty of
#    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
#    GNU General Public License for more details.
#
#    You should have received a copy of the GNU General Public License
#    along with this program.  If not, see <http://www.gnu.org/licenses/>.


#' @name summary.spictcls
#' @title Output a summary of a fit.spict() run.
#' @details The output includes the parameter estimates with 95% confidence intervals, estimates of derived parameters (Bmsy, Fmsy, MSY) with 95% confidence intervals, and predictions of biomass, fishing mortality, and catch.
#' @param object A result report as generated by running fit.spict.
#' @param numdigits Present values with this number of digits after the dot.
#' @return Nothing. Prints a summary to the screen.
#' @examples
#' data(pol)
#' rep <- fit.spict(pol$albacore)
#' summary(rep)
#' @export
summary.spictcls <- function(object, numdigits=8){
    rep <- object
    cat(paste('Convergence: ', rep$opt$convergence, '  MSG: ', rep$opt$message, '\n', sep=''))
    if(rep$opt$convergence>0){
        cat('WARNING: Model did not obtain proper convergence! Estimates and uncertainties are most likely invalid and cannot be trusted.\n')
        grad <- rep$obj$gr()
        names(grad) <- names(rep$par.fixed)
        cat('Gradient at current parameter vector\n')
        cat('', paste(capture.output(grad),' \n'))
        cat('\n')
    }
    if('sderr' %in% names(rep)) cat('WARNING: Could not calculate standard deviations. The optimum found may be invalid. Proceed with caution.\n')
    cat(paste0('Objective function at optimum: ', round(rep$opt$objective, numdigits), '\n'))
    cat(paste0('Euler time step: 1/', 1/rep$inp$dteuler, ' or ', rep$inp$dteuler, '\n'))
    cat(paste0('Nobs C: ', rep$inp$nobsC, paste0(paste0(',  Nobs I', 1:rep$inp$nindex), ': ', rep$inp$nobsI, collapse=''), '\n'))
    # -- Catch/biomass unit --
    if(rep$inp$catchunit != ''){
        cat(paste('Catch/biomass unit:', rep$inp$catchunit, '\n'))
    }
    # -- Residual diagnostics --
    statout <- unlist(rep$stats)
    if(length(statout)>0){
        cat('\nResidual diagnostics\n')
        inds <- grep('.p', names(statout))
        sig <- which(statout[inds]<0.05)
        names(statout)[inds][sig] <- paste0('*', names(statout)[inds][sig])
        cat('', paste(capture.output(statout),' \n'))
    }
    # -- Priors --
    indso <- which(rep$inp$priorsuseflag==1)
    if(length(indso)>0){
        usepriors <- names(rep$inp$priors)[indso]
        npriors <- length(usepriors)
        repriors <- c('logB', 'logF', 'logBBmsy', 'logFFmsy')
        if(any(repriors %in% usepriors)){
            inds <- match(repriors, usepriors)
            inds <- inds[!is.na(inds)]
            for(i in 1:length(inds)) usepriors[inds[i]] <- paste0(usepriors[inds[i]], fd(rep$inp$priors[[inds[i]]][4]))
        }
        str <- character(npriors)
        cat(paste('\nPriors\n'))
        maxchar <- max(nchar(usepriors))
        for(i in 1:npriors){
            str[i] <- paste0('~  N[log(', round(exp(rep$inp$priors[[indso[i]]][1]), 3), '), ', round(rep$inp$priors[[indso[i]]][2], 3), '^2]', ifelse(rep$inp$priors[[indso[i]]][2] <= 1e-3, ' (fixed)', ''))
            usepriors[i] <- formatC(usepriors[i], width = maxchar, flag = 0)
            cat(paste0(' ', usepriors[i], '  ', str[i], '\n'))
        }
    }
    # -- Fixed parameters --
    resout <- sumspict.fixedpars(rep, numdigits=numdigits)
    if(!is.null(resout)){
        cat('\nFixed parameters\n')
        cat('', paste(capture.output(resout),' \n'))
    }
    # -- Model parameters --
    cat('\nModel parameter estimates w 95% CI \n')
    resout <- sumspict.parest(rep, numdigits=numdigits)
    cat('', paste(capture.output(resout),' \n'), '\n')
    if(rep$inp$do.sd.report & !'sderr' %in% names(rep)){
        # Deterministic ref points
        cat('Deterministic reference points (Drp)\n')
        derout <- sumspict.drefpoints(rep, numdigits=numdigits)
        cat('', paste(capture.output(derout),' \n'))
        # Stochastic derived estimates
        cat('Stochastic reference points (Srp)\n')
        derout <- sumspict.srefpoints(rep, numdigits=numdigits)
        cat('', paste(capture.output(derout),' \n'))
        # States
        cat(paste0('\nStates w 95% CI (inp$msytype: ', rep$inp$msytype, ')\n'))
        stateout <- sumspict.states(rep, numdigits=numdigits)
        cat('', paste(capture.output(stateout),' \n'))
        # Predictions
        if(rep$inp$reportall){
            cat(paste0('\nPredictions w 95% CI (inp$msytype: ', rep$inp$msytype, ')\n'))
            predout <- sumspict.predictions(rep, numdigits=numdigits)
            cat('', paste(capture.output(predout),' \n'))

        } else {
            cat(paste0('\nPredictions omitted because inp$reportall = FALSE\n'))
        }
    }
}


#' @name get.order
#' @title Get order of printed quantities.
#' @return Vector containing indices of printed quantities.
get.order <- function() return(c(2, 1, 3, 2))


#' @name get.colnms
#' @title Get column names for data.frames.
#' @return Vector containing column names of data frames.
get.colnms <- function() return(c('estimate', 'cilow', 'ciupp', 'est.in.log'))


#' @name sumspict.parest
#' @title Parameter estimates of a fit.spict() run.
#' @param rep A result report as generated by running fit.spict.
#' @param numdigits Present values with this number of digits after the dot.
#' @return data.frame containing parameter estimates.
#' @export
sumspict.parest <- function(rep, numdigits=8){
    if(rep$inp$do.sd.report){
        order <- get.order()
        colnms <- get.colnms()
        sd <- sqrt(diag(rep$cov.fixed))
        nms <- names(rep$par.fixed)
        loginds <- grep('log', nms)
        logp1inds <- grep('logp1',nms)
        logitinds <- grep('logit',nms)
        loginds <- setdiff(loginds, c(logp1inds, logitinds))
        est <- rep$par.fixed
        est[loginds] <- exp(est[loginds])
        est[logitinds] <- invlogit(est[logitinds])
        est[logp1inds] <- invlogp1(est[logp1inds])
        cilow <- rep$par.fixed-1.96*sd
        cilow[loginds] <- exp(cilow[loginds])
        cilow[logitinds] <- invlogit(cilow[logitinds])
        cilow[logp1inds] <- invlogp1(cilow[logp1inds])
        ciupp <- rep$par.fixed+1.96*sd
        ciupp[loginds] <- exp(ciupp[loginds])
        ciupp[logitinds] <- invlogit(ciupp[logitinds])
        ciupp[logp1inds] <- invlogp1(ciupp[logp1inds])
        if('true' %in% names(rep$inp)){
            npar <- length(nms)
            unms <- unique(nms)
            nupar <- length(unms)
            truepar <- NULL
            parnotest <- NULL
            for(i in 1:nupar){
                tp <- rep$inp$true[[unms[i]]]
                nestpar <- sum(names(est) == unms[i])
                truepar <- c(truepar, tp[1:nestpar])
                if(nestpar < length(tp)){
                    inds <- (nestpar+1):length(tp)
                    parnotest <- c(parnotest, tp[inds])
                    names(parnotest) <- c(names(parnotest), paste0(unms[i], inds))
                }
            }
            truepar[loginds] <- exp(truepar[loginds])
            truepar[logitinds] <- invlogit(truepar[logitinds])
            truepar[logp1inds] <- invlogp1(truepar[logp1inds])
            ci <- rep(0, npar)
            for(i in 1:npar) ci[i] <- as.numeric(truepar[i] > cilow[i] & truepar[i] < ciupp[i])
            resout <- cbind(estimate=round(est,numdigits), true=round(truepar,numdigits), cilow=round(cilow,numdigits), ciupp=round(ciupp,numdigits), true.in.ci=ci, est.in.log=round(rep$par.fixed,numdigits))
        } else {
            resout <- cbind(estimate=round(est,numdigits), cilow=round(cilow,numdigits), ciupp=round(ciupp,numdigits), est.in.log=round(rep$par.fixed,numdigits))
        }
        nms[loginds] <- sub('log', '', names(rep$par.fixed[loginds]))
        nms[logitinds] <- sub('logit', '', names(rep$par.fixed[logitinds]))
        nms[logp1inds] <- sub('logp1', '', names(rep$par.fixed[logp1inds]))
        unms <- unique(nms)
        for(inm in unms){
            nn <- sum(inm==nms)
            if(nn>1){
                newnms <- paste0(inm, 1:nn)
                inds <- which(inm==nms)
                nms[inds] <- newnms
            }
        }
        rownames(resout) <- nms
        # Derived variables
        nalpha <- sum(names(rep$par.fixed) == 'logsdi')
        derout <- rbind(get.par(parname='logalpha', rep, exp=TRUE)[1:nalpha, order],
                        get.par(parname='logbeta', rep, exp=TRUE)[, order],
                        get.par(parname='logr', rep, exp=TRUE)[, order])
        derout[, 4] <- log(derout[, 4])
        derout <- round(derout, numdigits)
        nr <- dim(derout)[1]
        if('true' %in% names(rep$inp)){
            dertrue <- exp(c(rep(rep$inp$true$logalpha, nalpha), rep$inp$true$logbeta, rep$inp$true$logr))
            ndertrue <- length(dertrue)
            if(ndertrue == nr){
                cider <- numeric(ndertrue)
                for(i in 1:ndertrue) cider[i] <- as.numeric(dertrue[i] > derout[i, 2] & dertrue[i] < derout[i, 3])
            } else {
                dertrue <- rep(-9, nr)
                cider <- rep(-9, nr)
            }
            derout <- cbind(est=derout[, 1], true=dertrue, ll=derout[, 2], ul=derout[, 3], tic=cider, eil=derout[, 4])
        }
        if(nr>1 & 'yearsepgrowth' %in% names(rep$inp)){
            rnms <- c('r     ', paste0('r', rep$inp$yearsepgrowth))
        } else {
            rnms <- 'r    '
        }
        if(nalpha > 1){
            alphanms <- paste0('alpha', 1:nalpha)
        } else {
            alphanms <- 'alpha'
        }
        rownames(derout) <- c(alphanms, 'beta', rnms)
        resout <- rbind(derout, resout)
        if('true' %in% names(rep$inp)){
            colnames(resout) <- c(colnms[1], 'true', colnms[2:3], 'true.in.ci', colnms[4])
        } else {
            colnames(resout) <- colnms
        }
    } else {
        if('opt' %in% names(rep)) resout <- data.frame(estimate=rep$opt$par)
    }
    return(resout)
}


#' @name sumspict.drefpoints
#' @title Deternistic reference points of a fit.spict() run.
#' @param rep A result report as generated by running fit.spict.
#' @param numdigits Present values with this number of digits after the dot.
#' @return data.frame containing deterministic reference points.
#' @export
sumspict.drefpoints <- function(rep, numdigits=8){
    order <- get.order()
    colnms <- get.colnms()
    derout <- rbind(get.par(parname='logBmsyd', rep, exp=TRUE)[,order],
                    get.par(parname='logFmsyd', rep, exp=TRUE)[,order],
                    get.par(parname='logMSYd', rep, exp=TRUE)[,order])
    derout[, 4] <- log(derout[, 4])
    derout <- round(derout, numdigits)
    colnames(derout) <- colnms
    nr <- length(rep$inp$ini$logr)
    if(nr > 1){
        rownames(derout) <- c(t(outer(c('Bmsyd', 'Fmsyd', 'MSYd'), 1:2, paste0)))
    } else {
        rownames(derout) <- c('Bmsyd', 'Fmsyd', 'MSYd')
    }
    if('true' %in% names(rep$inp)){
        trueder <- c(rep$inp$true$Bmsyd, rep$inp$true$Fmsyd, rep$inp$true$MSYd)
        cider <- numeric(3)
        for(i in 1:3) cider[i] <- as.numeric(trueder[i] > derout[i, 2] & trueder[i] < derout[i, 3])
        derout <- cbind(derout[, 1], round(trueder,numdigits), derout[, 2:3], cider, derout[, 4])
        colnames(derout) <- c(colnms[1], 'true', colnms[2:3], 'true.in.ci', colnms[4])
    }
    return(derout)
}


#' @name sumspict.srefpoints
#' @title Stochastic reference points of a fit.spict() run.
#' @param rep A result report as generated by running fit.spict.
#' @param numdigits Present values with this number of digits after the dot.
#' @return data.frame containing stochastic reference points.
#' @export
sumspict.srefpoints <- function(rep, numdigits=8){
    order <- get.order()
    colnms <- get.colnms()
    derout <- rbind(get.par(parname='logBmsys', rep, exp=TRUE)[,order],
                    get.par(parname='logFmsys', rep, exp=TRUE)[,order],
                    get.par(parname='logMSYs', rep, exp=TRUE)[,order])
    derout[, 4] <- log(derout[, 4])
    derout <- round(derout, numdigits)
    colnames(derout) <- colnms
    nr <- length(rep$inp$ini$logr)
    if(nr > 1){
        rownames(derout) <- c(t(outer(c('Bmsys', 'Fmsys', 'MSYs'), 1:2, paste0)))
    } else {
        rownames(derout) <- c('Bmsys', 'Fmsys', 'MSYs')
    }
    if('true' %in% names(rep$inp)){
        trueder <- c(rep$inp$true$Bmsy, rep$inp$true$Fmsy, rep$inp$true$MSY)
        cider <- rep(0, 3)
        for(i in 1:3) cider[i] <- as.numeric(trueder[i] > derout[i, 2] & trueder[i] < derout[i, 3])
        derout <- cbind(derout[, 1], round(trueder,numdigits), derout[, 2:3], cider, derout[, 4])
        colnames(derout) <- c(colnms[1], 'true', colnms[2:3], 'true.in.ci', colnms[4])
    }
    Drp <- c(get.par(parname='logBmsyd', rep, exp=TRUE)[, 2],
             get.par(parname='logFmsyd', rep, exp=TRUE)[, 2],
             get.par(parname='logMSYd', rep, exp=TRUE)[, 2])
    rel.diff.Drp <- (derout[, 1] - Drp)/derout[, 1]
    if(length(rel.diff.Drp) == dim(derout)[1]) derout <- cbind(derout, rel.diff.Drp)
    return(derout)
}


#' @name sumspict.states
#' @title State estimates of a fit.spict() run.
#' @param rep A result report as generated by running fit.spict.
#' @param numdigits Present values with this number of digits after the dot.
#' @return data.frame containing state estimates.
#' @export
sumspict.states <- function(rep, numdigits=8){
    order <- get.order()
    colnms <- get.colnms()
    stateout <- rbind(
        get.par(parname='logBl', rep, exp=TRUE)[order],
        get.par(parname='logFl', rep, exp=TRUE)[order],
        get.par(parname='logBlBmsy', rep, exp=TRUE)[order],
        get.par(parname='logFlFmsy', rep, exp=TRUE)[order])
    stateout[, 4] <- log(stateout[, 4])
    stateout <- round(stateout, numdigits)
    colnames(stateout) <- colnms
    et <- fd(rep$inp$time[rep$inp$indlastobs])
    rownames(stateout) <- c(paste0('B_',et), paste0('F_',et), paste0('B_',et,'/Bmsy'), paste0('F_',et,'/Fmsy'))
    return(stateout)
}


#' @name sumspict.predictions
#' @title Predictions of a fit.spict() run.
#' @param rep A result report as generated by running fit.spict.
#' @param numdigits Present values with this number of digits after the dot.
#' @return data.frame containing predictions.
#' @export
sumspict.predictions <- function(rep, numdigits=8){
    order <- get.order()
    colnms <- get.colnms()
    EBinf <- get.EBinf(rep)
    predout <- rbind(
        get.par(parname='logBp', rep, exp=TRUE)[order],
        get.par(parname='logFp', rep, exp=TRUE)[order],
        get.par(parname='logBpBmsy', rep, exp=TRUE)[order],
        get.par(parname='logFpFmsy', rep, exp=TRUE)[order],
        tail(get.par(parname='logCpred', rep, exp=TRUE),1)[order],
        c(EBinf, NA, NA, EBinf))
    inds <- predout[, 4] <= 0
    predout[inds, 4] <- NA
    inds <- predout[, 4] > 0 & !is.na(predout[, 4])
    predout[inds, 4] <- log(predout[inds, 4])
    predout <- round(predout, numdigits)
    colnames(predout) <- c('prediction', colnms[2:4])
    et <- fd(rep$inp$time[rep$inp$dtprediind])
    rownames(predout) <- c(paste0('B_',et), paste0('F_',et), paste0('B_',et,'/Bmsy'), paste0('F_',et,'/Fmsy'), paste0('Catch_', fd(tail(rep$inp$timeCpred,1))), 'E(B_inf)')
    if(rep$inp$dtpredc == 0) predout <- predout[-dim(predout)[1], ]
    return(predout)
}


#' @name print.spictcls
#' @title Output a summary of a fit.spict() run.
#' @param object A result report as generated by running fit.spict.
#' @return Nothing.
#' @export
print.spictcls <- function(object) summary(object)


#' @name sumspict.fixedpars
#' @title Fixed paramters table.
#' @param rep A result report as generated by running fit.spict.
#' @param numdigits Present values with this number of digits after the dot.
#' @return data.frame containing fixed parameter information.
#' @export
sumspict.fixedpars <- function(rep, numdigits=8){
    inds <- which(unlist(rep$inp$phases) < 0)
    nms <- names(rep$inp$phases)[inds]
    # Remove random effects
    reinds <- which(nms %in% rep$inp$RE)
    if(length(reinds)>0) nms <- nms[-reinds]
    # Are robust options used? if not remove
    if(!any(rep$inp$robflagi | rep$inp$robflagc)){
        nms <- nms[-match(c('logitpp', 'logp1robfac'), nms)]
    }
    # Are seasonal spline used? if not remove
    if(rep$inp$seasontype != 1){
        nms <- nms[-match('logphi', nms)]
    }
    # Are seasonal SDE used? if not remove
    if(rep$inp$seasontype != 2){
        nms <- nms[-match(c('logsdu', 'loglambda'), nms)]
    }
    nnms <- length(nms)
    if(nnms > 0){
        vals <- numeric(0)
        valnms <- character(0)
        for(i in 1:nnms){
            val <- get.par(parname=nms[i], rep)[2]
            vals <- c(vals, val)
            nval <- length(val)
            if(nval>1){
                valnms <- c(valnms, paste0(nms[i], 1:nval))
            } else {
                valnms <- c(valnms, nms[i])
            }
        }
        names(vals) <- valnms
        vals <- trans2real(vals, nms)
        df <- data.frame(fixed.value=vals)
        df <- round(df, numdigits)
        
        if('true' %in% names(rep$inp)){
            alltrue <- unlist(rep$inp$true)
            inds <- match(nms, names(alltrue))
            truevals <- alltrue[inds]
            truenms <- names(truevals)
            truevals <- trans2real(truevals, truenms, chgnms=FALSE)
            df <- cbind(df, true=truevals)
        }
        return(df)
    } else {
        return(NULL)
    }
}


#' @name trans2real
#' @title Get real parameter values from transformed ones.
#' @param vals Parameters in transformed domain.
#' @param nms Names of transformed parameters (including log etc.)
#' @param chgnms Remove transformation indication from the parameter names (e.g. remove log from logK).
#' @return Parameter values in the natural domain.
trans2real <- function(vals, nms, chgnms=TRUE){
    loginds <- grep('log', nms)
    logp1inds <- grep('logp1',nms)
    logitinds <- grep('logit',nms)
    loginds <- setdiff(loginds, c(logp1inds, logitinds))
    vals[loginds] <- exp(vals[loginds])
    vals[logitinds] <- invlogit(vals[logitinds])
    vals[logp1inds] <- invlogp1(vals[logp1inds])
    if(chgnms){
        valnms <- names(vals)
        valnms[logitinds] <- gsub('logit', '', valnms[logitinds])
        valnms[logp1inds] <- gsub('logp1', '', valnms[logp1inds])
        valnms[loginds] <- gsub('log', '', valnms[loginds])    
        names(vals) <- valnms
    }
    return(vals)
}


#' @name sumspict.fixedpars
#' @title Fixed paramters table.
#' @param rep A result report as generated by running fit.spict.
#' @param numdigits Present values with this number of digits after the dot.
#' @return data.frame containing fixed parameter information.
#' @export
sumspict.priors <- function(rep, numdigits=8){
    inds <- which(rep$inp$priorsuseflags == 1)
    nms <- names(rep$inp$priors[inds])
    ninds <- length(inds)
    means <- numeric(ninds)
    stds <- numeric(ninds)
    priornms <- nms
    for(i in 1:ninds){
        means[i] <- rep$inp$priors[[i]][1]
        stds[i] <- rep$inp$priors[[i]][2]
        if(nms[i] %in% c('logF', 'logB')){
            priornms[i] <- paste0(priornms[i], '_', rep$inp$priors[[i]][4])
        } else {
            priornms[i] <- paste0(priornms[i], '    ')
        }
    }
    df <- data.frame(mean=means, std=stds)
    rownames(df) <- priornms
    return(df)
}
