
import numpy as np
import scipy.sparse as sparse
import scipy.sparse.linalg as linalg
from scipy.linalg import solve

            

def poisson(f,x,K0,u0,g0,KL,uL,gL,solver="iterative"):
    # f : fonction
    # x : tableau des x_i
    N = len(x)-1
    h = x-np.roll(x,1)
    C = np.zeros((N+1,N+1),dtype=np.float64)
    C[0,0] = 1/h[1]+K0
    C[0,1] = -1/h[1]
    C[N,N-1] = -1/h[N]
    C[N,N] = 1/h[N]-KL
    for i in range(1,N):
        C[i,i] = 1/h[i]+1/h[i+1]
        C[i,i+1] = -1/h[i+1]
        C[i,i-1] = -1/h[i]
    D = np.zeros((N+1),dtype=np.float64)
    D[0] = f(x[0])*h[1]*0.5+K0*u0-g0
    D[N] = f(x[N])*h[N]*0.5-KL*uL+gL
    for i in range(1,N):
        D[i] = f(x[i])*(h[i]+h[i+1])*0.5
    if solver=="iterative":
        C = sparse.bsr_array(C) # facultatif
        xi, exitCode = linalg.cg(C, D, atol=1e-5)
        print("exitCode : ",exitCode)
    else:
        xi = solve(C,D)
    
    return xi
            

def raffinement(f,x,alpha):
    N = len(x)-1
    eta = np.zeros(N)
    for i in range(N):
        h = x[i+1]-x[i]
        eta[i] = h*np.sqrt((f(x[i])**2+f(x[i+1])**2)*h/2)
    mx = np.max(eta)
    xx = []
    for i in range(N):
        xx.append(x[i])
        if eta[i] > alpha*mx:
            xx.append((x[i]+x[i+1])/2)
    xx.append(x[N])
    return np.array(xx)
            

x = np.linspace(0,1,300)
def f(x):
    return 1.0
K0 = KL = 1e3
xi = poisson(f,x,K0,0,0,KL,0,0)
def u(x):
    return -0.5*x*x+0.5*x
from matplotlib.pyplot import *
figure(figsize=(12,6))
plot(x,xi,"r")
plot(x,u(x),"k--")
grid()
xlabel("x")
ylabel("u")
                

x = np.linspace(0,1,300)
def f(x):
    return 1.0
K0 = 0
KL = 1e3
xi = poisson(f,x,K0,0,0,KL,0,0)
def u(x):
    return -0.5*x*x+0.5
figure(figsize=(12,6))
plot(x,xi,"r")
plot(x,u(x),"k--")
grid()
xlabel("x")
ylabel("u")
                

x = np.linspace(0,1,100)
def f(x):
    return np.exp(-200*(x-0.5)**2)
K0 = KL = 1e3
xi = poisson(f,x,K0,0,0,KL,0,0)
figure(figsize=(12,6))
plot(x,xi,"r")
grid()
xlabel("x")
ylabel("u")
                

x = np.linspace(0,1,20)
xi = poisson(f,x,K0,0,0,KL,0,0)
figure(figsize=(12,6))
plot(x,xi,"r.")
plot(x,xi,"r-")
grid()
xlabel("x")
ylabel("u")
                

x = raffinement(f,x,0.9)
x = raffinement(f,x,0.9)
xi = poisson(f,x,K0,0,0,KL,0,0)
figure(figsize=(12,6))
plot(x,xi,"r.")
plot(x,xi,"r-")
grid()
xlabel("x")
ylabel("u")
                

x = np.linspace(0,1,300)
def f(x):
    return 0
xi = poisson(f,x,1e9,0,0,0,0,1)
figure(figsize=(12,6))
plot(x,xi,"r")
grid()
xlabel("x")
ylabel("u")
                

x = np.linspace(0,1,300)
def f(x):
    return 0
xi = poisson(f,x,1e3,0,0,1e3,1,0)
figure(figsize=(12,6))
plot(x,xi,"r")
grid()
xlabel("x")
ylabel("u")
                

x = np.linspace(0,1,300)
def f(x):
    return 0
xi = poisson(f,x,1e6,0,0,1e6,1,0)
figure(figsize=(12,6))
plot(x,xi,"r")
grid()
xlabel("x")
ylabel("u")
                

xi = poisson(f,x,1e6,0,0,1e6,1,0,solver="direct")
figure(figsize=(12,6))
plot(x,xi,"r")
grid()
xlabel("x")
ylabel("u")
                

def poissonDirichlet(f,x,u0,uL,solver="iterative"):
    # f : fonction
    # x : tableau des x_i
    N = len(x)-1
    h = x-np.roll(x,1)
    A = np.zeros((N+1,N+1),dtype=np.float64) 
    for i in range(1,N):
        A[i,i] = 1/h[i]+1/h[i+1]
        A[i,i+1] = -1/h[i+1]
        A[i,i-1] = -1/h[i]
    B = np.zeros((N+1),dtype=np.float64)
    for i in range(1,N):
        B[i] = f(x[i])*(h[i]+h[i+1])*0.5
    # extraction des matrice pour les noeuds intérieurs
    A = A[1:N,1:N]
    B = B[1:N]
    B[0] += u0/h[1] # conditions limites de Dirichlet
    B[N-2] += uL/h[N]
    if solver=="iterative":
        A = sparse.bsr_array(A) # facultatif
        xi, exitCode = linalg.cg(A, B, atol=1e-5)
        print("exitCode : ",exitCode)
    else:
        xi = solve(A,B)
        # xi contient les noeuds intérieurs
    xi = np.concatenate((np.array([u0]),xi,np.array([uL])))
    return xi
            

x = np.linspace(0,1,300)
def f(x):
    return 0
xi = poissonDirichlet(f,x,0,1)
figure(figsize=(12,6))
plot(x,xi,"r")
grid()
xlabel("x")
ylabel("u")
                

def poissonGeneral(f,x,type0,K0,u0,g0,typeL,KL,uL,gL,solver="iterative"):
    # f : fonction
    # x : tableau des x_i
    # type0 et typeL : "dir","neu" ou "rob"
    if type0=="neu": K0 = 0
    if typeL=="neu": KL = 0
    N = len(x)-1
    h = x-np.roll(x,1)
    C = np.zeros((N+1,N+1),dtype=np.float64)
    C[0,0] = 1/h[1]+K0
    C[0,1] = -1/h[1]
    C[N,N-1] = -1/h[N]
    C[N,N] = 1/h[N]-KL
    for i in range(1,N):
        C[i,i] = 1/h[i]+1/h[i+1]
        C[i,i+1] = -1/h[i+1]
        C[i,i-1] = -1/h[i]
    D = np.zeros((N+1),dtype=np.float64)
    D[0] = f(x[0])*h[1]*0.5+K0*u0-g0
    D[N] = f(x[N])*h[N]*0.5-KL*uL+gL
    for i in range(1,N):
        D[i] = f(x[i])*(h[i]+h[i+1])*0.5
    if type0=="dir" and typeL=="dir":
        C = C[1:N,1:N]
        D = D[1:N]
        D[0] += u0/h[1] # conditions limites de Dirichlet
        D[N-2] += uL/h[N]
    elif type0=="dir" and typeL!="dir":
        C = C[1:N+1,1:N+1]
        D = D[1:N+1]
        D[0] += u0/h[1]
    elif type0!="dir" and typeL=="dir":
        C = C[0:N,0:N]
        D = D[0:N]
        D[N-1] += uL/h[N]
    if solver=="iterative":
        C = sparse.bsr_array(C) # facultatif
        xi, exitCode = linalg.cg(C, D, atol=1e-5)
        print("exitCode : ",exitCode)
    else:
        xi = solve(C,D)
    if type0=="dir" and typeL=="dir":
        xi = np.concatenate((np.array([u0]),xi,np.array([uL])))
    elif type0=="dir" and typeL!="dir":
        xi = np.concatenate((np.array([u0]),xi))
    elif type0!="dir" and typeL=="dir":
        xi = np.concatenate((xi,np.array([uL])))
    return xi
            

x = np.linspace(0,1,300)
def f(x):
    return 1.0
g0 = 0 # neumann en x=0
uL = 0 # dirichlet en x=L
xi = poissonGeneral(f,x,"neu",0,0,g0,"dir",0,uL,0)
def u(x):
    return -0.5*x*x+0.5
figure(figsize=(12,6))
plot(x,xi,"r")
plot(x,u(x),"k--")
grid()
xlabel("x")
ylabel("u")
                

x = np.linspace(0,1,300)
def f(x):
    return 0.0
u0 = 0 #dirichlet en x=0
uL = 1 # dirichlet en x=L
xi = poissonGeneral(f,x,"dir",0,u0,0,"dir",0,uL,0)
figure(figsize=(12,6))
plot(x,xi,"r")
grid()
xlabel("x")
ylabel("u")
                

def poissonGeneral2(s,a,x,type0,K0,u0,g0,typeL,KL,uL,gL,solver="iterative"):
    # s : fonction
    # a : tableau des a_i
    # x : tableau des x_i
    # type0 et typeL : "dir","neu" ou "rob"
    if type0=="neu": K0 = 0
    if typeL=="neu": KL = 0
    N = len(x)-1
    h = x-np.roll(x,1)
    C = np.zeros((N+1,N+1),dtype=np.float64)
    C[0,0] = a[1]/h[1]+K0*a[1]
    C[0,1] = -a[1]/h[1]
    C[N,N-1] = -a[N]/h[N]
    C[N,N] = a[N]/h[N]-KL*a[N]
    for i in range(1,N):
        C[i,i] = a[i]/h[i]+a[i+1]/h[i+1]
        C[i,i+1] = -a[i+1]/h[i+1]
        C[i,i-1] = -a[i]/h[i]
    D = np.zeros((N+1),dtype=np.float64)
    D[0] = s(x[0])*h[1]*0.5+(K0*u0-g0)*a[1]
    D[N] = f(x[N])*h[N]*0.5+(gL-KL*uL)*a[N]
    for i in range(1,N):
        D[i] = s(x[i])*(h[i]+h[i+1])*0.5
    if type0=="dir" and typeL=="dir":
        C = C[1:N,1:N]
        D = D[1:N]
        D[0] += u0/h[1]*a[1] # conditions limites de Dirichlet
        D[N-2] += uL/h[N]*a[N]
    elif type0=="dir" and typeL!="dir":
        C = C[1:N+1,1:N+1]
        D = D[1:N+1]
        D[0] += u0/h[1]*a[1]
    elif type0!="dir" and typeL=="dir":
        C = C[0:N,0:N]
        D = D[0:N]
        D[N-1] += uL/h[N]*a[N]
    if solver=="iterative":
        C = sparse.bsr_array(C) # facultatif
        xi, exitCode = linalg.cg(C, D, atol=1e-5)
        print("exitCode : ",exitCode)
    else:
        xi = solve(C,D)
    if type0=="dir" and typeL=="dir":
        xi = np.concatenate((np.array([u0]),xi,np.array([uL])))
    elif type0=="dir" and typeL!="dir":
        xi = np.concatenate((np.array([u0]),xi))
    elif type0!="dir" and typeL=="dir":
        xi = np.concatenate((xi,np.array([uL])))
    return xi
            

N = 300
x = np.linspace(0,1,N+1)
a = np.zeros(N+1)
a[1:N//2] = 1
a[N//2:N+1] = 2
def s(x):
    return 0
u0 = 0
uL = 1
xi = poissonGeneral2(s,a,x,"dir",0,u0,0,"dir",0,uL,0)
figure(figsize=(12,6))
plot(x,xi,"r-")
grid()
xlabel("x")
ylabel("u")
            

u0 = 0
gL = 1
xi = poissonGeneral2(s,a,x,"dir",0,u0,0,"neu",0,0,gL)
figure(figsize=(12,6))
plot(x,xi,"r-")
grid()
xlabel("x")
ylabel("u")
            
