#include <iostream>
#include <vector>
#include <utility>
#include <algorithm>
#include <cstdio>
#include <cstring>
#include <cassert>
#include <cmath>

#include "omp.h"

#include "VersionHeader.hpp"
#include "AssignBatches.hpp"
#include "FileUtils.cpp"
#include "StringUtils.cpp"
#include "Timer.cpp"

using namespace std;

struct SNP {
  int n[MAX_BATCHES_PER_SET+1][4][4];
  double params[MAX_BATCHES_PER_SET+1][4][4][5];
  int flipType, chr, snpNum;
};

inline double sq(double x) { return x*x; }

double det(double SigmaXX, double SigmaYY, double SigmaXY) {
  return SigmaXX*SigmaYY - sq(SigmaXY);
}

double hellinger(const double *pars1, const double *pars2) {
  double dmuX = pars1[0] - pars2[0];
  double dmuY = pars1[1] - pars2[1];
  double SigmaXX = 0.5*(pars1[2]+pars2[2]);
  double SigmaYY = 0.5*(pars1[3]+pars2[3]);
  double SigmaXY = 0.5*(pars1[4]+pars2[4]);
  double detSigma = det(SigmaXX, SigmaYY, SigmaXY);
  double dist = 1 -
    pow(det(pars1[2], pars1[3], pars1[4]) * det(pars2[2], pars2[3], pars2[4]), 0.25)
    * pow(detSigma, -0.5)
    * exp(-0.125 * (SigmaYY*sq(dmuX) + SigmaXX*sq(dmuY) - 2*SigmaXY*dmuX*dmuY) / detSigma);
  assert(dist>=0);
  return dist;
}

double KLdiv(const double *pars0, const double *pars1) {
  double dmuX = pars1[0] - pars0[0];
  double dmuY = pars1[1] - pars0[1];

  double Sigma0XX = pars0[2];
  double Sigma0YY = pars0[3];
  double Sigma0XY = pars0[4];
  double detSigma0 = det(Sigma0XX, Sigma0YY, Sigma0XY);
  double Sigma1XX = pars1[2];
  double Sigma1YY = pars1[3];
  double Sigma1XY = pars1[4];
  double detSigma1 = det(Sigma1XX, Sigma1YY, Sigma1XY);

  double div = 0.5 * ((Sigma1YY*Sigma0XX + Sigma1XX*Sigma0YY - 2*Sigma1XY*Sigma0XY) / detSigma1
		      + (Sigma1YY*sq(dmuX) + Sigma1XX*sq(dmuY) - 2*Sigma1XY*dmuX*dmuY) / detSigma1
		      - 2 + log(detSigma1 / detSigma0));
  assert(div>=0);
  return div;
}

int main(int argc, char *argv[]) {

  printVersion();

  cout << "predict_clusters:" << endl;
  cout << "- arg1 = $LRR_STD_SCALE_FILE" << endl;
  cout << "- arg2 = $CHR" << endl;
  cout << "- arg3 = $REF_CLUSTER_PREFIX_NO_CHR (.chr*.batch*.txt.gz)" << endl;
  cout << "- arg4 = $PRED_CLUSTER_PREFIX (.batch*.txt.gz output files)" << endl;
  cout << "- arg5 = $THREADS" << endl;
  cout << endl;

  printCmd(argc, argv);

  if (argc != 6) {
    cout << "ERROR: 5 arguments required" << endl;
    exit(1);
  }

  const char *lrrStdScaleFile = checkInputFileExt(argv, 1, ".txt");
  int chrPred; assert(sscanf(argv[2], "%d", &chrPred));
  const char *refClusterPrefixNoChr = argv[3];
  const char *predClusterPrefix = argv[4];
  int threads; assert(sscanf(argv[5], "%d", &threads));

  FileUtils::requireReadable(lrrStdScaleFile);
  FileUtils::requireReadable(refClusterPrefixNoChr + string(".chr1.batch1.txt.gz"));
  FileUtils::requireWriteable(predClusterPrefix + string(".batch1.txt.gz"));

  Timer timer; double t0 = timer.get_time();

  cout << "Predicting genotype cluster parameters for variants on chr" << chrPred << endl;
  assert(1 <= chrPred && chrPred <= 22);
  cout << "Setting number of threads to " << threads << endl;
  assert(threads > 0);
  cout << endl;
  omp_set_num_threads(threads);

  const int MIN_N_TRAIN = 10;
  const int MIN_N_HOM = 25;
  const int K = 20; // number of closest SNPs

  /***** read lrr std scale file (noise per indiv) *****/
  vector < pair <int, int> > batchMinMaxPerSet;
  vector <int> batch; vector <double> relScale; vector <string> IDpairs;
  assignBatches(batchMinMaxPerSet, batch, relScale, IDpairs, lrrStdScaleFile);

  for (uint s = 0; s < batchMinMaxPerSet.size(); s++) {
    int bMinSet = batchMinMaxPerSet[s].first, bMaxSet = batchMinMaxPerSet[s].second;
    assert(bMaxSet - bMinSet + 1 <= MAX_BATCHES_PER_SET);

    vector <SNP> snpsTrain, snpsTest;
    for (int chr = 1; chr <= 22; chr++) {
      FileUtils::AutoGzIfstream fins[bMaxSet-bMinSet+2];
      int Mchr = 0;
      for (int b = 0; b <= bMaxSet-bMinSet+1; b++) {
	string fileName = refClusterPrefixNoChr + string(".chr") + StringUtils::itos(chr)
	  + ".batch" + StringUtils::itos(max(0, b-1) + bMinSet) + ".txt.gz";
	if (b==0) { // all samples in set together
	  if (bMinSet != bMaxSet)
	    fileName = refClusterPrefixNoChr + string(".chr") + StringUtils::itos(chr)
	      + ".batch" + StringUtils::itos(bMinSet) + "-" + StringUtils::itos(bMaxSet)+ ".txt.gz";
	  Mchr = FileUtils::AutoGzIfstream::lineCount(fileName);
	}
	cout << "Reading reference CNV genotype clusters: " << fileName << endl;
	fins[b].openOrExit(fileName);
      }

      for (int m = 0; m < Mchr; m++) {
	SNP snp; bool hasTrain = false;
	for (int b = 0; b <= bMaxSet-bMinSet+1; b++) {
	  for (int CN = 1; CN <= 3; CN++)
	    for (int geno = 0; geno <= CN; geno++) {
	      fins[b] >> snp.n[b][CN][geno];
	      if (b != 0 && CN != 2 && snp.n[b][CN][geno] >= MIN_N_TRAIN)
		hasTrain = true;
	      for (int k = 0; k < 5; k++)
		fins[b] >> snp.params[b][CN][geno][k];
	    }
	  fins[b] >> snp.flipType >> /*snp.chr >> */snp.snpNum;
	  assert(snp.snpNum == m);
	  snp.chr = chr;
	}
	if (hasTrain) snpsTrain.push_back(snp);
	if (chr == chrPred || (chrPred == 0 && hasTrain)) snpsTest.push_back(snp);
      }
      for (int b = 0; b <= bMaxSet-bMinSet+1; b++) fins[b].close();
    }

    int Mtrain = snpsTrain.size();
    int Mtest = snpsTest.size();
    cout << endl;
    cout << "Read " << Mtrain << " SNPs with training data for at least one CNV cluster" << endl;
    cout << "Computing predictions for " << Mtest << " SNPs" << endl;

    vector <SNP> snpsPred(Mtest); // store predicted params

    double err2sums[bMaxSet-bMinSet+2][4][4][2]; memset(err2sums, 0, sizeof(err2sums));
    int err2nums[bMaxSet-bMinSet+2][4][4]; memset(err2nums, 0, sizeof(err2nums));

#pragma omp parallel for
    for (int mTest = 0; mTest < Mtest; mTest++) {
      for (int genoHom = 0; genoHom <= 2; genoHom += 2) {
	if (snpsTest[mTest].n[0][2][genoHom] < MIN_N_HOM) continue;
	const SNP &snpTest = snpsTest[mTest];
	SNP &snpPred = snpsPred[mTest];
	vector < pair <double, int> > dist_inds;
	for (int mTrain = 0; mTrain < Mtrain; mTrain++) {
	  const SNP &snpTrain = snpsTrain[mTrain];
	  if (snpTest.chr == snpTrain.chr && snpTest.snpNum == snpTrain.snpNum)
	    continue; // don't allow prediction using SNP itself
	  if (snpTrain.n[0][2][genoHom] < MIN_N_HOM) continue;
	  double dist_combined =
	    hellinger(snpTest.params[0][2][genoHom], snpTrain.params[0][2][genoHom]); // CN=2 hom
	  if (snpTest.n[0][2][1] >= MIN_N_TRAIN) { // test SNP has het cluster
	    if (snpTrain.n[0][2][1] < MIN_N_TRAIN) // required train SNP to also have het cluster
	      continue;
	    dist_combined +=
	      hellinger(snpTest.params[0][2][1], snpTrain.params[0][2][1]); // CN=2 het, batchALL
	  }
	  dist_inds.push_back(make_pair(dist_combined, mTrain));
	}
	sort(dist_inds.begin(), dist_inds.end());

	for (int CN = 1; CN <= 3; CN++)
	  for (int geno = 0; geno <= CN; geno++) {
	    int genoClosest = ((geno==0 || geno==CN) && geno/CN == genoHom/2) ? genoHom : 1;
	    if (snpTest.n[0][2][2-genoHom] >= MIN_N_HOM // have opposite cluster => don't pred opp
		&& ((CN==1 && geno!=genoHom/2) || (CN==3 && geno/2!=genoHom/2))) continue;
	    if (CN==2 && geno==genoHom) // for CN=2, only pred opp hom and het
	      continue;                 // (overwrite with real CN=2 data later if present)
	    if (snpTest.n[0][2][1] < MIN_N_TRAIN // if no het cluster, only pred all-major CN=1,3
		&& !((CN==1 && geno==genoHom/2) || (CN==3 && geno==genoHom/2*CN)))
	      continue;

	    int Kleft = (bMaxSet-bMinSet+2)*K; vector <int> KleftPerBatch(bMaxSet-bMinSet+2, K);
	    for (int k = 0; k < (int) dist_inds.size() && Kleft; k++) {
	      int mTrain = dist_inds[k].second;
	      const SNP &snpTrain = snpsTrain[mTrain];
	      for (int b = 0; b <= bMaxSet-bMinSet+1; b++) {
		if (snpTrain.n[b][CN][geno] >= MIN_N_TRAIN && KleftPerBatch[b]) {
		  Kleft--;
		  KleftPerBatch[b]--;
		  double muPred[2];
		  for (int xy = 0; xy < 2; xy++) {
		    // adjust means according to offset btwn centers (using batchALL if N too low)
		    int bClosest = (snpTest.n[b][2][genoClosest] >= MIN_N_TRAIN &&
				    snpTrain.n[b][2][genoClosest] >= MIN_N_TRAIN) ? b : 0;
		    double deltaMu = (snpTest.params[bClosest][2][genoClosest][xy] -
				      snpTrain.params[bClosest][2][genoClosest][xy]);
		    muPred[xy] = snpTrain.params[b][CN][geno][xy] + deltaMu;
		    snpPred.params[b][CN][geno][xy] += muPred[xy] / K;

		    // augment variance parameters: E[X^2]
		    snpPred.params[b][CN][geno][xy+2] +=
		      (snpTrain.params[b][CN][geno][xy+2] + sq(muPred[xy])) / K;
		  }
		  // augment covariance parameter: E[XY]
		  snpPred.params[b][CN][geno][4] +=
		    (snpTrain.params[b][CN][geno][4] + muPred[0]*muPred[1]) / K;
		}
	      }
	    }
	    assert(Kleft==0);
	    // subtract mu^2 to get variance, muX*muY to get covariance
	    for (int b = 0; b <= bMaxSet-bMinSet+1; b++) {
	      double *pred_params_b = &snpPred.params[b][CN][geno][0];
	      for (int xy = 0; xy < 2; xy++)
		pred_params_b[xy+2] -= sq(pred_params_b[xy]);
	      pred_params_b[4] -= pred_params_b[0] * pred_params_b[1];
	    }

#pragma omp critical
	    for (int b = 0; b <= bMaxSet-bMinSet+1; b++) {
	      if (snpTest.n[b][CN][geno] >= MIN_N_TRAIN) { // check against gold standard cluster
		err2nums[b][CN][geno]++;
		for (int xy = 0; xy < 2; xy++)
		  err2sums[b][CN][geno][xy] +=
		    sq(snpTest.params[b][CN][geno][xy] - snpPred.params[b][CN][geno][xy]);
	      }
	    }
	  }
      }
      for (int CN = 1; CN <= 3; CN++)
	for (int geno = 0; geno <= CN; geno++)
	  for (int b = 0; b <= bMaxSet-bMinSet+1; b++)
	    for (int xy = 0; xy < 2; xy++) {
	      double &Sigma = snpsPred[mTest].params[b][CN][geno][xy+2];
	      if (Sigma == 0) {
		Sigma = 1e6; // for non-predicted clusters, set variance to 1e6
		assert(snpsPred[mTest].params[b][CN][geno][xy] == 0);
	      }
	    }
    }

    for (int b = 0; b <= bMaxSet-bMinSet+1; b++) {
      if (b != 0 && b != 1 && b != (bMaxSet-bMinSet+2)/2 && b != bMaxSet-bMinSet+1) continue;
      if (b == 0)
	cout << "batch" << bMinSet << "-" << bMaxSet << ":" << endl;
      else
	cout << "batch" << b-1+bMinSet << ":" << endl;
      for (int CN = 1; CN <= 3; CN += 2) {
	double RMSEs[4][2];
	for (int geno = 0; geno <= CN; geno++)
	  for (int xy = 0; xy < 2; xy++)
	    RMSEs[geno][xy] = sqrt(err2sums[b][CN][geno][xy] / err2nums[b][CN][geno]);
	if (CN==1)
	  printf("RMSEs: %5.2f %5.2f | %5.2f %5.2f\n",
		 RMSEs[0][0], RMSEs[0][1], RMSEs[1][0], RMSEs[1][1]);
	else
	  printf("RMSEs: %5.2f %5.2f | %5.2f %5.2f | %5.2f %5.2f | %5.2f %5.2f\n",
		 RMSEs[0][0], RMSEs[0][1], RMSEs[1][0], RMSEs[1][1],
		 RMSEs[2][0], RMSEs[2][1], RMSEs[3][0], RMSEs[3][1]);
      }
      cout << endl;
    }
  
    for (int b = 1; b <= bMaxSet-bMinSet+1; b++) {
      FileUtils::AutoGzOfstream fout;
      fout.openOrExit(predClusterPrefix + string(".batch") + StringUtils::itos(b-1+bMinSet)
		      + ".txt.gz");
      for (int mTest = 0; mTest < Mtest; mTest++) {
	const SNP &snpPred = snpsPred[mTest];
	const SNP &snpTest = snpsTest[mTest];
	for (int CN = 1; CN <= 3; CN++)
	  for (int geno = 0; geno <= CN; geno++) {
	    const SNP &snpPrint = (CN==2 && snpTest.n[b][CN][geno]>=MIN_N_TRAIN) ? snpTest:snpPred;
	    if (!(CN==1 && geno==0)) fout << "\t";
	    if (CN!=2) fout << geno;
	    else fout << snpTest.n[b][CN][geno];
	    for (int p = 0; p < 5; p++)
	      fout << "\t" << snpPrint.params[b][CN][geno][p];
	  }
	fout << "\t" << snpTest.flipType << "\t" << snpTest.chr << "\t" << snpTest.snpNum << endl;
      }
      fout.close();
    }
  }

  cout << "Finished predict_clusters; total time = " << timer.get_time()-t0 << " sec" << endl;

  return 0;
}
