
import serial
import numpy
import math
import time
from matplotlib.pyplot import *
import scipy.signal

class Arduino():
    def __init__(self,port):
        self.ser = serial.Serial(port,baudrate=115200)
        time.sleep(1)
        self.SET_FILTRAGE_FLOAT = 104
        self.SET_FILTRAGE_32BITS = 103
        self.IS_READY = 100
        self.MAX_NCOEF = 256
        self.clockFreq = 42.0e6

    def close(self):
        self.ser.close()

    def write_int8(self,v):
        self.ser.write((v&0xFF).to_bytes(1,byteorder='big'))

    def write_int16(self,v):
        v = numpy.int16(v)
        char1 = int((v & 0xFF00) >> 8)
        char2 = int((v & 0x00FF))
        self.ser.write((char1).to_bytes(1,byteorder='big'))
        self.ser.write((char2).to_bytes(1,byteorder='big'))
        
    def write_int32(self,v):
        v = numpy.int32(v)
        char1 = int((v & 0xFF000000) >> 24)
        char2 = int((v & 0x00FF0000) >> 16)
        char3 = int((v & 0x0000FF00) >> 8)
        char4 = int((v & 0x000000FF))
        self.ser.write((char1).to_bytes(1,byteorder='big'))
        self.ser.write((char2).to_bytes(1,byteorder='big'))
        self.ser.write((char3).to_bytes(1,byteorder='big'))
        self.ser.write((char4).to_bytes(1,byteorder='big'))

    def write_float(self,v):
        if v!=0.0:
            e = math.floor(math.log(abs(v)/math.log(2)))
        else:
            e = 0
        m = numpy.int32(v*2**(30-e))
        self.write_int32(m)
        self.write_int8(e)
        
    def lancer_filtrage_float(self,voies,gains,fechant,a,b,offset,typ):
        a = numpy.array(a)
        b = numpy.array(b)
        na = len(a)
        nb = len(b)
        if na > self.MAX_NCOEF or nb > self.MAX_NCOEF:
            raise Exception("trop de coefficients de filtrage")
        print(b)
        print(a)
        (zeros,poles,gain) = scipy.signal.tf2zpk(b,a)
        for p in poles:
            if numpy.absolute(p) >= 1:
                print("Filtre instable")
        ticks = int(self.clockFreq/fechant)
        print("F echant = %d"%(self.clockFreq/ticks))
        self.write_int8(self.IS_READY)
        r = self.ser.read(1)
        if r!='O':
            print(u"Arduino bloqué : appuyez sur RESET")
            return 1
        self.write_int8(self.SET_FILTRAGE_FLOAT)
        nv = len(voies)
        self.write_int8(nv)
        for k in range(nv):
            self.write_int8(voies[k])
        for k in range(nv):
            self.write_int8(gains[k])
        self.write_int32(ticks)
        self.write_int8(na)
        for k in range(na):
            self.write_int8(a[k])
        self.write_int8(nb)
        for k in range(nb):
            self.write_float(b[k])
        self.write_int16(offset)
        self.write_int8(typ)
        return 0
        
    def lancer_filtrage_32bits(self,voies,gains,fechant,a,b,offset,typ,gbits):
        a = numpy.array(a)
        b = numpy.array(b)
        na = len(a)
        nb = len(b)
        if na > self.MAX_NCOEF or nb > self.MAX_NCOEF:
            raise Exception("trop de coefficients de filtrage")
        mb = numpy.max(numpy.absolute(b))
        ma = numpy.max(numpy.absolute(a))
        m = max(ma,mb)
        P = 32+1-12-gbits # nombre de bits des coefficients
        s = numpy.floor(P-1-numpy.log(m)/numpy.log(2))
        a_int32 = numpy.array(a*2**s,dtype=numpy.int32)
        b_int32 = numpy.array(b*2**s,dtype=numpy.int32)
        (zeros,poles,gain) = scipy.signal.tf2zpk(b_int32,a_int32)
        for p in poles:
            if numpy.absolute(p) >= 1:
                print("Filtre instable")
        print(b_int32)
        print(a_int32)
        ticks = int(self.clockFreq/fechant)
        print("F echant = %d"%(self.clockFreq/ticks))
        self.write_int8(self.IS_READY)
        r = self.ser.read(1)
        if r!='O':
            print(u"Arduino bloqué : appuyez sur RESET")
            return 1
        self.write_int8(self.SET_FILTRAGE_32BITS)
        nv = len(voies)
        self.write_int8(nv)
        for k in range(nv):
            self.write_int8(voies[k])
        for k in range(nv):
            self.write_int8(gains[k])
        self.write_int32(ticks)
        self.write_int8(na)
        for k in range(na):
            self.write_int32(a_int32[k])
        self.write_int8(nb)
        for k in range(nb):
            self.write_int32(b_int32[k])
        self.write_int8(int(s))
        self.write_int16(offset)
        self.write_int8(typ)
        return 0
            
def test_rif():
    ard = Arduino("COM17") 
    fechant = 21000.0
    voies = [0]
    gains = [1]
    fc = 1500.0
    a = [1.0]
    b = scipy.signal.firwin(numtaps=20,cutoff=[fc/fechant],window='hann',nyq=0.5)
    #b,a = scipy.signal.iirfilter(N=4,Wn=[fc/fechant*2],btype="lowpass",ftype="butter")
    w,h = scipy.signal.freqz(b,a)
    figure()
    subplot(211)
    plot(w/(2*numpy.pi)*fechant,numpy.absolute(h))
    xlabel("f")
    ylabel("GdB")
    grid()
    subplot(212)      
    plot(w/(2*numpy.pi)*fechant,numpy.unwrap(numpy.angle(h)))
    xlabel("f")
    ylabel("phase")
    grid()
    offset = 0x7FF
    typ = 1
    gbits = 6
    ard.lancer_filtrage_32bits(voies,gains,fechant,a,b,offset,typ,gbits)
    show(block=True)
           
def test_biquad():
    ard = Arduino("COM17") 
    fechant = 21000.0
    voies = [0]
    gains = [1]
    fc = 1500.0
    b,a = scipy.signal.iirfilter(N=2,Wn=[fc/fechant*2],btype="lowpass",ftype="butter")
    w,h = scipy.signal.freqz(b,a)
    figure()
    subplot(211)
    plot(w/(2*numpy.pi)*fechant,numpy.absolute(h))
    xlabel("f")
    ylabel("GdB")
    grid()
    subplot(212)      
    plot(w/(2*numpy.pi)*fechant,numpy.unwrap(numpy.angle(h)))
    xlabel("f")
    ylabel("phase")
    grid()
    offset = 0x7FF
    typ = 1
    ard.lancer_filtrage_float(voies,gains,fechant,a,b,offset,typ)
    show(block=True)
           
def test_integrateur():
    ard = Arduino("COM17") 
    fechant = 21000.0
    voies = [0]
    gains = [1]
    fc = 1500.0
    r1=0.98
    r2=0.98
    g = 0.03
    b=[g,0,-g]
    a=[1,-(r1+r2),r1*r2]
    w,h = scipy.signal.freqz(b,a)
    figure()
    subplot(211)
    plot(w/(2*numpy.pi)*fechant,numpy.absolute(h))
    xlabel("f")
    ylabel("GdB")
    grid()
    subplot(212)      
    plot(w/(2*numpy.pi)*fechant,numpy.unwrap(numpy.angle(h)))
    xlabel("f")
    ylabel("phase")
    grid()
    offset = 0x7FF
    typ = 2
    ard.lancer_filtrage_float(voies,gains,fechant,a,b,offset,typ)
    show(block=True)
           