// g++ -O3 -fopenmp -Wall -static-libgcc -static-libstdc++ normalizeReadCounts.cpp -o normalizeReadCounts -I/n/groups/price/poru/HSPH_SVN/src/EAGLE -I/home/pl88/boost_1_58_0/install/include -L/n/groups/price/poru/external_software/libstdc++/usr/lib/gcc/x86_64-redhat-linux/4.8.5/ -L/n/groups/price/poru/external_software/zlib/zlib-1.2.11 -L/home/pl88/boost_1_58_0/install/lib -Wl,-Bstatic -lboost_iostreams -lz

#include <iostream>
#include <iomanip>
#include <fstream>
#include <sstream>
#include <vector>
#include <string>
#include <queue>
#include <set>
#include <map>
#include <utility>
#include <algorithm>
#include <numeric>
#include <cstdio>
#include <cstdlib>
#include <cassert>
#include <cmath>

#include <boost/math/distributions/negative_binomial.hpp>

#include "omp.h"

#include "Types.hpp"
#include "FileUtils.cpp"
#include "StringUtils.cpp"
#include "NumericUtils.cpp"
#include "Timer.cpp"

using namespace std;

#define MAX_ID 6100000

struct Region {
  int start38, end38, minMAPQ;
  string toString(void) const {
    return StringUtils::itos(start38) + "_" + StringUtils::itos(end38)
      + ".Q" + StringUtils::itos(minMAPQ);
  }
};

inline float sq(float x) { return x*x; }

double negLogL(const vector <float> &expectRCs, const vector <ushort> &binRCs, float theta0) {
  const float clipProb = 1e-4; // for robust estimation (M-estimator?)
  double negLogLsum = 0;
  float r = 1 / theta0;
  int Nkept = binRCs.size();
  for (int j = 0; j < Nkept; j++) {
    float theta = expectRCs[j] * theta0;
    float p = 1 / (1+theta);
    float prob =
      boost::math::pdf(boost::math::negative_binomial_distribution <float> (r, p), binRCs[j]);
    negLogLsum += -logf(max(prob, clipProb));
  }
  return negLogLsum;
}

void readNeighborNormData(vector <int> &normIDs, vector < vector <int> > &normNeighborIDs,
			  vector <float> &IDtoInvScale, set <int> &IDsKeep,
			  const set <int> &IDsToNormalize, const char *IDscaleNeighborsFile,
			  int nNbr) {
  IDtoInvScale = vector <float> (MAX_ID, NAN);
  IDsKeep = IDsToNormalize;
  FileUtils::AutoGzIfstream fin; fin.openOrExit(IDscaleNeighborsFile);
  int ID; float scale;
  while (fin >> ID >> scale) {
    assert(scale > 0);
    IDtoInvScale[ID] = 1 / scale;
    string line; getline(fin, line);

    if (!IDsToNormalize.count(ID)) continue;

    normIDs.push_back(ID);
    // read noise neighbors of ID
    vector <int> neighbors;
    istringstream iss(line);
    int ctrNbr = 0; int IDnbr;
    while (ctrNbr < nNbr && (iss >> IDnbr)) {
      ctrNbr++;
      neighbors.push_back(IDnbr);
      IDsKeep.insert(IDnbr);
    }
    normNeighborIDs.push_back(neighbors);
  }
  fin.close();
  cout << "Stored noise neighbors for " << normIDs.size() << " samples" << endl;
  cout << "Flagged " << IDsKeep.size() << " IDs for which to store read counts" << endl;
  return;
}

ushort *unzipReadCounts(vector <int> &IDtoReadCountsInd, vector <Region> &regions,
			const set <int> &IDsKeep, const char *regionCountsBinGzListFile) {

  // read IDs; set up indexing
  vector <int> IDtoReadCountsIndPrelim(MAX_ID, -1); // from ID list; some may not have read counts
  IDtoReadCountsInd = IDtoReadCountsIndPrelim;
  int Nkeep = 0;
  for (set <int>::iterator it = IDsKeep.begin(); it != IDsKeep.end(); it++)
    IDtoReadCountsIndPrelim[*it] = Nkeep++;

  // read file list
  vector <string> files;
  {
    FileUtils::AutoGzIfstream finBinGzList; finBinGzList.openOrExit(regionCountsBinGzListFile);
    string file;
    while (getline(finBinGzList, file))
      files.push_back(file);
    finBinGzList.close();
    cout << "Read " << files.size() << " read count file paths" << endl;
  }

  // read regions from header of first file
  int R;
  {
    FileUtils::AutoGzIfstream finBinGz; finBinGz.openOrExit(files[0], std::ios_base::binary);
    finBinGz.read((char *) &R, sizeof(int));
    regions.resize(R);
    finBinGz.read((char *) &regions[0], R * sizeof(Region));
    finBinGz.close();
    cout << "Read " << R << " regions in which reads were counted" << endl;
  }

  // unzip read counts in parallel
  ushort *readCounts = new ushort[Nkeep * (uint64) R];
  int ctr = 0;
#pragma omp parallel for
  for (uint f = 0; f < files.size(); f++) {
    FileUtils::AutoGzIfstream finBinGz; finBinGz.openOrExit(files[f], std::ios_base::binary);
    // ignore header
    uint64 headerBytes = sizeof(int) + R * sizeof(Region);
    char *header = new char[headerBytes]; finBinGz.read(header, headerBytes); delete[] header;
    // read WES read counts per sample
    int ID;
    uint64 readCountBytes = R * sizeof(readCounts[0]);
    char *ignoreBytes = new char[readCountBytes];
    while (finBinGz.read((char *) &ID, sizeof(int))) {
      if (IDtoReadCountsIndPrelim[ID] == -1) // don't keep this sample
	finBinGz.read(ignoreBytes, readCountBytes);
      else {
	IDtoReadCountsInd[ID] = IDtoReadCountsIndPrelim[ID]; // sample has read counts
	finBinGz.read((char *) readCounts + IDtoReadCountsInd[ID]*readCountBytes, readCountBytes);
#pragma omp atomic
	ctr++;
      }
    }
    delete[] ignoreBytes;
    finBinGz.close();
  }
  cout << "Stored WES read counts for " << ctr << " samples" << endl;

  return readCounts;
}

// NOTE: need to restrict to same-sex if analyzing chrX or chrY
bool normalizeSample(float *normRCs, float *baselineRCs, uint *sumRCs, float *sumBaselineRCs,
		     int normID, const vector <int> &normNeighborIDs,
		     const vector <float> &IDtoInvScale, const vector <int> &IDtoReadCountsInd,
		     uint64 R, int t, int threads, ushort *readCounts) {

  int ind = IDtoReadCountsInd[normID];
  if (ind == -1) return false; // read counts not available
  float invScale = IDtoInvScale[normID];

  int rStart = R*t / threads, rEnd = R*(t+1) / threads;

  int num = 0;
  float *sumsChunk = new float[rEnd-rStart];
  memset(sumsChunk, 0, (rEnd-rStart)*sizeof(sumsChunk[0]));

  for (uint k = 0; k < normNeighborIDs.size(); k++) {
    int indNbr = IDtoReadCountsInd[normNeighborIDs[k]];
    if (indNbr != -1) {
      num++;
      float invScaleNbr = IDtoInvScale[normNeighborIDs[k]];
      const ushort *readCountsNbr = readCounts + indNbr*R;
      for (int r = rStart; r < rEnd; r++) {
	float nbrCvgAdjRC = invScaleNbr * readCountsNbr[r];
	sumsChunk[r-rStart] += nbrCvgAdjRC;
      }
    }
  }

  for (int r = rStart; r < rEnd; r++) {
    baselineRCs[r] = sumsChunk[r-rStart] / num;
    normRCs[r] = invScale * readCounts[ind*R + r] / baselineRCs[r];
    sumRCs[r] += readCounts[ind*R + r];
    sumBaselineRCs[r] += baselineRCs[r];
  }

  delete[] sumsChunk;

  return true;
}


int main(int argc, char *argv[]) {

  if (argc != 8) {
    cerr << "Usage:" << endl;
    cerr << "- arg1 = ID list to normalize" << endl;
    cerr << "- arg2 = ID scale neighbors (trimmed) file for noise-neighbor normalization" << endl;
    cerr << "- arg3 = 100bp-bin read count .bin.gz list file" << endl;
    cerr << "- arg4 = number of noise neighbors to use for normalization" << endl;
    cerr << "- arg5 = threads" << endl;
    cerr << "- arg6 = baselineScales output file (binary)" << endl;
    cerr << "- arg7 = tmp read count output file (binary)" << endl;
    exit(1);
  }
  const char *IDlistFile = argv[1];
  const char *IDscaleNeighborsFile = argv[2];
  const char *regionCountsBinGzListFile = argv[3];
  int nNbr; sscanf(argv[4], "%d", &nNbr);
  int threads; sscanf(argv[5], "%d", &threads);
  const char *outFile = argv[6];
  const char *outFileRCs = argv[7];

  Timer timer;
  
  cout << "Setting number of threads to " << threads << endl << endl;
  omp_set_num_threads(threads);

  // read IDs to normalize
  set <int> IDsToNormalize;
  {
    FileUtils::AutoGzIfstream finIDlist; finIDlist.openOrExit(IDlistFile);
    int ID;
    while (finIDlist >> ID)
      IDsToNormalize.insert(ID);
    finIDlist.close();
    cout << "Read " << IDsToNormalize.size() << " IDs to normalize" << endl;
  }  

  // read noise neighbor data
  vector <int> normIDs; // may be a subset of IDsKeep
  vector < vector <int> > normNeighborIDs;
  vector <float> IDtoInvScale;
  set <int> IDsKeep;
  readNeighborNormData(normIDs, normNeighborIDs, IDtoInvScale, IDsKeep, IDsToNormalize,
		       IDscaleNeighborsFile, nNbr);
  cout << "\nTime for reading noise neighbor data: " << timer.update_time() << " sec\n" << endl;

  // read 100bp bin counts
  vector <int> IDtoReadCountsInd;
  vector <Region> regions;
  ushort *readCounts =
    unzipReadCounts(IDtoReadCountsInd, regions, IDsKeep, regionCountsBinGzListFile);
  cout << "\nTime for unzipping read counts: " << timer.update_time() << " sec\n" << endl;

  // allocate storage for normalized data
  const int Nnorm = normIDs.size();
  const uint64 R = regions.size();
  float *normRCs = new float[Nnorm * R];
  float *baselineRCs = new float[Nnorm * R];

  // allocate storage for running totals
  int numRCs = 0;
  vector <uint> sumRCs(R);
  vector <float> sumBaselineRCs(R);

  /***** run normalization (using noise neighbors -> baselineRCs, normRCs) *****/

#pragma omp parallel for
  for (int t = 0; t < threads; t++)
    for (int i = 0; i < Nnorm; i++) {
      bool success = normalizeSample(normRCs + i*R, baselineRCs + i*R,
				     &sumRCs[0], &sumBaselineRCs[0],
				     normIDs[i], normNeighborIDs[i], IDtoInvScale,
				     IDtoReadCountsInd, R, t, threads, readCounts);
      if (t == 0) numRCs += success;
    }

  cout << "\nTime for normalizing read counts: " << timer.update_time() << " sec\n" << endl;


  /***** compute bin-level statistics *****/

  vector <float> meanRCs(R), meanBaselineRCs(R);
  for (uint r = 0; r < R; r++) {
    meanRCs[r] = sumRCs[r] / (float) numRCs;
    meanBaselineRCs[r] = sumBaselineRCs[r] / numRCs;
  }

  vector <float> numNormRCs(R), sumNormRCs(R), sum2NormRCs(R), sdNormRCs(R);
  vector <float> sum2BaselineRCs(R), coeffVarBaselineRCs(R);
#pragma omp parallel for
  for (int t = 0; t < threads; t++) {
    int rStart = R*t / threads, rEnd = R*(t+1) / threads;
    for (int i = 0; i < Nnorm; i++)
      if (IDtoReadCountsInd[normIDs[i]] != -1)
	for (int r = rStart; r < rEnd; r++) {
	  if (baselineRCs[i*R + r] > meanBaselineRCs[r] * 0.8f && // check baselineScale reasonable
	      baselineRCs[i*R + r] < meanBaselineRCs[r] * 1.333333f) {
	    numNormRCs[r]++;
	    sumNormRCs[r] += normRCs[i*R + r];
	    sum2NormRCs[r] += sq(normRCs[i*R + r]);
	  }
	  sum2BaselineRCs[r] += sq(baselineRCs[i*R + r]);
	}
    for (int r = rStart; r < rEnd; r++) {
      sdNormRCs[r] = sqrtf((sum2NormRCs[r] - sq(sumNormRCs[r])/numNormRCs[r]) / (numNormRCs[r]-1));
      float sdBaselineRCs_r = sqrtf((sum2BaselineRCs[r] - sq(sumBaselineRCs[r])/numRCs)
				    / (numRCs-1));
      coeffVarBaselineRCs[r] = sdBaselineRCs_r / meanBaselineRCs[r];
    }
  }

  // fit negative binomial parameters (note: downstream analyses will restrict to noCommonSV bins)
  vector <float> theta0s(R);
  //vector <string> fitInfo(R);
#pragma omp parallel for
  for (uint r = 0; r < R; r++) {
    // populate RCs and expected RCs for the current bin + compute moment-based estimate of theta0
    vector <ushort> binRCs; binRCs.reserve(Nnorm);
    vector <float> expectRCs; expectRCs.reserve(Nnorm);
    double MOMsum = 0, MOMnum = 0;
    for (int i = 0; i < Nnorm; i++) {
      int ind = IDtoReadCountsInd[normIDs[i]];
      if (ind != -1) {
	ushort RC = readCounts[ind*R + r];
	float normRC = normRCs[i*R + r];
	float expectRC = baselineRCs[i*R + r] / IDtoInvScale[normIDs[i]];
	float invExpectRC = 1 / expectRC;
	MOMsum += min(sq(normRC-1), 10*invExpectRC) - invExpectRC;
	MOMnum++;
	binRCs.push_back(RC);
	expectRCs.push_back(expectRC);
      }
    }
    float theta0_MOM_uncropped = MOMsum / MOMnum;
    const float theta0_min = 0.0001f, theta0_max = 1;
    float theta0_MOM = max(theta0_min, theta0_MOM_uncropped);

    // quadratic iteration
    vector < pair <double, float> > yx;
    for (float mult = 0.5; mult <= 1.5; mult += 0.5) {
      float xCur = theta0_MOM * mult;
      yx.push_back(make_pair(negLogL(expectRCs, binRCs, xCur), xCur));
    }
    float xLast = theta0_MOM, xCur;
    while (true) {
      sort(yx.begin(), yx.end());
      float x[3] = {yx[0].second, yx[1].second, yx[2].second};
      double y[3] = {yx[0].first, yx[1].first, yx[2].first};
      xCur =
	- (sq(x[2]) * (y[0] - y[1]) + sq(x[1]) * (y[2] - y[0]) + sq(x[0]) * (y[1] - y[2]))
	/ (2 * (x[2] * (y[1] - y[0]) + x[1] * (y[0] - y[2]) + x[0] * (y[2] - y[1])));
      xCur = min(max(xCur, theta0_min), theta0_max);
      if (isnan(xCur)) xCur = xLast; // error (usually same value used twice) => take last estimate
      if (fabsf(xCur - xLast) < theta0_min) // converged
	break;
      if (yx.size() == 20) // ran out of iters; use best attempt
	xCur = yx[0].second;
      yx.push_back(make_pair(negLogL(expectRCs, binRCs, xCur), xCur));
      xLast = xCur;
    }
    theta0s[r] = xCur;
    
  }
  cout << "\nTime for computing bin statistics: " << timer.update_time() << " sec\n" << endl;

  // write output
  FILE *fout = fopen(outFile, "wb"); assert(fout != NULL);
  FILE *foutRCs = fopen(outFileRCs, "wb"); assert(fout != NULL);
  fwrite(&R, sizeof(R), 1, fout);
  fwrite(&regions[0], sizeof(regions[0]), R, fout);
  fwrite(&meanRCs[0], sizeof(meanRCs[0]), R, fout);
  fwrite(&sdNormRCs[0], sizeof(sdNormRCs[0]), R, fout);
  fwrite(&meanBaselineRCs[0], sizeof(meanBaselineRCs[0]), R, fout);
  fwrite(&coeffVarBaselineRCs[0], sizeof(coeffVarBaselineRCs[0]), R, fout);
  fwrite(&theta0s[0], sizeof(theta0s[0]), R, fout);
  vector <uchar> baselineScaleBytes(R);
  for (int i = 0; i < Nnorm; i++) {
    int ind = IDtoReadCountsInd[normIDs[i]];
    if (ind != -1) {
      fwrite(&normIDs[i], sizeof(normIDs[0]), 1, fout); // ID
      float coverageScale = 1 / IDtoInvScale[normIDs[i]];
      fwrite(&coverageScale, sizeof(coverageScale), 1, fout); // coverage scale factor
      for (uint r = 0; r < R; r++) // 1-byte values: baselineScale
        baselineScaleBytes[r]
	  = min(255, max(0, (int) (128 * baselineRCs[i*R + r] / meanBaselineRCs[r] + 0.5f)));
      fwrite(&baselineScaleBytes[0], sizeof(baselineScaleBytes[0]), R, fout);
      fwrite(&readCounts[ind*R], sizeof(readCounts[0]), R, foutRCs); // read counts -> tmp file
    }
  }
  assert(!ferror(fout));
  assert(!ferror(foutRCs));
  fclose(fout);
  fclose(foutRCs);

  cout << "\nTime for writing output: " << timer.update_time() << " sec\n" << endl;

  delete[] readCounts;
  delete[] normRCs;
  delete[] baselineRCs;
  
  return 0;
}
