""" Read in a FITS image and display it in several different ways
"""

# import needed extensions
import numpy as np
import matplotlib.pyplot as plt # plotting package
import matplotlib.cm as cm # colormaps
from mpl_toolkits.mplot3d import Axes3D
import astropy.io.fits as pyfits # FITS format files
from scipy.ndimage.filters import gaussian_filter

plt.ion() # do plots in interactive mode    

#%% Read in the FITS image
# name of file with the image in FITS format
#image_file = 'm16_f502n.fits'
image_file = 'm57_f658n.fits'

h = pyfits.open(image_file) # read in the file
img = h[0].data # copy the image data into a numpy array

#%% plot a histogram of the image values
# img is a 2-d array, need to change to 1-d to make a histogram
ny, nx = img.shape # find the size of the array
imgh = np.reshape(img, nx*ny) # change the shape to be 1d
plt.figure('Histogram') # set up a window for the histogram plot
plt.clf() # clear the window
 # plot a histogram with 100 bins going from vmin to vmax
t = plt.hist(imgh, bins=100)

# plot the data as an image
plt.figure('Image') # set up a new window for the image plot
plt.clf() # clear the window
# a colormap coverts pixel values to colors
colmap = plt.get_cmap('gray') # we'll be boring and use a gray color map
# plot the image, aspect = 'equal' makes square pixels
plt.imshow(img, cmap=colmap, aspect='equal')



#%%
# print some statistics about the image
img_min = np.min(img)
img_max = np.max(img)
img_mean = np.mean(img)
img_median = np.median(img)
print('Image minimum = '+str(img_min))
print('Image maximum = '+str(img_max))
print('Image mean = '+str(img_mean))
print('Image median = '+str(img_median))

#%% figure out what range of pixel values is interesting
vmin = img_min # use the minimum pixel value as the minimum interesting pixel value
# try different ways to set the maximum interesting pixel value
# uncomment only one of the following lines to set vmax
# also feel free to change the numerical factors
#vmax = 0.1*img_max # use the maximum pixel value
#vmax = 10*img_mean # use a multiple of the mean pixel value
#vmax = 2*img_median # use a multiple of the median pixel value
print("Scaling pixel values to range "+str(vmin)+" to "+str(vmax))

# plot a histogram of the image values
# img is a 2-d array, need to change to 1-d to make a histogram
ny, nx = img.shape # find the size of the array
imgh = np.reshape(img, nx*ny) # change the shape to be 1d
plt.figure('Better Histogram') # set up a window for the histogram plot
plt.clf() # clear the window
 # plot a histogram with 100 bins going from vmin to vmax
plt.hist(imgh, bins=100, range=[vmin, vmax])

# plot the data as an image
plt.figure('Better Image') # set up a new window for the image plot
plt.clf() # clear the window
# a colormap coverts pixel values to colors
colmap = plt.get_cmap('gray') # we'll be boring and use a gray color map
# plot the image, aspect = 'equal' makes square pixels
plt.imshow(img, cmap=colmap, vmin=vmin, vmax=vmax, aspect='equal')


#%% the data are somewhat noisy 
# smooth the data before making the contour and surface plots
# smooth with a 2D Gaussian
nsigma = 10.0 # width of the Gaussian in pixels
smooth = gaussian_filter(img, nsigma) # do the smoothing

# re-evaluate vmin and vmax after smoothing
vmin = 0.0
#vmax = np.max(smooth)
#vmax = 10*np.mean(smooth)
#vmax = 2*np.median(smooth)
print("Scaling smoothed pixel values to range "+str(vmin)+" to "+str(vmax))

# plot the data as an image
plt.figure('Smoothed Image') # set up a new window for the image plot
plt.clf() # clear the window
# a colormap coverts pixel values to colors
colmap = plt.get_cmap('gray') # we'll be boring and use a gray color map
# plot the image, aspect = 'equal' makes square pixels
plt.imshow(smooth, cmap=colmap, vmin=vmin, vmax=vmax, aspect='equal')


#%% for contour and surface plots, we need a grid of points with the X,Y values
# the Z value will be the pixel value
x_vals = np.linspace(0.0, nx, nx)
y_vals = np.linspace(0.0, ny, ny)
x, y = np.meshgrid(x_vals, y_vals)

# plot the data as contours
# choose values for the contours
# generate ncontour values from vmin to vmax, then drop the lowest and highest
ncontour = 9
contours = np.linspace(vmin, vmax, ncontour)[1:ncontour-1]

# make the contour plot
plt.figure('Contours') # set up a new window for the contour plot
plt.clf() # clear the window
# make the contour plot and save the contour information in cs
cs = plt.contour(x, y, smooth, contours) 
plt.clabel(cs, inline=1, fontsize=10) # label the contours

# plot the data as a surface
surface_plot = plt.figure('Surface') # set up a new window for the surface plot
plt.clf() # clear the window
ax = Axes3D(surface_plot)
ax.plot_surface(x, y, smooth, vmin=vmin, vmax=vmax)

plt.show() # display the plots
