#include "options.inc"
    

c=======================================================================
c       linked in the code in driver.F and diag_numbers.F
c=======================================================================



#ifdef enable_tke_closure

c=======================================================================
c       TKE closure for non-hydrostatic regime
c
c       based on standard TKE equation 
c       length scale L=delta for unstable/neutral 
c       and min(delta,0.76 eke^0.5/N) for stable stratification
c       shear production is c_m K |nabla u|^2 
c       bouyancy work given by c_b N^2 where c_b = c_m(1 + 2L/delta)
c       and the dissipation by c_eps eke^3/2/L with c_eps = 0.19+0.51L/delta
c       delta is chosen as c_l*max(dx,dz)
c       
c=======================================================================
      module tke_closure_module
      implicit none
      real,allocatable :: tke(:,:,:,:),diff(:,:,:)
      real,allocatable :: diffU(:,:,:),diffT(:,:,:)
      real,allocatable :: diffE(:,:,:)
      real,allocatable :: diss(:,:,:),T4(:,:,:),T2(:,:,:)
      real,allocatable :: Nsqr(:,:,:),len(:,:,:)
      real, parameter :: c_m=0.1, c_eke=0.2, c_l=1.0, epsln=1e-20
      real, parameter :: Pec_max=10. 
      end module tke_closure_module


 
      subroutine init_tke_closure
c=======================================================================
c       allocate work space, etc
c=======================================================================
      use cpflame_module
      use tke_closure_module
      implicit none
      integer :: j,k,n

      allocate( tke(imt,jmt,km,0:2) ); tke=1.0e-12
      allocate( diff(imt,jmt,km) );    diff=0.0
      allocate( diffT(imt,jmt,km) );   diffT=0.0
      allocate( diffU(imt,jmt,km) );   diffU=0.0
      allocate( diffE(imt,jmt,km) );   diffE=0.0
      allocate( diss(imt,jmt,km) );    diss=0.0
      allocate( len(imt,jmt,km) );     len=0.0
      allocate( T4(imt,jmt,km) );      T4=0.0
      allocate( T2(imt,jmt,km) );      T2=0.0
      allocate( Nsqr (imt,jmt,km)   ); Nsqr=0.0
      call init_tke_diag
      end subroutine init_tke_closure



      subroutine tke_closure
c=======================================================================
c       main driver routine for this module
c=======================================================================
      use cpflame_module
      use tke_closure_module
      implicit none
      integer :: i,j,k,js,je
      real :: totke,area,fx

      js=max(2,js_pe); je = min(je_pe,jmt-1)

      call calc_Nsqr
      call integrate_tke_closure
      call apply_tke_closure

c      fx = dx**2*dz
c      totke=0.0;area=0.0
c      do k=2,km-1
c       do j=js,je
c        do i=2,imt-1
c          area=area + fx*maskT(i,j,k)
c          totke = totke + tke(i,j,k,tau)*fx*maskT(i,j,k)
c        enddo
c       enddo
c      enddo

c      call global_sum(area)
c      call global_sum(totke)
c      totke = totke/area
c
c      if (my_pe==0) print'(a,e12.7,a,f10.8,a,a,i5,a,i5)', 
c     &      ' tke=',totke
c      call sub_flush(6)

      if (snapshot_time_step.or.initial_time==current_time) then
        call diag_tke
      endif
      end subroutine tke_closure





      subroutine integrate_tke_closure
c=======================================================================
c       integrate prognostic budget for TKE
c=======================================================================
      use cpflame_module
      use tke_closure_module
      implicit none
      integer :: i,j,k,js,je,n
      real :: adv_fe(imt,jmt,km), adv_ft(imt,jmt,km)
      real :: adv_fn(imt,jmt,km), diff_fn(imt,jmt,km)
      real :: diff_fe(imt,jmt,km),diff_ft(imt,jmt,km)
      real :: ut(imt,jmt,km),vt(imt,jmt,km),fxa,fxb,fxc
      real :: wt(imt,jmt,km),delta,uabs,c_b,k_max,k_min

      js=max(2,js_pe); je = min(je_pe,jmt-1)
c
c=======================================================================
c       eddy length scale
c=======================================================================
c---------------------------------------------------------------------------------
      fxa = c_l*max(dx,dz)
      do k=2,km-1
       do j=js,je
        len(:,j,k)=fxa
        where( Nsqr(:,j,k) >0 ) len(:,j,k)=min(fxa,
     &          0.76*sqrt(tke(:,j,k,taum1)/Nsqr(:,j,k)) )
       enddo
      enddo
c      call smoothT(len)
      call border_exchg3D(len,1)
      call setcyclic3D(len)
c---------------------------------------------------------------------------------
c=======================================================================
c       u and v on T grid
c=======================================================================
c
      ut=0;vt=0;wt=0
      do k=2,km-1
       do j=js,je
        do i=2,imt-1
         ut(i,j,k)=(u(i-1,j,k,1,taum1)*maskU(i-1,j,k)
     &             +u(i  ,j,k,1,taum1)*maskU(i  ,j,k))/
     &         (maskU(i-1,j,k)+maskU(i,j,k)+epsln)
         vt(i,j,k)=(u(i,j-1,k,2,taum1)*maskV(i,j-1,k)
     &             +u(i,j  ,k,2,taum1)*maskV(i,j  ,k))/
     &         (maskV(i,j-1,k)+maskV(i,j,k)+epsln)
         wt(i,j,k)=(u(i,j,k-1,3,taum1)*maskW(i,j,k-1)
     &             +u(i,j,k  ,3,taum1)*maskW(i,j,k  ))/
     &         (maskW(i,j,k-1)+maskW(i,j,k)+epsln)
        enddo
       enddo
      enddo
      call border_exchg3D(ut,1)
      call setcyclic3D(ut)
      call border_exchg3D(vt,1)
      call setcyclic3D(vt)
      call border_exchg3D(wt,1)
      call setcyclic3D(wt)

c
c=======================================================================
c      Diffusivity
c=======================================================================
c
      delta = c_l*max(dx,dz)
      do k=2,km-1
       do j=js,je
        do i=2,imt-1
         diff(i,j,k)=sqrt( max(epsln,tke(i,j,k,taum1)) )
     &               *len(i,j,k)*maskT(i,j,k)
         diffU(i,j,k)=c_m*diff(i,j,k) 
         diffE(i,j,k)=c_eke*diff(i,j,k) 
         c_b = (1+2*len(i,j,k)/delta)*c_m
         diffT(i,j,k)=c_b*diff(i,j,k) 
        enddo
       enddo
      enddo
c
c=======================================================================
c      limit diffusivity  such that Peclet number is not larger than Pec_max
c      Pec = U dx/diff < Pec_max -=>   diff > U dx /Pec_max
c      Note that diff and thus T4 and T2 is not limited.
c=======================================================================
c
      K_max = min(dx,dz)**2/dt_in/4.
      do k=2,km-1
       do j=js,je
        do i=2,imt-1
         uabs=max(     abs(u(i,j,k,1,tau)) , abs(u(i-1,j,k,1,tau)) )
         uabs=max(uabs,abs(u(i,j,k,2,tau)) , abs(u(i,j-1,k,2,tau)) )
         uabs=max(uabs,abs(u(i,j,k,3,tau)) , abs(u(i,j,k-1,3,tau)) )
         K_min=uabs*delta/Pec_max 
         diffU(i,j,k)=max(k_min,min(K_max,diffU(i,j,k)))
         diffT(i,j,k)=max(k_min,min(K_max,diffT(i,j,k)))
         diffE(i,j,k)=max(k_min,min(K_max,diffE(i,j,k)))
        enddo
       enddo
      enddo
      call border_exchg3D(diff,1)
      call setcyclic3D(diff)
      call border_exchg3D(diffT,1)
      call setcyclic3D(diffT)
      call border_exchg3D(diffU,1)
      call setcyclic3D(diffU)
      call border_exchg3D(diffE,1)
      call setcyclic3D(diffE)
c---------------------------------------------------------------------------------
c=======================================================================
c        advection of TKE
c=======================================================================
c---------------------------------------------------------------------------------
      adv_fe=0; adv_fn=0; adv_ft=0
      call adv_flux(adv_fe,adv_fn,adv_ft,tke)
      call border_exchg3D(adv_fn,1)
      call setcyclic3D(adv_fe)
      call setcyclic3D(adv_fn)
c---------------------------------------------------------------------------------
c=======================================================================
c      diffusion of TKE
c=======================================================================
c---------------------------------------------------------------------------------
      diff_fe=0;diff_fn=0;diff_ft=0
      do k=2,km-1
       do j=js,je
        do i=1,imt-1
         fxa=(diffE(i,j,k)+diffE(i+1,j,k))/2.*maskU(i,j,k)
         diff_fe(i,j,k)=fxa*(tke(i+1,j,k,taum1)-tke(i,j,k,taum1))/dx
        enddo
       enddo
      enddo
      call setcyclic3D(diff_fe)
      do k=2,km-1
       do j=js,je 
        do i=2,imt-1
         fxa=(diffE(i,j,k)+diffE(i,j+1,k))/2.*maskV(i,j,k)
         diff_fn(i,j,k)=fxa*(tke(i,j+1,k,taum1)-tke(i,j,k,taum1))/dx
        enddo
       enddo
      enddo
      call border_exchg3D(diff_fn,1)
      call setcyclic3D(diff_fn)
      do k=1,km-2
       do j=js,je
        do i=2,imt-1
         fxa=(diffE(i,j,k)+diffE(i,j,k+1))/2.*maskW(i,j,k)
         diff_ft(i,j,k)=fxa*(tke(i,j,k+1,taum1)-tke(i,j,k,taum1))/dz
        enddo
       enddo
      enddo
c---------------------------------------------------------------------------------
c=======================================================================
c       shear production of TKE
c=======================================================================
c---------------------------------------------------------------------------------
      call gradT2(ut,diss)
      do k=2,km-1
       do j=js,je
        do i=2,imt-1
         T4(i,j,k)= c_m*diff(i,j,k)*diss(i,j,k)
        enddo
       enddo
      enddo
      call gradT2(vt,diss)
      do k=2,km-1
       do j=js,je
        do i=2,imt-1
         T4(i,j,k)= T4(i,j,k)+c_m*diff(i,j,k)*diss(i,j,k)
        enddo
       enddo
      enddo
      call gradT2(wt,diss)
      do k=2,km-1
       do j=js,je
        do i=2,imt-1
         T4(i,j,k)= T4(i,j,k)+c_m*diff(i,j,k)*diss(i,j,k)
        enddo
       enddo
      enddo
c
c=======================================================================
c      baroclinic instability :  K N^2  
c=======================================================================
c
      delta = c_l*max(dx,dz)
      do k=2,km-1
       do j=js,je
        do i=2,imt-1
         c_b = (1+2*len(i,j,k)/delta)*c_m
         T2(i,j,k)=-c_b*diff(i,j,k)*Nsqr(i,j,k)*maskT(i,j,k)
        enddo
       enddo
      enddo
c      call smoothT(T2)
c      call smoothT(T4)
c
c=======================================================================
c      dissipation of TKE
c=======================================================================
c
      delta = c_l*max(dx,dz)
      do k=2,km-1
       do j=js,je
        do i=2,imt-1
         diss(i,j,k)=(0.19+0.51*len(i,j,k)/delta)
     &                  *max(epsln,tke(i,j,k,taum1))**1.5
     &                    /max(epsln,len(i,j,k))*maskT(i,j,k)
        enddo
       enddo
      enddo
c---------------------------------------------------------------------------------
c=======================================================================
c       time tendency of tke
c=======================================================================
c---------------------------------------------------------------------------------
      do k=2,km-1
       do j=js,je
        do i=2,imt-1
          tke(i,j,k,taup1)=tke(i,j,k,taum1)+maskT(i,j,k)*c2dt*( 
     &  -(adv_fe(i,j,k)-adv_fe(i-1,j,k))/dx
     &  -(adv_fn(i,j,k)-adv_fn(i,j-1,k))/dx
     &  -(adv_ft(i,j,k)-adv_ft(i,j,k-1))/dz
     &  +(diff_ft(i,j,k)-diff_ft(i,j,k-1))/dz
     &  +(diff_fe(i,j,k)-diff_fe(i-1,j,k))/dx
     &  +(diff_fn(i,j,k)-diff_fn(i,j-1,k))/dx 
     &  -diss(i,j,k) +T4(i,j,k)+T2(i,j,k)  )
         tke(i,j,k,taup1)=max(epsln,tke(i,j,k,taup1))
        enddo
       enddo
      enddo
c        boundary exchange
c---------------------------------------------------------------------------------
      call border_exchg3D(tke(:,:,:,taup1),2)
      call setcyclic3D(tke(:,:,:,taup1) )
c---------------------------------------------------------------------------------
c       apply roberts time filter on time levels 
c---------------------------------------------------------------------------------
      tke(:,j,:,tau) = tke(:,j,:,tau) + gamma*
     &     (0.5*(tke(:,j,:,taup1)+tke(:,j,:,taum1))-tke(:,j,:,tau)) 
      call border_exchg3D(tke(:,:,:,tau),2)
      call setcyclic3D(tke(:,:,:,tau) )

      end subroutine integrate_tke_closure




      subroutine gradT2(a,bb)
c=======================================================================
c       calculate bb = |nabla a |^2 on T grid
c=======================================================================
      use cpflame_module
      real :: a(imt,jmt,km), bb(imt,jmt,km)
      real, parameter :: epsln     = 1.0e-20
      integer :: js,je,i,j,k
      js=max(2,js_pe); je = min(je_pe,jmt-1)
      do k=2,km-1
       do j=js,je
        do i=2,imt-1
         bb(i,j,k)= maskT(i,j,k)*( 
     &        (((a(i+1,j,k)-a(i  ,j,k))/dx*maskU(i  ,j,k)
     &         +(a(i  ,j,k)-a(i-1,j,k))/dx*maskU(i-1,j,k))/
     &         (maskU(i-1,j,k)+maskU(i,j,k)+epsln))**2
     &       +(((a(i,j+1,k)-a(i,j  ,k))/dx*maskV(i,j  ,k)
     &         +(a(i,j  ,k)-a(i,j-1,k))/dx*maskV(i,j-1,k))/
     &         (maskV(i,j-1,k)+maskV(i,j,k)+epsln))**2
     &       +(((a(i,j,k+1)-a(i,j,k  ))/dz*maskW(i,j,k  )
     &         +(a(i,j,k  )-a(i,j,k-1))/dz*maskW(i,j,k-1))/
     &         (maskW(i,j,k-1)+maskW(i,j,k)+epsln))**2 )
        enddo
       enddo
      enddo
      end subroutine gradT2



      subroutine smoothT(a)
c=======================================================================
c      smooth a on T grid
c=======================================================================
      use cpflame_module
      real :: a(imt,jmt,km), bb(imt,jmt,km),c(imt,jmt,km)
      real, parameter :: epsln     = 1.0e-20
      integer :: js,je,i,j,k,ii,jj,kk

      js=max(2,js_pe); je = min(je_pe,jmt-1)

      bb(:,js:je,:)=a(:,js:je,:)
      call border_exchg3D(bb,1)
      call setcyclic3D(a)

      a=0;c=0
      do ii=-1,1
      do jj=-1,1
      do kk=-1,1
      do k=2,km-1
       do j=js,je
        do i=2,imt-1
         a(i,j,k)= a(i,j,k)+bb(i+ii,j+jj,k+kk)*maskT(i+ii,j+jj,k+kk)
         c(i,j,k)= c(i,j,k)+maskT(i+ii,j+jj,k+kk)
        enddo
       enddo
      enddo
      enddo
      enddo
      enddo
      a(:,js:je,:)= a(:,js:je,:)/(c(:,js:je,:)+epsln)
      call border_exchg3D(a,1)
      call setcyclic3D(a)

      end subroutine smoothT




      subroutine  calc_Nsqr
c=======================================================================
c       Calculate stability freq.
c=======================================================================
      use cpflame_module
      use tke_closure_module
      implicit none
      integer :: i,j,k,js,je
      real :: fxa,fxb, Nsqrw(imt,jmt,km)  

      js=max(2,js_pe); je = min(je_pe,jmt-1)
c---------------------------------------------------------------------------------
c=======================================================================
c       stability freq.
c=======================================================================
c---------------------------------------------------------------------------------
      Nsqrw=0
      do k=1,km-1
       do j=js,je
        Nsqrw(:,j,k)=-(b(:,j,k+1,taum1)-b(:,j,k,taum1))/dz
       enddo
      enddo
      call border_exchg3D(Nsqrw,1)
      call setcyclic3D(Nsqrw )
c---------------------------------------------------------------------------------
c=======================================================================
c      Interpolate Nsqr vertically on T grid and bound Nsqr
c=======================================================================
c---------------------------------------------------------------------------------
      do k=2,km-1
       do j=js,je
         Nsqr(:,j,k)=maskT(:,j,k)*(
     &     Nsqrw(:,j,k  )*maskW(:,j,k  )+Nsqrw(:,j,k-1)*maskW(:,j,k-1) )
     &     /(maskW(:,j,k)+maskW(:,j,k-1)+epsln)
       enddo
      enddo
      call border_exchg3D(Nsqr,1)
      call setcyclic3D(Nsqr )
      end subroutine  calc_Nsqr






      subroutine apply_tke_closure
c=======================================================================
c       add friction to tendencies for u and v
c=======================================================================
      use cpflame_module
      use tke_closure_module
      implicit none
      integer :: i,j,k,js,je
      real :: diff_ft(imt,jmt,km)
      real :: diff_fe(imt,jmt,km),diff_fn(imt,jmt,km)
      real :: fxa,fxb,fxc

      js=max(2,js_pe); je = min(je_pe,jmt-1)
c---------------------------------------------------------------------------------
c=======================================================================
c      horizontal friction of u
c=======================================================================
c---------------------------------------------------------------------------------
      diff_ft=0; diff_fe=0.0; diff_fn=0.0; 
      do k=2,km-1
       do j=js,je
        do i=1,imt-1
         fxa=diffU(i+1,j,k)
         diff_fe(i,j,k)=fxa*(u(i+1,j,k,1,taum1)-u(i,j,k,1,taum1))/dx
     &                     *maskU(i+1,j,k)*maskU(i,j,k)
        enddo
       enddo
      enddo
      call setcyclic3D(diff_fe)
      do k=2,km-1
       do j=js,je
        do i=2,imt-1
          fxa=0.25*(diffU(i,j  ,k)+diffU(i+1,j  ,k)
     &             +diffU(i,j+1,k)+diffU(i+1,j+1,k))
          diff_fn(i,j,k)=fxa*(u(i,j+1,k,1,taum1)-u(i,j,k,1,taum1))/dx
     &                      *maskU(i,j+1,k)*maskU(i,j,k)
        enddo
       enddo
      enddo
      call border_exchg3D(diff_fn,1)
      call setcyclic3D(diff_fn)
      do k=1,km-1
       do j=js,je
        do i=2,imt-1
          fxa=0.25*(diffU(i,j,k  )+diffU(i+1,j,k)
     &             +diffU(i,j,k+1)+diffU(i+1,j,k+1))
          diff_ft(i,j,k)=fxa*(u(i,j,k+1,1,taum1)-u(i,j,k,1,taum1))/dz
     &                   *maskU(i,j,k+1)*maskU(i,j,k)
        enddo
       enddo
      enddo
      do k=2,km-1
       do j=js,je
        do i=2,imt-1
         fu(i,j,k)=fu(i,j,k)+maskU(i,j,k)*(
     &     (diff_ft(i,j,k) - diff_ft(i,j,k-1))/dz
     &    +(diff_fe(i,j,k) - diff_fe(i-1,j,k))/dx
     &    +(diff_fn(i,j,k) - diff_fn(i,j-1,k))/dx)
        enddo
       enddo
      enddo
c---------------------------------------------------------------------------------
c=======================================================================
c      horizontal friction of v
c=======================================================================
c---------------------------------------------------------------------------------
      diff_ft=0.0; diff_fe=0.0; diff_fn=0.0;
      do k=2,km-1
       do j=js,je
        do i=1,imt-1
          fxa=0.25*(diffU(i+1,j  ,k)+diffU(i,j  ,k)
     &             +diffU(i+1,j+1,k)+diffU(i,j+1,k))
          diff_fe(i,j,k)=fxa*(u(i+1,j,k,2,taum1)-u(i,j,k,2,taum1))/dx
     &                   *maskV(i+1,j,k)*maskV(i,j,k)
        enddo
       enddo
      enddo
      call setcyclic3D(diff_fe)
      do k=2,km-1
       do j=js,je
        do i=2,imt-1
          fxa=diffU(i,j+1,k)
          diff_fn(i,j,k)=fxa*(u(i,j+1,k,2,taum1)-u(i,j,k,2,taum1) )/dx
     &                      *maskV(i,j+1,k)*maskV(i,j,k)
        enddo
       enddo
      enddo
      call border_exchg3D(diff_fn,1)
      call setcyclic3D(diff_fn)
      do k=1,km-1
       do j=js,je
        do i=2,imt-1
          fxa=0.25*(diffU(i,j,k  )+diffU(i,j+1,k  )
     &             +diffU(i,j,k+1)+diffU(i,j+1,k+1))
          diff_ft(i,j,k)=fxa*(u(i,j,k+1,2,taum1)-u(i,j,k,2,taum1))/dz
     &                   *maskV(i,j,k+1)*maskV(i,j,k)
        enddo
       enddo
      enddo
      do k=2,km-1
       do j=js,je
        do i=2,imt-1
         fv(i,j,k)=fv(i,j,k)+ maskV(i,j,k)*(
     &     (diff_ft(i,j,k) - diff_ft(i,j,k-1))/dz
     &    +(diff_fe(i,j,k) - diff_fe(i-1,j,k))/dx
     &    +(diff_fn(i,j,k) - diff_fn(i,j-1,k))/dx)
        enddo
       enddo
      enddo
c---------------------------------------------------------------------------------
c=======================================================================
c       friction of v
c=======================================================================
c---------------------------------------------------------------------------------
      diff_ft=0.0; diff_fe=0.0; diff_fn=0.0; 

      do k=2,km-1
       do j=js,je
        do i=1,imt-1
          fxa=0.25*(diffU(i+1,j,k  )+diffU(i,j,k  )
     &             +diffU(i+1,j,k+1)+diffU(i,j,k+1))
          diff_fe(i,j,k)=
     &    fxa*(u(i+1,j,k,3,taum1)-u(i,j,k,3,taum1))/dx
     &         *maskW(i+1,j,k)*maskW(i,j,k)
        enddo
       enddo
      enddo
      call setcyclic3D(diff_fe)
      do k=2,km-1
       do j=js,je
        do i=2,imt-1
          fxa=0.25*(diffU(i,j,k  )+diffU(i,j+1,k  )
     &             +diffU(i,j,k+1)+diffU(i,j+1,k+1))
          diff_fn(i,j,k)=
     &     fxa*(u(i,j+1,k,3,taum1)-u(i,j,k,3,taum1))/dx
     &         *maskW(i,j+1,k)*maskW(i,j,k)
        enddo
       enddo
      enddo
      call border_exchg3D(diff_fn,1)
      call setcyclic3D(diff_fn)
      do k=1,km-1
       do j=js,je
        do i=2,imt-1
         fxa=diffU(i,j,k+1)
         diff_ft(i,j,k)=fxa*(u(i,j,k+1,3,taum1)-u(i,j,k,3,taum1))/dz
     &        *maskW(i,j,k+1)*maskW(i,j,k)
        enddo
       enddo
      enddo
      do k=2,km-1
       do j=js,je
        do i=2,imt-1
          fw(i,j,k)= fw(i,j,k)+maskW(i,j,k)*(
     &     +(diff_ft(i,j,k) - diff_ft(i,j,k-1))/dz
     &     +(diff_fe(i,j,k) - diff_fe(i-1,j,k))/dx
     &     +(diff_fn(i,j,k) - diff_fn(i,j-1,k))/dx )
        enddo
       enddo
      enddo

      call border_exchg3D(fu,1)
      call setcyclic3D(fu)
      call border_exchg3D(fv,1)
      call setcyclic3D(fv)
      call border_exchg3D(fw,1)
      call setcyclic3D(fw)
c
c=======================================================================
c  buoancy mixing
c=======================================================================
c
      diff_fe=0;diff_fn=0;diff_ft=0
      do k=2,km-1
        do j=js,je
         do i=1,imt-1
          fxa=0.5*(diffT(i,j,k)+diffT(i+1,j,k))*maskU(i,j,k)
          diff_fe(i,j,k)=fxa*(b(i+1,j,k,taum1)-b(i,j,k,taum1))/dx
         enddo
        enddo
      enddo
      call setcyclic3D(diff_fe)
      do k=2,km-1
        do j=js,je 
         do i=2,imt-1
          fxa=0.5*(diffT(i,j,k)+diffT(i,j+1,k))*maskV(i,j,k)
          diff_fn(i,j,k)=fxa*(b(i,j+1,k,taum1)-b(i,j,k,taum1))/dx
         enddo
        enddo
      enddo
      call border_exchg3D(diff_fn,1)
      call setcyclic3D(diff_fn)
      do k=1,km-2
       do j=js,je
        do i=2,imt-1
         fxa=0.5*(diffT(i,j,k)+diffT(i,j,k+1))*maskW(i,j,k)
         diff_ft(i,j,k)=fxa*(b(i,j,k+1,taum1)-b(i,j,k,taum1))/dz
        enddo
       enddo
      enddo
      do k=2,km-1
       do j=js,je
        do i=2,imt-1
         b(i,j,k,taup1)=b(i,j,k,taup1)+maskT(i,j,k)*c2dt*( 
     &  +(diff_ft(i,j,k)-diff_ft(i,j,k-1))/dz
     &  +(diff_fe(i,j,k)-diff_fe(i-1,j,k))/dx
     &  +(diff_fn(i,j,k)-diff_fn(i,j-1,k))/dx  )
        enddo
       enddo
      enddo
      call border_exchg3D(b(:,:,:,taup1),2)
      call setcyclic3D(b(:,:,:,taup1) )
      end subroutine apply_tke_closure





      subroutine init_tke_diag
c=======================================================================
c     initialize NetCDF snapshot file
c=======================================================================
      use cpflame_module
      use tke_closure_module
      implicit none
#include "netcdf.inc"
      integer :: ncid,iret,i,j,k,n
      integer :: lon_tdim,z_tdim,itimedim,lat_tdim,varid,dims(4)
      character :: name*24, unit*16

      call def_grid_cdf('tke.cdf')
      iret=nf_open('tke.cdf',NF_WRITE,ncid)
      iret=nf_set_fill(ncid, NF_NOFILL, iret)
      call ncredf(ncid, iret)
      iret=nf_inq_dimid(ncid,'xt',lon_tdim)
      iret=nf_inq_dimid(ncid,'yt',lat_tdim)
      iret=nf_inq_dimid(ncid,'zt',z_tdim)
      iret=nf_inq_dimid(ncid,'Time',itimedim)
      dims = (/Lon_tdim,lat_tdim, z_tdim, iTimedim/)
      varid = ncvdef (ncid,'tke', NCFLOAT,4,dims,iret)
      name = 'Turbulent kinetic energy'; unit = 'm^2/s^2'
      call dvcdf(ncid,varid,name,24,unit,16,spval)
      varid = ncvdef (ncid,'len', NCFLOAT,4,dims,iret)
      name = 'Length scale'; unit = 'm'
      call dvcdf(ncid,varid,name,24,unit,16,spval)
      varid = ncvdef (ncid,'diff', NCFLOAT,4,dims,iret)
      name = 'Diffusivity'; unit = 'm^2/s'
      call dvcdf(ncid,varid,name,24,unit,16,spval)
      varid = ncvdef (ncid,'diffT', NCFLOAT,4,dims,iret)
      name = 'Diffusivity'; unit = 'm^2/s'
      call dvcdf(ncid,varid,name,24,unit,16,spval)
      varid= ncvdef (ncid,'T4', NCFLOAT,4,dims,iret)
      name = 'shear production'; unit = 'm^2/s^3 '
      call dvcdf(ncid,varid,name,24,unit,16,spval)
      varid= ncvdef (ncid,'T2', NCFLOAT,4,dims,iret)
      name = 'baroclinic production'; unit = 'm^2/s^3 '
      call dvcdf(ncid,varid,name,24,unit,16,spval)
      varid= ncvdef (ncid,'eps', NCFLOAT,4,dims,iret)
      name = 'Dissipation'; unit = 'm^2/s^3 '
      call dvcdf(ncid,varid,name,24,unit,16,spval)
      call ncclos (ncid, iret)
      end subroutine init_tke_diag



      subroutine diag_tke
c=======================================================================
c     write to NetCDF snapshot file
c=======================================================================
      use cpflame_module
      use tke_closure_module
      implicit none
#include "netcdf.inc"
      integer :: ncid,iret,npe, corner(4), edges(4)
      real :: a(imt,js_pe:je_pe,km),fxa
      integer :: itdimid,ilen,rid,itimeid,fnid,varid
      integer :: i,j,is,ie,js,je
      character :: name*24, unit*16
      type(time_type) :: time

      js=max(2,js_pe); je = min(je_pe,jmt-1)
      do npe=0,n_pes
       if (my_pe==npe) then
        iret=nf_open('tke.cdf',NF_WRITE,ncid)
        iret=nf_set_fill(ncid, NF_NOFILL, iret)
        iret=nf_inq_dimid(ncid,'Time',itdimid)
        iret=nf_inq_dimlen(ncid, itdimid,ilen)
        iret=nf_inq_varid(ncid,'Time',itimeid)
        if (my_pe==0) then
         ilen=ilen+1
         time = current_time-initial_time
         fxa = time%days + time%seconds/86400.
         iret= nf_put_vara_double(ncid,itimeid,ilen,1,fxa)
        endif
        Corner = (/1,js_pe,1,ilen/); 
        edges  = (/imt,je_pe-js_pe+1,km,1/)
        a=tke(:,js_pe:je_pe,:,tau)
        where( maskT(:,js_pe:je_pe,:) == 0.) a = spval
        iret=nf_inq_varid(ncid,'tke',varid)
        iret= nf_put_vara_double(ncid,varid,corner,edges,a)
        a=len(:,js_pe:je_pe,:)
        where( maskT(:,js_pe:je_pe,:) == 0.) a = spval
        iret=nf_inq_varid(ncid,'len',varid)
        iret= nf_put_vara_double(ncid,varid,corner,edges,a)
        a=diff(:,js_pe:je_pe,:)
        where( maskT(:,js_pe:je_pe,:) == 0.) a = spval
        iret=nf_inq_varid(ncid,'diff',varid)
        iret= nf_put_vara_double(ncid,varid,corner,edges,a)
        a=diffT(:,js_pe:je_pe,:)
        where( maskT(:,js_pe:je_pe,:) == 0.) a = spval
        iret=nf_inq_varid(ncid,'diffT',varid)
        iret= nf_put_vara_double(ncid,varid,corner,edges,a)
        a=T4(:,js_pe:je_pe,:)
        where( maskT(:,js_pe:je_pe,:) == 0.) a = spval
        iret=nf_inq_varid(ncid,'T4',varid)
        iret= nf_put_vara_double(ncid,varid,corner,edges,a)
        a=T2(:,js_pe:je_pe,:)
        where( maskT(:,js_pe:je_pe,:) == 0.) a = spval
        iret=nf_inq_varid(ncid,'T2',varid)
        iret= nf_put_vara_double(ncid,varid,corner,edges,a)
        a=diss(:,js_pe:je_pe,:)
        where( maskT(:,js_pe:je_pe,:) == 0.) a = spval
        iret=nf_inq_varid(ncid,'eps',varid)
        iret= nf_put_vara_double(ncid,varid,corner,edges,a)
        call ncclos (ncid, iret)
       endif
       call barrier
      enddo
      end subroutine diag_tke

#else
      subroutine tke_closure_dummy
      end
#endif

