#include <iostream>
#include <vector>
#include <string>
#include <utility>
#include <algorithm>
#include <cstdio>
#include <cstdlib>
#include <cassert>
#include <cmath>

#include "VersionHeader.hpp"
#include "AssignBatches.hpp"
#include "FileUtils.cpp"
#include "StringUtils.cpp"
#include "Timer.cpp"

using namespace std;

const char NAN_CHAR = -128;
const double CLIP_LOGP = -log(1e-4);

struct GenoInfo {
  char lrr;
  char theta;
  /*
  unsigned char geno: 2;
  unsigned char conf: 6;
  */
  unsigned char genoConf;
};

struct GenoInfoTrim {
  char lrr;
  unsigned char genoConf;
};

inline double sq(double x) {
  return x*x;
}

inline double clipLogP(double x) {
  if (x > CLIP_LOGP) return CLIP_LOGP;
  if (x < -CLIP_LOGP) return -CLIP_LOGP;
  return x;
}

double logPlrr(pair <double, double> muSigma, char lrr, double muCNV, double relScale) {
  if (lrr == NAN_CHAR) return 0;
  double mu = muSigma.first, sigma = muSigma.second;
  return clipLogP(0.5*(sq(lrr-mu) - sq(lrr-muCNV)) / sq(sigma*relScale));
}

vector <int> readBimCoords(const char *bimFile) {
  cout << "Reading bim file: " << bimFile << endl;
  FileUtils::AutoGzIfstream finBim; finBim.openOrExit(bimFile);
  vector <int> bps;
  int chr, bp; double genpos; string rsID;
  while (finBim >> chr >> rsID >> genpos >> bp) {
    bps.push_back(bp);
    getline(finBim, rsID); // ignore rest of iine
  }
  finBim.close();
  return bps;
}

GenoInfoTrim *extractBatchData(vector < pair <double, double> > &muSigmas, uint64 M,
			       const vector <int> &batch, int b, const char *lrrThetaGenoFile) {
  
  muSigmas = vector < pair <double, double> > (M, make_pair(0.0, 1e6));
  GenoInfoTrim *data = new GenoInfoTrim[count(batch.begin(), batch.end(), b) * M];
  assert(data != NULL);
  int N = batch.size();

  /***** read, process, store lrr theta confgeno info *****/
  cout << "Extracting SNP-array data for batch" << b << " from " << lrrThetaGenoFile << endl;
  FILE *finBin; finBin = fopen(lrrThetaGenoFile, "rb"); assert(finBin != NULL);
  vector <GenoInfo> genoRow(N);
  for (uint m = 0; m < M; m++) {
    assert(fread(&genoRow[0], sizeof(GenoInfo), N, finBin) == (uint64) N);
    int num = 0; double sum = 0, sum2 = 0;
    for (int i = 0; i < N; i++)
      if (batch[i]==b && genoRow[i].lrr != NAN_CHAR) {
	num++;
	sum += genoRow[i].lrr;
	sum2 += sq(genoRow[i].lrr);
      }
    if (num > 50) {
      muSigmas[m].first = sum / num;
      muSigmas[m].second = sqrt((sum2 - sq(sum)/num) / (num-1));
    }
    int ib = 0;
    for (int i = 0; i < N; i++)
      if (batch[i]==b) {
	data[ib*M + m].lrr = genoRow[i].lrr;
	data[ib*M + m].genoConf = genoRow[i].genoConf;
	ib++;
      }
  }
  assert(fgetc(finBin) == EOF);
  fclose(finBin);

  return data;
}

int main(int argc, char *argv[]) {

  printVersion();

  cout << "call_CNVs_prelim:" << endl;
  cout << "- arg1 = $LRR_STD_SCALE_FILE" << endl;
  cout << "- arg2 = $BIM_FILE" << endl;
  cout << "- arg3 = $LRR_THETA_GENO_FILE" << endl;
  cout << "- arg4 = $LRR_DEL" << endl;
  cout << "- arg5 = $LRR_DUP" << endl;
  cout << "- arg6 = $LRR_REFINE_ITERS" << endl;
  cout << "- arg7 = $PRELIM_CNV_CALL_PREFIX (.batch*.txt.gz output files)" << endl;
  cout << endl;

  printCmd(argc, argv);

  if (argc != 8) {
    cout << "ERROR: 7 arguments required" << endl;
    exit(1);
  }

  const char *lrrStdScaleFile = checkInputFileExt(argv, 1, ".txt");
  const char *bimFile = checkInputFileExt(argv, 2, ".bim");
  const char *lrrThetaGenoFile = checkInputFileExt(argv, 3, ".bin");
  double lrrDEL; assert(sscanf(argv[4], "%lf", &lrrDEL));
  double lrrDUP; assert(sscanf(argv[5], "%lf", &lrrDUP));
  double muLRRs64[2] = {64*lrrDEL, 64*lrrDUP};
  int muRefineIters; assert(sscanf(argv[6], "%d", &muRefineIters));
  const char *prelimCNVcallPrefix = argv[7];

  FileUtils::requireReadable(lrrStdScaleFile);
  FileUtils::requireReadable(bimFile);
  FileUtils::requireReadable(lrrThetaGenoFile);
  FileUtils::requireWriteable(prelimCNVcallPrefix + string(".batch1.txt.gz"));

  Timer timer; double t0 = timer.get_time();

  cout << "Initial guess for LRR of DELs: " << lrrDEL << endl;
  assert(lrrDEL < 0 && lrrDEL >= -1);
  cout << "Initial guess for LRR of DUPs: " << lrrDUP << endl;
  assert(lrrDUP > 0 && lrrDUP <= 1);
  cout << "Number of refinement iterations to update initial guesses: " << muRefineIters << endl;
  assert(muRefineIters >= 0);
  cout << endl;

  /***** read lrr std scale file (noise per indiv) *****/
  vector < pair <int, int> > batchMinMaxPerSet;
  vector <int> batch; vector <double> relScale; vector <string> IDpairs;
  int B = assignBatches(batchMinMaxPerSet, batch, relScale, IDpairs, lrrStdScaleFile);
  int N = IDpairs.size();

  vector <int> bps = readBimCoords(bimFile);
  uint64 M = bps.size();
  cout << "Read " << M << " variants from bim file" << endl;
  
  /***** set lookup table defining penalty for hets within DELs *****/
  double logHetDelPenalty[256]; // log P(geno,conf|del): 0 unless geno==1 (het)
  const double pErrMin = 5e-6;
  const double scale = 63.5 / -log(pErrMin);
  for (int u = 0; u < 256; u++) {
    if ((u&3)==1) // het => penalty based on genotyping confidence (negative)
      logHetDelPenalty[u] = -clipLogP((u>>2) / scale);
    else // not het => no penalty
      logHetDelPenalty[u] = 0;
  }

  for (int b = 1; b <= B; b++) { // analyze all batches in turn

    FileUtils::AutoGzOfstream fout;
    fout.openOrExit(prelimCNVcallPrefix + string(".batch") + StringUtils::itos(b) + ".txt.gz");

    cout << "Reading data for batch " << b << " of " << B << endl;

    // allocate storage for LRR and geno+conf bytes 
    vector < pair <double, double> > muSigmas;
    GenoInfoTrim *data = extractBatchData(muSigmas, M, batch, b, lrrThetaGenoFile);

    for (int iter = 1; iter <= muRefineIters+1; iter++) {

      // initialize cumulative totals of LRR within interior of long CNV calls
      int lrrLargeCNVnum[2] = {0, 0};
      double lrrLargeCNVsum[2] = {0, 0};

      /***** run HMM caller *****/
      double (*cumLogP)[3] = new double[M][3];
      char (*prev)[3] = new char[M][3];
      const double logPjump = log(1e-4);
      int ib = 0;
      for (int i = 0; i < N; i++) {
	if (!(batch[i]==b)) continue;
	const GenoInfoTrim *iData = &data[ib*M]; ib++;
	// initialize
	uint m = 0;
	cumLogP[m][0] = logPjump + logPlrr(muSigmas[m], iData[m].lrr, muLRRs64[0], relScale[i])
	  + logHetDelPenalty[iData[m].genoConf]; // del
	cumLogP[m][1] = 0; // CN=2
	cumLogP[m][2] = logPjump + logPlrr(muSigmas[m], iData[m].lrr, muLRRs64[1], relScale[i]); // dup
	// iterate
	for (m = 1; m < M; m++) {
	  // transition; set prev
	  for (int s = 0; s <= 2; s++)
	    for (int t = 0; t <= 2; t++)
	      if (t==0 || (cumLogP[m][s] < cumLogP[m-1][t] + abs(s-t)*logPjump)) {
		cumLogP[m][s] = cumLogP[m-1][t] + abs(s-t)*logPjump;
		prev[m][s] = t;
	      }
	  // emission
	  cumLogP[m][0] += logPlrr(muSigmas[m], iData[m].lrr, muLRRs64[0], relScale[i])
	    + logHetDelPenalty[iData[m].genoConf]; // del
	  cumLogP[m][2] += logPlrr(muSigmas[m], iData[m].lrr, muLRRs64[1], relScale[i]); // dup
	}
	// finalize: penalty for ending in del/dup state
	m = M-1;
	cumLogP[m][0] += logPjump;
	cumLogP[m][2] += logPjump;
    
	// backtrack
	int s = 0;
	if (cumLogP[m][1] > cumLogP[m][s]) s = 1;
	if (cumLogP[m][2] > cumLogP[m][s]) s = 2;
	int mSegEnd = m;
	vector < pair <int, int> > CNVsegs; // inclusive
	vector <bool> states; // 0=del, 1=dup
	while (m > 0) {
	  if (prev[m][s] != s) {
	    if (s != 1) { CNVsegs.push_back(make_pair(m, mSegEnd)); states.push_back(s/2); }
	    mSegEnd = m-1;
	    s = prev[m][s];
	  }
	  m--;
	}
	if (s != 1) { CNVsegs.push_back(make_pair(m, mSegEnd)); states.push_back(s/2); }

	for (int k = CNVsegs.size()-1; k >= 0; k--) {
	  int mStart = CNVsegs[k].first, mEnd = CNVsegs[k].second, state = states[k];
	  int lrrNum = 0, genoNum = 0, hets = 0; double lrrSum = 0, lrrSum2 = 0, logPsum = 0;
	  for (int m = mStart; m <= mEnd; m++)
	    if (iData[m].lrr != NAN_CHAR) {
	      lrrNum++;
	      lrrSum += iData[m].lrr;
	      lrrSum2 += sq(iData[m].lrr);
	    }
	  // augment totals of LRR within interior of long CNV calls
	  const int minSNPs[2] = {15, 50};
	  const int dropEdgeSNPs[2] = {3, 10};
	  if (lrrNum >= minSNPs[state]) {
	    int lrrInd = 0;
	    for (int m = mStart; m <= mEnd; m++)
	      if (iData[m].lrr != NAN_CHAR) {
		if (lrrInd >= dropEdgeSNPs[state] && lrrInd < lrrNum-dropEdgeSNPs[state]) {
		  lrrLargeCNVnum[state]++;
		  lrrLargeCNVsum[state] += iData[m].lrr;
		}
		lrrInd++;
	      }
	  }

	  if (iter == muRefineIters+1) { // print output
	    for (int m = mStart; m <= mEnd; m++) {
	      logPsum += logPlrr(muSigmas[m], iData[m].lrr, muLRRs64[state], relScale[i]);
	      if (state==0)
		logPsum += logHetDelPenalty[iData[m].genoConf]; // del
	      if ((iData[m].genoConf&3)==1) hets++;
	      if ((iData[m].genoConf&3)!=3) genoNum++;
	    }
	    double mu = lrrSum / lrrNum;
	    double SE = sqrt((lrrSum2 - sq(lrrSum)/lrrNum) / (lrrNum-1) / lrrNum);
	    const char types[2][4] = {"DEL", "DUP"};
	    char buf[100];
	    sprintf(buf, "%s %3s %7.2f %6.2f (%4.2f) %6.1f kb %5.1f %5.1f %3d %3d %3d %5d %5d\n",
		    IDpairs[i].c_str(), types[state], logPsum/log(10), mu/64, SE/64,
		    (bps[mEnd]-bps[mStart])*1e-3, bps[mStart]*1e-6, bps[mEnd]*1e-6,
		    lrrNum, genoNum, hets, mStart, mEnd);
	    fout << string(buf);
	  }
	}
      }
  
      delete[] prev;
      delete[] cumLogP;

      for (int state = 0; state <= 1; state++)
	muLRRs64[state] = lrrLargeCNVsum[state] / lrrLargeCNVnum[state];
      printf("  iter %d: estimated LRR DEL = %.3f, DUP = %.3f (from %d+%d values)",
	     iter, muLRRs64[0]/64, muLRRs64[1]/64, lrrLargeCNVnum[0], lrrLargeCNVnum[1]);
      cout << endl;
    }
    delete[] data;
    fout.close();
  }

  cout << "Finished call_CNVs_prelim; total time = " << timer.get_time()-t0 << " sec" << endl;

  return 0;
}
