

!=======================================================================
!   add vertical friction due to residual mean formulation
!    vertical friction of residual momentum: u_t = .. + ( K f^2/N^2 u_z)_z
!=======================================================================

 subroutine vert_friction_trm(ierr)
!-----------------------------------------------------------------------
!-----------------------------------------------------------------------
      use pyOM_module   
      implicit none
      integer, intent(out) :: ierr
      integer :: js,je,i,j,k
      real*8 :: fxa,fxb,A_uu(nx,ny,nz),diff_ft(nx,ny,nz)
      real*8 :: fNsqrw(nx,ny,nz),Nsqrw(nx,ny,nz)
      real*8 :: epsln = 1.0d-12      

      ierr = 0
      js=max(2,js_pe); je = min(je_pe,ny -1)
!-----------------------------------------------------------------------
!     stability freq.
!-----------------------------------------------------------------------
      do k=1,nz-1
       do j=js,je
        Nsqrw(:,j,k)=-(b(:,j,k+1,taum1)-b(:,j,k,taum1))/dz 
       enddo
      enddo
      if (enable_back_state) then  
       do k=1,nz-1
        do j=js,je
         Nsqrw(:,j,k)=Nsqrw(:,j,k) -(back(:,j,k+1,taum1)-back(:,j,k,taum1))/dz 
        enddo
       enddo
      endif
      do k=1,nz-1
       do j=js,je
        Nsqrw(:,j,k)=max(Nsqrw(:,j,k),N_min**2)
       enddo
      enddo
!-----------------------------------------------------------------------
!     Calculate f**2/N**2 on W grid and bound 
!-----------------------------------------------------------------------
      do k=1,nz-1
       do j=js,je
        do i=2,nx -1
         fNsqrw(i,j,k)= coriolis_t(j)**2/Nsqrw(i,j,k) 
         fNsqrw(i,j,k)=min(fNsqrw(i,j,k),fNsqr_max)
        enddo
       enddo
      enddo
!-----------------------------------------------------------------------
!    vertical viscosity
!-----------------------------------------------------------------------
      do k=1,nz-1
       do j=js,je
        do i=2,nx -1
          A_trm(i,j,k)=K_gm*fNsqrw(i,j,k)
        enddo
       enddo
      enddo

      call border_exchg3D(nx,ny,nz,A_trm,1)
      call setcyclic3D(nx,ny,nz,A_trm)
!-----------------------------------------------------------------------
!      prepare coefficients for implicit part of vertical friction
!-----------------------------------------------------------------------
       A_uu(:,js:je,:)=0
       do k=1,nz-1
        do j=js,je
         do i=2,nx -1
          fxa=A_trm(i,j,k)*maskW(i,j,k)+A_trm(i+1,j,k)*maskW(i+1,j,k)
          fxb=maskW(i,j,k)+maskW(i+1,j,k)+epsln
          A_uu(i,j,k) =fxa/fxb*maskU(i,j,k+1)*maskU(i,j,k)
         enddo
        enddo
       enddo
!-----------------------------------------------------------------------
!     now explicit part
!-----------------------------------------------------------------------
       diff_ft(:,js:je,:)=0.
       do k=1,nz-1
        do j=js,je
         do i=2,nx -1
          diff_ft(i,j,k)=A_uu(i,j,k) *(u(i,j,k+1,taum1)-u(i,j,k,taum1))/dz
         enddo
        enddo
       enddo
!-----------------------------------------------------------------------
!     Add to zonal momentum tendencies
!-----------------------------------------------------------------------
      do k=2,nz-1
       do j=js,je
        do i=2,nx -1
         fu(i,j,k)=fu(i,j,k)+(1-aidif_trm)* maskU(i,j,k)*(diff_ft(i,j,k)-diff_ft(i,j,k-1))/dz
        enddo
       enddo
      enddo
!-----------------------------------------------------------------------
!     implicit part of vertical operator
!-----------------------------------------------------------------------
      call trm_implicit_umix(nx,ny,nz,A_uu)
!-----------------------------------------------------------------------
!      prepare coefficients for implicit part of vertical friction
!-----------------------------------------------------------------------
       A_uu(:,js:je,:)=0
       do k=1,nz-1
        do j=js,je
         do i=2,nx -1
          !fxa=A_trm(i,j,k)*maskW(i,j,k)+A_trm(i,j+1,k)*maskW(i+1,j,k) ! error (j+1) ??
          fxa=A_trm(i,j,k)*maskW(i,j,k)+A_trm(i,j+1,k)*maskW(i,j+1,k) 
          fxb=maskW(i,j,k)+maskW(i,j+1,k)+epsln
          A_uu(i,j,k) =fxa/fxb*maskV(i,j,k+1)*maskV(i,j,k)
         enddo
        enddo
       enddo
!-----------------------------------------------------------------------
!     now explicit part
!-----------------------------------------------------------------------
       diff_ft(:,js:je,:)=0.
       do k=1,nz-1
        do j=js,je
         do i=2,nx -1
          diff_ft(i,j,k)=A_uu(i,j,k)*(v(i,j,k+1,taum1)-v(i,j,k,taum1))/dz
         enddo
        enddo
       enddo
!-----------------------------------------------------------------------
!     Add to meridional momentum tendencies
!-----------------------------------------------------------------------
      do k=2,nz-1
       do j=js,je
        do i=2,nx -1
         fv(i,j,k)=fv(i,j,k)+(1-aidif_trm)*maskV(i,j,k)* (diff_ft(i,j,k)-diff_ft(i,j,k-1))/dz
        enddo
       enddo
      enddo
!-----------------------------------------------------------------------
!     implicit part of vertical operator
!-----------------------------------------------------------------------
      call trm_implicit_vmix(nx,ny,nz,A_uu)

 end subroutine vert_friction_trm



 subroutine trm_implicit_vmix(nx_,ny_,nz_,A_uu)
!---------------------------------------------------------------------------------
!     same for v
!---------------------------------------------------------------------------------
      use pyOM_module   
      implicit none
      integer :: nx_,ny_,nz_
      real*8 :: A_uu(nx_,ny_,nz_) 
      integer :: j,k,js,je
      real*8 :: a(nx,nz),bb(nx,nz),c(nx,nz),bet(nx)
      real*8 :: pu(nx,nz),gam(nx,nz),fxa,r(nx,nz)

      js=max(2,js_pe); je = min(je_pe,ny-1)
!---------------------------------------------------------------------------------
!      first fake integrate dv/dt = F_v, then solve for rest
!---------------------------------------------------------------------------------
      v(:,js:je,:,taup1)= v(:,js:je,:,taum1) +c2dt*fv(:,js:je,:)*maskV(:,js:je,:)
      fxa = aidif_trm*c2dt/dz**2

      do j=js,je
       bb(:,1) = 1+fxa * A_uu(:,j,1)
       c(:,1)  =  -fxa * A_uu(:,j,1)
       do k=2,nz-1
         a(:,k)  =  -fxa * A_uu(:,j,k-1)
         bb(:,k) = 1+fxa * (A_uu(:,j,k)+A_uu(:,j,k-1) )
         c(:,k)  =  -fxa * A_uu(:,j,k)
       enddo
       a(:,nz)  =  -fxa * A_uu(:,j,nz-1)
       bb(:,nz) = 1+fxa * A_uu(:,j,nz-1) 
       pu=0.0;gam=0.0
       r=v(:,j,:,taup1)*maskV(:,j,:)
       bet=bb(:,1)
       where (bet/=0.0) pu(:,1)=r(:,1)/bet
       do k=2,nz
        where (bet/=0.0) gam(:,k)=c(:,k-1)/bet
        bet=bb(:,k)-a(:,k)*gam(:,k)
        where(bet/=0.0) pu(:,k)=(r(:,k)-a(:,k)*pu(:,k-1))/bet
       enddo
       do k=nz-1,1,-1
        pu(:,k)=pu(:,k)-gam(:,k+1)*pu(:,k+1)
       enddo
       v(:,j,:,taup1)=pu
      enddo
      fv(:,js:je,:)=(v(:,js:je,:,taup1) -v(:,js:je,:,taum1))/c2dt*maskV(:,js:je,:)
 end subroutine trm_implicit_vmix



 subroutine trm_implicit_umix(nx_,ny_,nz_,A_uu)
!=======================================================================
!     implicit vertical friction
!=======================================================================
      use pyOM_module   
      implicit none
      integer :: nx_,ny_,nz_
      real*8 :: A_uu(nx_,ny_,nz_) 
      integer :: j,k,js,je
      real*8 :: a(nx,nz),bb(nx,nz),c(nx,nz),bet(nx)
      real*8 :: pu(nx,nz),gam(nx,nz),fxa,r(nx,nz)

      js=max(2,js_pe); je = min(je_pe,ny-1)
!---------------------------------------------------------------------------------
!      first fake integrate du/dt = F_u, then solve for rest
!---------------------------------------------------------------------------------
      u(:,js:je,:,taup1)= u(:,js:je,:,taum1)+c2dt*fu(:,js:je,:)*maskU(:,js:je,:)
      fxa = aidif_trm*c2dt/dz**2

      do j=js,je
       bb(:,1) = 1+fxa * A_uu(:,j,1)
       c(:,1)  =  -fxa * A_uu(:,j,1)
       do k=2,nz-1
         a(:,k)  =  -fxa * A_uu(:,j,k-1)
         bb(:,k) = 1+fxa * (A_uu(:,j,k)+A_uu(:,j,k-1) )
         c(:,k)  =  -fxa * A_uu(:,j,k)
       enddo
       a(:,nz)  =  -fxa * A_uu(:,j,nz-1)
       bb(:,nz) = 1+fxa * A_uu(:,j,nz-1) 
       pu=0.0;gam=0.0
       r=u(:,j,:,taup1)*maskU(:,j,:)
       bet=bb(:,1)
       where (bet/=0.0) pu(:,1)=r(:,1)/bet
       do k=2,nz
        where (bet/=0.0) gam(:,k)=c(:,k-1)/bet
        bet=bb(:,k)-a(:,k)*gam(:,k)
        where(bet/=0.0) pu(:,k)=(r(:,k)-a(:,k)*pu(:,k-1))/bet
       enddo
       do k=nz-1,1,-1
        pu(:,k)=pu(:,k)-gam(:,k+1)*pu(:,k+1)
       enddo
       u(:,j,:,taup1)=pu
      enddo

      fu(:,js:je,:)=(u(:,js:je,:,taup1)-u(:,js:je,:,taum1))/c2dt*maskU(:,js:je,:)
 end subroutine trm_implicit_umix



