from scipy import *
from scipy import integrate, optimize
from pylab import *

def Schrod_deriv(y, r, l, E, Z):
    "Given y=[u,u'] returns dy/dr=[u',u'']"
    du2 = (l*(l+1)/r**2-2*Z/r-E)*y[0]
    return [y[1],du2]

def Shoot(E,R,l,Z):
    y0=[0.0,-1e-7]
    Rb=R[::-1]
    y = integrate.odeint(Schrod_deriv, y0, Rb, args=(l,E ,Z))[:,0]
    norm = integrate.simps(Rb, y**2)
    f0 = y[-1]/sqrt(norm)
    f1 = y[-2]/sqrt(norm)
    final = f0 + (f1-f0)*(0.0-Rb[-1])/(Rb[-2]-Rb[-1])
    #print 'f0=', f0, 'f1=', f1, 'extrapolated=', final
    return final
    
def FindBoundStates(R,l,Z,nmax,MaxSteps,E0,dE):
    Eb=[]  # bound states
    u0 = Shoot(E0,R,l,Z)
    for i in range(MaxSteps):
        E0 += dE
        u1 = Shoot(E0,R,l,Z)
        #print E0, u1, dE
        if u0*u1<0.0:
            Ebound = optimize.brentq(Shoot, E0-dE, E0, xtol=1e-16, args=(R,l,Z))
            Eb.append( (l,Ebound) )
            #print 'Found bound state at E=', Ebound
            if len(Eb)>=nmax or Ebound>0: break
        u0=u1
        dE = dE/1.091
    return Eb

def SolveSchroedinger(E,R,l,Z):
    Rb=R[::-1]
    y0=[0.0,-1e-7] # starting guess should be quite small
    y = integrate.odeint(Schrod_deriv, y0, Rb, args=(l,E ,Z))[:,0]
    norm = integrate.simps(Rb, y**2)
    return y[::-1]/sqrt(norm)
    
def cmpE(x,y):
    if abs(x[1]-y[1])>1e-4:
        return cmp(x[1],y[1])
    else:
        return cmp(x[0],y[0])
    

Z = 1.
MaxSteps=1000
nmax=3
lmax=3
Ntot=36

R = logspace(-6,2.2,500)

E0 = -1.2*Z**2
dE = abs(E0)/12.
bstates=[]
for l in range(lmax):
    cbst = FindBoundStates(R,l,Z,nmax-l,MaxSteps,E0,dE)
    bstates += cbst
    E0 = cbst[0][1]
    dE = abs(E0)/12.
    #print bstates

bstates.sort(cmpE)
print 'bstates=', bstates

rho=zeros(len(R),dtype=float)
N=0
for b in bstates:
    l=b[0]
    E=b[1]
    u = SolveSchroedinger(E,R,l,Z)
    dN= 2*(2*l+1)
    if N+dN<Ntot:
        ferm=1.
    else:
        ferm=(Ntot-N)/(dN+0.0)
    drho = u**2 * dN*ferm/(4*pi*R**2)
    rho += drho
    N += dN
    print 'adding states', b, 'ferm=', ferm, N, Ntot
    if N>=Ntot: break


plot(R,4*pi*rho*R**2, 'o-')
grid()
show()

sys.exit(0)
for b in bstates:
    l=b[0]
    E=b[1]
    u = SolveSchroedinger(E,R,l,Z)
    plot(R,u,label=('l=%s' % (l,)))
legend(loc='best')
show()

    
