#include <iostream>
#include <iomanip>
#include <vector>
#include <fstream>
#include <list>
#include <algorithm>
#include "numerov.h"
#include "integrate.h"
#include "kmesh.h"
#include "bessel.h"
#include "function.h"

extern "C" {
  void dsygvd_(const int* ITYPE, const char* JOBZ, const char* UPLO, const int* N,
               double* A, const int* LDA, double* B, const int* LDB, double* W, double* WORK, const int* LWORK,
               int* IWORK, const int* LIWORK, int* INFO);
  void dggev_(const char* JOBVL, const char* JOBVR, int* N, double* A, int* LDA, double* B, int* LDB,
	      double* ALPHAR, double* ALPHAI, double* BETA,
	      double* VL, int* LDVL, double* VR, int* LDVR, double* WORK, int* LWORK, int* INFO);
};

using namespace std;

int Eigensystem(int N, function1D<double>& Energy, const function2D<double>& Olap, function2D<double>& F)
{// C++ wrapper function for general eigenvalue problem
  static function2D<double> tOlap(N,N);
  int lwork = 1 + 6*N + 2*N*N + 10;
  static function1D<double> work(lwork);
  int liwork = 3 + 5*N + 10;
  static function1D<int> iwork(liwork);
  int itype=1;
  int info=0;
  tOlap = Olap;
  int lF = F.size_Nd();
  int lO = Olap.size_Nd();
  dsygvd_(&itype, "V", "U", &N, F.MemPt(), &lF, tOlap.MemPt(), &lO, Energy.MemPt(), work.MemPt(), &lwork, iwork.MemPt(), &liwork, &info);
  if (info) cerr<<"Not sucessfull solving eigenvalue problem with dsygvd "<<info<<endl;
  return info;
}

int GeneralEigenvalues(function1D<double>& Energy, const function2D<double>& Olap, function2D<double>& F)
{
  int N = Olap.size_N();
  static function2D<double> tOlap(N,N);
  int lwork = 16*N*N;
  static function1D<double> work(lwork);
  //  int liwork = 3 + 5*N + 10;
  //  static function1D<int> iwork(liwork);
  //  int itype=1;
  int info=0;
  tOlap = Olap;
  int lF = F.size_Nd();
  int lO = Olap.size_Nd();
  static function1D<double> alphar(N), alphai(N), beta(N);
  static function2D<double> VL(N,N), VR(N,N);
  dggev_("N", "N", &N, F.MemPt(), &lF, tOlap.MemPt(), &lO, alphar.MemPt(), alphai.MemPt(), beta.MemPt(),
	 VL.MemPt(),&N,VR.MemPt(),&N,work.MemPt(),&lwork,&info);
  if (info) cerr<<"Not sucessfull solving eigenvalue problem with dsygvd "<<info<<endl;
  for (int i=0; i<N; i++) {
    Energy[i] = alphar[i]/beta[i];
    if (fabs(alphai[i])>1e-5) cerr<<"Eigenvalues nonreal! "<<i<<endl;
  }
  sort(Energy.MemPt(),Energy.MemPt()+Energy.size());
  return info;
}
class FccLattice{ // Class for storing reciprocal lattice
  double LatConst, Volume;
  dvector3 a0, a1, a2;    // Primitive vectors of fcc lattice
  dvector3 b0, b1, b2;    // Primitive vectors of reciprocal lattice
  dvector3 GammaPoint, LPoint, KPoint, XPoint, WPoint; // Special points in 1IRB
  vector<dvector3> Kmesh, kmesh;
public:
  double Vol(){return Volume;}
  int Ksize(){return Kmesh.size();}
  int ksize(){return kmesh.size();}
  const dvector3& K(int i){return Kmesh[i];} // can not be changed, only read
  const dvector3& k(int i){return kmesh[i];} // can not be changed, only read
  FccLattice(double LatConst_) : LatConst(LatConst_)
  {
    a0 = dvector3(0.5*LatConst,0.5*LatConst,0);
    a1 = dvector3(0.5*LatConst,0,0.5*LatConst);
    a2 = dvector3(0,0.5*LatConst,0.5*LatConst);
    Volume = fabs(Vproduct(a0,a1,a2));// Volume
    clog<<"Volume is "<<Volume<<endl;
    b0 = (2*M_PI/Volume)*cross(a1,a0);
    b1 = (2*M_PI/Volume)*cross(a0,a2);
    b2 = (2*M_PI/Volume)*cross(a2,a1);
    // Special points in Brillouin zone
    double brs = 2*M_PI/LatConst;
    GammaPoint = dvector3(0,0,0);
    LPoint = dvector3(0.5*brs,0.5*brs,0.5*brs);
    KPoint = dvector3(0.75*brs,0.75*brs,0);
    XPoint = dvector3(1*brs,0, 0);
    WPoint = dvector3(1*brs,0.5*brs,0);
  }
  void GenerateReciprocalVectors(int q, double CutOffK)
  {
    // Many reciprocal vectors are generated and later only the sortest are used
    list<dvector3> Kmesh0;
    for (int n=-q; n<q; n++){
      for (int l=-q; l<q; l++){
	for (int m=-q; m<q; m++){
	  Kmesh0.push_back(n*b0+l*b1+m*b2);
	}
      }
    }
    Kmesh0.sort(cmp); // Sorting according to the length of vector. Shortest will be kept
    int Ksize=0;
    for (list<dvector3>::const_iterator l=Kmesh0.begin(); l!=Kmesh0.end(); l++,Ksize++) if (l->length()>CutOffK) break;
    Kmesh.resize(Ksize);
    int j=0;
    for (list<dvector3>::const_iterator l=Kmesh0.begin(); l!=Kmesh0.end() && j<Ksize; l++,j++) Kmesh[j]=*l;
    clog<<"K-mesh size="<<Kmesh.size()<<endl;
  }
  void ChoosePointsInFBZ(int nkp){// Chooses the path in the 1BZ we will use
    kmesh.resize(nkp);
    int N0=kmesh.size()/4;
    for (int i=0; i<N0; i++) kmesh[i]      = GammaPoint + (XPoint-GammaPoint)*i/(N0-1.);
    for (int i=0; i<N0; i++) kmesh[N0+i]   = XPoint + (LPoint-XPoint)*i/(N0-1.);
    for (int i=0; i<N0; i++) kmesh[N0*2+i] = LPoint + (GammaPoint-LPoint)*i/(N0-1.);
    for (int i=0; i<N0; i++) kmesh[N0*3+i] = GammaPoint + (KPoint-GammaPoint)*i/(N0-1.);
  }
};

// This is the parametrization for the effective potential from ....
double VeffP(double R)
{
  return 29*exp(-2.3151241717834*pow(R,0.81266614122432)+ 2.1984250222603e-2*pow(R,4.2246376280056))
    -0.15595606773483*R-3.1350051440417e-3*R*R+5.1895222293006e-2*pow(R,3)-2.8027608685637e-2*pow(R,4);
}

class PartialWave{// Class for solving SCH equation
  int Z;
  vector<double> Rmesh;
  vector<double> rhs_MT, ur, urp, temp, inhom; // For solving SCH equation
  vector<double> dlogPsi, dlogPsip, Psi, Psip, PsipPsip;
  vector<vector<double> > Psi_l, Psip_l;
public:
  PartialWave(int N, double RMuffinTin, int Z_, int lMax): Z(Z_), Rmesh(N), rhs_MT(N), ur(N), urp(N), temp(N), inhom(N),
							   dlogPsi(lMax+1), dlogPsip(lMax+1), Psi(lMax+1), Psip(lMax+1), PsipPsip(lMax+1),
							   Psi_l(lMax+1), Psip_l(lMax+1)
  {
    // Building equidistant radial mesh
    double dh = RMuffinTin/(N-1.);
    for (int i=0; i<N; i++) Rmesh[i] = i*dh;
    clog<<"RmuffinTin="<<Rmesh[Rmesh.size()-1]<<endl;
  }
  int Rsize(){return Rmesh.size();}
  double psi(int l){return Psi[l];}
  double psip(int l){return Psip[l];}
  double psi(int l, int r){ return Psi_l[l][r];}
  double psip(int l, int r){ return Psip_l[l][r];}
  double dlog(int l){return dlogPsi[l];}
  double dlogp(int l){return dlogPsip[l];}
  double PP(int l){return PsipPsip[l];}
  double startSol(int Z, int l, double r) // good choice for starting Numerov algorithm
  { return pow(r,l+1)*(1-Z*r/(l+1));}
  double SolveSCHEquation(double Enu)
  {
    for (int l=0; l<Psi.size(); l++){
      rhs_MT[0]=0;
      for (int i=1; i<Rmesh.size(); i++){
	double Veff = -VeffP(Rmesh[i])/Rmesh[i] + 0.5*l*(l+1)/sqr(Rmesh[i]);
	rhs_MT[i] = 2*(Veff-Enu);
      }
      double dh = Rmesh[1]-Rmesh[0];
      ur[0]=0;
      ur[1]=startSol(Z, l, dh);
      // Solving radial SCH equation
      Numerov(rhs_MT, ur.size(), dh, ur);
      // Normalizing the result 
      for (int i=0; i<Rmesh.size(); i++) temp[i] = sqr(ur[i]);
      double norm = 1./sqrt(integrate4<double>(temp, dh, temp.size()));
      for (int i=0; i<ur.size(); i++) ur[i] *= norm;
      Psi_l[l] = ur;// storing for future (density)
      for (int ir=1; ir<Rmesh.size(); ir++) Psi_l[l][ir]/=Rmesh[ir]; // Psi=u/r!
      Psi_l[l][0] = (Psi_l[l][1]*(Rmesh[2]-Rmesh[0])-Psi_l[l][2]*(Rmesh[1]-Rmesh[0]))/(Rmesh[2]-Rmesh[1]);//extrapolation to zero
      // Energy derivative of Psi
      for (int i=0; i<Rmesh.size(); i++) inhom[i] = -2.*ur[i];
      urp[0]=0;
      urp[1]=startSol(Z,l,dh);
      NumerovGen(rhs_MT, inhom, urp.size(), dh, urp);
      for (int i=0; i<ur.size(); i++) temp[i] = ur[i]*urp[i];
      double alpha = integrate4<double>(temp, dh, temp.size());
      for (int i=0; i<ur.size(); i++) urp[i] -= alpha*ur[i];
      Psip_l[l] = urp;      // Store it
      for (int ir=1; ir<Rmesh.size(); ir++) Psip_l[l][ir]/=Rmesh[ir];// Psip=up/r!
      Psip_l[l][0] = (Psip_l[l][1]*(Rmesh[2]-Rmesh[0])-Psip_l[l][2]*(Rmesh[1]-Rmesh[0]))/(Rmesh[2]-Rmesh[1]);// extrapolating
    
      for (int i=0; i<urp.size(); i++) temp[i] = urp[i]*urp[i];
      PsipPsip[l] = integrate4<double>(temp, dh, temp.size()); // (Psip,Psip)

      // Here we estimate the derivatives at the Muffin-Tin boundary
      int N0 = ur.size()-1;
      double RMuffinTin = Rmesh[N0];
      double v1 = rhs_MT[N0]*ur[N0];
      double v0 = rhs_MT[N0-1]*ur[N0-1];
      double w1 = rhs_MT[N0]*urp[N0]+inhom[N0];
      double w0 = rhs_MT[N0-1]*urp[N0-1]+inhom[N0-1];
      double dudr  = (ur[N0]-ur[N0-1])/dh + 0.125*dh*(3*v1+v0);
      double dupdr = (urp[N0]-urp[N0-1])/dh + 0.125*dh*(3*w1+w0);
      dlogPsi[l]  = RMuffinTin*dudr/ur[N0] -  1;
      dlogPsip[l] = RMuffinTin*dupdr/urp[N0] - 1;
      Psi[l]  = ur[N0]/RMuffinTin;
      Psip[l] = urp[N0]/RMuffinTin;
      
      /// TEST_ONLY
      //      clog<<"This quantity is "<<(dlogPsi[l]-dlogPsip[l])*RMuffinTin*Psi[l]*Psip[l]<<endl;
    }
  }
};

int main(int argc, char *argv[], char *env[])
{
  int Z=29;                    // Number of electrons in the Cu atom
  double E_start = -0.2;      // Where to start searching for Energy
  double dE = 1e-3;            // Step in serching for bound states
  int nE = 500;                // Number of enrgy steps when searching for energy bands
  double LatConst = 6.8219117; // Lattic constant
  double RMuffinTin = 2.41191; // Muffin tin radius - Touching spheres
  int lMax=5;                  // Maximum l considered in calculation
  int N = 1001;                // Number of points in radial mesh
  int nkp = 40;                // Number of k-points in 1BZ
  double CutOffK=3.;           // Largest lengt of reciprocal vectors K (only shorter vec. are taken into account)
  
  int i=0;
  while (++i<argc){
    std::string str(argv[i]);
    if (str=="-dE" && i<argc-1) dE = atof(argv[++i]);
    if (str=="-h" || str=="--help"){
      std::clog<<"**************** APW program for Cu-fcc **************\n";
      std::clog<<"**                                                  **\n";
      std::clog<<"**      Copyright Kristjan Haule, 18.10.2005        **\n";
      std::clog<<"******************************************************\n";
      std::clog<<"\n";
      std::clog<<"apwCu [-dE double] [] []\n" ;
      std::clog<<"Options:   -Z          Number of electrons ("<<Z<<")\n";
      std::clog<<"           -dE         Step in searching for states ("<<dE<<")\n";
      std::clog<<"*****************************************************\n";
      return 0;
    }
  }

    
  clog.precision(10);
  // For solving SCH equation
  PartialWave wave(N, RMuffinTin, Z, lMax);
  // Generates and stores momentum points
  FccLattice fcc(LatConst);                  // Information about lattice
  fcc.GenerateReciprocalVectors(4, CutOffK); // Reciprocal bravais lattice is builded
  fcc.ChoosePointsInFBZ(nkp);                // Chooses the path in the 1BZ we will use


  double Enu = 0, VKSi = 0;
  // Here Schroedinger equation is solved with energy Enu
  wave.SolveSCHEquation(Enu);
  
  // Storage for matrices
  function2D<double> Olap_I(fcc.Ksize(),fcc.Ksize()); // Overlap in interstitials
  function2D<double> Olap(fcc.Ksize(),fcc.Ksize()), Ham(fcc.Ksize(),fcc.Ksize());

  // Overlap in the interstitials can be calculated outside
  for (int i=0; i<fcc.Ksize(); i++){
    Olap_I(i,i) = 1 - 4*M_PI*sqr(RMuffinTin)*RMuffinTin/(3.*fcc.Vol());
    for (int j=i+1; j<fcc.Ksize(); j++){
      double KKl = (fcc.K(i)-fcc.K(j)).length();
      Olap_I(i,j) = -4*M_PI*sqr(RMuffinTin)*bessel_j(1,KKl*RMuffinTin)/(KKl*fcc.Vol());
      Olap_I(j,i) = Olap_I(i,j);
    }
  }


    
  // For electron density
  vector<vector<vector<double> > > PsiPsip(fcc.Ksize());
  for (int iK=0; iK<fcc.Ksize(); iK++){
    PsiPsip[iK].resize(wave.Rsize());
    for (int ir=0; ir<wave.Rsize(); ir++)
      PsiPsip[iK][ir].resize(lMax+1);
  }
  function2D<double> omegal(fcc.Ksize(),lMax+1), C1(fcc.Ksize(),lMax+1);
  vector<double> C2l(lMax+1);
  vector<function2D<double> > RhoMT(wave.Rsize());
  for (int ir=0; ir<wave.Rsize(); ir++) RhoMT[ir].resize(fcc.Ksize(),fcc.Ksize());
  function1D<double> Energy(fcc.Ksize());

  
  // Main loop over k points
  for (int ik=0; ik<fcc.ksize(); ik++){
    dvector3 k = fcc.k(ik);
    clog<<"k="<<ik<<endl;
    
    // Bessel functions can be calculated only ones for each K-point
    for (int iK=0; iK<fcc.Ksize(); iK++)
      for (int il=0; il<=lMax; il++){
	double Dl, jl;
	dlog_bessel_j(il, (k+fcc.K(iK)).length()*RMuffinTin, Dl, jl);
	omegal(iK,il) = -wave.psi(il)/wave.psip(il)*(Dl-wave.dlog(il))/(Dl-wave.dlogp(il));
	C1(iK,il) = sqrt(4*M_PI*(2*il+1)/fcc.Vol())*jl/(wave.psi(il)+omegal(iK,il)*wave.psip(il));

	for (int ir=0; ir<wave.Rsize(); ir++)
	  PsiPsip[iK][ir][il] = wave.psi(il,ir)+omegal(iK,il)*wave.psip(il,ir);
      }
    // Parts of the Hamiltonian matrix which do not depend on energy, are calculated
    for (int iK=0; iK<fcc.Ksize(); iK++){
      for (int jK=0; jK<fcc.Ksize(); jK++){
	dvector3 qi(k+fcc.K(iK));
	dvector3 qj(k+fcc.K(jK));
	
	double qi_len = qi.length();
	double qj_len = qj.length();
	double argv = (qi_len*qj_len==0) ? 1. : qi*qj/(qi_len*qj_len);
	
	double olapMT=0, hamMT=0;
	for (int il=0; il<=lMax; il++){
	  C2l[il] = C1(iK,il)*C1(jK,il)*Legendre(il,argv);
	  olapMT += C2l[il]*(1.+omegal(iK,il)*omegal(jK,il)*wave.PP(il));
	  hamMT  += C2l[il]*0.5*(omegal(iK,il)+omegal(jK,il));
	}
	Olap(iK,jK) = olapMT + Olap_I(iK,jK);
	Ham(iK,jK) = (0.25*(qi*qi+qj*qj) + VKSi - Enu)*Olap_I(iK,jK) + hamMT;
	/*
	for (int ir=0; ir<wave.Rsize(); ir++){
	  double sum=0;
	  for (int il=0; il<=lMax; il++) sum += C2l[il]*PsiPsip[iK][ir][il]*PsiPsip[jK][ir][il];
	  RhoMT[ir][iK][jK] = sum/(4*M_PI);
	}
	*/
      }
    }
    /*
    for (int i=0; i<fcc.Ksize(); i++)
      for (int j=0; j<fcc.Ksize(); j++)
	if (fabs(Olap(i,j)-Olap(j,i))>1e-6) cerr<<"Overlap not hermitian"<<endl;
    for (int i=0; i<fcc.Ksize(); i++)
      for (int j=0; j<fcc.Ksize(); j++)
	if (fabs(Ham(i,j)-Ham(j,i))>1e-6) cerr<<"Hamiltonian not hermitian"<<endl;
    */
    Eigensystem(fcc.Ksize(), Energy, Olap, Ham);
    
    cout<<setw(10)<<ik/static_cast<double>(fcc.ksize()-1)<<" ";
    for (int iK=0; iK<Energy.size(); iK++) cout<<setw(12)<<Energy[iK]<<" ";
    cout<<endl;
  }
  return 0;
}
