#ifndef FUNCTION_
#define FUNCTION_
#include <iostream>
#include <algorithm>
#include <limits>
#include <complex>
#include "sutil.h"
#include "sblas.h"

//////////////////////////// Hierarchy of classes: ///////////////////////////////////////
//                             
//                                     base clas  = function<>
//                                               |                                   
//                                          function1D<>                             
//

//*****************************************//
// Classes also used in this header file   //
//*****************************************//
class intpar;

typedef std::complex<double> dcomplex;

//*******************************************************//
// Classes and functions implemented in this header file //
//*******************************************************//
template<class T> class function;
template<class T> class function1D;
template<class T> class spline1D;
//********************************************************************************//
// Base class for two derived classes: function1D<T> and function2D<T>.	  	  //
// It is also used as a proxy class for function2D. Function2D<T> consists of	  //
// arrays of function<T> rather than functions1D<T>.				  //
// Memory is allocated in a fortran-like fashion for better performance.	  //
// Linear interpolation is implemented with the operator() that takes one	  //
// argument (class intpar).							  //
//********************************************************************************//
template<class T>
class function{
protected:
  T *f;
  int N0, N;
  function() : f(NULL), N0(0), N(0) {};        // constructor is made protected such that mesh can not be instantiated
  explicit function(int N_) : N0(N_), N(N_) {};// This class is used only as base class
  ~function(){};
  function(const function&){};
public:
  // OPERATORS
  T& operator[](int i) {Assert(i<N,"Out of range in function[]"); return f[i];}
  const T& operator[](int i) const {Assert(i<N,"Out of range in function[]"); return f[i];}
  T operator()(const intpar& ip) const; // linear interpolation
  function& operator+=(const function& m);
  function& operator*=(const T& m);
  function& operator=(const T& c); // used for initialization
  // SHORT FUNCTIONS
  const T& last() const {return f[N-1];}
  int size() const { return N;}
  int fullsize() const { return N0;}
  T* MemPt() {return f;}
  const T* MemPt() const{return f;}
  // OTHER FUNCTIONS
  T accumulate();
  template <class U> friend class spline1D;
};

//******************************************************************//
// One dimensional functions derived from function<T>. It has it's  //
// own constructors and destructors.				    //
//******************************************************************//
template <class T>
class function1D : public function<T>{
public:
  // CONSTRUCTORS AND DESTRUCTORS
  function1D(){};                 // default constructor exists for making arrays
  explicit function1D(int N_);    // basic constructor
  ~function1D();                  // destructor
  function1D(const function1D& m);//copy constructor
  // INITIALIZATION ROUTINES
  void resize(int N_);
  // OPERATORS
  function1D& operator=(const function1D& m); // copy constructor
  function1D& operator=(const T& c) {function<T>::operator=(c); return *this;}// used for initialization
};
//************************************************************************//
// One dimensional spline function derived from function<T>. It has it's  //
// own constructors and destructors.				          //
//************************************************************************//
template <class T>
class spline1D : public function<T>{
  T* f2;       // second derivatives
  double *dxi; // x_{j+1}-x_j
public:
  // CONSTRUCTORS AND DESTRUCTORS
  spline1D() : f2(NULL), dxi(NULL) {};// default constructor exists for allocating arrays
  //  explicit spline1D(int N_);    // basic constructor
  // constructor
  spline1D(const mesh1D& om, const function<T>& fu, // the function being splined
	   const T& df0=std::numeric_limits<double>::max(), // the derivatives at both ends
	   const T& dfN=std::numeric_limits<double>::max()); 
  ~spline1D();                  // destructor
  spline1D(const spline1D& m); //copy constructor
  // INITIALIZATION ROUTINES
  void resize(int N_);
  // OPERATORS
  T operator()(const intpar& ip) const; // spline interpolation
  spline1D& operator=(const spline1D& m); // copy operator
  //  spline1D& operator=(const T& c) {function<T>::operator=(c); return *this;}// used for initialization
  // ADVANCED FUNCTIONS
  T integrate();
  dcomplex Fourier(double om, const mesh1D& xi);
};

// function ////////////////////////////////////////////////////////////////
// Routine for linear interpolation
template<class T>
inline T function<T>::operator()(const intpar& ip) const
{ return  f[ip.i]+ip.p*(f[ip.i+1]-f[ip.i]);}

template<class T> 
inline function<T>& function<T>::operator+=(const function& m)
{
  _LOG(if (N!=m.size()) cerr << "Functions not of equal length! Can't sum!" << std::endl;)
  for (int i=0; i<N; i++) f[i] += m[i];
  return *this;
}

template<class T>
inline function<T>& function<T>::operator*=(const T& m)
{
  for (int i=0; i<N; i++) f[i] *= m;
  return *this;
}

template <class T>
inline T function<T>::accumulate()
{
  T sum(0);
  for (int i=0; i<N; ++i) sum += f[i];
  return sum;
}

template <class T>
inline function<T>& function<T>::operator=(const T& c)
{
  _LOG(if (N<=0) cerr << "Size of function is non positive! "<<N<<std::endl;)
  for (int i=0; i<N; i++) f[i] = c;
  return *this;
}

// function1D ////////////////////////////////////////////////////////////
template<class T>
inline function1D<T>::function1D(int N_) : function<T>(N_)
{ f = new T[N_];}

template<class T>
inline function1D<T>::~function1D()
{ delete[] f;  f = NULL; N=0; N0=0;}

template<class T>
inline void function1D<T>::resize(int n)
{
  if (n>N0){
    if (f) delete[] f;
    f = new T[n];
    N0=n;
  }
  N = n;
}

template<class T>
inline function1D<T>::function1D(const function1D& m)
{
  resize(m.N);
  std::copy(m.f,m.f+N,f);
}

template <class T>
inline function1D<T>& function1D<T>::operator=(const function1D<T>& m)
{
  resize(m.N);
  std::copy(m.f,m.f+N,f);
  return *this;
}

// spline1D ////////////////////////////////////////////////////////////
// template<class T>
// inline spline1D<T>::spline1D(int N_) : function<T>(N_)
// {
//   f = new T[N_];
//   f2 = new T[N_];
//   dxi = new double[N_];
// }

template<class T>
inline spline1D<T>::~spline1D()
{
  delete[] f; f = NULL;
  delete[] f2; f2 = NULL;
  delete[] dxi; dxi = NULL;
  N=0;
  N0=0;
}

template<class T>
inline void spline1D<T>::resize(int n)
{
  if (n>N0){
    if (f) delete[] f;
    if (f2) delete[] f2;
    if (dxi) delete[] dxi;
    f = new T[n];
    f2 = new T[n];
    dxi = new double[n];
    N0=n;
  }
  N = n;
}

template<class T>
inline spline1D<T>::spline1D(const spline1D& m)
{
  resize(m.N);
  std::copy(m.f,m.f+N,f);
  std::copy(m.f2,m.f2+N,f2);
  std::copy(m.dxi,m.dxi+N,dxi);
}

template <class T>
inline spline1D<T>& spline1D<T>::operator=(const spline1D<T>& m)
{
  resize(m.N);
  std::copy(m.f,m.f+N,f);
  std::copy(m.f2,m.f2+N,f2);
  std::copy(m.dxi,m.dxi+N,dxi);
  return *this;
}

template <class T>
T spline1D<T>::operator()(const intpar& ip) const
{
  int i= ip.i; double p = ip.p, q=1-ip.p;
  return q*f[i] + p*f[i+1] + dxi[i]*dxi[i]*(q*(q*q-1)*f2[i] + p*(p*p-1)*f2[i+1])/6.;
}

template <class T>
inline spline1D<T>::spline1D(const mesh1D& om, const function<T>& fu, const T& df0, const T& dfN)  : f2(NULL), dxi(NULL)
{
  if (om.size()!=fu.size()) cerr<<"Sizes of om and f are different in spline setup"<<endl;
  resize(om.size()); // Calling constructor to initialize memory
  std::copy(fu.f,fu.f+N,f);
  function1D<T> diag(om.size()), offdiag(om.size()-1); // matrix is stored as diagonal values + offdiagonal values
  // Below, matrix and rhs is setup
  diag[0] = (om[1]-om[0])/3.;
  double dfu0 = (fu[1]-fu[0])/(om[1]-om[0]);
  f2[0] = dfu0-df0;
  for (int i=1; i<om.size()-1; i++){
    diag[i] = (om[i+1]-om[i-1])/3.;
    double dfu1 = (fu[i+1]-fu[i])/(om[i+1]-om[i]);
    f2[i] = dfu1-dfu0;
    dfu0 = dfu1;
  }
  diag[N-1] = (om[N-1]-om[N-2])/3.;
  f2[N-1] = dfN - (fu[N-1]-fu[N-2])/(om[N-1]-om[N-2]);
  for (int i=0; i<om.size()-1; i++) offdiag[i] = (om[i+1]-om[i])/6.;
  // The system of symmetric tridiagonal equations is solved by lapack
  int one=1, info=0;
  if (df0==std::numeric_limits<double>::max() || dfN==std::numeric_limits<double>::max()){
    int size = N-2;// natural splines	
    dptsv_(&size, &one, diag.MemPt()+1, offdiag.MemPt()+1, f2+1, &N, &info);
    f2[0]=0; f2[N-1]=0;
  } else  dptsv_(&N, &one, diag.MemPt(), offdiag.MemPt(), f2, &N, &info);

  if (info!=0) cerr<<"dptsv return an error "<<info<<endl;
  // Setup of other necessary information for doing splines.
  for (int i=0; i<om.size()-1; i++) dxi[i] = (om[i+1]-om[i]);
}

template <class T>
inline T spline1D<T>::integrate()
{
  T sum=0;
  for (int i=0; i<N-1; i++) sum += 0.5*dxi[i]*(f[i+1]+f[i]-(f2[i+1]+f2[i])*dxi[i]*dxi[i]/12.);
  return sum;
}

template <class T>
inline dcomplex spline1D<T>::Fourier(double om, const mesh1D& xi)
{
  dcomplex ii(0,1);
  dcomplex sum=0;
  dcomplex val;
  for (int i=0; i<N-1; i++) {
    double u = om*dxi[i], u2=u*u, u4=u2*u2;
    if (fabs(u)<1e-4){// Taylor expansion for small u
      val = 0.5*(1.+ii*(u/3.))*f[i]+0.5*(1.+ii*(2*u/3.))*f[i+1];
      val -= dxi[i]*dxi[i]/24.*(f2[i]*(1.+ii*7.*u/15.)+f2[i+1]*(1.+ii*8.*u/15.));
    }else{
      dcomplex exp(cos(u),sin(u));
      val  = (f[i]*(1.+ii*u-exp) + f[i+1]*((1.-ii*u)*exp-1.))/u2;
      val += dxi[i]*dxi[i]/(6.*u4)*(f2[i]*(exp*(6+u2)+2.*(u2-3.*ii*u-3.))+f2[i+1]*(6+u2+2.*exp*(u2+3.*ii*u-3.)));
    }
    sum += dxi[i]*dcomplex(cos(om*xi[i]),sin(om*xi[i]))*val;
  }
  return sum;
}

template <class T>
std::ostream& operator<<(std::ostream& stream, const function<T>& f)
{
  int width = stream.width(); 
  for (int i=0; i<f.size(); i++) stream<<i<<" "<<std::setw(width)<<f[i]<<std::endl;
  return stream;
}

template<class meshx, class functionx>
void print(std::ostream& stream, const meshx& om, const functionx& f, int width)
{
  if (om.size()!=f.size()) std::cerr<<"Can't print objectc of different size!"<<std::endl;
  for (int i=0; i<om.size(); i++)
    stream <<std::setw(width)<<om[i]<<std::setw(width)<<f[i]<<std::endl;
}
template<class meshx, class functionx, class functiony>
void print(std::ostream& stream, const meshx& om, const functionx& f1, const functiony& f2, int width)
{
  if (om.size()!=f1.size() || om.size()!=f2.size()) std::cerr<<"Can't print objectc of different size!"<<std::endl;
  for (int i=0; i<om.size(); i++)
    stream <<std::setw(width)<<om[i]<<std::setw(width)<<f1[i]<<std::setw(width)<<f2[i]<<std::endl;
}
template<class meshx, class functionx, class functiony, class functionz>
void print(std::ostream& stream, const meshx& om, const functionx& f1, const functiony& f2, const functionz& f3, int width)
{
  if (om.size()!=f1.size() || om.size()!=f2.size() || om.size()!=f3.size()) std::cerr<<"Can't print objectc of different size!"<<std::endl;
  for (int i=0; i<om.size(); i++)
    stream <<std::setw(width)<<om[i]<<std::setw(width)<<f1[i]<<std::setw(width)<<f2[i]<<std::setw(width)<<f3[i]<<std::endl;
}
template<class meshx, class functionx, class functiony, class functionz, class functionw>
void print(std::ostream& stream, const meshx& om, const functionx& f1, const functiony& f2, const functionz& f3, const functionw& f4, int width)
{
  if (om.size()!=f1.size() || om.size()!=f2.size() || om.size()!=f3.size() || om.size()!=f4.size())
    std::cerr<<"Can't print objectc of different size!"<<std::endl;
  for (int i=0; i<om.size(); i++)
    stream <<std::setw(width)<<om[i]<<std::setw(width)<<f1[i]<<std::setw(width)<<f2[i]<<std::setw(width)<<f3[i]<<std::setw(width)<<f4[i]<<std::endl;
}

inline double logpart(double a, double b, double x)
{
  return log(fabs((b-x)/(x-a)));
}
template <class complexn>
inline complexn logpart(double a, double b, const complexn& x)
{
  return log((b-x)/(a-x));
}
template <class T, class meshx>
T KramarsKronig(const function<double>& fi, const meshx& om, const T& x, int i0, double S0)
{
  T sum=0;
  for (int j=0; j<i0-1; j++) sum += (fi[j]-S0)*om.Dh(j)/(om[j]-x);
  if (i0>0)                  sum += (fi[i0-1]-S0)*(om.Dh(i0-1)+0.5*om.Dh(i0))/(om[i0-1]-x);
  if (i0<om.size()-1)        sum += (fi[i0+1]-S0)*(om.Dh(i0+1)+0.5*om.Dh(i0))/(om[i0+1]-x);
  for (int j=i0+1; j<om.size(); j++) sum += (fi[j]-S0)*om.Dh(j)/(om[j]-x);
  if (x!=om.last() && x!=om[0]) sum += S0*logpart(om[0],om.last(),x);
  return sum/M_PI;
}
#endif
