#include "options.inc"

c=======================================================================
c      solve two dimensional Possion equation
c           A * dpsi = forc,  where A = nabla_h^2  
c      with Neumann boundary conditions
c      used for surface pressure or free surface
c=======================================================================

      module congrad_module
      implicit none
      private
      public congrad
      contains

      subroutine congrad(jmt,cf,p_surf, forc,
     &                  max_iterations, iterations, epsilon)
      implicit none
      integer :: max_iterations,iterations, j,n
      real :: epsilon,estimated_error
      logical :: converged, diverged
      integer :: jmt
      real :: forc(jmt),p_surf(jmt)
      real :: res(jmt),Z(jmt),Zres(jmt)
      real :: ss(jmt),As(jmt)
      real :: cf(jmt,-1:1)
      real :: zresmax,betakm1,betak,betak_min,betaquot,s_dot_As,smax
      real :: alpha,step,step1,convergence_rate

      Z=0.;Zres=0.;ss=0.;As=0.
c-----------------------------------------------------------------------
c     impose boundary conditions on guess
c     dpsi(0) = guess
c-----------------------------------------------------------------------
c-----------------------------------------------------------------------
c     make approximate inverse operator Z (always even symmetry)
c-----------------------------------------------------------------------
      call make_inv (jmt,cf, Z)
c-----------------------------------------------------------------------
c     res(0)  = forc - A * eta(0)
c-----------------------------------------------------------------------
      call op9_vec(jmt,cf, p_surf, res)
      do j=1,jmt
       res(j) = forc(j) - res(j)
      enddo
c-----------------------------------------------------------------------
c     Zres(k-1) = Z * res(k-1)
c     see if guess is a solution, bail out to avoid division by zero
c-----------------------------------------------------------------------
      n = 0
      diverged=.false.
      call inv_op(jmt,Z, res, Zres)
      Zresmax = absmax(jmt,Zres)
c
c       Assume convergence rate of 0.99 to extrapolate error
c
      if (100.0 * Zresmax .lt. epsilon) then
	  estimated_error = 100.0 * Zresmax 
	  goto 101
      endif
c-----------------------------------------------------------------------
c     beta(0) = 1
c     ss(0)    = zerovector()
c-----------------------------------------------------------------------
      betakm1 = 1.0
      ss=0.
c-----------------------------------------------------------------------
c     begin iteration loop
c-----------------------------------------------------------------------
      do n = 1,max_iterations
c-----------------------------------------------------------------------
c       Zres(k-1) = Z * res(k-1)
c-----------------------------------------------------------------------
        call inv_op(jmt,Z, res, Zres)
c-----------------------------------------------------------------------
c       beta(k)   = res(k-1) * Zres(k-1)
c-----------------------------------------------------------------------
        betak = dot2(jmt,Zres, res)

        if (n .eq. 1) then
          betak_min = abs(betak)
        elseif (n .gt. 2) then
          betak_min = min(betak_min, abs(betak))
          if (abs(betak) .gt. 100.0*betak_min) then
           print*,' ',
     &      'WARNING: 2D solver terminated because correction'
     &,     '         steps are diverging. Probable cause...roundoff'
            diverged=.true.
            goto 101
          endif
        endif
c-----------------------------------------------------------------------
c       ss(k)      = Zres(k-1) + (beta(k)/beta(k-1)) * ss(k-1)
c-----------------------------------------------------------------------
        betaquot = betak/betakm1

        do j=1,jmt
          ss(j) = Zres(j) + betaquot * ss(j)
        enddo

c-----------------------------------------------------------------------
c       As(k)     = A * ss(k)
c-----------------------------------------------------------------------
        call op9_vec(jmt, cf, ss, As)
c-----------------------------------------------------------------------
c       If ss=0 then the division for alpha(k) gives a float exception.
c       Assume convergence rate of 0.99 to extrapolate error.
c       Also assume alpha(k) ~ 1.
c-----------------------------------------------------------------------

        s_dot_As = dot2(jmt, ss, As)

        if (abs(s_dot_As) .lt. abs(betak)*1.e-10) then
          smax = absmax(jmt,ss)
          estimated_error = 100.0 * smax 
           goto 101
        endif
c-----------------------------------------------------------------------
c       alpha(k)  = beta(k) / (ss(k) * As(k))
c-----------------------------------------------------------------------
        alpha = betak / s_dot_As

c-----------------------------------------------------------------------
c       update values:
c       eta(k)   = eta(k-1) + alpha(k) * ss(k)
c       res(k)    = res(k-1) - alpha(k) * As(k)
c-----------------------------------------------------------------------
        do j=1,jmt
          p_surf(j)  = p_surf(j) + alpha * ss(j)
          res(j)     = res(j)   - alpha * As(j)
        enddo

! adjust this parameter for better performance
! on massive parallel architectures
c        if ((mod(n,10)==0).or.(n==1)) then  
        smax = absmax(jmt, ss)
c-----------------------------------------------------------------------
c       test for convergence
c       if (estimated_error) < epsilon) exit
c-----------------------------------------------------------------------
        step = abs(alpha) * smax
        if (n .eq. 1) then
          step1 = step
          estimated_error = step
          if (step .lt. epsilon) goto 101
        else if (step .lt. epsilon) then
          convergence_rate = exp(log(step/step1)/(n-1))
	  estimated_error = step*convergence_rate/(1.0-convergence_rate)
          if (estimated_error .lt. epsilon) goto 101
        end if
c        endif ! mod(n,10)

        betakm1 = betak
c        print*,'PE=',my_pe,'n=',n,' error=',step,' eps=',epsilon
      end do
c-----------------------------------------------------------------------
c     end of iteration loop
c-----------------------------------------------------------------------
  101 continue
      if ((n .gt. max_iterations).or.(diverged)) then
          print*,'2D-Poisson solver is not converged '
        converged = .false.
      else
        converged = .true.
      end if
      iterations = n
c      if (.not.converged) call halt_stop('  ')
      end subroutine congrad



      subroutine op9_vec(jmt, cf, p1, res)
      implicit none
c-----------------------------------------------------------------------
c                       res = A *eta 
c-----------------------------------------------------------------------
      real :: cf(jmt,-1:1) 
      real :: p1(jmt), res(jmt)
      integer :: j,jj,jmt
      res=0.
      do jj=-1,1
      do j=2,jmt-1
        res(j) = res(j) + cf(j,jj)*p1(j+jj) 
      end do
      end do
      end subroutine op9_vec


      subroutine inv_op(jmt,Z, res, Zres)
      implicit none
c-----------------------------------------------------------------------
c     apply and approximate inverse Z or the operator A
c-----------------------------------------------------------------------
      real ::  Z(jmt),res(jmt),Zres(jmt)
      integer :: j,jmt
      do j=1,jmt
        Zres(j) = Z(j) * res(j)
      end do
      end subroutine inv_op


      subroutine make_inv (jmt, cf, Z)
      implicit none
c-----------------------------------------------------------------------
c     construct an approximate inverse Z to A
c-----------------------------------------------------------------------
      real :: cf(jmt,-1:1) 
      real ::  Z(jmt)
      integer :: j,jmt
      do j=1,jmt
        if (cf(j,0) .ne. 0.0) then
         Z(j) = 1./cf(j,0)
        else
         Z(j)=0.
        endif
      end do
      end subroutine make_inv


      real function absmax(jmt,p1)
      implicit none
      real :: p1(jmt),s2
      integer :: j,jmt
      s2=0
      do j=1,jmt
c       s2 = max( abs(p1(j)*maskT(j,km-1)), s2 )
c      do j=2,jmt-1
       s2 = max( abs(p1(j)), s2 )
      enddo
      absmax=s2
      end function absmax


      real  function dot2(jmt,p1,p2)
      implicit none
      real :: p1(jmt),p2(jmt),s2
      integer :: j,jmt
      s2=0
      do j=1,jmt
c        s2 = s2+p1(j)*p2(j)*maskt(j,km-1)
c      do j=2,jmt-1
        s2 = s2+p1(j)*p2(j)
      enddo
      dot2=s2
      end function dot2


      end module congrad_module
