import pyfits
from numpy import array, where, arange, zeros, interp
import matplotlib.pyplot as plt

# from pyatomdb, modified
def get_emissivity(linefile, elem, ion, upper, lower, kT=[-1], hdu = [-1]):

  a = pyfits.open(linefile)
  # find the blocks of interest
  if ((kT[0] != -1) & (hdu[0] != -1)):
    print "Error in get_emissivity: provide either the hdu numbers "+\
          "(starting from 2) or the temperatures"
    return -1
  interpKt = False
  if (kT[-1] != -1):
    interpKt = True
    hdulist = []
    ikT_min = where(a[1].data.field('kT') >= min(kT))[0]
    ikT_max = where(a[1].data.field('kT') <= max(kT))[0]
    if ((len(ikT_min)==0) | (ikT_min[0] == 0)):
      print "Error in get_emissivity: kT=%e out of range %e:%e keV" %\
             (min(kT), a[1].data.field('kT')[0], \
              a[1].data.field('kT')[-1])
      return -1

    if ((len(ikT_max)==0) | (ikT_max[-1] == len(a[1].data.field('kT'))-1)):
      print "Error in get_emissivity: kT=%e out of range %e:%e keV" %\
             (max(kT), a[1].data.field('kT')[0], \
              a[1].data.field('kT')[-1])
      return -1
    ikT = arange(ikT_min[0]-1, ikT_max[-1]+2)
  elif (hdu[0] != -1):
    ikT = array(hdu)-2
    kT = a[1].data.field('kT')[ikT]
  else:
    ikT = arange(len(a[1].data.field('kT')), dtype=int)
    kT = a[1].data.field('kT')[ikT]

  # ok. Get the numbers
  emiss_grid = zeros(len(ikT), dtype=float)
  for iemiss, i in enumerate(ikT):
    ii = i+2
    j = where((a[ii].data.field('element')==elem) &\
              (a[ii].data.field('ion')==ion) &\
              (a[ii].data.field('upperlev')==upper) &\
              (a[ii].data.field('lowerlev')==lower))[0]
    if len(j) > 0:
      emiss_grid[iemiss] = a[ii].data.field('epsilon')[j[0]]

  # ok, now get the numbers on a nice grid
  if (interpKt):
    emiss = interp(kT, a[1].data.field('kT')[ikT], emiss_grid)
  else:
    emiss = emiss_grid
  return kT, emiss


# file from AtomDB holding the emissivity data
# modify the file location for your system
linefile = '/kaaret/halosat/atomdb/apec_line.fits'

# get the emissivities for this transition
elem, ion, upper, lower = 8, 7, 2, 1
kT, emiss = get_emissivity(linefile, elem, ion, upper, lower, kT=[-1], hdu=[-1])

kT = (arange(10)+1)*0.02
kT, emiss = get_emissivity(linefile, elem, ion, upper, lower, kT=kT, hdu=[-1])

# plot the emissivites
# note that this plot uses temperatures (kT) in keV
plt.ion()
plt.clf()
plt.ylabel('Emissivity (ph cm^3/s)')
plt.xlabel('k*Temperature (keV)')
plt.title('Emissivity vs Temperature')
plt.xscale('log')
plt.yscale('log')
plt.plot(kT, emiss, label="O VII (7-2)")
plt.legend(loc='upper left')
plt.show()

