module collisions

  use redistribute, only: redist_type

  implicit none

  public :: init_collisions
  public :: solfp1
  public :: reset_init
  public :: update_vnewh          !Used for adapative hypercollisionality

  private

  ! knobs
  logical :: conserve_momentum, const_v
  integer :: collision_model_switch
  logical :: adjust
  logical :: heating
  logical :: hyper_colls

  integer, parameter :: collision_model_lorentz = 1
  integer, parameter :: collision_model_none = 3
  integer, parameter :: collision_model_lorentz_test = 5

! TT> First index omitted.
! GS2 has tunits(ik) factor which is removed in AstroGK.
!  real, dimension (:,:,:), allocatable :: vnew
  real, dimension (:,:), allocatable :: vnew, vnew_ss
  ! (negrid,nspec) replicated
! <TT
  ! (naky,negrid,nspec) replicated

  ! TT: Do we want to keep ig dependence for vnewh?
  ! only for hyper-diffusive collisions
  real, dimension (:,:,:,:), allocatable :: vnewh
  ! (-ntgrid:ntgrid,ntheta0,naky,nspec) replicated

  ! TT: we can probably remove the first index
  ! only for "new" momentum conservation (8.06)
  complex, dimension(:,:,:), allocatable :: z0, z1

  ! only for original parallel mom conservation (not used nowadays)
  real, dimension (:,:,:), allocatable :: sq
  ! (-ntgrid:ntgrid,nlambda,2) replicated

  ! only for lorentz
  real :: cfac
  real, dimension (:,:), allocatable :: c1, betaa, ql, d1, h1
  complex, dimension (:,:), allocatable :: glz, glzc
  ! ( (2*nlambda+1), -*- lz_layout -*-)
  ! (-ntgrid:ntgrid,naky,2*nlambda+1,negrid,nspec)

  type (redist_type), save :: lorentz_map

  logical :: hypermult
  logical :: initialized = .false.
  logical :: accelerated_x = .false.
  logical :: accelerated_v = .false.

contains

  subroutine init_collisions
    use species, only: init_species, nspec, spec
    use theta_grid, only: init_theta_grid, ntgrid
    use kgrids, only: init_kgrids, naky, nakx
    use le_grids, only: init_le_grids, nlambda, negrid 
    use run_parameters, only: init_run_parameters
    use agk_layouts, only: init_dist_fn_layouts, init_agk_layouts
    implicit none

    if (initialized) return
    initialized = .true.

    call init_agk_layouts
    call init_species

    hyper_colls = .false.
    if (any(spec%nu_h > epsilon(0.0))) hyper_colls = .true.

    call init_theta_grid
    call init_kgrids
    call init_le_grids (accelerated_x, accelerated_v)
    call init_run_parameters
    call init_dist_fn_layouts (ntgrid, naky, nakx, nlambda, negrid, nspec)

    call read_parameters
    call init_arrays
  end subroutine init_collisions

  subroutine read_parameters
    use file_utils, only: input_unit, error_unit, input_unit_exist
    use text_options
    use mp, only: proc0, broadcast
    implicit none
    type (text_option), dimension (5), parameter :: modelopts = &
         (/ text_option('default', collision_model_lorentz), &
            text_option('lorentz', collision_model_lorentz), &
            text_option('lorentz-test', collision_model_lorentz_test), &
            text_option('none', collision_model_none), &
            text_option('collisionless', collision_model_none) /)
    character(20) :: collision_model
    namelist /collisions_knobs/ collision_model, conserve_momentum, heating, adjust, const_v, cfac, hypermult
    integer :: ierr, in_file
    logical :: exist

    if (proc0) then
       hypermult = .false.
       cfac = 1.   ! DEFAULT CHANGED TO INCLUDE CLASSICAL DIFFUSION: APRIL 18, 2006
       adjust = .true.
       collision_model = 'default'
       conserve_momentum = .true.  ! DEFAULT CHANGED TO REFLECT IMPROVED MOMENTUM CONSERVATION, 8/06
       const_v = .false.
       heating = .false.
       in_file = input_unit_exist ("collisions_knobs", exist)
       if (exist) read (unit=input_unit("collisions_knobs"), nml=collisions_knobs)

       ierr = error_unit()
       call get_option_value &
            (collision_model, modelopts, collision_model_switch, &
            ierr, "collision_model in collisions_knobs")
    end if

    call broadcast (hypermult)
    call broadcast (cfac)
    call broadcast (conserve_momentum)
    call broadcast (const_v)
    call broadcast (collision_model_switch)
    call broadcast (heating)
    call broadcast (adjust)
  end subroutine read_parameters

  subroutine init_arrays
    use species, only: nspec
    use le_grids, only: negrid
    use agk_layouts, only: g_lo
    use kgrids, only: naky, nakx
    use theta_grid, only: ntgrid
    use dist_fn_arrays, only: c_rate
    implicit none
    real, dimension (negrid,nspec) :: hee
    logical :: first_time = .true.

    if (first_time) then
       allocate (c_rate(-ntgrid:ntgrid, nakx, naky, nspec, 3))
       c_rate = 0.
       first_time = .false.
    end if

    if (collision_model_switch == collision_model_none) return

    call init_vnew (hee)
! TT>
!    if (all(abs(vnew(:,1,:)) <= 2.0*epsilon(0.0))) then
    if (all(abs(vnew(1,:)) <= 2.0*epsilon(0.0))) then
! <TT
       collision_model_switch = collision_model_none
       return
    end if

    select case (collision_model_switch)
    case (collision_model_lorentz,collision_model_lorentz_test)
       call init_lorentz
       if (conserve_momentum) call init_lz_mom_conserve
    end select

  end subroutine init_arrays

  subroutine init_lz_mom_conserve 

!
! Precompute two quantities needed for momentum conservation:
! z0, z1
    
    use agk_layouts, only: g_lo, ie_idx, is_idx, ik_idx, il_idx, it_idx
    use species, only: nspec, spec
    use kgrids, only: naky, nakx
    use theta_grid, only: ntgrid
    use le_grids, only: e, integrate_moment
    use agk_time, only: dtime
    use dist_fn_arrays, only: aj0, aj1vp2, kperp2, vpa

    logical, save :: first = .true.
    complex, dimension (1,1,1) :: dum1 = 0., dum2 = 0.
    complex, dimension (:,:,:), allocatable :: gtmp
    complex, dimension (:,:,:,:), allocatable :: v0z0, v0z1, v1z0, v1z1, Enuinv
    integer :: ie, il, ik, is, ig, isgn, iglo, all, it

    if (first) then
       allocate (z0(-ntgrid:ntgrid,2,g_lo%llim_proc:g_lo%ulim_alloc))
       allocate (z1(-ntgrid:ntgrid,2,g_lo%llim_proc:g_lo%ulim_alloc))
       first = .false.
    end if

! First, get Enu and then 1/Enu == Enuinv

    allocate (gtmp(-ntgrid:ntgrid,2,g_lo%llim_proc:g_lo%ulim_alloc))
    allocate (Enuinv(-ntgrid:ntgrid, nakx, naky, nspec))

!
! Enu == int (E nu f_0);  Enu = Enu(z, kx, ky, s)
! Enuinv = 1/Enu
    do iglo = g_lo%llim_proc, g_lo%ulim_proc
! TT>
!       ik = ik_idx(g_lo,iglo)
! <TT
       ie = ie_idx(g_lo,iglo)
       is = is_idx(g_lo,iglo)
       do isgn = 1, 2
          do ig=-ntgrid, ntgrid
! TT> GS2 has unknown factor of vnm here
!             gtmp(ig,isgn,iglo) = e(ie,is)*vnew(ik,ie,is)
             gtmp(ig,isgn,iglo) = e(ie,is) * vnew_ss(ie,is)
! <TT
          end do
       end do
    end do

! No real z dependence anymore, since B is constant.  Could be simplified.

    all = 1
    call integrate_moment (gtmp, Enuinv, all)  ! not 1/Enu yet

    where (cabs(Enuinv) > epsilon(0.0))  ! necessary b/c some species may have nu=0
                                   ! Enuinv=0 iff vnew=0 so ok to keep Enuinv=0.
       Enuinv = 1./Enuinv  ! now it is 1/Enu
    end where

!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
! Now get z0 (first form)

    do iglo = g_lo%llim_proc, g_lo%ulim_proc
       it = it_idx(g_lo,iglo)
       ik = ik_idx(g_lo,iglo)
       ie = ie_idx(g_lo,iglo)
       il = il_idx(g_lo,iglo)
       is = is_idx(g_lo,iglo)
       do isgn = 1, 2
          do ig=-ntgrid, ntgrid
! V_perp == e(ie,is)*al(il)*aj1(iglo) = 0.5 * aj1vp2(iglo)
! u0 = -3 nu V_perp dt a f_0 / Enu
! where a = kperp2 * (T m / q**2)  ! missing factor of 1/B(theta)**2 ???
! TT> GS2 has unknown factor of vnm here
!             z0(ig,isgn,iglo) = - 3.*vnew(ik,ie,is)*e(ie,is)*al(il)*aj1(iglo) &
!             z0(ig,isgn,iglo) = - 3.*vnew_ss(ie,is)*e(ie,is)*al(il)*aj1(iglo) &
             z0(ig,isgn,iglo) = - 1.5*vnew_ss(ie,is)*aj1vp2(iglo) &
! <TT
                  * dtime * spec(is)%smz**2 * kperp2(it,ik) * Enuinv(ig,it,ik,is)
          end do
       end do
    end do

    call solfp_lorentz (z0, dum1, dum2)   ! z0 is redefined below

!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
! Now get z1 (first form)

    do iglo = g_lo%llim_proc, g_lo%ulim_proc
       ik = ik_idx(g_lo,iglo)
       ie = ie_idx(g_lo,iglo)
       il = il_idx(g_lo,iglo)
       is = is_idx(g_lo,iglo)
       do isgn = 1, 2
          do ig=-ntgrid, ntgrid
! v_parallel == vpa or vpac?? (zone vs. boundary)
! V_parallel == v_parallel J0
! u1 = -3 nu V_parallel dt f_0 / Enu
!
! No factor of sqrt(T/m) here on purpose (see derivation) 
!
! TT> GS2 has unknown factor of vnm here
!             z1(ig,isgn,iglo) = - 3.*vnew(ik,ie,is)*vpa(isgn,iglo)*aj0(iglo) &
             z1(ig,isgn,iglo) = - 3.*vnew_ss(ie,is)*vpa(isgn,iglo)*aj0(iglo) &
! <TT
                  * dtime * Enuinv(ig,it,ik,is)
          end do
       end do
    end do

    deallocate (Enuinv)  ! Done with this variable

    call solfp_lorentz (z1, dum1, dum2)    ! z1 is redefined below

!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
! Now get v0z0

    allocate (v0z0(-ntgrid:ntgrid, nakx, naky, nspec))         

    do iglo = g_lo%llim_proc, g_lo%ulim_proc
! TT>
!       ik = ik_idx(g_lo,iglo)
! <TT
       ie = ie_idx(g_lo,iglo)
       il = il_idx(g_lo,iglo)
       is = is_idx(g_lo,iglo)
       do isgn = 1, 2
          do ig=-ntgrid, ntgrid
! V_perp == e(ie,is)*al(il)*aj1(iglo) = 0.5 * aj1vp2(iglo)
! v0 = nu V_perp
! TT> GS2 has unknown factor of vnm here
!             gtmp(ig,isgn,iglo) = vnew(ik,ie,is)*e(ie,is)*al(il)*aj1(iglo) &
!             gtmp(ig,isgn,iglo) = vnew_ss(ie,is)*e(ie,is)*al(il)*aj1(iglo) &
             gtmp(ig,isgn,iglo) = 0.5*vnew_ss(ie,is)*aj1vp2(iglo) &
! <TT
                  * z0(ig,isgn,iglo)
          end do
       end do
    end do

    call integrate_moment (gtmp, v0z0, all)    ! v0z0

!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
! Now get v1z0

    allocate (v1z0(-ntgrid:ntgrid, nakx, naky, nspec))         

    do iglo = g_lo%llim_proc, g_lo%ulim_proc
! TT>
!       ik = ik_idx(g_lo,iglo)
! <TT
       ie = ie_idx(g_lo,iglo)
       il = il_idx(g_lo,iglo)
       is = is_idx(g_lo,iglo)
       do isgn = 1, 2
          do ig=-ntgrid, ntgrid
! v_parallel == vpa or vpac?? (zone vs. boundary)
! V_parallel == v_parallel J0
! v1 = nu V_parallel f_0
!
! No factor of sqrt(T/m) here on purpose (see derivation) 
!
! TT> GS2 has unknown factor of vnm here
!             gtmp(ig,isgn,iglo) = vnew(ik,ie,is)*vpa(isgn,iglo)*aj0(iglo) &
             gtmp(ig,isgn,iglo) = vnew_ss(ie,is)*vpa(isgn,iglo)*aj0(iglo) &
! <TT
                  * z0(ig,isgn,iglo)
          end do
       end do
    end do

    call integrate_moment (gtmp, v1z0, all)    ! v1z0

!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
! Now get v0z1

    allocate (v0z1(-ntgrid:ntgrid, nakx, naky, nspec))         

    do iglo = g_lo%llim_proc, g_lo%ulim_proc
! TT>
!       ik = ik_idx(g_lo,iglo)
! <TT
       ie = ie_idx(g_lo,iglo)
       il = il_idx(g_lo,iglo)
       is = is_idx(g_lo,iglo)
       do isgn = 1, 2
          do ig=-ntgrid, ntgrid
! V_perp == e(ie,is)*al(il)*aj1(iglo) = 0.5 * aj1vp2(iglo)
! v0 = nu V_perp
! TT> GS2 has unknown factor of vnm here
!             gtmp(ig,isgn,iglo) = vnew(ik,ie,is)*e(ie,is)*al(il)*aj1(iglo) &
!             gtmp(ig,isgn,iglo) = vnew_ss(ie,is)*e(ie,is)*al(il)*aj1(iglo) &
             gtmp(ig,isgn,iglo) = 0.5*vnew_ss(ie,is)*aj1vp2(iglo) &
! <TT
                  * z1(ig,isgn,iglo)
          end do
       end do
    end do

    call integrate_moment (gtmp, v0z1, all)

!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
! Now redefine z1 == z1 - z0 [v0 . z1]/(1+v0.z0)

    do iglo = g_lo%llim_proc, g_lo%ulim_proc
       ik = ik_idx(g_lo,iglo)
       it = it_idx(g_lo,iglo)
       is = is_idx(g_lo,iglo)
       do isgn=1,2
          do ig=-ntgrid,ntgrid
             z1(ig,isgn,iglo) = z1(ig,isgn,iglo) - z0(ig,isgn,iglo)*v0z1(ig,it,ik,is) &
                  / (1.+v0z0(ig,it,ik,is))
          end do
       end do
    end do

    deallocate (v0z1)

!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
! Now get v1z1

    allocate (v1z1(-ntgrid:ntgrid, nakx, naky, nspec))         

    do iglo = g_lo%llim_proc, g_lo%ulim_proc
! TT>
!       ik = ik_idx(g_lo,iglo)
! <TT
       ie = ie_idx(g_lo,iglo)
       il = il_idx(g_lo,iglo)
       is = is_idx(g_lo,iglo)
       do isgn = 1, 2
          do ig=-ntgrid, ntgrid
! v_parallel == vpa or vpac?? (zone vs. boundary)
! V_parallel == v_parallel J0
! v1 = nu V_parallel f_0
!
! No factor of sqrt(T/m) here on purpose (see derivation) 
!
! TT> GS2 has unknown factor of vnm here
!             gtmp(ig,isgn,iglo) = vnew(ik,ie,is)*vpa(isgn,iglo)*aj0(iglo) &
             gtmp(ig,isgn,iglo) = vnew_ss(ie,is)*vpa(isgn,iglo)*aj0(iglo) &
! <TT
                  * z1(ig,isgn,iglo)
          end do
       end do
    end do

    call integrate_moment (gtmp, v1z1, all)    ! redefined below

    deallocate (gtmp)
    
!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
! Now redefine z1 == z1/(1 + v1z1)

    do iglo = g_lo%llim_proc, g_lo%ulim_proc
       ik = ik_idx(g_lo,iglo)
       it = it_idx(g_lo,iglo)
       is = is_idx(g_lo,iglo)
       do isgn=1,2
          do ig=-ntgrid,ntgrid
             z1(ig,isgn,iglo) = z1(ig,isgn,iglo) / (1.+v1z1(ig,it,ik,is))
          end do
       end do
    end do

    deallocate (v1z1)
    
!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
! Now redefine z0 == (z1 * v1z0 - z0) / (1 + v0z0)

    do iglo = g_lo%llim_proc, g_lo%ulim_proc
       ik = ik_idx(g_lo,iglo)
       it = it_idx(g_lo,iglo)
       is = is_idx(g_lo,iglo)
       do isgn=1,2
          do ig=-ntgrid,ntgrid
             z0(ig,isgn,iglo) = (z1(ig,isgn,iglo) * v1z0(ig,it,ik,is) - z0(ig,isgn,iglo)) &
                  / (1.+v0z0(ig,it,ik,is))
          end do
       end do
    end do

    deallocate (v0z0, v1z0)
    
  end subroutine init_lz_mom_conserve

  subroutine init_vnew (hee)
    use species, only: nspec, spec, electron_species
    use le_grids, only: negrid, e
    use kgrids, only: naky, nakx
    use theta_grid, only: ntgrid
    use run_parameters, only: zeff
    use dist_fn_arrays, only: kperp2
    use constants
    real, dimension (:,:), intent (out) :: hee
    real,dimension (negrid,nspec)::heevth
    integer :: ik, ie, is, it, ig
    real :: v, k_nexp_h_max 

    do is = 1, nspec
       do ie = 1, negrid
          v = sqrt(e(ie,is))
          hee(ie,is) = 1.0/sqrt(pi)/v*exp(-e(ie,is)) &
               + (1.0 - 0.5/e(ie,is)) &
               *(1.0 - 1.0/(1.0          + v &
               *(0.0705230784 + v &
               *(0.0422820123 + v &
               *(0.0092705272 + v &
               *(0.0001520143 + v &
               *(0.0002765672 + v &
               *(0.0000430638)))))))**16)
       end do
    end do

!heevth is hee but using only thermal velocity (energy independent)

    do is = 1, nspec
       do ie = 1, negrid
          v = 1           
          heevth(ie,is) = 1.0/sqrt(pi)/v*exp(-v**2) &
               + (1.0 - 0.5/v**2) &
               *(1.0 - 1.0/(1.0          + v &
               *(0.0705230784 + v &
               *(0.0422820123 + v &
               *(0.0092705272 + v &
               *(0.0001520143 + v &
               *(0.0002765672 + v &
               *(0.0000430638)))))))**16)
       end do
    end do                                                                  

!    do is = 1, nspec
!       if (spec(is) % nustar < 0.)
!
!    end do

! TT>
!    if(.not.allocated(vnew)) allocate (vnew(naky,negrid,nspec))
    if (.not.allocated(vnew)) then
       allocate (vnew(negrid,nspec))
       allocate (vnew_ss(negrid,nspec))
    end if
! <TT
    if(.not.allocated(vnewh)) allocate (vnewh(-ntgrid:ntgrid,nakx,naky,nspec))

    do is = 1, nspec
       if (spec(is)%type == electron_species) then
          do ie = 1, negrid
! TT>
!             do ik = 1, naky
! <TT
                if (const_v) then
! TT>
!                   vnew(ik,ie,is) = spec(is)%nu * (zeff + heevth(ie,is)) * 0.5
                   vnew(ie,is) = spec(is)%nu * (zeff + heevth(ie,is)) * 0.5
                   vnew_ss(ie,is) = spec(is)%nu * heevth(ie,is) * 0.5
! <TT
                else
! TT>
!                   vnew(ik,ie,is) = spec(is)%nu / e(ie,is)**1.5 &
                   vnew(ie,is) = spec(is)%nu / e(ie,is)**1.5 &
                        *(zeff + hee(ie,is))*0.5        ! w/ ee
!                        * zeff * 0.5                    ! no ee
                   vnew_ss(ie,is) = spec(is)%nu / e(ie,is)**1.5 &
                        * hee(ie,is) * 0.5
! <TT
                end if
! TT>
!             end do
! <TT
          end do
       else
          do ie = 1, negrid
! TT>
!             do ik = 1, naky
! <TT
                if (const_v) then
! TT>
!                   vnew(ik,ie,is) = spec(is)%nu * heevth(ie,is) * 0.5
                   vnew(ie,is) = spec(is)%nu * heevth(ie,is) * 0.5
! <TT
                else
! TT>
!                   vnew(ik,ie,is) = spec(is)%nu / e(ie,is)**1.5 &
                   vnew(ie,is) = spec(is)%nu / e(ie,is)**1.5 &
! <TT
                        * hee(ie,is) * 0.5
                end if
! TT>
                vnew_ss(ie,is) = vnew(ie,is)
!             end do
! <TT
          end do
       end if

       ! add hyper-terms inside collision operator
!BD: Warning!
!BD: Also: there is no "grid_norm" option here and the exponent is fixed to 4 for now
!GGH: 07SEP07- Fixed so that exponent is given by spec(is)%nexp_h (default=4.0)
       if (hyper_colls) then
          k_nexp_h_max = (maxval(kperp2))**(spec(is)%nexp_h/2.)
          do ik = 1, naky
             do it = 1, nakx
                do ig=-ntgrid,ntgrid
                   vnewh(ig,it,ik,is) = spec(is)%nu_h * kperp2(it,ik)**(spec(is)%nexp_h/2.)/k_nexp_h_max
                end do
             end do
          end do
       else
          vnewh = 0.
       end if
    end do
    
  end subroutine init_vnew
!================================================================================
! Update vnewh if adaptive hypercollisions has changed nu_h
! GGH: 2007 JUN 5
!================================================================================
  subroutine update_vnewh
    use species, only: nspec, spec, specie
    use kgrids, only: naky, nakx
    use theta_grid, only: ntgrid
    use le_grids, only: negrid
    use dist_fn_arrays, only: kperp2
    implicit none
    integer :: ik, is, it, ig
    real :: k_nexp_h_max            !kperp_max normalization of hypercollisionality
    real, dimension (negrid,nspec) :: hee

    !Recalculate vnewh based on updated value of nu_h
    do is = 1, nspec
       if (hyper_colls) then
          k_nexp_h_max = (maxval(kperp2))**(spec(is)%nexp_h/2.) 
          do ik = 1, naky
             do it = 1, nakx
                do ig=-ntgrid,ntgrid
                   vnewh(ig,it,ik,is) = spec(is)%nu_h * kperp2(it,ik)**(spec(is)%nexp_h/2.)/k_nexp_h_max
                end do
             end do
          end do
       else
          vnewh = 0.
       end if
    enddo
    
    !Recalculate arrays for collision operator
    select case (collision_model_switch)
    case (collision_model_lorentz,collision_model_lorentz_test)
       call init_lorentz
       if (conserve_momentum) call init_lz_mom_conserve
    end select

  end subroutine update_vnewh
!================================================================================

  subroutine init_lorentz
    use species, only: nspec, spec
    use theta_grid, only: ntgrid
    use kgrids, only: naky, nakx
    use le_grids, only: nlambda, negrid, al, e
    use agk_time, only: dtime
    use dist_fn_arrays, only: kperp2
    use agk_layouts, only: init_lorentz_layouts
    use agk_layouts, only: lz_lo
    use agk_layouts, only: ig_idx, ik_idx, ie_idx, is_idx, it_idx
    implicit none
    integer :: ig, il, ilz, it, ik, ie, is
    real, dimension (nlambda+1) :: aa, bb, cc, dd, hh
    real, dimension (2*nlambda+1) :: a1, b1
    real :: xi0, xi1, xi2, xil, xir, vn, ee, vnh, vnc

    call init_lorentz_layouts (ntgrid, naky, nakx, nlambda, negrid, nspec)
    call init_lorentz_redistribute

    if (.not.allocated(glz)) then
       allocate (glz(2*nlambda+1,lz_lo%llim_proc:lz_lo%ulim_alloc))
       glz = 0.0
       if (heating) then
          allocate (glzc(2*nlambda+1,lz_lo%llim_proc:lz_lo%ulim_alloc))
          glzc = 0.0
       end if
    end if

    if (.not.allocated(c1)) then
       allocate (c1   (2*nlambda+1,lz_lo%llim_proc:lz_lo%ulim_alloc))
       allocate (betaa(2*nlambda+1,lz_lo%llim_proc:lz_lo%ulim_alloc))
       allocate (ql   (2*nlambda+1,lz_lo%llim_proc:lz_lo%ulim_alloc))
       allocate (d1   (2*nlambda+1,lz_lo%llim_proc:lz_lo%ulim_alloc))
       allocate (h1   (2*nlambda+1,lz_lo%llim_proc:lz_lo%ulim_alloc))
    endif

    c1 = 0.0 ; betaa = 0.0 ; ql = 0.0 ; d1 = 0.0 ; h1 = 0.0

    do ilz = lz_lo%llim_proc, lz_lo%ulim_proc
       is = is_idx(lz_lo,ilz)
       ik = ik_idx(lz_lo,ilz)
       it = it_idx(lz_lo,ilz)
       ie = ie_idx(lz_lo,ilz)
       ig = ig_idx(lz_lo,ilz)
       if (collision_model_switch == collision_model_lorentz_test) then
          vn = abs(spec(is)%nu)
          vnc = 0.
          vnh = 0.
       else
          if (hypermult) then
! TT>
!             vn = vnew(ik,ie,is)*(1.+vnewh(ig,it,ik,is))
!             vnc = vnew(ik,ie,is)
             vn = vnew(ie,is)*(1.+vnewh(ig,it,ik,is))
             vnc = vnew(ie,is)
! <TT
             vnh = vnewh(ig,it,ik,is)*vnc
          else
! TT>
!             vn = vnew(ik,ie,is)+vnewh(ig,it,ik,is)
!             vnc = vnew(ik,ie,is)
             vn = vnew(ie,is)+vnewh(ig,it,ik,is)
             vnc = vnew(ie,is)
! <TT
             vnh = vnewh(ig,it,ik,is)
          end if
       end if

       do il = 2, nlambda-1
          ! xi = v_par/v
          xi0 = sqrt(abs(1.0 - al(il-1)))   ! xi_{j-1}
          xi1 = sqrt(abs(1.0 - al(il)))     ! xi_j
          xi2 = sqrt(abs(1.0 - al(il+1)))   ! xi_{j+1}
          
          xil = (xi1 + xi0)/2.0  ! xi(j-1/2)
          xir = (xi1 + xi2)/2.0  ! xi(j+1/2)

          ee = 0.5*e(ie,is)*(1+xi1**2) / spec(is)%zstm**2 * kperp2(it,ik)*cfac

          ! coefficients for tridiagonal matrix:
          cc(il) = -vn*dtime*(1.0 - xir*xir)/(xir - xil)/(xi2 - xi1)
          aa(il) = -vn*dtime*(1.0 - xil*xil)/(xir - xil)/(xi1 - xi0)
          bb(il) = 1.0 - (aa(il) + cc(il)) + ee*vn*dtime
          
          ! coefficients for entropy heating calculation
          dd(il) =vnc*((1.0-xir*xir)/(xir-xil)/(xi2-xi1) + ee)
          hh(il) =vnh*((1.0-xir*xir)/(xir-xil)/(xi2-xi1) + ee)
       end do

! boundary at xi = 1
       xi0 = 1.0
       xi1 = sqrt(abs(1.0-al(1)))
       xi2 = sqrt(abs(1.0-al(2)))
       
       xil = (xi1 + xi0)/2.0
       xir = (xi1 + xi2)/2.0
       
       ee = 0.5*e(ie,is)*(1+xi1**2) / spec(is)%zstm**2 * kperp2(it,ik)*cfac
       
       cc(1) = -vn*dtime*(-1.0 - xir)/(xi2-xi1)
       aa(1) = 0.0
       bb(1) = 1.0 - (aa(1) + cc(1)) + ee*vn*dtime
       
       dd(1) =vnc*((1.0-xir*xir)/(xir-xil)/(xi2-xi1) + ee)
       hh(1) =vnh*((1.0-xir*xir)/(xir-xil)/(xi2-xi1) + ee)
       
! boundary at xi = 0
       il = nlambda
       xi0 = sqrt(abs(1.0 - al(il-1)))
       xi1 = sqrt(abs(1.0 - al(il)))
       xi2 = -xi1
       
       xil = (xi1 + xi0)/2.0
       xir = (xi1 + xi2)/2.0

       ee = 0.5*e(ie,is)*(1+xi1**2) / spec(is)%zstm**2 * kperp2(it,ik)*cfac
       
       cc(il) = -vn*dtime*(1.0 - xir*xir)/(xir - xil)/(xi2 - xi1)
       aa(il) = -vn*dtime*(1.0 - xil*xil)/(xir - xil)/(xi1 - xi0)
       bb(il) = 1.0 - (aa(il) + cc(il)) + ee*vn*dtime
       
       dd(il) =vnc*((1.0-xir*xir)/(xir-xil)/(xi2-xi1) + ee)
       hh(il) =vnh*((1.0-xir*xir)/(xir-xil)/(xi2-xi1) + ee)
       
! start to fill in the arrays for the tridiagonal
       a1(:nlambda) = aa(:nlambda)
       b1(:nlambda) = bb(:nlambda)
       c1(:nlambda,ilz) = cc(:nlambda)
       
       d1(:nlambda,ilz) = dd(:nlambda)
       h1(:nlambda,ilz) = hh(:nlambda)

! assuming symmetry in xi, fill in the rest of the arrays.
       a1(nlambda+1:2*nlambda)     = cc(nlambda:1:-1)
       b1(nlambda+1:2*nlambda)     = bb(nlambda:1:-1)
       c1(nlambda+1:2*nlambda,ilz) = aa(nlambda:1:-1)
       
       d1(nlambda+1:2*nlambda,ilz) = dd(nlambda:1:-1)
       h1(nlambda+1:2*nlambda,ilz) = hh(nlambda:1:-1)
       
       betaa(1,ilz) = 1.0/b1(1)
       do il = 1, 2*nlambda-1
          ql(il+1,ilz) = a1(il+1)*betaa(il,ilz)
          betaa(il+1,ilz) = 1.0/(b1(il+1)-ql(il+1,ilz)*c1(il,ilz))
       end do
       
       ql(1,ilz) = 0.0

       ql(2*nlambda+1,ilz) = 0.0
       c1(2*nlambda+1,ilz) = 0.0
       betaa(2*nlambda+1,ilz) = 0.0
       
       d1(2*nlambda+1,ilz) = 0.0
       h1(2*nlambda+1,ilz) = 0.0
    end do
  end subroutine init_lorentz

  subroutine init_lorentz_redistribute
    use mp, only: nproc
    use species, only: nspec
    use theta_grid, only: ntgrid
    use kgrids, only: naky, nakx
    use le_grids, only: nlambda, negrid
    use agk_layouts, only: init_lorentz_layouts
    use agk_layouts, only: g_lo, lz_lo
    use agk_layouts, only: idx_local, proc_id
    use agk_layouts, only: gidx2lzidx 
    use redistribute, only: index_list_type, init_redist, delete_list
    implicit none
    type (index_list_type), dimension(0:nproc-1) :: to_list, from_list
    integer, dimension(0:nproc-1) :: nn_to, nn_from
    integer, dimension(3) :: from_low, from_high
    integer, dimension(2) :: to_high
    integer :: to_low
    integer :: ig, isign, iglo, il, ilz
    integer :: n, ip
    logical :: done = .false.

    if (done) return

    call init_lorentz_layouts (ntgrid, naky, nakx, nlambda, negrid, nspec)

    ! count number of elements to be redistributed to/from each processor
    nn_to = 0
    nn_from = 0
    do iglo = g_lo%llim_world, g_lo%ulim_world
       do isign = 1, 2
          do ig = -ntgrid, ntgrid
             call gidx2lzidx (ig, isign, g_lo, iglo, lz_lo, ntgrid, il, ilz)
             if (idx_local(g_lo,iglo)) &
                  nn_from(proc_id(lz_lo,ilz)) = nn_from(proc_id(lz_lo,ilz)) + 1
             if (idx_local(lz_lo,ilz)) &
                  nn_to(proc_id(g_lo,iglo)) = nn_to(proc_id(g_lo,iglo)) + 1
          end do
       end do
    end do

    do ip = 0, nproc-1
       if (nn_from(ip) > 0) then
          allocate (from_list(ip)%first(nn_from(ip)))
          allocate (from_list(ip)%second(nn_from(ip)))
          allocate (from_list(ip)%third(nn_from(ip)))
       end if
       if (nn_to(ip) > 0) then
          allocate (to_list(ip)%first(nn_to(ip)))
          allocate (to_list(ip)%second(nn_to(ip)))
       end if
    end do

    ! get local indices of elements distributed to/from other processors
    nn_to = 0
    nn_from = 0
    do iglo = g_lo%llim_world, g_lo%ulim_world
       do isign = 1, 2
          do ig = -ntgrid, ntgrid
             call gidx2lzidx (ig, isign, g_lo, iglo, lz_lo, ntgrid, il, ilz)
!             write(*,*) ig,':',isign,':',iglo,':',il,':',ilz
             if (idx_local(g_lo,iglo)) then
                ip = proc_id(lz_lo,ilz)
                n = nn_from(ip) + 1
                nn_from(ip) = n
                from_list(ip)%first(n) = ig
                from_list(ip)%second(n) = isign
                from_list(ip)%third(n) = iglo
             end if
             if (idx_local(lz_lo,ilz)) then
                ip = proc_id(g_lo,iglo)
                n = nn_to(ip) + 1
                nn_to(ip) = n
                to_list(ip)%first(n) = il
                to_list(ip)%second(n) = ilz
             end if
          end do
       end do
    end do

    from_low (1) = -ntgrid
    from_low (2) = 1
    from_low (3) = g_lo%llim_proc

    to_low = lz_lo%llim_proc

    to_high(1) = 2*nlambda+1
    to_high(2) = lz_lo%ulim_alloc

    from_high(1) = ntgrid
    from_high(2) = 2
    from_high(3) = g_lo%ulim_alloc

    call init_redist (lorentz_map, 'c', to_low, to_high, to_list, &
         from_low, from_high, from_list)

    call delete_list (to_list)
    call delete_list (from_list)

    done = .true.

  end subroutine init_lorentz_redistribute

  subroutine solfp1 (g, gold, g1, phi, bpar, diagnostics)
    use agk_layouts, only: g_lo
    use theta_grid, only: ntgrid
    use run_parameters, only: fphi, fbpar
    use le_grids, only: integrate_moment
    use agk_time, only: dtime
    use dist_fn_arrays, only: c_rate
    use constants
    implicit none
    complex, dimension (-ntgrid:,:,:), intent (in out) :: g, gold, g1 
    complex, dimension (-ntgrid:,:,:), intent (in) :: phi, bpar
    complex, dimension (:,:,:), allocatable :: gc1, gc2, gc3
    integer, optional, intent (in) :: diagnostics

    if (collision_model_switch == collision_model_none) return

    if (heating .and. present(diagnostics)) then
       allocate (gc1(-ntgrid:ntgrid,2,g_lo%llim_proc:g_lo%ulim_alloc)) ; gc1 = 0.
       allocate (gc2(-ntgrid:ntgrid,2,g_lo%llim_proc:g_lo%ulim_alloc)) ; gc2 = 0.
       allocate (gc3(-ntgrid:ntgrid,2,g_lo%llim_proc:g_lo%ulim_alloc)) ; gc3 = 0.
    else
       allocate (gc1(1,1,1)) ; gc1 = 0.
       allocate (gc2(1,1,1)) ; gc2 = 0.
       allocate (gc3(1,1,1)) ; gc3 = 0.
    end if
    
    if (adjust) call g_adjust (g, phi, bpar, fphi, fbpar)
    if (adjust) call g_adjust (gold, phi, bpar, fphi, fbpar)

    if (present(diagnostics)) then
       gc3 = g
       call solfp_lorentz (g, gc1, gc2, diagnostics)
    else
       call solfp_lorentz (g, gc1, gc2)
    end if
    
    if (conserve_momentum) call conserve_mom (g, g1)

    if (heating .and. present(diagnostics)) then

       call integrate_moment (gc1, c_rate(:,:,:,:,1))

       if (hyper_colls) call integrate_moment (gc2, c_rate(:,:,:,:,2))

! form (h_i+1 + h_i)/2 * C(h_i+1) and integrate.  

       gc3 = 0.5*conjg(g+gold)*(g-gc3)/dtime

      call integrate_moment (gc3, c_rate(:,:,:,:,3))

    end if

    deallocate (gc1, gc2, gc3)
    
    if (adjust) call g_adjust (g, phi, bpar, -fphi, -fbpar)
    if (adjust) call g_adjust (gold, phi, bpar, -fphi, -fbpar)
    
  end subroutine solfp1

  subroutine conserve_mom (g, g1)

    use mp, only: proc0
    use theta_grid, only: ntgrid
    use species, only: nspec
    use kgrids, only: naky, nakx
    use agk_layouts, only: g_lo, ik_idx, it_idx, ie_idx, il_idx, is_idx
    use le_grids, only: integrate_moment
    use dist_fn_arrays, only: aj0, aj1vp2, vpa
    
    complex, dimension (-ntgrid:,:,g_lo%llim_proc:), intent (in out) :: g, g1
    complex, dimension (:,:,:,:), allocatable :: v0y0, v1y0

    integer :: ig, isgn, iglo, ik, ie, il, is, it, all = 1

!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
! First get v0y0

    allocate (v0y0(-ntgrid:ntgrid, nakx, naky, nspec))         

    do iglo = g_lo%llim_proc, g_lo%ulim_proc
       ik = ik_idx(g_lo,iglo)
       ie = ie_idx(g_lo,iglo)
       il = il_idx(g_lo,iglo)
       is = is_idx(g_lo,iglo)
       do isgn = 1, 2
          do ig=-ntgrid, ntgrid
! V_perp == e(ie,is)*al(il)*aj1(iglo) = 0.5*aj1vp2(iglo)
! v0 = nu V_perp
! TT>
!             g1(ig,isgn,iglo) = vnew(ik,ie,is)*e(ie,is)*al(il)*aj1(iglo) &
!             g1(ig,isgn,iglo) = vnew_ss(ie,is)*e(ie,is)*al(il)*aj1(iglo) &
             g1(ig,isgn,iglo) = 0.5*vnew_ss(ie,is)*aj1vp2(iglo) &
! <TT
                  * g(ig,isgn,iglo)
          end do
       end do
    end do

    call integrate_moment (g1, v0y0, all)    ! v0y0
    
!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
! Now get v1y0

    allocate (v1y0(-ntgrid:ntgrid, nakx, naky, nspec))         

    do iglo = g_lo%llim_proc, g_lo%ulim_proc
       ik = ik_idx(g_lo,iglo)
       ie = ie_idx(g_lo,iglo)
       il = il_idx(g_lo,iglo)
       is = is_idx(g_lo,iglo)
       do isgn = 1, 2
          do ig=-ntgrid, ntgrid
! v_parallel == vpa or vpac?? (zone vs. boundary)
! V_parallel == v_parallel J0
! v1 = nu V_parallel f_0
!
! No factor of sqrt(T/m) here on purpose (see derivation) 
!
! TT>
!             g1(ig,isgn,iglo) = vnew(ik,ie,is)*vpa(isgn,iglo)*aj0(iglo) &
             g1(ig,isgn,iglo) = vnew_ss(ie,is)*vpa(isgn,iglo)*aj0(iglo) &
! <TT
                  * g(ig,isgn,iglo)
          end do
       end do
    end do

    call integrate_moment (g1, v1y0, all)    ! v1y0

!    if (proc0) then
!       write (*,*) v1y0
!    end if

! Conserve momentum:

!    write (*,*) sum(z0)
    
    do iglo = g_lo%llim_proc, g_lo%ulim_proc
       ik = ik_idx(g_lo,iglo)
       it = it_idx(g_lo,iglo)
       is = is_idx(g_lo,iglo)
       do isgn=1,2
          do ig=-ntgrid,ntgrid
             g(ig,isgn,iglo) = g(ig,isgn,iglo) + z0(ig,isgn,iglo)*v0y0(ig,it,ik,is) &
                  - z1(ig,isgn,iglo) * v1y0(ig,it,ik,is)
          end do
       end do
    end do

    deallocate (v0y0, v1y0)

  end subroutine conserve_mom

  subroutine g_adjust (g, phi, bpar, facphi, facbpar)
    use species, only: spec
    use theta_grid, only: ntgrid
    use le_grids, only: anon
    use dist_fn_arrays, only: aj0, aj1vp2
    use agk_layouts, only: g_lo, ik_idx, it_idx, ie_idx, is_idx
    implicit none
    complex, dimension (-ntgrid:,:,g_lo%llim_proc:), intent (in out) :: g
    complex, dimension (-ntgrid:,:,:), intent (in) :: phi, bpar
    real, intent (in) :: facphi, facbpar

    integer :: iglo, ig, ik, it, ie, is
    complex :: adj

    do iglo = g_lo%llim_proc, g_lo%ulim_proc
       ik = ik_idx(g_lo,iglo)
       it = it_idx(g_lo,iglo)
       ie = ie_idx(g_lo,iglo)
       is = is_idx(g_lo,iglo)
       do ig = -ntgrid, ntgrid
          adj = anon(ie,is)*aj1vp2(iglo)*bpar(ig,it,ik)*facbpar &
               + spec(is)%z*anon(ie,is)*phi(ig,it,ik)*aj0(iglo) &
                  /spec(is)%temp*facphi
          g(ig,1,iglo) = g(ig,1,iglo) + adj
          g(ig,2,iglo) = g(ig,2,iglo) + adj
       end do
    end do
  end subroutine g_adjust

  subroutine solfp_lorentz (g, gc, gh, diagnostics)
    use species, only: spec, electron_species
    use theta_grid, only: ntgrid
    use le_grids, only: nlambda
    use agk_layouts, only: g_lo, lz_lo
    use agk_layouts, only: ig_idx, ik_idx, il_idx, is_idx, it_idx
    use redistribute, only: gather, scatter
    implicit none
    complex, dimension (-ntgrid:,:,g_lo%llim_proc:), intent (in out) :: g, gc, gh
    integer, optional, intent (in) :: diagnostics

    complex, dimension (2*nlambda+1) :: delta
    complex :: fac
    integer :: iglo, ilz, ig, ik, il, is, je, it


    call gather (lorentz_map, g, glz)

    if (heating .and. present(diagnostics)) then
       do ilz = lz_lo%llim_proc, lz_lo%ulim_proc
          ig = ig_idx(lz_lo,ilz)

          je = 2*nlambda

          do il = 1, je-1
             fac = glz(il+1,ilz)-glz(il,ilz)
             glzc(il,ilz) = conjg(fac)*fac*d1(il,ilz)  ! d1 accounts for hC(h) entropy
          end do
       end do
       call scatter (lorentz_map, glzc, gc)

       if (hyper_colls) then
          do ilz = lz_lo%llim_proc, lz_lo%ulim_proc
             ig = ig_idx(lz_lo,ilz)
             
             je = 2*nlambda 
             
             do il = 1, je-1
                fac = glz(il+1,ilz)-glz(il,ilz)
                glzc(il,ilz) = conjg(fac)*fac*h1(il,ilz)  ! h1 accounts for hH(h) entropy
             end do
          end do
          call scatter (lorentz_map, glzc, gh)
       end if
    end if

    ! solve for glz row by row
    do ilz = lz_lo%llim_proc, lz_lo%ulim_proc
       ig = ig_idx(lz_lo,ilz)
! TT>
!       ik = ik_idx(lz_lo,ilz)
! <TT
       is = is_idx(lz_lo,ilz)
! TT>
!       if (abs(vnew(ik,1,is)) < 2.0*epsilon(0.0)) cycle
       if (abs(vnew(1,is)) < 2.0*epsilon(0.0)) cycle
! <TT

       je = 2*nlambda+1

       glz(je:,ilz) = 0.0

       ! right and left sweeps for tridiagonal solve:

       delta(1) = glz(1,ilz)
       do il = 1, je-1
          delta(il+1) = glz(il+1,ilz) - ql(il+1,ilz)*delta(il)
       end do
       
       glz(je,ilz) = delta(je)*betaa(je,ilz)
       do il = je-1, 1, -1
          glz(il,ilz) = (delta(il) - c1(il,ilz)*glz(il+1,ilz))*betaa(il,ilz)
       end do

    end do

    call scatter (lorentz_map, glz, g)

  end subroutine solfp_lorentz

  subroutine reset_init
!
! forces recalculation of coefficients in collision operator
! when timestep changes.
!    
    initialized = .false.  

  end subroutine reset_init

end module collisions
