# Python program to load a FITS spectrum and display it

# import needed extensions
from numpy import *
import pyfits
import matplotlib.pyplot as plt # plotting package
import matplotlib.cm as cm # colormaps
from matplotlib.colors import LogNorm
from scipy import stats

# read in the files
# change the file names as appropriate
#s1file = pyfits.open('hydrogen1s.fit')
s1file = pyfits.open('lamp.fits')

# copy the image data into a numpy (numerical python) array
s1img = s1file[0].data

plt.ion() # do plots in interactive mode
colmap = plt.get_cmap('gray') # load gray colormap

# plot the difference image
plt.figure(1)
plt.clf()
# plot image using gray colorbar on log scale
# adjust vmin and vmax based on spectrum
plt.imshow(s1img, cmap=colmap, norm=LogNorm(vmin=5, vmax=6E4)) 
plt.show() # display the image


# calculate a spectrum
# y0 is the center of the band over which the spectrum is extracted
y0 = 80.0
# dy sets the width of the band (y0-dy to y0+dy)
dy = 10
# figure out dimensions of spectrum image
(ny, nx) = shape(s1img)
# find a 1-d spectrum by integrating across the band
s1 = zeros(nx)
for i in range(nx):
  s1[i] = sum(s1img[(y0-dy):(y0+dy+1), i])
# calculate pixel numbers
p = 1+arange(len(s1)) 

# plot the spectrum versus pixel number
plt.figure(2)
plt.clf()
plt.xlabel('Pixel number')
plt.ylabel('Counts')
plt.plot(p, s1, '-b')
plt.show() # display the plot

print 'Line centroids in pixels'
# calculate centroids for each line
linec = array([  1000,   1100,   1200,   1300]) # center (pixel)
lined = array([     5,      5,      5,      5]) # width (pixel)
linew = array([546.07, 587.56, 435.83, 404.66]) # He+HgCd

# maximum value in spectrum (only for plotting)
smax = max(s1)
centroid = 0.0*linec
for i in range(len(linec)):
  # array elements included in this line
  k = range(linec[i]-lined[i], linec[i]+lined[i]+1)
  # find statistics for this line
  centroid[i] = sum(p[k]*s1[k])/sum(s1[k])
  plt.plot([centroid[i], centroid[i]], [0, smax], '-g')
  plt.plot([linec[i]-lined[i], linec[i]-lined[i]], [0, smax], '--g')
  plt.plot([linec[i]+lined[i], linec[i]+lined[i]], [0, smax], '--g')
  print 'center, range, mean = ', linec[i], lined[i], centroid[i]
plt.show() # display the plot

print
print 'Calibration'
# calibration
centralp = mean(centroid)
if len(centroid) == 2:
  # straight calculation with two points
  centralw = mean(linew)
  slope = (linew[1]-linew[0])/(centroid[1]-centroid[0])
else:
  # do a linear fit to the data
  slope, centralw, r_value, p_value, std_err = \
    stats.linregress(centroid-centralp, linew)
print 'central pixel = ', centralp
print 'central wavelength = ', centralw
print 'slope = ', slope


