//  -------------------------------------------------------------------------------
//  Seasonal State-space assessment model (SESAM)
//
//  Copyright (c) 2015, Casper W. Berg and Anders Nielsen <an@aqua.dtu.dk>  
//  -------------------------------------------------------------------------------
     
#include <TMB.hpp>
#include <iostream>

template<class Type>
bool isNA(Type x){
  return R_IsNA(asDouble(x));
}

bool isNAINT(int x){
  return NA_INTEGER==x;
}

template <class Type>
Type ilogit(Type x){
  return Type(1.0)/(Type(1.0)+exp(-x));
}

template <class Type>
Type myfloor(Type x){
  return Type(std::floor(asDouble(x)));
}

template <class Type>
Type f(Type x){return Type(2)/(Type(1) + exp(-Type(2) * x)) - Type(1);}
 
template<class Type>
Type objective_function<Type>::operator() ()
{
  DATA_VECTOR(t1);     
  DATA_VECTOR(t2);    
  DATA_IVECTOR(ageFrom);
  DATA_IVECTOR(ageTo);  
  DATA_VECTOR(obs);    
  DATA_IVECTOR(fleet);  
  DATA_VECTOR(auxt1); 
  DATA_VECTOR(auxt2);  
  DATA_IVECTOR(auxage); 
  DATA_VECTOR(auxM);   
  DATA_VECTOR(auxPM);  
  DATA_VECTOR(auxSW);  
  DATA_VECTOR(auxCW);  
  DATA_VECTOR(auxDW);  
  DATA_VECTOR(auxLF);
  DATA_IVECTOR(auxbigjump);
  DATA_VECTOR(times); 
  DATA_IVECTOR(ages); 
  DATA_SCALAR(bday);  
  DATA_IARRAY(mapaux);
  DATA_IARRAY(keyVarObs);
  DATA_IVECTOR(keyVarLogN);
  DATA_IVECTOR(keyLogFsta);
  DATA_IVECTOR(keyVarLogF);
  DATA_IARRAY(keyLogQ);
  DATA_IVECTOR(idx1);//index corresponding to timeFrom (t1)
  DATA_IVECTOR(idx2);//index corresponding to timeFrom (t2)
  DATA_IVECTOR(fbarrange);
  DATA_IVECTOR(maplagR);
  DATA_IMATRIX(maplagF);
  DATA_IVECTOR(recruitTimes);
  DATA_VECTOR(eps); // minimum obs possible (less than this is treated as censored) by fleet
  DATA_VECTOR(logobs);
  DATA_INTEGER(noYears);
  DATA_VECTOR(quarterFrac); //for each survey fleet fraction of quarter
  DATA_VECTOR(jumpmult); // F process "big jump" standard dev. multiplier
  DATA_IVECTOR(quarter);
  //DATA_IVECTOR(dontfit);  
  DATA_VECTOR_INDICATOR(keep, logobs);
  

  PARAMETER_VECTOR(logSdLogFsta); 
  PARAMETER_VECTOR(logSdLogN); 
  PARAMETER_VECTOR(logSdLogObs); 
  PARAMETER_VECTOR(logQ)
  PARAMETER(trans_rho); 
  PARAMETER(logSdLogR); 
  PARAMETER_ARRAY(logF); 
  PARAMETER_ARRAY(logN);
  PARAMETER_VECTOR(logR);
  
  int timeSteps=times.size();
  int stateDimF=logF.dim[0];
  int stateDimN=logN.dim[0];
  int nobs=obs.size();
  Type rho=f(trans_rho);
  vector<Type> sdLogFsta=exp(logSdLogFsta);
  vector<Type> sdLogN=exp(logSdLogN);
  Type sdLogR=exp(logSdLogR);
  double timeEps = 1e-6;

  vector<Type> sdLogObs=exp(logSdLogObs);

  Type ans=0; //negative log-likelihood
  
  //First take care of F
  matrix<Type> fvar(stateDimF,stateDimF);
  matrix<Type> fcor(stateDimF,stateDimF);
  vector<Type> fsd(stateDimF);  
  for(int i=0; i<stateDimF; ++i){
    fcor(i,i)=1.0;
  }

  for(int i=0; i<stateDimF; ++i){
    for(int j=0; j<i; ++j){
      fcor(i,j)=pow(rho,abs(i-j)); //pow(rho,abs(Type(i-j)));
      fcor(j,i)=fcor(i,j);
    }
  } 
  
  for(int i=0; i<stateDimF; ++i){
    fsd(i)=sdLogFsta(keyVarLogF(i));
  }
 
  for(int i=0; i<stateDimF; ++i){
    for(int j=0; j<stateDimF; ++j){
      fvar(i,j)=fsd(i)*fsd(j)*fcor(i,j);
    }
  }

  using namespace density;
  MVNORM_t<Type> neg_log_densityF(fvar);
  Type sqrtStep; 
  for(int i=1;i<timeSteps;i++){
    if(!isNAINT(maplagF(1,i))){ 
      sqrtStep=sqrt(times(i)-times(maplagF(1,i)));
      // Down weighting of some chosen increments by pretending the timestep is really large     
      //if(times(i)>2004.9 && times(i)<2008)sqrtStep*=Type(100);
      if(auxbigjump(i)>0) sqrtStep*= jumpmult( auxbigjump(i) ); //Type(100);
      ans+=neg_log_densityF((logF.col(i)-logF.col(maplagF(1,i)))/sqrtStep)+Type(stateDimF)*log(sqrtStep);// F-Process likelihood
    }
  }
 
  // calc fbar
  vector<Type> fbar(timeSteps);
  vector<Type> logfbar(timeSteps);
  vector<Type> logfbarY(noYears);
  for(int i=0;i<timeSteps;i++){ 
    fbar(i)=0.0;    
    for(int j=0; j<stateDimN; ++j){
      if((ages(j)<=fbarrange(1))&&(ages(j)>=fbarrange(0))){
        fbar(i)+=exp(logF(j,i));
      }
    }
    fbar(i)/=Type(fbarrange(1)-fbarrange(0)+1);
    logfbar(i)=log(fbar(i));
  }
  // calc fbar yearly
  
  Type tmp=0.0;
  Type tt1, tt2;
  tmp+=fbar(0);
  int idx=0; // year index
  int tmpidx=1; // number of values added for current year
  for(int i=1;i<timeSteps;i++){
    
    tt1=times(i-1); 
    tt2=times(i); 
    if(((tt1-myfloor(tt2)+timeEps)<bday)&&((tt2-myfloor(tt2)+timeEps)>bday)){
      logfbarY(idx) = log(tmp/tmpidx);
      idx++;
      tmpidx=0;
      tmp=0.0;
    }
    tmp+=fbar(i);
    tmpidx++;
  }    

  
  // calc ssb
  vector<Type> ssb(timeSteps);
  vector<Type> logssb(timeSteps);
  for(int i=0;i<timeSteps;i++){ 
    ssb(i)=0.0;    
    for(int j=0; j<stateDimN; ++j){ 
      ssb(i)+=exp(logN(j,i))*auxPM(mapaux(j,i))*auxSW(mapaux(j,i));
    }
    logssb(i)=log(ssb(i));
  }

  // calc tsb
  vector<Type> tsb(timeSteps);
  vector<Type> logtsb(timeSteps);
  for(int i=0;i<timeSteps;i++){ 
    tsb(i)=0.0;    
    for(int j=0; j<stateDimN; ++j){ 
      tsb(i)+=exp(logN(j,i))*auxSW(mapaux(j,i));
    }
    logtsb(i)=log(tsb(i));
  }


  //Now take care of N

  //find first recruitment
  int firstR=0;
  for(int i=0;i<timeSteps;i++){
    if(recruitTimes[i]==0) { ans += -dnorm(logR(i),Type(-3),Type(0.1),true); } else 
    { firstR=i; break; }
  }
  int lastR = firstR;
  for(int i=firstR+1;i<timeSteps;i++){
    if(recruitTimes[i]==1){
      ans += -dnorm(logR(i),logR(lastR),sdLogR,true);
      lastR = i;
	} else {
      ans += -dnorm(logR(i),Type(-3),Type(0.1),true);
    }
  }

  vector<Type> nsd(stateDimN);
  for(int j=0; j<stateDimN; ++j){
    nsd(j)=sdLogN(keyVarLogN(j));
  }
  vector<Type> predN(stateDimN); 
  Type deltat, sqrtdeltat;
  ans+=-dnorm(logN(0,0),Type(0),Type(1),true);
  logN(0,0)=logR(0);
  matrix<Type> residN(timeSteps,stateDimN);
  for(int i=1;i<timeSteps;i++){
    
    
    
    tt1=times(i-1); 
    tt2=times(i); 
    deltat=tt2-tt1; 
    sqrtdeltat=sqrt(deltat);
    for(int j=0; j<stateDimN; ++j){
      predN(j)=logN(j,i-1)-(exp(logF(keyLogFsta(j),i-1))+auxM(mapaux(j,i-1)))*deltat; 
    }  
    if(((tt1-myfloor(tt2)+timeEps)<bday)&&((tt2-myfloor(tt2)+timeEps)>bday)){ // if time from is before bday and time to is after bday
      //if(recruitTimes(i)==1){
	ans+=-dnorm(logN(0,i),Type(0),Type(1),true); // trick to disable logN(0,) in timesteps with recruitment: put a prior on it, and then overwrite(!)
	logN(0,i)=logR(i);
	//}
      predN(stateDimN-1)=log(exp(predN(stateDimN-1))+exp(predN(stateDimN-2)));
      for(int j=(stateDimN-2); j>0; --j){
        predN(j)=predN(j-1); 
      }
 
      for(int j=1; j<stateDimN; ++j){
        ans+=-dnorm((logN(j,i)-predN(j))/sqrtdeltat,Type(0),nsd(j),true)+log(sqrtdeltat);
	residN(i,j) = (logN(j,i)-predN(j))/(sqrtdeltat*nsd(j));
      }
      
    }else{
      //predN(0)=log(exp(predN(0))+exp(logR(i))); 
      for(int j=1; j<stateDimN; ++j){
        ans+=-dnorm((logN(j,i)-predN(j))/sqrtdeltat,Type(0),nsd(j),true)+log(sqrtdeltat);
	residN(i,j) = (logN(j,i)-predN(j))/(sqrtdeltat*nsd(j));
      }
      if(recruitTimes(i)==1 || quarter(i)<4){
	//recruitment
	ans+=-dnorm(logN(0,i),Type(0),Type(1),true);
	logN(0,i)=log(exp(logR(i)) + exp(predN(0)));
      } else if(quarter(i)==4) { // quarter 4 has survival process error for age 0
	ans+=-dnorm((logN(0,i)-predN(0))/sqrtdeltat,Type(0),nsd(0),true)+log(sqrtdeltat);
	residN(i,0) = (logN(0,i)-predN(0))/(sqrtdeltat*nsd(0));
      }
      
    }
    
  }

  // Now finally match to observations 

  int i1, i2, f, a1, a2;  
  vector<Type> predObs(nobs);
  predObs.setZero();
  vector<Type> predSd(nobs);
  Type zz, dt;
  for(int i=0;i<nobs;i++){
    i1=idx1(i);
    i2=idx2(i);
    f=fleet(i);
    a1=ageFrom(i)-ages(0);
    a2=ageTo(i)-ages(0);
    if(f==1){
      for(int idx=i1; idx<i2; ++idx){
        dt=times(idx+1)-times(idx);
        for(int a=a1; a<=a2; ++a){
          zz=exp(logF(keyLogFsta(a),idx))+auxM(mapaux(a,idx));
          predObs(i)+=exp(logN(a,idx)-log(zz)+log(Type(1)-exp(-zz*dt))+logF(keyLogFsta(a),idx));
        }
      }
    }
    
    if(f>=2){
      for(int a=a1; a<=a2; ++a){
	if(!isNAINT(keyLogQ(f-1,a))){
	  dt=times(1)-times(0); // obs, assumes constant dt
	  zz=exp(logF(keyLogFsta(a),i1))+auxM(mapaux(a,i1));
	  predObs(i)+=exp(logN(a,i1) -zz*dt*quarterFrac(f-1) +logQ(keyLogQ(f-1,a)));
	}
      }
    }
    
    if(!isNA(obs(i))){
      predObs(i)=log(predObs(i));
      if(!isNAINT(keyVarObs(f-1,a1))){
	predSd(i)=sdLogObs(keyVarObs(f-1,a1)); // obs using first age in range 
      }
      
      // observation likelihood, branch to pnorm likelihood of obs not greater than detection limit (eps)
      // CppAD Conditional expression must be used here for OSA calculations to work
      //if(dontfit(i)!=1){
	ans -= keep(i) * CppAD::CondExpGt(obs(i),
					  eps(f-1),
					  dnorm(logobs(i),predObs(i),predSd(i),true),
					  log(pnorm(log(eps(f-1)),predObs(i),predSd(i))));
	//}
    }
  }

  //PRIORS 
  // F prior of optimization stability
  for(int i=0;i<timeSteps;i++){
    for(int j=0; j<stateDimF; j++){
      ans+=-dnorm(logF(j,i),Type(0),Type(20));
    }
  }

  REPORT(predObs);
  REPORT(predSd);
  ADREPORT(ssb);
  ADREPORT(tsb);
  ADREPORT(logssb);
  ADREPORT(logtsb);
  ADREPORT(logfbar);
  ADREPORT(logfbarY);
  ADREPORT(rho);
  ADREPORT(logN); // ADreport these, because of the "overwrite" hack
  ADREPORT(logF);
  ADREPORT(logR);
  ADREPORT(residN);
  return ans;
}


