#include <iostream>
#include <vector>
#include <string>
#include <map>
#include <utility>
#include <algorithm>
#include <cstdio>
#include <cstdlib>
#include <cassert>
#include <cmath>

#include "VersionHeader.hpp"
#include "AssignBatches.hpp"
#include "FileUtils.cpp"
#include "StringUtils.cpp"
#include "Timer.cpp"

using namespace std;

struct IndDeltaCN {
  int ind;
  char deltaCN;
  IndDeltaCN(int _ind, char _deltaCN) : ind(_ind), deltaCN(_deltaCN) {}
};

void addCalls(vector < vector <IndDeltaCN> > &cnvBounds, map <string, int> &IDtoInd,
	      const string &prelimCNVcallPrefix, int b) {
  const int minSNPs[2] = {15, 50}; // minimum number of SNPs spanned by "large" DEL/DUP
  const int dropEdgeSNPs[2] = {3, 10};
  string prelimCNVcallFile = prelimCNVcallPrefix + ".batch" + StringUtils::itos(b) + ".txt.gz";
  cout << "Reading large CNVs called using only LRR: " << prelimCNVcallFile << endl;
  FileUtils::AutoGzIfstream fin; fin.openOrExit(prelimCNVcallFile);
  string ID_1, ID_2, type; float log10p, mlrr, SElrr, kb, startMb, endMb;
  int lrrNum, genoNum, hets, mStart, mEnd;
  while (fin >> ID_1 >> ID_2 >> type) {
    string IDpair = ID_1 + "\t" + ID_2;
    string line; getline(fin, line);
    sscanf(line.c_str(), "%f %f (%f) %f kb %f %f %d %d %d %d %d", &log10p, &mlrr, &SElrr, &kb,
	   &startMb, &endMb, &lrrNum, &genoNum, &hets, &mStart, &mEnd);
    int t = type == "DEL" ? 0 : 1;
    if (lrrNum >= minSNPs[t] && IDtoInd.find(IDpair) != IDtoInd.end()) {
      cnvBounds[mStart + dropEdgeSNPs[t]].push_back(IndDeltaCN(IDtoInd[IDpair], t==0 ? -1 : 1));
      cnvBounds[mEnd - dropEdgeSNPs[t] + 1].push_back(IndDeltaCN(IDtoInd[IDpair], t==0 ? 1 : -1));
    }
  }
}


const char NAN_CHAR = -128;
const double INF_STD = 1000;

struct GenoInfo {
  char lrr;
  char theta;
  unsigned char geno: 2;
  unsigned char conf: 6;
};

struct Cluster {
  double muX, muY, sigmaX, sigmaY;
  double slope; // d(theta)/d(lrr)
  double SigmaXX, SigmaYY, SigmaXY;
  int n;  
  Cluster() : muX(0), muY(0), sigmaX(INF_STD), sigmaY(INF_STD), slope(0),
	      SigmaXX(INF_STD*INF_STD), SigmaYY(INF_STD*INF_STD), SigmaXY(0), n(0) {}
};

inline double sq(double x) {
  return x*x;
}

Cluster computeCluster(const vector < pair <char, char> > &lrrThetaVec, int CN) {
  Cluster ans;
  int n = lrrThetaVec.size();
  if (n == 0) return ans;

  vector <bool> masked(n);
  for (int dim = 0; dim < 2; dim++) {
    char coords[n];
    for (int i = 0; i < n; i++)
      coords[i] = dim==0 ? lrrThetaVec[i].first : lrrThetaVec[i].second;
    sort(coords, coords+n);
    double q1 = 0.5 * (coords[(n-1)/4] + coords[(n-1+3)/4]);
    double q2 = 0.5 * (coords[(n-1)/2] + coords[(n-1+1)/2]);
    double q3 = 0.5 * (coords[3*(n-1)/4] + coords[(3*(n-1)+3)/4]);
    if (n <= 2) { q1 = coords[0]; q3 = coords[n-1]; }
    double IQRmult = CN==2 ? 3 : 2; // allow more outliers (be more conservative) for CN=2
    double IQRmin = 5;
    double lo = q2 - IQRmult * max(q3-q1, IQRmin);
    double hi = q2 + IQRmult * max(q3-q1, IQRmin);
    for (int i = 0; i < n; i++) {
      char coord = dim==0 ? lrrThetaVec[i].first : lrrThetaVec[i].second;
      if (coord<lo || coord>hi)
	masked[i] = true;
    }
    //cout << n << " " << lo << " " << q1 << " " << q2 << " " << q3 << " " << hi << endl;
  }

  int nTrim = 0;
  for (int i = 0; i < n; i++)
    nTrim += !masked[i];
  //cout << nTrim << endl;

  ans.n = nTrim;
  ans.SigmaXX = 0;
  ans.SigmaYY = 0;
  for (int i = 0; i < n; i++)
    if (!masked[i]) {
      ans.muY += lrrThetaVec[i].first;
      ans.muX += lrrThetaVec[i].second;
    }
  ans.muY /= nTrim; ans.muX /= nTrim;
  double numer = 0, denom = 0;
  for (int i = 0; i < n; i++)
    if (!masked[i]) {
      numer += (lrrThetaVec[i].first - ans.muY) * (lrrThetaVec[i].second - ans.muX);
      denom += sq(lrrThetaVec[i].first - ans.muY);
    }
  ans.slope = denom!=0 ? numer / denom : 0;
  // clip slopes to reasonable range [very rarely has an effect]
  ans.slope = max(min(ans.slope, 1.5), -1.5);

  double sumX = 0, sum2X = 0, sumY = 0, sum2Y = 0;
  for (int i = 0; i < n; i++)
    if (!masked[i]) {
      double y = lrrThetaVec[i].first;
      double x = lrrThetaVec[i].second - ans.slope * y;
      sumX += x; sum2X += sq(x);
      sumY += y; sum2Y += sq(y);
      // terms for computing MLE of MVN
      double dx_no_slope = lrrThetaVec[i].second - ans.muX;
      double dy = y - ans.muY;
      ans.SigmaXX += sq(dx_no_slope);
      ans.SigmaYY += sq(dy);
      ans.SigmaXY += dx_no_slope * dy;
    }
  ans.sigmaX = sqrt((sum2X - sq(sumX)/nTrim) / (nTrim-1));
  ans.sigmaY = sqrt((sum2Y - sq(sumY)/nTrim) / (nTrim-1));
  ans.SigmaXX /= nTrim;
  ans.SigmaYY /= nTrim;
  ans.SigmaXY /= nTrim;

  // clip lrr centers to reasonable range [turned off to be safe]
  //ans.muY = max(min(ans.muY, 64.0), -64.0);
  // lower-bound std [very rarely has an effect]
  ans.sigmaX = max(ans.sigmaX, 1.0);
  ans.sigmaY = max(ans.sigmaY, 1.0);
  ans.SigmaXX = max(ans.SigmaXX, 1.0);
  ans.SigmaYY = max(ans.SigmaYY, 1.0);

  assert(ans.SigmaXX * ans.SigmaYY * (1 + 1e-9) >= sq(ans.SigmaXY));
  return ans;
}

int main(int argc, char *argv[]) {

  printVersion();

  cout << "compute_ref_clusters:" << endl;
  cout << "- arg1 = $LRR_STD_SCALE_FILE" << endl;
  cout << "- arg2 = $BIM_FILE" << endl;
  cout << "- arg3 = $LRR_THETA_GENO_FILE" << endl;
  cout << "- arg4 = $PRELIM_CNV_CALL_PREFIX" << endl;
  cout << "- arg5 = $REF_CLUSTER_PREFIX (.batch*.txt.gz output files)" << endl;
  cout << endl;

  printCmd(argc, argv);

  if (argc != 6) {
    cout << "ERROR: 5 arguments required" << endl;
    exit(1);
  }

  const char *lrrStdScaleFile = checkInputFileExt(argv, 1, ".txt");
  const char *bimFile = checkInputFileExt(argv, 2, ".bim");
  const char *lrrThetaGenoFile = checkInputFileExt(argv, 3, ".bin");
  const char *prelimCNVcallPrefix = argv[4];
  const char *refClusterPrefix = argv[5];

  FileUtils::requireReadable(lrrStdScaleFile);
  FileUtils::requireReadable(bimFile);
  FileUtils::requireReadable(lrrThetaGenoFile);
  FileUtils::requireReadable(prelimCNVcallPrefix + string(".batch1.txt.gz"));
  FileUtils::requireWriteable(refClusterPrefix + string(".batch1.txt.gz"));

  Timer timer; double t0 = timer.get_time();

  /***** read lrr std scale file (noise per indiv) *****/
  vector < pair <int, int> > batchMinMaxPerSet;
  vector <int> batch; vector <double> relScale; vector <string> IDpairs;
  assignBatches(batchMinMaxPerSet, batch, relScale, IDpairs, lrrStdScaleFile);
  int N = IDpairs.size();
  map <string, int> IDtoInd;
  for (int i = 0; i < N; i++)
    IDtoInd[IDpairs[i]] = i;
  
  // look up number of typed SNPs on chromosome
  int M = FileUtils::AutoGzIfstream::lineCount(bimFile);

  for (uint s = 0; s < batchMinMaxPerSet.size(); s++) {
    int bMinSet = batchMinMaxPerSet[s].first, bMaxSet = batchMinMaxPerSet[s].second;
    // analyze each batch alone + analyze all batches in set together (if 2+ batches)
    for (int b = bMinSet; b <= bMaxSet + (bMinSet!=bMaxSet); b++) {

      /***** read LRR-based prelim CNV calls *****/
      vector < vector <IndDeltaCN> > cnvBounds(M+1);
      int bMin, bMax;
      string refClusterPrefixBatch;
      if (b == bMaxSet+1) { // all batches in set together
	bMin = bMinSet; bMax = bMaxSet;
	refClusterPrefixBatch = refClusterPrefix + string(".batch") + StringUtils::itos(bMin)
	  + "-" + StringUtils::itos(bMax);
	cout << "Analyzing genotyping set " << s+1 << ", batch " << bMin << "-" << bMax << endl;
      }
      else { // single batch
	bMin = bMax = b;
	refClusterPrefixBatch = refClusterPrefix + string(".batch") + StringUtils::itos(b);
	cout << "Analyzing genotyping set " << s+1 << ", batch " << b << endl;
      }
      for (int b2 = bMin; b2 <= bMax; b2++)
	addCalls(cnvBounds, IDtoInd, prelimCNVcallPrefix, b2);

      vector <char> callCN(N, 2);

      /***** read, process, store lrr theta confgeno info *****/
      cout << "Processing SNP-array data from " << lrrThetaGenoFile << endl;
      FILE *finBin; finBin = fopen(lrrThetaGenoFile, "rb"); assert(finBin != NULL);
      FileUtils::AutoGzOfstream fout; fout.openOrExit(refClusterPrefixBatch + ".txt.gz");
      FileUtils::AutoGzOfstream foutExamples;
      foutExamples.openOrExit(refClusterPrefixBatch + ".examples.txt.gz");
      vector <GenoInfo> genoRow(N);
      int ctrNonMonotonic = 0;
      for (int m = 0; m < M; m++) {
	// update CNV call status from lrrCall
	for (int j = 0; j < (int) cnvBounds[m].size(); j++)
	  callCN[cnvBounds[m][j].ind] += cnvBounds[m][j].deltaCN;
	// read genotype intensities for SNP m
	fread(&genoRow[0], sizeof(GenoInfo), N, finBin);
	vector <int> genoCounts(4);
	const int DEL_IND = 3, DUP_IND = 4; // indices into lrrThetaVecs
	vector < vector < pair <char, char> > > lrrThetaVecs(5);
	vector < vector <int> > genoVecs(5); // later, to indicate DEL and DUP sub-genotypes
	int num = 0; double sum = 0, sum2 = 0;
	for (int i = 0; i < N; i++)
	  if (bMin <= batch[i] && batch[i] <= bMax) {
	    int g = genoRow[i].geno;
	    char lrr = genoRow[i].lrr;
	    char theta = genoRow[i].theta;
	    genoCounts[g]++;
	    // assign each point to 1 of 5 groups: CN=2[0,1,2], CN=1, CN=3 (ignore no-call no-CNV)
	    if ((g <= 2 || callCN[i]!=2) && lrr != NAN_CHAR && theta != NAN_CHAR) {
	      int ind_v = g;
	      if (callCN[i]==1) ind_v = DEL_IND;
	      if (callCN[i]==3) ind_v = DUP_IND;
	      lrrThetaVecs[ind_v].push_back(make_pair(lrr, theta));
	    }
	    if (lrr != NAN_CHAR) {
	      num++;
	      sum += lrr;
	      sum2 += sq(lrr);
	    }
	  }

	Cluster clusters[4][4]; int flipType = 0; // +1 (no flipLR), -1 (flipLR), 0 (uncertain)

	// compute CN=2 clusters
	for (int geno = 0; geno <= 2; geno++)
	  clusters[2][geno] = computeCluster(lrrThetaVecs[geno], 2);

	// set flipType (letting hom-major outvote hom-minor)
	double muXhet = (clusters[2][1].n > 0 ? clusters[2][1].muX : 0);
	flipType = (clusters[2][0].n * (clusters[2][0].muX < muXhet ? 1 : -1) +
		    clusters[2][2].n * (clusters[2][2].muX > muXhet ? 1 : -1)) > 0 ? 1 : -1;

	// swap g=0,2 clusters if flipped
	if (flipType == -1)
	  swap(clusters[2][0], clusters[2][2]);

	if (clusters[2][0].n && clusters[2][2].n &&
	    (clusters[2][0].muX < muXhet) != (clusters[2][2].muX > muXhet)) {
	  flipType = 0; // muX not monotonic
	  ctrNonMonotonic++;
	}

	// subdivide CN=1, CN=3 clusters -> estimate params
	for (int CN = 1; CN <= 3; CN += 2) {
	  int ind_v = CN==1 ? DEL_IND : DUP_IND;
	  const double alpha = 2.0/3; // split distance from het cluster to hom cluster
	  vector < vector < pair <char, char> > > lrrThetaVecsGenos(CN+1);
	  const vector < pair <char, char> > &lrrThetaVecsCNV = lrrThetaVecs[ind_v];
	  genoVecs[ind_v].resize(lrrThetaVecsCNV.size());
	  for (int j = 0; j < (int) lrrThetaVecsCNV.size(); j++) {
	    int &geno = genoVecs[ind_v][j];
	    double x = lrrThetaVecsCNV[j].second, y = lrrThetaVecsCNV[j].first;
	    const Cluster &hetCluster = clusters[2][1];
	    if (CN==1)
	      geno = x < hetCluster.muX ? 0 : 1;
	    else {
	      geno = x < hetCluster.muX ? 0 : 2;
	      if (geno == 0) {
		const Cluster &homCluster = clusters[2][0];
		if (homCluster.n < 25)
		  geno = 1; // assume no CN=3 hom
		else {
		  double xDiv = alpha*homCluster.muX + (1-alpha)*hetCluster.muX;
		  double yDiv = alpha*homCluster.muY + (1-alpha)*hetCluster.muY;
		  geno += x-clusters[2][0].slope*(y-yDiv) < xDiv ? 0 : 1;
		}
	      }
	      else {
		const Cluster &homCluster = clusters[2][2];
		if (homCluster.n < 25)
		  geno = 2; // assume no CN=3 hom
		else {
		  double xDiv = alpha*homCluster.muX + (1-alpha)*hetCluster.muX;
		  double yDiv = alpha*homCluster.muY + (1-alpha)*hetCluster.muY;
		  geno += x-clusters[2][2].slope*(y-yDiv) < xDiv ? 0 : 1;
		}
	      }
	    }
	    lrrThetaVecsGenos[geno].push_back(lrrThetaVecsCNV[j]);
	  }
	  for (int geno = 0; geno <= CN; geno++)
	    clusters[CN][geno] = computeCluster(lrrThetaVecsGenos[geno], CN);
	}
	// post-process: mask minor CN=1 cluster if MAF<0.05
	double maf = (2*min(genoCounts[0], genoCounts[2])+genoCounts[1])
	  / (2.0*(genoCounts[0]+genoCounts[1]+genoCounts[2]));
	if (maf < 0.05) {
	  if (clusters[1][0].n < clusters[1][1].n)
	    clusters[1][0].n *= -1;
	  else
	    clusters[1][1].n *= -1;
	}
	// post-process: mask CN=3 clusters that overlap other CN=3 clusters with larger N
	for (int geno = 0; geno <= 3; geno++)
	  for (int g2 = geno-1; g2 <= geno+1; g2 += 2) {
	    const double sepRelMin = geno>0 && geno<3 ? 2 : 2.5; // more stringent for hom CN=3
	    if (g2 >= 0 && g2 <= 3 && abs(clusters[3][g2].n) >= max(10, abs(clusters[3][geno].n))
		&& fabs(clusters[3][geno].muX - clusters[3][g2].muX)
		< sepRelMin * (sqrt(clusters[3][geno].SigmaXX) + sqrt(clusters[3][g2].SigmaXX)))
	      clusters[3][geno].n = -abs(clusters[3][geno].n);
	  }
	// post-process: mask clusters with very high variance
	const double varRelMax = 1.5;
	for (int CN = 1; CN <= 3; CN += 2)
	  for (int geno = 0; geno <= CN; geno++) {
	    double totSigmaXX = 0, totSigmaYY = 0;
	    if (clusters[2][0].n < 10) {
	      totSigmaXX = 1.5 * (clusters[2][1].SigmaXX + clusters[2][2].SigmaXX);
	      totSigmaYY = 1.5 * (clusters[2][1].SigmaYY + clusters[2][2].SigmaYY);
	    }
	    else if (clusters[2][2].n < 10) {
	      totSigmaXX = 1.5 * (clusters[2][0].SigmaXX + clusters[2][1].SigmaXX);
	      totSigmaYY = 1.5 * (clusters[2][0].SigmaYY + clusters[2][1].SigmaYY);
	    }
	    else {
	      totSigmaXX = clusters[2][0].SigmaXX + clusters[2][1].SigmaXX + clusters[2][2].SigmaXX;
	      totSigmaYY = clusters[2][0].SigmaYY + clusters[2][1].SigmaYY + clusters[2][2].SigmaYY;
	    }
	    if (clusters[CN][geno].SigmaXX > varRelMax * totSigmaXX ||
		clusters[CN][geno].SigmaYY > varRelMax * totSigmaYY)
	      clusters[CN][geno].n = -abs(clusters[CN][geno].n);
	  }
	// mask all CN=1, CN=3 clusters if not enough CN=2 hets (don't use for prediction)
	if (genoCounts[1] < 25)
	  for (int CN = 1; CN <= 3; CN += 2)
	    for (int geno = 0; geno <= CN; geno++)
	      clusters[CN][geno].n = -abs(clusters[CN][geno].n);

	// write cluster data for this SNP
	bool hasCNVgenoNge10 = false;
	for (int CN = 1; CN <= 3; CN++)
	  for (int geno = 0; geno <= CN; geno++) {
	    if (!(CN==1 && geno==0)) fout << "\t";
	    const Cluster &c = clusters[CN][geno];
	    fout << c.n << "\t" << c.muX << "\t" << c.muY << "\t" << c.SigmaXX << "\t" << c.SigmaYY
		 << "\t" << c.SigmaXY;
	    if (CN != 2 && c.n >= 10)
	      hasCNVgenoNge10 = true;
	  }
	fout << "\t" << flipType << "\t" << /*chr << "\t" << */m << endl;

	// if this SNP has a CNV genotype with n=10+, write 500 example data points per cluster
	if (flipType != 0 && hasCNVgenoNge10) {
	  for (int ind_v = 0; ind_v < 5; ind_v++) {
	    vector <int> order(lrrThetaVecs[ind_v].size());
	    for (int j = 0; j < (int) order.size(); j++)
	      order[j] = j;
	    random_shuffle(order.begin(), order.end());
	    for (int j = 0; j < min(500, (int) lrrThetaVecs[ind_v].size()); j++) {
	      int CN, geno;
	      if (ind_v < 3) {
		CN = 2;
		geno = flipType==-1 ? 2-ind_v : ind_v;
	      }
	      else {
		CN = ind_v==DEL_IND ? 1 : 3;
		geno = genoVecs[ind_v][order[j]];
	      }
	      foutExamples << m << "\t" << CN << "\t" << geno << "\t"
			   << (int) lrrThetaVecs[ind_v][order[j]].second << "\t"
			   << (int) lrrThetaVecs[ind_v][order[j]].first << endl;
	    }
	  }
	}
      }
      fout.close();
      foutExamples.close();
      fclose(finBin);

      cout << "Finished analyzing batch; ctrNonMonotonic = " << ctrNonMonotonic << endl;
    }
  }

  cout << "Finished compute_ref_clusters; total time = " << timer.get_time()-t0 << " sec" << endl;

  return 0;
}
