# -*- coding: utf-8 -*-

import numpy
from schrodinger2d.main import Schrodinger2d
import time

px_min = 7
py_min = 6
levels = 3
Lx = 2.0
solver = Schrodinger2d(px_min,py_min,levels,Lx,normalize=True,colormap=2,gamma=0.7)
solver.opencl_platforms()
solver.set_opencl_platform_device(0,0)
solver.opencl_init()

dt = 2*solver.dx**2
x0=0.7
y0=0.5
k0 = 100
sigma0 = 0.1
E=k0*k0

solver.schrodinger(dt)
solver.init()
solver.paquet(x0,y0,k0,sigma0)
solver.opencl_create_memory()

ti=0
tf = 100*dt
temps = time.clock()
solver.iterations(ti,tf,threads=1)
temps = time.clock()-temps
print("1 threads : %f s"%temps)

temps = time.clock()
solver.iterations(ti,tf,threads=2)
temps = time.clock()-temps
print("2 threads : %f s"%temps)

temps = time.clock()
solver.iterations(ti,tf,threads=4)
temps = time.clock()-temps
print("4 threads : %f s"%temps)

temps = time.clock()
solver.opencl_iterations(ti,tf)
temps = time.clock()-temps
print("opencl : %f s"%temps)

solver.opencl_release_memory()
solver.close()



