" Step 5: Finds bound states of H2+ molecule "

from scipy import *
from pylab import *
from scipy import integrate, optimize, interpolate


def create_Xi_mesh(p):
    " Mesh for variable xi in prolate coordinates "
    xi_max = 30./p
    Xi = 1 + logspace(-7,log10(xi_max),200)
    Xim = Xi[::-1]
    return (Xi,Xim)

def create_mu_mesh():
    " Mesh for variable eta~cos(theta) in prolate coordinates "
    return linspace(1-1e-7,0.,50)

def create_full_mu_mesh():
    " Mesh for variable eta~cos(theta) in prolate coordinates "
    return linspace(1-1e-7,-1+1e-7,50)

def give_xx0(p,A,etam0):
    " Value and derivative at large distance for X(xi>>1)"
    return [exp(-p*etam0+0.05*A),-p*exp(-p*etam0+0.05*A)]
    
def give_yy0(p,A,m):
    " Value and derivative for Y(eta=1)"
    return [1.0, (p**2-A)/(2*(m+1))] # You have to start such that derivative is finite!

def Xderiv(xxd, xi, A, p2, m, R):
    """Routine to solve the second part of the H2+ Schroedinger equation.
       (xi^2-1)*X'' + 2*(m+1)*xi*X' + [A-p^2*xi^2+2*R*xi]*X=0
    """
    X = xxd[0]
    dX = xxd[1]
    return [dX, (-2*(m+1)*xi*dX-(A-p2*xi**2+2*R*xi)*X)/(xi**2-1) ]

def Yderiv(yyd, mu, A, p2, m):
    """Routine to solve the first part of the H2+ Schroedinger equation.
      -(1-mu^2)*Y'' + 2*(m+1)*mu*Y' + [A-p^2*mu^2]*Y=0
    """
    Y = yyd[0]
    dY = yyd[1]
    return [dY, (2*(m+1)*mu*dY + (A-p2*mu**2)*Y )/(1-mu**2)]


def Shoot1(A, p, m, mu, gerade):
    "Shoots to find solution of Y(eta) with Y(1)=1 and Y(0)=0 or Y'(0)=0 "
    # gerade==1 : looking for symmetric function gerade                                                               
    # gerade==0 : looking for antisymmetric function                                                                  
    yy0 = give_yy0(p,A,m)
    y = integrate.odeint(Yderiv, yy0, mu, args=(A,p**2,m) )
    # returns the last point, which is value at zero
    # for non-gerade we look for zero of y, for gerade we look for zero of the derivative        
    return y[-1,gerade]

def FindA(p,m,mu,gerade):
    "Finds A(p)"
    lmax=2.
    Al = linspace(p**2/2+p**3/12,-lmax*(lmax+1),10)
    p_zero_val=1
    for iA,A in enumerate(Al):
        zero_val = Shoot1(A, p, m, mu, gerade)
        if iA>0 and zero_val*p_zero_val<0: break # Finds just the first solution!
        p_zero_val = zero_val
    A = optimize.brentq(Shoot1, Al[iA-1], Al[iA], args=(p, m, mu, gerade),xtol=1e-17)
    return A

def Shoot2(p,m,mu,gerade,R):
    "Shoots to find solution of X(xi) with X(inf)=0 and X'(1)=0"
    # Finds A(p) at this p
    A = FindA(p,m,mu,gerade)
    # Creates good mesh for Xi
    (Xi,Xim) = create_Xi_mesh(p)
    # good value of X(infty)
    xx0=give_xx0(p,A,Xim[0])
    X = integrate.odeint(Xderiv, xx0, Xim, args=(A,p**2,m,R) )
    return X[-1,1] + (X[-2,1]-X[-1,1])*(1.0-Xim[-1])/(Xim[-2]-Xim[-1])

def FindBoundStates(m,mu,gerade,R):
    pl = linspace(R/2+0.6,R/4,40)
    Solutions=[]
    p_value_zero=0.
    for ip,p in enumerate(pl):
        value_zero = Shoot2(p,m,mu,gerade,R)
        if ip>0 and value_zero*p_value_zero<0:
            p = optimize.brentq(Shoot2, pl[ip-1], pl[ip], args=(m, mu, gerade,R),xtol=1e-17)
            A = FindA(p,m,mu,gerade)
            print 'Found bound state p=', p, 'A=', A
            Solutions.append([p,A,m,gerade])
        p_value_zero=value_zero
    return Solutions

def GiveY(p,A,m,gerade):
    mu = create_full_mu_mesh()
    yy0=give_yy0(p,A,m)
    Y = integrate.odeint(Yderiv, yy0, mu, args=(A,p**2,m) )[:,0]
    norm = integrate.simps(Y**2, mu)
    Y *= 1/sqrt(abs(norm))
    Y *= (1-mu**2)**(m/2.)
    return (mu,Y)

def GiveX(p,A,m,gerade):
    (Xi,Xim) = create_Xi_mesh(p)
    xx0=give_xx0(p,A,Xim[0])
    X = integrate.odeint(Xderiv, xx0, Xim, args=(A,p**2,m,R) )[:,0]
    norm = integrate.simps(X**2, Xim)
    X *= sign(X[-1])/sqrt(abs(norm))
    X *= (Xim**2-1)**(m/2.)
    return (Xi,X[::-1])

def prolate((x,y,z),R):
    "Converts from cartesian coordinates to prolate"
    r2 = (x**2+y**2+z**2)/(R/2)**2
    z2 = (z/(R/2))**2
    sq=sqrt((1+r2)**2-4*z2)
    
    xi=sqrt(0.5*(1+r2+sq))
    et=sqrt(0.5*(1+r2-sq))*sign(z)
    ph=arctan2(y,x)
    return (xi,et,ph)


def coordy0(R,Nx,Ny,XZrang):
    "Generates coordinates in prolate system for plotting the XZ-plane"
    xl=linspace(-XZrng*R,XZrng*R,Nx)
    zl=linspace(-XZrng*R,XZrng*R,Ny)
    y=0.0
    coord=[]
    for ix,x in enumerate(xl):
        for iz,z in enumerate(zl):
            coord.append(prolate((x,y,z),R))
    return coord

if __name__ == '__main__':
    R=2.0
    Plot2D=True
    Plot3D=True
    
    mu = create_mu_mesh()
    
    Solutions=[]
    for m in range(2):
        for gerade in range(1,-1,-1):
            Solutions += FindBoundStates(m,mu,gerade,R)
    
    print 'Solutions:'
    for (p,A,m,gerade) in Solutions:
        Ene = -4*p**2/R**2 + 2./R
        print 'E=', Ene, 'p=', p, 'A=', A, 'm=', m, 'gerade=', gerade
        
    for (p,A,m,gerade) in Solutions:
        Ene = -4*p**2/R**2 + 2./R
        (mu,Y) = GiveY(p,A,m,gerade)
        (Xi,X) = GiveX(p,A,m,gerade)
    
        if Plot2D:
            subplot(2,1,1)
            plot(mu,Y,lw=2)
            subplot(2,1,2)
            plot(Xi,X,lw=2)
            show()
            
        if Plot3D:
            Ys = interpolate.UnivariateSpline(mu[::-1],Y[::-1],s=0)
            Xs = interpolate.UnivariateSpline(Xi,X,s=0)
            
            XZrng=1.
            if m==1: XZrng=3.
            
            Nx,Ny=100,100
            coord=coordy0(R,Nx,Ny,XZrng)
            wf=[]
            for i,(xit,mut,phit) in enumerate(coord):
                psi = Ys(mut) * Xs(xit) * cos(m*phit)
                wf.append( psi )
            wf=array(wf).reshape((Nx,Ny))
            print 'E=', Ene, 'p=', p, 'A=', A, 'm=', m, 'gerade=', gerade
            imshow(wf,extent=[-XZrng,XZrng,-XZrng,XZrng])
            show()
     
