from scipy import *
from pylab import *


def Schoedinger_deriv(u, r, l, eps, Z):
    """ This function returns derivative to solve Schroedinger    equation for hydrogen atom.
    """
    return array([u[1], 2*(l*(l+1)/(2*r**2)-Z/r-eps)*u[0]])

def normalize(u, r_mesh):
    " normalizes solution "
    u2 = u*u    
    nrm = integrate.simps(u2, r_mesh)
    return -u/sqrt(abs(nrm))

def Shoot(eps, r_mesh, l, Z):
    " Starting from r=infty, finds the solution of Schroedinger equation at r=0"
    # initial condition
    u0 = array([0.,1.])
    # actual integration of Sch. equation
    ub = integrate.odeint(Schoedinger_deriv, u0, r_mesh, args=(l, eps, Z))
    u_at_zero = ub[-1,0] + (ub[-2,0]-ub[-1,0])*(0.0-r_mesh[-1])/(r_mesh[-2]-r_mesh[-1])
    return u_at_zero

def FindBoundStates(l, Z, r_mesh, nmax, eps_start):
    " Finds nmax-l eigenstates of the radial Schroedinger equation for given l"
    eps = eps_start
    states = []
    for n in range(l,nmax):
        deps = 0.3/(n+1)**3
        val0 = Shoot(eps, r_mesh, l, Z)
        val1 = val0
        while (val0*val1>0):
            eps += deps
            val0 = val1
            val1 = Shoot(eps, r_mesh, l, Z)
            
        zero = optimize.brentq(Shoot, eps-deps, eps, args=(r_mesh, l, Z) )
        states.append([n,l,zero])
        
    return states

def addl(a, b):
    "Just usual sumation. Used to concatenate lists."
    return a + b

def cmp_energy(a, b):
    "compares energy of the bound states. Used for sorting"
    return cmp(a[2], b[2])

def ComputeDensity(Z, r_mesh, bstates):
    states = reduce(addl, bstates)  # flatten all bound states
    states.sort(cmp_energy)         # sorting of bound states according to their energy
    rho = zeros(len(r_mesh), dtype=float) # initialize density
    N = 0          # number of electrons added to density
    for state in states:
        l = state[1]
        eps = state[2]
        ub = integrate.odeint(Schoedinger_deriv, [0,1e-10], r_mesh, args=(l, eps, Z)) # recomputes psi
        u = normalize(ub[:,0], r_mesh) # and normalize it
        u2 = u**2/(4*pi*r_mesh**2)     # normalization in the angle
        
        if (N+2*(2*l+1) < Z): # all 2*(2*l+1) electrons in the degenerate multiplet can be added
            ferm = 1
        else:
            ferm = (Z-N)/(2.*(2.*l+1.)) # only a fraction of the charge is added for this multiplet

        rho = rho + array(u2)*2*(2*l+1)*ferm
        N = N + 2*(2*l+1)*ferm
        print N
        if (N>Z): break

    return rho


def Poisson_derivative(y, r, rho_spline):
    """Right hand side of the differential equation.
    Here y = [u, v].
    """
    rho = interpolate.splev(r, rho_spline)
    return array([y[1], -4*pi*r*rho]) # (\dot{U}, \dot{\dot{U}})

def SolvePoisson(Z, Rmesh, rho_spline):
    """Solve the Poisson equation: given spline of charge density rho_spline
    it computes Hartree potential, which satisfies Poisson equation
    U''(r) = -4*pi*r*rho(r)
    """
    y_t = integrate.odeint(Poisson_derivative, [0, 1.0], Rmesh, args = (rho_spline,)) # integration of the equation
    # adding homogeneous solution to satisfay boundary conditions: U(0)=0, U(infinity)=Z
    U = y_t[:,0]
    alpha = (Z - U[-1])/Rmesh[-1]
    U += alpha*Rmesh
    return U



if  __name__ == '__main__':

    # mesh of r-points
    r_mesh = logspace(2,-6,100)

    # Z and maximal number of l needed
    Z = 2
    n = 1

    # starting guess for the lowest energy
    eguess = -0.6*Z**2

    # All bound states will be saved
    bstates=[]
    for l in range(n):
        print l
        bstates.append(FindBoundStates(l, Z, r_mesh, n, eguess)) # Finds new bound state at certain l
        eguess = bstates[l][0][2]         # guess for the first bound state (uses the fact that E(1s)>E(1p)>E(1d)...)


    # Prints the bound states
    for l in range(n):
        for ni in range(0,n-l):
            print l, ni, bstates[l][ni]

    rho = ComputeDensity(Z, r_mesh, bstates)

    nr = [rho[i]*4*pi*r_mesh[i]**2 for i in range(len(rho))]
    
    plot(r_mesh, nr, 'ro-')
    show()

    Rmesh = r_mesh[::-1]
    nrho = rho[::-1]
    rho_spline = interpolate.splrep(Rmesh, nrho, s=0)   # spline density with cubic spline
    
    Uhartree = SolvePoisson(Z, Rmesh, rho_spline) # Hartree potential

    plot(Rmesh, Uhartree, 'bs-')
    show()

    # New effective potential
    Veff0 = (-Z*ones(len(Rmesh)) + Uhartree)/Rmesh
    Veff = interpolate.splrep(Rmesh, Veff0, s=0)

    plot(Rmesh, Veff0, 'go-')    
    axis([0,5.,-10,0])
    show()
