// alternative compilation using icpc (module load intel/2016)

// icpc -DMKL_ILP64 -O3 -qopenmp -Wall denoise_lrr.cpp -I/n/groups/price/poru/HSPH_SVN/src/EAGLE -I/home/pl88/boost_1_58_0/install/include -I/n/groups/price/poru/external_software/intel_mkl_2019u4/mkl/include -Wl,-rpath,/n/groups/price/poru/external_software/intel_mkl_2019u4/mkl/lib/intel64 -o denoise_lrr -L/n/groups/price/poru/external_software/intel_mkl_2019u4/mkl/lib/intel64 -lmkl_intel_ilp64 -lmkl_intel_thread -lmkl_core -liomp5 -lpthread -lm -ldl

#include <iostream>
#include <fstream>
#include <string>
#include <vector>
#include <set>
#include <map>
#include <algorithm>
#include <utility>
#include <cstdio>
#include <cstdlib>
#include <cassert>
#include <cmath>

#include "mkl.h"
#include "omp.h"

#include "VersionHeader.hpp"
#include "FileUtils.cpp"
#include "StringUtils.cpp"
#include "Timer.cpp"

using namespace std;

const char NAN_CHAR = -128;
const int BATCH_SIZE = 5000;

char charCrop(float x) {
  return (char) max(-127, min(127, (int) floor(x+0.5)));
}

uint hash(uint x) {
  x = ((x >> 16) ^ x) * 0x45d9f3b;
  x = ((x >> 16) ^ x) * 0x45d9f3b;
  x = (x >> 16) ^ x;
  return x;
}

uint unhash(uint x) {
  x = ((x >> 16) ^ x) * 0x119de1f3;
  x = ((x >> 16) ^ x) * 0x119de1f3;
  x = (x >> 16) ^ x;
  return x;
}

double sq(double x) { return x*x; }

string tmpLRRperBatch(const char *outPrefix, int batch) {
  return string(outPrefix) + ".tmpLRR.group" + StringUtils::itos(batch) + ".bin";
}

string tmpDenoisedLRRperBatch(const char *outPrefix, int batch) {
  return string(outPrefix) + ".tmpDenoisedLRR.group" + StringUtils::itos(batch) + ".bin";
}

void defineBatches(vector <string> &IDpairsGenoSet, vector <string> &genotypingSetNames,
		   vector < vector <int> > &batchIndsPerSet, vector <int> &famIndToBatch,
		   vector <uint64> &NperBatch, const char *sampleInfoFile) {

  cout << "Reading sample info file: " << sampleInfoFile << endl;
  FileUtils::AutoGzIfstream fin; fin.openOrExit(sampleInfoFile);
  string line; getline(fin, line); // throw away header
  map <string, int> setNameToInd; // map genotyping set names to genotypingSets indices
  vector < vector < pair <bool, uint> > > ancHashedIndsPerSet;
  string ID_1, ID_2, ID_imp, genotypingSetName; int inUniformAncestrySet;
  int famInd = 0;
  while (fin >> ID_1 >> ID_2 >> ID_imp >> genotypingSetName >> inUniformAncestrySet) {
    IDpairsGenoSet.push_back(ID_1 + "\t" + ID_2 + "\t" + genotypingSetName);
    if (setNameToInd.find(genotypingSetName) == setNameToInd.end()) {
      setNameToInd[genotypingSetName] = genotypingSetNames.size();
      ancHashedIndsPerSet.push_back(vector < pair <bool, uint> > ());
      genotypingSetNames.push_back(genotypingSetName);
    }
    ancHashedIndsPerSet[setNameToInd[genotypingSetName]].
      push_back(make_pair((bool) !inUniformAncestrySet, hash(famInd)));
    famInd++;
  }
  fin.close();
  uint64 N = famInd;
  famIndToBatch.resize(N);
  cout << "Read sample info for " << N << " samples" << endl;
  cout << "Read " << genotypingSetNames.size() << " genotyping sets:" << endl;
  int S = genotypingSetNames.size();
  batchIndsPerSet.resize(S);
  int B = 0; // current batch index
  for (int s = 0; s < S; s++) {
    int Nset = ancHashedIndsPerSet[s].size();
    cout << "  " << genotypingSetNames[s] << " (N = " << Nset << ")" << endl;
    int Bset = (Nset+BATCH_SIZE-1)/BATCH_SIZE;
    cout << "  Splitting into " << Bset << " groups of at most " << BATCH_SIZE << " samples"
	 << endl;
    sort(ancHashedIndsPerSet[s].begin(), ancHashedIndsPerSet[s].end());
    for (int bs = 0; bs < Bset; bs++) {
      int ctrInUniformAncestrySet = 0, ctr = 0;
      vector <int> famInds;
      for (int i = bs*BATCH_SIZE; i < (bs+1)*BATCH_SIZE && i < Nset; i++) {
	ctrInUniformAncestrySet += !ancHashedIndsPerSet[s][i].first;
	ctr++;
	uint n = unhash(ancHashedIndsPerSet[s][i].second);
	assert(n < N);
	famIndToBatch[n] = B;
      }
      batchIndsPerSet[s].push_back(B);
      NperBatch.push_back(ctr);
      cout << "    group" << B << ": " << NperBatch[B] << " samples, " << ctrInUniformAncestrySet
	   << " in uniform ancestry set" << endl;
      B++;
    }
  }
}
  
void splitLRR(uint64 &M, vector <uint64> &MperChr, const char *lrrFormatStr,
	      const vector <int> &famIndToBatch, int B, const char *outPrefix) {

  Timer timer;
  uint64 N = famIndToBatch.size();
  FILE *foutPerBatch[B];
  for (int b = 0; b < B; b++) {
    string tmpLRRfile = tmpLRRperBatch(outPrefix, b);
    cout << "Writing temporary LRR file: " << tmpLRRfile << endl;
    foutPerBatch[b] = fopen(tmpLRRfile.c_str(), "wb");
    assert(foutPerBatch[b] != NULL);
  }

  M = 0;
  MperChr = vector <uint64> (23);
  for (int c = 1; c <= 22; c++) {
    char lrrFile[1000]; sprintf(lrrFile, lrrFormatStr, c);
    cout << "Splitting LRR file: " << lrrFile << "... " << flush;
    FileUtils::AutoGzIfstream finLRR; finLRR.openOrExit(lrrFile);
    string lineLRR, lrrStr;
    while (getline(finLRR, lineLRR)) {
      MperChr[c]++;
      M++;
      istringstream issLRR(lineLRR);
      for (uint64 n = 0; n < N; n++) {
	assert(issLRR >> lrrStr);
	float lrr = NAN;
	assert(sscanf(lrrStr.c_str(), "%f", &lrr) || lrrStr=="NA" || lrrStr=="nan");
	assert(fwrite(&lrr, sizeof(float), 1, foutPerBatch[famIndToBatch[n]]) == 1U);
      }
      assert(!(issLRR >> lrrStr));
    }
    finLRR.close();
    cout << " done: " << MperChr[c] << " variants in " << timer.update_time() << " sec" << endl;
  }

  for (int b = 0; b < B; b++)
    fclose(foutPerBatch[b]);
}

vector < vector <float> > computePCs(const char *outPrefix, const int P,
				     const vector <string> &genotypingSetNames,
				     const vector < vector <int> > &batchIndsPerSet,
				     const vector <uint64> &NperBatch, const uint64 M) {
  Timer timer;

  int S = genotypingSetNames.size();
  // allocate storage for S batches of M x P floats (P PCs of length M; variant-major)
  vector < vector <float> > PCsPerSet(S, vector <float> (M * P));

  for (int s = 0; s < S; s++) { // iterate through genotyping sets
    cout << "Computing PCs for genotyping set " << genotypingSetNames[s] << endl;

    int b = batchIndsPerSet[s][0]; // compute PCs using first batch in each set
    uint64 Nb = NperBatch[b];

    // allocate storage for LRR and SVD
    float *A = (float *) malloc(sizeof(float) * Nb * M); // A = M x Nb = Nb x M col-maj
    float *U = (float *) malloc(sizeof(float) * Nb * Nb); // U in column-major order
    float *S = (float *) malloc(sizeof(float) * Nb); // singular values
    assert(A != NULL); assert(U != NULL); assert(S != NULL);

    // read LRR data
    string tmpLRRfile = tmpLRRperBatch(outPrefix, b);
    cout << "Opening temporary LRR file: " << tmpLRRfile << endl;
    FILE *finBin; finBin = fopen(tmpLRRfile.c_str(), "rb");
    assert(finBin != NULL);
    assert(fread(A, sizeof(float), Nb * M, finBin) == Nb * M);
    assert(fgetc(finBin) == EOF);
    fclose(finBin);
    cout << "Read LRR values for " << Nb << " samples and " << M << " variants" << endl;

    // set NAN to 0 for PC computation
    for (uint64 nm = 0; nm < Nb*M; nm++)
      if (isnan(A[nm]))
	A[nm] = 0;

    // compute svd
    { // A (Nb x M in LAPACK) = U (Nb x Nb) * Sigma (Nb x Nb) * V' (Nb x M)
      char JOBU_ = 'A', JOBVT_ = 'O'; // overwrite input matrix with right singular vectors
      long long M_ = Nb, N_ = M;
      float *A_ = A;
      long long LDA_ = Nb;
      float *S_ = S;
      float *U_ = U;
      long long LDU_ = Nb;
      float *VT_ = NULL;
      long long LDVT_ = 1;
      long long LWORK_ = 5*M;
      float *WORK_ = (float *) malloc(sizeof(float) * LWORK_);
      long long INFO_;
      sgesvd_( &JOBU_, &JOBVT_, &M_, &N_, A_, &LDA_, S_, U_, &LDU_, VT_, &LDVT_,
	       WORK_, &LWORK_, &INFO_ );
      free(WORK_);
      if (INFO_ != 0) {
	cout << "ERROR: SVD computation failed" << endl;
	exit(1);
      }
    }

    // output fraction of variance explained
    float s2sum = 0;
    for (uint i = 0; i < Nb; i++)
      s2sum += S[i] * S[i];
    float s2cur = 0;
    cout << "Fraction of variance explained by top 10,20,...,100 PCs:";
    for (int p = 1; p <= 100; p++) {
      s2cur += S[p-1] * S[p-1];
      if (p % 10 == 0)
	printf(" %.3f", s2cur / s2sum);
    }
    cout << endl;

    for (uint64 m = 0; m < M; m++)
      for (int p = 0; p < P; p++)
	PCsPerSet[s][m*P + p] = A[p + m*Nb];
    
    free(S);
    free(U);
    free(A);

    cout << "Time for PCA: " << timer.update_time() << " sec" << endl;
  }

  return PCsPerSet;
}

void projPCs(const char *outPrefix, const vector < vector <float> > &PCsPerSet,
	     const vector <string> &genotypingSetNames,
	     const vector < vector <int> > &batchIndsPerSet, const vector <uint64> &NperBatch,
	     const uint64 M, const int P) {
  
  Timer timer;

  int S = genotypingSetNames.size();
  for (int s = 0; s < S; s++) { // iterate through genotyping sets
    cout << "Projecting PCs for genotyping set " << genotypingSetNames[s] << endl;
    const float *PCs = &PCsPerSet[s][0];

    for (uint k = 0; k < batchIndsPerSet[s].size(); k++) {
      int b = batchIndsPerSet[s][k]; // index of k-th batch in set s
      uint64 Nb = NperBatch[b];

      // read LRR data
      float *lrrAll = (float *) malloc(M * Nb * sizeof(float)); assert(lrrAll != NULL);
      string tmpLRRfile = tmpLRRperBatch(outPrefix, b);
      cout << "Opening temporary LRR file: " << tmpLRRfile << endl;
      FILE *finBin; finBin = fopen(tmpLRRfile.c_str(), "rb");
      assert(finBin != NULL);
      assert(fread(lrrAll, sizeof(float), M * Nb, finBin) == M * Nb);
      assert(fgetc(finBin) == EOF);
      fclose(finBin);
      cout << "Read LRR values for " << Nb << " samples and " << M << " variants" << endl;

      // allocate 1-byte denoised LRR output buffer
      char *lrrOut = (char *) malloc(M * Nb * sizeof(char)); assert(lrrOut != NULL);
      memset(lrrOut, NAN_CHAR, M * Nb * sizeof(char));

#pragma omp parallel for
      for (uint64 n = 0; n < Nb; n++) { // project top PCs; store projected lrr
	float coeffs[P];
	for (int p = 0; p < P; p++)
	  coeffs[p] = 0;
	for (uint64 m = 0; m < M; m++) {
	  float lrr = lrrAll[m*Nb + n];
	  if (!isnan(lrr))
	    for (int p = 0; p < P; p++)
	      coeffs[p] += PCs[m*P + p] * lrr;
	}
	// coeffs[p] should equal U[n + p*Nb] * S[p] for individuals in 0-th batch
	for (uint64 m = 0; m < M; m++) {
	  float lrr = lrrAll[m*Nb + n];
	  if (!isnan(lrr)) {
	    float lrr_proj = lrr;
	    for (int p = 0; p < P; p++)
	      lrr_proj -= coeffs[p] * PCs[m*P + p];
	    lrrOut[m*Nb + n] = charCrop(64 * lrr_proj);
	  }
	}
      }

      // write output
      string tmpDenoisedLRRfile = tmpDenoisedLRRperBatch(outPrefix, b);
      cout << "Writing temporary denoised LRR file: " << tmpDenoisedLRRfile << endl;
      FILE *foutBin = fopen(tmpDenoisedLRRfile.c_str(), "wb");
      assert(foutBin != NULL);
      fwrite(lrrOut, sizeof(char), M * Nb, foutBin);
      fclose(foutBin);
      remove(tmpLRRperBatch(outPrefix, b).c_str());
      free(lrrAll);
      free(lrrOut);

      cout << "Time for projecting PCs from group" << b << ": " << timer.update_time() << " sec"
	   << endl;
    }
  }
}

void mergeLRR(const char *outPrefix, const vector <int> &famIndToBatch,
	      const vector <uint64> &NperBatch, const vector <uint64> &MperChr,
	      const vector <string> &IDpairsGenoSet) {

  uint64 N = famIndToBatch.size();
  int B = NperBatch.size();
  FILE *finPerBatch[B];
  char *bufs[B];
  for (int b = 0; b < B; b++) {
    string tmpDenoisedLRRfile = tmpDenoisedLRRperBatch(outPrefix, b);
    cout << "Opening temporary denoised LRR file: " << tmpDenoisedLRRfile << endl;
    finPerBatch[b] = fopen(tmpDenoisedLRRfile.c_str(), "rb");
    assert(finPerBatch[b] != NULL);
    bufs[b] = (char *) malloc(NperBatch[b]);
    assert(bufs[b] != NULL);
  }

  char *lrrMerged = (char *) malloc(N);
  assert(lrrMerged != NULL);
  vector <int> varNums(N); vector <double> varSums(N);
  for (int c = 1; c <= 22; c++) {
    string mergedDenoisedLRRfile = string(outPrefix) + ".chr" + StringUtils::itos(c) + ".bin";
    cout << "Writing merged denoised LRR file: " << mergedDenoisedLRRfile << endl;
    FILE *fout = fopen(mergedDenoisedLRRfile.c_str(), "wb");
    assert(fout != NULL);
    for (uint64 m = 0; m < MperChr[c]; m++) {
      for (int b = 0; b < B; b++)
	assert(fread(bufs[b], 1, NperBatch[b], finPerBatch[b]) == NperBatch[b]);
      char *bufPtrs[B]; memcpy(bufPtrs, bufs, B*sizeof(bufs[0]));
      for (uint64 n = 0; n < N; n++)
	lrrMerged[n] = *bufPtrs[famIndToBatch[n]]++;
      fwrite(lrrMerged, sizeof(char), N, fout);
      // augment running totals of relative variance per sample
      int num = 0; double sum = 0, sum2 = 0;
      for (uint64 n = 0; n < N; n++)
	if (lrrMerged[n] != NAN_CHAR) {
	  num++;
	  sum += lrrMerged[n];
	  sum2 += sq(lrrMerged[n]);
	}
      if (num >= 10) {
	double mu = sum / num;
	double sigma2 = (sum2 - sq(sum)/num) / (num-1);
	if (sigma2 > 1e-6) {
	  double sigmaInv = 1/sqrt(sigma2);
	  for (uint64 n = 0; n < N; n++)
	    if (lrrMerged[n] != NAN_CHAR) {
	      varSums[n] += sq((lrrMerged[n] - mu) * sigmaInv);
	      varNums[n]++;
	    }
	}
      }
    }
    fclose(fout);
  }

  for (int b = 0; b < B; b++) {
    assert(fgetc(finPerBatch[b]) == EOF);
    fclose(finPerBatch[b]);
    remove(tmpDenoisedLRRperBatch(outPrefix, b).c_str());
    free(bufs[b]);
  }
  free(lrrMerged);

  // write relative noise per sample
  FileUtils::AutoGzOfstream foutStdScale;
  foutStdScale.openOrExit(string(outPrefix) + ".std_scale.txt");
  foutStdScale << "ID_1\tID_2\tgenotypingSet\tstd_scale\tnum_snps" << endl;
  for (uint64 n = 0; n < N; n++)
    foutStdScale << IDpairsGenoSet[n] << "\t" << sqrt(varSums[n]/varNums[n]) << "\t" << varNums[n]
		 << endl;
  foutStdScale.close();
}


int main(int argc, char *argv[]) {

  printVersion();

  cout << "denoise_lrr:" << endl;
  cout << "- arg1: sample info file" << endl;
  cout << "- arg2: format string for LRR files (...chr%d...)" << endl;
  cout << "- arg3: number of LRR PCs to project" << endl;
  cout << "- arg4: threads" << endl;
  cout << "- arg5: output prefix (denoised LRR in 1-byte binary format; 1 file per chr)" << endl;
  cout << endl;

  printCmd(argc, argv);

  if (argc != 6) {
    cout << "ERROR: 5 arguments required" << endl;
    exit(1);
  }

  const char *sampleInfoFile = argv[1];
  const char *lrrFormatStr = argv[2];
  int P; assert(sscanf(argv[3], "%d", &P));
  int threads; assert(sscanf(argv[4], "%d", &threads));
  const char *outPrefix = argv[5];

  FileUtils::requireReadable(sampleInfoFile);
  FileUtils::requireWriteable(outPrefix + string(".chr1.bin"));

  Timer timer; double t0 = timer.get_time();
  
  cout << "Number of top PCs to remove from LRR: " << P << endl;
  assert(P > 0);
  cout << "Setting number of threads to " << threads << endl;
  assert(threads > 0);
  cout << endl;
  omp_set_num_threads(threads);
  mkl_set_num_threads(threads);

  vector <string> IDpairsGenoSet, genotypingSetNames;
  vector < vector <int> > batchIndsPerSet;
  vector <int> famIndToBatch;
  vector <uint64> NperBatch;
  defineBatches(IDpairsGenoSet, genotypingSetNames, batchIndsPerSet, famIndToBatch, NperBatch,
		sampleInfoFile);
  
  uint64 M;
  vector <uint64> MperChr;
  splitLRR(M, MperChr, lrrFormatStr, famIndToBatch, NperBatch.size(), outPrefix);
  cout << "Finished splitting LRR: " << timer.update_time() << " sec" << endl << endl;

  vector < vector <float> > PCsPerSet
    = computePCs(outPrefix, P, genotypingSetNames, batchIndsPerSet, NperBatch, M);
  cout << "Finished computing PCs: " << timer.update_time() << " sec" << endl << endl;

  projPCs(outPrefix, PCsPerSet, genotypingSetNames, batchIndsPerSet, NperBatch, M, P);
  cout << "Finished projecting PCs: " << timer.update_time() << " sec" << endl << endl;
  
  mergeLRR(outPrefix, famIndToBatch, NperBatch, MperChr, IDpairsGenoSet);
  cout << "Finished merging output: " << timer.update_time() << " sec" << endl << endl;

  cout << "Finished denoise_lrr; total time = " << timer.get_time()-t0 << " sec" << endl;

  return 0;
}
