# -*- coding: utf-8 -*-
"""
Created on Sat Dec  1 19:30:12 2018

@author: kaaret
"""

# make Python 2.7 act like Python 3
from __future__ import division, print_function
input = raw_input

# import commonly used libraries
import numpy as np
import matplotlib.pyplot as plt
# import differential equation solver
from scipy.integrate import solve_ivp
# import curve fitting routine
from scipy.optimize import curve_fit
# import stuff for GUI
import matplotlib
from matplotlib.backends.backend_tkagg import FigureCanvasTkAgg, NavigationToolbar2TkAgg
import Tkinter as tk  # python 2.7


def dfun(t, y) :
# find derivatives of variables at current time
# t = time, not used explicitly
# y = vector [x1, v1, x2, v2]
  # constants defining the system
  #k = 2.0 # [N/m] spring constant for edge springs
  #kappa = 2.0 # [N/m] spring constant for center spring
  #m = 1.0 # [kg] mass of masses
  # unpack y into physical variables
  x1 = y[0]
  v1 = y[1]
  x2 = y[2]
  v2 = y[3]
  # calculate derivatives
  dx1dt = v1
  dv1dt = -x1*(k+kappa)/m + x2*kappa
  dx2dt = v2
  dv2dt = x1*kappa - x2*(k+kappa)/m
  # return derivatives packed into an array
  return np.array([dx1dt, dv1dt, dx2dt, dv2dt])


# define function used to fit 
def eigenf(t, a1, t1, a2, t2):
  return a1*np.cos(w1*(t-t1)) + a2*np.cos(w2*(t-t2))


def update(*dummy) : # use * to allow a variable number of arguments
  # get initial values from scale widgets
  y0 = [x1s.get(), v1s.get(), x2s.get(), v2s.get()]
  t_span = (0.0, ts.get()) # get time interval from scale widget
  rtol, atol = 1E-6, 1E-9 # set tolerances for numerical integration
  # numerically integrate equations of motion
  sol = solve_ivp(dfun, t_span, y0, rtol=rtol, atol=atol)
  # translate solution array to easier to use vectors
  x1 = sol.y[0]
  v1 = sol.y[1]
  x2 = sol.y[2]
  v2 = sol.y[3]
  # fit data to eigenfunctions
  pinit = [0.0, 0.0, 0.0, 0.0] # initial guess
  # use the scipy curve_fit routine using Levenberg-Marquardt algorithm to find the best fit
  # error are set by numerical accuracy
  # note that we do this for mass 1 only
  popt, pcov = curve_fit(eigenf, sol.t, x1, p0=pinit, sigma=atol+x1*rtol, method='lm')
  # translate fit results to user friendly variables
  a1, a1_err = popt[0], np.sqrt(pcov[0,0])
  t1, t1_err = popt[1], np.sqrt(pcov[1,1])
  a2, a2_err = popt[2], np.sqrt(pcov[2,2])
  t2, t2_err = popt[3], np.sqrt(pcov[3,3])
  # information about solution
  if sol.success :
    text0 = 'Number time steps = '+str(len(sol.t))
  else :
    text0 = 'No solution found'
  l0.config(text = text0) # write to a label to show on screen
  # write information about amplitude and phase of first eigenfunction to screen
  text1 = 'Eigenvalue = '+'{0:.3f}'.format(w1) + \
          '   amplitude = '+'{0:.4f}'.format(a1) + \
          '   phase [s] = '+'{0:.3f}'.format(t1)
  l1.config(text = text1) 
  # write information about amplitude and phase of second eigenfunction to screen
  text2 = 'Eigenvalue = '+'{0:.3f}'.format(w2) + \
          '   amplitude = '+'{0:.4f}'.format(a2) + \
          '   phase [s] = '+'{0:.3f}'.format(t2)
  l2.config(text = text2)
  # plot the solution
  ax.clear() # clear previous plot
  ax.set_title('Displacement')
  ax.set_xlabel('Time (s)') # label the X-axis
  ax.set_ylabel('Position (m)') # label the Y-axis
  ax.plot(sol.t, x1, '+g', label='1') # plot position for mass 1
  ax.plot(sol.t, x2, '+r', label='2') # plot position for mass 2
  ax.legend(loc='best') # add a legend to explain the curves
  # find best fitted curve
  t = np.linspace(min(sol.t), max(sol.t), 1000)
  x1fit = eigenf(t, a1, t1, a2, t2)
  ax.plot(t, x1fit, '-g', label='1') # plot position for mass 1
  canvas.draw()


# constants defining the system
k = 2.0 # [N/m] spring constant for edge springs
kappa = 2.0 # [N/m] spring constant for center spring
m = 1.0 # [kg] mass of masses
w1 = np.sqrt(k/m) # first eigenfrequency
w2 = np.sqrt((k+2*kappa)/m) # second eigenfrequency

matplotlib.use('TkAgg')
root = tk.Tk()

fig = plt.Figure(figsize=(6,4)) # use capital F to avoid drawing a plot
ax = fig.add_subplot(111)
canvas = FigureCanvasTkAgg(fig, root)
canvas.get_tk_widget().grid(row=0, column=5, rowspan=5)
canvas.draw()

x1s = tk.Scale(root, from_=+1.0, to=-1.0, resolution=0.01, label='x1 [m]', command = update)
x1s.set(0.2)
x1s.grid(row=0, column=0)
v1s = tk.Scale(root, from_=+1.0, to=-1.0, resolution=0.01, label='v1 [m/s]', command = update)
v1s.grid(row=0, column=1)
x2s = tk.Scale(root, from_=+1.0, to=-1.0, resolution=0.01, label='x2 [m]', command = update)
x2s.set(-0.2)
x2s.grid(row=0, column=2)
v2s = tk.Scale(root, from_=+1.0, to=-1.0, resolution=0.01, label='v2 [m/s]', command = update)
v2s.grid(row=0, column=3)

ts = tk.Scale(root, from_=0.0, to=20.0, resolution=1, label='Time [s]', \
              orient=tk.HORIZONTAL, command = update)
ts.set(10.0)
ts.grid(row=1, column=0, columnspan=2, sticky=tk.W)

l0 = tk.Label(root)
l0.grid(row=1, column=2, columnspan=2, sticky=tk.W)
l1 = tk.Label(root)
l1.grid(row=2, column=0, columnspan=3, sticky=tk.W)
l2 = tk.Label(root)
l2.grid(row=3, column=0, columnspan=3, sticky=tk.W)

#plotbutton = tk.Button(root, text="plot", command=lambda: doplot(canvas, ax, x1))
#plotbutton = tk.Button(root, text="Plot", command = update)
#plotbutton.pack(side='bottom', anchor='w')
#plotbutton.grid(row=0,column=0)

quitbutton = tk.Button(root, text="Quit", command=root.quit)
quitbutton.grid(row=4, column=0, sticky=tk.W)

update() # draw the plot window

root.protocol("WM_DELETE_WINDOW", root.quit)
root.mainloop()
root.destroy()

