#include <cstdlib>
#include <iomanip>
#include <vector>
#include <list>
#include <algorithm>
#include "function.h"
#include <plotter.h>

using namespace std;

inline void LennardJones(const double r[3], double f[3])
{
  double r2 = sqr(r[0])+sqr(r[1])+sqr(r[2]);
  double fc = 24 * (2 * pow(r2, -7) - pow(r2, -4));
  f[0] = r[0]*fc;
  f[1] = r[1]*fc;
  f[2] = r[2]*fc;
}

void Force(const function2D<double>& r, function2D<double>& a, double L)
{
  a=0;
  for (int i=0; i<r.size_N()-1; i++)     // loop over all distinct pairs i,j
    for (int j=i+1; j<r.size_N(); j++){ 
      double rij[3], fij[3];            // position of i relative to j and force
      for (int k=0; k<3; k++) {
	rij[k] = r(i,k) - r(j,k);
	// closest image convention for periodic boundary conditions
	// this finds the closest image of the atom j
	if (rij[k] >= 0.5*L) rij[k] -= L;
	if (rij[k] < -0.5*L) rij[k] += L;
      }
      LennardJones(rij,fij);
      for (int k=0; k < 3; k++) {
	a[i][k] += fij[k];
	a[j][k] -= fij[k];
      }
    }
}

// Velocity-Verlet algorithm written for very general system of variables
template <class storage, class functor>
void velocityVerlet(int N, int M, double dh, double L, storage& r, storage& v, storage& a, functor& F)
{
  F(r,a,L);// Computes acceleration "a" using function "F" which depends solely on posions "r".
  for (int i=0; i<N; i++){
    for (int k=0; k<M; k++){
      v[i][k] += 0.5*a[i][k]*dh;
      r[i][k] += v[i][k]*dh;
      
      // use periodic boundary conditions
      while (r[i][k] <  0) r[i][k] += L;
      while (r[i][k] >= L) r[i][k] -= L;
    }
  }
  F(r,a,L);
  for (int i=0; i<N; i++)
    for (int k=0; k<M; k++)
      v[i][k] += 0.5*a[i][k]*dh;
}

void Initialize(double L, double T, function2D<double>& r, function2D<double>& v)
{
  int N = r.size_N();
  // How many unit cells do we need?
  int M = 1;
  while (4*M*M*M < N) M++;
  double a0 = L / M;           // lattice constant of conventional unit cell

  // there are 4 atoms per unit cell filling the following positions.
  double atom_position[4][3] = {{0,0,0},{0.5,0.5,0.0},{0.5,0.0,0.5},{0.0,0.5,0.5}};

  // initialize positions
  int p = 0; 
  for (int iz = 0; iz < M; iz++)// unit cell
    for (int iy = 0; iy < M; iy++) 
      for (int ix = 0; ix < M; ix++) 
	for (int k=0; k<4; k++){    // 4 particles per unit cell
	  if (p < N) {
	    // particles are not placed close to the boundary but rather 0.25*a0 from the boundary!
	    r(p,0) = (ix + 0.25 + atom_position[k][0]) * a0;
	    r(p,1) = (iy + 0.25 + atom_position[k][1]) * a0;
	    r(p,2) = (iz + 0.25 + atom_position[k][2]) * a0;
	    p++;
	  }
	}
  // initialize velocities
  double T2 = sqrt(T);
  function1D<double> Vaverage(3);
  Vaverage=0;
  for (int i = 0; i<N; i++)
    for (int j = 0; j<3; j++){
      double phi = 2*M_PI*drand48();
      double rr = sqrt(-2*log(1-drand48()));
      double y = rr*cos(phi);
      v(i,j) = T2*y;
      Vaverage[j] += v(i,j);
    }
  // normalize
  Vaverage*=(1./N);
  // subtract the center of mass motion
  for (int i=0; i<N; i++)
    for (int j=0; j<3; j++) v(i,j) -= Vaverage[j];
}

double instantaneousTemperature(function2D<double>& v){
  double sum = 0;
  for (int i=0; i<v.size_N(); i++)
    for (int k=0; k<v.size_Nd(); k++)
      sum += sqr(v[i][k]);
  return sum/(3*(v.size_N() - 1));
}

void rescaleVelocities(double T, function2D<double>& v){
  double v2 = accumulate(v,sqr<double>);       // v2 = sum_{ij} v_ij^2
  double lambda = sqrt( 3*(v.size_N()-1)*T/v2);// scaling factor is sqrt(3*T/<v^2>)
  v *= lambda;
}


void Print(const function2D<double>& r, const function2D<double>& v, const function2D<double>& a, double L)
{
  using namespace std;
  for (int i=0; i<r.size_N(); i++){
    cout<<setw(5)<<i<<" "<<setw(25)<<r(i,0)/L<<" "<<setw(25)<<r(i,1)/L<<" "<<setw(25)<<r(i,2)/L<<"    ";
    cout<<setw(25)<<v(i,0)<<" "<<setw(25)<<v(i,1)<<" "<<setw(25)<<v(i,2)<<"    ";
    cout<<setw(25)<<a(i,0)<<" "<<setw(25)<<a(i,1)<<" "<<setw(25)<<a(i,2)<<"    ";
    cout<<endl;
  }
}

int cmp(const vector<double>& r1, const vector<double>& r2)// Only for sorting atoms according to their z coordinate
{ return r1[2]>r2[2];}

class PlotPrimitive{
  // Clas to simplify 3D plotting: it is used for 3D->2D projection
  double c1, c2;  // for projection
  double pixsize;
public:
  PlotPrimitive(double alpha, double theta, double pixsize_) :  pixsize(pixsize_)
  { // 3D->2D projection (Cavalier projection)
    c1 = cos(theta)/tan(alpha);
    c2 = sin(theta)/tan(alpha);
  }
  void D3D2_Project(double x, double y, double z, int& ixp, int& iyp)
  {
    ixp = static_cast<int>((x + c1*z)*pixsize);
    iyp = static_cast<int>((y + c2*z)*pixsize);
  }
  void Line(Plotter& plotter, double x0, double y0, double z0, double x1, double y1, double z1)
  {
    int ix0, ix1, iy0, iy1;
    D3D2_Project(x0,y0,z0,ix0,iy0);
    D3D2_Project(x1,y1,z1,ix1,iy1);
    plotter.line(ix0,iy0,ix1,iy1);
  }
  void Circle(Plotter& plotter, double x, double y, double z, double radius=0.45)
  {
    int ix, iy;
    D3D2_Project(x,y,z,ix,iy);
    plotter.circle(ix,iy,static_cast<int>(radius*pixsize));
  }
};

void Draw(double alpha, double theta, int pixsize, double L, Plotter& plotter, const function2D<double>& r)
{
  PlotPrimitive plot(alpha,theta,pixsize);
  plotter.erase(); // Clears plotting window
  // Plots box around the unit cell
  plotter.filltype(0);
  plot.Line(plotter, 0,0,L, L,0,L);
  plot.Line(plotter, L,0,L, L,L,L);
  plot.Line(plotter, L,L,L, 0,L,L);
  plot.Line(plotter, 0,L,L, 0,0,L);
  plot.Line(plotter, 0,0,0, 0,0,L);
  plot.Line(plotter, 0,L,0, 0,L,L);

  // atoms that have larger z coordinate should be plotted first to have the right stacking order in 3D plot
  // need to sort coordinates
  list<vector<double> > coord;
  for (int i=0; i<r.size_N(); i++){
    vector<double> r0(3);
    r0[0] = r[i][0];
    r0[1] = r[i][1];
    r0[2] = r[i][2];
    coord.push_back(r0);
  }
  coord.sort(cmp);

  // We will draw more distant atoms by dark color and closer by bright color
  // Need the minimum and maximum z coordinate
  double zmin = (*coord.begin())[2];
  double zmax = (*(--coord.end()))[2];
  
  plotter.filltype(1);
  for (list<vector<double> >::iterator ri=coord.begin(); ri!=coord.end(); ri++){
    int color = static_cast<int>(((*ri)[2]-zmin)/(zmax-zmin)*65530);
    plotter.fillcolor(color,color,color);// only grey colors used
    plot.Circle(plotter, (*ri)[0],(*ri)[1],(*ri)[2]);
  }
  
  // Plots the rest of the box around the unit cell
  plotter.filltype(0);
  plot.Line(plotter, L,0,0, L,0,L);
  plot.Line(plotter, L,L,0, L,L,L);
  plot.Line(plotter, 0,0,0, L,0,0);
  plot.Line(plotter, L,0,0, L,L,0);
  plot.Line(plotter, L,L,0, 0,L,0);
  plot.Line(plotter, 0,L,0, 0,0,0);
}

int main(int argc, char *argv[], char *env[])
{
  int N=64;// Number of particles
  double rho=1.0;// Density of particles (number per unit volume)
  double T=1.0;  // Temperature
  int MaxSteps=3000;
  double dt = 0.01; // Time-step
  int i=0;
  while (++i<argc){
    std::string str(argv[i]);
    if (str=="-N" && i<argc-1)  N = atoi(argv[++i]);
    if (str=="-rho" && i<argc-1) rho = atof(argv[++i]);
    if (str=="-dt" && i<argc-1) dt = atof(argv[++i]);
    if (str=="-Ms" && i<argc-1) MaxSteps = atoi(argv[++i]);
    if (str=="-h" || str=="--help"){
      std::clog<<"********* Molecular dynamics for argon **************\n";
      std::clog<<"**                                                **\n";
      std::clog<<"**      Copyright Kristjan Haule, 26.09.2005      **\n";
      std::clog<<"****************************************************\n";
      std::clog<<"\n";
      std::clog<<"dla [-N int] [-h]\n" ;
      std::clog<<"Options:   -N       Total number of particles ("<<N<<")\n";
      std::clog<<"           -rho     Particle density ("<<rho<<")\n";
      std::clog<<"           -Ms      Maximum number of steps ("<<MaxSteps<<")\n";
      std::clog<<"           -dt      Time-step ("<<dt<<")\n";
      std::clog<<"*****************************************************\n";
      return 0;
    }
  }

  double V = N/rho;
  double L = pow(V, 1./3.);
  
  /************** Initialization of plotter **************************/
  PlotterParams params; // set a Plotter parameter
  params.setplparam ("PAGESIZE", (char *)"letter");
  XPlotter plotter(cin, cout, cerr, params); // declare Plotter
  if (plotter.openpl () < 0){ // open Plotter      
    cerr << "Couldn't open Plotter\n";
    return 1;
  }
  int pixsize=500;
  clog<<"L="<<L<<" "<<L*pixsize<<endl;
  plotter.fspace (-pixsize*L, -pixsize*L, pixsize*L*3, pixsize*L*3); // specify user coor system
  /*************************************************************************/

  // Data structures to store position, velocity and acceleartion
  function2D<double> r(N,3), v(N,3), a(N,3);
  
  /************** Initialization of random number gen. *********************/
  int random_seed = time(0);
  srand48(random_seed);
  /*************************************************************************/
  Initialize(L,T,r,v);// Initialize the position and velocities of atoms

  Draw(0.25*M_PI, 0.25*M_PI, pixsize, L, plotter, r);
  for (int i=0; i<MaxSteps; i++){
    velocityVerlet(N, 3, dt, L, r, v, a, Force);// Actual simulation
    std::cout<<i<<" "<<instantaneousTemperature(v)<<std::endl;
    if  (i%50==0){
      Draw(0.25*M_PI, 0.25*M_PI, pixsize, L, plotter, r);
    }
    if (i%200==0) rescaleVelocities(T,v);
  }

  /************** Plotter Done**************************/
  clog<<"DONE"<<endl;
  if (plotter.closepl () < 0){ // close Plotter
    cerr << "Couldn't close Plotter\n";
    return 1;
  }
  /*****************************************************/
  return 0;
}
