#include <iostream>
#include <iomanip>
#include <vector>
#include "numerov.h"
#include "zeroin.h"
#include "integrate.h"

using namespace std;

class RadialWave{
  int N;
  vector<double> R;        // Radial mesh but reversed R[0]=Rmax and R.last=0
  vector<double> Solution; // Non-normalized solution of Schroedinger equation
  vector<double> rhs;      // Right-hand-site in solving Schroedinger equation
  vector<double> Veff;     // Effective KS potential
  vector<double> Veff0;    // Effective KS potential without centrifugal part
public:
  RadialWave(const vector<double>& Rmesh) : N(Rmesh.size()), Solution(N), R(N), rhs(N), Veff(N), Veff0(N)
  {
    for (int i=0; i<N; i++) R[i] = Rmesh[N-1-i]; // The mesh is reversed
    Solution[0] = R[0]*exp(-R[0]); // Boundary (starting) points of integration by Numerov
    Solution[1] = R[1]*exp(-R[1]); // Boundary (starting) points of integration by Numerov
  }
  // This function-operator is used to find bound states. It is given to root-finding routine
  double operator()(double E)
  {
    double h = R[1]-R[0];
    for (int i=0; i<N-1; i++) rhs[i] = 2*(Veff[i]-E); // The RHS of the SCHR-equation for choosen energy E
    rhs[R.size()-1]=0;                                // This is the zero frequency
    Numerov(rhs, Solution.size()-1, h, Solution);     // Solving the radial SCH-equation
    int last = Solution.size()-1;                     // The last point at zero frequency needs extrapolation
    Solution[last] = Solution[last-1]*(2+h*h*rhs[last-1])-Solution[last-2];
    return Solution[last];                            // Value at zero frequency
  }
  // This function return the density of electrons (up+down per volume) of one (nl) state.
  void Density(vector<double>& rho)
  {
    rho.resize(Solution.size());
    int N = Solution.size();
    for (int i=0; i<Solution.size(); i++) rho[i] = Solution[N-1-i]*Solution[N-1-i]; // The mesh outside this class is reversed!
    double norm = 1./integrate4<double>(rho, R[0]-R[1], rho.size());                // Normalization constant 
    for (int i=1; i<rho.size(); i++) rho[i] = rho[i]*norm/(4*M_PI*sqr(R[N-i-1]));   // rho_{nl}=u^2/(4*Pi*r^2)
    rho[0] = 2*rho[1]-rho[2];                                                       // extrapolation to zero frequency
  }
  // This function sets KS potential without centrifugal part
  void SetVeff0(const vector<double>& Uhartree, const vector<double>& Vxc, int Z)
  { for (int i=0; i<R.size()-1; i++) Veff0[i] = (-Z+Uhartree[N-1-i])/R[i]+Vxc[N-1-i];}
  void AddCentrifugal(int l)
  { for (int i=0; i<R.size()-1; i++) Veff[i] = Veff0[i] + 0.5*l*(l+1)/sqr(R[i]);}
  void SetVeff(const vector<double>& Uhartree, const vector<double>& Vxc, int l, int Z)
  { SetVeff0(Uhartree,Vxc,Z);
    AddCentrifugal(l); }
  double V_KS0(int i){return Veff0[N-1-i];}
};

class ExchangeCorrelation{
//******************************************************************************/
//  Calculates Exchange&Correlation Energy and Potential                       */ 
//  type=0 - due to U.von.Barth and L.Hedin, J.Phys.C5, 1629 (1972)            */
//  type=1 - O.E.Gunnarsson and S.Lundqvist,  Phys.Rev.B                       */
//  type=2 - V.L.Moruzzi, J.F.Janak, and A.R.Williams, Calculated              */
//           Electronic Properties of Metals (New York, Pergamon Press, 1978)  */
//  type=3 - S.H.Vosko, L.Wilk, and M.Nusair, Can.J.Phys.58, 1200 (1980)       */
//  type=4 - Correlation of Perdew and Wang 1991                               */
//******************************************************************************/
  int type;
  double A, C;
  static const double alphax = 0.610887057710857;//(3/(2 Pi))^(2/3)
  static const double Aw = 0.0311;
  static const double Bw = -0.048;
  static const double Cw = 0.002;
  static const double D  = -0.0116;
  static const double gamma  = -0.1423;
  static const double beta1  =  1.0529;
  static const double beta2  =  0.3334;
  static const double Ap  =  0.0621814;
  static const double xp0 = -0.10498;
  static const double bp  =  3.72744;
  static const double cp  =  12.9352;
  static const double Qp  =  6.1519908;
  static const double cp1 =  1.2117833;
  static const double cp2 =  1.1435257;
  static const double cp3 = -0.031167608;
public:
  ExchangeCorrelation(int type_) : type(type_)
  {
    switch(type){
    case 0: C = 0.0504; A = 30; break;
    case 1: C = 0.0666; A = 11.4; break;
    case 2: C = 0.045;  A = 21; break;
    }
  };
  double Vx(double rs){return -alphax/rs;}
  double ExVx(double rs){return 0.25*alphax/rs;}
  double Ex(double rs){return -0.75*alphax/rs;}
  double Vc(double rs)
  {
    if (type<3){
      double x = rs/A;
      return -0.5*C*log(1+1/x);
    }else if(type<4){// type=3 WVN
      double x=sqrt(rs);
      double xpx=x*x+bp*x+cp;
      double atnp=atan(Qp/(2*x+bp));
      double ecp = 0.5*Ap*(log(x*x/xpx)+cp1*atnp-cp3*(log(sqr(x-xp0)/xpx)+cp2*atnp));
      return ecp - Ap/6.*(cp*(x-xp0)-bp*x*xp0)/((x-xp0)*xpx);
    }else{
      if (rs>1) return gamma/(1+beta1*sqrt(rs)+beta2*rs)*(1+7/6.*beta1*sqrt(rs)+beta2*rs)/(1+beta1*sqrt(rs)+beta2*rs);
      else return Aw*log(rs)+Bw-Aw/3.+2/3.*Cw*rs*log(rs)+(2*D-Cw)*rs/3.;
    }
  }
  double EcVc(double rs)
  {
    if (type<3){
      double x = rs/A;
      double epsilon = -0.5*C*((1+x*x*x)*log(1+1/x)+0.5*x-x*x-1/3.);
      return epsilon-Vc(rs);
    } else if (type<4){// type=3 WVN
      double x=sqrt(rs);
      return Ap/6.*(cp*(x-xp0)-bp*x*xp0)/((x-xp0)*(x*x+bp*x+cp));
    }else{
      if (rs>1) return 2*gamma/(1+beta1*sqrt(rs)+beta2*rs)-Vc(rs);
      else return Aw*log(rs)+Bw+Cw*rs*log(rs)+D*rs-Vc(rs);
    }
  }
};

// Calculates numer of principal quantum numbers n to be used for certain Z
int NumberOfPrincipal(int Z)
{
  double v = pow((54*Z+sqrt(-3+2916.*Z*Z))*3.,1/3.);
  int n0 = static_cast<int>((-3+3/v+v)/6.+0.99);
  if (Z>36) n0++;                // Because 5s is filled before 4f
  return n0;
}

// Class for storing the quantum numbers n,l, and Energy together
// It is important to be able to sort those states by the energy
class BState{
public:
  int n, l;
  double E;
  BState(){E=0;}
  BState(int n_, int l_, double E_): n(n_), l(l_), E(E_){};
};
// Class for comparison of above classes. Needed to sort bound states by the energy
class Cmp{
public:
  int operator()(const BState& a, const BState& b){return a.E<b.E;}
};

void SolvePoisson(int Zq, const vector<double>& Rmesh, const vector<double>& rho, vector<double>& Uhartree)
{// Given the input density rho, calculates the Hartree potential
  static vector<double> RHS(Rmesh.size());
  for (int i=0; i<Rmesh.size(); i++) RHS[i] = -4*M_PI*Rmesh[i]*rho[i];
  Uhartree[0]=0;  Uhartree[1]=(Rmesh[1]-Rmesh[0]);// Boundary condition for U_H=V_H/r
  NumerovInhom(RHS, RHS.size(), Rmesh[1]-Rmesh[0], Uhartree); // Solving the 2nd order differential equation
  // adding homogeneous solution to satisfay boundary conditions: U(0)=0, U(infinity)=Z
  int ilast = Uhartree.size()-1;
  double U_last = Uhartree[ilast];
  double alpha = (Zq - U_last)/Rmesh[ilast];
  for (int i=0; i<Rmesh.size(); i++) Uhartree[i] += alpha*Rmesh[i];
}

/////////////////////////////////////////////
////// n=0	n=1	n=2	n=3	.....
//////---------------------------------------
////// l=0	l=0	l=0	l=0	.....
//////		l=1	l=1	l=1	.....
//////			l=2	l=2	.....
//////				l=3
////// E(n+1,l) > E(n,l) therefore we need to keep
////// track of E(n,l).
////// We start with E(n,l=n) > -0.5*Z^2/(l+1)^2-3.
//////////////////////////////////////////////
void FindBoundStates(int n0, int Z, double dEz, RadialWave& wave, vector<BState>& states)
{// Searches for bound state with given n,l. They are stored in vector<BState>
  for (int i=0; i<states.size(); i++) states[i].E=100;// Sets some high energy not to mix with bound states
  int j=0;
  for (int l=0; l<n0; l++){
    wave.AddCentrifugal(l);                        // Adds centrifugal part to effective KS potential
    double x = -0.5*Z*Z/sqr(l+1)-3.;               // starting guess for the lowest bound state
    for (int n=l; n<n0; n++){
      double v0 = wave(x), v1=v0;
      while(x<10.){                                 // Looks for zero even at positive frequencies somethimes
	x+=dEz;                                     // Proceeding in small steps to bracket all zeros
	v1 = wave(x);                               // New value of radial function at origin
	if (v0*v1<0) {                              // Changs sign?
	  double zero = zeroin(x-dEz,x,wave,1e-10); // Root-finder locates bound state very precisely
	  x = zero+1e-7;
	  states[j++] = BState(n,l,zero);           // Stores solution
	  clog<<"Found solutin for n="<<n<<" l="<<l<<" at "<<zero<<endl;
	  if (j>=states.size()) return;
	  break;
	}
      }
    }
  }
}

double BuildNewRho(int Z, const vector<BState>& states, RadialWave& wave, vector<double>& nrho)
{// Knowing the energies of eigenstates, finds chemical potential and new charge density
  static vector<double> drho(nrho.size());
  for (int k=0; k<nrho.size(); k++) nrho[k]=0;
  double Eb=0;        // Sum of eigenvalues
  int Nt=0;    // number of electrons added
  for (int k=0; k<states.size(); k++){
    int l = states[k].l;
    int dN = 2*(2*l+1);  // degeneracy of each radial wave level
    double ferm = Nt+dN<=Z  ? 1 : (Z-Nt)/(2.*(2.*l+1)); // if shell is not fully-filled, take only part of charge
    wave.AddCentrifugal(states[k].l);
    wave(states[k].E);
    wave.Density(drho);
    for (int om=0; om<nrho.size(); om++) nrho[om] += drho[om]*2*(2*l+1)*ferm;
    Eb += 2*(2*l+1)*ferm*states[k].E; // Sum of eigenvalues times degeneracy
    Nt += dN;
    clog<<"Adding orbital n="<<states[k].n<<" l="<<states[k].l<<" E="<<states[k].E<<" "<<Nt<<endl;
    if (Nt>=Z) break; // Finish when enough electrons added
  }
  return Eb;
}

inline double rs(double rho)
{return pow(3/(4*M_PI*rho),1/3.);}

int main(int argc, char *argv[], char *env[])
{
  int Z=2;                  // Number of electrons in the atom: Z=1,2,10,18,36
  double precision = 1e-10; // Stops when dE<precision
  int MaxSteps=100;         // Maximum number of SC steps
  double Rmax = 10.;        // Maximum Radius of integration
  double h = 0.001;         // Step in solving differential equation (dr of radial equidistant mesh)
  double admix = 0.3;       // Mixing parameter of linear mixing
  double dEz = 0.1;         // Step in serching for bound states
  bool plot = false;        // Whether to plot output density
  bool addShell = false;
  int i=0;
  while (++i<argc){
    std::string str(argv[i]);
    if (str=="-Z" && i<argc-1) Z = atoi(argv[++i]);
    if (str=="-Rmax" && i<argc-1) Rmax = atof(argv[++i]);
    if (str=="-dR" && i<argc-1) h = atof(argv[++i]);
    if (str=="-dE" && i<argc-1) dEz = atof(argv[++i]);
    if (str=="-admix" && i<argc-1) admix = atof(argv[++i]);
    if (str=="-precision" && i<argc-1) precision = atof(argv[++i]);
    if (str=="-MaxSteps" && i<argc-1) MaxSteps = atoi(argv[++i]);
    if (str=="-addShell") addShell = true;
    if (str=="-plot") plot=true;
    if (str=="-h" || str=="--help"){
      std::clog<<"**************** LDA program for atoms ***************\n";
      std::clog<<"**                                                  **\n";
      std::clog<<"**      Copyright Kristjan Haule, 18.10.2005        **\n";
      std::clog<<"******************************************************\n";
      std::clog<<"\n";
      std::clog<<"atomLDA [-precision double] [] []\n" ;
      std::clog<<"Options:   -Z          Number of electrons ("<<Z<<")\n";
      std::clog<<"           -Rmax       Maximum radius in radial mesh ("<<Rmax<<")\n";
      std::clog<<"           -dR         Distance between equidistant point of radial mesh ("<<h<<")\n";
      std::clog<<"           -dE         Step in searching for bound states ("<<dEz<<")\n";
      std::clog<<"           -admix      Mixing parameter for linear mixing ("<<admix<<")\n";
      std::clog<<"           -precision  Stops when dE/E<precision ("<<precision<<")\n";
      std::clog<<"           -MaxSteps   Maximum number of SC steps ("<<MaxSteps<<")\n";
      std::clog<<"           -addShell   Adds one more main quantum number ("<<addShell<<")\n";
      std::clog<<"           -plot       Whether to plot outpur density ("<<plot<<")\n";
      std::clog<<"*****************************************************\n";
      return 0;
    }
  }

  int N = static_cast<int>(Rmax/fabs(h)+1.5); // Number of radial points
  // Building equidistant radial mesh
  vector<double> Rmesh(N);
  for (int i=0; i<N; i++) Rmesh[i] = i*h;
  Rmax = Rmesh[Rmesh.size()-1];
  
  RadialWave wave(Rmesh); // Basic class for solving radial Schroedinger equation

  vector<double> Uhartree(Rmesh.size()), Vxc(Rmesh.size()); // Hartree and exchange-correlation potential
  // Density, its partial{nl} contribution and new densiry
  vector<double> rho(Rmesh.size()), nrho(Rmesh.size()); 
  vector<double> tmp(Rmesh.size());  // temporarily to calculate energies
  
  
  int n0 = NumberOfPrincipal(Z); // From Z, one can calculate which main quantum numbers n need to be incuded 
  if (addShell) n0++;
  
  vector<BState> states(static_cast<int>((Z+1)/2)+5); // States to be sorted for lowest few to build density from
  
  for (int i=0; i<Rmesh.size(); i++) {Uhartree[i]=0; Vxc[i]=0;}// Set Vxc and U_H to zero at start
  for (int i=0; i<Rmesh.size(); i++) rho[i]=exp(-Z/2.*Rmesh[i])*sqr(sqr(Z))/(64.*M_PI);// starting guess for density
  
  double Excor=0, Etot=0, pEtot=0, Eb=0, Ehartree, Epotential;
  ExchangeCorrelation XC(3);// WVN seems to be the best (look http://physics.nist.gov/PhysRefData/DFTdata/Tables/ptable.html)

  clog.precision(8);
  // Main iteration loop for self-consistency
  for (int itt=0; itt<MaxSteps; itt++){
    clog<<"***** Step "<<itt<<" *****"<<endl;
    SolvePoisson(Z, Rmesh, rho, Uhartree);// Given the input density rho, calculates the Hartree potential
    
    // Adding exchange-correlation part
    for (int i=1; i<Rmesh.size(); i++){
      double rs_ = rs(rho[i]);
      Vxc[i] = XC.Vx(rs_) + XC.Vc(rs_);
    }

    wave.SetVeff0(Uhartree, Vxc, Z);
    
    FindBoundStates(n0, Z, dEz, wave, states);
    Cmp cmp;
    if (itt>0) sort(states.begin(),states.end(),cmp); // Sort states (first time they are degenerate)
    Eb = BuildNewRho(Z, states, wave, nrho);// Gives new density and sum of eigenvalues
    
    // E_hartree = -1/2*V_hartree*n
    for (int i=0; i<Rmesh.size(); i++) tmp[i] = -0.5*4*M_PI*Rmesh[i]*Uhartree[i]*nrho[i];
    Ehartree = integrate4<double>(tmp,Rmesh[1]-Rmesh[0],tmp.size());
    
    // Adding exchange-correlation energy part
    tmp[0]=0;
    for (int i=1; i<Rmesh.size(); i++){
      double rs_ = rs(rho[i]);
      tmp[i] = 4*M_PI*sqr(Rmesh[i])*nrho[i]*(XC.ExVx(rs_)+XC.EcVc(rs_));
    }
    Excor = integrate4<double>(tmp,Rmesh[1]-Rmesh[0],tmp.size());
    // Enuclei = <-Z/r>
    for (int i=0; i<Rmesh.size(); i++) tmp[i] = -Z*rho[i]*4*M_PI*Rmesh[i];
    double Enuclei = integrate4<double>(tmp,Rmesh[1]-Rmesh[0],tmp.size());
    // Epotential = <V_KS>
    for (int i=0; i<Rmesh.size(); i++) tmp[i]  = 4*M_PI*sqr(Rmesh[i])*rho[i]*wave.V_KS0(i);
    Epotential = integrate4<double>(tmp,Rmesh[1]-Rmesh[0],tmp.size());

    // Total energy on output density
    Etot = Eb+Ehartree+Excor;
    double Ekin = Eb-Epotential;
    double Ecoul = Enuclei-Ehartree;
    clog<<left<<"Total E="<<setw(10)<<Etot<<" Ekin="<<Ekin<<"  E-Coulomb="<<setw(10)<<-Ehartree<<" Exch+Corr' E="<<setw(10)<<Excor<<" Epotential="<<Epotential<<" Enuclei="<<Enuclei<<" Ecoulomb'="<<Ecoul<<endl;
    
    for (int k=0; k<Rmesh.size(); k++) rho[k] = (1.-admix)*rho[k]+admix*nrho[k];// linear mixing!
    if (fabs((Etot-pEtot)/Etot)<precision) break;                               // if desired accuracy met, finish
    pEtot=Etot;
  }
  if (plot) for (int i=0; i<Rmesh.size(); i++) cout<<Rmesh[i]<<" "<<4*M_PI*sqr(Rmesh[i])*rho[i]<<endl;
  return 0;
}
