
!=======================================================================
!      solve three dimensional Possion equation
!           A * dpsi = forc,  where A = nabla^2  
!      with Neumann boundary conditions
!      used for non-hydrostatic pressure
!=======================================================================


 subroutine congrad3D(nx_,ny_,nz_,cf, forc, max_iterations, iterations, epsilon)
      use pyOM_module   
      implicit none
!=======================================================================
!      solve:  A * dpsi = forc
!=======================================================================
      integer :: nx_,ny_,nz_
      real*8 :: forc(nx_,ny_,nz_),cf(nx_,ny_,nz_,3,3,3)
      integer :: max_iterations,iterations,i,j,k,n
      real*8 :: epsilon,estimated_error
      real*8 :: res(nx,ny,nz),Z(nx,ny,nz),Zres(nx,ny,nz)
      real*8 :: s(nx,ny,nz),As(nx,ny,nz)
      real*8 :: zresmax,betakm1,betak,betak_min=0,betaquot,s_dot_As,smax
      real*8 :: alpha,step,step1=0,convergence_rate
      logical :: diverged,converged
      real*8, external :: dot2_3D,absmax_3D

      Z=0.;Zres=0.;s=0.;As=0.
!-----------------------------------------------------------------------
!     impose boundary conditions on guess
!     dpsi(0) = guess
!-----------------------------------------------------------------------
      call setcyclic3D(nx_,ny_,nz_,psi)
!-----------------------------------------------------------------------
!     make approximate inverse operator Z (always even symmetry)
!-----------------------------------------------------------------------
      call make_inv_3D (nx_,ny_,nz_,cf, Z )
      call setcyclic3D(nx_,ny_,nz_,Z)
!-----------------------------------------------------------------------
!     res(0)  = forc - A * dpsi(0)
!-----------------------------------------------------------------------
      call op9_vec_3D(nx_,ny_,nz_,cf, psi, res)
      do k=1,nz
       do j=js_pe,je_pe
        do i=1,nx
         res(i,j,k) = forc(i,j,k) - res(i,j,k)
        enddo
       enddo
      enddo
      call setcyclic3D(nx_,ny_,nz_,res)
!-----------------------------------------------------------------------
!     Zres(k-1) = Z * res(k-1)
!     see if guess is a solution, bail out to avoid division by zero
!-----------------------------------------------------------------------
      n = 0
      diverged=.false.
      call inv_op_3D(nx_,ny_,nz_,Z, res, Zres)
      call setcyclic3D(nx_,ny_,nz_,Zres)
      Zresmax = absmax_3D(nx_,ny_,nz_,Zres)
!       Assume convergence rate of 0.99 to extrapolate error
      if (100.0 * Zresmax .lt. epsilon) then
       estimated_error = 100.0 * Zresmax 
       goto 101
      endif
!-----------------------------------------------------------------------
!     beta(0) = 1
!     s(0)    = zerovector()
!-----------------------------------------------------------------------
      betakm1 = 1.0
      s=0.
!-----------------------------------------------------------------------
!     begin iteration loop
!-----------------------------------------------------------------------
      do n = 1,max_iterations
!-----------------------------------------------------------------------
!       Zres(k-1) = Z * res(k-1)
!-----------------------------------------------------------------------
        call inv_op_3D(nx_,ny_,nz_,Z, res, Zres)
        call setcyclic3D(nx_,ny_,nz_,Zres)
!-----------------------------------------------------------------------
!       beta(k)   = res(k-1) * Zres(k-1)
!-----------------------------------------------------------------------
        betak = dot2_3D(nx_,ny_,nz_,Zres, res)
        if (n .eq. 1) then
          betak_min = abs(betak)
        elseif (n .gt. 2) then
          betak_min = min(betak_min, abs(betak))
          if (abs(betak) .gt. 100.0*betak_min) then
           print*,'PE ',my_pe,' : ','WARNING: 3D conj. gradient terminated because correction', &
          '         steps are diverging. Probable cause...roundoff'
            diverged=.true.
            goto 101
         endif
        endif
!-----------------------------------------------------------------------
!       s(k)      = Zres(k-1) + (beta(k)/beta(k-1)) * s(k-1)
!-----------------------------------------------------------------------
        betaquot = betak/betakm1
        do k=1,nz
         do j=js_pe,je_pe
          do i=1,nx
           s(i,j,k) = Zres(i,j,k) + betaquot * s(i,j,k)
          enddo
         enddo
        enddo
        call setcyclic3D(nx_,ny_,nz_,s)
        call border_exchg3D(nx_,ny_,nz_,s,1)

!-----------------------------------------------------------------------
!       As(k)     = A * s(k)
!-----------------------------------------------------------------------
        call op9_vec_3D(nx_,ny_,nz_,cf, s, As)
        call setcyclic3D(nx_,ny_,nz_,As)
!-----------------------------------------------------------------------
!       If s=0 then the division for alpha(k) gives a float exception.
!       Assume convergence rate of 0.99 to extrapolate error.
!       Also assume alpha(k) ~ 1.
!-----------------------------------------------------------------------
        s_dot_As = dot2_3D(nx_,ny_,nz_,s, As)
        if (abs(s_dot_As) .lt. abs(betak)*1.e-10) then
          smax = absmax_3D(nx_,ny_,nz_,s)
          estimated_error = 100.0 * smax 
          goto 101
        endif
!-----------------------------------------------------------------------
!       alpha(k)  = beta(k) / (s(k) * As(k))
!-----------------------------------------------------------------------
        alpha = betak / s_dot_As
!-----------------------------------------------------------------------
!       update values:
!       dpsi(k)   = dpsi(k-1) + alpha(k) * s(k)
!       res(k)    = res(k-1) - alpha(k) * As(k)
!-----------------------------------------------------------------------
        do k=1,nz
         do j=js_pe,je_pe
          do i=1,nx
           psi(i,j,k)  = psi(i,j,k) + alpha * s(i,j,k)
           res(i,j,k)  = res(i,j,k) - alpha * As(i,j,k)
          enddo
         enddo
        enddo
        smax = absmax_3D(nx_,ny_,nz_,s)
!-----------------------------------------------------------------------
!       test for convergence
!       if (estimated_error) < epsilon) exit
!-----------------------------------------------------------------------
        step = abs(alpha) * smax
        if (n .eq. 1) then
          step1 = step
          estimated_error = step
          if (step .lt. epsilon) goto 101
        else if (step .lt. epsilon) then
          convergence_rate = exp(log(step/step1)/(n-1))
          estimated_error = step*convergence_rate/(1.0-convergence_rate)
          if (estimated_error .lt. epsilon) goto 101
        end if
        betakm1 = betak
      end do
!-----------------------------------------------------------------------
!     end of iteration loop
!-----------------------------------------------------------------------
  101 continue
      if ((n .gt. max_iterations).or.(diverged)) then
          print*,'PE ',my_pe,' : 3D-Poisson solver is not converged '
        converged = .false.
      else
        converged = .true.
      end if
      iterations = n
 end subroutine congrad3D




 subroutine op9_vec_3D(nx_,ny_,nz_,cf, p1, res)
      use pyOM_module   
      implicit none
!-----------------------------------------------------------------------
!                       res = A * dpsi
!-----------------------------------------------------------------------
      integer :: nx_,ny_,nz_
      real*8 :: cf(nx_,ny_,nz_,3,3,3) 
      real*8 :: p1(nx_,ny_,nz_), res(nx_,ny_,nz_)
      integer :: i,j,k,ii,jj,kk,js,je
      js =max(2,js_pe);   je  = min(je_pe,ny-1)
      res=0.
      do kk=-1,1
       do jj=-1,1
        do ii=-1,1
      do k=2,nz-1
       do j=js,je
        do i=2,nx-1
          res(i,j,k) = res(i,j,k) + cf(i,j,k,ii+2,jj+2,kk+2)*p1(i+ii,j+jj,k+kk) 
        end do
       end do
      end do
        end do
       end do
      end do
      
 end subroutine op9_vec_3D



 subroutine inv_op_3D(nx_,ny_,nz_,Z, res, Zres)
      use pyOM_module   
      implicit none
!-----------------------------------------------------------------------
!     apply and approximate inverse Z or the operator A
!-----------------------------------------------------------------------
      real*8 ::  Z(nx_,ny_,nz_),res(nx_,ny_,nz_),Zres(nx_,ny_,nz_)
      integer :: i,j,k
      integer :: nx_,ny_,nz_
      do k=1,nz
       do j=js_pe,je_pe
        do i=1,nx
          Zres(i,j,k) = Z(i,j,k) * res(i,j,k)
        end do
       end do
      end do
 end subroutine inv_op_3D


 subroutine make_inv_3D (nx_,ny_,nz_,cf, Z)
      use pyOM_module   
      implicit none
!-----------------------------------------------------------------------
!     construct an approximate inverse Z to A
!-----------------------------------------------------------------------
      real*8 :: cf(nx_,ny_,nz_,3,3,3) 
      real*8 ::  Z(nx_,ny_,nz_)
      integer :: i,j,k
      integer :: nx_,ny_,nz_
      do k=1,nz
       do j=js_pe,je_pe
        do i=1,nx
          if (cf(i,j,k,2,2,2)/=.0) then
            Z(i,j,k) = 1./cf(i,j,k,2,2,2)
          else
            Z(i,j,k)=0.
          endif
        end do
       end do
      end do
 end subroutine make_inv_3D


 real*8 function absmax_3D(nx_,ny_,nz_,p1)
      use pyOM_module   
      implicit none
      real*8 :: p1(nx_,ny_,nz_),s
      integer :: i,j,k
      integer :: nx_,ny_,nz_
      s=0
      do k=2,nz-1
       do j=js_pe,je_pe
        do i=2,nx-1
         s = max( abs(p1(i,j,k)*maskT(i,j,k)), s )
        enddo
       enddo
      enddo
      call global_max(s)
      absmax_3D=s
 end function absmax_3D


 real*8  function dot2_3D(nx_,ny_,nz_,p1,p2)
      use pyOM_module   
      implicit none
      real*8 :: p1(nx_,ny_,nz_),p2(nx_,ny_,nz_),s
      integer :: i,j,k
      integer :: nx_,ny_,nz_
      s=0
      do k=1,nz
       do j=js_pe,je_pe
        do i=1,nx
         s = s+p1(i,j,k)*p2(i,j,k)*maskt(i,j,k)
        enddo
       enddo
      enddo
      call global_sum(s)
      dot2_3D=s
 end function dot2_3D










