#include "options.inc"


c=======================================================================
c   add vertical friction due to residual mean formulation
c    vertical friction of residual momentum: u_t = .. + ( K f^2/N^2 u_z)_z
c    linked in the code in driver.F only 
c=======================================================================

#ifdef enable_vert_friction_trm

c   write diffusivity to netcdf file
c#define enable_vert_friction_diagnostics

      module vert_friction_trm_module
c-----------------------------------------------------------------------
c-----------------------------------------------------------------------
      implicit none
      real, allocatable :: A_trm(:,:,:) ! vertical viscosity
      real, allocatable :: K_thk(:,:,:) ! thickness diffusivity
      real, allocatable :: Nsqrw(:,:,:) 
      real, allocatable :: fNsqrw(:,:,:) 
      real, parameter :: epsln = 1.0e-20 ! a small parameter
      real, parameter :: K_const=1e3  ! thickness diffusivity in m^2/s
      real :: N_min     = 1e-5       ! minimal threshold for stability freq. N
      real :: fNsqr_max = 0.01       ! limit f^2/N^2 by this value
      real :: aidif_momentum = 1.0   ! choose implicit (=1)/explicit(=0) formulation 
      end module vert_friction_trm_module

      subroutine init_vert_friction_trm
c-----------------------------------------------------------------------
c     initialize module
c-----------------------------------------------------------------------
      use cpflame_module
      use vert_friction_trm_module
      implicit none
      real :: fxa
      if (my_pe==0) print*,' Initializing TRM module '
      allocate( A_trm(imt,jmt,km) ); A_trm=0.
      allocate( K_thk(imt,jmt,km) ); K_thk=K_const
      allocate( Nsqrw(imt,jmt,km) ); Nsqrw=0.
      allocate( fNsqrw(imt,jmt,km) ); fNsqrw=0.
#ifdef enable_vert_friction_diagnostics
      if (my_pe==0) call init_vert_friction_trm_diag
#endif
      if (my_pe==0) then
         print*,' K_const =  ',k_const,' m^2/s'
         print*,' N_min   =  ',N_min,' 1/s'
         print*,' fNsqr_max   =  ',fNsqr_max
      endif
      if (my_pe==0) print*,' done initializing TRM module '
      end subroutine init_vert_friction_trm

      subroutine vert_friction_trm
c-----------------------------------------------------------------------
c     Verteiler funktion
c-----------------------------------------------------------------------
      use cpflame_module
      use vert_friction_trm_module
      implicit none
      integer :: js,je,i,j,k,kp
      real :: fxa,fxb,A_uu(imt,jmt,km),diff_ft(imt,jmt,km)

      js=max(2,js_pe); je = min(je_pe,jmt-1)


#ifdef notdef

      do k=1,km-1
       do j=js,je
c        Nsqrw(:,j,k)=-(b(:,j,k+1,taum1)-b(:,j,k,taum1))/dz
        Nsqrw(:,j,k)=-(b_r(k+1)-b_r(k))/dz
        Nsqrw(:,j,k)=max(Nsqrw(:,j,k),N_min**2)
       enddo
      enddo

      diff_ft=0.
      do k=2,km-1
       do j=js,je
        do i=2,imt-1
         kp=min(k+1,km-1)
         diff_ft(i,j,k)=-(b(i,j,k,taum1)-b_r(k)
     &                   +b(i,j,kp,taum1)-b_r(kp) )/2.0
     &                    /Nsqrw(i,j,k)
        enddo
       enddo
      enddo
      diff_ft(:,:,1)=diff_ft(:,:,2)
c      diff_ft(:,:,km-1)=diff_ft(:,:,km-2)

      do k=2,km-1
       do j=js,je
        do i=2,imt-1
         fu(i,j,k)=fu(i,j,k)+maskU(i,j,k)*( 
     &     -coriolis_t(j)*(u(i,j  ,k,2,tau)+u(i+1,j  ,k,2,tau)+
     &                    u(i,j-1,k,2,tau)+u(i+1,j-1,k,2,tau))/4.0
     &           *( (diff_ft(i  ,j,k)-diff_ft(i  ,j,k-1))/dz
     &            + (diff_ft(i+1,j,k)-diff_ft(i+1,j,k-1))/dz )/2.0 )
        enddo
       enddo
      enddo

      do k=2,km-1
       do j=js,je
        do i=2,imt-1
         fv(i,j,k)= fv(i,j,k)+maskV(i,j,k)*(
     &     +(coriolis_t(j  )*(u(i-1,j  ,k,1,tau)+u(i,j  ,k,1,tau)) 
     &      +coriolis_t(j+1)*(u(i-1,j+1,k,1,tau)+u(i,j+1,k,1,tau)))/4.0
     &           *( (diff_ft(i,j  ,k)-diff_ft(i,j  ,k-1))/dz
     &            + (diff_ft(i,j+1,k)-diff_ft(i,j+1,k-1))/dz )/2.0 )
        enddo
       enddo
      enddo
#endif


#ifndef notdef
c-----------------------------------------------------------------------
c     stability freq.
c-----------------------------------------------------------------------
      do k=1,km-1
       do j=js,je
        Nsqrw(:,j,k)=-(b(:,j,k+1,taum1)-b(:,j,k,taum1))/dz
        Nsqrw(:,j,k)=max(Nsqrw(:,j,k),N_min**2)
       enddo
      enddo
c-----------------------------------------------------------------------
c     Calculate f**2/N**2 on W grid and bound 
c-----------------------------------------------------------------------
      do k=1,km-1
       do j=js,je
        do i=2,imt-1
         fNsqrw(i,j,k)= coriolis_t(j)**2/Nsqrw(i,j,k)
         fNsqrw(i,j,k)=min(fNsqrw(i,j,k),fNsqr_max)
        enddo
       enddo
      enddo
c-----------------------------------------------------------------------
c    vertical viscosity
c-----------------------------------------------------------------------
      do k=1,km-1
       do j=js,je
        do i=2,imt-1
          fxa=K_thk(i,j,k)*maskT(i,j,k)+K_thk(i,j,k+1)*maskW(i,j,k) 
          fxb=maskT(i,j,k)+maskW(i,j,k)+epsln
          A_trm(i,j,k)=fxa/fxb*fNsqrw(i,j,k)
        enddo
       enddo
      enddo
      call border_exchg3D(A_trm,1); call setcyclic3D(A_trm)
c-----------------------------------------------------------------------
c      prepare coefficients for implicit part of vertical friction
c-----------------------------------------------------------------------
       A_uu(:,js:je,:)=0
       do k=1,km-1
        do j=js,je
         do i=2,imt-1
          fxa=A_trm(i,j,k)*maskW(i,j,k)+A_trm(i+1,j,k)*maskW(i+1,j,k)
          fxb=maskW(i,j,k)+maskW(i+1,j,k)+epsln
          A_uu(i,j,k) =fxa/fxb*maskU(i,j,k+1)*maskU(i,j,k)
         enddo
        enddo
       enddo
c-----------------------------------------------------------------------
c     now explicit part
c-----------------------------------------------------------------------
       diff_ft(:,js:je,:)=0.
       do k=1,km-1
        do j=js,je
         do i=2,imt-1
          diff_ft(i,j,k)=A_uu(i,j,k)
     &              *(u(i,j,k+1,1,taum1)-u(i,j,k,1,taum1))/dz
         enddo
        enddo
       enddo
c-----------------------------------------------------------------------
c     Add to zonal momentum tendencies
c-----------------------------------------------------------------------
      do k=2,km-1
       do j=js,je
        do i=2,imt-1
         fu(i,j,k)=fu(i,j,k)+(1-aidif_momentum)*
     &           maskU(i,j,k)*(diff_ft(i,j,k)-diff_ft(i,j,k-1))/dz
        enddo
       enddo
      enddo
c-----------------------------------------------------------------------
c     implicit part of vertical operator
c-----------------------------------------------------------------------
      call trm_implicit_umix(A_uu,aidif_momentum)
c-----------------------------------------------------------------------
c      prepare coefficients for implicit part of vertical friction
c-----------------------------------------------------------------------
       A_uu(:,js:je,:)=0
       do k=1,km-1
        do j=js,je
         do i=2,imt-1
          fxa=A_trm(i,j,k)*maskW(i,j,k)+A_trm(i,j+1,k)*maskW(i+1,j,k) ! error (j+1) ??
          fxb=maskW(i,j,k)+maskW(i,j+1,k)+epsln
          A_uu(i,j,k) =fxa/fxb*maskV(i,j,k+1)*maskV(i,j,k)
         enddo
        enddo
       enddo
c-----------------------------------------------------------------------
c     now explicit part
c-----------------------------------------------------------------------
       diff_ft(:,js:je,:)=0.
       do k=1,km-1
        do j=js,je
         do i=2,imt-1
          diff_ft(i,j,k)=A_uu(i,j,k)
     &              *(u(i,j,k+1,2,taum1)-u(i,j,k,2,taum1))/dz
         enddo
        enddo
       enddo
c-----------------------------------------------------------------------
c     Add to meridional momentum tendencies
c-----------------------------------------------------------------------
      do k=2,km-1
       do j=js,je
        do i=2,imt-1
         fv(i,j,k)=fv(i,j,k)+(1-aidif_momentum)*
     &          maskV(i,j,k)* (diff_ft(i,j,k)-diff_ft(i,j,k-1))/dz
        enddo
       enddo
      enddo
c-----------------------------------------------------------------------
c     implicit part of vertical operator
c-----------------------------------------------------------------------
      call trm_implicit_vmix(A_uu,aidif_momentum)
#endif

#ifdef enable_vert_friction_diagnostics
      if (snapshot_time_step.or.initial_time==current_time) then
        call vert_friction_trm_diag
      endif
#endif
      end subroutine vert_friction_trm






      subroutine trm_implicit_umix(A_uu,aidif)
c=======================================================================
c     implicit vertical friction
c=======================================================================
      use cpflame_module
      implicit none
      real :: A_uu(imt,jmt,km) ,aidif
      integer :: j,k,js,je
      real :: a(imt,km),bb(imt,km),c(imt,km),bet(imt)
      real :: pu(imt,km),gam(imt,km),fxa,r(imt,km)
      js=max(2,js_pe); je = min(je_pe,jmt-1)
c---------------------------------------------------------------------------------
c      first fake integrate du/dt = F_u, then solve for rest
c---------------------------------------------------------------------------------
      u(:,js:je,:,1,taup1)= u(:,js:je,:,1,taum1)
     &         +c2dt*fu(:,js:je,:)*maskU(:,js:je,:)
      do j=js,je
       fxa = aidif*c2dt/dz**2
       bb(:,1) = 1+fxa * A_uu(:,j,1)
       c(:,1)  =  -fxa * A_uu(:,j,1)
       do k=2,km-1
         a(:,k)  =  -fxa * A_uu(:,j,k-1)
         bb(:,k) = 1+fxa * (A_uu(:,j,k)+A_uu(:,j,k-1) )
         c(:,k)  =  -fxa * A_uu(:,j,k)
       enddo
       a(:,km)  =  -fxa * A_uu(:,j,km-1)
       bb(:,km) = 1+fxa * A_uu(:,j,km-1) 
       pu=0.0;gam=0.0
       r=u(:,j,:,1,taup1)*maskU(:,j,:)
       bet=bb(:,1)
       where (bet/=0.0) pu(:,1)=r(:,1)/bet
       do k=2,km
        where (bet/=0.0) gam(:,k)=c(:,k-1)/bet
        bet=bb(:,k)-a(:,k)*gam(:,k)
        where(bet/=0.0) pu(:,k)=(r(:,k)-a(:,k)*pu(:,k-1))/bet
       enddo
       do k=km-1,1,-1
        pu(:,k)=pu(:,k)-gam(:,k+1)*pu(:,k+1)
       enddo
       u(:,j,:,1,taup1)=pu
      enddo
      fu(:,js:je,:)=(u(:,js:je,:,1,taup1)
     &      -u(:,js:je,:,1,taum1))/c2dt*maskU(:,js:je,:)
      end subroutine trm_implicit_umix


      subroutine trm_implicit_vmix(A_uu,aidif)
c---------------------------------------------------------------------------------
c     same for v
c---------------------------------------------------------------------------------
      use cpflame_module
      implicit none
      real :: A_uu(imt,jmt,km) ,aidif
      integer :: j,k,js,je
      real :: a(imt,km),bb(imt,km),c(imt,km),bet(imt)
      real :: pu(imt,km),gam(imt,km),fxa,r(imt,km)
      js=max(2,js_pe); je = min(je_pe,jmt-1)
c---------------------------------------------------------------------------------
c      first fake integrate dv/dt = F_v, then solve for rest
c---------------------------------------------------------------------------------
      u(:,js:je,:,2,taup1)= u(:,js:je,:,2,taum1)
     &       +c2dt*fv(:,js:je,:)*maskV(:,js:je,:)
      do j=js,je
       fxa = aidif*c2dt/dz**2
       bb(:,1) = 1+fxa * A_uu(:,j,1)
       c(:,1)  =  -fxa * A_uu(:,j,1)
       do k=2,km-1
         a(:,k)  =  -fxa * A_uu(:,j,k-1)
         bb(:,k) = 1+fxa * (A_uu(:,j,k)+A_uu(:,j,k-1) )
         c(:,k)  =  -fxa * A_uu(:,j,k)
       enddo
       a(:,km)  =  -fxa * A_uu(:,j,km-1)
       bb(:,km) = 1+fxa * A_uu(:,j,km-1) 
       pu=0.0;gam=0.0
       r=u(:,j,:,2,taup1)*maskV(:,j,:)
       bet=bb(:,1)
       where (bet/=0.0) pu(:,1)=r(:,1)/bet
       do k=2,km
        where (bet/=0.0) gam(:,k)=c(:,k-1)/bet
        bet=bb(:,k)-a(:,k)*gam(:,k)
        where(bet/=0.0) pu(:,k)=(r(:,k)-a(:,k)*pu(:,k-1))/bet
       enddo
       do k=km-1,1,-1
        pu(:,k)=pu(:,k)-gam(:,k+1)*pu(:,k+1)
       enddo
       u(:,j,:,2,taup1)=pu
      enddo
      fv(:,js:je,:)=(u(:,js:je,:,2,taup1)
     &           -u(:,js:je,:,2,taum1))/c2dt*maskV(:,js:je,:)
      end subroutine trm_implicit_vmix

#ifdef enable_vert_friction_diagnostics

      subroutine init_vert_friction_trm_diag
c-----------------------------------------------------------------------
c     initialize NetCDF snapshot file
c-----------------------------------------------------------------------
      use cpflame_module
      use vert_friction_trm_module
      implicit none
#include "netcdf.inc"
      integer :: ncid,iret,i,j,k,n
      integer :: lon_tdim,lon_udim,z_tdim,z_udim,itimedim
      integer :: lat_tdim,lat_udim,varid
      integer :: dims(4), corner(4), edges(4)
      character :: name*24, unit*16
      call def_grid_cdf('vert_friction_trm.cdf')
      iret=nf_open('vert_friction_trm.cdf',NF_WRITE,ncid)
      iret=nf_set_fill(ncid, NF_NOFILL, iret)
      call ncredf(ncid, iret)
      iret=nf_inq_dimid(ncid,'xt',lon_tdim)
      iret=nf_inq_dimid(ncid,'xu',lon_udim)
      iret=nf_inq_dimid(ncid,'yt',lat_tdim)
      iret=nf_inq_dimid(ncid,'yu',lat_udim)
      iret=nf_inq_dimid(ncid,'zt',z_tdim)
      iret=nf_inq_dimid(ncid,'zu',z_udim)
      iret=nf_inq_dimid(ncid,'Time',itimedim)
      dims = (/Lon_tdim,lat_tdim, z_udim, iTimedim/)
      varid  = ncvdef (ncid,'kappa', NCFLOAT,4,dims,iret)
      name = 'Vertical viscosity'; unit = 'm^2/s'
      call dvcdf(ncid,varid,name,24,unit,16,spval)
      varid  = ncvdef (ncid,'f2N2', NCFLOAT,4,dims,iret)
      name = 'f^2/N^2               '; unit = ' '
      call dvcdf(ncid,varid,name,24,unit,16,spval)
      call ncclos (ncid, iret)
      end subroutine init_vert_friction_trm_diag


      subroutine vert_friction_trm_diag
c-----------------------------------------------------------------------
c     write to NetCDF snapshot file
c-----------------------------------------------------------------------
      use cpflame_module
      use vert_friction_trm_module
      implicit none
#include "netcdf.inc"
      integer :: ncid,iret,n,npe, corner(4), edges(4)
      real :: a(imt,js_pe:je_pe,km)
      integer :: itdimid,ilen,rid,itimeid
      integer :: i,j,is,ie,js,je
      real :: fxa,ut,vt
      type(time_type) :: time
      js=max(2,js_pe); je = min(je_pe,jmt-1)
      do npe=0,n_pes
       if (my_pe==npe) then
        iret=nf_open('vert_friction_trm.cdf',NF_WRITE,ncid)
        iret=nf_set_fill(ncid, NF_NOFILL, iret)
        iret=nf_inq_dimid(ncid,'Time',itdimid)
        iret=nf_inq_dimlen(ncid, itdimid,ilen)
        iret=nf_inq_varid(ncid,'Time',itimeid)
        if (my_pe==0) then
         ilen=ilen+1
         time = current_time-initial_time
         fxa = time%days + time%seconds/86400.
         iret= nf_put_vara_double(ncid,itimeid,ilen,1,fxa)
        endif
        Corner = (/1,js_pe,1,ilen/); 
        edges  = (/imt,je_pe-js_pe+1,km,1/)
        iret=nf_inq_varid(ncid,'kappa',rid)
        a=A_trm(:,js_pe:je_pe,:)
        where( maskW(:,js_pe:je_pe,:) == 0.) a = spval
        iret= nf_put_vara_double(ncid,rid,corner,edges,a)
        iret=nf_inq_varid(ncid,'f2N2',rid)
        a=fNsqrw(:,js_pe:je_pe,:)
        where( maskW(:,js_pe:je_pe,:) == 0.) a = spval
        iret= nf_put_vara_double(ncid,rid,corner,edges,a)
        call ncclos (ncid, iret)
       endif
       call barrier
      enddo
      end subroutine vert_friction_trm_diag
#endif


#else
      subroutine vert_friction_trm_dummy
      end subroutine vert_friction_trm_dummy
#endif
