import random
import matplotlib.pyplot as plt


class Noeud:
    def __init__(self):
        self.x_median = 0
        self.liste_x = []
        self.x_min = 0
        self.x_max = 0
        self.noeud_gauche = None
        self.noeud_droite = None

def generation_arbre(noeud,liste_x):
    # liste_x : liste de valeurs ordonnées
    n = len(liste_x)
    noeud.liste_x = liste_x
    noeud.x_min = liste_x[0]
    noeud.x_max = liste_x[n-1]
    if n==1:
        return
    else:
        m = int(n/2)
        x_median = liste_x[m]
        noeud.x_median = x_median
        i = m-1
        while i>=0 and liste_x[i]==x_median: 
            i -= 1
        i+=1
        if i>0:
            liste_gauche = liste_x[0:i] 
            noeud.noeud_gauche = Noeud()
            generation_arbre(noeud.noeud_gauche,liste_gauche)
            liste_droite = liste_x[i:n] 
            noeud.noeud_droite = Noeud()
            generation_arbre(noeud.noeud_droite,liste_droite)
        
def affichage_arbre(noeud,x,y,dx,dy):
    if noeud.noeud_gauche==None and noeud.noeud_droite==None:
        text = plt.text(x,y,str(noeud.liste_x),ha="center",va="center")
        text.set_bbox(dict(boxstyle="round",facecolor='y',edgecolor='None',alpha=0.4))
    else:
        text = plt.text(x,y,str([noeud.x_min,noeud.x_max]),ha="center",va="center")
        text.set_bbox(dict(boxstyle="round",facecolor='w',edgecolor='None',alpha=0.85))
    if noeud.noeud_gauche != None:
        plt.plot([x,x-dx],[y,y-dy],"k-")
        affichage_arbre(noeud.noeud_gauche,x-dx,y-dy,dx*0.5,dy)
    if noeud.noeud_droite != None:
        plt.plot([x,x+dx],[y,y-dy],"k-")
        affichage_arbre(noeud.noeud_droite,x+dx,y-dy,dx*0.5,dy)
    
            

N=10
liste_x = []
for i in range(N):
    liste_x.append(random.randrange(0,100))
liste_x.sort()
print(liste_x)

racine = Noeud()
generation_arbre(racine,liste_x)
fig=plt.figure(figsize=(12,6))
plt.axis('off')
plt.xticks([])
plt.yticks([])
affichage_arbre(racine,0,0,20,10)
plt.show()



def rechercher_valeur(noeud,valeur):
    pass
    # à compléter

def obtenir_arbre(noeud,pile):
    pass
    #à compléter


def rechercher_dans_intervalle(noeud,a,b,pile):
    pass
    # à compléter
                

    
        
