
import numpy as np
from scipy.special import ellipk, ellipe
from matplotlib.pyplot import *

def frfz(r,z,r1,z1):
    b = z-z1
    b2 = b*b
    A = (r-r1)**2+b2
    B = (r+r1)**2+b2
    C = r**2-r1**2-b2
    x = -4*r*r1/A
    E = ellipe(x)
    K = ellipk(x)
    D = B*np.sqrt(A)
    if r!=0:
        fr =  2*r1/r*(C*E+B*K)/D
    else:
        fr=0
    fz = 4*b*r1*E/D
    return fr,fz

frfz = np.frompyfunc(frfz,4,2)
                   

z1 = 0
a = 1
z = 0.1
r = 0.1
N = 100
r1 = np.linspace(0,a,N)
fr,fz = frfz(r,z,r1,z1)
figure()
subplot(211)
plot(r1,fr)
ylabel('fr')
grid()
subplot(212)
plot(r1,fz)
ylabel('fz')
xlabel('r1')
grid()
                   

z1 = 0
a = 1
z = 0.01
r = 0.1
N = 500
r1 = np.linspace(0,a,N)
fr,fz = frfz(r,z,r1,z1)
figure()
subplot(211)
plot(r1,fr)
ylabel('fr')
grid()
subplot(212)
plot(r1,fz)
ylabel('fz')
xlabel('r1')
grid()
                   

dr = r1[1]-r1[0]
Iz = (np.sum(fz[1:N-1])+0.5*(fz[0]+fz[N-1]))*dr
Ir = (np.sum(fr[1:N-1])+0.5*(fr[0]+fr[N-1]))*dr
M = 1
Hz = M/(4*np.pi)*Iz
Hr = M/(4*np.pi)*Ir
                   

def HrHz(r,z,z1,a,M,N):
    r1 = np.linspace(0,a,N)
    dr = r1[1]-r1[0]
    fr,fz = frfz(r,z,r1,z1)
    Hz = M/(4*np.pi)*(np.sum(fz[1:N-1])+0.5*(fz[0]+fz[N-1]))*dr
    Hr = M/(4*np.pi)*(np.sum(fr[1:N-1])+0.5*(fr[0]+fr[N-1]))*dr
    return Hr,Hz
                   

P = 100
z = np.linspace(1e-3,2,P)
Hz = np.zeros(P)
r = 0
z1 = 0
a = 1
M = 1
N = 500
for k in range(P):
    hr,hz = HrHz(r,z[k],z1,a,M,N)
    Hz[k] = hz
def Hz_axe(z,z1,a,M):
    return M/2*((z-z1)/abs(z-z1)-(z-z1)/(a**2+(z-z1)**2)**0.5)
Hz_exact = Hz_axe(z,z1,a,M)
figure()
plot(z,Hz,'b')
plot(z,Hz_exact,'r')
grid()
xlabel('z')
ylabel('Hz')
                   

def HrHz_adapt(r,z,z1,a,M,Nmin,tol):
    N = Nmin
    Hr1,Hz1 = HrHz(r,z,z1,a,M,N)
    H1 = np.sqrt(Hr1*Hr1+Hz1*Hz1)
    N *= 2
    Hr,Hz = HrHz(r,z,z1,a,M,N)
    H = np.sqrt(Hr*Hr+Hz*Hz)
    while abs((H-H1)/H)>tol:
        H1 = H
        N *= 2
        Hr,Hz = HrHz(r,z,z1,a,M,N)
        H = np.sqrt(Hr*Hr+Hz*Hz)
    return Hr,Hz,N
                    

P = 100
z = np.linspace(1e-3,1,P)
Hz = np.zeros(P)
list_N = np.zeros(P)
r = 0
z1 = 0
a = 1
M = 1
Nmin = 10
tol=1e-2
for k in range(P):
    hr,hz,N = HrHz_adapt(r,z[k],z1,a,M,Nmin,tol)
    Hz[k] = hz
    list_N[k] = N
def Hz_axe(z,z1,a,M):
    return M/2*((z-z1)/abs(z-z1)-(z-z1)/(a**2+(z-z1)**2)**0.5)
Hz_exact = Hz_axe(z,z1,a,M)
figure()
plot(z,Hz,'b')
plot(z,Hz_exact,'r')
grid()
xlabel('z')
ylabel('Hz')
                   

figure()
plot(z,list_N)
xlabel('z')
ylabel('N')
yscale('log')
grid()
                   
                     
def HrHz_partiel(r,z,z1,a,M,N,sum_fr,sum_fz):
    dr = a/N
    r1 = np.arange(1,N,2)*dr
    fr,fz = frfz(r,z,r1,z1)
    sum_fr += np.sum(fr)
    sum_fz += np.sum(fz)
    Hz = M/(4*np.pi)*(sum_fz*dr)
    Hr = M/(4*np.pi)*(sum_fr*dr)
    return Hr,Hz,sum_fr,sum_fz
                     

def HrHz_complet(r,z,z1,a,M,N):
    fr0,fz0 = frfz(r,z,0,z1)
    fra,fza = frfz(r,z,a,z1)
    dr = a/N
    r1 = np.arange(1,N)*dr
    fr,fz = frfz(r,z,r1,z1)
    sum_fr = np.sum(fr)+0.5*(fr0+fra)
    sum_fz = np.sum(fz)+0.5*(fz0+fza)
    Hz = M/(4*np.pi)*(sum_fz*dr)
    Hr = M/(4*np.pi)*(sum_fr*dr)
    return Hr,Hz,sum_fr,sum_fz
                     

def HrHz_iter(r,z,z1,a,M,tol):
    N = 10
    Hr1,Hz1,sum_fr,sum_fz = HrHz_complet(r,z,z1,a,M,N)
    H1 = np.sqrt(Hr1*Hr1+Hz1*Hz1)
    N *= 2
    Hr,Hz,sum_fr,sum_fz = HrHz_partiel(r,z,z1,a,M,N,sum_fr,sum_fz)
    H = np.sqrt(Hr*Hr+Hz*Hz)
    while abs((H-H1)/H)>tol:
        H1 = H
        N *= 2
        Hr,Hz,sum_fr,sum_fz = HrHz_partiel(r,z,z1,a,M,N,sum_fr,sum_fz)
        H = np.sqrt(Hr*Hr+Hz*Hz)
    return Hr,Hz,N
    
                     

P = 100
z = np.linspace(1e-3,1,P)
Hz = np.zeros(P)
list_N = np.zeros(P)
r = 0
z1 = 0
a = 1
M = 1
Nmin = 10
tol=1e-2
for k in range(P):
    hr,hz,N = HrHz_iter(r,z[k],z1,a,M,tol)
    Hz[k] = hz
    list_N[k] = N
def Hz_axe(z,z1,a,M):
    return M/2*((z-z1)/abs(z-z1)-(z-z1)/(a**2+(z-z1)**2)**0.5)
Hz_exact = Hz_axe(z,z1,a,M)
figure()
plot(z,Hz,'b')
plot(z,Hz_exact,'r')
grid()
xlabel('z')
ylabel('Hz')
                   

figure()
plot(z,list_N)
xlabel('z')
ylabel('N')
yscale('log')
grid()
                   

def champAimant(r,z,a,zs,zn,M,mu0=1,tol=1e-2):
    Hr1,Hz1,N1 = HrHz_iter(r,z,zs,a,-M,tol)
    Hr2,Hz2,N2 = HrHz_iter(r,z,zn,a,M,tol)
    Br1 = mu0*Hr1
    Br2 = mu0*Hr2
    if abs(r)<a and z>min(zs,zn) and z<max(zs,zn):
        Bz1 = mu0*(Hz1+M)
        Bz2 = mu0*(Hz2+M)
    else:
        Bz1 = mu0*Hz1 
        Bz2 = mu0*Hz2
    return (Hr1+Hr2,Hz1+Hz2,Br1+Br2,Bz1+Bz2)


def fleche(x,y,sens,long,style='k-'):
    n = len(x)//2
    xa = x[n]
    xb = x[n+sens]
    ya = y[n]
    yb = y[n+sens]
    z = (xb-xa)+1j*(yb-ya)
    phi = np.angle(z)
    a = np.pi/2+np.pi/3
    xc = xb+long*np.cos(phi-a)
    yc = yb+long*np.sin(phi-a)
    xd = xb+long*np.cos(phi+a)
    yd = yb+long*np.sin(phi+a)
    plot([xb,xc],[yb,yc],style)
    plot([xb,xd],[yb,yd],style)
                    

def ligneH(a,zs,zn,M,ri,zi,sens,rmax,zmax,dmin):
    ds = 0.01*sens
    r = ri
    z = zi
    ligne_z = []
    ligne_r = []
    continuer = True
    while continuer:
        ligne_z.append(z)
        ligne_r.append(r)
        (Hr,Hz,Br,Bz) = champAimant(r,z,a,zs,zn,M)
        H = np.sqrt(Hr**2+Hz**2)
        r += Hr/H*ds                                    
        z += Hz/H*ds       
        if (abs(z-zs) < dmin and abs(r)<a) or (abs(z-zn)< dmin and abs(r)<a) or (abs(r)>rmax) or (abs(z)>zmax) : continuer=False
    return (np.array(ligne_r),np.array(ligne_z))
a=1
M = 1 
zs=-2
zn=2
ds = 0.05
figure(figsize=(8,8))
zmax = 10
rmax = 10
longfleche = 0.2
dmin = 0.1 # distance minimale d'approche des faces
for ri in [0,0.2,0.4,0.6,0.8,1.0]:
    sens = 1
    ligne_r,ligne_z = ligneH(a,zs,zn,M,ri,zn+dmin,sens,rmax,zmax,dmin)
    plot(ligne_z,ligne_r,'b-')
    fleche(ligne_z,ligne_r,sens,longfleche,style='b-')
    plot(ligne_z,-ligne_r,'b-')
    fleche(ligne_z,-ligne_r,sens,longfleche,style='b-')
    sens = -1
    ligne_r,ligne_z = ligneH(a,zs,zn,M,ri,zs-dmin,sens,rmax,zmax,dmin)
    plot(ligne_z,ligne_r,'b-')
    fleche(ligne_z,ligne_r,sens,longfleche,style='b-')
    plot(ligne_z,-ligne_r,'b-')
    fleche(ligne_z,-ligne_r,sens,longfleche,style='b-')
    sens = 1
    ligne_r,ligne_z = ligneH(a,zs,zn,M,ri,zn-dmin,sens,rmax,zmax,dmin)
    plot(ligne_z,ligne_r,'b-')
    fleche(ligne_z,ligne_r,sens,longfleche,style='b-')
    plot(ligne_z,-ligne_r,'b-')
    fleche(ligne_z,-ligne_r,sens,longfleche,style='b-')
axis('equal')
xlabel('z')
ylabel('r')
xlim(-zmax,zmax)
ylim(-rmax,rmax)
plot([-zs,zs,zs,-zs,-zs],[a,a,-a,-a,a],'r-')
grid()
title('Lignes de H')
                     

def ligneB(a,zs,zn,M,ri,zi,sens,rmax,zmax,dmin):
    ds = 0.01*sens
    r = ri
    z = zi
    ligne_z = []
    ligne_r = []
    continuer = True
    while continuer:
        ligne_z.append(z)
        ligne_r.append(r)
        (Hr,Hz,Br,Bz) = champAimant(r,z,a,zs,zn,M)
        B = np.sqrt(Br**2+Bz**2)
        r += Br/B*ds
        z += Bz/B*ds
        if (abs(z-zs) < dmin and abs(r)<a) or (abs(z-zn)< dmin and abs(r)<a) or (abs(r)>rmax) or (abs(z)>zmax) : continuer=False
    return (np.array(ligne_r),np.array(ligne_z))
    

figure(figsize=(8,8))
for ri in [0,0.2,0.4,0.6,0.8,1.0]: 
    sens = 1
    ligne_r,ligne_z = ligneB(a,zs,zn,M,ri,zn+dmin,sens,rmax,zmax,dmin)
    plot(ligne_z,ligne_r,'b-')
    fleche(ligne_z,ligne_r,sens,longfleche,style='b-')
    plot(ligne_z,-ligne_r,'b-')
    fleche(ligne_z,-ligne_r,sens,longfleche,style='b-')
    sens = -1 
    ligne_r,ligne_z = ligneB(a,zs,zn,M,ri,zs-dmin,sens,rmax,zmax,dmin)
    plot(ligne_z,ligne_r,'b-')
    fleche(ligne_z,ligne_r,sens,longfleche,style='b-')
    plot(ligne_z,-ligne_r,'b-')
    fleche(ligne_z,-ligne_r,sens,longfleche,style='b-')
    sens = -1
    ligne_r,ligne_z = ligneB(a,zs,zn,M,ri,zn-dmin,sens,rmax,zmax,dmin)
    plot(ligne_z,ligne_r,'b-')
    fleche(ligne_z,ligne_r,sens,longfleche,style='b-')
    plot(ligne_z,-ligne_r,'b-')
    fleche(ligne_z,-ligne_r,sens,longfleche,style='b-')
axis('equal')
xlabel('z')
ylabel('r')
xlim(-zmax,zmax)
ylim(-rmax,rmax)
plot([-zs,zs,zs,-zs,-zs],[a,a,-a,-a,a],'r-')
grid()
title('Lignes de B')
                     
