#include <cmath>
#include <iostream>
#include <vector>
#include <list>
#include <fstream>
using namespace std;

template <class functor>
void recurs(list<double>& x, list<double>& f, const functor& fun, list<double>::iterator xit, list<double>::iterator fit, double precision, int Max_level, int level){
  // Recursive algorithm inserts points into the mesh if the function values
  // differ more than precision
  // Recursion can not go deeper than Max_level
  list<double>::iterator fip = fit; fip++;
  list<double>::iterator xip = xit; xip++;
  if (fabs(*(fit)-*(fip))<precision || level>Max_level){
    return;
  } else{
    double x_new = 0.5*((*xit)+ (*xip));
    list<double>::iterator xin = x.insert(xip, x_new);
    list<double>::iterator fin = f.insert(fip, fun(x_new));
    
    recurs(x, f, fun, xit, fit, precision, Max_level, level+1);
    recurs(x, f, fun, xin, fin, precision, Max_level, level+1);
  }
}

template <class functor>
void mesh(list<double>& x, list<double>& f, const functor& fun, double a, double b, int Nmin=9, int Max_level=20, double precision=1e-3)
{
  // This function constructs an adaptive mesh for function f.
  // The mesh is recursively subdivided until the difference between the function values is less than precision
  // or there were more than Max_level subdivisions - recursion steps.
  // This function creates a list for mesh and a list of function values.
  for (int i=0; i<Nmin; i++) {
    double x0 = a + (b-a)*i/(Nmin-1.);
    x.push_back(x0);
    f.push_back(fun(x0));
  }

  list<double>::iterator xit = x.begin();
  list<double>::iterator fit = f.begin();

  vector<list<double>::iterator> xiter(x.size());
  vector<list<double>::iterator> fiter(x.size());
  for (int i=0; i<x.size(); i++){
    xiter[i] = xit;
    fiter[i] = fit;
    xit++;
    fit++;
  }

  for (int i=0; i<xiter.size()-1; i++){
    int level=0;
    recurs(x, f, fun, xiter[i], fiter[i], precision, Max_level, level);
  }
}

double simps(const list<double>& x, const list<double>& f)
{// given two lists : x and f, it computes integral using simpson's rule.
  list<double>::const_iterator f0, f1, f2;
  list<double>::const_iterator i0, i1, i2;
  f0 = f.begin();
  i0 = x.begin();
  double sum=0;
  for (; i0!= x.end();){
    double a = *i0;
    i1  = i0;
    double b = *(++i1);
    if (i1==x.end()) break;
    i2  = i1;
    double c = *(++i2);
    if (i2==x.end()) break;
    double fa = *f0;
    f1  = f0;
    double fb = *(++f1);
    f2  = f1;
    double fc = *(++f2);
    double dh1 = b-a;
    double dh2 = c-b;
    sum += (dh1+dh2)/(6*dh1*dh2)*((3*b-2*a-c)*dh2*fa + (c-a)*(c-a)*fb + (2*c-3*b+a)*dh1*fc);
    i0++;
    i0++;
    f0++;
    f0++;
  }
  if (x.size()%2==0){
    i1 = x.end();
    f1 = f.end();
    double b = *(--i1);
    double fb = *(--f1);
    i0 = i1;
    f0 = f1;
    double a = *(--i0);
    double fa = *(--f1);
    sum += 0.5*(b-a)*(fa+fb);
  }
  return sum;
}
	     
double F0(double x)
{ return (1-2*sqrt(fabs(x)))*exp(-x*x);}

double F1(double x)
{ return cos(x);}

double F3(double x)
{
  return 1/sqrt(x);
}

int main()
{
  double IF0_exact = -0.3391897770124197*2;
  
  list<double> x, f;
  mesh(x, f, F0, -5., 5.);
  double integral = simps(x,f);
  cout<<"Total number of points used "<<x.size()<<" The integral is "<<integral<<" and the error is "<<fabs(integral-IF0_exact)<<endl;

  ofstream fdat1("func1.dat");
  for (list<double>::const_iterator i=x.begin(); i!=x.end(); i++){
    fdat1<<*i<<" "<<F0(*i)<<endl;
  }
  
  x.clear();
  f.clear();
  mesh(x, f, F1, 0, 51.*M_PI/2.);
  integral = simps(x,f);
  cout<<"Total number of points used "<<x.size()<<" The integral is "<<integral<<" and the error is "<<fabs(integral+1)<<endl;

  ofstream fdat2("func2.dat");
  for (list<double>::const_iterator i=x.begin(); i!=x.end(); i++){
    fdat2<<*i<<" "<<F1(*i)<<endl;
  }


  x.clear();
  f.clear();
  mesh(x, f, F3, 1e-10, 1.);
  integral = simps(x,f);
  double I_Exact = 1.99998;
  cout<<"Total number of points used "<<x.size()<<" The integral is "<<integral<<" and the error is "<<fabs(integral-I_Exact)<<endl;

  ofstream fdat3("func3.dat");
  for (list<double>::const_iterator i=x.begin(); i!=x.end(); i++){
    fdat3<<*i<<" "<<F3(*i)<<endl;
  }
  
  return 0;
}
