#include "options.inc"

c=======================================================================
c      solve two dimensional Possion equation
c           A * dpsi = forc,  where A = nabla_h^2  
c      with Neumann boundary conditions
c      used for surface pressure or free surface
c=======================================================================

      module congrad2D_module
      implicit none
      private
      public congrad2D
      contains

      subroutine congrad2D(cf, forc,
     &                  max_iterations, iterations, epsilon)
      use cpflame_module
      implicit none
      integer :: max_iterations,iterations, i,j,k,n
      real :: epsilon,estimated_error
      logical :: converged, diverged
      real :: forc(imt,jmt)
      real :: res(imt,jmt),Z(imt,jmt),Zres(imt,jmt)
      real :: ss(imt,jmt),As(imt,jmt)
      real :: cf(imt,jmt,-1:1,-1:1)
      real :: zresmax,betakm1,betak,betak_min,betaquot,s_dot_As,smax
      real :: alpha,step,step1,convergence_rate

      Z=0.;Zres=0.;ss=0.;As=0.
c-----------------------------------------------------------------------
c     impose boundary conditions on guess
c     dpsi(0) = guess
c-----------------------------------------------------------------------
      call setcyclic2D(p_surf)
c-----------------------------------------------------------------------
c     make approximate inverse operator Z (always even symmetry)
c-----------------------------------------------------------------------
      call make_inv (cf, Z)
      call setcyclic2D(Z)

c-----------------------------------------------------------------------
c     res(0)  = forc - A * eta(0)
c-----------------------------------------------------------------------
      call op9_vec(cf, p_surf, res)
      do j=js_pe,je_pe
       do i=1,imt
        res(i,j) = forc(i,j) - res(i,j)
       enddo
      enddo
      call setcyclic2D(res)
c-----------------------------------------------------------------------
c     Zres(k-1) = Z * res(k-1)
c     see if guess is a solution, bail out to avoid division by zero
c-----------------------------------------------------------------------
      n = 0
      diverged=.false.
      call inv_op(Z, res, Zres)
      call setcyclic2D(Zres)
      Zresmax = absmax(Zres)
c
c       Assume convergence rate of 0.99 to extrapolate error
c
      if (100.0 * Zresmax .lt. epsilon) then
	  estimated_error = 100.0 * Zresmax 
	  goto 101
      endif
c-----------------------------------------------------------------------
c     beta(0) = 1
c     ss(0)    = zerovector()
c-----------------------------------------------------------------------
      betakm1 = 1.0
      ss=0.
c-----------------------------------------------------------------------
c     begin iteration loop
c-----------------------------------------------------------------------
      do n = 1,max_iterations
c-----------------------------------------------------------------------
c       Zres(k-1) = Z * res(k-1)
c-----------------------------------------------------------------------
        call inv_op(Z, res, Zres)
        call setcyclic2D(Zres)
c-----------------------------------------------------------------------
c       beta(k)   = res(k-1) * Zres(k-1)
c-----------------------------------------------------------------------
        betak = dot2(Zres, res)

        if (n .eq. 1) then
          betak_min = abs(betak)
        elseif (n .gt. 2) then
          betak_min = min(betak_min, abs(betak))
          if (abs(betak) .gt. 100.0*betak_min) then
           print*,'PE ',my_pe,' : ',
     &      'WARNING: 2D solver terminated because correction'
     &,     '         steps are diverging. Probable cause...roundoff'
            diverged=.true.
            goto 101
          endif
        endif
c-----------------------------------------------------------------------
c       ss(k)      = Zres(k-1) + (beta(k)/beta(k-1)) * ss(k-1)
c-----------------------------------------------------------------------
        betaquot = betak/betakm1

        do j=js_pe,je_pe
         do i=1,imt
          ss(i,j) = Zres(i,j) + betaquot * ss(i,j)
         enddo
        enddo
        call setcyclic2D(ss)
        call border_exchg2D(ss,1)

c-----------------------------------------------------------------------
c       As(k)     = A * ss(k)
c-----------------------------------------------------------------------
        call op9_vec(cf, ss, As)
        call setcyclic2D(As)
c-----------------------------------------------------------------------
c       If ss=0 then the division for alpha(k) gives a float exception.
c       Assume convergence rate of 0.99 to extrapolate error.
c       Also assume alpha(k) ~ 1.
c-----------------------------------------------------------------------

        s_dot_As = dot2(ss, As)

        if (abs(s_dot_As) .lt. abs(betak)*1.e-10) then
          smax = absmax(ss)
          estimated_error = 100.0 * smax 
           goto 101
        endif
c-----------------------------------------------------------------------
c       alpha(k)  = beta(k) / (ss(k) * As(k))
c-----------------------------------------------------------------------
        alpha = betak / s_dot_As

c-----------------------------------------------------------------------
c       update values:
c       eta(k)   = eta(k-1) + alpha(k) * ss(k)
c       res(k)    = res(k-1) - alpha(k) * As(k)
c-----------------------------------------------------------------------
        do j=js_pe,je_pe
         do i=1,imt
          p_surf(i,j)  = p_surf(i,j) + alpha * ss(i,j)
          res(i,j)     = res(i,j)   - alpha * As(i,j)
         enddo
        enddo

! adjust this parameter for better performance
! on massive parallel architectures
c        if ((mod(n,10)==0).or.(n==1)) then  
        smax = absmax(ss)
c-----------------------------------------------------------------------
c       test for convergence
c       if (estimated_error) < epsilon) exit
c-----------------------------------------------------------------------
        step = abs(alpha) * smax
        if (n .eq. 1) then
          step1 = step
          estimated_error = step
          if (step .lt. epsilon) goto 101
        else if (step .lt. epsilon) then
          convergence_rate = exp(log(step/step1)/(n-1))
	  estimated_error = step*convergence_rate/(1.0-convergence_rate)
          if (estimated_error .lt. epsilon) goto 101
        end if
c        endif ! mod(n,10)

        betakm1 = betak
c        print*,'PE=',my_pe,'n=',n,' error=',step,' eps=',epsilon
      end do
c-----------------------------------------------------------------------
c     end of iteration loop
c-----------------------------------------------------------------------
  101 continue
      if ((n .gt. max_iterations).or.(diverged)) then
          print*,'PE ',my_pe,' : 2D-Poisson solver is not converged '
        converged = .false.
      else
        converged = .true.
      end if
      iterations = n
c      if (.not.converged) call halt_stop('  ')
      end subroutine congrad2D



      subroutine op9_vec(cf, p1, res)
      use cpflame_module
      implicit none
c-----------------------------------------------------------------------
c                       res = A *eta 
c-----------------------------------------------------------------------
      real :: cf(imt,jmt,-1:1,-1:1) 
      real :: p1(imt,jmt), res(imt,jmt)
      integer :: i,j,ii,jj,js,je
      js =max(2,js_pe);   je  = min(je_pe,jmt-1)
      res(:,js_pe:je_pe)=0.
      do jj=-1,1
       do ii=-1,1
      do j=js,je
       do i=2,imt-1
        res(i,j) = res(i,j) + cf(i,j,ii,jj)*p1(i+ii,j+jj) 
       end do
      end do
       end do
      end do
      end subroutine op9_vec


      subroutine inv_op(Z, res, Zres)
      use cpflame_module
      implicit none
c-----------------------------------------------------------------------
c     apply and approximate inverse Z or the operator A
c-----------------------------------------------------------------------
      real ::  Z(imt,jmt),res(imt,jmt),Zres(imt,jmt)
      integer :: i,j
      do j=js_pe,je_pe
       do i=1,imt
        Zres(i,j) = Z(i,j) * res(i,j)
       end do
      end do
      end subroutine inv_op


      subroutine make_inv (cf, Z)
      use cpflame_module
      implicit none
c-----------------------------------------------------------------------
c     construct an approximate inverse Z to A
c-----------------------------------------------------------------------
      real :: cf(imt,jmt,-1:1,-1:1) 
      real ::  Z(imt,jmt)
      integer :: i,j
      do j=js_pe,je_pe
       do i=1,imt
        if (cf(i,j,0,0) .ne. 0.0) then
         Z(i,j) = 1./cf(i,j,0,0)
        else
         Z(i,j)=0.
        endif
       end do
      end do
      end subroutine make_inv


      real function absmax(p1)
      use cpflame_module
      implicit none
      real :: p1(imt,jmt),s2
      integer :: i,j
      s2=0
      do j=js_pe,je_pe
       do i=2,imt-1
        s2 = max( abs(p1(i,j)*maskT(i,j,km-1)), s2 )
       enddo
      enddo
      if (n_pes>1) call global_max(s2)
      absmax=s2
      end function absmax


      real  function dot2(p1,p2)
      use cpflame_module
      implicit none
      real :: p1(imt,jmt),p2(imt,jmt),s2
      integer :: i,j
      s2=0
      do j=js_pe,je_pe
       do i=1,imt
        s2 = s2+p1(i,j)*p2(i,j)*maskt(i,j,km-1)
       enddo
      enddo
      call global_sum(s2)
      dot2=s2
      end function dot2


      end module congrad2D_module
