
import numpy
import math
import cmath
                
class Reflexion:
    def __init__(self,milieu):
        self.freq = []
        self.amp = []
        self.reflex = []
        self.trans = []
        self.k1 = []
        self.k2 = []
        self.nf = 0
        self.fc = 0.0
        self.L = 0.0
        if milieu=="plasma":
            self.f_k1 = self.k1_vide
            self.f_k2 = self.k2_plasma
        elif milieu=="particule":
            self.f_k1 = self.k1_particule
            self.f_k2 = self.k2_particule
                
    def k1_vide(self,f):
        return f
    def k2_plasma(self,f):
        return cmath.sqrt(f*f-self.fc*self.fc)
                 
    def k1_particule(self,f):
        return cmath.sqrt(f)
    def k2_particule(self,f):
        return cmath.sqrt(f-self.fc)
                 
    def reflexion(self,f):
        deuxpi = 2*math.pi
        k1 = self.f_k1(f)
        k2 = self.f_k2(f)
        r = (k1-k2)/(k1+k2)*cmath.exp(2j*k1*deuxpi*self.L)
        tau = (cmath.exp(1j*k1*deuxpi*self.L)\
              +r*cmath.exp(-1j*k1*deuxpi*self.L))*cmath.exp(-1j*k2*deuxpi*self.L)
        return(k1,k2,r,tau)
                 
    def sinus(self,f):
        self.nf = 1
        self.freq = [f]
        self.amp = [1]
        (k1,k2,r,tau) = self.reflexion(f)
        self.reflex = [r]
        self.trans = [tau]
        self.k1 = [k1]
        self.k2 = [k2]
                 
    def paquet(self,N,P):
        sigma = P*0.25
        self.freq = []
        self.amp = []
        self.reflex = []
        self.trans = []
        self.k1 = []
        self.k2 = []
        self.nf = 0
        deuxpi = 2*math.pi
        for n in range(-P,P+1):
            f = N+n
            if f>=0:
                (k1,k2,r,tau) = self.reflexion(f)
                self.freq.append(f)
                self.amp.append(math.exp(-n*n/(4*sigma*sigma)))
                self.reflex.append(r)
                self.k1.append(k1)
                self.k2.append(k2)
                self.trans.append(tau)
                self.nf += 1
                 
    def echantillons(self,xmin,xmax,t,np,proba=False):
        x1 = numpy.linspace(xmin,self.L,np)
        psi1_i = numpy.zeros(x1.size,dtype=numpy.complex)
        psi1_r = numpy.zeros(x1.size,dtype=numpy.complex)
        psi1 = numpy.zeros(x1.size,dtype=numpy.complex)
        x2 = numpy.linspace(self.L,xmax,np)
        psi2 = numpy.zeros(x2.size,dtype=numpy.complex)
        deuxpi = 2*math.pi
        for i in range(self.nf):
            psi1_i+=self.amp[i]*(numpy.exp(1j*deuxpi*(self.k1[i]*x1\
                                                        -self.freq[i]*t)))
            psi1_r+=self.amp[i]*self.reflex[i]*numpy.exp(1j*deuxpi*(-self.k1[i]*x1\
                                                                      -self.freq[i]*t))
            psi2+=self.amp[i]*self.trans[i]*numpy.exp(1j*deuxpi*(self.k2[i]*x2\
                                                                   -self.freq[i]*t))
        psi1 = psi1_i+psi1_r
        if proba:
            return (x1,x2,psi1_i*numpy.conj(psi1_i),psi1_r*numpy.conj(psi1_r),psi1*numpy.conj(psi1),psi2*numpy.conj(psi2))
        else:
            return (x1,x2,psi1_i,psi1_r,psi1,psi2)        
                 