"""
Program to read in polarization modulation curve data, plot the data,
fit a constant value, and fit and plot a modulation curve of the form
A + B*cos^2(phi-phi0).

Data are assumed to be in csv format.
"""

# make Python 2.7 act like Python 3
from __future__ import division, print_function
input = raw_input


import numpy as np
import matplotlib.pyplot as plt
from scipy.optimize import curve_fit


# calculate weighted average
def wavg(value, err):
  a = np.sum(value/err)
  b =  np.sum(1/err)
  return (a/b)

# define function used to fit modulation curve
def func_cos2t(phi, a, b, phi0):
  print('In func_cos2t a, b, phi0 = ', a, b, phi0)
  return a + b*np.cos((phi-phi0)*np.pi/180)**2


#%% read in the modulation curve data
(angle, rate, rate_err) = np.loadtxt('mxs_polarization.txt', delimiter=',', unpack=True)
print('Read in ', len(angle), ' data points.')

#%% plot the modulation curve
plt.ion() # interactive plotting
plt.figure('Rotation curve') # make a plot window
plt.clf() # clear the plot window
# plot the data with errorbars
plt.errorbar(angle, rate, rate_err, fmt='o')
plt.ylim(ymin=0) # set the bottom of the y-axis to zero
plt.title('MXS Rotation curve')
plt.xlabel('Rotation angle (degrees)')
plt.ylabel('Count rate (events/second)')
plt.show()

#%% fit with a constant
print('Fit data with a constant:')
avg = wavg(rate, rate_err)
# find the chi-squared
chisq0 = sum(((rate-avg)/rate_err)**2)
print( 'Average (c/s) = %.3f' % (avg))
dof0 = len(angle)-1  # Degrees of Freedom
print('Chisq/Dof = %.1f/%d' % (chisq0, dof0))
print()

#%% fit to a + b*cos^2(phi-phi0)
print('Fit data with A + B*cos^2(phi-phi0) :')
#pinit = [avg, 0.0, 0.0] # an ok initial guess
pinit = [1.0, 1.0, 1.0] # the default initial guess
#pinit = [0.0, -0.05, -30.0] # a confusing initial guess
# use the scipy curve_fit routine using Levenberg-Marquardt algorithm to find the best fit
popt, pcov = curve_fit(func_cos2t, angle, rate, p0=pinit, sigma=rate_err, method='lm')
# translate fit results to user friendly variable
a, a_err = popt[0], np.sqrt(pcov[0,0])
b, b_err = popt[1], np.sqrt(pcov[1,1])
phi0, phi0_err = popt[2], np.sqrt(pcov[2,2])
# print the results
print('A = ', a, ' +/- ', a_err)
print('B = ', b, ' +/- ', b_err)
print('phi_0 = ', phi0, ' +/- ', phi0_err)
print('Modulation = ', b/(2*a+b))
# find the chi-squared
chisq = sum(((rate-func_cos2t(angle, a, b, phi0))/rate_err)**2)
dof = len(angle)-3
print('Chisq/Dof = ', chisq, '/', dof)
# plot the fit
pang = np.linspace(0.0, 180.0, 180)
plt.plot(pang, func_cos2t(pang, a, b, phi0), 'b--')



