// g++ -O3 -fopenmp -Wall -static-libgcc -static-libstdc++ findNeighborsEUR.cpp -o findNeighborsEUR -I/n/groups/price/poru/HSPH_SVN/src/EAGLE -I/home/pl88/boost_1_58_0/install/include -L/n/groups/price/poru/external_software/libstdc++/usr/lib/gcc/x86_64-redhat-linux/4.8.5/ -L/n/groups/price/poru/external_software/zlib/zlib-1.2.11 -L/home/pl88/boost_1_58_0/install/lib -Wl,-Bstatic -lboost_iostreams -lz

#include <iostream>
#include <iomanip>
#include <sstream>
#include <fstream>
#include <vector>
#include <set>
#include <map>
#include <algorithm>
#include <cstdio>
#include <cassert>
#include <cmath>

#include "omp.h"

#include "FileUtils.cpp"
#include "StringUtils.cpp"
#include "Timer.cpp"

using namespace std;

inline float sq(float x) { return x*x; }

float crop(float z, float zMax) {
  return min(zMax, max(-zMax, z));
}


int main(int argc, char *argv[]) {

  if (argc != 8) {
    cerr << "ERROR: 7 arguments required" << endl;
    cerr << "- arg1: batch number" << endl;
    cerr << "- arg2: total batches" << endl;
    cerr << "- arg3: zMax: crop value for z-scores" << endl;
    cerr << "- arg4: number of neighbors" << endl;
    cerr << "- arg5: ID_scale_zdepths file" << endl;
    cerr << "- arg6: threads" << endl;
    cerr << "- arg7: output file" << endl;
    return 1;
  }

  int b; sscanf(argv[1], "%d", &b);
  int B; sscanf(argv[2], "%d", &B);
  float zMax; sscanf(argv[3], "%f", &zMax);
  int Nnbr; sscanf(argv[4], "%d", &Nnbr);
  const char *dataFile = argv[5]; //"ID_scale_zdepths.txt.gz";
  int threads; sscanf(argv[6], "%d", &threads);
  const char *outFileNbrs = argv[7]; //"ID_scale_neighbors";
  
  cout << "Computing nearest neighbors for batch " << b << " mod " << B << endl;

  omp_set_num_threads(threads);
  cout << "Using " << threads << " threads" << endl;

  cout << "Cropping 'z-score' values to zMax = " << zMax << endl;

  Timer timer;
  
  FileUtils::AutoGzIfstream fin;

  // read non-stringent British IDs
  set <int> nonBritIDs;
  fin.openOrExit("/mnt/project/lohdata/ploh/WES_CNVs/noise_neighbors/inputs/remove.nonStringentBritish.FID_IID.40709.txt");
  string line; getline(fin, line);
  int ID;
  while (fin >> ID) nonBritIDs.insert(ID);
  fin.close();

  // read and store related pairs
  map < int, set <int> > relatedIDs;
  fin.openOrExit("/mnt/project/lohdata/ploh/WES_CNVs/noise_neighbors/inputs/ukb4070_rel_s488374.dat");
  getline(fin, line);
  int ID1, ID2;
  while (fin >> ID1 >> ID2) {
    relatedIDs[ID1].insert(ID2);
    relatedIDs[ID2].insert(ID1);
    getline(fin, line);
  }
  fin.close();

  // read numbers of individuals and regions
  fin.openOrExit(dataFile);  
  int N, R; fin >> N >> R;
  int Nbatch = 0;
  for (int n = 0; n < N; n++)
    if (n % B == b)
      Nbatch++;
  getline(fin, line); // throw away first header line: mu

  fin >> N >> R;
  getline(fin, line); // throw away second header line (don't bother with sigma2ratios)

  // read additional header rows containing chr bpStart bpEnd for each extracted region
  vector <int> regionChr(R), regionStart(R), regionEnd(R);
  for (int h = 0; h < 3; h++) {
    fin >> N >> R;
    for (int r = 0; r < R; r++) {
      if (h==0) fin >> regionChr[r];
      else if (h==1) fin >> regionStart[r];
      else fin >> regionEnd[r];
    }
  }

  cout << "Reading data for " << Nbatch << " / " << N << " indivs at " << R << " regions"
       << endl;

  // allocate memory for individuals in batch
  vector <int> IDs(N);
  vector <bool> isNonBrit(N);
  vector <float> scales(N);
  float *zs = new float[Nbatch*(long long) R];
  
  // store z-scores for individuals in batch
  int nonBritCount = 0;
  float scale, z;
  vector <int> keepInds;
  for (int n = 0; n < N; n++) {
    fin >> ID >> scale;
    IDs[n] = ID;
    isNonBrit[n] = nonBritIDs.count(ID);
    nonBritCount += isNonBrit[n];
    scales[n] = scale;
    if (n % B == b) {
      int i = keepInds.size();
      keepInds.push_back(n);
      for (int r = 0; r < R; r++) {
	fin >> z;
	zs[i*R+r] = crop(z, zMax);
      }
    }
    else
      getline(fin, line);
    if (n%100==0)
      cout << "." << flush;
  }
  fin.close();
  cout << endl << "Read data for " << keepInds.size() << " / " << N << " indivs in batch ("
       << timer.update_time() << " sec)" << endl;
  cout << "Flagged " << nonBritCount << " non-stringent British indivs to exclude from neighbors"
       << endl;
  assert((int) keepInds.size() == Nbatch);

  // stream z-scores for all individuals
  float *dists = new float[Nbatch*(long long) N];
  memset(dists, 0, Nbatch*(long long) N*sizeof(dists[0]));
  fin.openOrExit(dataFile);
  for (int h = 0; h < 5; h++) getline(fin, line); // throw away header lines
  float *zsCur = new float[R];
  for (int n = 0; n < N; n++) {
    if (!isNonBrit[n]) {
      fin >> ID >> scale; // already stored
      for (int r = 0; r < R; r++) {
	fin >> z;
	zsCur[r] = crop(z, zMax);
      }
      const set <int> &relIDs = relatedIDs[ID];
#pragma omp parallel for
      for (int i = 0; i < Nbatch; i++) {
	if (relIDs.count(IDs[keepInds[i]]))
	  dists[i*N+n] = 1e9; // set dist to INF if related
	else
	  for (int r = 0; r < R; r++)
	    dists[i*N+n] += sq(zsCur[r] - zs[i*R+r]);
      }
    }
    getline(fin, line);
    if (n%100==0)
      cout << "." << flush;
  }  
  fin.close();
  cout << endl << "Computed distances for " << Nbatch << " / " << N << " indivs in batch ("
       << timer.update_time() << " sec)" << endl;

  FileUtils::AutoGzOfstream fout; fout.openOrExit(outFileNbrs);
  fout << std::fixed;
  vector < pair <float, int> > distIDs(N);
  for (int i = 0; i < Nbatch; i++) {
    int n_i = keepInds[i];
    fout << IDs[n_i] << "\t" << std::setprecision(3) << scales[n_i];
    // sort and output best matches
    for (int n = 0; n < N; n++)
      distIDs[n] = make_pair(isNonBrit[n] ? 1e9 : dists[i*N+n], n);
    distIDs[n_i].first = 1e9;
    sort(distIDs.begin(), distIDs.end());

    for (int j = 0; j < Nnbr; j++) {
      int n = distIDs[j].second;
      fout << "\t" << IDs[n]
	   << "\t" << std::setprecision(3) << scales[n]
	   << "\t" << std::setprecision(2) << dists[i*N+n]/(2*R);
    }
    fout << endl;
  
  }
  fout.close();
  cout << "Found neighbors and wrote output (" << timer.update_time() << " sec)" << endl;

  delete[] zsCur;
  delete[] dists;
  delete[] zs;
}
