// g++ -O3 -fopenmp -Wall -static-libgcc -static-libstdc++ normPhaseRegionsPSVs.cpp -o normPhaseRegionsPSVs -I/n/groups/price/poru/HSPH_SVN/src/EAGLE -I/home/pl88/boost_1_58_0/install/include -L/n/groups/price/poru/external_software/libstdc++/usr/lib/gcc/x86_64-redhat-linux/4.8.5/ -L/n/groups/price/poru/external_software/zlib/zlib-1.2.11 -L/home/pl88/boost_1_58_0/install/lib -Wl,-Bstatic -lboost_iostreams -lz

#include <iostream>
#include <iomanip>
#include <fstream>
#include <sstream>
#include <vector>
#include <string>
#include <queue>
#include <set>
#include <map>
#include <utility>
#include <algorithm>
#include <numeric>
#include <cstdio>
#include <cstdlib>
#include <cassert>
#include <cmath>

#include "omp.h"

#include "Types.hpp"
#include "FileUtils.cpp"
#include "StringUtils.cpp"
#include "NumericUtils.cpp"
#include "MemoryUtils.cpp"
#include "Timer.cpp"

#define N_HAP_NBRS 10
#define MAX_ID 6100000

using namespace std;


struct PSV {
  int chr, bp38; string base;
  string toString(void) const {
    return "chr" + StringUtils::itos(chr) + "_" + StringUtils::itos(bp38) + "_" + base;
  }
};

struct RegionStats {
  double /*meanReads, missFrac,*/ r2prePhasing, r2postPhasing;
  float phasingParam;
  string summary;
  /*
  double LD_r2;
  int LD_region;
  */
};

struct Trio {
  int childInd, parentInds[2];
  Trio(int _childInd = 0, int _fatherInd = 0, int _motherInd = 0) {
    childInd = _childInd;
    parentInds[0] = _fatherInd;
    parentInds[1] = _motherInd;
  }
};

inline double sq(double x) { return x*x; }

double r2(vector <double> x, vector <double> y) {
  int N = x.size();
  double mu = NumericUtils::mean(x);
  for (int i = 0; i < N; i++) x[i] -= mu;
  mu = NumericUtils::mean(y);
  for (int i = 0; i < N; i++) y[i] -= mu;
  return sq(NumericUtils::dot(&x[0], &y[0], N)) / (NumericUtils::norm2(&x[0], N) * NumericUtils::norm2(&y[0], N));
}

string summary(const float *_x, int _N) {
  float *x = new float[_N];
  int N = 0;
  for (int i = 0; i < _N; i++)
    if (!isnanf(_x[i]))
      x[N++] = _x[i];
  sort(x, x+N);
  ostringstream oss;
  oss << x[0] << "\t" << x[N/4] << "\t" << x[N/2] << "\t" << x[3*N/4] << "\t" << x[N-1];
  delete[] x;
  return oss.str();
};

typedef unsigned short uint16;

class HapBitsT {
  uint64 *haploBitsT;
  uint64 Nhaps, M, M64;
public:
  HapBitsT(uint64 _Nhaps, uint64 _M);
  ~HapBitsT(void);
  void setBit(uint64 n, uint64 m);
  void flipBit(uint64 n, uint64 m);
  int getBit(uint64 n, uint64 m) const;
  uint64 getNhaps(void) const;
  uint64 getM(void) const;
  uint64 getHaploBitsSize(void) const;
  uint64* getHaploBitsT(void) const;
  uint64* getHaploBitsTrow(uint64 n) const;
};

void HapBitsT::setBit(uint64 n, uint64 m) { haploBitsT[n*M64 + (m>>6)] |= 1ULL<<(m&63); }

void HapBitsT::flipBit(uint64 n, uint64 m) { haploBitsT[n*M64 + (m>>6)] ^= 1ULL<<(m&63); }

int HapBitsT::getBit(uint64 n, uint64 m) const { return (haploBitsT[n*M64 + (m>>6)]>>(m&63))&1; }

uint64 HapBitsT::getNhaps(void) const { return Nhaps; }

uint64 HapBitsT::getM(void) const { return M; }

uint64 HapBitsT::getHaploBitsSize(void) const { return Nhaps * M64; }

uint64* HapBitsT::getHaploBitsT(void) const { return haploBitsT; }

uint64* HapBitsT::getHaploBitsTrow(uint64 n) const { return haploBitsT + n*M64; }

HapBitsT::HapBitsT(uint64 _Nhaps, uint64 _M) {
  Nhaps = _Nhaps;
  M = _M;
  M64 = (M+63)/64;
  haploBitsT = ALIGNED_MALLOC_UINT64S(Nhaps * M64);
  memset(haploBitsT, 0, Nhaps * M64 * sizeof(haploBitsT[0]));
}

HapBitsT::~HapBitsT(void) { ALIGNED_FREE(haploBitsT); }

void pbwt(const HapBitsT &hapBitsT, bool isFwd,
	  const map < int, vector <int> > &mPhasedToRegionInds, int *lexSorts[][2]) {

  int H = hapBitsT.getNhaps(), M = hapBitsT.getM();

  // initialize work arrays
  int *a1 = new int[H], *d1 = new int[H], *a = new int[H], *b = new int[H], *d = new int[H],
    *e = new int[H];
  for (int n = 0; n < H; n++) {
    a1[n] = n;
    d1[n] = M;
  }

  /***** RUN PBWT *****/
  for (int m = M-1; m >= 0; m--) {
    const int mBit = isFwd ? M-1-m : m; // process SNPs in forward or reverse order
    // compute sort order and divergence array
    int u = 0, v = 0, p = m, q = m;
    for (int i = 0; i < H; i++) {
      if (d1[i] < p) p = d1[i];
      if (d1[i] < q) q = d1[i];
      if (hapBitsT.getBit(a1[i], mBit) == 0) {
	a[u] = a1[i]; d[u] = p; u++; p = M;
      }
      else {
	b[v] = a1[i]; e[v] = q; v++; q = M;
      }
    }
    memcpy(a1, a, u * sizeof(a1[0])); memcpy(a1+u, b, v * sizeof(a1[0]));
    memcpy(d1, d, u * sizeof(d1[0])); memcpy(d1+u, e, v * sizeof(d1[0]));
    
    if (mPhasedToRegionInds.find(mBit) != mPhasedToRegionInds.end()) {
      // store sort order at specified SNPs
      lexSorts[mBit][isFwd] = new int[H];
      memcpy(lexSorts[mBit][isFwd], a1, H * sizeof(int));
    }
  }
  delete[] a1;
  delete[] d1;
  delete[] a;
  delete[] b;
  delete[] d;
  delete[] e;
}


const int MAX_ERR = 2;

struct IBSmatch {
  int hap;
  double cMerrs[2][MAX_ERR];
  double len;
  bool operator < (const IBSmatch &match2) const {
    if (len != match2.len) return len > match2.len;
    else return hap < match2.hap;
  }
};

IBSmatch computeIBS(const HapBitsT &hapBitsT, int h1, int h2, int m0, const vector <double> &cMs) {
  int M = hapBitsT.getM();

  IBSmatch matchInfo; matchInfo.hap = h2;

  const double weight0err = 0.5, weightTotalLen = 0.25;

  const uint64 *haploBitsRow1 = hapBitsT.getHaploBitsTrow(h1);
  const uint64 *haploBitsRow2 = hapBitsT.getHaploBitsTrow(h2);

  for (int k = 0; k < 2; k++) {
    int dir = k==0 ? -1 : 1;
    int mEnd = dir==1 ? M : -1;
    int errs = 0;
    uint64 diffBits = 0;
    for (int m = m0; m >= 0 && m < M; m += dir) {
      // fill 64-bit xor buffer at m0 or when entering new block: m%64==0 (fwd) or m%64=63 (rev)
      if (m == m0 || (dir==1 && (m&63)==0) || (dir==-1 && (m&63)==63)) {
	diffBits = haploBitsRow1[m>>6] ^ haploBitsRow2[m>>6];
	if (diffBits == 0) { // all bits match in this block of 64 SNPs
	  if (dir == 1) m |= 63; // skip ahead to next 64-SNP block
	  else m &= ~63; // skip back to previous 64-SNP block
	  continue;
	}
      }

      //if (hapBitsT.getBit(h1, m) != hapBitsT.getBit(h2, m)) {
      if ((diffBits>>(m&63))&1) {
	matchInfo.cMerrs[k][errs] = cMs[m] - cMs[m0];
	errs++;
	if (errs == MAX_ERR)
	  break;
      }
    }
    while (errs < MAX_ERR) {
      matchInfo.cMerrs[k][errs] = cMs[mEnd-dir] - cMs[m0];
      errs++;
    }
  }  
  double lenLeft = weight0err*-matchInfo.cMerrs[0][0] + (1-weight0err)*-matchInfo.cMerrs[0][1];
  double lenRight = weight0err*matchInfo.cMerrs[1][0] + (1-weight0err)*matchInfo.cMerrs[1][1];
  double lenTot = lenLeft + lenRight;
  double lenMin = min(lenLeft, lenRight);
  // randomize for sorting
  matchInfo.len = weightTotalLen*lenTot + (1-weightTotalLen)*lenMin + (h1*h2%1000003)*1e-12;

  return matchInfo;
}

float hapNbrMean(const float *dipCNvec, const float *hapCNvec, int h, int (*hapNbrs)[N_HAP_NBRS]) {
  float sum = 0, num = 0;
  for (int nbr = 0; nbr < N_HAP_NBRS; nbr++) {
    int hNbr = hapNbrs[h][nbr];
    if (!isnanf(dipCNvec[hNbr/2])) {
      sum += hapCNvec[hNbr];
      num++;
    }
  }
  return sum/num;
}

vector < vector <float> > readPSVCNs(int &N_WES, vector <int> &IDtoReadCountsInd,
				     vector <PSV> &PSVs, const char *PSVCNfile) {

  vector < vector <float> > PSVCNs;

  N_WES = 0;
  IDtoReadCountsInd = vector <int> (MAX_ID, -1);
  PSVs.clear();

  // read regions from file header
  FileUtils::AutoGzIfstream finPSVCNs; finPSVCNs.openOrExit(PSVCNfile);
  {
    string line;
    finPSVCNs >> line; assert(line == "ID");
    getline(finPSVCNs, line);
    istringstream iss(line);
    string PSVname;
    PSV psv;
    char buf[100000];
    while (iss >> PSVname) {
      assert(sscanf(PSVname.c_str(), "chr%d_%d_%s", &psv.chr, &psv.bp38, buf) == 3);
      psv.base = buf;
      PSVs.push_back(psv);
    }
    cout << "Read " << PSVs.size() << " PSVs to phase" << endl;
  }

  // read WES read counts per sample
  long long ctrMissing = 0;
  {
    int R = PSVs.size();
    int ID;
    while (finPSVCNs >> ID) {
      assert(IDtoReadCountsInd[ID] == -1);
      IDtoReadCountsInd[ID] = N_WES++;
      vector <float> PSVCNvec(R, NAN);
      string token;
      for (int r = 0; r < R; r++) {
	finPSVCNs >> token;
	ctrMissing += (sscanf(token.c_str(), "%f", &PSVCNvec[r]) != 1);
      }	
      PSVCNs.push_back(PSVCNvec);
    }
  }
  finPSVCNs.close();
  cout << "Read PSVCNs for " << N_WES << " samples" << endl;
  cout << "Overall missingness: " << ctrMissing / (double) N_WES / PSVs.size() << endl;

  return PSVCNs;
}

void processSNPs(map < int, vector <int> > &mPhasedToRegionInds, vector <double> &cMvec,
		 const char *chrStr, const char *bimFile, const char *phasedSnpsFile,
		 const char *lift38File, const vector <PSV> &PSVs) {

  // set up cM lookup
  map <int, double> bp_to_cM;
  {
    ifstream finBim(bimFile);
    int c; string snpStr; double genpos; int bp; string lineStr;
    int Mbim = 0;
    while (finBim >> c >> snpStr >> genpos >> bp) {
      getline(finBim, lineStr);
      bp_to_cM[bp] = 100*genpos;
      Mbim++;
    }
    finBim.close();
    cout << "Read " << Mbim << " SNPs in bim file" << endl;
  }
  
  // set up SNP index lookup (read phased SNP list)
  map <int, int> bpToSnpInd;
  int Mphased = 0;
  {
    FileUtils::AutoGzIfstream finSnps; finSnps.openOrExit(phasedSnpsFile);
    int c; string snpStr; int bp; string lineStr;
    while (finSnps >> c >> snpStr >> bp) {
      getline(finSnps, lineStr);
      bpToSnpInd[bp] = Mphased++;
      cMvec.push_back(bp_to_cM[bp]);
    }
    finSnps.close();
    cout << "Read " << Mphased << " phased SNPs" << endl;
  }

  // read lifted hg38 coordinates
  vector <int> bps38(Mphased, 1<<30);
  {
    FileUtils::AutoGzIfstream finLift38; finLift38.openOrExit(lift38File);
    string cStr; int bp38_1, bp38, bp;
    while (finLift38 >> cStr >> bp38_1 >> bp38 >> bp)
      if (cStr == chrStr && bpToSnpInd.find(bp) != bpToSnpInd.end())
	bps38[bpToSnpInd[bp]] = bp38;
    finLift38.close();
    int ctr38 = 0; for (int m = 0; m < Mphased; m++) if (bps38[m] < (1<<30)) ctr38++;
    cout << "Read hg38 coordinates for " << ctr38 << " phased SNPs" << endl;
  }

  // set up map from closest-SNPs to region indices
  for (uint r = 0; r < PSVs.size(); r++) {
    int bp38 = PSVs[r].bp38;
    int mBest = 0;
    for (int m = 1; m < Mphased; m++)
      if (abs(bp38 - bps38[m]) < abs(bp38 - bps38[mBest]))
	mBest = m;
    mPhasedToRegionInds[mBest].push_back(r);
  }
  cout << "Analyzing haplotype neighbors for " << PSVs.size() << " regions at "
       << mPhasedToRegionInds.size() << " closest SNPs" << endl;  
}

void processSamples(vector <int> &keptIDs, vector <int> &keptReadCountInds, vector <Trio> &trios,
		    vector <bool> &isPhasedSampleKept, const char *phasedSampleFile,
		    const char *trioFile, const vector <int> &IDtoReadCountsInd, int mainIter=2) {

  // read sample IDs in phased haps file
  int Nphased = 0;
  vector <int> phasedIDs;
  map <int, int> IDtoPhasedInd;
  {
    FileUtils::AutoGzIfstream finSample; finSample.openOrExit(phasedSampleFile);
    string lineStr; getline(finSample, lineStr); getline(finSample, lineStr); // ignore header
    int ID; 
    while (finSample >> ID) {
      phasedIDs.push_back(ID);
      IDtoPhasedInd[ID] = Nphased++;
      getline(finSample, lineStr);
    }
    finSample.close();
    cout << "Read " << Nphased << " phased samples" << endl;
  }

  // read trios
  vector <bool> inTrio(Nphased);
  vector < vector <int> > trioIDs;
  {
    FileUtils::AutoGzIfstream finTrios; finTrios.openOrExit(trioFile);
    vector <int> IDs(3);
    while (finTrios >> IDs[0] >> IDs[1] >> IDs[2]) {
      bool good = true;
      for (int k = 0; k < 3; k++)
	if (IDtoPhasedInd.find(IDs[k]) == IDtoPhasedInd.end() || IDtoReadCountsInd[IDs[k]] == -1)
	  good = false;
      if (good) {
	trioIDs.push_back(IDs);
	for (int k = 0; k < 3; k++)
	  inTrio[IDtoPhasedInd[IDs[k]]] = true;
      }
    }
    cout << "Using " << trioIDs.size() << " trios" << endl;
  }

  // select samples to keep
  const int downsample = (mainIter == 1 ? 10 : 1);
  isPhasedSampleKept = vector <bool> (Nphased);
  map <int, int> IDtoKeptInd;
  for (int i = 0; i < Nphased; i++)
    if (i % downsample == 0 || inTrio[i]) {
      isPhasedSampleKept[i] = true;
      IDtoKeptInd[phasedIDs[i]] = keptIDs.size();
      keptIDs.push_back(phasedIDs[i]);
    }
  int Nkept = keptIDs.size();
  keptReadCountInds = vector <int> (Nkept, -1);
  for (int iKept = 0; iKept < Nkept; iKept++)
    if (keptIDs[iKept] > 0)
      keptReadCountInds[iKept] = IDtoReadCountsInd[keptIDs[iKept]];
  cout << "Phasing " << Nkept << " samples" << endl;

  // map trio IDs to indices in keptIDs
  for (uint t = 0; t < trioIDs.size(); t++)
    trios.push_back(Trio(IDtoKeptInd[trioIDs[t][0]], IDtoKeptInd[trioIDs[t][1]],
			 IDtoKeptInd[trioIDs[t][2]]));
}

void computeDipCNs(float *dipCNvec, const vector < vector <float> > &PSVCNs, int r,
		   const vector <int> &keptReadCountInds) {

  for (uint iKept = 0; iKept < keptReadCountInds.size(); iKept++) {
    int iWES = keptReadCountInds[iKept];
    if (iWES == -1) // not in WES
      dipCNvec[iKept] = NAN;
    else
      dipCNvec[iKept] = PSVCNs[iWES][r];
  }
}

int main(int argc, char *argv[]) {

  if (argc != 12) {
    cerr << "Usage:" << endl;
    cerr << "- arg1 = chrStr (chr##)" << endl;
    cerr << "- arg2 = PSVCN .txt[.gz] file" << endl;
    cerr << "- arg3 = bim file" << endl;
    cerr << "- arg4 = phased snps.txt.gz file" << endl;
    cerr << "- arg5 = hg38 lift file" << endl;
    cerr << "- arg6 = phased sample file" << endl;
    cerr << "- arg7 = phased hap64.bin file" << endl;
    cerr << "- arg8 = trio file" << endl;
    cerr << "- arg9 = PBWT band width" << endl;
    cerr << "- arg10 = threads" << endl;
    cerr << "- arg11 = output prefix (chr#_#_#.Q#.txt.gz)" << endl;
    exit(1);
  }
  const char *chrStr = argv[1]; assert(chrStr[0]=='c');
  const char *PSVCNfile = argv[2];
  const char *bimFile = argv[3];
  const char *phasedSnpsFile = argv[4];
  const char *lift38File = argv[5];
  const char *phasedSampleFile = argv[6];
  const char *phasedHap64File = argv[7];
  const char *trioFile = argv[8];
  int pbwtBand; sscanf(argv[9], "%d", &pbwtBand);
  int threads; sscanf(argv[10], "%d", &threads);
  const char *outPrefix = argv[11];

  Timer timer;
  
  cout << "Setting number of threads to " << threads << endl << endl;
  omp_set_num_threads(threads);

  int N_WES;
  vector <int> IDtoReadCountsInd;
  vector <PSV> PSVs;
  vector < vector <float> > PSVCNs = readPSVCNs(N_WES, IDtoReadCountsInd, PSVs, PSVCNfile);
  cout << "\nTime for reading PSVCNs: " << timer.update_time() << " sec\n" << endl;

  map < int, vector <int> > mPhasedToRegionInds; // list of regions closest to SNP
  vector <double> cMvec;
  processSNPs(mPhasedToRegionInds, cMvec, chrStr, bimFile, phasedSnpsFile, lift38File, PSVs);
  int Mphased = cMvec.size();
  int R = PSVs.size();
  cout << "\nTime for processing SNPs: " << timer.update_time() << " sec\n" << endl;

  vector <Trio> trios;
  vector <int> keptIDs, keptReadCountInds;
  vector <bool> isPhasedSampleKept;
  processSamples(keptIDs, keptReadCountInds, trios, isPhasedSampleKept, phasedSampleFile,
		 trioFile, IDtoReadCountsInd);
  int Ntrios = trios.size();
  cout << "\nTime for processing samples: " << timer.update_time() << " sec\n" << endl;

  // read haploBitsT
  int H = 2*keptIDs.size();
  HapBitsT hapBitsT(H, Mphased);
  uint64 M64 = (Mphased+63)/64;
  FILE *finBin = fopen(phasedHap64File, "rb");
  int Nkept = 0;
  for (uint i = 0; i < isPhasedSampleKept.size(); i++) {
    if (isPhasedSampleKept[i]) {
      assert(fread(hapBitsT.getHaploBitsT() + Nkept*2*M64, sizeof(uint64), 2*M64, finBin)
	     == 2*M64);
      Nkept++;
    }
    else
      assert(fseek(finBin, 2*M64*sizeof(uint64), SEEK_CUR) == 0);
  }
  assert(finBin);
  assert(fgetc(finBin) == EOF);
  fclose(finBin);
  assert(H == 2*Nkept);
  cout << "\nTime for reading haploBitsT: " << timer.update_time() << " sec\n" << endl;

  // run PBWT; save lexicographic sorts at closest-probes
  int *lexSorts[Mphased][2]; memset(lexSorts, 0, Mphased*2*sizeof(int *));
#pragma omp parallel for
  for (int isFwd = 0; isFwd < 2; isFwd++)
    pbwt(hapBitsT, isFwd, mPhasedToRegionInds, lexSorts);

  cout << "\nTime for PBWT: " << timer.update_time() << " sec\n" << endl;

  // open output files
  FileUtils::AutoGzOfstream foutStats, //foutDosage, foutDosageNbrMean, foutFam, foutMap,
    /*foutMapNbrMean,*/ foutHapCN;
  foutStats.openOrExit(outPrefix + string(".stats.txt"));
  /*
  foutDosage.openOrExit(outPrefix + string(".dosage.txt"));
  foutDosageNbrMean.openOrExit(outPrefix + string(".nbrMean.dosage.txt"));
  foutFam.openOrExit(outPrefix + string(".fam"));
  foutMap.openOrExit(outPrefix + string(".map"));
  foutMapNbrMean.openOrExit(outPrefix + string(".nbrMean.map"));
  */
  foutHapCN.openOrExit(outPrefix + string(".hapCN.txt"));
  foutHapCN << "ID";
  for (int i = 0; i < Nkept; i++) {
    //foutFam << keptIDs[i] << "\t" << keptIDs[i] << "\t0\t0\t0\t-9" << endl;
    foutHapCN << "\t" << keptIDs[i] << "_1" << "\t" << keptIDs[i] << "_2";
  }
  foutHapCN << std::fixed << setprecision(3) << endl;
  //foutDosage << std::fixed << setprecision(3);

  /***** RUN NEIGHBOR-FINDING + PHASING *****/

  double tNbr = 0, tPhase = 0, tWrite = 0;

  // allocate storage for haplotype neighbors of each haplotype
  int (*hapNbrs)[N_HAP_NBRS] = new int[H][N_HAP_NBRS];
  // allocate storate for lookup: position of each haplotype in lexSorts[m][1]
  int *fwdInd = new int[H];
  // allocate storage for read counts, dipCNs, hapCNs
  float *dipCNvec = new float[Nkept];
  float *dipCNnbrMean = new float[Nkept];
  float *hapCNvec = new float[H];
  float *hapCNupdates = new float[H];

  vector <RegionStats> regionStats(R);

  for (map < int, vector <int> >::iterator it = mPhasedToRegionInds.begin();
       it != mPhasedToRegionInds.end(); it++) {

    cout << "." << flush;

    int m = it->first; const vector <int> &regionInds = it->second;

    /***** find haplotype neighbors at SNP m *****/

    for (int i = 0; i < H; i++) // set up lookup in PBWT sort order
      fwdInd[lexSorts[m][1][i]] = i;
    
#pragma omp parallel for
    for (int i = 0; i < H; i++) {
      set <IBSmatch> bestMatches;
      set <int> hNbrs0;

      const int h = lexSorts[m][0][i]; // haplotype currently being analyzed
      const int i01[2] = {i, fwdInd[h]};
      
      for (int k = 0; k < 2; k++) {
	for (int j = max(0, i01[k]-pbwtBand); j <= min(H-1, i01[k]+pbwtBand); j++) {
	  if (j == i01[k]) continue;
	  const int hNbr = lexSorts[m][k][j];
	  if (keptReadCountInds[hNbr/2] == -1) continue; // require hNbr to have WES data
	  if (k == 0) hNbrs0.insert(hNbr);
	  if (k == 1 && hNbrs0.count(hNbr)) continue; // don't retry neighbors already considered
	  bestMatches.insert(computeIBS(hapBitsT, h, hNbr, m, cMvec));
	  if (bestMatches.size() > N_HAP_NBRS)
	    bestMatches.erase(--bestMatches.end());
	}
      }
      // store best matches
      int nbr = 0;
      for (set <IBSmatch>::iterator it = bestMatches.begin(); it != bestMatches.end(); it++)
	hapNbrs[h][nbr++] = it->hap;
      assert(nbr == N_HAP_NBRS);
    }

    tNbr += timer.update_time();

    /***** phase all regions with m as closest-SNP [within sub-iters omp parallel for] *****/

    for (uint k = 0; k < regionInds.size(); k++) {
      int r = regionInds[k];

      computeDipCNs(dipCNvec, PSVCNs, r, keptReadCountInds);

      float dipCNchildren[Ntrios], dipCNparents[Ntrios][2], dipCNmidParents[Ntrios];

      // store initial child + mid-parent dipCNs; set left-out parents to missing for now
      for (int t = 0; t < Ntrios; t++) {
	dipCNchildren[t] = dipCNvec[trios[t].childInd];
	for (int parent = 0; parent < 2; parent++) {
	  dipCNparents[t][parent] = dipCNvec[trios[t].parentInds[parent]];
	  dipCNvec[trios[t].parentInds[parent]] = NAN; // mask parent
	}
	dipCNmidParents[t] = (dipCNparents[t][0] + dipCNparents[t][1]) / 2;
      }
      // initialize hapCNs to dipCNs/2 (or 0.5 if missing)
      for (int h = 0; h < H; h++)
	hapCNvec[h] = isnanf(dipCNvec[h/2]) ? 0.5f : 0.5f*dipCNvec[h/2];

      // run phasing iterations
      float &param = regionStats[r].phasingParam;
      for (int iter = 1; iter <= 10; iter++) {
	// optimize R2(refined child dipCN, average dipCN of held-out parents)
	double r2best = 0;
	vector <double> dipCNs0, dipCNs1, dipCNsEst;
	for (int t = 0; t < Ntrios; t++)
	  if (!isnanf(dipCNchildren[t]) && !isnanf(dipCNmidParents[t])) {
	    dipCNs0.push_back(dipCNchildren[t]);
	    dipCNs1.push_back(dipCNmidParents[t]);
	    int iRefined = trios[t].childInd; // refine child dipCN
	    dipCNsEst.push_back(hapNbrMean(dipCNvec, hapCNvec, 2*iRefined, hapNbrs) +
				hapNbrMean(dipCNvec, hapCNvec, 2*iRefined+1, hapNbrs));
	  }
	vector <double> dipCNs0refine(dipCNs0.size());
	for (float p = 0; p <= 0.5f; p += 0.005f) {
	  for (int t = 0; t < (int) dipCNs0.size(); t++)
	    dipCNs0refine[t] = dipCNsEst[t] + 2 * p * (dipCNs0[t] - dipCNsEst[t]);
	  double r2test = r2(dipCNs0refine, dipCNs1);
	  if (r2test > r2best) {
	    r2best = r2test;
	    param = p;
	  }
	}

	// update hapCNs using optimized weight parameter
	const int numBatches = 10;
	for (int b = 1; b <= numBatches; b++) {
	  int iStart = (b-1)*Nkept/numBatches, iEnd = b*Nkept/numBatches;
#pragma omp parallel for
	  for (int i = iStart; i < iEnd; i++) {
	    dipCNnbrMean[i] = 0;
	    for (int h = 2*i; h < 2*i+2; h++) {
	      hapCNupdates[h] = hapNbrMean(dipCNvec, hapCNvec, h, hapNbrs);
	      dipCNnbrMean[i] += hapCNupdates[h];
	    }
	    if (!isnanf(dipCNvec[i]))
	      for (int h = 2*i; h < 2*i+2; h++)
		hapCNupdates[h] += param * (dipCNvec[i] - dipCNnbrMean[i]);
	  }
	  // apply all updates computed for hapCNs in the current batch
	  for (int h = iStart*2; h < iEnd*2; h++)
	    hapCNvec[h] = hapCNupdates[h];
	}
      }

      // output estimated R2 to truth pre- and post-phasing
      vector <double> dipCNs0, dipCNs1, dipCNs0refine;
      for (int t = 0; t < Ntrios; t++)
	if (!isnanf(dipCNchildren[t]) && !isnanf(dipCNmidParents[t])) {
	  dipCNs0.push_back(dipCNchildren[t]);
	  dipCNs1.push_back(dipCNmidParents[t]);
	  int iRefined = trios[t].childInd; // refine child dipCN
	  dipCNs0refine.push_back(hapCNvec[2*iRefined] + hapCNvec[2*iRefined+1]);
	}
      regionStats[r].r2prePhasing = sqrt(2*r2(dipCNs0, dipCNs1));
      regionStats[r].r2postPhasing = 2*r2(dipCNs0refine, dipCNs1) / regionStats[r].r2prePhasing;

      // update estimates in left-out parents: use masked dipCNs (but don't bother propagating)
      for (int t = 0; t < Ntrios; t++)
	for (int parent = 0; parent < 2; parent++) {
	  int iLeftOut = trios[t].parentInds[parent];
	  dipCNvec[iLeftOut] = dipCNparents[t][parent]; // unmask parent
	  if (!isnanf(dipCNvec[iLeftOut])) {
	    for (int h = 2*iLeftOut; h < 2*iLeftOut+2; h++)
	      hapCNvec[h] += param * (dipCNvec[iLeftOut] - dipCNnbrMean[iLeftOut]);
	  }
	}

      tPhase += timer.update_time();

      // output hapCNs
      foutHapCN << PSVs[r].toString();
      for (int h = 0; h < H; h++)
	foutHapCN << "\t" << hapCNvec[h];
      foutHapCN << endl;
      /*
      // output dipCNs in PLINK 1 dosage format
      foutDosage << cnvRegionName << "\t" << "A1" << "\t" << "A2";
      foutDosageNbrMean << cnvRegionName << ".nbrMean\t" << "A1" << "\t" << "A2";
      foutMap << chrStr << "\t" << cnvRegionName << "\t" << 0 << "\t"
	      << regions[r].start38 << endl;
      foutMapNbrMean << chrStr << "\t" << cnvRegionName << ".nbrMean\t" << 0 << "\t"
		     << regions[r].start38 << endl;
      float minDosage = hapCNvec[0]+hapCNvec[1], maxDosage = minDosage;
      for (int i = 0; i < Nkept; i++) {
	float dosage = hapCNvec[2*i] + hapCNvec[2*i+1];
	if (dosage < minDosage) minDosage = dosage;
	if (dosage > maxDosage) maxDosage = dosage;
      }
      for (int i = 0; i < Nkept; i++) {
	foutDosage << "\t" << 2*(hapCNvec[2*i]+hapCNvec[2*i+1]-minDosage)/(maxDosage-minDosage);
	foutDosageNbrMean << "\t" << max(0.0f, min(2.0f, 2*(dipCNnbrMean[i]-minDosage)/(maxDosage-minDosage)));
      }
      foutDosage << endl;
      foutDosageNbrMean << endl;
      */
      tWrite += timer.update_time();
	
      regionStats[r].summary = summary(hapCNvec, H);
    }
  }

  // print region stats
  foutStats << "region"
	    << "\t" << "r2prePhasing"
	    << "\t" << "r2postPhasing"
	    << "\t" << "phasingParam"
	    << "\t" << "min\t1stQ\tmedian\t3rdQ\tmax";
  foutStats << endl;
  
  for (int r = 0; r < R; r++) {
    foutStats << PSVs[r].toString()
	      << "\t" << regionStats[r].r2prePhasing
	      << "\t" << regionStats[r].r2postPhasing
	      << "\t" << regionStats[r].phasingParam
	      << "\t" << regionStats[r].summary;
    foutStats << endl;
  }
  foutStats.close(); /*foutDosage.close(); foutDosageNbrMean.close(); foutFam.close();
		       foutMap.close(); foutMapNbrMean.close();*/ foutHapCN.close();

  cout << endl;
  cout << "Time for finding haplotype neighbors: " << tNbr << endl;
  cout << "Time for phasing: " << tPhase << endl;
  cout << "Time for writing output: " << tWrite << endl;

  for (int m = 0; m < Mphased; m++)
    for (int isFwd = 0; isFwd < 2; isFwd++)
      if (lexSorts[m][isFwd] != NULL)
	delete[] lexSorts[m][isFwd];
  delete[] dipCNvec;
  delete[] dipCNnbrMean;
  delete[] hapCNvec;
  delete[] hapCNupdates;
  delete[] fwdInd;
  delete[] hapNbrs;

  return 0;
}
