# -*- coding: utf-8 -*-

from PyQt5 import QtGui, QtCore, QtWidgets
from matplotlib.backends.backend_qt4agg import FigureCanvasQTAgg as FigureCanvas
from matplotlib.figure import Figure
import sys,math
import numpy
import scipy.signal
import time
from ArduinoDueFiltrageNumerique import Arduino


class FiltreNumerique:
    def __init__(self,a,b):
        self.a = numpy.array(a,dtype=numpy.double)
        self.b = numpy.array(b,dtype=numpy.double)
        self.a_max = numpy.max(numpy.abs(self.a))
        self.b_max = numpy.max(numpy.abs(self.b))
        self.ab_max = max(self.a_max,self.b_max)
        self.a_size = self.a.size
        self.b_size = self.b.size
        self.type = "double"
        self.dtype = numpy.double
        

    def reponse_freq(self):
        def H(f):
            num = self.b[0]
            for k in range(1,self.b_size):
                num += self.b[k]*numpy.exp(-1j*2*numpy.pi*k*f)
            den = self.a[0]
            for k in range(1,self.a_size):
                den += self.a[k]*numpy.exp(-1j*2*numpy.pi*k*f)
            return num/den
        f = numpy.linspace(0.0,0.5,1000)
        hf = H(f)
        g = numpy.absolute(hf)
        phi = numpy.unwrap(numpy.angle(hf))
        return (f,g,phi)
    
        
    def integer32(self,signal_nbits,filtre_gbits):
        self.type = "int32"
        self.filtre_nbits = 32+1-signal_nbits-filtre_gbits
        self.signal_nbits = signal_nbits
        self.ab_shift = int(self.filtre_nbits-1-numpy.log(self.ab_max)/numpy.log(2))
        self.a_i = numpy.array(self.a*2**self.ab_shift,dtype=numpy.int32)
        self.b_i = numpy.array(self.b*2**self.ab_shift,dtype=numpy.int32)
        self.a = numpy.array(self.a_i*2**(-self.ab_shift),dtype=numpy.double)
        self.b = numpy.array(self.b_i*2**(-self.ab_shift),dtype=numpy.double)
        self.dtype = numpy.int32
        
    
    def integer64(self,signal_nbits,filtre_gbits):
        self.type = "int64"
        self.filtre_nbits = 64+1-signal_nbits-filtre_gbits
        self.signal_nbits = signal_nbits
        self.ab_shift = int(self.filtre_nbits-1-numpy.log(self.ab_max)/numpy.log(2))
        self.a_i = numpy.array(self.a*2**self.ab_shift,dtype=numpy.int64)
        self.b_i = numpy.array(self.b*2**self.ab_shift,dtype=numpy.int64)
        self.a = numpy.array(self.a_i*2**(-self.ab_shift),dtype=numpy.double)
        self.b = numpy.array(self.b_i*2**(-self.ab_shift),dtype=numpy.double)
        self.dtype = numpy.int64

        


class MyMplCanvas(FigureCanvas):
    def __init__(self, parent=None, width=5, height=4, dpi=100):
        fig = Figure(figsize=(width, height), dpi=dpi)
        self.axes = fig.gca()
        self.compute_initial_figure()
        FigureCanvas.__init__(self, fig)
        self.setParent(parent)
        FigureCanvas.setSizePolicy(self,
                                    QtWidgets.QSizePolicy.Expanding,
                                    QtWidgets.QSizePolicy.Expanding)
        FigureCanvas.updateGeometry(self)
    def compute_initial_figure(self):
        pass
    
class TraceGain(MyMplCanvas):
    def compute_initial_figure(self):
        pass
    def update_figure(self,fmax,f,G):
        self.axes.clear()
        self.axes.set_position([0.15,0.2,0.7,0.7])
        self.axes.plot(f/1000,G)
        self.axes.axis([0,fmax/1000,0,1.5])
        self.axes.set_xlabel("f (kHz)")
        self.axes.set_ylabel("G")
        self.axes.grid()
        self.draw()
        
class TracePhase(MyMplCanvas):
    def compute_initial_figure(self):
        pass
    def update_figure(self,fmax,f,phi):
        self.axes.clear()
        self.axes.set_position([0.15,0.2,0.7,0.7])
        psi = numpy.unwrap(phi)*1.0/math.pi
        self.axes.plot(f/1000,psi)
        self.axes.axis([0,fmax/1000,psi.min(),psi.max()])
        self.axes.set_xlabel("f (kHz)")
        self.axes.set_ylabel("phi/pi")
        self.axes.grid()
        self.draw()
        
       
class FilterApplicationWindow(QtWidgets.QMainWindow):
    def __init__(self,filterType):
        self.samplesReading_thread = 0
        QtWidgets.QMainWindow.__init__(self)
        self.setAttribute(QtCore.Qt.WA_DeleteOnClose)
        self.filterType = filterType
        if filterType=="RIF":
            self.setWindowTitle(u'Filtre RIF Arduino')
        elif filterType=="RII":
            self.setWindowTitle(u'Filtre RII Arduino')
        self.main_widget = QtWidgets.QWidget(self)
        main_box_layout = QtWidgets.QHBoxLayout(self.main_widget)
        plot_layout = QtWidgets.QVBoxLayout(self.main_widget)
        self.traceGain = TraceGain(self.main_widget, width=5, height=4, dpi=100)
        plot_layout.addWidget(self.traceGain)
        self.tracePhase = TracePhase(self.main_widget, width=5, height=4, dpi=100)
        plot_layout.addWidget(self.tracePhase)
        main_box_layout.addLayout(plot_layout)
        self.file_menu = QtWidgets.QMenu('&Fichier', self)
        self.file_menu.addAction('&Quitter', self.quit,QtCore.Qt.CTRL + QtCore.Qt.Key_Q)
        self.menuBar().addMenu(self.file_menu)
        if filterType=="RIF":
            filtre_label = QtWidgets.QLabel(u"Filtre RIF")
        elif filterType=="RII":
            filtre_label = QtWidgets.QLabel(u"Filtre RII")
            self.type_filtre_combo = QtWidgets.QComboBox()
            self.type_filtre_combo.addItems(["butter","cheby1","cheby2","ellip","bessel"])
        self.filtre_combo = QtWidgets.QComboBox()
        self.filtre_combo.addItems(["Passe bas","Passe haut","Passe bande","Coupe bande"])
        fen_label = QtWidgets.QLabel(u"Fenêtre")
        if filterType=="RIF":
            self.fen_combo = QtWidgets.QComboBox()
            self.fen_combo.addItems(["boxcar","hamming","hann","blackman","triang"])
        fechant_label = QtWidgets.QLabel(u"Fréquence d'échantillonnage (kHz)")
        self.fechant_combo = QtWidgets.QComboBox()
        self.fechant_combo.addItems(["1","2","4","10","20","40","80","100"])
        if filterType=="RIF":
            self.P_label = QtWidgets.QLabel(u"Nombre de coefficients")
            self.P_slider = QtWidgets.QSlider(QtCore.Qt.Horizontal)
            self.P_slider.setMinimum(1)
            self.P_slider.setMaximum(100)
        if filterType=="RII":
            ordre_label = QtWidgets.QLabel(u"Ordre")
            self.ordre_combo = QtWidgets.QComboBox()
            self.ordre_combo.addItems(["1","2","3","4","5","6"])
        fc1_label = QtWidgets.QLabel(u"Fréquence de coupure 1 (Hz)")
        self.fc1_edit = QtWidgets.QLineEdit()
        self.fc1_edit.setText("100")
        fc2_label = QtWidgets.QLabel(u"Fréquence de coupure 2 (Hz)")
        self.fc2_edit = QtWidgets.QLineEdit()
        self.fc2_edit.setText("200")
        
        numtype_label = QtWidgets.QLabel(u"Type de l'accumulateur")
        self.numtype_combo = QtWidgets.QComboBox()
        self.numtype_combo.addItems(["Entier 32 bits","Flottant"])
        nbits_label = QtWidgets.QLabel(u"Nombre de bits de sécurité")
        self.nbits_combo = QtWidgets.QComboBox()
        self.nbits_array_32 = ["1","2","3","4","5","6","7","8"]
        self.nbits_combo.addItems(self.nbits_array_32)
        real_label = QtWidgets.QLabel(u"Réalisation")
        self.real_combo = QtWidgets.QComboBox()
        self.real_combo.addItems(["Forme directe I","Forme directe II",u"Forme directe transposée II"])
        
        offset_label = QtWidgets.QLabel(u"Offset")
        self.offset_edit = QtWidgets.QLineEdit()
        self.offset_edit.setText("2048")
        com_label = QtWidgets.QLabel(u"Port COM")
        self.com_combo = QtWidgets.QComboBox()
        self.com_combo.addItems(["1","2","3","4","5","6","7","8","9","10","11","12","13","14","15","16","17","18","19","20"])
        send_button = QtWidgets.QPushButton("Transmettre")
        ui_layout = QtWidgets.QVBoxLayout(self.main_widget)
        layout = QtWidgets.QGridLayout()
        k = 0
        if filterType=="RIF":
            layout.addWidget(filtre_label,k,0)
            layout.addWidget(self.filtre_combo,k,1)
            k += 1
        if filterType=="RII":
            layout.addWidget(filtre_label,k,0)
            layout.addWidget(self.type_filtre_combo,k,1)
            k += 1
            layout.addWidget(self.filtre_combo,k,1)
            k += 1
        if filterType=="RIF":
            layout.addWidget(fen_label,k,0)
            layout.addWidget(self.fen_combo,k,1)
            k += 1
        layout.addWidget(fechant_label,k,0)
        layout.addWidget(self.fechant_combo,k,1)
        k += 1
        if filterType=="RIF":
            layout.addWidget(self.P_label,k,0)
            layout.addWidget(self.P_slider,k,1)
            k += 1
        elif filterType=="RII":
            layout.addWidget(ordre_label,k,0)
            layout.addWidget(self.ordre_combo,k,1)
            k += 1
        layout.addWidget(fc1_label,k,0)
        layout.addWidget(self.fc1_edit,k,1)
        k += 1
        layout.addWidget(fc2_label,k,0)
        layout.addWidget(self.fc2_edit,k,1)
        k += 1
        layout.addWidget(numtype_label,k,0)
        layout.addWidget(self.numtype_combo,k,1)
        k += 1
        layout.addWidget(nbits_label,k,0)
        layout.addWidget(self.nbits_combo,k,1)
        k += 1
        layout.addWidget(real_label,k,0)
        layout.addWidget(self.real_combo,k,1)
        k += 1
        layout.addWidget(offset_label,k,0)
        layout.addWidget(self.offset_edit,k,1)
        k += 1
        layout.addWidget(com_label,k,0)
        layout.addWidget(self.com_combo,k,1)
        k += 1
        layout.addWidget(send_button,k,0)
        
        ui_layout.addLayout(layout)
        self.console = QtWidgets.QTextEdit()
        ui_layout.addWidget(self.console)
        main_box_layout.addLayout(ui_layout)
        self.main_widget.setFocus()
        self.setCentralWidget(self.main_widget)
        self.filtre_combo.currentIndexChanged.connect(self.update)
        
        self.fechant_combo.currentIndexChanged.connect(self.update)
        if filterType=="RIF":
            self.fen_combo.currentIndexChanged.connect(self.update)
            self.P_slider.sliderReleased.connect(self.update)
            #self.P_slider.sliderMoved.connect(self.updatePlabel)
        elif filterType=="RII":
            self.ordre_combo.currentIndexChanged.connect(self.update)
            self.type_filtre_combo.currentIndexChanged.connect(self.update)
        self.fc1_edit.returnPressed.connect(self.update)
        self.fc2_edit.returnPressed.connect(self.update)
        send_button.clicked.connect(self.transmettre)
        self.numtype_combo.currentIndexChanged.connect(self.update)
        self.nbits_combo.currentIndexChanged.connect(self.update)
        self.update()
        
        self.arduino = 0
        
    def updatePlabel(self):
        P = self.P_slider.value()
        self.P_label.setText("Nombre de coefficients = %d"%2*P+1)
    def update(self):
        if self.filterType=="RIF":
            self.update_rif()
        elif self.filterType=="RII":
            self.update_rii()
    def update_rif(self):
        class FrequenceError(Exception): pass
        filtre = self.filtre_combo.currentText()
        fen = str(self.fen_combo.currentText())
        fechant = float(self.fechant_combo.currentText())*1000.0
        self.fsample_hz = fechant
        P = int(self.P_slider.value())
        N = 2*P+1
        self.P_label.setText("Nombre de coefficients N=%d"%(N))
        fc1 = float(self.fc1_edit.text())
        fc2 = float(self.fc2_edit.text())
        self.offset = float(self.offset_edit.text())
        self.numtype = self.numtype_combo.currentIndex()
        self.nbits = int(self.nbits_combo.currentIndex())
        try:
            f1 = fc1/fechant
            f2 = fc2/fechant
            if f1>0.5 or f2>0.5:
                raise FrequenceError(u"La fréquence de coupure doit être inférieure à fechant/2")
            if f2<=f1 and (filtre=="Passe bande" or filtre=="Coupe bande"):
                raise FrequenceError(u"La première fréquence de coupure doit être inférieure à la seconde")
            
            self.coef_a = [1.0]
            if filtre=="Passe bas":
                self.coef_b = scipy.signal.firwin(numtaps=N,cutoff=[f1],nyq=0.5,window=fen)
            elif filtre=="Passe haut":
                self.coef_b = scipy.signal.firwin(numtaps=N,cutoff=[f1],pass_zero=False,nyq=0.5,window=fen)
            elif filtre=="Passe bande":
                self.coef_b = scipy.signal.firwin(numtaps=N,cutoff=[f1,f2],pass_zero=False,nyq=0.5,window=fen)
            elif filtre=="Coupe bande":
                self.coef_b = scipy.signal.firwin(numtaps=N,cutoff=[f1,f2],nyq=0.5,window=fen)
            self.filtreNum = FiltreNumerique(self.coef_a,self.coef_b)
            if self.numtype==0: # entier 32 bits
                self.filtreNum.integer32(12,self.nbits)
            (f,g,phi) = self.filtreNum.reponse_freq()
            fhz = f*fechant
            fn = fechant/2
            self.traceGain.update_figure(fn,fhz,g)
            self.tracePhase.update_figure(fn,fhz,phi)
            self.fechant = fechant
            s = u"Coefficients du filtre\n"
            for k in range(self.coef_b.size):
                s += "b[%d] : %g\n"%(k,self.filtreNum.b[k])
            self.console.setText(s)
        except FrequenceError as e:
            QtWidgets.QMessageBox.warning(self, "",e)        

            
    def update_rii(self):
        class FrequenceError(Exception): pass
        filtre = self.filtre_combo.currentText()
        ordre = int(self.ordre_combo.currentText())
        type_filtre = str(self.type_filtre_combo.currentText())
        fechant = float(self.fechant_combo.currentText())*1000.0
        self.fsample_hz = fechant
        fc1 = float(self.fc1_edit.text())
        fc2 = float(self.fc2_edit.text())
        self.offset = float(self.offset_edit.text())
        self.numtype = self.numtype_combo.currentIndex()
        self.nbits = int(self.nbits_combo.currentIndex())
        
        try:
            a = fc1/fechant
            b = fc2/fechant
            if a>0.5 or b>0.5:
                raise FrequenceError(u"La fréquence de coupure doit être inférieure à fechant/2")
            if b<=a and (filtre=="Passe bande" or filtre=="Coupe bande"):
                raise FrequenceError(u"La première fréquence de coupure doit être inférieure à la seconde")
            if filtre=="Passe bas":
                btype = "lowpass"
                freq = [a/0.5]
            elif filtre=="Passe haut":
                btype="highpass"
                freq = [a/0.5]
            elif filtre=="Passe bande":
                btype="bandpass"
                freq = [a/0.5,b/0.5]
            elif filtre=="Coupe bande":
                btype ="bandstop"
                freq = [a/0.5,b/0.5]
            (self.coef_b,self.coef_a) = scipy.signal.iirfilter(ordre,freq,btype=btype,rp=1,rs=1,ftype=type_filtre)
            self.filtreNum = FiltreNumerique(self.coef_a,self.coef_b) 
            if self.numtype==0: # entier 32 bits
                self.filtreNum.integer32(12,self.nbits)
            (f,g,phi) = self.filtreNum.reponse_freq()
            fhz = f*fechant
            fn = fechant/2
            self.traceGain.update_figure(fn,fhz,g)
            self.tracePhase.update_figure(fn,fhz,phi)
            self.fechant = fechant

            s = u"Coefficients du filtre\n"            
            s += "\n"
            for k in range(self.coef_b.size):
                s += "b[%d] : %g\n"%(k,self.filtreNum.b[k])
            s += "\n"
            for k in range(self.coef_a.size):
                s += "a[%d] : %g\n"%(k,self.filtreNum.a[k])
            self.console.setText(s)
        except FrequenceError as e:
            QtWidgets.QMessageBox.warning(self, "",e)
            
    
    def transmettre(self):
        self.update()
        port = "COM%s"%self.com_combo.currentText()
        if self.real_combo.currentText() == "Forme directe I":
            real = 1
        elif self.real_combo.currentText() == "Forme directe II":
            real = 2
        elif self.real_combo.currentText()==u"Forme directe transposée II":
            real = 3
        m = u"Transmission des paramètres à l'Arduino"
        QtWidgets.QMessageBox.warning(self, "",m)
        
        if self.arduino==0:
            self.arduino = Arduino(port)
        ard = self.arduino
        if self.numtype==0:
            r=ard.lancer_filtrage_32bits([0],[1],self.fechant,self.coef_a,self.coef_b,self.offset,real,self.nbits)
            if r:
                QtWidgets.QMessageBox.warning(self, "",u"Arduino ne répond pas, appuyer sur RESET")
        elif self.numtype==1:
            r=ard.lancer_filtrage_float([0],[1],self.fechant,self.coef_a,self.coef_b,self.offset,real)
            if r:
                QtWidgets.QMessageBox.warning(self, "",u"Arduino ne répond pas, appuyer sur RESET")
        
         
    def quit(self):
        if self.arduino!=0:
            self.arduino.close()
        self.close()

    def closeEvent(self, ce):
        self.quit()
        
