// g++ -O3 -Wall -static-libgcc -static-libstdc++ computeReadCountSketches.cpp -o computeReadCountSketches -I/n/groups/price/poru/HSPH_SVN/src/EAGLE -I/home/pl88/boost_1_58_0/install/include -L/n/groups/price/poru/external_software/libstdc++/usr/lib/gcc/x86_64-redhat-linux/4.8.5/ -L/n/groups/price/poru/external_software/zlib/zlib-1.2.11 -L/home/pl88/boost_1_58_0/install/lib -Wl,-Bstatic -lboost_iostreams -lz

#include <iostream>
#include <iomanip>
#include <sstream>
#include <fstream>
#include <vector>
#include <map>
#include <set>
#include <algorithm>
#include <cstdio>
#include <cstdlib>
#include <cmath>
#include <cassert>

#include "FileUtils.cpp"
#include "StringUtils.cpp"
#include "Timer.cpp"

using namespace std;

struct Region {
  int start38, end38, minMAPQ;
  string toString(void) const {
    return StringUtils::itos(start38) + "_" + StringUtils::itos(end38)
      + ".Q" + StringUtils::itos(minMAPQ);
  }
};

float medianScale(double *chrMeanCounts, int *extractedCounts) {
  double ratios[22];
  for (int chr = 1; chr <= 22; chr++)
    ratios[chr-1] = extractedCounts[chr] / chrMeanCounts[chr];
  sort(ratios, ratios+22);
  float scale = (ratios[10]+ratios[11])/2;
  return scale;
}

int main(int argc, char *argv[]) {

  Timer timer;

  if (argc != 8) {
    cerr << "Usage:" << endl;
    cerr << "- arg1: bpRes (e.g., 1kb)" << endl;
    cerr << "- arg2: min depth in bpRes bin" << endl;
    cerr << "- arg3: max depth in bpRes bin" << endl;
    cerr << "- arg4: quantile to use for deciding most-variable-depth bins (sigma2ratio)" << endl;
    cerr << "- arg5: pilot per-100bp mean read count file" << endl;
    cerr << "- arg6: output header file" << endl;
    cerr << "- arg7: output data file" << endl;
    exit(1);
  }

  int bpRes; sscanf(argv[1], "%d", &bpRes);
  double minMeanDepth; sscanf(argv[2], "%lf", &minMeanDepth);
  double maxMeanDepth; sscanf(argv[3], "%lf", &maxMeanDepth);
  double sigma2ratioQuantile; sscanf(argv[4], "%lf", &sigma2ratioQuantile);
  const char *pilotMeanReadCountFile = argv[5];
  const char *outFileHeader = argv[6];
  const char *outFileData = argv[7];
  
  FileUtils::AutoGzIfstream fin;

  /***** CREATE REGION LIST, EXCLUDING REGIONS OVERLAPPING CNVs / VNTRs *****/

  // mark regions overlapping 1KGP3 SVs / 1000G 30x CNVs
  vector < vector <bool> > overlapsCNV(23, vector <bool> ((int) (250e6/bpRes)));
  for (int i = 0; i < 2; i++) {
    fin.openOrExit(i==0 ? "/mnt/project/lohdata/ploh/WES_CNVs/noise_neighbors/inputs/1KGP3.svs.MAF_gt_0.001.txt" : "/mnt/project/lohdata/ploh/WES_CNVs/noise_neighbors/inputs/1000G_30x.CNVs.MAF_gt_0.001.txt");
    int numCNVs = 0, totLen = 0; string chrStr; 
    while (fin >> chrStr) {
      if (chrStr == "chrX" || chrStr == "chrY") continue;
      numCNVs++;
      int chr, bpStart, bpEnd;//, bpLen; string vntrStr;
      sscanf(chrStr.c_str(), "chr%d", &chr);
      fin >> bpStart >> bpEnd;// >> vntrStr >> bpLen;
      for (int bin = bpStart/bpRes; bin <= bpEnd/bpRes; bin++)
	overlapsCNV[chr][bin] = true;
      totLen += bpEnd - bpStart;
    }
    fin.close();
    cout << "Read " << numCNVs << " autosomal CNVs spanning " << totLen*1e-6 << " Mb" << endl;
  }

  // mark regions overlapping 118 coding VNTRs
  int numVNTRs = 0, totLen = 0; string chrStr;
  fin.openOrExit("/mnt/project/lohdata/ploh/WES_CNVs/noise_neighbors/inputs/118_vntr_list.txt");
  string vntrStr;
  while (fin >> vntrStr) {
    numVNTRs++;
    int chr, bpStart, bpEnd;
    sscanf(vntrStr.c_str(), "hg38_%d_%d_%d", &chr, &bpStart, &bpEnd);
    for (int bin = bpStart/bpRes; bin <= bpEnd/bpRes; bin++)
      overlapsCNV[chr][bin] = true;
    totLen += bpEnd - bpStart;    
  }
  fin.close();
  cout << "Read " << numVNTRs << " autosomal VNTRs spanning " << totLen*1e-6 << " Mb" << endl;

  // read region spans; create extract list
  vector < map <int, double> > binMeanDepths(23); // [chr][kb-scale bin] -> mean depth across bin
  int R; // number of 100bp regions in which reads were counted
  set < pair <int, int> > binSet; // set of selected kb-scale bins passing read depth filters
  vector <int> binInd; // 0 if not selected; 1..binSet.size() if assigned to a bin
  vector <int> regionChr; // chromosome of each 100bp region
  int Rextract = 0;
  vector <int> Rchr(23);
  for (int iter = 1; iter <= 2; iter++) { // read file twice: 1 = merge depths; 2 = select bins
    R = 0;
    fin.openOrExit(pilotMeanReadCountFile);
    int chr, bpStart; double meanReadCount;
    while (fin >> chr >> bpStart >> meanReadCount) {
      if (iter == 1) regionChr.push_back(chr);
      R++;
      if (chr > 22) continue;
      if (iter == 1) Rchr[chr]++;
      if (iter==1) {
	if (!overlapsCNV[chr][bpStart/bpRes])
	  binMeanDepths[chr][bpStart/bpRes] += meanReadCount*76*2/bpRes; // 1 read counted per pair
      }
      else {
	double binMeanDepth = binMeanDepths[chr][bpStart/bpRes];
	if (minMeanDepth <= binMeanDepth && binMeanDepth <= maxMeanDepth) {
	  binSet.insert(make_pair(chr, bpStart/bpRes));
	  binInd[R-1] = binSet.size()-1;
	  Rextract++;
	}
      }
    }
    fin.close();
    if (iter == 1)
      binInd = vector <int> (R, -1);
  }
  
  int Bextract = binSet.size();
  cout << "Using " << Bextract << " bins containing " << Rextract << " / " << R << " regions"
       << endl;

  /***** compute mean read count (in 100bp non-CNV regions) per chr for "medianLRR" later *****/
  vector <double> chrMeanCounts(23);
  vector < pair <int, int> > binVec(binSet.begin(), binSet.end());
  for (int s = 0; s < Bextract; s++) {
    int chr = binVec[s].first, bin = binVec[s].second;
    chrMeanCounts[chr] += binMeanDepths[chr][bin] / (76*2) * bpRes;
  }
  for (int chr = 1; chr <= 22; chr++)
    cout << "chr" << chr << " bins: " << Rchr[chr] << " mean count: " << chrMeanCounts[chr]
	 << endl;


  const int normBatches = 1; // number of batches to use for normalizing regions

  /***** extract data for 1st normBatches; normalize each individual (across regions) *****/

  int N = 0;
  const int MAX_N = 5000*normBatches;
  float *depths = new float[MAX_N*Bextract]; // note: scale is relative (not absolute) here
  memset(depths, 0, MAX_N*Bextract*sizeof(depths[0]));
  unsigned short *counts = new unsigned short[R];
  string line;

  for (int batch = 0; batch < normBatches; batch++) {
    FileUtils::AutoGzIfstream finBinGz[23];
    for (int chr = 1; chr <= 22; chr++) {
      char buf[200];
      sprintf(buf, "/mnt/project/lohdata/ploh/WES_CNVs/WES_read_counts/results/WES_15K_batch%d/WES_15K_batch%d.100bp.chr%d.bin.gz", batch, batch, chr);
      finBinGz[chr].openOrExit(buf, std::ios_base::binary);
      // ignore header
      uint64 headerBytes = sizeof(int) + Rchr[chr] * sizeof(Region);
      char *header = new char[headerBytes]; finBinGz[chr].read(header, headerBytes);
      assert(Rchr[chr] == *(int *) header);
      delete[] header;
    }
    int ID;
    while (finBinGz[1].read((char *) &ID, sizeof(int))) {
      //cout << ID << endl;
      for (int chr = 2; chr <= 22; chr++) {
	int ID2 = 0;
	finBinGz[chr].read((char *) &ID2, sizeof(int));
	//cout << ID2 << endl;
	assert(ID == ID2);
      }
      int Rcum = 0;
      for (int chr = 1; chr <= 22; chr++) {
	finBinGz[chr].read((char *) &counts[Rcum], 2*Rchr[chr]);
	Rcum += Rchr[chr];
      }
      float *depthsRow = depths + N*Bextract;
      int extractedCounts[23]; memset(extractedCounts, 0, 23*4);//float mean_depthExtract = 0;
      for (int r = 0; r < R; r++) {
	if (binInd[r] != -1) {
	  float count = counts[r]; // data line contains counts for MAPQ>=0 / MAPQ>=1
	  depthsRow[binInd[r]] += count;
	  extractedCounts[regionChr[r]] += count;//mean_depthExtract += count;
	}
      }
      //mean_depthExtract /= Bextract;
      float invScale = 1 / medianScale(&chrMeanCounts[0], extractedCounts)/*mean_depthExtract*/;
      for (int s = 0; s < Bextract; s++)
	depthsRow[s] *= invScale;
      N++;
    }
    for (int chr = 1; chr <= 22; chr++)
      finBinGz[chr].close();
    cout << "Read batch " << batch << " (" << timer.update_time() << " sec)" << endl;
  }
  cout << "Read " << N << " indivs; normalizing by region" << endl;
  
  /***** normalize each region (across individuals) *****/

  float x[N];
  const float ratioMult = 100; // just for more readable output
  vector <float> mus(Bextract), sigma2s(Bextract), sigma2ratios(Bextract), invRootMeans(Bextract);
  for (int s = 0; s < Bextract; s++) {
    for (int n = 0; n < N; n++) x[n] = depths[n*Bextract+s];
    float mu = 0, s2 = 0;
    for (int n = 0; n < N; n++) mu += x[n];
    mu /= N;
    for (int n = 0; n < N; n++) s2 += (x[n]-mu)*(x[n]-mu);
    s2 /= (N-1);
    mus[s] = mu; sigma2s[s] = s2; sigma2ratios[s] = ratioMult*s2/mu;
    //float invStd = 1 / sqrtf(s2);
    float invRootMean = 1 / sqrtf(mu);
    invRootMeans[s] = invRootMean;
    for (int n = 0; n < N; n++) depths[n*Bextract+s] = (x[n]-mu) * invRootMean/*invStd*/;
  }
  cout << "Normalized by region (" << timer.update_time() << " sec)" << endl;

  sort(sigma2ratios.begin(), sigma2ratios.end());
  float sigma2ratioMin = sigma2ratios[(int) (sigma2ratioQuantile*Bextract)];
  vector <bool> want(Bextract); int Rwant = 0;
  for (int s = 0; s < Bextract; s++)
    if (ratioMult*sigma2s[s]/mus[s] > sigma2ratioMin) {
      want[s] = true;
      Rwant++;
    }
  cout << "Restricting to " << Rwant << " regions with sigma2ratio > " << sigma2ratioMin
       << endl;
  float sigma2ratioMedian = sigma2ratios[Bextract/2];
  cout << "Rescaling to approximate z-scores based on median sigma2ratio = " << sigma2ratioMedian
       << endl;
  float scaleOverall = 1/sqrtf(sigma2ratioMedian/ratioMult);

  /***** stream batches; write output *****/
  
  FileUtils::AutoGzOfstream fout;
  fout.openOrExit(outFileData);
  fout << std::setprecision(3) << std::fixed;

  N = 0;
  for (int batch = 0; batch < 10; batch++) {
    FileUtils::AutoGzIfstream finBinGz[23];
    for (int chr = 1; chr <= 22; chr++) {
      char buf[200];
      sprintf(buf, "/mnt/project/lohdata/ploh/WES_CNVs/WES_read_counts/results/WES_15K_batch%d/WES_15K_batch%d.100bp.chr%d.bin.gz", batch, batch, chr);
      finBinGz[chr].openOrExit(buf, std::ios_base::binary);
      // ignore header
      uint64 headerBytes = sizeof(int) + Rchr[chr] * sizeof(Region);
      char *header = new char[headerBytes]; finBinGz[chr].read(header, headerBytes); delete[] header;
    }
    int ID;
    while (finBinGz[1].read((char *) &ID, sizeof(int))) {
      for (int chr = 2; chr <= 22; chr++) {
	int ID2 = 0;
	finBinGz[chr].read((char *) &ID2, sizeof(int));
	assert(ID == ID2);
      }
      int Rcum = 0;
      for (int chr = 1; chr <= 22; chr++) {
	finBinGz[chr].read((char *) &counts[Rcum], 2*Rchr[chr]);
	Rcum += Rchr[chr];
      }
      // process data for sample
      float *depthsRow = depths;
      memset(depthsRow, 0, Bextract*sizeof(depthsRow[0]));
      int extractedCounts[23]; memset(extractedCounts, 0, 23*4);//float mean_depthExtract = 0;
      for (int r = 0; r < R; r++) {
	if (binInd[r] != -1) {
	  float count = counts[r]; // data line contains counts for MAPQ>=0 / MAPQ>=1
	  depthsRow[binInd[r]] += count;
	  extractedCounts[regionChr[r]] += count;//mean_depthExtract += count;
	}
      }
      //mean_depthExtract /= Bextract;
      float invScale = 1 / medianScale(&chrMeanCounts[0], extractedCounts)/*mean_depthExtract*/;
      N++;

      // write output (overall read depth + scaled depth at selected regions)
      fout << ID << "\t" << std::setprecision(3) << 1 / invScale << std::setprecision(2);
      for (int s = 0; s < Bextract; s++)
	if (want[s])
	  fout << "\t" << scaleOverall * (depthsRow[s]*invScale - mus[s]) * invRootMeans[s];
      fout << endl;
    }
    for (int chr = 1; chr <= 22; chr++)
      finBinGz[chr].close();
    cout << "Read batch " << batch << " (" << timer.update_time() << " sec)" << endl;
  }
  cout << "Normalized " << N << " indivs" << endl;
  fout.close();


  // write header lines (now that N is known)

  fout.openOrExit(outFileHeader);
  fout << std::setprecision(3) << std::fixed;

  fout << N << "\t" << Rwant;
  for (int s = 0; s < Bextract; s++)
    if (want[s])
      fout << "\t" << mus[s];
  fout << endl;
  fout << N << "\t" << Rwant;
  for (int s = 0; s < Bextract; s++)
    if (want[s])
      fout << "\t" << ratioMult*sigma2s[s]/mus[s];
  fout << endl;
  for (int i = 0; i < 3; i++) { // output chr bpStart bpEnd for each extracted bin
    fout << N << "\t" << Rwant;
    for (int s = 0; s < Bextract; s++)
      if (want[s])
	fout << "\t" << (i==0 ? binVec[s].first : ((binVec[s].second+(i==2))*bpRes));
    fout << endl;
  }
  fout.close();

  delete[] depths;
  delete[] counts;

  return 0;
}
