!
! This module is much simpler in AstroGK than GS2
! because of the lack of equilibrium magnetic shear. 
! Shear couples different modes together linearly
! through the twist and shift boundary conditions, 
! and leads to very large matrix inversions which 
! are parallelized in GS2.  
!
! The only routine which would benefit strongly from 
! parallelization in the AstroGK case is getfield.
! Getfield is called in the timestepping loop.  
! (Note that advance_implicit is already parallel.)
! Getfield calls a couple of array-packing and 
! unpacking routines, which could be made parallel
! if getfield were parallelized.  
!
! One obvious strategy for parallelizing these loops 
! would be to distribute the it and ik indices.  Since
! many large runs use layouts with these indices on 
! processor (to optimize the evaluation of the nonlinear
! terms), I am putting this off for another time.
!

module fields_implicit
  use fields_arrays, only: nidx
  implicit none

  public :: init_fields_implicit   ! initialize this module
  public :: advance_implicit       ! Do a full timestep using original algorithm
  public :: init_phi_implicit      ! Set up fields for first time step
  public :: nidx                   ! Do not recall why this isn't declared locally to fields_implicit module.
  public :: reset_init             ! Deallocates matrices for implicit fields solution, in prep for changing dt.

  private

  integer, save :: nfield          ! number of fields (1,2, or 3)
  logical :: initialized = .false.

contains

  subroutine init_fields_implicit

    use antenna, only: init_antenna
    use theta_grid, only: init_theta_grid
    use kgrids, only: init_kgrids
    implicit none

    if (initialized) return
    initialized = .true.

    call init_theta_grid          ! Set up grid along the field line
    call init_kgrids              ! Set up grids perpendicular to the field line
    call init_response_matrix     ! Calculate matrices for implicit Maxwell solve
    call init_antenna             ! Set up external driving antenna

  end subroutine init_fields_implicit

  subroutine init_phi_implicit
    use fields_arrays, only: phi, apar, bpar, phinew, aparnew, bparnew
! TT>
    use dist_fn, only: get_init_field
    use dist_fn_arrays, only: g, gnew
! <TT
    implicit none

    call init_fields_implicit
! TT> next line is not good for the initialization of the EM problem
!    call getfield (phinew, aparnew, bparnew)
!    phi = phinew; apar = aparnew; bpar = bparnew
    ! bug fix follows
    call get_init_field (phinew, aparnew, bparnew)
    ! next line does not hurt, but maybe not needed except for heating diag
    phi = phinew; apar = aparnew; bpar = bparnew; g = gnew
! <TT

  end subroutine init_phi_implicit

  subroutine advance_implicit (istep)
    use fields_arrays, only: phi, apar, bpar, phinew, aparnew, bparnew
    use fields_arrays, only: apar_ext 
    use antenna, only: antenna_amplitudes
    use dist_fn, only: timeadv, exb_shear
    use dist_fn_arrays, only: g, gnew
    implicit none
    integer, intent (in) :: istep
    
    call antenna_amplitudes (apar_ext)                                          ! GGH NOTE: apar_ext is initialized in this call
    call exb_shear (gnew, phinew, aparnew, bparnew)                             ! See Hammett & Loureiro, APS 2006
    g = gnew  ;  phi = phinew  ;  apar = aparnew  ;  bpar = bparnew             ! Store previously advanced values as current values
    call timeadv (phi, apar, bpar, phinew, aparnew, bparnew, istep)             ! Advance f
    aparnew = aparnew + apar_ext                                                ! Add in antenna
    call getfield (phinew, aparnew, bparnew)                                    ! Find future part of fields
    phinew = phinew + phi ; aparnew = aparnew + apar ; bparnew = bparnew + bpar ! Update fields
    call timeadv (phi, apar, bpar, phinew, aparnew, bparnew, istep)             ! Complete f advance with updated fields
    
  end subroutine advance_implicit

  subroutine getfield (phi, apar, bpar)
    use kgrids, only: naky, nakx
    use fields_arrays, only: aminv
    use theta_grid, only: ntgrid
    implicit none
    complex, dimension (-ntgrid:,:,:), intent (in) :: phi, apar, bpar
    complex, dimension (:,:,:), allocatable :: fl, u
    integer :: ik, it, i

    allocate (fl(nidx, nakx, naky))
    allocate (u(nidx, nakx, naky)) ; u = 0.

    ! am*u = fl, Poisson's and Ampere's law, u is phi, apar, bpar 
    ! u = aminv*fl

    call get_field_vector (fl, phi, apar, bpar)
    
    do ik = 1, naky
       do it = 1, nakx
          do i = 1, nidx
             u(i,it,ik) = - sum(aminv(:,i,it,ik)*fl(:,it,ik))
          end do
       end do
    end do

    deallocate (fl)

    call get_field_solution (u)
    deallocate (u)

  end subroutine getfield

  subroutine get_field_vector (fl, phi, apar, bpar)
    use theta_grid, only: ntgrid
    use kgrids, only: naky, nakx
    use dist_fn, only: getfieldeq
    use run_parameters, only: use_Phi, use_Apar, use_Bpar
    implicit none
    complex, dimension (-ntgrid:,:,:), intent (in) :: phi, apar, bpar
    complex, dimension (:,:,:), intent (out) :: fl
    complex, dimension (:,:,:), allocatable :: fieldeq, fieldeqa, fieldeqp
    integer :: istart

    allocate (fieldeq (-ntgrid:ntgrid,nakx,naky))
    allocate (fieldeqa(-ntgrid:ntgrid,nakx,naky))
    allocate (fieldeqp(-ntgrid:ntgrid,nakx,naky))

    call getfieldeq (phi, apar, bpar, fieldeq, fieldeqa, fieldeqp)

    istart = 0

    if (use_Phi) then
       istart = istart + 1
       fl(istart:nidx:nfield,:,:) = fieldeq
    end if

    if (use_Apar) then
       istart = istart + 1
       fl(istart:nidx:nfield,:,:) = fieldeqa
    end if

    if (use_Bpar) then
       istart = istart + 1
       fl(istart:nidx:nfield,:,:) = fieldeqp
    end if

    deallocate (fieldeq, fieldeqa, fieldeqp)

  end subroutine get_field_vector

  subroutine get_field_solution (u)
    use fields_arrays, only: phinew, aparnew, bparnew
    use run_parameters, only: use_Phi, use_Apar, use_Bpar
    implicit none
    complex, dimension (:,:,:), intent (in) :: u
    integer :: istart

    istart = 0

    if (use_Phi) then
       istart = istart + 1
       phinew = u(istart:nidx:nfield,:,:)
    endif

    if (use_Apar) then
       istart = istart + 1
       aparnew = u(istart:nidx:nfield,:,:)
    endif

    if (use_Bpar) then
       istart = istart + 1
       bparnew = u(istart:nidx:nfield,:,:)
    endif

  end subroutine get_field_solution

  subroutine reset_init

    use fields_arrays, only: aminv

    initialized = .false.
    deallocate (aminv)

  end subroutine reset_init

  subroutine init_response_matrix

    use fields_arrays, only: phi, apar, bpar, phinew, aparnew, bparnew
    use theta_grid, only: ntgrid,ntheta
    use kgrids, only: naky, nakx
    use dist_fn_arrays, only: g
! TT>
!    use dist_fn, only: M_class
! <TT
    use run_parameters, only: use_Phi, use_Apar, use_Bpar
    implicit none
    integer :: ig, ifield, it, ik, i, m, n
    complex, dimension(:,:,:,:), allocatable :: am
    logical :: endpoint

    nfield = 0
    if (use_Phi)  nfield = nfield + 1
    if (use_Apar) nfield = nfield + 1
    if (use_BPar) nfield = nfield + 1
    nidx = (2*ntgrid+1)*nfield

    allocate (am(nidx, nidx, nakx, naky))
    
    am = 0.0
    g = 0.0
    
    phi = 0.0
    apar = 0.0
    bpar = 0.0
    phinew = 0.0
    aparnew = 0.0
    bparnew = 0.0
    
    do ig = -ntgrid, ntgrid
       ifield = 0
       if (use_Phi) then
          ifield = ifield + 1
          phinew(ig,:,:) = 1.0
          call init_response_row (ig, ifield, am)
          phinew = 0.0
       end if
       
       if (use_Apar) then
          ifield = ifield + 1
          aparnew(ig,:,:) = 1.0
          call init_response_row (ig, ifield, am)
          aparnew = 0.0
       end if
       
       if (use_Bpar) then
          ifield = ifield + 1
          bparnew(ig,:,:) = 1.0
          call init_response_row (ig, ifield, am)
          bparnew = 0.0
       end if
    end do
    
    call init_inverse_matrix (am)
    deallocate (am)
    
  end subroutine init_response_matrix

  subroutine init_response_row (ig, ifield, am)

    use fields_arrays, only: phi, apar, bpar, phinew, aparnew, bparnew
    use theta_grid, only: ntgrid
    use kgrids, only: naky, nakx
    use dist_fn, only: getfieldeq, timeadv
    use run_parameters, only: use_Phi, use_Apar, use_Bpar
    implicit none
    integer, intent (in) :: ig, ifield
    complex, dimension(:,:,:,:), intent (in out) :: am
    complex, dimension (:,:,:), allocatable :: fieldeq, fieldeqa, fieldeqp
    integer :: irow, istart, ik, it, ifin, m, nn

    allocate (fieldeq (-ntgrid:ntgrid, nakx, naky))
    allocate (fieldeqa(-ntgrid:ntgrid, nakx, naky))
    allocate (fieldeqp(-ntgrid:ntgrid, nakx, naky))

    call timeadv (phi, apar, bpar, phinew, aparnew, bparnew, 0)
    call getfieldeq (phinew, aparnew, bparnew, fieldeq, fieldeqa, fieldeqp)

    irow = ifield + nfield*(ig+ntgrid)

    do ik=1,naky
       do it=1,nakx

          istart = 0 
          
          if (use_Phi) then
             ifin = istart + nidx
             istart = istart + 1
             am(istart:ifin:nfield,irow,it,ik) = fieldeq(:,it,ik) 
          end if

          if (use_Apar) then
             ifin = istart + nidx
             istart = istart + 1
             am(istart:ifin:nfield,irow,it,ik) = fieldeqa(:,it,ik)
          end if
          
          if (use_Bpar) then
             ifin = istart + nidx
             istart = istart + 1
             am(istart:ifin:nfield,irow,it,ik) = fieldeqp(:,it,ik)
          end if
                 
       end do
    end do

    deallocate (fieldeq, fieldeqa, fieldeqp)
  end subroutine init_response_row

  subroutine init_inverse_matrix (am)

    use kgrids, only: aky, akx, nakx, naky
    use theta_grid, only: ntgrid
    use fields_arrays, only: aminv
    implicit none
    complex, dimension(:,:,:,:), intent (in out) :: am
    complex, dimension(:,:,:), allocatable :: lhscol, rhsrow
    complex :: fac
    integer :: i, j, ik, it, ig

    allocate (lhscol (nidx,nakx,naky))
    allocate (rhsrow (nidx,nakx,naky))
   
    allocate (aminv(nidx,nidx,nakx,naky))
    aminv = 0.0
    
    do ik=1,naky
       do it=1,nakx
          do i=1,nidx
             aminv(i, i, it, ik) = 1.0
          end do
       end do
    end do
!
! Gauss-Jordan elimination: surely could be improved!
! Inverting naky*nakx arrays of size nidx*nidx
!
    do i=1,nidx
       do ik=1,naky
          do it=1,nakx
             lhscol(:,it,ik) = am   (:,i,it,ik)
             rhsrow(:,it,ik) = aminv(:,i,it,ik)
          end do
       end do
    
       do ik=1,naky
          do it=1,nakx
             if (aky(ik) /= 0.0 .or. akx(it) /= 0.0) then
                
                do j=1,nidx
                   
                   fac = am(i,j,it,ik)/lhscol(i,it,ik)
                   am(i,j,it,ik) = fac
                   am(:i-1,j,it,ik) = am(:i-1,j,it,ik) - lhscol(:i-1,it,ik)*fac
                   am(i+1:,j,it,ik) = am(i+1:,j,it,ik) - lhscol(i+1:,it,ik)*fac
                   
                   if (j == i) then
                      aminv(:,j,it,ik) = aminv(:,j,it,ik)/lhscol(i,it,ik)
                   else
                      aminv(:,j,it,ik) = aminv(:,j,it,ik) &
                           - rhsrow(:,it,ik)*lhscol(j,it,ik)/lhscol(i,it,ik)
                   end if
                   
                end do
                
             else
                aminv(:,:,it,ik) = 0.0
             end if
          end do
       end do
    end do

    deallocate (lhscol, rhsrow)

  end subroutine init_inverse_matrix

end module fields_implicit
