#include <cstdlib>
#include <iomanip>
#include <vector>
#include <list>
#include <algorithm>
#include "function.h"
#include <plotter.h>

using namespace std;

// class for storing distance between two particles
class VerletComponent{
  double r2;
  double dr[3]; 
public:
  int i, j; // this should be rather private. It is public to speed up the code
  VerletComponent(int i_, int j_, double dr_[], double r2_) :  i(i_), j(j_), r2(r2_)
  {dr[0]=dr_[0];dr[1]=dr_[1];dr[2]=dr_[2];}
  void Update(double dr_[3], double r2_)
  {dr[0]=dr_[0]; dr[1]=dr_[1]; dr[2]=dr_[2]; r2=r2_;}
  const double* Dr() const{return dr;}
  double Distance() const {return r2;}
};
// The force due to Lennard-Jones potential
inline void LennardJones(const double r[3], double f[3])
{
  double r2 = sqr(r[0])+sqr(r[1])+sqr(r[2]);
  double fc = 24 * (2 * pow(r2, -7) - pow(r2, -4));
  f[0] = r[0]*fc;
  f[1] = r[1]*fc;
  f[2] = r[2]*fc;
}
// Gives distance between particles at r1 and r2 according to closes image convention
template <class vect>
inline double computeDistance(const vect& r1, const vect& r2, double dr[], double L){
  double rSqr = 0;
  for (int k = 0; k < 3; k++){
    dr[k] = r1[k] - r2[k];
    // find separation using closest image convention
    if (dr[k] >= 0.5*L) dr[k] -= L;
    if (dr[k] < -0.5*L) dr[k] += L;
    rSqr += dr[k]*dr[k];
  }
  return rSqr;
}
// Updates distance between particles which are in Verlet's list
void UpdatePairSeparations(const function2D<double>& r, list<VerletComponent>& verletList, double L) {
  double dr[3];
  for (list<VerletComponent>::iterator li=verletList.begin(); li!=verletList.end(); li++){
    double r2 = computeDistance(r[li->i], r[li->j], dr, L);
    li->Update(dr,r2);
  }
}
// Updates Verlet's list by adding or removing particles.
// Those that come closer than rMax are included in the list, while rCutOff are actually needed for force calculation
void UpdatePairList(list<VerletComponent>& VerletList, const function2D<double>& r, double L, double rMax){
  VerletList.clear();
  double dr[3];
  double rMaxSqr = rMax*rMax;
  for (int i=0; i<r.size_N(); i++)       // all distinct pairs
    for (int j=i+1; j<r.size_N(); j++){  // of particles i,j
      double rSqr = computeDistance(r[i], r[j], dr, L);
      if (rSqr < rMaxSqr)
	VerletList.push_back(VerletComponent(i,j,dr,rSqr));
    }
}
// Computes Force on each particle due to Lennard-Jones interaction
void Force(const list<VerletComponent>& verletList, function2D<double>& a, double rCutOff)
{
  double rCutOff2 = sqr(rCutOff);
  a=0;
  double fij[3];  // force
  for (list<VerletComponent>::const_iterator li=verletList.begin(); li!=verletList.end(); li++){// loop over list of close pairs
    if (li->Distance()<rCutOff2){ // the pair is still very close
      LennardJones(li->Dr(),fij);    // compute LenardJones force for the pair
      for (int k=0; k < 3; k++) { // and update force between the pair
	a[li->i][k] += fij[k];
	a[li->j][k] -= fij[k];
      }
    }
  }
}

// Velocity-Verlet algorithm
template <class storage, class functor>
void velocityVerlet(int N, int M, double dh, double L, double rCutOff, list<VerletComponent>& verletList,
		    storage& r, storage& v, storage& a, functor& F)
{
  F(verletList,a,rCutOff);// Computes acceleration "a" using function "F" which depends solely on posions "r".
  for (int i=0; i<N; i++){
    for (int k=0; k<M; k++){
      v[i][k] += 0.5*a[i][k]*dh;
      r[i][k] += v[i][k]*dh;
      
      // use periodic boundary conditions
      while (r[i][k] <  0) r[i][k] += L;
      while (r[i][k] >= L) r[i][k] -= L;
    }
  }
  UpdatePairSeparations(r, verletList, L);
  F(verletList,a,rCutOff);
  for (int i=0; i<N; i++)
    for (int k=0; k<M; k++)
      v[i][k] += 0.5*a[i][k]*dh;
}

void Initialize(double L, double T, function2D<double>& r, function2D<double>& v)
{
  int N = r.size_N();
  // How many unit cells do we need?
  int M = 1;
  while (4*M*M*M < N) M++;
  double a0 = L / M;           // lattice constant of conventional unit cell

  // there are 4 atoms per unit cell filling the following positions.
  double atom_position[4][3] = {{0,0,0},{0.5,0.5,0.0},{0.5,0.0,0.5},{0.0,0.5,0.5}};

  // initialize positions
  int p = 0; 
  for (int iz = 0; iz < M; iz++)// unit cell
    for (int iy = 0; iy < M; iy++) 
      for (int ix = 0; ix < M; ix++) 
	for (int k=0; k<4; k++){    // 4 particles per unit cell
	  if (p < N) {
	    // particles are not placed close to the boundary but rather 0.25*a0 from the boundary!
	    r(p,0) = (ix + 0.25 + atom_position[k][0]) * a0;
	    r(p,1) = (iy + 0.25 + atom_position[k][1]) * a0;
	    r(p,2) = (iz + 0.25 + atom_position[k][2]) * a0;
	    p++;
	  }
	}
  // initialize velocities
  double T2 = sqrt(T);
  function1D<double> Vaverage(3);
  Vaverage=0;
  for (int i = 0; i<N; i++)
    for (int j = 0; j<3; j++){
      double phi = 2*M_PI*drand48();
      double rr = sqrt(-2*log(1-drand48()));
      double y = rr*cos(phi);
      v(i,j) = T2*y;
      Vaverage[j] += v(i,j);
    }
  // normalize
  Vaverage*=(1./N);
  // subtract the center of mass motion
  for (int i=0; i<N; i++)
    for (int j=0; j<3; j++) v(i,j) -= Vaverage[j];
}

double instantaneousTemperature(const function2D<double>& v){
  double sum = 0;
  for (int i=0; i<v.size_N(); i++)
    for (int k=0; k<v.size_Nd(); k++)
      sum += sqr(v[i][k]);
  return sum/(3*(v.size_N() - 1));
}

void rescaleVelocities(double T, function2D<double>& v){
  double v2 = accumulate(v,sqr<double>);       // v2 = sum_{ij} v_ij^2
  double lambda = sqrt( 3*(v.size_N()-1)*T/v2);// scaling factor is sqrt(3*T/<v^2>)
  v *= lambda;
}


void Print(const function2D<double>& r, const function2D<double>& v, const function2D<double>& a, double L)
{
  using namespace std;
  for (int i=0; i<r.size_N(); i++){
    cout<<setw(5)<<i<<" "<<setw(25)<<r(i,0)/L<<" "<<setw(25)<<r(i,1)/L<<" "<<setw(25)<<r(i,2)/L<<"    ";
    cout<<setw(25)<<v(i,0)<<" "<<setw(25)<<v(i,1)<<" "<<setw(25)<<v(i,2)<<"    ";
    cout<<setw(25)<<a(i,0)<<" "<<setw(25)<<a(i,1)<<" "<<setw(25)<<a(i,2)<<"    ";
    cout<<endl;
  }
}

int cmp(const vector<double>& r1, const vector<double>& r2)// Only for sorting atoms according to their z coordinate
{ return r1[2]>r2[2];}

class PlotPrimitive{
  // Clas to simplify 3D plotting: it is used for 3D->2D projection
  double c1, c2;  // for projection
  double pixsize;
public:
  PlotPrimitive(double alpha, double theta, double pixsize_) :  pixsize(pixsize_)
  { // 3D->2D projection (Cavalier projection)
    c1 = cos(theta)/tan(alpha);
    c2 = sin(theta)/tan(alpha);
  }
  void D3D2_Project(double x, double y, double z, int& ixp, int& iyp)
  {
    ixp = static_cast<int>((x + c1*z)*pixsize);
    iyp = static_cast<int>((y + c2*z)*pixsize);
  }
  void Line(Plotter& plotter, double x0, double y0, double z0, double x1, double y1, double z1)
  {
    int ix0, ix1, iy0, iy1;
    D3D2_Project(x0,y0,z0,ix0,iy0);
    D3D2_Project(x1,y1,z1,ix1,iy1);
    plotter.line(ix0,iy0,ix1,iy1);
  }
  void Circle(Plotter& plotter, double x, double y, double z, double radius=0.45)
  {
    int ix, iy;
    D3D2_Project(x,y,z,ix,iy);
    plotter.circle(ix,iy,static_cast<int>(radius*pixsize));
  }
};

void Draw(double alpha, double theta, int pixsize, double L, Plotter& plotter, const function2D<double>& r)
{
  PlotPrimitive plot(alpha,theta,pixsize);
  plotter.erase(); // Clears plotting window
  // Plots box around the unit cell
  plotter.filltype(0);
  plot.Line(plotter, 0,0,L, L,0,L);
  plot.Line(plotter, L,0,L, L,L,L);
  plot.Line(plotter, L,L,L, 0,L,L);
  plot.Line(plotter, 0,L,L, 0,0,L);
  plot.Line(plotter, 0,0,0, 0,0,L);
  plot.Line(plotter, 0,L,0, 0,L,L);

  // atoms that have larger z coordinate should be plotted first to have the right stacking order in 3D plot
  // need to sort coordinates
  list<vector<double> > coord;
  for (int i=0; i<r.size_N(); i++){
    vector<double> r0(3);
    r0[0] = r[i][0];
    r0[1] = r[i][1];
    r0[2] = r[i][2];
    coord.push_back(r0);
  }
  coord.sort(cmp);

  // We will draw more distant atoms by dark color and closer by bright color
  // Need the minimum and maximum z coordinate
  double zmin = (*coord.begin())[2];
  double zmax = (*(--coord.end()))[2];
  
  plotter.filltype(1);
  for (list<vector<double> >::iterator ri=coord.begin(); ri!=coord.end(); ri++){
    int color = static_cast<int>(((*ri)[2]-zmin)/(zmax-zmin)*65530);
    plotter.fillcolor(color,color,color);// only grey colors used
    plot.Circle(plotter, (*ri)[0],(*ri)[1],(*ri)[2]);
  }
  
  // Plots the rest of the box around the unit cell
  plotter.filltype(0);
  plot.Line(plotter, L,0,0, L,0,L);
  plot.Line(plotter, L,L,0, L,L,L);
  plot.Line(plotter, 0,0,0, L,0,0);
  plot.Line(plotter, L,0,0, L,L,0);
  plot.Line(plotter, L,L,0, 0,L,0);
  plot.Line(plotter, 0,L,0, 0,0,0);
}

void Print(const list<VerletComponent>& verletList)
{
  for (list<VerletComponent>::const_iterator li=verletList.begin(); li!=verletList.end(); li++){
    cout<<setw(4)<<li->i<<" "<<setw(4)<<li->j<<" "<<setw(20)<<li->Distance()<<" ";
    cout<<setw(20)<<li->Dr()[0]<<" "<<setw(20)<<li->Dr()[1]<<" "<<setw(20)<<li->Dr()[2]<<endl;
  }
}

int main(int argc, char *argv[], char *env[])
{
  int N=64;// Number of particles
  double rho=1.0;// Density of particles (number per unit volume)
  double T=1.0;  // Temperature
  double dt = 0.01; // Time-step
  // Verlet update list parameters
  double rCutOff = 2.5;     // cut-off on Lennard-Jones potential and force
  double rMax = 3.3;        // maximum separation to include in pair list
  int updateInterval = 10;  // number of time steps between updates of pair list
  int i=0;
  while (++i<argc){
    std::string str(argv[i]);
    if (str=="-N" && i<argc-1)  N = atoi(argv[++i]);
    if (str=="-rho" && i<argc-1) rho = atof(argv[++i]);
    if (str=="-dt" && i<argc-1) dt = atof(argv[++i]);
    if (str=="-rCutOff" && i<argc-1) rCutOff = atof(argv[++i]);
    if (str=="-rMax" && i<argc-1) rMax = atof(argv[++i]);
    if (str=="-updateInt" && i<argc-1) updateInterval = atoi(argv[++i]);
    if (str=="-h" || str=="--help"){
      std::clog<<"********* Molecular dynamics for argon **************\n";
      std::clog<<"**                                                **\n";
      std::clog<<"**      Copyright Kristjan Haule, 26.09.2005      **\n";
      std::clog<<"****************************************************\n";
      std::clog<<"\n";
      std::clog<<"dla [-N int] [-h]\n" ;
      std::clog<<"Options:   -N          Total number of particles ("<<N<<")\n";
      std::clog<<"           -rho        Particle density ("<<rho<<")\n";
      std::clog<<"           -dt         Time-step ("<<dt<<")\n";
      std::clog<<"           -rCutOff    Cut-off energy for the Verlet list ("<<rCutOff<<")\n";
      std::clog<<"           -rMax       Maximum energy kept in Verlet list ("<<rMax<<")\n";
      std::clog<<"           -updateInt  Update Interval for Verlet list("<<updateInterval<<")\n";
      std::clog<<"*****************************************************\n";
      return 0;
    }
  }

  double V = N/rho;
  double L = pow(V, 1./3.);
  cout<<"L="<<L<<endl;
  
  /************** Initialization of plotter **************************/
  PlotterParams params; // set a Plotter parameter
  params.setplparam ("PAGESIZE", (char *)"letter");
  XPlotter plotter(cin, cout, cerr, params); // declare Plotter
  if (plotter.openpl () < 0){ // open Plotter      
    cerr << "Couldn't open Plotter\n";
    return 1;
  }
  int pixsize=500;
  plotter.fspace (-pixsize*L, -pixsize*L, pixsize*L*3, pixsize*L*3); // specify user coor system
  /*************************************************************************/

  // Data structures to store position, velocity and acceleartion
  function2D<double> r(N,3), v(N,3), a(N,3);
  
  /************** Initialization of random number gen. *********************/
  int random_seed = time(0);
  srand48(random_seed);
  /*************************************************************************/
  Initialize(L,T,r,v);// Initialize the position and velocities of atoms
  list<VerletComponent> verletList;

  /************** Actual simulation ****************************************/
  Draw(0.25*M_PI, 0.25*M_PI, pixsize, L, plotter, r);
  for (int i=0; i<3000; i++){
    if (i%updateInterval==0) UpdatePairList(verletList, r, L, rMax); // update list only from time to time
    else UpdatePairSeparations(r, verletList, L);                    // always update distances between particles which are very close
    velocityVerlet(N, 3, dt, L, rCutOff, verletList, r, v, a, Force);// one step in solving the differential equation
    std::cout<<i<<" "<<instantaneousTemperature(v)<<std::endl;
    if  (i%50==0){
      Draw(0.25*M_PI, 0.25*M_PI, pixsize, L, plotter, r);
    }
    if (i%200==0) rescaleVelocities(T,v);
  }
  /*************************************************************************/

  /************** Plotter Done**************************/
  clog<<"DONE"<<endl;
  if (plotter.closepl () < 0){ // close Plotter
    cerr << "Couldn't close Plotter\n";
    return 1;
  }
  /*****************************************************/
  return 0;
}
