#include <iostream>
#include <vector>
#include <string>
#include <queue>
#include <set>
#include <map>
#include <utility>
#include <algorithm>
#include <cstdio>
#include <cstdlib>
#include <cassert>
#include <cmath>

#include "omp.h"

#include "VersionHeader.hpp"
#include "FileUtils.cpp"
#include "StringUtils.cpp"
#include "Timer.cpp"

using namespace std;

struct ErrInfo {
  char err;
  char conf;
  float cMtot;
  float cMedge;
  ErrInfo(char _err=0, char _conf=0, float _cMtot=0, float _cMedge=0) :
    err(_err), conf(_conf), cMtot(_cMtot), cMedge(_cMedge) {}
};

struct GenoInfo {
  char lrr;
  char theta;
  unsigned char geno: 2;
  unsigned char conf: 6;
};

struct Match {
  int hap; ushort mStart, mEnd;
  Match(int _hap=0, ushort _mStart=0, ushort _mEnd=0) : hap(_hap), mStart(_mStart), mEnd(_mEnd) {}
  bool operator < (const Match &match) const { // endOpp_hap_start sort order for active set PQ
    if (mEnd != match.mEnd) return mEnd > match.mEnd; // opp for priority queue!
    else if (hap != match.hap) return hap < match.hap;
    else return mStart < match.mStart;
  }
};
bool comp_hap_end_start(const Match &match1, const Match &match2) {
  if (match1.hap != match2.hap) return match1.hap < match2.hap;
  else if (match1.mEnd != match2.mEnd) return match1.mEnd < match2.mEnd;
  else return match1.mStart > match2.mStart;
}
bool comp_hap_start_end(const Match &match1, const Match &match2) {
  if (match1.hap != match2.hap) return match1.hap < match2.hap;
  else if (match1.mStart != match2.mStart) return match1.mStart < match2.mStart;
  else return match1.mEnd < match2.mEnd;
}
bool comp_start_endOpp_hap(const Match &match1, const Match &match2) {
  if (match1.mStart != match2.mStart) return match1.mStart < match2.mStart;
  else if (match1.mEnd != match2.mEnd) return match1.mEnd > match2.mEnd;
  else return match1.hap < match2.hap;
}

struct Match_cM {
  Match match; float cMlen;
  Match_cM(const Match &_match, const vector <float> &cMvec) :
    match(_match), cMlen(cMvec[_match.mEnd-1]-cMvec[_match.mStart]) {}
  bool operator < (const Match_cM &m) const { // sort order for active set
    if (cMlen != m.cMlen) return cMlen > m.cMlen; // longest first
    else return match < m.match;
  }
};

class HapBitsT {
  uint64 *haploBitsT;
  uint64 Nhaps, M, M64;
public:
  HapBitsT(uint64 _Nhaps, uint64 _M);
  ~HapBitsT(void);
  void setBit(uint64 n, uint64 m);
  void flipBit(uint64 n, uint64 m);
  int getBit(uint64 n, uint64 m) const;
  uint64 getHaploBitsSize(void) const;
  uint64* getHaploBitsT(void) const;
  uint64* getHaploBitsTrow(uint64 n) const;
};

void HapBitsT::setBit(uint64 n, uint64 m) { haploBitsT[n*M64 + (m>>6)] |= 1ULL<<(m&63); }

void HapBitsT::flipBit(uint64 n, uint64 m) { haploBitsT[n*M64 + (m>>6)] ^= 1ULL<<(m&63); }

int HapBitsT::getBit(uint64 n, uint64 m) const { return (haploBitsT[n*M64 + (m>>6)]>>(m&63))&1; }

uint64 HapBitsT::getHaploBitsSize(void) const { return Nhaps * M64; }

uint64* HapBitsT::getHaploBitsT(void) const { return haploBitsT; }

uint64* HapBitsT::getHaploBitsTrow(uint64 n) const { return haploBitsT + n*M64; }

HapBitsT::HapBitsT(uint64 _Nhaps, uint64 _M) {
  Nhaps = _Nhaps;
  M = _M;
  M64 = (M+63)/64;
  haploBitsT = new uint64[Nhaps * M64];
  memset(haploBitsT, 0, Nhaps * M64 * sizeof(haploBitsT[0]));
}

HapBitsT::~HapBitsT(void) { delete[] haploBitsT; }

inline int popcount64_01(uint64 i) {
  return i!=0;
}
inline int popcount64_012(uint64 i) {
  if (i == 0) return 0;
  else if ((i & (i-1ULL)) == 0) return 1;
  else return 2;
}

inline int popcount64(uint64 i) {
  i = i - ((i >> 1) & 0x5555555555555555);
  i = (i & 0x3333333333333333) + ((i >> 2) & 0x3333333333333333);
  i = (i + (i >> 4)) & 0xF0F0F0F0F0F0F0F;
  return (i * 0x101010101010101) >> 56;
}

int compute_err_score(uint64 diff64, uint64 confLoBit1, uint64 confHiBit1, uint64 confLoBit2,
		      uint64 confHiBit2) {
  diff64 &= confHiBit1 & confHiBit2; // mask any bits for which either sample is not confident
  uint64 confLoBit = confLoBit1 & confLoBit2; // mask for highest-conf (both lo bits set)
  return 2*popcount64(diff64 & ~confLoBit) + 4*popcount64(diff64 & confLoBit);
}

string makeIDpair(const string &ID_1, const string &ID_2) { return ID_1 + "\t" + ID_2; }

int isectFamSample(vector <bool> &inFam, vector <int> &famIndsDropMissing, const string &famFile,
		   const string &sampleFile) {

  cout << "Reading fam file: " << famFile << endl;
  FileUtils::AutoGzIfstream finFam; finFam.openOrExit(famFile);
  map <string, int> IDpairToFamInd;
  string ID_1, ID_2, line;
  int Nfam = 0;
  while (finFam >> ID_1 >> ID_2) {
    IDpairToFamInd[makeIDpair(ID_1, ID_2)] = Nfam++;
    getline(finFam, line);
  }
  finFam.close();
  cout << "Read " << Nfam << " samples in .fam file" << endl;
  assert(Nfam == (int) IDpairToFamInd.size());

  cout << "Reading sample file: " << sampleFile << endl;
  FileUtils::AutoGzIfstream finSample; finSample.openOrExit(sampleFile);
  getline(finSample, line); getline(finSample, line); // throw away header lines
  while (finSample >> ID_1 >> ID_2) {
    string IDpair = makeIDpair(ID_1, ID_2);
    inFam.push_back(IDpairToFamInd.find(IDpair) != IDpairToFamInd.end());
    if (inFam.back())
      famIndsDropMissing.push_back(IDpairToFamInd[IDpair]);
    getline(finSample, line);
  }
  finSample.close();
  cout << "Read " << inFam.size() << " phased samples in .sample file" << endl;
  cout << "Analyzing " << famIndsDropMissing.size() << " samples in intersection" << endl;

  return Nfam;
}

void lookupSnpsInBim(vector <int> &mToBimIndex, vector <float> &cMvec, const string &bimFile,
		     const string &hapsFile, int Nsample) {

  cout << "Reading bim file: " << bimFile << endl;
  FileUtils::AutoGzIfstream finBim; finBim.openOrExit(bimFile);
  map < string, pair <int, float> > snpToCoords;
  int chr; string snpStr, line; float genpos;
  int Mbim = 0;
  while (finBim >> chr >> snpStr >> genpos) {
    snpToCoords[snpStr] = make_pair(Mbim++, 100*genpos);
    getline(finBim, line);
  }
  finBim.close();
  cout << "Read " << Mbim << " SNPs in .bim file" << endl;
  assert(Mbim == (int) snpToCoords.size());
  
  cout << "Reading SNP IDs from haps file: " << hapsFile << endl;
  FileUtils::AutoGzIfstream finHaps; finHaps.openOrExit(hapsFile);
  int bp; string allele1, allele2;
  while (finHaps >> chr >> snpStr >> bp >> allele1 >> allele2) {
    getline(finHaps, line);
    assert(line.length() == 4U * Nsample);
    assert(snpToCoords.find(snpStr) != snpToCoords.end());
    if (!mToBimIndex.empty()) {
      assert(snpToCoords[snpStr].first > mToBimIndex.back());
      assert(snpToCoords[snpStr].second >= cMvec.back());
    }
    mToBimIndex.push_back(snpToCoords[snpStr].first);
    cMvec.push_back(snpToCoords[snpStr].second);
  }
  mToBimIndex.push_back(Mbim); // append sentinel value
  finHaps.close();
  cout << "Read " << cMvec.size() << " phased SNPs in .haps.gz file" << endl;
}

void readHaps(HapBitsT &hapBitsT, const string &hapsFile, const vector <bool> &inFam) {

  cout << "Reading haplotypes from haps file: " << hapsFile << endl;
  FileUtils::AutoGzIfstream finHaps; finHaps.openOrExit(hapsFile);
  int Nsample = inFam.size();
  char *hapLine = new char[4*Nsample];
  int chr; string snpStr; int bp; string allele1, allele2;
  int m = 0;
  while (finHaps >> chr >> snpStr >> bp >> allele1 >> allele2) {
    finHaps.read(hapLine, 4*Nsample);
    int iDropMissing = 0;
    for (int iSample = 0; iSample < Nsample; iSample++)
      if (inFam[iSample]) {
	if (hapLine[4*iSample+1]=='1') // h1
	  hapBitsT.setBit(2*iDropMissing, m);
	if (hapLine[4*iSample+3]=='1') // h2
	  hapBitsT.setBit(2*iDropMissing+1, m);
	iDropMissing++;
      }
    m++;
  }
  delete[] hapLine;
  finHaps.close();
}

void readConfs(HapBitsT &confBitsT, const char *lrrThetaGenoFile, int Nfam,
	       const vector <int> &famIndsDropMissing, const vector <int> &mToBimIndex) {

  cout << "Reading genotype call missingness/confidence from " << lrrThetaGenoFile << endl;
  FILE *finBin = fopen(lrrThetaGenoFile, "rb"); assert(finBin != NULL);
  int N = famIndsDropMissing.size();
  GenoInfo *genoRow = new GenoInfo[Nfam];
  for (uint m = 0; m < mToBimIndex.size(); m++) {
    fseek(finBin, (uint64) mToBimIndex[m] * Nfam * sizeof(GenoInfo), SEEK_SET);
    fread(&genoRow[0], sizeof(GenoInfo), Nfam, finBin);
    for (int n = 0; n < N; n++) {
      int famInd = famIndsDropMissing[n];
      int h1 = 2*n, h2 = 2*n+1; // two "haps" for bit4 (16+/48+), bit5 (>=32)
      if ((genoRow[famInd].conf>>4)&1)
	confBitsT.setBit(h1, m);
      if ((genoRow[famInd].conf>>5)&1)
	confBitsT.setBit(h2, m);
    }
  }
  delete[] genoRow;
  fclose(finBin);
}  

int main(int argc, char *argv[]) {

  printVersion();

  cout << "find_IBD:" << endl;
  cout << "- arg1 = $BED_BIM_FAM_PREFIX" << endl;
  cout << "- arg2 = $HAPS_SAMPLE_PREFIX" << endl;
  cout << "- arg3 = $LRR_THETA_GENO_FILE" << endl;
  cout << "- arg4 = $THREADS" << endl;
  cout << "- arg5 = $IBD_FILE (output)" << endl;
  cout << endl;

  printCmd(argc, argv);

  if (argc != 6) {
    cout << "ERROR: 5 arguments required" << endl;
    exit(1);
  }
  
  const char *plinkPrefix = argv[1];
  const char *hapsSamplePrefix = argv[2];
  const char *lrrThetaGenoFile = checkInputFileExt(argv, 3, ".bin");
  int threads; assert(sscanf(argv[4], "%d", &threads));
  const char *outFile = argv[5];

  FileUtils::requireReadable(plinkPrefix + string(".bim"));
  FileUtils::requireReadable(plinkPrefix + string(".fam"));
  FileUtils::requireReadable(hapsSamplePrefix + string(".haps.gz"));
  FileUtils::requireReadable(hapsSamplePrefix + string(".sample"));
  FileUtils::requireReadable(lrrThetaGenoFile);
  FileUtils::requireWriteable(outFile);

  cout << "Setting number of threads to " << threads << endl;
  assert(threads > 0);
  cout << endl;
  omp_set_num_threads(threads);

  Timer timer; double t0 = timer.get_time();
  
  // set up sample indexing
  vector <bool> inFam; vector <int> famIndsDropMissing;
  const int Nfam = isectFamSample(inFam, famIndsDropMissing, plinkPrefix + string(".fam"),
				  hapsSamplePrefix + string(".sample"));
  const int N = famIndsDropMissing.size();
  const int H = 2*N;

  // set up SNP indexing
  vector <int> mToBimIndex; vector <float> cMvec;
  lookupSnpsInBim(mToBimIndex, cMvec, plinkPrefix + string(".bim"),
		  hapsSamplePrefix + string(".haps.gz"), inFam.size());
  const int M = cMvec.size(); assert(mToBimIndex.size() == M+1U);
  const int M64 = (M+63)/64;

  // read haplotypes
  HapBitsT hapBitsT(H, M);
  readHaps(hapBitsT, hapsSamplePrefix + string(".haps.gz"), inFam);

  // read genotype call confidences
  HapBitsT confBitsT(H, M);
  readConfs(confBitsT, lrrThetaGenoFile, Nfam, famIndsDropMissing, mToBimIndex);

  cout << "Finished reading input (time = " << timer.update_time() << " sec)" << endl;

  for (int iter = 1; iter <= 2; iter++) {

  int maxNumLong = 5*iter; // number of longest matches to store

  // set up PBWT
  vector < vector <Match> > matches(H);
  vector <ushort> mEndMin(H);
  vector <char> prevNumLong(H);
  const int skip = 32; // skip between updating matches
  const int minLen = 128; // minimum match length

  // initialize work arrays
  int *a1 = new int[H], *d1 = new int[H], *a = new int[H], *b = new int[H], *d = new int[H],
    *e = new int[H];
  for (int n = 0; n < H; n++) {
    a1[n] = n;
    d1[n] = M;
  }

  double tPBWT = 0, tSeed = 0;

  /***** RUN PBWT *****/

  for (int m = M-1; m >= 0; m--) {
    // compute sort order and divergence array
    int u = 0, v = 0, p = m, q = m;
    for (int i = 0; i < H; i++) {
      if (d1[i] < p) p = d1[i];
      if (d1[i] < q) q = d1[i];
      if (hapBitsT.getBit(a1[i], m) == 0) {
	a[u] = a1[i]; d[u] = p; u++; p = M;
      }
      else {
	b[v] = a1[i]; e[v] = q; v++; q = M;
      }
    }
    memcpy(a1, a, u * sizeof(a1[0])); memcpy(a1+u, b, v * sizeof(a1[0]));
    memcpy(d1, d, u * sizeof(d1[0])); memcpy(d1+u, e, v * sizeof(d1[0]));
    
    tPBWT += timer.update_time();

    if (m % skip == 0) {
      //cout << "seed at " << m << endl;
      cout << "." << flush;
#pragma omp parallel for
      for (int i = 0; i < H; i++) {
	Match hMatches[maxNumLong];
	char toErase[maxNumLong];
	int h = a1[i]; //if (h >= 10) continue;
	//cout << endl;
	// walk backward and forward, keeping track of divergence: populate new best matches
	int iUp = i-1, iDown = i+1;
	int mEndUp = iUp >= 0 ? d1[iUp+1] : -1;
	int mEndDown = iDown < H ? d1[iDown] : -1;
	int mEndPrevMinCtr = 0; // count number of new matches w/ previous worst mEnd
	int numLong = 0; // count number of new matches
	for (int k = 0; k < maxNumLong; k++) {
	  if (mEndUp >= mEndDown) {
	    if (mEndUp-m < minLen) break;
	    hMatches[numLong++] = Match(a1[iUp], m, mEndUp);
	    iUp--;
	    mEndUp = iUp >= 0 ? min(mEndUp, d1[iUp+1]) : -1;
	  }
	  else {
	    if (mEndDown-m < minLen) break;
	    hMatches[numLong++] = Match(a1[iDown], m, mEndDown);
	    iDown++;
	    mEndDown = iDown < H ? min(mEndDown, d1[iDown]) : -1;
	  }
	  mEndPrevMinCtr += hMatches[k].mEnd == mEndMin[h];
	}
	int numLongPrev = prevNumLong[h];
	int mEndCurMin = numLong ? hMatches[numLong-1].mEnd : 0; // record new worst mEnd
	sort(hMatches, hMatches+numLong, comp_hap_end_start); // sort new matches by hap index
	
	vector <Match> &hMatchesAll = matches[h];
	int hMatchesAllSize = hMatchesAll.size();
	if (!hMatchesAll.empty() && numLong) {
	  // for each match in prev batch, check if present in new batch w/ same mEnd
	  int k = 0;
	  memset(toErase, 0, numLongPrev);
	  for (int j = hMatchesAllSize-numLongPrev; j < hMatchesAllSize; j++) {
	    while (k < numLong && comp_hap_end_start(hMatches[k], hMatchesAll[j])) k++;
	    if (k == numLong) break;
	    if (hMatchesAll[j].hap == hMatches[k].hap && hMatchesAll[j].mEnd == hMatches[k].mEnd) {
	      toErase[j - (hMatchesAllSize-numLongPrev)] = true; // mark for deletion
	      if (hMatches[k].mEnd == mEndMin[h])
		mEndPrevMinCtr--; // update count of available new matches w/ previous worst mEnd
	    }
	  }
	  // for each match in prev batch not marked for deletion and w/ prev worst mEnd
	  if (mEndPrevMinCtr)
	    for (int j = hMatchesAllSize-numLongPrev; j < hMatchesAllSize; j++)
	      if (!toErase[j-(hMatchesAllSize-numLongPrev)] && hMatchesAll[j].mEnd == mEndMin[h]) {
		toErase[j-(hMatchesAllSize-numLongPrev)] = true; // mark for deletion
		mEndPrevMinCtr--;
		if (mEndPrevMinCtr == 0)
		  break;
	      }
	  // delete and collapse
	  int jCollapsed = hMatchesAllSize-numLongPrev;
	  for (int j = hMatchesAllSize-numLongPrev; j < hMatchesAllSize; j++) {
	    if (!toErase[j-(hMatchesAllSize-numLongPrev)])
	      hMatchesAll[jCollapsed++] = hMatchesAll[j];
	  }
	  hMatchesAll.resize(jCollapsed);
	}
	// copy in new batch
	hMatchesAll.insert(hMatchesAll.end(), hMatches, hMatches+numLong);
	mEndMin[h] = mEndCurMin;
	prevNumLong[h] = numLong;
      }
      tSeed += timer.update_time();
    }
  }
  delete[] a1;
  delete[] d1;
  delete[] a;
  delete[] b;
  delete[] d;
  delete[] e;

  cout << endl;
  cout << "Time for PBWT: " << tPBWT << endl;
  cout << "Time for seed-finding: " << tSeed << endl;

  long long matchCount = 0;
  for (int h = 0; h < H; h++) matchCount += matches[h].size();
  cout << "Average number of matches: " << matchCount / H << endl;


  matchCount = 0; // count matches used
  vector < vector <double> > fracAbove(maxNumLong, vector <double> (11)); // frac > X cM
  vector <int> errCounts(M), totCounts(M);
  vector < vector <ushort> > flips; if (iter == 1) flips.resize(H);

#pragma omp parallel for schedule(static)
  for (int h1 = 0; h1 < H; h1++) {
    char err_scores[M64]; // stack space for recording errors while walking left and right
    const uint64* haploBitsT1 = hapBitsT.getHaploBitsTrow(h1);
    const uint64* confBitsLoT1 = confBitsT.getHaploBitsTrow(h1&~1);
    const uint64* confBitsHiT1 = confBitsT.getHaploBitsTrow((h1&~1)+1);

    /***** EXTEND SEEDS *****/

    for (int j = 0; j < (int) matches[h1].size(); j++) {
      Match &match = matches[h1][j];
      int h2 = match.hap;
      const uint64* haploBitsT2 = hapBitsT.getHaploBitsTrow(h2);
      const uint64* confBitsLoT2 = confBitsT.getHaploBitsTrow(h2&~1);
      const uint64* confBitsHiT2 = confBitsT.getHaploBitsTrow((h2&~1)+1);
      
      int mExtStart = -1, mExtEnd = -1;
      for (int dir = -1; dir <= 1; dir += 2) {
	int m64 = (dir==-1 ? match.mStart : match.mEnd)>>6;
	int cur_score = 0, best_score = 0, best_m64 = m64 - dir;
	while (m64 >= 0 && m64 < M64) {
	  err_scores[m64] = compute_err_score(haploBitsT1[m64] ^ haploBitsT2[m64],
					      confBitsLoT1[m64], confBitsHiT1[m64],
					      confBitsLoT2[m64], confBitsHiT2[m64]);
	  cur_score += 1 - err_scores[m64]; // +1 for 0-err, 0 for 1-err, -1 for 2-or-more-err
	  if (cur_score > best_score) {
	    best_score = cur_score;
	    best_m64 = m64;
	  }
	  else if ((m64 - best_m64)*dir >= (13-(best_score-cur_score)))
	    break; // require score to break even within the next several 64-bit chunks
	  m64 += dir;
	}

	m64 = best_m64 + dir; // end at last 0-err seg; extend to the next err

	if (dir == -1) { // find exact start position
	  if (m64 == -1)
	    mExtStart = 0;
	  else { // back up to last 1-err/2-err bit before first 0-err 64 on left
	    int highMismatch = (m64<<6) + 63-__builtin_clzll(haploBitsT1[m64] ^ haploBitsT2[m64]);
	    assert(hapBitsT.getBit(h1, highMismatch) != hapBitsT.getBit(h2, highMismatch));
	    mExtStart = highMismatch + 1;
	  }
	}
	else { // find exact end position
	  if (m64 == M64)
	    mExtEnd = M;
	  else { // back up to first 1-err/2-err bit after last 0-err 64 on right
	    int lowMismatch = (m64<<6) + __builtin_ctzll(haploBitsT1[m64] ^ haploBitsT2[m64]);
	    assert(hapBitsT.getBit(h1, lowMismatch) != hapBitsT.getBit(h2, lowMismatch));
	    mExtEnd = lowMismatch;
	  }
	}
      }

      // update match boundaries with extensions
      match.mStart = mExtStart;
      match.mEnd = mExtEnd;
    }


    /***** DEDUP AND MERGE MATCHES *****/

    sort(matches[h1].begin(), matches[h1].end(), comp_hap_start_end);
    vector <Match> matchesMerged;
    if (!matches[h1].empty()) matchesMerged.push_back(matches[h1][0]);
    for (int j = 1; j < (int) matches[h1].size(); j++) {
      const Match &cur = matches[h1][j]; Match &prev = matchesMerged[matchesMerged.size()-1];
      if (cur.hap == prev.hap && cur.mStart <= prev.mEnd) { // redundant/overlaps w/ previous match
	if (cur.mEnd > prev.mEnd) // merge with previous match: set end to max of ends
	  prev.mEnd = cur.mEnd;
      }
      else
	matchesMerged.push_back(cur);
    }


    /***** PRUNE TO TOP maxNumLong MATCHES PER POSITION *****/

    // in iter 1, identify errors to flip for iter 2 PBWT
    const float cMminEdge = 0.5;
    const int Nflip = 5;
    char votesFlip[M][2]; memset(votesFlip, 0, M*2);
    // record errors for diagnostic output

    priority_queue <Match> activeMatchPQ;
    set <Match_cM> activeMatches;
    set <Match> usedMatches;
    sort(matchesMerged.begin(), matchesMerged.end(), comp_start_endOpp_hap); // sort by start
    uint jMerged = 0; // position in matchesMerged list
    int mPrev = 0;

    // iterate through matches in order of mStart
    while (jMerged < matchesMerged.size() || !activeMatchPQ.empty()) {
      int earliestActiveEnd = activeMatchPQ.empty() ? (1<<30) : (int) activeMatchPQ.top().mEnd;
      int incomingStart =
	(jMerged < matchesMerged.size()) ? (int) matchesMerged[jMerged].mStart : (1<<30);
      int mFirst = min(earliestActiveEnd, incomingStart);

      // augment used-set with top maxNumLong matches from activeMatches set

      //cout << "best to " << mFirst << ":";
      int topCtr = 0;
      for (set <Match_cM>::iterator it = activeMatches.begin();
	   it != activeMatches.end() && topCtr < maxNumLong; it++, topCtr++) {

	// update error counts; downsample to speed up
	if (iter==1 && ((h1&63)==0) && !usedMatches.count(it->match)) {
	  if (it->cMlen > 4) {
#pragma omp critical
	    for (int m = it->match.mStart; m < it->match.mEnd; m++) {
	      errCounts[m] += hapBitsT.getBit(h1, m) != hapBitsT.getBit(it->match.hap, m);
	      totCounts[m]++;
	    }
	  }
	}

	// augment used-set
	usedMatches.insert(it->match);

	// update fraction of matches >X cM
	if ((h1&63)==0) {
#pragma omp critical
	  for (int cMmin = 1; cMmin <= 10; cMmin++)
	    fracAbove[topCtr][cMmin] += (mFirst - mPrev) * (it->cMlen > cMmin);
	}

      }

      // update flip votes from top Nflip matches from activeMatches set
      if (iter == 1) {
	topCtr = 0;
	for (set <Match_cM>::iterator it = activeMatches.begin();
	     it != activeMatches.end() && topCtr < Nflip; it++, topCtr++)
	  for (int m = mPrev; m < mFirst; m++)
	    if (confBitsT.getBit((h1&~1)+1, m) == 0 // h1 not confident (high bit = 0)
		&& cMvec[m] - cMvec[it->match.mStart] > cMminEdge
		&& cMvec[it->match.mEnd-1] - cMvec[m] > cMminEdge)
	      votesFlip[m][hapBitsT.getBit(h1, m) != hapBitsT.getBit(it->match.hap, m)]++;
	for (int m = mPrev; m < mFirst; m++)
	  if (votesFlip[m][0] + votesFlip[m][1] >= 4 && votesFlip[m][0] <= 1)
	    flips[h1].push_back(m);
      }

      // pop all matches with earliestActiveEnd; delete them from activeMatches set
      if (earliestActiveEnd <= incomingStart) {
	while (!activeMatchPQ.empty() && activeMatchPQ.top().mEnd == earliestActiveEnd) {
	  activeMatches.erase(Match_cM(activeMatchPQ.top(), cMvec));
	  activeMatchPQ.pop();
	}
      }
      // add all matches with incomingStart to active set
      if (incomingStart <= earliestActiveEnd) {
	while (jMerged < matchesMerged.size() && matchesMerged[jMerged].mStart == incomingStart) {
	  activeMatchPQ.push(matchesMerged[jMerged]);
	  activeMatches.insert(Match_cM(matchesMerged[jMerged], cMvec));
	  jMerged++;
	}
      }
      mPrev = mFirst;
    } // [end iteration through matches in order of mStart]

    if (iter == 2) // collapse matches
      matches[h1] = vector <Match> (usedMatches.begin(), usedMatches.end());

#pragma omp atomic
    matchCount += usedMatches.size();
  }

  cout << "Average number of pruned matches: " << matchCount / H << endl;
  for (int cMmin = 1; cMmin <= 10; cMmin++) {
    printf(">%2dcM:", cMmin);
    for (int k = 0; k < maxNumLong; k++)
      printf("   %.2f", fracAbove[k][cMmin] / (H>>6) / M);
    cout << endl;
  }

  cout << "Time for seed extension: " << timer.update_time() << endl;

  if (iter == 1) {
    // flip low-confidence bits with evidence for incorrectness from IBD haplotypes
    int totFlips = 0;
    for (int h = 0; h < H; h++) {
      totFlips += flips[h].size();
      for (int j = 0; j < (int) flips[h].size(); j++)
	hapBitsT.flipBit(h, flips[h][j]);
    }
    cout << "Average number of flips: " << totFlips / (double) H << endl;

    // output per-SNP error statistics; adjust maxErrRate to avoid dropping >5% of SNPs
    double maxErrRate = 0.0025;
    for (double rate = 0.01; rate >= 0.0005; rate -= 0.001) {
      int Mbad = 0;
      for (int m = 0; m < M; m++)
	Mbad += errCounts[m] / (double) totCounts[m] > rate;
      printf("SNPs with IBD error rate > %.3f: %5d / %5d = %.3f\n",
	     rate, Mbad, M, Mbad / (double) M);
      if (Mbad / (double) M > 0.05)
	maxErrRate = max(maxErrRate, rate + 0.001);
    }

    // mask bad SNPs by setting to 1 for all samples
    int Mbad = 0;
    for (int m = 0; m < M; m++) {
      if (errCounts[m] / (double) totCounts[m] > maxErrRate) {
	for (int h = 0; h < H; h++)
	  hapBitsT.setBit(h, m);
	Mbad++;
      }
    }
    printf("Masked SNPs with IBD error rate > %.4f: %5d / %5d = %.3f\n",
	   maxErrRate, Mbad, M, Mbad / (double) M);
  }
  else {
    cout << "Writing IBD output file: " << outFile << endl;
    FILE *foutBin = fopen(outFile, "wb"); assert(foutBin != NULL);
    fwrite(&N, sizeof(int), 1, foutBin);      // 1 int: N = 487409
    fwrite(&famIndsDropMissing[0], sizeof(int), N, foutBin); // N int: indices of phased samples (header1)
    // start of match blocks in binary file (after headers)
    uint64 seekPos = (1+N)*sizeof(int) + (H+1)*sizeof(uint64);
    for (int h = 0; h <= H; h++) {            // H+1 uint64: match block starts (header2)
      fwrite(&seekPos, sizeof(uint64), 1, foutBin);
      if (h < H)
	seekPos += matches[h].size() * sizeof(Match);
    }
    for (int h = 0; h < H; h++) { // match blocks
      for (int j = 0; j < (int) matches[h].size(); j++) { // convert mStart/mEnd to bim coordinates
	matches[h][j].mStart = mToBimIndex[matches[h][j].mStart];
	matches[h][j].mEnd = mToBimIndex[matches[h][j].mEnd];
      }
      fwrite(&matches[h][0], sizeof(Match), matches[h].size(), foutBin);
    }
    fclose(foutBin);
  }

  } // end main iteration (2 rounds of PBWT)

  cout << "Finished find_IBD; total time = " << timer.get_time()-t0 << " sec" << endl;

  return 0;
}
