#include <iostream>
#include <vector>
#include <string>
#include <map>
#include <algorithm>
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <cassert>
#include "zlib.h"

#include "VersionHeader.hpp"
#include "FileUtils.cpp"
#include "StringUtils.cpp"
#include "Timer.cpp"

using namespace std;

void process(double *dBuf, double lut[256], uchar *buf, uint bufLen, const uchar *zBuf,
	     uint zBufLen, uint Nbgen, uint M, uint mStart, uint mEnd, uint mSkip,
	     uint64 *maskBits, double callThresh, const vector <int> &bgenInds) {

  uLongf destLen = bufLen;
  if (uncompress(buf, &destLen, zBuf, zBufLen) != Z_OK || destLen != bufLen) {
    cout << "ERROR: uncompress() failed" << endl;
    exit(1);
  }
  uchar *bufAt = buf;
  uint Ncheck = bufAt[0]|(bufAt[1]<<8)|(bufAt[2]<<16)|(bufAt[3]<<24); bufAt += 4;
  assert(Ncheck==Nbgen);
  uint K = bufAt[0]|(bufAt[1]<<8); bufAt += 2;
  assert(K==2U);
  uint Pmin = *bufAt; bufAt++;
  assert(Pmin==2U);
  uint Pmax = *bufAt; bufAt++;
  assert(Pmax==2U);
  for (uint i = 0; i < Nbgen; i++) {
    uint ploidyMiss = *bufAt; bufAt++;
    assert(ploidyMiss==2U);
  }
  uint Phased = *bufAt; bufAt++;
  assert(Phased==0U || Phased==1U);
  uint B = *bufAt; bufAt++;
  assert(B==8U);

  double mean = 0;
  for (uint i = 0; i < Nbgen; i++) {
    dBuf[i] = lut[bufAt[0]] * (Phased==0U ? 2 : 1) + lut[bufAt[1]]; // dosage
    bufAt += 2;
    mean += dBuf[i];
  }
  mean /= 2*Nbgen;

  for (uint ip = 0; ip < bgenInds.size(); ip++)
    if (bgenInds[ip] != -1) {
      double dosage = dBuf[bgenInds[ip]];
      bool dropAll = 0.25<mean && mean<0.75;
      if (dropAll || (mean<0.5 && dosage>=callThresh) || (mean>=0.5 && dosage<=2-callThresh)) {
	for (uint m = mStart; m < mEnd; m++)
	  if (m != mSkip) {
	    uint64 x = ip * (uint64) M + m;
	    maskBits[x>>6ULL] |= 1ULL<<(x&63ULL);
	  }
      }
    }
}

void processBgen(uint64 *maskBits, const string &bgenFile, const vector <int> &bgenInds,
		 const vector <int> &bps, const vector <string> &REFs, const vector <string> &ALTs,
		 int maxBpDist, double callThresh) {

  int M = bps.size();

  // read BGEN header
  cout << "Reading bgen file: " << bgenFile << endl;
  FILE *fin = fopen(bgenFile.c_str(), "rb"); assert(fin != NULL);
  uint offset; fread(&offset, 4, 1, fin);
  uint L_H; fread(&L_H, 4, 1, fin);
  uint Mbgen; fread(&Mbgen, 4, 1, fin);
  uint Nbgen; fread(&Nbgen, 4, 1, fin);
  char magic[5]; fread(magic, 1, 4, fin); magic[4] = '\0';
  assert(magic==string("bgen"));
  fseek(fin, L_H-20, SEEK_CUR);
  uint flags; fread(&flags, 4, 1, fin);
  uint CompressedSNPBlocks = flags&3;
  assert(CompressedSNPBlocks==1);
  uint Layout = (flags>>2)&0xf;
  assert(Layout==2);
  fseek(fin, offset+4, SEEK_SET);

  // process variant data
  cout << "Reading data for " << Mbgen << " variants in bgen file" << endl;

  double lut[256];
  for (int i = 0; i <= 255; i++)
    lut[i] = i/255.0;

  int maxLA = 65536, maxLB = 65536;
  char *A1 = (char *) malloc(maxLA+1);
  char *A2 = (char *) malloc(maxLB+1);

  uchar *zBuf = (uchar *) malloc(3*Nbgen+100);
  uchar *buf = (uchar *) malloc(3*Nbgen+100);
  double *dBuf = (double *) malloc(Nbgen * sizeof(double));
 
  int ctrIsect = 0;
  for (uint mbgen = 0; mbgen < Mbgen; mbgen++) {
    ushort LS; fread(&LS, 2, 1, fin);
    //fread(snpID, 1, LS, fin); snpID[LS] = '\0';
    fseek(fin, LS, SEEK_CUR);
    ushort LR; fread(&LR, 2, 1, fin);
    //fread(rsID, 1, LR, fin); rsID[LR] = '\0';
    fseek(fin, LR, SEEK_CUR);
    ushort LC; fread(&LC, 2, 1, fin);
    //fread(chrStr, 1, LC, fin); chrStr[LC] = '\0';
    fseek(fin, LC, SEEK_CUR);
    int bp; fread(&bp, 4, 1, fin);
    ushort K; fread(&K, 2, 1, fin);
    assert(K==2); // bi-allelic
    int LA; fread(&LA, 4, 1, fin);
    if (LA > maxLA) {
      maxLA = 2*LA;
      free(A1);
      A1 = (char *) malloc(maxLA+1);
    }
    fread(A1, 1, LA, fin); A1[LA] = '\0';
    int LB; fread(&LB, 4, 1, fin);
    if (LB > maxLB) {
      maxLB = 2*LB;
      free(A2);
      A2 = (char *) malloc(maxLB+1);
    }
    fread(A2, 1, LB, fin); A2[LB] = '\0';

    // read genotype data block (Layout 2)
    uint C; fread(&C, 4, 1, fin); //cout << "C: " << C << endl;
    uint D; fread(&D, 4, 1, fin); //cout << "D: " << D << endl;
    uint zBufLen = C-4; uint bufLen = D;

  // NOTE: be careful about N != Nbgen!

    // check bp; skip if not near a typed SNP
    // NOTE: event affects bp-maxBpDist to bp+(LB-1)+maxBpDist inclusive... be careful about DELs
    int lo = 0, hi = M; // want lo=hi to be first position s.t. bps[.] >= bp-maxBpDist
    while (lo < hi) {
      int mid = (lo+hi)/2;
      if (bps[mid] < bp-maxBpDist)
	lo = mid+1;
      else
	hi = mid;
    }
    int mStart = lo, mEnd = lo, mSkip = -1;
    while (mEnd < M && bps[mEnd] <= bp+(LB-1)+maxBpDist) {
      if (bps[mEnd] == bp) {
	if ((REFs[mEnd]==string(A1) && ALTs[mEnd]==string(A2)) ||
	    (REFs[mEnd]==string(A2) && ALTs[mEnd]==string(A1)))
	  mSkip = mEnd;
      }
      mEnd++;
    }

    if (mEnd > mStart) {
      fread(zBuf, 1, zBufLen, fin);
      process(dBuf, lut, buf, bufLen, zBuf, zBufLen, Nbgen, M, mStart, mEnd, mSkip, maskBits,
	      callThresh, bgenInds);
      ctrIsect++;
    }
    else
      fseek(fin, zBufLen, SEEK_CUR);

    if (mbgen % 100000 == 99999)
      cout << "Processed " << mbgen+1 << " variants; " << ctrIsect << " near typed SNPs" << endl;
  }

  free(A1);
  free(A2);
  free(zBuf);
  free(buf);
  free(dBuf);
  fclose(fin);

  cout << "Finished processing bgen file; found " << ctrIsect << " variants near typed SNPs"
       << endl;
}


string makeIDpair(const string &ID_1, const string &ID_2) { return ID_1 + "\t" + ID_2; }

vector <int> getBgenInds(const vector <int> &famInds, const char *sampleInfoFile,
			 const string &sampleFile) {

  cout << "Reading sample file: " << sampleFile << endl;
  FileUtils::AutoGzIfstream finSample; finSample.openOrExit(sampleFile);
  string line; getline(finSample, line); getline(finSample, line); // throw away headers
  string ID_1, ID_2;
  int bgenInd = 0;
  map <string, int> IDpairToBgenInd, IDimpToBgenInd;
  while (finSample >> ID_1 >> ID_2) {
    IDpairToBgenInd[makeIDpair(ID_1, ID_2)] = bgenInd;
    IDimpToBgenInd[ID_2] = bgenInd;
    getline(finSample, line);
    bgenInd++;
  }
  finSample.close();
  cout << "Read " << bgenInd << " samples in bgen data" << endl;

  cout << "Reading sample info file: " << sampleInfoFile << endl;
  FileUtils::AutoGzIfstream finSampleInfo; finSampleInfo.openOrExit(sampleInfoFile);
  getline(finSampleInfo, line); // throw away header
  string ID_imp;
  vector <int> famIndToBgenInd;
  while (finSampleInfo >> ID_1 >> ID_2 >> ID_imp) {
    bgenInd = -1;
    string IDpair = makeIDpair(ID_1, ID_2);
    if (IDpairToBgenInd.find(IDpair) != IDpairToBgenInd.end())
      bgenInd = IDpairToBgenInd[IDpair];
    else if (IDimpToBgenInd.find(ID_imp) != IDimpToBgenInd.end())
      bgenInd = IDimpToBgenInd[ID_imp];
    famIndToBgenInd.push_back(bgenInd);
    getline(finSampleInfo, line);
  }
  finSampleInfo.close();
  cout << "Read " << famIndToBgenInd.size() << " samples from sample info file" << endl;
  
  vector <int> bgenInds(famInds.size());
  for (uint i = 0; i < famInds.size(); i++)
    bgenInds[i] = famIndToBgenInd[famInds[i]];
  int numMissing = count(bgenInds.begin(), bgenInds.end(), -1);
  if (numMissing)
    cout << "WARNING: " << numMissing << " phased samples are not in bgen data" << endl;

  return bgenInds;
}

int main(int argc, char *argv[]) {

  printVersion();

  cout << "make_imp_masks:" << endl;
  cout << "- arg1 = $NEARBY_IMP_MASK_FILE (output)" << endl;
  cout << "- arg2 = $SAMPLE_INFO_FILE" << endl;
  cout << "- arg3 = $BIM_FILE" << endl;
  cout << "- arg4 = $IBD_FILE" << endl;
  cout << "- arg5 = $MAX_BP_DIST" << endl;
  cout << "- arg6 = $DOSAGE_THRESH" << endl;
  cout << "- arg7+ = BGEN files (must end with .bgen; .sample files must also exist" << endl;
  cout << endl;

  printCmd(argc, argv);

  if (argc < 8) {
    cout << "ERROR: 7+ arguments required" << endl;
    exit(1);
  }

  const char *nearbyImpMaskFile = checkInputFileExt(argv, 1, ".bin");
  const char *sampleInfoFile = argv[2];
  const char *bimFile = checkInputFileExt(argv, 3, ".bim");
  const char *ibdFile = checkInputFileExt(argv, 4, ".bin");
  int maxBpDist; assert(sscanf(argv[5], "%d", &maxBpDist));
  double callThresh; assert(sscanf(argv[6], "%lf", &callThresh));

  FileUtils::requireWriteable(nearbyImpMaskFile);
  FileUtils::requireReadable(sampleInfoFile);
  FileUtils::requireReadable(bimFile);
  FileUtils::requireReadable(ibdFile);
  for (int a = 7; a < argc; a++) {
    checkInputFileExt(argv, a, ".bgen");
    FileUtils::requireReadable(argv[a]);
    FileUtils::requireReadable(string(argv[a]).substr(0, strlen(argv[a])-4) + "sample");
  }
  
  Timer timer; double t0 = timer.get_time();

  cout << "Distance threshold for masking genotype probes near imputed variants: " << maxBpDist
       << endl;
  assert(maxBpDist>=0 && maxBpDist<=100);
  cout << "Minimum minor-allele dosage threshold of nearby genotypes: " << callThresh << endl;
  assert(callThresh > 0.01 && callThresh < 1);
  cout << endl;

  // read fam inds of phased samples from header of IBD file
  cout << "Reading phased sample indices from IBD file: " << ibdFile << endl;
  FILE *finBin = fopen(ibdFile, "rb"); assert(finBin != NULL);
  int N; fread(&N, sizeof(int), 1, finBin); // read number of phased samples to analyze
  vector <int> famInds(N);
  fread(&famInds[0], sizeof(int), N, finBin); // read fam inds of phased samples
  fclose(finBin);

  /***** read bim file to get SNP locations *****/
  cout << "Reading bim file: " << bimFile << endl;
  FileUtils::AutoGzIfstream finBim; finBim.openOrExit(bimFile);
  vector <int> bps; vector <string> REFs, ALTs;
  int chr, bp; double genpos; string rsID, REF, ALT;
  while (finBim >> chr >> rsID >> genpos >> bp >> REF >> ALT) {
    bps.push_back(bp);
    REFs.push_back(REF);
    ALTs.push_back(ALT);
  }
  finBim.close();
  int M = bps.size();
  cout << "Read " << M << " SNPs in bim file" << endl;
  
  // allocate memory for masks
  uint64 maskBitsULLs = (N * (uint64) M + 63)>>6;
  uint64 *maskBits = (uint64 *) calloc(maskBitsULLs, sizeof(uint64));
  assert(maskBits != NULL);

  /***** loop through bgen+sample files *****/
  for (int a = 7; a < argc; a++) {
    const char *bgenFile = argv[a];
    string bgenSamplePrefix = string(bgenFile).substr(0, strlen(bgenFile)-5);
    cout << "Processing bgen and sample files: " << bgenSamplePrefix << ".{bgen,sample}" << endl
	 << endl;

    vector <int> bgenInds = getBgenInds(famInds, sampleInfoFile, bgenSamplePrefix + ".sample");

    processBgen(maskBits, bgenSamplePrefix + ".bgen", bgenInds, bps, REFs, ALTs, maxBpDist,
		callThresh);

    cout << endl;
  }

  // summarize and write output
  uint64 bitsSet = 0;
  for (uint64 i = 0; i < maskBitsULLs; i++)
    bitsSet += __builtin_popcountll(maskBits[i]);
  cout << "Fraction of bits set: " << bitsSet / (double) N / M << endl;

  cout << "Writing mask file for genotypes near imputed variants: " << nearbyImpMaskFile << endl;
  FILE *fout = fopen(nearbyImpMaskFile, "wb"); assert(fout != NULL);
  fwrite(maskBits, sizeof(uint64), maskBitsULLs, fout);
  fclose(fout);

  free(maskBits);

  cout << "Finished make_imp_masks; total time = " << timer.get_time()-t0 << " sec" << endl;

  return 0;
}
