module legendre

! Taken from Numerical Recipes, incorporated here by Tomo Tatsuno
! Aug, 2005

  implicit none

  public :: nrgauleg

  private

contains

  subroutine nrgauleg (x1, x2, x, w)!, eps)

    real, intent(in) :: x1, x2
    real, dimension(:), intent(out) :: x, w
    real  :: eps

    integer :: its, j, m, n
    integer, parameter :: maxit=100
    double precision :: xl, xm, pi
    double precision, dimension((size(x)+1)/2) :: p1, p2, p3, pp, z, z1
    logical, dimension((size(x)+1)/2) :: unfinished

! hack for now
    eps = epsilon(xm)
    
    n = size(x)
    pi = asin(real(1.0,kind(pi)))*2.0
    m = (n+1)/2

    xm = real(0.5,kind(xm)) * (x1+x2)   ! middle of the section
    xl = real(0.5,kind(xl)) * (x2-x1)   ! signed half length of the section
    z = (/ (cos(pi*(j-0.25)/(n+0.5)), j=1,m) /)
    unfinished = .true.

    do its=1, maxit
       where(unfinished)
          p1 = real(1.0,kind(p1(1)))
          p2 = real(0.0,kind(p2(1)))
       end where
       do j=1, n
          where (unfinished)
             p3 = p2
             p2 = p1
             p1 = ((2*j-1) * z * p2 - (j-1) * p3) / j
          end where
       end do
! p1 now contains the desired legendre polynomials.
       where (unfinished)
          pp = n * (z * p1 - p2) / (z**2 - 1.0)
          z1 = z
          z = z1 - p1 / pp
          unfinished = (abs(z-z1) > eps)
       end where
       if (.not. any(unfinished)) exit
    end do

    if (its == maxit+1) then
       print*, 'too many iterations in nrgauleg'
       stop
    end if
    x(1:m) = xm - xl * z
    x(n:n-m+1:-1) = xm + xl * z
    w(1:m) = 2.0 * abs(xl) / ((1.0 - z**2) * pp**2)
    w(n:n-m+1:-1) = w(1:m)

  end subroutine nrgauleg
  
end module legendre

module egrid

! By Tomo Tatsuno, Aug 2005
! Improved accuracy and speed and maximum number of energy grid points
!

  implicit none

  public :: energy, xgrid, setegrid, init_egrid
  public :: zeroes, x0

  private

  real :: x0
  real, dimension (:), allocatable :: zeroes

  interface xgrid
     module procedure xgrid_s, xgrid_v
  end interface

contains

  subroutine init_egrid (negrid)
    
    integer, intent (in) :: negrid
    logical :: first = .true.

    if (first) then
       first = .false.
       allocate (zeroes(negrid-1)) ; zeroes = 0.
    end if

  end subroutine init_egrid

  subroutine setegrid (Ecut, negrid, epts, wgts)

    use legendre, only: nrgauleg
! TT>
    use file_utils, only: error_unit
! <TT
    implicit none

    integer, intent (in) :: negrid
    real, intent (in) :: ecut
    real, dimension(:), intent (out) :: epts, wgts
    integer :: ie, np
! TT>
    integer :: ier, ierr
! <TT

    real :: eps=1.e-15

! TT> never used
!    real :: x
! <TT

    call init_egrid (negrid)
    
    if (Ecut > 20.0) then
! should go to runname.error?
! TT> Now I made it so.
!       write (*,*) 'E -> x transformation is not numerically correct for such a large Ecut'
!       write (*,*) 'x(E=20) =', xgrid(20.), ' is too close to 1.0'
       ierr = error_unit()
       write (ierr,*) 'E -> x transformation is not numerically correct for such a large Ecut'
       write (ierr,*) 'x(E=20) =', xgrid(20.), ' is too close to 1.0'
! <TT
    end if

    np = negrid-1

    x0 = xgrid(ecut)      ! function xgrid_s (single element)

    call nrgauleg (0., x0, zeroes, wgts(1:np))!, eps**1.5)

    do ie=1,np
! TT> added error message
!       epts(ie) = energy(zeroes(ie), Ecut)
       epts(ie) = energy(zeroes(ie), Ecut, ier)
       if (ier /= 0) then
          ierr = error_unit()
          write (ierr, '(a,i8,2(a,f16.10))') 'ERROR in ie= ', ie, &
               & 'zero= ', zeroes(ie), 'epts= ', epts(ie)
       end if
! <TT
    end do

    epts(np+1) = ecut
    wgts(np+1) = 1.-x0

    if (wgts(negrid) > wgts(np)) then
       write (*,*) 'WARNING: weight at ecut: ', wgts(negrid)
       write (*,*) 'WARNING: is larger than the one at lower grid: ', wgts(np)
       write (*,*) 'WARNING: Recommend using fewer energy grid points'
       if (ecut < 20.0) &
            & write (*,*) 'WARNING: or you should increase ecut (<= 20)'
    end if

  end subroutine setegrid

! TT>
!  function energy (xeval, ecut)
  function energy (xeval, ecut, ier)
! <TT
    real, intent (in) :: xeval, ecut
    real :: xerrbi, xerrsec, a, b, energy
! TT>
!    integer :: ier
    integer, intent(out) :: ier
! <TT

    xerrbi = 1.e-5
    xerrsec = 1.e-13

    a = 0.0
    b = ecut
    call roote (xeval, a, b, xerrbi, xerrsec, 1, ier, energy)

  end function energy

! called as roote (xeval, a, b, xerrbi, xerrsec, 1, ier, energy)

  subroutine roote(fval,a,b,xerrbi,xerrsec,nsolv,ier,soln)

    use mp, only: proc0
    use file_utils, only: error_unit
! TT
! solve xgrid(E) == fval for E
!   xerrbi: error max in x for bisection routine
!   xerrsec: error max in x for secant (Newton-Raphson) method
! TT

    integer, intent(in) :: nsolv
    integer, intent(out) :: ier
    real, intent(in) :: fval, a, b, xerrbi, xerrsec
    real, intent(out) :: soln
    integer, parameter :: maxit=30 ! maximum iteration for Newton scheme
    integer :: i, niter, isolv, ierr
    real :: a1, b1, f1, f2, f3, trial, aold

    ier=0
    a1=a
    b1=b
    f1=xgrid(a1)-fval
    f2=xgrid(b1)-fval

    if (xerrbi > 0.) then
       if (f1*f2 < 0.) then
          niter = int(log(abs(b-a)/xerrbi)/log(2.))+1
          do i=1, niter
             trial = 0.5*(a1+b1)
             f3 = xgrid(trial) - fval
             if (f3*f1 > 0.) then
                a1 = trial
                f1 = f3
             else
                b1 = trial
                f2 = f3
             endif
             !      write(11,*) 'i,a1,f1,b1,f2 ',i,a1,f1,b1,f2
          enddo
       else
          if (proc0) then
             ierr = error_unit()
             write(ierr,*) 'f1 and f2 have same sign in bisection routine'
             write(ierr,*) 'a1,f1,b1,f2=',a1,f1,b1,f2
             write(ierr,*) 'fval=',fval
          end if
          ier=1
       endif
    end if

    ! to make (a1,f1) the closest
    if ( abs(f1) > abs(f2) ) then
       f1=f2
       aold=a1
       a1=b1
       b1=aold
    endif

! Newton-Raphson method (formerly it was secant method)
    isolv = 0
    do i=1, maxit
       b1 = a1
       a1 = a1 - f1 / xgrid_prime(a1)
! TT> The next line is too severe for large ecut
!       if (abs(a1-b1) < xerrsec) isolv = isolv+1
       if (abs((a1-b1)/a1) < xerrsec) isolv = isolv+1
! <TT
       if (isolv >= nsolv) exit
       f1 = xgrid(a1) - fval
    end do

    if (i > maxit) then
       if (proc0) then
          ierr = error_unit()
          write (ierr,*) 'le_grids:roote: bad convergence'
       end if
       ier = 1
    end if

    soln = a1

  end subroutine roote

  function xgrid_s (e)
! TT
! this is a function
!              2           E
!   x(E) = ---------- * int   exp(-e) sqrt(e) de
!           sqrt(pi)       0
! which gives energy integral with a Maxwellian weight
! x is a monotonic function of E with
!    E | 0 -> infinity
!   -------------------
!    x | 0 -> 1
! The integral is an error function and numerical evaluation is
! obtained from the formula of incomplete gamma function
! TT

    double precision :: xg, denom, pi
    real :: e, xgrid_s
    integer :: kmax, k, j

    pi = asin(real(1.0,kind(pi))) * 2.0
    kmax = 100
    xg = 0.0

    denom = 1.0
    do k = 0, kmax
       denom = denom * (1.5+k)
       xg = xg + e**(1.5+k) / denom
    end do

    xgrid_s = xg * exp(-e) * 2. / sqrt(pi)

  end function xgrid_s

  function xgrid_prime (e)

! TT
! this is a function
!               2
!   x'(E) = ---------- * exp(-e) sqrt(e)
!            sqrt(pi)
! TT

    real :: e, xgrid_prime, pi

    pi = asin(1.0) * 2.0
    xgrid_prime = exp(-e) * sqrt(e) * 2. / sqrt(pi)

  end function xgrid_prime

  function xgrid_v (e) result (xg)

    real, dimension (:) :: e
    real, dimension (size(e)) :: xg
    real :: denom, pi
    integer :: kmax, k

    pi = asin(1.0) * 2.0

    xg = 0.
    kmax = 100
    denom=1.
    do k = 0, kmax
       denom = denom * (1.5 + k)
       xg = xg + e**(1.5+k) / denom
    end do
    xg = xg * exp(-e) * 2. / sqrt(pi)

  end function xgrid_v

end module egrid

module le_grids
  
  implicit none

  public :: init_le_grids, integrate_moment, integrate_species
  public :: e, dele, al, delal, anon
  public :: negrid, nlambda, ng2
  public :: xx, nterp, testfac, ecut
  public :: eint_error, lint_error, integrate_test
  public :: init_weights
!  public :: legendre_transform, lagrange_interp, lagrange_coefs

  private

  interface integrate_species
     module procedure integrate_species0, integrate_species1
  end interface

  real, dimension (:,:), allocatable :: e, w, anon, dele ! (negrid,nspec)
  real, dimension (:), allocatable :: al, delal, wl ! (nlambda)

  real, dimension (:), allocatable :: xx ! (nlambda)
  real, dimension (:,:), allocatable, save :: werr, wlerr

 ! knobs
  integer :: ngauss, negrid
  real :: ecut

  integer :: nlambda, ng2
  logical :: accel_x = .false.
  logical :: accel_v = .false.
  logical :: test = .false.

  integer :: testfac = 1
  integer :: nmax = 500
  integer :: nterp = 100

  real :: wgt_fac = 10.0

contains

  subroutine init_le_grids (accelerated_x, accelerated_v)
    use mp, only: proc0, finish_mp
    use species, only: init_species
    use theta_grid, only: init_theta_grid
    use kgrids, only: init_kgrids
    use agk_layouts, only: init_agk_layouts
    implicit none
    logical, intent (out) :: accelerated_x, accelerated_v
    logical, save :: initialized = .false.
    integer :: il, ie

    if (initialized) return
    initialized = .true.

    call init_agk_layouts
    call init_species
    call init_theta_grid
    call init_kgrids

    if (proc0) then
       call read_parameters
       call set_grids
    end if
    call broadcast_results
    call init_integrations

    accelerated_x = accel_x
    accelerated_v = accel_v

    if (test) then
       if (proc0) then
          do il = 1, nlambda
             write(*,*) al(il)
          end do
          write(*,*) 
          do ie = 1, negrid
             write(*,*) e(ie,1)
          end do
       end if
       call finish_mp
       stop
    endif
    
  end subroutine init_le_grids

  subroutine broadcast_results
    use mp, only: proc0, broadcast
    use species, only: nspec
    use theta_grid, only: ntgrid
    use egrid, only: zeroes, x0, init_egrid
    implicit none
    integer :: il, is, ie, ipt, isgn, tsize

    tsize = 2*nterp-1

    call broadcast (ngauss)
    call broadcast (negrid)
    call broadcast (ecut)
    call broadcast (nlambda)
    call broadcast (ng2)
    call broadcast (test)
    call broadcast (testfac)
    call broadcast (nmax)
    call broadcast (wgt_fac)
    call broadcast (nterp)

    if (.not. proc0) then
       allocate (e(negrid,nspec), w(negrid,nspec), anon(negrid,nspec))
       allocate (dele(negrid,nspec))
       allocate (al(nlambda), delal(nlambda))
       allocate (wl(nlambda))
       allocate (xx(nlambda))
    end if

    call init_egrid (negrid)

    call broadcast (al)
    call broadcast (delal)
    call broadcast (xx)
    call broadcast (x0)
    call broadcast (zeroes)

    do is = 1, nspec
       call broadcast (e(:,is))
       call broadcast (w(:,is))
       call broadcast (anon(:,is))
       call broadcast (dele(:,is))
    end do
    call broadcast (wl)

  end subroutine broadcast_results

  subroutine read_parameters
    use file_utils, only: input_unit, error_unit, input_unit_exist
    implicit none
    integer :: ierr, in_file
    logical :: exist
    namelist /le_grids_knobs/ ngauss, negrid, ecut, test, &
         testfac, nmax, wgt_fac, nterp

    ngauss = 8    ! Note: nlambda = 2 * ngauss
    negrid = 16   
    ecut = 6.0    ! new default value for advanced scheme
    in_file=input_unit_exist("le_grids_knobs", exist)
    if (exist) read (unit=input_unit("le_grids_knobs"), nml=le_grids_knobs)

  end subroutine read_parameters

  subroutine init_integrations
    use mp, only: nproc
    use theta_grid, only: ntgrid
    use kgrids, only: nakx, naky
    use species, only: nspec
    use agk_layouts, only: init_dist_fn_layouts, pe_layout
    implicit none
    character (1) :: char
    logical :: first = .true.

    call init_dist_fn_layouts (ntgrid, naky, nakx, nlambda, negrid, nspec)

    if (first) then
       first = .false.
       call pe_layout (char)
       if (char == 'x') then
          accel_x = mod(nakx*naky*nspec, nproc) == 0
          accel_v = .false.
       end if
       if (char == 'v') then
          accel_x = .false.
          accel_v = mod(negrid*nlambda*nspec, nproc) == 0
       end if
    end if
          
  end subroutine init_integrations

  subroutine init_weights

    use file_utils, only: open_output_file, close_output_file
    use egrid, only: x0, zeroes
    use mp, only: proc0

    implicit none

    real, dimension (:), allocatable :: modzeroes, werrtmp  ! (negrid-2)
    real, dimension (:), allocatable :: lmodzeroes, wlerrtmp ! (nlambda-1)
    integer :: ipt, ndiv, divmax
    logical :: eflag = .false.

    integer :: ie, il, itmp
    integer :: tmp_unit

    allocate(modzeroes(negrid-2), werrtmp(negrid-2))
    allocate(lmodzeroes(nlambda-1), wlerrtmp(nlambda-1))
    allocate(werr(negrid-1,negrid-1))
    allocate(wlerr(nlambda,nlambda))

    werr = 0.0; modzeroes = 0.0; werrtmp = 0.0
    wlerr = 0.0; lmodzeroes = 0.0; wlerrtmp = 0.0

! loop to obtain weights for energy grid points.  negrid-1 sets
! of weights are needed because we want to compute integrals
! for negrid-1 sets of energy points (corresponding to negrid-1
! points that we can choose to drop from the guassian quadrature)

    do ipt=1,negrid-1

! drops the point corresponding to ipt from the energy grid

       if (ipt /= 1) modzeroes(:ipt-1) = zeroes(:ipt-1)
       if (ipt /= negrid-1) modzeroes(ipt:negrid-2) = zeroes(ipt+1:)

! get weights for energy grid points
       
       call get_weights (nmax,0.0,x0,modzeroes,werrtmp,ndiv,divmax,eflag)
       
! a zero is left in the position corresponding to the dropped point
! factor of 0.25 is necessary to account for plus/minus vpa and
! 1/2 factor in lambda integrals

       if (ipt /= 1) werr(:ipt-1,ipt) = werrtmp(:ipt-1)*0.25
       if (ipt /= negrid-1) werr(ipt+1:,ipt) = werrtmp(ipt:)*0.25

    end do

! same thing done here for lamdba as was
! done earlier for energy space

    do ipt=1,nlambda

       if (ipt /= 1) lmodzeroes(:ipt-1) = xx(:ipt-1)
       if (ipt /= nlambda) lmodzeroes(ipt:nlambda-1) = xx(ipt+1:)

       call get_weights (nmax,1.0,0.0,lmodzeroes,wlerrtmp,ndiv,divmax,eflag)

       if (ipt /= 1) wlerr(:ipt-1,ipt) = 2.0*wlerrtmp(:ipt-1)
       if (ipt /= nlambda) wlerr(ipt+1:,ipt) = 2.0*wlerrtmp(ipt:)

    end do

    call open_output_file (tmp_unit,".wgts")

    call close_output_file (tmp_unit)

    deallocate(modzeroes,werrtmp,lmodzeroes,wlerrtmp)

  end subroutine init_weights

! the get_weights subroutine determines how to divide up the integral into 
! subintervals and how many grid points should be in each subinterval

  subroutine get_weights (maxpts_in, llim, ulim, nodes, wgts, ndiv, divmax, err_flag)

    implicit none

    integer, intent (in) :: maxpts_in
    real, intent (in) :: llim, ulim
    real, dimension (:), intent (in) :: nodes
    real, dimension (:), intent (out) :: wgts
    logical, intent (out) :: err_flag
    integer, intent (out) :: ndiv, divmax

    integer :: npts, rmndr, basepts, divrmndr, base_idx, idiv, epts, im, maxpts
    integer, dimension (:), allocatable :: divpts

    real :: wgt_max

! npts is the number of grid points in the integration interval
    npts = size(nodes)

    wgts = 0.0; epts = npts; basepts = nmax; divrmndr = 0; ndiv = 1; divmax = npts

! maxpts is the max number of pts in an integration subinterval
    maxpts = min(maxpts_in,npts)

    do

       wgt_max = wgt_fac/maxpts

! only need to subdivide integration interval if maxpts < npts
       if (maxpts .ge. npts) then
          call get_intrvl_weights (llim, ulim, nodes, wgts)
       else
          rmndr = mod(npts-maxpts,maxpts-1)
          
! if rmndr is 0, then each subinterval contains maxpts pts
          if (rmndr == 0) then
! ndiv is the number of subintervals
             ndiv = (npts-maxpts)/(maxpts-1) + 1
             allocate (divpts(ndiv))
! divpts is an array containing the # of pts for each subinterval
             divpts = maxpts
          else
             ndiv = (npts-maxpts)/(maxpts-1) + 2
             allocate (divpts(ndiv))
! epts is the effective number of pts after taking into account double
! counting of some grid points (those that are boundaries of subintervals
! are used twice)
             epts = npts + ndiv - 1
             basepts = epts/ndiv
             divrmndr = mod(epts,ndiv)
             
! determines if all intervals have same # of pts
             if (divrmndr == 0) then
                divpts = basepts
             else
                divpts(:divrmndr) = basepts + 1
                divpts(divrmndr+1:) = basepts
             end if
          end if
          
          base_idx = 0
          
! loop calls subroutine to get weights for each subinterval
          do idiv=1,ndiv
             if (idiv == 1) then
                call get_intrvl_weights (llim, nodes(base_idx+divpts(idiv)), &
                     nodes(base_idx+1:base_idx+divpts(idiv)),wgts(base_idx+1:base_idx+divpts(idiv)))
             else if (idiv == ndiv) then
                call get_intrvl_weights (nodes(base_idx+1), ulim, &
                     nodes(base_idx+1:base_idx+divpts(idiv)),wgts(base_idx+1:base_idx+divpts(idiv)))
             else
                call get_intrvl_weights (nodes(base_idx+1), nodes(base_idx+divpts(idiv)), &
                     nodes(base_idx+1:base_idx+divpts(idiv)),wgts(base_idx+1:base_idx+divpts(idiv)))
             end if
             base_idx = base_idx + divpts(idiv) - 1
          end do
          
          divmax = maxval(divpts)

          deallocate (divpts)
       end if

! check to make sure the weights do not get too large
       if (abs(maxval(wgts)) .gt. wgt_max) then
          if (maxpts .lt. 3) then
             err_flag = .true.
             exit
          end if
          maxpts = divmax - 1
       else
          exit
       end if

       wgts = 0.0; epts = npts; divrmndr = 0; basepts = nmax
    end do

  end subroutine get_weights

  subroutine get_intrvl_weights (llim, ulim, nodes, wgts)
    use legendre, only: nrgauleg
    
    implicit none
    
    ! llim (ulim) is lower (upper) limit of integration
    real, intent (in) :: llim, ulim
    real, dimension (:), intent (in) :: nodes
    real, dimension (:), intent (in out) :: wgts
    
    ! stuff needed to do guassian quadrature 
    real, dimension (:), allocatable :: gnodes, gwgts, omprod
    integer :: ix, iw

    allocate (gnodes(size(nodes)/2+1), gwgts(size(wgts)/2+1), omprod(size(nodes)/2+1))
    
    call nrgauleg(llim, ulim, gnodes, gwgts)
    
    do iw=1,size(wgts)
       omprod = 1.0
       
       do ix=1,size(nodes)
          if (ix /= iw) omprod = omprod*(gnodes - nodes(ix))/(nodes(iw) - nodes(ix))
       end do
       
       do ix=1,size(gwgts)
          wgts(iw) = wgts(iw) + omprod(ix)*gwgts(ix)
       end do
    end do
       
  end subroutine get_intrvl_weights

  subroutine set_grids
    use species, only: init_species, nspec
    use theta_grid, only: init_theta_grid, ntgrid
    use egrid, only: setegrid
    implicit none
    integer :: is

    call init_theta_grid
    call init_species

    allocate (e(negrid,nspec), w(negrid,nspec), anon(negrid,nspec))
    allocate (dele(negrid,nspec))

    call setegrid (ecut, negrid, e(:,1), w(:,1))
    do is = 2, nspec
       e(:,is) = e(:,1)
       w(:,is) = w(:,1)
    end do
    w = 0.25*w
    anon = 1.0

    dele(1,:) = e(1,:)
    dele(2:,:) = e(2:,:)-e(:negrid-1,:)

    ng2 = 2*ngauss

    nlambda = ng2

    allocate (al(nlambda), delal(nlambda))
    allocate (wl(nlambda))
    allocate (xx(nlambda))

    call setlgrid

    delal(1) = al(1)
    delal(2:) = al(2:) - al(:nlambda-1)

  end subroutine set_grids

  subroutine setlgrid

! Modified to use nrgauleg routine, Tomo Tatsuno, Aug 2005

    use legendre, only: nrgauleg
    use constants
    implicit none

    real, dimension (2*ngauss) :: wx

    call nrgauleg(1., 0., xx, wx)

    wl = 2.0*wx

    al = 1.0 - xx**2

  end subroutine setlgrid

  subroutine integrate_species1 (g, weights, total)
    use theta_grid, only: ntgrid
    use species, only: nspec
    use kgrids, only: naky, nakx
    use agk_layouts, only: g_lo, idx, idx_local
    use agk_layouts, only: is_idx, ik_idx, it_idx, ie_idx, il_idx
    use mp, only: sum_allreduce
    implicit none

    complex, dimension (-ntgrid:,:,g_lo%llim_proc:), intent (in) :: g
    complex, dimension (-ntgrid:,:,:), intent (out) :: total
    real, dimension (:), intent (in) :: weights

! total = total(theta, kx, ky)
    complex, dimension (:), allocatable :: work
    real :: fac
    integer :: is, il, ie, ik, it, iglo, ig, i

    total = 0.
    do iglo = g_lo%llim_proc, g_lo%ulim_proc
       ik = ik_idx(g_lo,iglo)
       it = it_idx(g_lo,iglo)
       ie = ie_idx(g_lo,iglo)
       is = is_idx(g_lo,iglo)
       il = il_idx(g_lo,iglo)
       fac = weights(is)*w(ie,is)*wl(il)

       total(:, it, ik) = total(:, it, ik) + fac*(g(:,1,iglo)+g(:,2,iglo))
    end do

    allocate (work((2*ntgrid+1)*naky*nakx)) ; work = 0.
    i = 0
    do ik = 1, naky
       do it = 1, nakx
          do ig = -ntgrid, ntgrid
             i = i + 1
             work(i) = total(ig, it, ik)
          end do
       end do
    end do
    
    call sum_allreduce (work) 

    i = 0
    do ik = 1, naky
       do it = 1, nakx
          do ig = -ntgrid, ntgrid
             i = i + 1
             total(ig, it, ik) = work(i)
          end do
       end do
    end do
    deallocate (work)

  end subroutine integrate_species1

  subroutine integrate_species0 (g, weights, total)
    use theta_grid, only: ntgrid
    use species, only: nspec
    use kgrids, only: naky, nakx
    use agk_layouts, only: g_lo, idx, idx_local
    use agk_layouts, only: is_idx, ik_idx, it_idx, ie_idx, il_idx
    use mp, only: sum_allreduce
    implicit none

    complex, dimension (:,g_lo%llim_proc:), intent (in) :: g
    complex, dimension (:,:), intent (out) :: total
    real, dimension (:), intent (in) :: weights

! total = total(theta, kx, ky)
    complex, dimension (:), allocatable :: work
    real :: fac
    integer :: is, il, ie, ik, it, iglo, ig, i

    total = 0.
    do iglo = g_lo%llim_proc, g_lo%ulim_proc
       ik = ik_idx(g_lo,iglo)
       it = it_idx(g_lo,iglo)
       ie = ie_idx(g_lo,iglo)
       is = is_idx(g_lo,iglo)
       il = il_idx(g_lo,iglo)
       fac = weights(is)*w(ie,is)*wl(il)

       total(it, ik) = total(it, ik) + fac*(g(1,iglo)+g(2,iglo))
    end do

    allocate (work(naky*nakx)) ; work = 0.
    i = 0
    do ik = 1, naky
       do it = 1, nakx
          i = i + 1
          work(i) = total(it, ik)
       end do
    end do
    
    call sum_allreduce (work) 

    i = 0
    do ik = 1, naky
       do it = 1, nakx
          i = i + 1
          total(it, ik) = work(i)
       end do
    end do
    deallocate (work)

  end subroutine integrate_species0

  subroutine integrate_test (g, weights, total, istep)
    use egrid, only: x0, zeroes
    use theta_grid, only: ntgrid
    use species, only: nspec
    use kgrids, only: naky, nakx
    use agk_layouts, only: g_lo, idx, idx_local
    use agk_layouts, only: is_idx, ik_idx, it_idx, ie_idx, il_idx
    use mp, only: sum_allreduce

    implicit none

    complex, dimension (-ntgrid:,:,g_lo%llim_proc:), intent (in) :: g
    complex, dimension (-ntgrid:,:,:), intent (out) :: total
    real, dimension (:), intent (in) :: weights
    integer, intent (in) :: istep

    complex, dimension (:), allocatable :: work
    real :: fac
    integer :: is, il, ie, ik, it, iglo, ig, i

!    real, dimension (:), allocatable :: xpt
    real, dimension (:), allocatable :: ypt

!    allocate(xpt(negrid))
!    xpt(:negrid-1) = zeroes
!    xpt(negrid) = x0
       
    allocate(ypt(nlambda))
    ypt = 0.0

    total = 0.
    do iglo = g_lo%llim_proc, g_lo%ulim_proc
       ik = ik_idx(g_lo,iglo)
       it = it_idx(g_lo,iglo)
       ie = ie_idx(g_lo,iglo)
       is = is_idx(g_lo,iglo)
       il = il_idx(g_lo,iglo)
       fac = weights(is)*w(ie,is)*wl(il)

! MUST COMMENT IMMEDIATELY AFTER TESTING
!       wl(:ng2) = 0.0

!       do ig=-ntgrid,ntgrid
!          if (.not. forbid(ig,il)) then
!             ypt(ig,il) = sqrt(max(1.0-bmag(ig)*al(il),0.0))
!          else
!             ypt(ig,il) = 0
!          end if
!       end do

       total(:, it, ik) = total(:, it, ik) + fac*cos(istep*0.1*ypt(il))
    end do

    allocate (work((2*ntgrid+1)*naky*nakx)) ; work = 0.
    i = 0
    do ik = 1, naky
       do it = 1, nakx
          do ig = -ntgrid, ntgrid
             i = i + 1
             work(i) = total(ig, it, ik)
          end do
       end do
    end do
    
    call sum_allreduce (work) 

    i = 0
    do ik = 1, naky
       do it = 1, nakx
          do ig = -ntgrid, ntgrid
             i = i + 1
             total(ig, it, ik) = work(i)
          end do
       end do
    end do
    deallocate (work)

    deallocate(ypt)
!    deallocate(xpt)
  end subroutine integrate_test

  subroutine eint_error (g, weights, total)
    use egrid, only: x0, zeroes
    use theta_grid, only: ntgrid
    use species, only: nspec
    use kgrids, only: naky, nakx
    use agk_layouts, only: g_lo, idx, idx_local
    use agk_layouts, only: is_idx, ik_idx, it_idx, ie_idx, il_idx
    use mp, only: sum_allreduce, proc0, broadcast

    implicit none

    complex, dimension (-ntgrid:,:,g_lo%llim_proc:), intent (in) :: g
    real, dimension (:), intent (in) :: weights
    complex, dimension (-ntgrid:,:,:,:), intent (out) :: total

    complex, dimension (:), allocatable :: work
    real, dimension (:,:), allocatable, save :: wmod
    real :: fac
    integer :: is, il, ie, ik, it, iglo, ig, i, ipt
    logical, save :: first = .true.

    if (first) then
       if (proc0) then
          allocate (wmod(negrid,negrid-1))
          wmod = 0.0
          wmod(:negrid-1,:) = werr(:,:)
          wmod(negrid,:) = w(negrid,1)
       end if

       if (.not. proc0) then
          allocate (wmod(negrid,negrid-1))
       end if

       do ie = 1, negrid-1
          call broadcast (wmod(:,ie))
       end do

       first = .false.
    end if

    do ipt=1,negrid-1
       total(:,:,:,ipt) = 0.
       do iglo = g_lo%llim_proc, g_lo%ulim_proc
          ik = ik_idx(g_lo,iglo)
          it = it_idx(g_lo,iglo)
          ie = ie_idx(g_lo,iglo)
          is = is_idx(g_lo,iglo)
          il = il_idx(g_lo,iglo)
          fac = weights(is)*wmod(ie,ipt)*wl(il)

          total(:, it, ik, ipt) = total(:, it, ik, ipt) + fac*(g(:,1,iglo)+g(:,2,iglo))
       end do

       allocate (work((2*ntgrid+1)*naky*nakx)) ; work = 0.
       i = 0
       do ik = 1, naky
          do it = 1, nakx
             do ig = -ntgrid, ntgrid
                i = i + 1
                work(i) = total(ig, it, ik, ipt)
             end do
          end do
       end do
       
       call sum_allreduce (work) 
       
       i = 0
       do ik = 1, naky
          do it = 1, nakx
             do ig = -ntgrid, ntgrid
                i = i + 1
                total(ig, it, ik, ipt) = work(i)
             end do
          end do
       end do
       deallocate (work)
    end do

  end subroutine eint_error

  subroutine lint_error (g, weights, total)
    use theta_grid, only: ntgrid
    use species, only: nspec
    use kgrids, only: naky, nakx
    use agk_layouts, only: g_lo, idx, idx_local
    use agk_layouts, only: is_idx, ik_idx, it_idx, ie_idx, il_idx
    use mp, only: sum_allreduce, proc0, broadcast

    use constants, only: pi

    implicit none

    complex, dimension (-ntgrid:,:,g_lo%llim_proc:), intent (in) :: g
    real, dimension (:), intent (in) :: weights
    complex, dimension (-ntgrid:,:,:,:), intent (out) :: total
    complex, dimension (:), allocatable :: work
    real :: fac
    integer :: is, il, ie, ik, it, iglo, ig, i, ipt
    logical, save :: first = .true.

    if (first) then
       if (.not. allocated(wlerr)) allocate (wlerr(nlambda, nlambda))
       do il = 1, nlambda
          call broadcast (wlerr(il,:))
       end do
       first = .false.
    end if

    do ipt=1,nlambda
       total(:,:,:,ipt) = 0.
       do iglo = g_lo%llim_proc, g_lo%ulim_proc
          ik = ik_idx(g_lo,iglo)
          it = it_idx(g_lo,iglo)
          ie = ie_idx(g_lo,iglo)
          is = is_idx(g_lo,iglo)
          il = il_idx(g_lo,iglo)
          fac = weights(is)*w(ie,is)*wlerr(il,ipt)

          total(:, it, ik, ipt) = total(:, it, ik, ipt) + fac*(g(:,1,iglo)+g(:,2,iglo))
       end do

       allocate (work((2*ntgrid+1)*naky*nakx)) ; work = 0.
       i = 0
       do ik = 1, naky
          do it = 1, nakx
             do ig = -ntgrid, ntgrid
                i = i + 1
                work(i) = total(ig, it, ik, ipt)
             end do
          end do
       end do
       
       call sum_allreduce (work) 
       
       i = 0
       do ik = 1, naky
          do it = 1, nakx
             do ig = -ntgrid, ntgrid
                i = i + 1
                total(ig, it, ik, ipt) = work(i)
             end do
          end do
       end do
       deallocate (work)
    end do

  end subroutine lint_error

  subroutine integrate_moment (g, total, all)
! returns results to PE 0 [or to all processors if 'all' is present in input arg list]
! NOTE: Takes f = f(x, y, z, sigma, lambda, E, species) and returns int f, where the integral
! is over all velocity space
    use mp, only: nproc
    use theta_grid, only: ntgrid
    use species, only: nspec
    use kgrids, only: naky, nakx
    use agk_layouts, only: g_lo, idx, idx_local
    use mp, only: sum_reduce, proc0, sum_allreduce
    implicit none
    complex, dimension (-ntgrid:,:,g_lo%llim_proc:), intent (in) :: g
    complex, dimension (-ntgrid:,:,:,:), intent (out) :: total
    integer, optional, intent(in) :: all

    complex, dimension (:), allocatable :: work
    real :: fac
    integer :: is, il, ie, ik, it, iglo, ig, i

    total = 0.
    do is = 1, nspec
       do ie = 1, negrid
          do il = 1, nlambda
             fac = w(ie,is)*wl(il)
             do it = 1, nakx
                do ik = 1, naky
                   iglo = idx (g_lo, ik, it, il, ie, is)
                   if (idx_local (g_lo, iglo)) then
                      do ig = -ntgrid, ntgrid
                         total(ig, it, ik, is) = total(ig, it, ik, is) + &
                              fac*(g(ig,1,iglo)+g(ig,2,iglo))
                      end do
                   end if
                end do
             end do
          end do
       end do
    end do

    if (nproc > 1) then
       allocate (work((2*ntgrid+1)*naky*nakx*nspec)) ; work = 0.
       i = 0
       do is = 1, nspec
          do ik = 1, naky
             do it = 1, nakx
                do ig = -ntgrid, ntgrid
                   i = i + 1
                   work(i) = total(ig, it, ik, is)
                end do
             end do
          end do
       end do
       
       if (present(all)) then
          call sum_allreduce (work)
       else
          call sum_reduce (work, 0)
       end if

       if (proc0 .or. present(all)) then
          i = 0
          do is = 1, nspec
             do ik = 1, naky
                do it = 1, nakx
                   do ig = -ntgrid, ntgrid
                      i = i + 1
                      total(ig, it, ik, is) = work(i)
                   end do
                end do
             end do
          end do
       end if
       deallocate (work)
    end if

  end subroutine integrate_moment

end module le_grids
