from scipy import *
from scipy import weave
from scipy import integrate
from scipy import optimize
from pylab import *

code_Numerov="""
// void Numerov(const container& F, int Nmax, double dh, container& Solution)
  double dx = dh;
  double h2 = dx*dx;
  double h12 = h2/12;
  
  double w0 = (1-h12*F(0))*Solution(0);
  double Fx = F(1);
  double w1 = (1-h12*Fx)*Solution(1);
  double Phi = Solution(1);
  
  double w2;
  for (int i=2; i<Nmax; i++){
    w2 = 2*w1 - w0 + h2*Phi*Fx;
    w0 = w1;
    w1 = w2;
    Fx = F(i);
    Phi = w2/(1-h12*Fx);
    Solution(i) = Phi;
  }

"""
code_Sch="""
   for (int i=0; i<Nmax; i++){
      F(i) = l*(l+1)/(R(i)*R(i))-2*Z/R(i)-E;
   }
"""

def ComputeSchrod(E,R,l,Z,Normaliza=False):
    dh=(R[0]-R[-1])/(len(R)-1.)
    Nmax = len(R)
    
    F = zeros(len(R),dtype=float)
    weave.inline(code_Sch, ['F', 'Nmax', 'R', 'E', 'l', 'Z'],type_converters=weave.converters.blitz, compiler = 'gcc')
    
    Solution = zeros(len(R),dtype=float)
    Solution[1]=1e-7
    weave.inline(code_Numerov, ['F', 'Nmax', 'dh', 'Solution'],type_converters=weave.converters.blitz, compiler = 'gcc')

    if (Normalize):
        Norm = abs(integrate.romb(Solution**2,R[1]-R[0]))
        Solution *= 1./sqrt(Norm)
    return Solution

def Shoot(E,R,l,Z):
    u = ComputeSchrod(E,R,l,Z)
    return u[-1]+(u[-2]-u[-1])*(0.0-R[-1])/(R[-2]-R[-1])

    
def comp(x,y):
    if abs(x[1]-y[1])<1e-5:
        # sort according to l quantum number when energy is degenerate
        return cmp(x[0],y[0])
    # If energy different, sort according to energy
    return cmp(x[1],y[1])

def FindBoundStates(Z,R,n_max,lmax,Nmax=1000):
    E0=-1.2*Z**2
    Eb=[]
    for l in range(lmax+1):
        dE = abs(E0)/12.
        u0 = Shoot(E0,R,l,Z)
        Ebl=[]
        for i in range(Nmax):
            E0 += dE
            u1 = Shoot(E0,R,l,Z)
            if u0*u1<0:
                Ebound = optimize.brentq(Shoot, E0-dE, E0, args=(R,l,Z), xtol=1e-16)
                Ebl.append([l,Ebound])
                print 'Found bound state at ', Ebound
                if len(Ebl)>=(n_max-l): break
            u0=u1
            dE = dE/1.091
        Eb += Ebl
        E0 = Ebl[0][1]
    
    Eb.sort(comp)
    return Eb

    

Z_screened=1.
Z=82
n_max=5
lmax=3

#R = logspace(-7,2.3,1000)
#R0 = linspace(1e-7,50,2**10+1)
R0 = linspace(1e-7,90,2**12+1)
R=R0[::-1]

Eb = FindBoundStates(Z_screened,R,n_max,lmax)

rho = zeros(len(R),dtype=float)
Nelec=0
for i,(l,Ene) in enumerate(Eb):
    print 'state=', l, Ene
    u = ComputeSchrod(Ene,R,l,Z_screened,True)

    dN = 2*(2*l+1)
    if Z >= Nelec+dN:
        ferm=1.
    else:
        ferm= (Z-Nelec)/(dN+0.0)
        
    Nelec += dN*ferm
    rho += u*u*dN*ferm / (4*pi*R**2)
    
    if Nelec>=Z: break

    print 'Adding ', dN*ferm, 'electrons of l=', l, 'Ene=', Ene
    
print 'Nelec=', Nelec
print 'Norm=', integrate.romb(rho*4*pi*R**2,R[0]-R[1])
plot(R,rho*4*pi*R**2, 'o-')
show()
