from scipy import *
from pylab import *


def Schoedinger_deriv(u, r, l, eps, Z):
    """ This function returns derivative to solve Schroedinger    equation for hydrogen atom.
    """
    return array([u[1], 2*(l*(l+1)/(2*r**2)-Z/r-eps)*u[0]])

def normalize(u, r_mesh):
    " normalizes solution "
    u2 = u*u    
    nrm = integrate.simps(u2, r_mesh)
    return -u/sqrt(abs(nrm))

def Shoot(eps, r_mesh, l, Z):
    " Starting from r=infty, finds the solution of Schroedinger equation at r=0"
    # initial condition
    u0 = array([0.,1.])
    # actual integration of Sch. equation
    ub = integrate.odeint(Schoedinger_deriv, u0, r_mesh, args=(l, eps, Z))
    u_at_zero = ub[-1,0] + (ub[-2,0]-ub[-1,0])*(0.0-r_mesh[-1])/(r_mesh[-2]-r_mesh[-1])
    return u_at_zero

def FindBoundStates(l, Z, r_mesh, nmax, eps_start):
    " Finds nmax-l eigenstates of the radial Schroedinger equation for given l"
    eps = eps_start
    states = []
    for n in range(l,nmax):
        deps = 0.3/(n+1)**3
        val0 = Shoot(eps, r_mesh, l, Z)
        val1 = val0
        while (val0*val1>0):
            eps += deps
            val0 = val1
            val1 = Shoot(eps, r_mesh, l, Z)
            
        zero = optimize.brentq(Shoot, eps-deps, eps, args=(r_mesh, l, Z) )
        states.append([n,l,zero])
        
    return states
        
if  __name__ == '__main__':

    # mesh of r-points
    r_mesh = logspace(2,-6,100)

    # Z and maximal number of l needed
    Z = 1
    n=8

    # starting guess for the lowest energy
    eguess = -0.6*Z**2

    # All bound states will be saved
    bstates=[]
    for l in range(n):
        print l
        bstates.append(FindBoundStates(l, Z, r_mesh, n, eguess)) # Finds new bound state at certain l
        eguess = bstates[l][0][2]         # guess for the first bound state (uses the fact that E(1s)>E(1p)>E(1d)...)


    # Prints the bound states and plots them
    for l in range(n):
        for ni in range(0,n-l):
            print l, ni, bstates[l][ni]
            eps = bstates[l][ni][2]
            ub = integrate.odeint(Schoedinger_deriv, [0,1], r_mesh, args=(l, eps, Z)) # recomputes psi
            u = normalize(ub[:,0], r_mesh) # and normalizes it
            plot(r_mesh, u)        # plotting
            
axis([0., 50, -0.3, 0.8])
show()
