import random
import numpy
import matplotlib.pyplot as plt


def partition(liste,axe,debut,fin):
    pivot = liste[fin][axe]
    i = debut
    for j in range(debut,fin):
        if liste[j][axe]<=pivot:
            liste[i],liste[j] = liste[j],liste[i]
            i += 1
    liste[i],liste[fin] = liste[fin],liste[i]
    return i

def selection(liste,axe,debut,fin,rang): 
    if debut==fin:
        return liste[debut][axe]
    i = partition(liste,axe,debut,fin)
    k = i-debut+1
    if rang==k:
        return liste[i][axe]
    elif rang < k:
        return selection(liste,axe,debut,i-1,rang)
    else:
        return selection(liste,axe,i+1,fin,rang-k)
    
def mediane(liste,axe):
    return selection(list(liste),axe,0,len(liste)-1,int(len(liste)/2)+1)

N=10
K=2
points = []
for i in range(N):
    p = []
    for axe in range(K):
        p.append(random.random())
    points.append(p)

print(mediane(points,0))

def separation_mediane(liste,axe):
    m = mediane(liste,axe)
    N = len(liste)
    K = len(liste[0])
    L1 = []
    L2 = []
    for i in range(N):
        if liste[i][axe] < m:
            L1.append(liste[i])
        else:
            L2.append(liste[i])
    return (m,L1,L2)

class NoeudKd:
    def __init__(self):
        self.axe = 0
        self.valeur_mediane = 0
        self.liste_points = []
        self.noeud_gauche = None
        self.noeud_droite = None

def affichage_arbre(noeud,x,y,dx,dy):
    if noeud.noeud_gauche==None and noeud.noeud_droite==None:
        text = plt.text(x,y,str(numpy.array(noeud.liste_points)),ha="center",va="center")
        text.set_bbox(dict(boxstyle="round",facecolor='y',edgecolor='None',alpha=0.4))
        
    else :
        text = plt.text(x,y,"%d : %0.3f"%(noeud.axe,noeud.valeur_mediane),ha="center",va="center")
        text.set_bbox(dict(boxstyle="round",facecolor='w',edgecolor='None',alpha=0.85))
    if noeud.noeud_gauche != None:
        plt.plot([x,x-dx],[y,y-dy],"k-")
        affichage_arbre(noeud.noeud_gauche,x-dx,y-dy,dx*0.5,dy)
    if noeud.noeud_droite != None:
        plt.plot([x,x+dx],[y,y-dy],"k-")
        affichage_arbre(noeud.noeud_droite,x+dx,y-dy,dx*0.5,dy)

def generation_arbre_kd(noeud,liste_points,K,profondeur=0):
    pass
    # à compléter


def recherche_domaine(noeud,xmin,ymin,pile):
    pass
    # à compléter
            
    

