#include <iostream>
#include <sstream>
#include <vector>
#include <string>
#include <map>
#include <set>
#include <numeric>
#include <cstdio>
#include <cstdlib>
#include <cassert>
#include <cmath>

#include "zlib.h"

#include "VersionHeader.hpp"
#include "FileUtils.cpp"
#include "StringUtils.cpp"
#include "Timer.cpp"

using namespace std;

struct ClusterSums {
  int num[3];
  double sumX[3], sumY[3], sumXX[3], sumXY[3], sumYY[3];
};

const int bedToGeno[4] = {2, 3, 1, 0};

void readBedLine(uchar *bedLineIn, FILE *finBed, uint64 M, uint64 N, uint64 m) {
  fseek(finBed, 3 + m*((N+3)>>2), SEEK_SET);
  fread(bedLineIn, 1, (N+3)>>2, finBed);  
}

vector <string> tokenize(const string &line) {
  istringstream iss(line);
  vector <string> tokens; string token;
  while (iss >> token)
    tokens.push_back(token);
  return tokens;
}

/*
void require(bool test, string errorStr) {
  if (!test) {
    cout << "ERROR: " << errorStr << endl;
    exit(1);
  }
}
*/
#define require(test, errorStr) if (!(test)) { cout << "ERROR: " << (errorStr) << endl; exit(1); }

string makeIDpair(const string &ID_1, const string &ID_2) { return ID_1 + "\t" + ID_2; }

bool inRange(double x, double lo, double hi) {
  return lo <= x && x <= hi;
}

inline double sq(double x) { return x*x; }

void checkSampleInfo(vector <string> &IDpairs, map <string, int> &IDpairToFamInd,
		     map <string, int> &IDimpToFamInd, vector <string> &genotypingSetNames,
		     vector <int> &famIndToGenotypingSet, const char *sampleInfoFile) {

  cout << "====> Checking sample info file: " << sampleInfoFile << endl << endl;
  FileUtils::AutoGzIfstream fin; fin.openOrExit(sampleInfoFile);
  string line;
  // check header
  getline(fin, line);
  vector <string> tokens = tokenize(line);
  require(tokens.size()==5U && tokens[0]+" "+tokens[1]+" "+tokens[2]+" "+tokens[3]+" "+tokens[4]==
	  "ID_1 ID_2 ID_imp genotypingSet inUniformAncestrySet",
	  "sample info header must match:\nID_1 ID_2 ID_imp genotypingSet inUniformAncestrySet");
  map <string, int> setNameToInd; // map genotyping set names to genotypingSets indices
  vector <int> setSizes;
  // check sample lines
  int lineCtr = 1, ctrInUniformAncestrySet = 0;
  while (getline(fin, line)) {
    lineCtr++;
    vector <string> tokens = tokenize(line);
    require(tokens.size()==5U, "line " + StringUtils::itos(lineCtr) + " contains " +
	    StringUtils::itos(tokens.size()) + " tokens; expected 5");
    const string &ID_1 = tokens[0], &ID_2 = tokens[1], &IDimp = tokens[2],
      &genotypingSetName = tokens[3], &inUniformAncestrySet = tokens[4];
    string IDpair = makeIDpair(ID_1, ID_2);
    int famInd = IDpairs.size();
    IDpairs.push_back(IDpair);
    require(IDpairToFamInd.find(IDpair)==IDpairToFamInd.end(),
	    "duplicate ID_1 ID_2 pair in line " + StringUtils::itos(lineCtr));
    IDpairToFamInd[IDpair] = famInd;
    require(IDimp=="NA" || IDimpToFamInd.find(IDimp)==IDimpToFamInd.end(),
	    "duplicate IDimp in line " + StringUtils::itos(lineCtr));
    IDimpToFamInd[IDimp] = famInd;
    if (setNameToInd.find(genotypingSetName) == setNameToInd.end()) {
      setNameToInd[genotypingSetName] = genotypingSetNames.size();
      genotypingSetNames.push_back(genotypingSetName);
      setSizes.push_back(0);
    }
    famIndToGenotypingSet.push_back(setNameToInd[genotypingSetName]);
    setSizes[famIndToGenotypingSet.back()]++;
    require(inUniformAncestrySet=="0" || inUniformAncestrySet=="1",
	    "non-binary 5th column (inUniformAncestrySet) in line " + StringUtils::itos(lineCtr));
    ctrInUniformAncestrySet += inUniformAncestrySet=="1";
  }
  fin.close();
  cout << "Read sample info for " << IDpairs.size() << " samples" << endl;
  cout << "Read " << genotypingSetNames.size() << " genotyping sets:" << endl;
  for (uint i = 0; i < genotypingSetNames.size(); i++)
    cout << "  " << genotypingSetNames[i] << " (N = " << setSizes[i] << ")"
	 << endl;
  cout << "Read " << ctrInUniformAncestrySet << " / " << IDpairs.size()
       << " samples in uniform ancestry set" << endl;
  cout << endl << endl;
}

void checkFam(const string &famFile, const vector <string> &IDpairs) {

  cout << "====> Checking fam file: " << famFile << endl << endl;
  uint numLines = FileUtils::AutoGzIfstream::lineCount(famFile);
  require(numLines == IDpairs.size(), "fam file has " + StringUtils::itos(numLines)
	  + " lines; expected " + StringUtils::itos(IDpairs.size()));
  FileUtils::AutoGzIfstream fin; fin.openOrExit(famFile);
  string line;
  int lineCtr = 0;
  while (getline(fin, line)) {
    lineCtr++;
    vector <string> tokens = tokenize(line);
    require(tokens.size()==6U, "line " + StringUtils::itos(lineCtr) + " contains " +
	    StringUtils::itos(tokens.size()) + " tokens; expected 6");
    const string &FID = tokens[0], &IID = tokens[1];
    require(makeIDpair(FID, IID) == IDpairs[lineCtr-1],
	    "FID IID in line " + StringUtils::itos(lineCtr)
	    + " of fam file does not match ID_1 ID_2 in line " + StringUtils::itos(lineCtr+1)
	    + " of sample info file");
  }
  fin.close();
  cout << "Verified that fam file contains same IDs as sample info file" << endl;
  cout << endl << endl;
}

void checkBim(vector <string> &snpStrs, map <string, int> &snpStrToBimInd,
	      map <string, int> &posA1A2ToBimInd, const string &bimFile) {

  cout << "====> Checking bim file: " << bimFile << endl << endl;
  FileUtils::AutoGzIfstream fin; fin.openOrExit(bimFile);
  string line;
  int lineCtr = 0;
  double posMorgansPrev = -1e9;
  int posPrev = 0;
  set <int> posSet;
  while (getline(fin, line)) {
    lineCtr++;
    vector <string> tokens = tokenize(line);
    require(tokens.size()==6U, "line " + StringUtils::itos(lineCtr) + " contains " +
	    StringUtils::itos(tokens.size()) + " tokens; expected 6");
    //int chr = StringUtils::stoi(tokens[0]);
    const string &snpStr = tokens[1];
    double posMorgans = StringUtils::stod(tokens[2]);
    int pos = StringUtils::stoi(tokens[3]);
    const string &A1 = tokens[4], &A2 = tokens[5];
    string posA1A2 = StringUtils::itos(pos) + "_" + A1+"_"+A2;
    require(posMorgans >= posMorgansPrev,
	    "genetic map coordinate (column 3) decreased in line " + StringUtils::itos(lineCtr));
    require(pos >= posPrev,
	    "base pair coordinate (column 4) decreased in line " + StringUtils::itos(lineCtr));
    require(snpStrToBimInd.find(snpStr)==snpStrToBimInd.end(),
	    "duplicate snpID (column 2) in line " + StringUtils::itos(lineCtr));
    snpStrs.push_back(snpStr);
    snpStrToBimInd[snpStr] = lineCtr-1;
    require(posA1A2ToBimInd.find(posA1A2)==posA1A2ToBimInd.end(),
	    "duplicate SNP (POS_A1_A2) in line " + StringUtils::itos(lineCtr));
    posA1A2ToBimInd[posA1A2] = lineCtr-1;
    posSet.insert(pos);
    posMorgansPrev = posMorgans;
    posPrev = pos;
  }
  fin.close();
  double scale = posMorgansPrev / posPrev;
  require(0.5e-8 < scale && scale < 3e-8,
	  "column 3 must contain genetic map coordinates in Morgan units");
  cout << "Read data for " << snpStrToBimInd.size() << " variants" << endl;
  if (posSet.size() < snpStrs.size())
    cout << "WARNING: " << snpStrs.size()-posSet.size()
	 << " duplicate positions -- multiallelic variants not recommended for analysis" << endl;
  cout << endl << endl;
}

FILE *openCheckBed(const string &bedFile, uint64 M, uint64 N) {

  cout << "====> Checking bed file: " << bedFile << endl << endl;
  FILE *finBed = fopen(bedFile.c_str(), "rb");
  require(finBed != NULL, "unable to open file");
  uint64 bedBytesExpected = M * ((N+3)>>2) + 3;
  fseek(finBed, 0, SEEK_END);
  require((uint64) ftell(finBed) == bedBytesExpected,
	  "incorrect number of bytes in file: expected " + StringUtils::itos(M) + "*"
	  + StringUtils::itos((N+3)>>2) + "+3");
  fseek(finBed, 0, SEEK_SET);
  uchar header[3];
  fread(header, 1, 3, finBed);
  require(header[0]==0x6c && header[1]==0x1b && header[2]==0x01,
	  "incorrect first three bytes of bed file");
  cout << "Verified first three bytes and file size" << endl;
  cout << endl << endl;
  return finBed;
}

void checkLRRtheta(const char *outPrefix, const char *lrrFile, const char *thetaFile,
		   const vector <string> &genotypingSetNames,
		   const vector <int> &famIndToGenotypingSet, const vector <string> &IDpairs,
		   const vector <string> &snpStrs, uint64 M, uint64 N, FILE *finBed) {

  cout << "====> Checking LRR and theta files: " << lrrFile << " " << thetaFile << endl << endl;
  int S = genotypingSetNames.size();
  vector <double> setSizes(S);
  for (uint64 n = 0; n < N; n++)
    setSizes[famIndToGenotypingSet[n]]++;
  vector <int> sampleMissingLRR(N), sampleMissingTheta(N), sampleMissingGeno(N);
  int ctrHighMissingLRR = 0, ctrHighMissingTheta = 0, ctrHighMissingGeno = 0, ctrHighLargeLRR = 0;

  FileUtils::AutoGzIfstream finLRR; finLRR.openOrExit(lrrFile);
  FileUtils::AutoGzIfstream finTheta; finTheta.openOrExit(thetaFile);
  FileUtils::AutoGzOfstream foutSnpStats;
  foutSnpStats.openOrExit(outPrefix + string(".snp_stats.txt.gz"));
  foutSnpStats << "snpID" << "\t" << "genotypingSet";
  for (int g = 0; g <= 2; g++) {
    foutSnpStats << "\t" << "n" << g << "\t" << "muTheta" << g << "\t" << "muLRR" << g
		 << "\t" << "stdTheta" << g << "\t" << "stdLRR" << g << "\t" << "rho" << g;
  }
  foutSnpStats << "\t" << "fracMissLRR"
	       << "\t" << "fracMissTheta"
	       << "\t" << "fracMissGeno"
	       << "\t" << "fracLargeLRR"
	       << endl;
  int ctrHomMinor5 = 0, ctrHomMinor5nonMonotonic = 0;
  uint64 ctrTheta01 = 0, ctrThetaNotNan = 0;
  string lrrStr, thetaStr;
  uchar *bedLineIn = (uchar *) malloc((N+3)>>2); assert(bedLineIn != NULL);  
  for (uint64 m = 0; m < M; m++) {
    if ((m+1)%100 == 0) cout << "." << flush;
    string lineLRR, lineTheta;
    require(getline(finLRR, lineLRR), "LRR file only contains " + StringUtils::itos(m)
	    + " lines; expected " + StringUtils::itos(M));
    require(getline(finTheta, lineTheta), "theta file only contains " + StringUtils::itos(m)
	    + " lines; expected " + StringUtils::itos(M));
    istringstream issLRR(lineLRR), issTheta(lineTheta);

    vector <ClusterSums> clusterSumsAll(S);
    fread(bedLineIn, 1, (N+3)>>2, finBed);
    vector <int> ctrMissingLRR(S), ctrMissingTheta(S), ctrMissingGeno(S), ctrLargeLRR(S);

    for (uint64 n = 0; n < N; n++) {
      require(issLRR >> lrrStr,
	      "LRR file line " + StringUtils::itos(m+1) + " only contains "
	      + StringUtils::itos(n) + " tokens; expected " + StringUtils::itos(N));
      require(issTheta >> thetaStr,
	      "theta file line " + StringUtils::itos(m+1) + " only contains "
	      + StringUtils::itos(n) + " tokens; expected " + StringUtils::itos(N));
      double lrr = NAN, theta = NAN;
      require(sscanf(lrrStr.c_str(), "%lf", &lrr) || lrrStr=="NA" || lrrStr=="nan",
	      "LRR file line " + StringUtils::itos(m+1) + " has invalid entry in column "
	      + StringUtils::itos(n+1) + ": " + lrrStr);
      require(sscanf(thetaStr.c_str(), "%lf", &theta) || thetaStr=="NA" || thetaStr=="nan",
	      "theta file line " + StringUtils::itos(m+1) + " has invalid entry in column "
	      + StringUtils::itos(n+1) + ": " + thetaStr);
      require(isnan(theta) || (0 <= theta && theta <= 1),
	      "theta file line " + StringUtils::itos(m+1) + " has out-of-bounds entry in column "
	      + StringUtils::itos(n+1) + ": " + thetaStr);

      int s = famIndToGenotypingSet[n];
      int g = bedToGeno[(bedLineIn[n>>2]>>((n&3)<<1))&3];

      if (isnan(lrr)) { ctrMissingLRR[s]++; sampleMissingLRR[n]++; }
      if (isnan(theta)) { ctrMissingTheta[s]++; sampleMissingTheta[n]++; }
      if (g > 2) { ctrMissingGeno[s]++; sampleMissingGeno[n]++; }
      if (!isnan(lrr) && (lrr < -2 || lrr > 2)) ctrLargeLRR[s]++;
      if (!isnan(theta)) ctrThetaNotNan++;
      if (theta==0 || theta==1) ctrTheta01++;
      
      if (!isnan(lrr) && !isnan(theta) && g <= 2) {
	ClusterSums &cs = clusterSumsAll[s];
	cs.num[g]++;
	cs.sumX[g] += theta;
	cs.sumY[g] += lrr;
	cs.sumXX[g] += sq(theta);
	cs.sumXY[g] += theta*lrr;	
	cs.sumYY[g] += sq(lrr);
      }
    }
    require(!(issLRR >> lrrStr), "LRR file line " + StringUtils::itos(m+1)
	    + " contains more than " + StringUtils::itos(N) + " tokens");
    require(!(issTheta >> thetaStr), "theta file line " + StringUtils::itos(m+1)
	    + " contains more than " + StringUtils::itos(N) + " tokens");

    for (int s = 0; s < S; s++) {
      foutSnpStats << snpStrs[m] << "\t" << genotypingSetNames[s];
      ClusterSums &cs = clusterSumsAll[s];
      for (int g = 0; g <= 2; g++) {
	double muX = cs.sumX[g] / cs.num[g], muY = cs.sumY[g] / cs.num[g];
	double stdX = sqrt((cs.sumXX[g] - cs.sumX[g]*muX) / (cs.num[g]-1));
	double stdY = sqrt((cs.sumYY[g] - cs.sumY[g]*muY) / (cs.num[g]-1));
	double rho = (cs.sumXY[g] - cs.sumX[g]*muY) / ((cs.num[g]-1) * stdX * stdY);
	foutSnpStats << "\t" << cs.num[g] << "\t" << muX << "\t" << muY
		     << "\t" << stdX << "\t" << stdY << "\t" << rho;
      }
      foutSnpStats << "\t" << ctrMissingLRR[s]/setSizes[s]
		   << "\t" << ctrMissingTheta[s]/setSizes[s]
		   << "\t" << ctrMissingGeno[s]/setSizes[s]
		   << "\t" << ctrLargeLRR[s]/setSizes[s]
		   << endl;
      if (cs.num[0] >= 5 && cs.num[1] >= 5 && cs.num[2] >= 5) {
	ctrHomMinor5++;
	double muX[3];
	for (int g = 0; g <= 2; g++) muX[g] = cs.sumX[g] / cs.num[g];
	if ((muX[1]-muX[0])*(muX[2]-muX[1]) < 0)
	  ctrHomMinor5nonMonotonic++;
      }
      ctrHighMissingLRR += inRange(ctrMissingLRR[s]/setSizes[s], 0.1, 0.999999);
      ctrHighMissingTheta += inRange(ctrMissingTheta[s]/setSizes[s], 0.1, 0.999999);
      ctrHighMissingGeno += inRange(ctrMissingGeno[s]/setSizes[s], 0.1, 0.999999);
      ctrHighLargeLRR += inRange(ctrLargeLRR[s]/setSizes[s], 0.01, 1);
    }
  }
  cout << endl;
  free(bedLineIn);
  require(!(finLRR >> lrrStr),
	  "LRR file contains more than " + StringUtils::itos(M) + " lines");
  require(!(finTheta >> thetaStr),
	  "theta file contains more than " + StringUtils::itos(M) + " lines");
  finLRR.close();
  finTheta.close();
  foutSnpStats.close();
  if (ctrHomMinor5nonMonotonic)
    cout << "WARNING: " << ctrHomMinor5nonMonotonic << " of " << ctrHomMinor5
	 << " clusters with hom-minor count >= 5 have non-monotonic theta values" << endl;
  if (ctrHighMissingLRR)
    cout << "WARNING: " << ctrHighMissingLRR << " of " << (M*S)
	 << " SNP x genotypingSet pairs have high LRR missingness >=0.1 and <1" << endl;
  if (ctrHighMissingTheta)
    cout << "WARNING: " << ctrHighMissingTheta << " of " << (M*S)
	 << " SNP x genotypingSet pairs have high theta missingness >=0.1 and <1" << endl;
  if (ctrHighMissingGeno)
    cout << "WARNING: " << ctrHighMissingGeno << " of " << (M*S)
	 << " SNP x genotypingSet pairs have high genotype missingness >=0.1 and <1" << endl;
  if (ctrHighLargeLRR)
    cout << "WARNING: " << ctrHighLargeLRR << " of " << (M*S)
	 << " SNP x genotypingSet pairs have >1% frequency of |LRR|>2" << endl;
  if (ctrTheta01 >= 0.001 * ctrThetaNotNan)
    cout << "WARNING: " << ctrTheta01 / (double) ctrThetaNotNan << " of theta values are 0 or 1"
	 << endl;
  
  FileUtils::AutoGzOfstream foutSampleStats;
  foutSampleStats.openOrExit(outPrefix + string(".sample_stats.txt.gz"));
  foutSampleStats << "FID\tIID\tmissLRR\tmissTheta\tmissGeno" << endl;
  int highMissLRR = 0, highMissTheta = 0, highMissGeno = 0;
  for (uint64 n = 0; n < N; n++) {
    highMissLRR += sampleMissingLRR[n]>0.1*N;
    highMissTheta += sampleMissingTheta[n]>0.1*N;
    highMissGeno += sampleMissingGeno[n]>0.1*N;
    foutSampleStats << IDpairs[n]
		    << "\t" << sampleMissingLRR[n] / (double) N
		    << "\t" << sampleMissingTheta[n] / (double) N
		    << "\t" << sampleMissingGeno[n] / (double) N
		    << endl;
  }
  foutSampleStats.close();
  if (highMissLRR)
    cout << "WARNING: " << highMissLRR << " samples have >0.1 missing LRR " << endl;
  if (highMissTheta)
    cout << "WARNING: " << highMissTheta << " samples have >0.1 missing theta " << endl;
  if (highMissGeno)
    cout << "WARNING: " << highMissGeno << " samples have >0.1 missing genotypes " << endl;

  cout << endl << endl;
}

vector <int> getFamInds(const string &sampleFile, const map <string, int> &IDpairToFamInd,
			const map <string, int> &IDimpToFamInd, bool checkIDimp) {

  FileUtils::AutoGzIfstream finSample; finSample.openOrExit(sampleFile);
  string line;
  // check header
  getline(finSample, line);
  vector <string> tokens = tokenize(line);
  require(tokens.size()>=3U && tokens[0]+" "+tokens[1]+" "+tokens[2]=="ID_1 ID_2 missing",
	  "sample header line 1 must begin with:\nID_1 ID_2 missing");
  getline(finSample, line);
  tokens = tokenize(line);
  require(tokens.size()>=3U && tokens[0]+" "+tokens[1]+" "+tokens[2]=="0 0 0",
	  "sample header line 2 must begin with:\n0 0 0");
  // check IDs
  vector <int> famInds;
  int lineCtr = 2;
  int ctrNotInFam = 0;
  while (getline(finSample, line)) {
    lineCtr++;
    tokens = tokenize(line);
    require(tokens.size()>=3U,
	    "sample file line " + StringUtils::itos(lineCtr) + " does not contain >=3 tokens");
    const string &ID_1 = tokens[0], &ID_2 = tokens[1];
    string IDpair = makeIDpair(ID_1, ID_2);
    map <string, int>::const_iterator it = IDpairToFamInd.find(IDpair);
    map <string, int>::const_iterator it2 = IDimpToFamInd.find(ID_2);
    if (it == IDpairToFamInd.end() && (!checkIDimp || it2 == IDimpToFamInd.end())) {
      famInds.push_back(-1);
      ctrNotInFam++;
    }
    else
      famInds.push_back(it != IDpairToFamInd.end() ? it->second : it2->second);
  }
  finSample.close();
  if (ctrNotInFam)
    cout << "WARNING: sample file contains " << ctrNotInFam << " IDs not in sample info file"
	 << endl;
  return famInds;
}

void checkHapsSample(const string &phasedStatsFile, const char *hapsSamplePrefix,
		     const map <string, int> &IDpairToFamInd,
		     const map <string, int> &snpStrToBimInd, uint64 M, uint64 N,
		     FILE *finBed) {

  cout << "====> Checking haps.gz and sample files: " << hapsSamplePrefix << ".{haps.gz,sample}"
       << endl << endl;
  
  vector <int> famInds = getFamInds(hapsSamplePrefix + string(".sample"), IDpairToFamInd,
				    IDpairToFamInd, false);
  uint Nphased = famInds.size();
  cout << "Read IDs of " << Nphased << " samples in sample file" << endl;
  int ctrNotInFam = count(famInds.begin(), famInds.end(), -1);
  if (Nphased-ctrNotInFam < N)
    cout << "WARNING: " << N-Nphased+ctrNotInFam
	 << " samples without phased haplotypes will be excluded from analysis" << endl;

  FileUtils::AutoGzIfstream finHaps; finHaps.openOrExit(hapsSamplePrefix + string(".haps.gz"));
  FileUtils::AutoGzOfstream foutPhasedStats; foutPhasedStats.openOrExit(phasedStatsFile);
  foutPhasedStats << "snpID\tAF\tAFphased\tr2" << endl;
  uchar *bedLineIn = (uchar *) malloc((N+3)>>2); assert(bedLineIn != NULL);  
  int ctrLowR2 = 0;
  string line;
  int lineCtr = 0;
  while (getline(finHaps, line)) {
    lineCtr++;
    istringstream iss(line);
    string chr, snpStr, pos, A1, A2;
    require(iss >> chr >> snpStr >> pos >> A1 >> A2,
	    "haps.gz file line " + StringUtils::itos(lineCtr) + " has too few fields");
    getline(iss, line);
    require(line.length() == Nphased*4,
	    "haps.gz file line " + StringUtils::itos(lineCtr) + " has wrong length");
    map <string, int>::const_iterator it = snpStrToBimInd.find(snpStr);
    require(it != snpStrToBimInd.end(), "haps.gz file line " + StringUtils::itos(lineCtr)
	    + " contains snpID not in bim file");
    readBedLine(bedLineIn, finBed, M, N, it->second);
    double num = 0, sumX = 0, sumY = 0, sumXX = 0, sumXY = 0, sumYY = 0;
    for (uint i = 0; i < Nphased; i++) {
      require(line[4*i]==' ' && (line[4*i+1]=='0'||line[4*i+1]=='1') &&
	      line[4*i+2]==' ' && (line[4*i+3]=='0'||line[4*i+3]=='1'),
	      "haps.gz file line " + StringUtils::itos(lineCtr) + " has sample "
	      + StringUtils::itos(i+1) + " incorrectly formatted");
      int gPhased = (line[4*i+1]=='1') + (line[4*i+3]=='1');
      int n = famInds[i];
      if (n == -1) continue; // sample not in fam file
      int g = bedToGeno[(bedLineIn[n>>2]>>((n&3)<<1))&3];
      if (g <= 2) {
	num++;
	sumX += g;
	sumY += gPhased;
	sumXX += sq(g);
	sumXY += g*gPhased;
	sumYY += sq(gPhased);
      }
    }
    double AF = sumX/(2*num), AFphased = sumY/(2*num);
    double r2 = sq(sumXY*num - sumX*sumY) / ((sumXX*num - sq(sumX)) * (sumYY*num - sq(sumY)));
    if (r2 < 0.9)
      ctrLowR2++;
    foutPhasedStats << snpStr << "\t" << AF << "\t" << AFphased << "\t" << r2 << endl;
  }
  free(bedLineIn);
  finHaps.close();
  foutPhasedStats.close();
  uint Mphased = lineCtr;
  cout << "Read phased haplotypes for " << Mphased << " variants in haps.gz file" << endl;
  if (ctrLowR2)
    cout << "WARNING: " << ctrLowR2 << " variants have r2<0.9 vs. bed file genotypes" << endl;
  if (Mphased < 0.9*M)
    cout << "WARNING: >10% of variants in bim file not present in haps.gz file" << endl;

  cout << endl << endl;
}

void checkBgenSample(vector <int> &bimSnpFound, FileUtils::AutoGzOfstream &foutImputedStats,
		     const char *bgenFile, const map <string, int> &IDpairToFamInd,
		     const map <string, int> &IDimpToFamInd,
		     const map <string, int> &posA1A2ToBimInd, uint64 M, uint64 N,
		     FILE *finBed) {

  string bgenSamplePrefix = string(bgenFile).substr(0, strlen(bgenFile)-5);
  cout << "====> Checking bgen and sample files: " << bgenSamplePrefix << ".{bgen,sample}" << endl
       << endl;

  vector <int> famInds = getFamInds(bgenSamplePrefix + ".sample", IDpairToFamInd, IDimpToFamInd,
				    true);
  uint Nimputed = famInds.size();
  cout << "Read IDs of " << Nimputed << " samples in sample file" << endl;
  int ctrNotInFam = count(famInds.begin(), famInds.end(), -1);
  if (Nimputed-ctrNotInFam < N)
    cout << "WARNING: " << N-Nimputed+ctrNotInFam
	 << " samples without imputed data will not have probes near other SNPs masked" << endl;

  // check bgen header
  FILE *fin = fopen((bgenSamplePrefix + ".bgen").c_str(), "rb"); assert(fin != NULL);
  uint offset; fread(&offset, 4, 1, fin);
  uint L_H; fread(&L_H, 4, 1, fin);
  uint Mbgen; fread(&Mbgen, 4, 1, fin);
  assert(Mbgen != 0);
  uint Nbgen; fread(&Nbgen, 4, 1, fin);
  require(Nbgen==Nimputed, "number of samples in bgen header does not match sample file");
  char magic[5]; fread(magic, 1, 4, fin); magic[4] = '\0';
  require(magic==string("bgen"), "incorrect magic bytes");
  fseek(fin, L_H-20, SEEK_CUR);
  uint flags; fread(&flags, 4, 1, fin);
  uint CompressedSNPBlocks = flags&3;
  require(CompressedSNPBlocks==1, "CompressedSNPBlocks flag required to be 1")
  uint Layout = (flags>>2)&0xf;
  require(Layout==2, "Layout required to be 2 (BGEN v1.2)");
  fseek(fin, offset+4, SEEK_SET);

  // check variant data
  cout << "Reading data for " << Mbgen << " variants in bgen file" << endl;

  double lut[256];
  for (int i = 0; i <= 255; i++)
    lut[i] = i/255.0;

  uchar *bedLineIn = (uchar *) malloc((N+3)>>2); assert(bedLineIn != NULL);  
    
  int maxLA = 65536, maxLB = 65536;
  char *A1 = (char *) malloc(maxLA+1);
  char *A2 = (char *) malloc(maxLB+1);
  char chrStr[65536];

  uchar *zBuf = (uchar *) malloc(3*Nbgen+100);
  uchar *buf = (uchar *) malloc(3*Nbgen+100);
  int ctrInBim = 0, ctrCommonInBim = 0, ctrLowR2 = 0;
  for (uint mbgen = 0; mbgen < Mbgen; mbgen++) {
    ushort LS; fread(&LS, 2, 1, fin);
    //fread(snpID, 1, LS, fin); snpID[LS] = '\0';
    fseek(fin, LS, SEEK_CUR);
    ushort LR; fread(&LR, 2, 1, fin);
    //fread(rsID, 1, LR, fin); rsID[LR] = '\0';
    fseek(fin, LR, SEEK_CUR);
    ushort LC; fread(&LC, 2, 1, fin);
    fread(chrStr, 1, LC, fin); chrStr[LC] = '\0'; //fseek(fin, LC, SEEK_CUR);
    int pos; fread(&pos, 4, 1, fin);
    ushort K; fread(&K, 2, 1, fin);
    require(K==2, "non-biallelic variant found in bgen file");
    int LA; fread(&LA, 4, 1, fin);
    if (LA > maxLA) {
      maxLA = 2*LA;
      free(A1);
      A1 = (char *) malloc(maxLA+1);
    }
    fread(A1, 1, LA, fin); A1[LA] = '\0';
    int LB; fread(&LB, 4, 1, fin);
    if (LB > maxLB) {
      maxLB = 2*LB;
      free(A2);
      A2 = (char *) malloc(maxLB+1);
    }
    fread(A2, 1, LB, fin); A2[LB] = '\0';

    uint C; fread(&C, 4, 1, fin); //cout << "C: " << C << endl;
    uint D; fread(&D, 4, 1, fin); //cout << "D: " << D << endl;
    uint zBufLen = C-4; uint bufLen = D;

    int bimInd = -1; map <string, int>::const_iterator it;
    it = posA1A2ToBimInd.find(StringUtils::itos(pos)+"_"+A1+"_"+A2);
    if (it != posA1A2ToBimInd.end()) bimInd = it->second;
    it = posA1A2ToBimInd.find(StringUtils::itos(pos)+"_"+A2+"_"+A1);
    if (it != posA1A2ToBimInd.end()) bimInd = it->second;

    if (bimInd != -1) {
      fread(zBuf, 1, zBufLen, fin);
      uLongf destLen = bufLen;
      require(uncompress(buf, &destLen, zBuf, zBufLen) == Z_OK && destLen == bufLen,
	      "uncompress() failed");
      uchar *bufAt = buf;
      uint Ncheck = bufAt[0]|(bufAt[1]<<8)|(bufAt[2]<<16)|(bufAt[3]<<24); bufAt += 4;
      assert(Ncheck==Nimputed);
      uint K = bufAt[0]|(bufAt[1]<<8); bufAt += 2;
      assert(K==2U);
      uint Pmin = *bufAt; bufAt++;
      assert(Pmin==2U);
      uint Pmax = *bufAt; bufAt++;
      assert(Pmax==2U);
      for (uint i = 0; i < Nimputed; i++) {
	uint ploidyMiss = *bufAt; bufAt++;
	require(ploidyMiss==2U, "bgen file from imputation required to have no missingness");
      }
      uint Phased = *bufAt; bufAt++;
      assert(Phased==0U || Phased==1U);
      uint B = *bufAt; bufAt++;
      require(B==8U, "bgen file required to use 8-bit coding");
      
      readBedLine(bedLineIn, finBed, M, N, bimInd);
      double num = 0, sumX = 0, sumY = 0, sumXX = 0, sumXY = 0, sumYY = 0;
      for (uint i = 0; i < Nimputed; i++) {
	double gImp = lut[bufAt[0]] * (Phased==0U ? 2 : 1) + lut[bufAt[1]];
	bufAt += 2;
	int n = famInds[i];
	if (n == -1) continue; // sample not in fam file
	int g = bedToGeno[(bedLineIn[n>>2]>>((n&3)<<1))&3];
	if (g <= 2) {
	  num++;
	  sumX += g;
	  sumY += gImp;
	  sumXX += sq(g);
	  sumXY += g*gImp;
	  sumYY += sq(gImp);
	}
      }
      double AF = sumX/(2*num), AFimp = sumY/(2*num);
      double r2 = sq(sumXY*num - sumX*sumY) / ((sumXX*num - sq(sumX)) * (sumYY*num - sq(sumY)));
      ctrInBim++;
      bimSnpFound[bimInd] = 1;
      if (AFimp > 0.05 && AFimp < 0.95) {
	ctrCommonInBim++;
	if (r2 < 0.9)
	  ctrLowR2++;
      }
      foutImputedStats << chrStr << ":" << pos << "_" << A1 << "_" << A2
		       << "\t" << AF << "\t" << AFimp << "\t" << r2 << endl;
    }
    else
      fseek(fin, zBufLen, SEEK_CUR);
  }
  free(bedLineIn);
  free(A1);
  free(A2);
  free(zBuf);
  free(buf);
  fclose(fin);

  cout << "Checked consistency between PLINK and BGEN genotypes for " << ctrInBim << " variants"
       << endl;
  if (ctrLowR2)
    cout << "WARNING: " << ctrLowR2 << " of " << ctrCommonInBim
	 << " common variants have r2<0.9 vs. bed file genotypes" << endl;
  cout << endl << endl;
}


int main(int argc, char *argv[]) {
  
  printVersion();

  cout << "check_inputs:" << endl;
  cout << "- arg1 = $OUTPUT_PREFIX" << endl; // .{snp,sample,phased,imputed}_stats.txt.gz
  cout << "- arg2 = $SAMPLE_INFO_FILE" << endl;
  cout << "- arg3 = $BED_BIM_FAM_PREFIX" << endl;
  cout << "- arg4 = $LRR_FILE" << endl;
  cout << "- arg5 = $THETA_FILE" << endl;
  cout << "- arg6 = $HAPS_SAMPLE_PREFIX" << endl;
  cout << "- arg7+ = BGEN files (must end with .bgen; .sample files must also exist" << endl;
  cout << endl;

  printCmd(argc, argv);

  if (argc < 8) {
    cout << "ERROR: 7+ arguments required" << endl;
    exit(1);
  }

  const char *outPrefix = argv[1];
  const char *sampleInfoFile = argv[2];
  const char *plinkPrefix = argv[3];
  const char *lrrFile = argv[4];
  const char *thetaFile = argv[5];
  const char *hapsSamplePrefix = argv[6];
  
  FileUtils::requireWriteable(outPrefix + string(".sample_stats.txt.gz"));
  FileUtils::requireReadable(sampleInfoFile);
  FileUtils::requireReadable(plinkPrefix + string(".bed"));
  FileUtils::requireReadable(plinkPrefix + string(".bim"));
  FileUtils::requireReadable(plinkPrefix + string(".fam"));
  FileUtils::requireReadable(lrrFile);
  FileUtils::requireReadable(thetaFile);
  FileUtils::requireReadable(hapsSamplePrefix + string(".haps.gz"));
  FileUtils::requireReadable(hapsSamplePrefix + string(".sample"));
  for (int a = 7; a < argc; a++) {
    require(strlen(argv[a])>=5 && strcmp(argv[a]+strlen(argv[a])-5, ".bgen")==0,
	    "arg 7+ must end in .bgen, but arg " + StringUtils::itos(a) + " = " + string(argv[a]));
    FileUtils::requireReadable(argv[a]);
    FileUtils::requireReadable(string(argv[a]).substr(0, strlen(argv[a])-4) + "sample");
  }

  Timer timer; double t0 = timer.get_time();

  vector <string> IDpairs;
  map <string, int> IDpairToFamInd, IDimpToFamInd;
  vector <string> genotypingSetNames;
  vector <int> famIndToGenotypingSet;
  checkSampleInfo(IDpairs, IDpairToFamInd, IDimpToFamInd, genotypingSetNames,
		  famIndToGenotypingSet, sampleInfoFile);

  checkFam(plinkPrefix + string(".fam"), IDpairs);
  uint64 N = IDpairs.size();
  
  vector <string> snpStrs;
  map <string, int> snpStrToBimInd, posA1A2ToBimInd;
  checkBim(snpStrs, snpStrToBimInd, posA1A2ToBimInd, plinkPrefix + string(".bim"));
  uint64 M = snpStrToBimInd.size();

  FILE *finBed = openCheckBed(plinkPrefix + string(".bed"), M, N);

  checkLRRtheta(outPrefix, lrrFile, thetaFile, genotypingSetNames, famIndToGenotypingSet, IDpairs,
		snpStrs, M, N, finBed);

  checkHapsSample(outPrefix + string(".phased_stats.txt.gz"), hapsSamplePrefix, IDpairToFamInd,
		  snpStrToBimInd, M, N, finBed);

  FileUtils::AutoGzOfstream foutImputedStats;
  foutImputedStats.openOrExit(outPrefix + string(".imputed_stats.txt.gz"));
  foutImputedStats << "snp\tAF\tAFimp\tr2" << endl;
  vector <int> bimSnpFound(M);
  for (int a = 7; a < argc; a++)
    checkBgenSample(bimSnpFound, foutImputedStats, argv[a],
		    IDpairToFamInd, IDimpToFamInd, posA1A2ToBimInd, M, N, finBed);
  int ctrInBim = accumulate(bimSnpFound.begin(), bimSnpFound.end(), 0);
  cout << "Found " << ctrInBim << " variants in BGEN data out of " << M << " on SNP-array" << endl;
  if (ctrInBim < 0.9*M)
    cout << "WARNING: >10% of variants in bim file not present in bgen file" << endl;
  foutImputedStats.close();

  fclose(finBed);

  cout << "Finished check_inputs; total time = " << timer.get_time()-t0 << " sec" << endl;

  return 0;
}
