
!=======================================================================
! Meso-scale eddy parameterisation based on pv mixing
! the lateral diffusivity for eddy pv fluxes can be calculated
! using linear stability theory
!=======================================================================

module pv_mixing_module
 use pyOM_module
 implicit none
 integer :: j_constraint
 real*8,parameter :: pi = 3.14159265358979323846264338327950588, grav = 9.81, rho_0 = 1024
 real*8, dimension(:,:,:), allocatable :: vtendu, htendu,theta_u,vtendv, htendv,theta_v
 real*8, dimension(:,:,:), allocatable :: Nsqrw, fNsqrw, K_iso
 real*8, dimension(:,:,:,:), allocatable :: K_rho, K_pv
 real*8, dimension(:,:),   allocatable :: kx_max,ky_max,meanN
 complex*8, allocatable :: om(:,:)
 real*8  :: K_min = 25.0, K_max = 2e4, len_fac=2.0, K_w = 1.0, K_hor_min=0500.0
 real*8  :: u_max     = 1e18 !0.2
 real*8  :: M_min     = 1d-12! 1d-9
 integer :: pv_mixing_calcint = 3*86400.0
 integer :: pv_mixing_diagint = 3*86400.0
 logical :: enable_pv_mixing_read_diff = .false.       ! read diffusivity from file
 logical :: enable_pv_mixing_set_diff  = .false.       ! set diffusivity to constant
 logical :: enable_pv_mixing_calc_diff = .true.        ! calculate diffusivity by linear stability analysis
 logical :: enable_pv_mixing_smooth_backstate = .false.! smooth background state by running mean
 logical :: enable_pv_mixing_smooth_diff      = .false. ! smooth diffusivity
 integer :: k_runmean=5
 logical :: enable_pv_mixing_diag      = .true.
 !
 logical :: enable_pv_mixing_full      = .false. ! enables full pv mixing instead of TRM formulation
 logical :: enable_pv_mixing_beta_term = .true.  ! enables the force by beta term
 logical :: enable_pv_mixing_friction  = .false. ! enables also the lateral friction
 !
 logical :: pv_mixing_initialized       
end module pv_mixing_module


subroutine init_pv_mixing
 use pyOM_module
 use pv_mixing_module
 implicit none

 if (my_pe==0) print*,' ' 
 if (my_pe==0) print*,' initializing pv mixing module' 
 pv_mixing_initialized = .true.

 j_constraint = int( 300e3/dx )
 if (my_pe==0) print*,' pv mixing constraint over ',j_constraint,' grid points'
 allocate(K_rho(nx,ny,nz,2) ); K_rho   = 0.0
 allocate(K_pv(nx,ny,nz,2) );  K_pv    = 0.0
 allocate(Nsqrw(nx,ny,nz) );   Nsqrw   = 0.0
 allocate(fNsqrw(nx,ny,nz) );  fNsqrw  = 0.0

 allocate(htendu(nx,ny,nz) ); htendu=0.0
 allocate(htendv(nx,ny,nz) ); htendv=0.0
 allocate(vtendu(nx,ny,nz) ); vtendu=0.0
 allocate(vtendv(nx,ny,nz) ); vtendv=0.0
 allocate(theta_u(nx,ny,nz) ); theta_u=0.0
 allocate(theta_v(nx,ny,nz) ); theta_v=0.0
 if (my_pe==0) then  
   print*,' K_min     = ',K_min
   print*,' K_max     = ',K_max
   print*,' K_w       = ',K_w
   print*,' len fac   = ',len_fac
   print*,' K_hor_min = ',K_hor_min
   print*,' M_min     = ',M_min
   print*,' N_min     = ',N_min
   print*,' u_max     = ',u_max
 endif
 if (enable_pv_mixing_diag .and. my_pe==0 ) then
   print*,' diagnosing any ',int(pv_mixing_diagint/dt),' time steps'
 endif
 if (enable_pv_mixing_full ) then
   if (my_pe==0) print*,' using full pv mixing' 
   if (enable_pv_mixing_friction .and. my_pe==0) print*,' with lateral friction' 
   if (enable_pv_mixing_beta_term.and. my_pe==0) print*,' with beta term' 
 else
   if (my_pe==0) print*,' using partial pv mixing' 
 endif
 if (enable_pv_mixing_calc_diff ) then
   if (my_pe==0) print*,' calculating diffusivity any ',int(pv_mixing_calcint/dt),' time steps'
   allocate(meanN(nx,ny) );  meanN  = 0.0
   allocate(om(nx,ny) );     om     = cmplx(0.0,0.0)
   allocate(kx_max(nx,ny) ); kx_max = 0.0
   allocate(ky_max(nx,ny) ); ky_max = 0.0
   allocate(K_iso(nx,ny,nz) ); K_iso    = 0.0
 endif
 if (enable_pv_mixing_read_diff) then
      if (my_pe==0) print*,' reading diffusivity'
      call pv_mixing_read_diff()
 endif
 if (enable_pv_mixing_set_diff)  then
      if (my_pe==0) print*,' setting diffusivity'
      call pv_mixing_set_diff()
 endif
 if (enable_pv_mixing_smooth_backstate) then
      if (my_pe==0) print*,' smoothing background state for linear stability analysis'
      if (my_pe==0) print*,' with running mean over ',k_runmean,' grid points'
 endif
 if (enable_pv_mixing_smooth_diff) then
      if (my_pe==0) print*,' smoothing diffusivity from lin. stab. analysis'
      if (my_pe==0) print*,' with running mean over ',k_runmean,' grid points'
 endif
 if  (enable_pv_mixing_diag)  call pv_mixing_diag_init
 if (my_pe==0) print*,' done initializing pv mixing module' 
 if (my_pe==0) print*,' ' 
end subroutine init_pv_mixing


subroutine pv_mixing_main(ierr)
 !=======================================================================
 ! Hauptverteiler funktion
 !=======================================================================
  use pyOM_module
  use fcontrol_module
  use pv_mixing_module
  integer, intent(out) :: ierr
  integer :: i,j,k,js,je

  ierr=0
  js=max(2,js_pe); je = min(je_pe,ny -1)
 !-----------------------------------------------------------------------
 !     stability freq.
 !-----------------------------------------------------------------------
  Nsqrw(:,js:je,:)=0.0
  do k=1,nz-1
       do j=js,je
        Nsqrw(:,j,k)=-(b(:,j,k+1,taum1)-b(:,j,k,taum1))/dz
       enddo
  enddo
  if (enable_back_state) then  
       do k=1,nz-1
        do j=js,je
         Nsqrw(:,j,k)=Nsqrw(:,j,k) -(back(:,j,k+1,taum1)-back(:,j,k,taum1))/dz 
        enddo
       enddo
  endif
 !-----------------------------------------------------------------------
 !     calculate diffusivity
 !-----------------------------------------------------------------------
  if (enable_pv_mixing_calc_diff.and.((mod(itt,int(pv_mixing_calcint/dt))== 0).or.pv_mixing_initialized)) then
        meanN(:,js:je)=0.0
        do k=1,nz-1
         do j=js,je
          do i=2,nx-1
            meanN(i,j)=meanN(i,j)+sqrt(max(0d0,Nsqrw(i,j,k)))*maskT(i,j,k)*dz
          enddo
         enddo
        enddo
        where (ht /=0.0 ) meanN=meanN/ht
        where (ht ==0.0 ) meanN=0.0
        call pv_mixing_calc_diff
        !call pv_mixing_calc_diff_somewhere
  endif
  Nsqrw(:,js:je,:)=max(Nsqrw(:,js:je,:),N_min**2)
  call border_exchg3D(nx,ny,nz,Nsqrw,1); call setcyclic3D(nx,ny,nz,Nsqrw )
 !-----------------------------------------------------------------------
 !     Calculate f**2/N**2 on W grid and bound 
 !-----------------------------------------------------------------------
  do k=1,nz-1
       do j=js,je
        do i=2,nx -1
         fNsqrw(i,j,k)= coriolis_t(j)**2/Nsqrw(i,j,k)
        enddo
       enddo
  enddo
  fNsqrw(:,js:je,:)=min(fNsqrw(:,js:je,:),fNsqr_max)
  call border_exchg3D(nx,ny,nz,fNsqrw,1); call setcyclic3D(nx,ny,nz,fNsqrw )

  if (enable_pv_mixing_full) then
    call pv_mixing_zonal
    call pv_mixing_meridional 
  else
    call pv_mixing_zonal_trm
    call pv_mixing_meridional_trm
  endif
  pv_mixing_initialized = .false.
end subroutine pv_mixing_main




subroutine pv_mixing_set_diff
 !-----------------------------------------------------------------------
 !     set diffusivity 
 !-----------------------------------------------------------------------
 use pyOM_module
 use pv_mixing_module
 implicit none
 integer :: i,j
 do j=1,ny
   do i=1,nx
     K_pv(i,j,:,:) = K_gm
   enddo
 enddo
end subroutine pv_mixing_set_diff



subroutine pv_mixing_read_diff
 !-----------------------------------------------------------------------
 !     read diffusivity from NetCDF file
 !-----------------------------------------------------------------------
      use pyOM_module
      use pv_mixing_module
      implicit none
      include "netcdf.inc"
      integer :: ncid,iret,corner(4), edges(4)
      integer :: varid,i,j,is,ie,js,je,k

      js=max(2,js_pe); je = min(je_pe,ny-1)
      if (my_pe==0) print*,' reading diffusivity from file diff_interp.cdf'
      iret=nf_open('diff_interp.cdf',NF_WRITE,ncid)
      iret=nf_inq_varid(ncid,'DIFF2',varid)
!        if (iret.ne.0) print*,nf_strerror(iret)
      Corner = (/1,1,1,1/); 
      edges  = (/nx,ny,nz,1/)
      iret=nf_get_vara_double(ncid,varid,corner,edges,K_pv(:,:,:,1))
      K_pv(:,:,:,1)=min(max(K_pv(:,:,:,1)*maskT,1d2),5d4)
      K_pv(:,:,:,2) = K_pv(:,:,:,1)
      call ncclos(ncid,iret)
end subroutine pv_mixing_read_diff


subroutine running_mean1(a,b,nx,k)
 implicit none
 integer :: nx,k,i,ii
 real*8 :: a(nx),b(nx)
 b(1)=0.0
 do ii=1,min(nx,1+k)
   b(1)=b(1)+a(ii)
 enddo
 do i=2,nx
   b(i)=b(i-1)
   ii=i-k-1
   if (ii>=1) b(i)=b(i-1)-a(ii)
   ii=i+k
   if (ii<=nx) b(i)=b(i)+a(ii)
 enddo
end subroutine running_mean1


subroutine running_mean2(a,bb,nx_,ny_,k)
 use pyOM_module
 implicit none
 integer :: nx_,ny_,k
 real*8 :: a(nx_,ny_),bb(nx_,ny_),c(nx_,ny_)
 integer :: i,j,js,je 

 js=max(2,js_pe); je = min(je_pe,ny -1)
 c(:,1:js)=0.
 c(:,je+1:ny)=0.
 do j=js,je
  call running_mean1(a(:,j),c(:,j),nx,k)
 enddo
 call border_exchg2D(nx,ny,c,k+1)
 if (k+1>n_pes_j)   call border_exchg2D(nx,ny,c,k+1)
 if (k+1>2*n_pes_j) call border_exchg2D(nx,ny,c,k+1)
 do i=1,nx 
  call running_mean1(c(i,:),bb(i,:),ny,k)
 enddo
end subroutine running_mean2




subroutine pv_mixing_calc_diff
 !-----------------------------------------------------------------------
 ! use linear instability analysis to calculate diffusivity at all points 
 !-----------------------------------------------------------------------
 use pyOM_module
 use pv_mixing_module
 implicit none
 integer :: js,je,i,j,k,nb,ier=0
 real*8 :: rho1(nx,ny,nz),p1(nx,ny,nz),rxytmp(nx,ny)
 real*8, save, allocatable :: rmask(:,:,:)
 real*8 :: rho(nz),ub(nz),vb(nz),beta,f0,Lr(nx,ny)
 real*8 :: fxa,by,bx,amp,gb,umag,qx(nz),qy(nz),kmag
 complex*8 :: phiz(nz),phi(nz),im,c
 logical, save :: first = .true.

  im=cmplx(0,1)
  js=max(2,js_pe); je = min(je_pe,ny -1)
  nb=nz-2
  
  Lr(:,js:je)=0.0
  do j=js,je
   do i=2,nx-1
     Lr(i,j) = meanN(i,j)*ht(i,j)/abs(coriolis_t(j))*maskT(i,j,nz-1)
   enddo
  enddo
 
  rho1=b(:,:,:,tau)*maskT*rho_0/grav
  if (enable_back_state) rho1=rho1+back(:,:,:,tau)*maskT*rho_0/grav
  p1=p_full(:,:,:,tau)*maskT

  if (enable_pv_mixing_smooth_backstate .or. enable_pv_mixing_smooth_diff) then
   if (first) then
       allocate(rmask(nx,ny,nz))
       rmask(:,:,:)=maskT
       call border_exchg3D(nx,ny,nz,rmask,k_runmean+1)
       do k=2,nz-1
        call running_mean2(rmask(:,:,k),rxytmp,nx,ny,k_runmean)    
        rmask(:,:,k)=rxytmp*maskT(:,:,k)
       enddo
   endif
  endif

  if (enable_pv_mixing_smooth_backstate) then 
   ! apply a running mean to background state
   do k=2,nz-1
     call running_mean2(rho1(:,:,k),rxytmp,nx,ny,k_runmean)    
     where (rmask(:,:,k)/=0.0); rho1(:,:,k)=rxytmp/rmask(:,:,k); elsewhere; rho1(:,:,k)=0.0; end where
     call running_mean2(p1(:,:,k),rxytmp,nx,ny,k_runmean)    
     where (rmask(:,:,k)/=0.0); p1(:,:,k)=rxytmp/rmask(:,:,k) ; elsewhere ; p1(:,:,k)=0.0; end where
   enddo
  endif

  K_rho(:,:,:,1)= K_rho(:,:,:,2)
  K_pv (:,:,:,1)= K_pv (:,:,:,2)
  do j=js,je
   do i=2,nx-1
    if (maskT(i,j,nz-1) == 1.0) then
      f0=coriolis_t(j)
      rho(1:nb) = rho1(i,j,nz-1:2:-1) 
      do k=2,nz-1
        ub(k) = - ( (p1(i,j+1,k)-p1(i,j  ,k))/dx*maskV(i,j  ,k)  &
                   +(p1(i,j  ,k)-p1(i,j-1,k))/dx*maskV(i,j-1,k)) &
                        /( ( maskV(i,j,k)+maskV(i,j-1,k)+1e-12)*f0)
        vb(k) =   ( (p1(i+1,j,k)-p1(i  ,j,k))/dx*maskU(i  ,j,k)  &
                   +(p1(i  ,j,k)-p1(i-1,j,k))/dx*maskU(i-1,j,k)) &
                        /( ( maskU(i,j,k)+maskU(i-1,j,k)+1e-12)*f0)
      enddo
      ub(1:nb)=ub(nz-1:2:-1); vb(1:nb)=vb(nz-1:2:-1)
      beta=(coriolis_t(j+1)-coriolis_t(j-1))/(2*dx)
      fxa=max(dx,Lr(i,j))
      call vertical_eigenvalues0b(nb,rho,ub,vb,f0,beta,dz,fxa,om(i,j),phi,kx_max(i,j),ky_max(i,j),qx,qy,ier) 
      phi(nz-1:2:-1)=phi(1:nb)
      ub(nz-1:2:-1)=ub(1:nb); vb(nz-1:2:-1)=vb(1:nb)
      qx(nz-1:2:-1)=qx(1:nb); qy(nz-1:2:-1)=qy(1:nb)
      ! amplitude 
      kmag = sqrt(kx_max(i,j)**2+ky_max(i,j)**2)
      amp = 0.0
      if (kmag/=0.0 .and. imag(om(i,j))>0.0 ) amp = K_w*2*pi*imag(om(i,j))/kmag**2
      !  diffusivity for eddy potential vorticity flux
      do k=2,nz-1
          c=om(i,j)/kmag
          umag = (ub(k)*kx_max(i,j)+vb(k)*ky_max(i,j))/kmag
          fxa = 0.5*amp**2*kmag*imag(c) *abs(phi(k)/(umag-c))**2
          gb=sqrt(qx(k)**2+qy(k)**2)+1e-28
          K_iso(i,j,k) = fxa
          K_pv(i,j,k,2)= fxa * ( (-ky_max(i,j)*Qx(k)+kx_max(i,j)*Qy(k))/kmag)**2/gb**2
          K_pv(i,j,k,2)=K_pv(i,j,k,2)/(1.+u(i,j,k,tau)**2/u_max**2 )
          K_pv(i,j,k,2)=min(K_max,max(K_min, K_pv(i,j,k,2) ))
      enddo
      ! vertical derivative is buoyancy
      do k=2,nz-1
       phiz(k)=(phi(k+1)-phi(k-1))/(2*dz)
      enddo
      phiz(2)=(phi(3)-phi(2))/dz
      phiz(nz-1)=(phi(nz-1)-phi(nz-2))/dz
      ! diffusivity for eddy buoyancy flux
      do k=2,nz-1
          !by   = -(rho1(i,j+1,k)-rho1(i,j-1,k))/(2*dx) *grav/rho_0
          !bx   = -(rho1(i+1,j,k)-rho1(i-1,j,k))/(2*dx) *grav/rho_0
          by   = -grav/rho_0*( (rho1(i,j+1,k)-rho1(i,j,k))/dx *maskV(i,j,k) &
                              +(rho1(i,j,k)-rho1(i,j-1,k))/dx*maskV(i,j-1,k)) &
                                 /( maskV(i,j,k)+maskV(i,j-1,k)+1e-12)
          bx   = -grav/rho_0*( (rho1(i+1,j,k)-rho1(i,j,k))/dx *maskU(i,j,k) &
                              +(rho1(i,j,k)-rho1(i-1,j,k))/dx*maskU(i-1,j,k)) &
                                 /( maskU(i,j,k)+maskU(i-1,j,k)+1e-12)
          gb =    max(M_min, sqrt(bx**2+by**2) )
          fxa = 0.5*amp**2*kmag*f0*real(im*phi(k)*conjg(phiz(k)))/gb
          K_rho(i,j,k,2) =  fxa*(-ky_max(i,j)*bx+kx_max(i,j)*by)/kmag/gb
      enddo
      if (sum(K_rho(i,j,2:nz-1,2))<0.0) then
           K_rho(i,j,:,2)=-K_rho(i,j,:,2) ! choose complex conjugate solution
      endif
      K_rho(i,j,:,2)=K_rho(i,j,:,2)/(1.+u(i,j,:,tau)**2/u_max**2 )
      K_rho(i,j,:,2)=min(K_max,max(K_min, K_rho(i,j,:,2) ))
    endif ! maskT
   enddo
  enddo

 if (enable_pv_mixing_smooth_diff) then
   call setcyclic3D(nx,ny,nz,K_rho(:,:,:,2) )
   call setcyclic3D(nx,ny,nz,K_pv(:,:,:,2) )
   rxytmp=0
   do k=2,nz-1
     call running_mean2(K_rho(:,:,k,2),rxytmp,nx,ny,k_runmean)    
     where (rmask(:,:,k)/=0.0); K_rho(:,:,k,2)=rxytmp/rmask(:,:,k); elsewhere; K_rho(:,:,k,2)=0.0; end where
     call running_mean2(K_pv(:,:,k,2),rxytmp,nx,ny,k_runmean)    
     where (rmask(:,:,k)/=0.0); K_pv(:,:,k,2)=rxytmp/rmask(:,:,k); elsewhere; K_pv(:,:,k,2)=0.0; end where
   enddo
  endif

  call border_exchg3D(nx,ny,nz,K_rho(:,:,:,2),1)
  call setcyclic3D(nx,ny,nz,K_rho(:,:,:,2) )
  call border_exchg3D(nx,ny,nz,K_pv(:,:,:,2),1)
  call setcyclic3D(nx,ny,nz,K_pv(:,:,:,2) )
  if (pv_mixing_initialized)  then
       K_pv (:,:,:,1) = K_pv (:,:,:,2)
       K_rho(:,:,:,1) = K_rho(:,:,:,2)
  endif
  first=.false.
end subroutine pv_mixing_calc_diff





subroutine pv_mixing_calc_diff_somewhere
 !-----------------------------------------------------------------------
 ! use linear instability analysis to calculate diffusivity
 ! at some points and interpolate inbetween
 !-----------------------------------------------------------------------
 use pyOM_module
 use pv_mixing_module
 implicit none
 integer :: js,je,i,j,k,nb,ier=0,ii,jj,jLr
 real*8 :: rho(nz),ub(nz),vb(nz),beta,f0,Lr(nx,ny),mLr
 real*8 :: fxa,by,fxb,uflx,vflx,bx,amp,umag,gb
 real*8 :: xx(nx),yy(ny),diff1(nx,ny,nz),d2(nx,ny),diff2(nx,ny,nz),d3(nx,ny),d4(nx,ny)
 real*8 :: qx(nz),qy(nz),spval = -1d33
 complex*8 :: phiz(nz),phi(nz),im,c,new_om
 logical, save :: first = .true.
 real*8, save, allocatable :: rmask(:,:,:)
 real*8 :: rho1(nx,ny,nz), p1(nx,ny,nz), rxytmp(nx,ny)

  im=cmplx(0,1)
  js=max(2,js_pe); je = min(je_pe,ny -1)
  nb=nz-2
  om=cmplx(1.,1.)*spval; phi=cmplx(1,1)*spval; 
  K_iso=spval
  kx_max=spval; ky_max=spval;
  
  Lr(:,js:je)=0.0
  fxa=0;mLr=0
  do j=js,je
   do i=2,nx-1
     Lr(i,j) = meanN(i,j)*ht(i,j)/abs(coriolis_t(j))*maskT(i,j,nz-1)
     !Lr(i,j) = 100e3
     mLr=mLr+Lr(i,j)*maskT(i,j,nz-1)
     fxa=fxa+maskT(i,j,nz-1)
   enddo
  enddo
  call global_sum(mLr)
  call global_sum(fxa)
  if (fxa/=0.0) mLr=mLr/fxa
  jLr = int(len_fac*mLr/dx);  jLr = max(1,jLr); jLr = min(ny/2,jLr)
  if (my_pe==0) print*,' mean Rossby radius is ',mLr/1e3,' km'
  if (my_pe==0) print*,' calculating diffusivity any',jLr,' grid points'


  rho1=b(:,:,:,tau)*maskT*rho_0/grav
  if (enable_back_state) rho1=rho1+back(:,:,:,tau)*maskT*rho_0/grav
  p1=p_full(:,:,:,tau)*maskT

  if (enable_pv_mixing_smooth_backstate .or. enable_pv_mixing_smooth_diff) then
   if (first) then
      allocate(rmask(nx,ny,nz))
      rmask(:,:,:)=maskT
      call border_exchg3D(nx,ny,nz,rmask,k_runmean+1)
      do k=2,nz-1
       call running_mean2(rmask(:,:,k),rxytmp,nx,ny,k_runmean)    
       rmask(:,:,k)=rxytmp*maskT(:,:,k)
      enddo
   endif
  endif

  if (enable_pv_mixing_smooth_backstate) then 
   ! apply a running mean to background state
   do k=2,nz-1
     call running_mean2(rho1(:,:,k),rxytmp,nx,ny,k_runmean)    
     where (rmask(:,:,k)/=0.0); rho1(:,:,k)=rxytmp/rmask(:,:,k); elsewhere; rho1(:,:,k)=0.0; end where
     call running_mean2(p1(:,:,k),rxytmp,nx,ny,k_runmean)    
     where (rmask(:,:,k)/=0.0); p1(:,:,k)=rxytmp/rmask(:,:,k) ; elsewhere ; p1(:,:,k)=0.0; end where
   enddo
  endif


  diff1(:,:,:)=spval; diff2(:,:,:)=spval ! all rows need to be initialized
  jj=0
  do j=2+jLr/2,ny-1,jLr
   jj=jj+1; ;yy(jj)=yt(j); ii=0;  
   !do i=min(nx-1,2) ,nx-1,jLr
   do i=2+JLr/2,nx-1,jLr
    ii=ii+1; xx(ii)=xt(i)
    if ((maskT(i,j,nz-1) == 1.0) .and. (j>=js .and. j<= je ) ) then
      f0=coriolis_t(j)
      rho(1:nb) = rho1(i,j,nz-1:2:-1) 
      do k=2,nz-1
        ub(k) = - ( (p1(i,j+1,k)-p1(i,j  ,k))/dx*maskV(i,j  ,k)  &
                   +(p1(i,j  ,k)-p1(i,j-1,k))/dx*maskV(i,j-1,k)) &
                        /( ( maskV(i,j,k)+maskV(i,j-1,k)+1e-12)*f0)
        vb(k) =   ( (p1(i+1,j,k)-p1(i  ,j,k))/dx*maskU(i  ,j,k)  &
                   +(p1(i  ,j,k)-p1(i-1,j,k))/dx*maskU(i-1,j,k)) &
                        /( ( maskU(i,j,k)+maskU(i-1,j,k)+1e-12)*f0)
      enddo
      ub(1:nb)=ub(nz-1:2:-1); vb(1:nb)=vb(nz-1:2:-1)
      beta=(coriolis_t(j+1)-coriolis_t(j-1))/(2*dx)
      fxa=max(dx,Lr(i,j))
      call vertical_eigenvalues3(nb,rho,ub,vb,f0,beta,dz,fxa,new_om,phi,kx_max(i,j),ky_max(i,j),qx,qy,ier) 
      if (ier/=0) then
       print*,' ERROR: failed to solve vertical eigenvalue problem at i,j=',i,j
      endif
      ! check if growth rate has changed mare than 20%
      !if (ier==0.and.((imag(new_om)>imag(om(i,j))*1.20).or.(imag(new_om)<imag(om(i,j))*0.80))) then
      if (ier==0) then
        om(ii,jj)=new_om
        phi(nz-1:2:-1)=phi(1:nb)
        ub(nz-1:2:-1)=ub(1:nb); vb(nz-1:2:-1)=vb(1:nb)
        qx(nz-1:2:-1)=qx(1:nb); qy(nz-1:2:-1)=qy(1:nb)
        ! amplitude 
        fxb = sqrt(kx_max(i,j)**2+ky_max(i,j)**2)
        amp = 0.0
        if (fxb/=0.0) amp = K_w*2*pi/fxb*imag(om(ii,jj))/fxb
        !  diffusivity for eddy potential vorticity flux
        do k=2,nz-1
          fxb = sqrt(kx_max(i,j)**2+ky_max(i,j)**2)
          c=om(ii,jj)/fxb
          umag = (ub(k)*kx_max(i,j)+vb(k)*ky_max(i,j))/fxb
          fxa = 0.5*amp**2 *fxb*imag(c) *abs(phi(k)/(umag-c))**2
          gb=sqrt(qx(k)**2+qy(k)**2)+1e-28
          K_iso(i,j,k)=fxa
          diff1(ii,jj,k)= fxa* ( (-ky_max(i,j)*Qx(k)+kx_max(i,j)*Qy(k))/fxb )**2/gb**2
          diff1(ii,jj,k)=min(K_max,max(K_min, diff1(ii,jj,k) ))
        enddo
        ! vertical derivative is buoyancy
        do k=2,nz-1
         phiz(k)=(phi(k+1)-phi(k-1))/(2*dz)
        enddo
        phiz(2)=(phi(3)-phi(2))/dz
        phiz(nz-1)=(phi(nz-1)-phi(nz-2))/dz
        ! diffusivity for eddy buoyancy flux
        do k=2,nz-1
          by   = -(rho1(i,j+1,k)-rho1(i,j-1,k))/(2*dx) /rho_0*grav
          bx   = -(rho1(i+1,j,k)-rho1(i-1,j,k))/(2*dx) /rho_0*grav
          !gb = sqrt(bx**2+by**2)+1e-22
          gb =    max(M_min, sqrt(bx**2+by**2) )
          fxb = sqrt(kx_max(i,j)**2+ky_max(i,j)**2)
          fxa = 0.5*amp**2*fxb*f0*real(im*phi(k)*conjg(phiz(k)))/gb
          diff2(ii,jj,k) =  fxa*(-ky_max(i,j)*bx+kx_max(i,j)*by)/fxb/gb
        enddo
        if (sum(diff2(ii,jj,2:nz-1))<0.0) then
           diff2(ii,jj,:)=-diff2(ii,jj,:) ! choose complex conjugate solution
        endif
        diff2(ii,jj,:)=min(K_max,max(K_min, diff2(ii,jj,:) ))
      else ! growth rate has not changed much, use old estimates
        diff1(ii,jj,:)=K_pv(i,j,:,2)
        diff2(ii,jj,:)=K_rho(i,j,:,2)
        om(ii,jj) = om(i,j)
      endif 
    endif ! maskT
   enddo
  enddo

  ! communication between PEs
  jj=0
  do j=2+jLr/2 ,ny-1,jLr
   jj=jj+1; 
   i=0; if (j>=js .and. j<=je) i=my_pe
   call global_max_int(i,1)
   do k=2,nz-1
    call bcast_real(diff1(1,jj,k),ii,i)
    call bcast_real(diff2(1,jj,k),ii,i)
   enddo
   d2(1:ii,1)=real(om(1:ii,jj)); call bcast_real(d2(1,1),ii,i) 
   d2(1:ii,2)=imag(om(1:ii,jj)); call bcast_real(d2(1,2),ii,i)
   om(1:ii,jj) = cmplx( d2(1:ii,1), d2(1:ii,2) )
  enddo

  ! now interpolate between estimates (done by every pe)
  xx(1:ii+2)=(/xt(1),xx(1:ii),xt(nx)/)
  yy(1:jj+2)=(/yt(1),yy(1:jj),yt(ny)/) 

  ! interpolate growth rate
  d2=spval; d3=0; 
  d2(2:ii+1,2:jj+1)=real( om(1:ii,1:jj) )
  where (d2(2:ii+1,2:jj+1)<-1.0) d2(2:ii+1,2:jj+1)=spval
  if (enable_cyclic_x) then
    d2(1,:)=d2(ii+1,:); d2(ii+2,:)=d2(2,:)
  endif
  call fillgaps(ii+2,jj+2,d2(1:ii+2,1:jj+2),spval)
  call rgrd2(ii+2,jj+2,xx,yy,d2(1:ii+2,1:jj+2),nx,ny,xt,yt,d3,ier)
  d2=spval; d4=0; 
  d2(2:ii+1,2:jj+1)=imag( om(1:ii,1:jj) )
  where (d2(2:ii+1,2:jj+1)<-1.0) d2(2:ii+1,2:jj+1)=spval
  if (enable_cyclic_x) then
    d2(1,:)=d2(ii+1,:); d2(ii+2,:)=d2(2,:)
  endif
  call fillgaps(ii+2,jj+2,d2(1:ii+2,1:jj+2),spval)
  call rgrd2(ii+2,jj+2,xx,yy,d2(1:ii+2,1:jj+2),nx,ny,xt,yt,d4,ier)
  om = cmplx(d3,d4) 

  do k=2,nz-1
   ! interpolate PV diffusivity
   d2=spval
   d2(2:ii+1,2:jj+1)=diff1(1:ii,1:jj,k)
   if (enable_cyclic_x) then
    d2(1,:)=d2(ii+1,:); d2(ii+2,:)=d2(2,:)
   endif
   call fillgaps(ii+2,jj+2,d2(1:ii+2,1:jj+2),spval)
   K_pv(:,:,k,1)= K_pv(:,:,k,2)
   call rgrd2(ii+2,jj+2,xx,yy,d2(1:ii+2,1:jj+2),nx,ny,xt,yt,K_pv(:,:,k,2),ier)
   if (ier/=0) then 
      print*,' Error: interpolation of K_pv failed'
      print*,'ier=',ier
      stop
   endif
   ! interpolate density diffusivity
   d2=spval
   d2(2:ii+1,2:jj+1)=diff2(1:ii,1:jj,k)
   if (enable_cyclic_x) then
    d2(1,:)=d2(ii+1,:); d2(ii+2,:)=d2(2,:)
   endif
   call fillgaps(ii+2,jj+2,d2(1:ii+2,1:jj+2),spval)
   K_rho(:,:,k,1)= K_rho(:,:,k,2)
   call rgrd2(ii+2,jj+2,xx,yy,d2(1:ii+2,1:jj+2),nx,ny,xt,yt,K_rho(:,:,k,2),ier)
  enddo

 if (enable_pv_mixing_smooth_diff) then
   call setcyclic3D(nx,ny,nz,K_rho(:,:,:,2) )
   call setcyclic3D(nx,ny,nz,K_pv(:,:,:,2) )
   rxytmp=0
   do k=2,nz-1
     call running_mean2(K_rho(:,:,k,2),rxytmp,nx,ny,k_runmean)    
     where (rmask(:,:,k)/=0.0); K_rho(:,:,k,2)=rxytmp/rmask(:,:,k); elsewhere; K_rho(:,:,k,2)=0.0; end where
     call running_mean2(K_pv(:,:,k,2),rxytmp,nx,ny,k_runmean)    
     where (rmask(:,:,k)/=0.0); K_pv(:,:,k,2)=rxytmp/rmask(:,:,k); elsewhere; K_pv(:,:,k,2)=0.0; end where
   enddo
  endif


  call border_exchg3D(nx,ny,nz,K_pv(:,:,:,2),1)
  call setcyclic3D(nx,ny,nz,K_pv(:,:,:,2) )
  call border_exchg3D(nx,ny,nz,K_rho(:,:,:,2),1)
  call setcyclic3D(nx,ny,nz,K_rho(:,:,:,2) )
  if (pv_mixing_initialized)  then
       K_rho(:,:,:,1) = K_rho(:,:,:,2)
       K_pv(:,:,:,1)  = K_pv(:,:,:,2)
  endif
  first = .false.
end subroutine pv_mixing_calc_diff_somewhere





subroutine pv_mixing_zonal_trm
 !=======================================================================
 !     calculate only TRM force and add to tendencies for u
 !=======================================================================
  use pyOM_module
  use fcontrol_module
  use pv_mixing_module
  implicit none
  integer :: i,j,k,js,je,n,ne
  real*8 :: fxa,fxb,fxc, tx
  real*8, dimension(nx,ny,nz) :: diff_ft,A_uu,diff_fe,diff_fn

  ! time interpolation factor
  ne = int(pv_mixing_calcint/dt)
  n = itt/ne
  n = itt-n*ne
  tx = (n*1.0)/(1.0*ne)

  js=max(2,js_pe); je = min(je_pe,ny -1)
  vtendu(:,js:je,:)=0.; A_uu(:,js:je,:)=0.0; htendu(:,js:je,:)=0.0
 !-----------------------------------------------------------------------
 !      vertical friction of residual momentum: K(f^2/N^2 u_z)_z
 !      prepare coefficients for implicit part
 !-----------------------------------------------------------------------
  do k=1,nz-1
        do j=js,je
         do i=2,nx -1
          fxb=0.5*(K_rho(i,j,k,2)+K_rho(i+1,j,k,2))
          fxc=0.5*(K_rho(i,j,k,1)+K_rho(i+1,j,k,1))
          fxb = fxb * tx + fxc * (1-tx)
          fxa=0.5*(fNsqrw(i,j,k)+fNsqrw(i+1,j,k))
          A_uu(i,j,k) =fxb*fxa*maskU(i,j,k+1)*maskU(i,j,k)
         enddo
        enddo
  enddo
 !-----------------------------------------------------------------------
 !     now explicit part
 !-----------------------------------------------------------------------
  diff_ft(:,js:je,:)=0.
  do k=1,nz-1
        do j=js,je
         do i=2,nx -1
          diff_ft(i,j,k)=A_uu(i,j,k)*(u(i,j,k+1,taum1)-u(i,j,k,taum1))/dz
         enddo
        enddo
       enddo
       do k=2,nz-1
        do j=js,je
         do i=2,nx -1
          vtendu(i,j,k)=maskU(i,j,k)*(diff_ft(i,j,k)-diff_ft(i,j,k-1))/dz
         enddo
        enddo
  enddo
 !-----------------------------------------------------------------------
 ! add also horizontal friction
 !-----------------------------------------------------------------------
 diff_fe(:,js:je,:)=0.0; diff_fn(:,js-1:je,:)=0.0; fxa=A_h;A_h=K_gm
 call harm_hfric_u(nx,ny,nz,diff_fe,diff_fn)
 A_h=fxa
 do k=2,nz-1
       do j=js,je
        do i=2,nx -1
         htendu(i,j,k)=maskU(i,j,k)*( (diff_fn(i,j,k) - diff_fn(i,j-1,k))/dx    &
                                    + (diff_fe(i,j,k) - diff_fe(i-1,j,k))/dx )
        enddo
       enddo
 enddo
 !-----------------------------------------------------------------------
 !     Add to zonal momentum tendencies
 !-----------------------------------------------------------------------
  do k=2,nz-1
      do j=js,je
        do i=2,nx -1
         fu(i,j,k)=fu(i,j,k)+vtendu(i,j,k)*(1-aidif_trm)+htendu(i,j,k)
        enddo
      enddo
  enddo
 !-----------------------------------------------------------------------
 !     implicit part of vertical operator
 !-----------------------------------------------------------------------
  call trm_implicit_umix(nx,ny,nz,A_uu)
end subroutine pv_mixing_zonal_trm




subroutine pv_mixing_meridional_trm
 !=======================================================================
 !     calculate only TRM force and add to tendencies for v
 !=======================================================================
  use pyOM_module
  use fcontrol_module
  use pv_mixing_module
  implicit none
  integer :: i,j,k,js,je,n,ne
  real*8 :: fxa,fxb,fxc,tx
  real*8, dimension(nx,ny,nz) :: diff_ft,A_uu,diff_fe,diff_fn

  ! time interpolation factor
  ne = int(pv_mixing_calcint/dt)
  n = itt/ne
  n = itt-n*ne
  tx = (n*1.0)/(1.0*ne)


  js=max(2,js_pe); je = min(je_pe,ny -1)
  vtendv(:,js:je,:)=0.; A_uu(:,js:je,:)=0.0; htendv(:,js:je,:)=0.0
 !-----------------------------------------------------------------------
 !      vertical friction of residual momentum: (Kf^2/N^2 v_z)_z
 !      prepare coefficients for implicit part
 !-----------------------------------------------------------------------
  do k=1,nz-1
        do j=js,je
         do i=2,nx -1

          fxb=0.5*(K_rho(i,j,k,2)+K_rho(i,j+1,k,2))
          fxc=0.5*(K_rho(i,j,k,1)+K_rho(i,j+1,k,1))
          fxb = fxb * tx + fxc * (1-tx)

          fxa=0.5*(fNsqrw(i,j,k)+fNsqrw(i,j+1,k))
          A_uu(i,j,k) =fxb*fxa*maskV(i,j,k+1)*maskV(i,j,k)
         enddo
        enddo
  enddo
 !-----------------------------------------------------------------------
 !     now explicit part
 !-----------------------------------------------------------------------
  diff_ft(:,js:je,:)=0.
  do k=1,nz-1
        do j=js,je
         do i=2,nx -1
          diff_ft(i,j,k)=A_uu(i,j,k)*(v(i,j,k+1,taum1)-v(i,j,k,taum1))/dz
         enddo
        enddo
       enddo
       do k=2,nz-1
        do j=js,je
         do i=2,nx -1
          vtendv(i,j,k)=maskV(i,j,k)*(diff_ft(i,j,k)-diff_ft(i,j,k-1))/dz
         enddo
        enddo
  enddo
 !-----------------------------------------------------------------------
 ! add also horizontal friction
 !-----------------------------------------------------------------------
  diff_fe(:,js:je,:)=0.0; diff_fn(:,js-1:je,:)=0.0; fxa=A_h;A_h=K_gm
  call harm_hfric_v(nx,ny,nz,diff_fe,diff_fn)
  A_h=fxa
  do k=2,nz-1
       do j=js,je
        do i=2,nx -1
         htendv(i,j,k)=maskV(i,j,k)* ( (diff_fn(i,j,k) - diff_fn(i,j-1,k))/dx   &
                                     + (diff_fe(i,j,k) - diff_fe(i-1,j,k))/dx )
        enddo
       enddo
  enddo
 !-----------------------------------------------------------------------
 !     Add to meridional momentum tendencies
 !-----------------------------------------------------------------------
  do k=2,nz-1
       do j=js,je
        do i=2,nx -1
         fv(i,j,k)=fv(i,j,k)+vtendv(i,j,k)*(1-aidif_trm) +htendv(i,j,k)
        enddo
       enddo
  enddo
 !-----------------------------------------------------------------------
 !     implicit part of vertical operator
 !-----------------------------------------------------------------------
  call trm_implicit_vmix(nx,ny,nz,A_uu)
end subroutine pv_mixing_meridional_trm



subroutine pv_mixing_zonal
 !=======================================================================
 !     calculate all forces and add to tendencies for u
 !=======================================================================
  use pyOM_module
  use fcontrol_module
  use pv_mixing_module
  implicit none
  integer :: i,j,k,js,je,jj,ii,n,ne
  real*8 :: fxa,fxb,fxc,tx
  real*8, dimension(nx,ny,nz) :: diff_ft,diff_fe,diff_fn,A_uu,A2_uu
  logical, save :: first = .false.
  real*8, allocatable, save :: rmask(:,:,:)
  real*8 :: rxytmp(nx,ny)

  ! time interpolation factor
  ne = int(pv_mixing_calcint/dt)
  n = itt/ne
  n = itt-n*ne
  tx = (n*1.0)/(1.0*ne)

  js=max(2,js_pe); je = min(je_pe,ny -1)
  vtendu(:,js:je,:)=0.; htendu(:,js:je,:)=0; theta_u(:,js:je,:) = 0
  A_uu(:,js:je,:)=0.0;A2_uu(:,js:je,:)=0 ! A_uu is factor inside 2. derivative, A2_uu outside
 !-----------------------------------------------------------------------
 !      vertical friction of residual momentum: K(f^2/N^2 u_z)_z
 !      prepare coefficients for implicit part
 !-----------------------------------------------------------------------
  do k=1,nz-1
        do j=js,je
         do i=2,nx -1
          fxb=0.5*(K_pv(i,j,k,2)+K_pv(i+1,j,k,2))
          fxc=0.5*(K_pv(i,j,k,1)+K_pv(i+1,j,k,1))
          fxb = fxb * tx + fxc * (1-tx)

          fxa=0.5*(fNsqrw(i,j,k)+fNsqrw(i+1,j,k))
          A2_uu(i,j,k)=fxb
          A_uu(i,j,k) =fxa*maskU(i,j,k+1)*maskU(i,j,k)
         enddo
        enddo
  enddo
 !-----------------------------------------------------------------------
 !     now explicit part
 !-----------------------------------------------------------------------
  diff_ft(:,js:je,:)=0.
  do k=1,nz-1
        do j=js,je
         do i=2,nx -1
          diff_ft(i,j,k)=A_uu(i,j,k)*(u(i,j,k+1,taum1)-u(i,j,k,taum1))/dz
         enddo
        enddo
       enddo
       do k=2,nz-1
        do j=js,je
         do i=2,nx -1
          vtendu(i,j,k)=maskU(i,j,k)*A2_uu(i,j,k)*(diff_ft(i,j,k)-diff_ft(i,j,k-1))/dz
         enddo
        enddo
  enddo
 !-----------------------------------------------------------------------
 !     horizontal friction and beta term: K(u_yy+u_xx-beta)
 !-----------------------------------------------------------------------
  if (enable_pv_mixing_beta_term) then
   do k=2,nz-1
     do j=js,je
       do i=2,nx -1
         fxb=0.5*(K_pv(i,j,k,2)+K_pv(i+1,j,k,2))
         fxc=0.5*(K_pv(i,j,k,1)+K_pv(i+1,j,k,1))
         fxa = fxb * tx + fxc * (1-tx) 
         fxb = (coriolis_t(j+1)-coriolis_t(j-1) )/(2*dx)
         htendu(i,j,k)=fxa*maskU(i,j,k)*( - fxb)
        enddo
     enddo
   enddo
  endif

  if (enable_pv_mixing_friction) then
    diff_fe(:,js:je,:)=0.0; diff_fn(:,js-1:je,:)=0.0; fxa=A_h;A_h=1.0
    call harm_hfric_u(nx,ny,nz,diff_fe,diff_fn)
    A_h=fxa
    do k=2,nz-1
     do j=js,je
       do i=2,nx -1
         fxb=0.5*(K_pv(i,j,k,2)+K_pv(i+1,j,k,2))
         fxc=0.5*(K_pv(i,j,k,1)+K_pv(i+1,j,k,1))
         fxa = fxb * tx + fxc * (1-tx) 
         fxa = min( K_hor_min, fxa  )
         htendu(i,j,k)=fxa*maskU(i,j,k)*( (diff_fn(i,j,k) - diff_fn(i,j-1,k))/dx    &
                                        + (diff_fe(i,j,k) - diff_fe(i-1,j,k))/dx )
        enddo
     enddo
   enddo
  endif

 !-----------------------------------------------------------------------
 !     integrate Reynolds stress and substract mean force
 !     theta = 1/(V) int dv (K(u_yy+u_xx-beta)+K_z f b_y 
 !-----------------------------------------------------------------------
  if (first) then
    allocate(rmask(nx,ny,nz))
    rmask(:,:,:)=maskU
    do k=2,nz-1
       call running_mean2(rmask(:,:,k),rxytmp,nx,ny,j_constraint)    
       rmask(:,:,k)=rxytmp*maskU(:,:,k)
    enddo
  endif
  do k=2,nz-1
    call running_mean2(vtendu(:,:,k)+htendu(:,:,k),rxytmp,nx,ny,j_constraint)    
    where (rmask(:,:,k)/=0.0) theta_u(:,:,k)=rxytmp/rmask(:,:,k)
  enddo

  fxb=0.0;fxc=0.0
  do k=2,nz-1
       do j=js,je
        do i=2,nx -1
         fxa=dz*dx**2*maskU(i,j,k)
         fxb=fxb+fxa*(vtendu(i,j,k)+htendu(i,j,k)-theta_u(i,j,k))
         fxc=fxc+fxa
        enddo
       enddo
  enddo
  call global_sum(fxb)
  call global_sum(fxc)
  if (fxc/=0.0) theta_u(:,js:je,:)=theta_u(:,js:je,:)+fxb/fxc   

 !-----------------------------------------------------------------------
 !     Add to zonal momentum tendencies
 !-----------------------------------------------------------------------
  do k=2,nz-1
      do j=js,je
        do i=2,nx -1
         fu(i,j,k)=fu(i,j,k)+vtendu(i,j,k)*(1-aidif_trm)+htendu(i,j,k)-theta_u(i,j,k)
        enddo
      enddo
  enddo
 !-----------------------------------------------------------------------
 !     implicit part of vertical operator
 !-----------------------------------------------------------------------
  call pv_mixing_implicit_umix(nx,ny,nz,A_uu,A2_uu)
  first = .false.
end subroutine pv_mixing_zonal








subroutine pv_mixing_meridional
 !=======================================================================
 !     calculate all forces and add to tendencies for v
 !=======================================================================
  use pyOM_module
  use fcontrol_module
  use pv_mixing_module
  implicit none
  integer :: i,j,k,js,je,jj,ii,n,ne
  real*8 :: fxa,fxb,fxc,tx
  real*8, dimension(nx,ny,nz) :: diff_ft,diff_fe,diff_fn,A_uu,A2_uu
  logical, save :: first = .false.
  real*8, allocatable, save :: rmask(:,:,:)
  real*8 :: rxytmp(nx,ny)

  ! time interpolation factor
  ne = int(pv_mixing_calcint/dt)
  n = itt/ne
  n = itt-n*ne
  tx = (n*1.0)/(1.0*ne)

  js=max(2,js_pe); je = min(je_pe,ny -1)
  vtendv(:,js:je,:)=0.; htendv(:,js:je,:)=0; theta_v(:,js:je,:) = 0
  A_uu(:,js:je,:)=0.0;A2_uu(:,js:je,:)=0 ! A_uu is factor inside 2. derivative, A2_uu outside
 !-----------------------------------------------------------------------
 !      vertical friction of residual momentum: K(f^2/N^2 v_z)_z
 !      prepare coefficients for implicit part
 !-----------------------------------------------------------------------
  do k=1,nz-1
        do j=js,je
         do i=2,nx -1
          fxb=0.5*(K_pv(i,j,k,2)+K_pv(i,j+1,k,2))
          fxc=0.5*(K_pv(i,j,k,1)+K_pv(i,j+1,k,1))
          fxb =fxb * tx + fxc * (1-tx) 
          fxa=0.5*(fNsqrw(i,j,k)+fNsqrw(i,j+1,k))
          A2_uu(i,j,k)=fxb
          A_uu(i,j,k) =fxa*maskV(i,j,k+1)*maskV(i,j,k)
         enddo
        enddo
  enddo
 !-----------------------------------------------------------------------
 !     now explicit part
 !-----------------------------------------------------------------------
  diff_ft(:,js:je,:)=0.
  do k=1,nz-1
        do j=js,je
         do i=2,nx -1
          diff_ft(i,j,k)=A_uu(i,j,k)*(v(i,j,k+1,taum1)-v(i,j,k,taum1))/dz
         enddo
        enddo
       enddo
       do k=2,nz-1
        do j=js,je
         do i=2,nx -1
          vtendv(i,j,k)=maskV(i,j,k)*A2_uu(i,j,k)*(diff_ft(i,j,k)-diff_ft(i,j,k-1))/dz
         enddo
        enddo
  enddo
 !-----------------------------------------------------------------------
 !     horizontal friction and no beta term: K(v_yy+v_xx)
 !-----------------------------------------------------------------------
  if (enable_pv_mixing_friction) then
   diff_fe(:,js:je,:)=0.0; diff_fn(:,js-1:je,:)=0.0; fxa=A_h;A_h=1.0
   call harm_hfric_v(nx,ny,nz,diff_fe,diff_fn)
   A_h=fxa
   do k=2,nz-1
        do j=js,je
         do i=2,nx -1
          fxb=0.5*(K_pv(i,j,k,2)+K_pv(i,j+1,k,2))
          fxc=0.5*(K_pv(i,j,k,1)+K_pv(i,j+1,k,1))
          fxa = min( K_hor_min, fxb * tx + fxc * (1-tx)  )
          htendv(i,j,k)=fxa*maskV(i,j,k)* ( (diff_fn(i,j,k) - diff_fn(i,j-1,k))/dx   &
                                          + (diff_fe(i,j,k) - diff_fe(i-1,j,k))/dx )
         enddo
        enddo
   enddo
  endif
 !-----------------------------------------------------------------------
 !     integrate Reynolds stress and substract mean force
 !     theta = 1/(V) int dv (K(v_yy+v_xx)-K_z f b_x 
 !-----------------------------------------------------------------------

  if (first) then
    allocate(rmask(nx,ny,nz))
    rmask(:,:,:)=maskV
    do k=2,nz-1
       call running_mean2(rmask(:,:,k),rxytmp,nx,ny,j_constraint)    
       rmask(:,:,k)=rxytmp*maskV(:,:,k)
    enddo
  endif
  do k=2,nz-1
    call running_mean2(vtendv(:,:,k)+htendv(:,:,k),rxytmp,nx,ny,j_constraint)    
    where (rmask(:,:,k)/=0.0) theta_v(:,:,k)=rxytmp/rmask(:,:,k)
  enddo

  fxb=0.0;fxc=0.0
  do k=2,nz-1
       do j=js,je
        do i=2,nx -1
         fxa=dz*dx**2*maskV(i,j,k)
         fxb=fxb+fxa*(vtendv(i,j,k)+htendv(i,j,k)-theta_v(i,j,k))
         fxc=fxc+fxa
        enddo
       enddo
  enddo
  call global_sum(fxb)
  call global_sum(fxc)
  if (fxc/=0.0) theta_v(:,js:je,:)=theta_v(:,js:je,:)+fxb/fxc   
 !-----------------------------------------------------------------------
 !     Add to meridional momentum tendencies
 !-----------------------------------------------------------------------
  do k=2,nz-1
       do j=js,je
        do i=2,nx -1
         fv(i,j,k)=fv(i,j,k)+vtendv(i,j,k)*(1-aidif_trm) +htendv(i,j,k)-theta_v(i,j,k)
        enddo
       enddo
  enddo
 !-----------------------------------------------------------------------
 !     implicit part of vertical operator
 !-----------------------------------------------------------------------
  call pv_mixing_implicit_vmix(nx,ny,nz,A_uu,A2_uu)
  first=.false.
 end subroutine pv_mixing_meridional



subroutine pv_mixing_diag
 !-----------------------------------------------------------------------
 !    diagnostic
 !-----------------------------------------------------------------------
 use pyOM_module
 use pv_mixing_module
 use fcontrol_module
 implicit none
      include "netcdf.inc"
      integer :: ncid,iret,n,npe, corner(4), edges(4)
      real*8 :: a(nx,js_pe:je_pe,nz),fxa,time,fxb,tx
      integer :: itdimid,ilen,itimeid
      integer :: id,i,j,k,ne
      integer :: js,je
      character :: name*24
      real*8,parameter :: spval = -1.0d33


   if (.not. enable_pv_mixing_diag ) return

   if  (mod(itt,int(pv_mixing_diagint/dt))  == 0)  then


     ! time interpolation factor
      ne = int(pv_mixing_calcint/dt)
      n = itt/ne
      n = itt-n*ne
      tx = (n*1.0)/(1.0*ne)

      js=max(2,js_pe); je = min(je_pe,ny -1)
      do npe=0,n_pes
       if (my_pe==npe) then
        iret=nf_open('pvmix.cdf',NF_WRITE,ncid)
        iret=nf_set_fill(ncid, NF_NOFILL, iret)
        iret=nf_inq_varid(ncid,'vtendu',id)
        iret=nf_inq_dimid(ncid,'Time',itdimid)
        iret=nf_inq_dimlen(ncid, itdimid,ilen)
        iret=nf_inq_varid(ncid,'Time',itimeid)
        if (my_pe==0) then
         ilen=ilen+1
         time = itt*dt ! current_time-initial_time
         fxa = time/86400.0 ! time%days + time%seconds/86400.
         iret= nf_put_vara_double(ncid,itimeid,ilen,1,fxa)
        endif
        Corner = (/1,js_pe,1,ilen/); 
        edges  = (/nx,je_pe-js_pe+1,nz,1/)

        iret=nf_inq_varid(ncid,'K_pv',id)
        a(:,js_pe:je_pe,:)=K_pv(:,js_pe:je_pe,:,2) *tx + K_pv(:,js_pe:je_pe,:,1) *(1-tx)
        where( maskT(:,js_pe:je_pe,:) == 0.) a = spval
        iret= nf_put_vara_double(ncid,id,corner,edges,a)

        iret=nf_inq_varid(ncid,'K_rho',id)
        a(:,js_pe:je_pe,:)=K_rho(:,js_pe:je_pe,:,2) *tx + K_rho(:,js_pe:je_pe,:,1) *(1-tx)
        where( maskT(:,js_pe:je_pe,:) == 0.) a(:,js_pe:je_pe,:) = spval
        iret= nf_put_vara_double(ncid,id,corner,edges,a)

        if (enable_pv_mixing_calc_diff) then 

         iret=nf_inq_varid(ncid,'K_iso',id)
         a(:,js_pe:je_pe,:)=K_iso(:,js_pe:je_pe,:) 
         where( maskT(:,js_pe:je_pe,:) == 0.) a(:,js_pe:je_pe,:) = spval
         iret= nf_put_vara_double(ncid,id,corner,edges,a)

         Corner = (/1,js_pe,ilen,1/); 
         edges  = (/nx,je_pe-js_pe+1,1,1/)
         iret=nf_inq_varid(ncid,'om_i',id)
         a(:,js_pe:je_pe,1)=imag(om(:,js_pe:je_pe) )
         where( maskT(:,js_pe:je_pe,nz-1) == 0.) a(:,js_pe:je_pe,1) = spval
         iret= nf_put_vara_double(ncid,id,corner,edges,a)

         iret=nf_inq_varid(ncid,'kx_max',id)
         a(:,js_pe:je_pe,1)=kx_max(:,js_pe:je_pe) 
         where( maskT(:,js_pe:je_pe,nz-1) == 0.) a(:,js_pe:je_pe,1) = spval
         iret= nf_put_vara_double(ncid,id,corner,edges,a)

         iret=nf_inq_varid(ncid,'ky_max',id)
         a(:,js_pe:je_pe,1)=ky_max(:,js_pe:je_pe) 
         where( maskT(:,js_pe:je_pe,nz-1) == 0.) a(:,js_pe:je_pe,1) = spval
         iret= nf_put_vara_double(ncid,id,corner,edges,a)


        endif
        call ncclos (ncid, iret)
       endif
       call fortran_barrier
      enddo
    endif
end subroutine pv_mixing_diag



subroutine pv_mixing_diag_init
 !-----------------------------------------------------------------------
 !    diagnostic
 !-----------------------------------------------------------------------
 use pyOM_module
 use pv_mixing_module
 implicit none

     include "netcdf.inc"
      integer :: ncid,iret,n
      integer :: lon_tdim,z_tdim,itimedim,lat_tdim
      integer :: lon_udim,z_udim,lat_udim,id
      integer :: dims(4)
      character :: name*24, unit*16
      real*8,parameter :: spval = -1.0d33

   if (.not. enable_pv_mixing_diag ) return

   if (my_pe == 0) then

      call def_grid_cdf('pvmix.cdf')
      iret=nf_open('pvmix.cdf',NF_WRITE,ncid)
      iret=nf_set_fill(ncid, NF_NOFILL, iret)
      call ncredf(ncid, iret)
      iret=nf_inq_dimid(ncid,'xt',lon_tdim)
      iret=nf_inq_dimid(ncid,'yt',lat_tdim)
      iret=nf_inq_dimid(ncid,'zt',z_tdim)
      iret=nf_inq_dimid(ncid,'xu',lon_udim)
      iret=nf_inq_dimid(ncid,'yu',lat_udim)
      iret=nf_inq_dimid(ncid,'zu',z_udim)
      iret=nf_inq_dimid(ncid,'Time',itimedim)
      dims = (/Lon_udim,lat_tdim, z_tdim, iTimedim/)

      dims = (/Lon_tdim,lat_tdim, z_tdim, iTimedim/)
      id = ncvdef (ncid,'K_pv', NCFLOAT,4,dims,iret)
      name = 'lateral diffusivity'; unit = 'm^2/s'
      call dvcdf(ncid,id,name,24,unit,16,spval)

      id = ncvdef (ncid,'K_rho', NCFLOAT,4,dims,iret)
      name = 'diffusivity'; unit = 'm^2/s'
      call dvcdf(ncid,id,name,24,unit,16,spval)

      if (enable_pv_mixing_calc_diff) then 

       id = ncvdef (ncid,'K_iso', NCFLOAT,4,dims,iret)
       name = 'diffusivity'; unit = 'm^2/s'
       call dvcdf(ncid,id,name,24,unit,16,spval)

       dims = (/Lon_tdim,lat_tdim, iTimedim,1/)
       id = ncvdef (ncid,'kx_max', NCFLOAT,3,dims,iret)
       name = 'wavenumber'; unit = '1/m'
       call dvcdf(ncid,id,name,24,unit,16,spval)

       id = ncvdef (ncid,'ky_max', NCFLOAT,3,dims,iret)
       name = 'wavenumber'; unit = '1/m'
       call dvcdf(ncid,id,name,24,unit,16,spval)

       id = ncvdef (ncid,'om_i', NCFLOAT,3,dims,iret)
       name = 'growth rate'; unit = '1/s'
       call dvcdf(ncid,id,name,24,unit,16,spval)
      endif

      call ncclos (ncid, iret)
   endif

   call fortran_barrier
end subroutine pv_mixing_diag_init






subroutine pv_mixing_implicit_umix(nx_,ny_,nz_,A_uu,A2_uu)
 !=======================================================================
 !     implicit vertical friction
 !
 !     u^{n+1} = u^{n-1} + 2 dt ( ... + K ( N u_z )_z  )
 !
 !     K( N u_z )_z = K [N^{k+1/2}(u^{k+1}-u^k)/dz - N^{k-1/2}(u^k-u^{k-1})/dz]/dz
 !     K( N u_z )_z = K N^{k+1/2}/dz^2 u^{k+1} - K (N^{k+1/2}+N^{k-1/2})/dz^2 u^k 
 !                    + K N^{k-1/2}/dz^2 u^{k-1}
 !
 !     u^{n+1} = u^* + 2 dt K(N u_z^{n+1})_z
 !     u^k   = u^* + 2 dt/dz^2 K ( N^{k+1/2} u^{k+1} - (N^{k+1/2}+N^{k-1/2}) u^k 
 !                               + N^{k-1/2} u^{k-1} )
 !     u^k (1+ 2 dt/dz^2 K (N^{k+1/2}+K^{k-1/2}) ) = 
 !            u^*+ 2 dt/dz^2 (K^k N^{k+1/2} u^{k+1} + K^k N^{k-1/2} u^{k-1})
 !
 !     upper boundary condition
 !     K ( N u_z )_z  =ca.  [F - K N^{k-1/2}(u^k-u^{k-1})/dz]/dz
 !     K ( N b_z )_z  =   F/dz - K N^{k-1/2}/dz^2 u^k + K N^{k-1/2}/dz^2 u^{k-1}
 !     u^k   = u^* +  2 dt/dz^2 ( F dz - K N^{k-1&2} u^k + K N^{k-1/2} u^{k-1})
 !     u^k (1+ 2 dt/dz^2 K N^{k-1/2} ) = u^*+ F 2dt/dz + 2 dt/dz^2  K N^{k-1/2} u^{k-1}
 !
 !     lower boundary condition
 !     K(N u_z )_z  =  K [N^{k+1/2}(u^{k+1}-u^k)/dz]/dz
 !     K(N u_z )_z  =  K N^{k+1/2}/dz^2 u^{k+1} - K N^{k+1/2}/dz^2 u^k 
 !     u^k (1+ 2 dt/dz^2 K N^{k+1/2} ) = u^*+ 2 dt/dz^2 K N^{k+1/2} u^{k+1} 
 !=======================================================================
      use pyOM_module
      implicit none
      integer :: nx_,ny_,nz_
      integer :: j,k,js,je
      real*8 :: A_uu(nx_,ny_,nz_) ,A2_uu(nx_,ny_,nz_)
      real*8 :: a(nx ,nz),bb(nx ,nz),c(nx ,nz),bet(nx )
      real*8 :: pu(nx ,nz),gam(nx ,nz),fxa,r(nx ,nz)

      js=max(2,js_pe); je = min(je_pe,ny -1)
 !---------------------------------------------------------------------------------
 !      first fake integrate du/dt = F_u, then solve for rest
 !---------------------------------------------------------------------------------
      u(:,js:je,:,taup1)= u(:,js:je,:,taum1) +c2dt*fu(:,js:je,:)*maskU(:,js:je,:)
      do j=js,je
       fxa = aidif_trm*c2dt/dz**2
       bb(:,1) = 1+fxa * A2_uu(:,j,1)*A_uu(:,j,1)
       c(:,1)  =  -fxa * A2_uu(:,j,1)*A_uu(:,j,1)
       do k=2,nz-1
         a(:,k)  =  -fxa * A2_uu(:,j,k)* A_uu(:,j,k-1)
         bb(:,k) = 1+fxa * A2_uu(:,j,k)*(A_uu(:,j,k)+A_uu(:,j,k-1) )
         c(:,k)  =  -fxa * A2_uu(:,j,k)* A_uu(:,j,k)
       enddo
       a(:,nz)  =  -fxa * A2_uu(:,j,nz)* A_uu(:,j,nz-1)
       bb(:,nz) = 1+fxa * A2_uu(:,j,nz)* A_uu(:,j,nz-1) 
       pu=0.0;gam=0.0
       r=u(:,j,:,taup1)*maskU(:,j,:)
       bet=bb(:,1)
       where (bet/=0.0) pu(:,1)=r(:,1)/bet
       do k=2,nz
        where (bet/=0.0) gam(:,k)=c(:,k-1)/bet
        bet=bb(:,k)-a(:,k)*gam(:,k)
        where(bet/=0.0) pu(:,k)=(r(:,k)-a(:,k)*pu(:,k-1))/bet
       enddo
       do k=nz-1,1,-1
        pu(:,k)=pu(:,k)-gam(:,k+1)*pu(:,k+1)
       enddo
       u(:,j,:,taup1)=pu
      enddo
      fu(:,js:je,:)=(u(:,js:je,:,taup1) -u(:,js:je,:,taum1))/c2dt*maskU(:,js:je,:)
end subroutine pv_mixing_implicit_umix


subroutine pv_mixing_implicit_vmix(nx_,ny_,nz_,A_uu,A2_uu)
 !---------------------------------------------------------------------------------
 !     same for v
 !---------------------------------------------------------------------------------
      use pyOM_module
      implicit none
      integer :: nx_,ny_,nz_
      integer :: j,k,js,je
      real*8 :: A_uu(nx_,ny_,nz_) ,A2_uu(nx_,ny_,nz_)
      real*8 :: a(nx ,nz),bb(nx ,nz),c(nx ,nz),bet(nx )
      real*8 :: pu(nx ,nz),gam(nx ,nz),fxa,r(nx ,nz)

      js=max(2,js_pe); je = min(je_pe,ny -1)
 !---------------------------------------------------------------------------------
 !      first fake integrate dv/dt = F_v, then solve for rest
 !---------------------------------------------------------------------------------
      v(:,js:je,:,taup1)= v(:,js:je,:,taum1) +c2dt*fv(:,js:je,:)*maskV(:,js:je,:)

      do j=js,je
       fxa = aidif_trm*c2dt/dz**2
       bb(:,1) = 1+fxa * A2_uu(:,j,1)*A_uu(:,j,1)
       c(:,1)  =  -fxa * A2_uu(:,j,1)*A_uu(:,j,1)
       do k=2,nz-1
         a(:,k)  =  -fxa * A2_uu(:,j,k)* A_uu(:,j,k-1)
         bb(:,k) = 1+fxa * A2_uu(:,j,k)*(A_uu(:,j,k)+A_uu(:,j,k-1) )
         c(:,k)  =  -fxa * A2_uu(:,j,k)* A_uu(:,j,k)
       enddo
       a(:,nz)  =  -fxa * A2_uu(:,j,nz)* A_uu(:,j,nz-1)
       bb(:,nz) = 1+fxa * A2_uu(:,j,nz)* A_uu(:,j,nz-1) 
       pu=0.0;gam=0.0
       r=v(:,j,:,taup1)*maskV(:,j,:)
       bet=bb(:,1)
       where (bet/=0.0) pu(:,1)=r(:,1)/bet
       do k=2,nz
        where (bet/=0.0) gam(:,k)=c(:,k-1)/bet
        bet=bb(:,k)-a(:,k)*gam(:,k)
        where(bet/=0.0) pu(:,k)=(r(:,k)-a(:,k)*pu(:,k-1))/bet
       enddo
       do k=nz-1,1,-1
        pu(:,k)=pu(:,k)-gam(:,k+1)*pu(:,k+1)
       enddo
       v(:,j,:,taup1)=pu
      enddo
      fv(:,js:je,:)=(v(:,js:je,:,taup1) -v(:,js:je,:,taum1))/c2dt*maskV(:,js:je,:)
end subroutine pv_mixing_implicit_vmix






subroutine vertical_eigenvalues0b(nb,rho,u,v,f0,beta,dz,Lr,omax,psimax,kx_max,ky_max,Qx,Qy,ier) 
!------------------------------------------------------------------------
! approx. quasi geostrophic vertical eigenvalue problem after
!------------------------------------------------------------------------
 implicit none
 real*8, parameter :: rho_0 = 1024, grav = 9.81,N_min=0.001,drho_min = (N_min**2)*rho_0/grav
 integer :: nb,k,ier,info,lwork,n,nk=16
 real*8 :: rho(nb),u(nb),v(nb),beta,f0,dz,Lr,Qy(nb),Qx(nb)
 real*8 :: kx_max,ky_max
 real*8 :: Nsqr(nb) ,drho(nb)
 complex*8 :: cmax,omax,psimax(nb),H1(nb),H2(nb),H3(nb)

 real*8 :: Gij(nb,nb),Bij(nb,nb),Aij(nb,nb),Cij(nb,nb),eigv
 real*8 :: omr(nb),omi(nb),mu(nb),um,vm,uvm,uum,vvm,fxa,kmax,fxb,h0,theta
 real*8 :: uk(nb),kh_appr,kh2,betas
 complex*8 :: c0,c1,ps1(nb),pp(nb),q,h1s(nb)
 real*8, allocatable :: work(:)

 ier=0
 ! check for unstable profile
 do k=1,nb-1
  drho(k)=max( rho(k+1)-rho(k), drho_min )
 enddo
 ! vertical differencing operator
 Gij=0.0
 Gij(1,1) = -1/drho(1)/dz;
 Gij(1,2) =  1/drho(1)/dz;
 do k = 2,nb-1
     Gij(k,k-1) = 1/drho(k-1)/dz;
     Gij(k,k)   = (-1/drho(k-1) -1/drho(k))/dz;
     Gij(k,k+1) =  1/drho(k) /dz;
 enddo
 Gij(nb,nb)   = -1/drho(nb-1) /dz;
 Gij(nb,nb-1) =  1/drho(nb-1) /dz;
 Gij          = Gij*f0**2*rho_0/grav
 ! background PV gradient
 do k=1,nb
   Qy(k) = beta
   Qx(k) = 0.0
   do n=1,nb
       Qy(k) = Qy(k)-Gij(k,n)*u(n)
       Qx(k) = Qx(k)+Gij(k,n)*v(n)
   enddo
 enddo
 ! stability frequency
 Nsqr(1)= -(rho(1)-rho(2))/(dz) /rho_0*grav
 do k=2,nb-1
  Nsqr(k)= -(rho(k-1)-rho(k+1))/(2*dz) /rho_0*grav
 enddo
 Nsqr(nb)=-(rho(nb-1)-rho(nb))/(dz) /rho_0*grav
 do k=1,nb
  Nsqr(k) = max(N_min**2,Nsqr(k))
 enddo

! direction of approximate minimum
 h0=dz*nb
 um = sum(u*dz)/h0
 vm = sum(v*dz)/h0
 uvm= sum(u*v*dz)/h0
 uum = sum(u**2*dz)/h0
 vvm = sum(v**2*dz)/h0
 theta = 0.5*atan2(2*(uvm-um*vm), uum-um**2-vvm+vm**2)
 uk=u*cos(theta)+v*sin(theta)

 ! magnitude of wave number of approximate minimum
 c0 = um**2 - uum
 c0 = um + sqrt(c0)
 ps1(1)=0. 
 do k=2,nb
  ps1(k) = ps1(k-1)+dz*(c0-uk(k))**2 
 enddo
 pp(1)=0
 do k=2,nb
  pp(k) = pp(k-1)+dz*( Nsqr(k)/(f0**2*(c0-uk(k))**2)*ps1(k) )
 enddo
 q = 2*c0*sum(Uk*pp*dz)/h0 -c0**2*sum(pp*dz)/h0 - sum(Uk**2*pp*dz)/h0
 q = q/( 2*(c0-sum(Uk*dz)/h0))
 fxa = -imag(c0)/(3*imag(q))  
 if (fxa > 0) then
   kh_appr=  sqrt(fxa) 
 else
    kh_appr=  1.6/Lr
 endif
 !kh_appr= real( sqrt(-imag(c0)/(3*imag(q)) )) 
 kh_appr = min(5.0/Lr, max( 0.1/Lr, kh_appr)  ) ! limit wavenumber
 kx_max=kh_appr*cos(theta)
 ky_max=kh_appr*sin(theta)
 kh2 = kx_max**2 + ky_max**2

 ! eigenvalue at that wavenumber
 betas = beta * kx_max  /kh_appr 
 c0 = (sum(Uk*dz)/h0)**2-sum(Uk**2*dz)/h0+(betas/(2*kh2))**2 
 c0 = sum(Uk*dz)/h0 -betas/(2*kh2) + sqrt( c0 )
 h1s(1)=0.0
 do k=2,nb
  h1s(k)=h1s(k-1) + dz*( (c0-Uk(k))**2*kh2 +betas* (c0-Uk(k)))  
 enddo
 h1(1)=0.0
 do k=2,nb
  h1(k)=h1(k-1) + dz*( h1s(k)*Nsqr(k)/( f0**2*(c0-Uk(k))**2 ) )  
 enddo
 c1 = c0*(c0+betas/kh2)*sum(h1*dz)/h0 &
         -(2*c0+betas/kh2)*sum(h1*Uk*dz)/h0 +sum(h1*Uk**2*dz)/h0
 c1 =-c1/(betas/kh2+2*(c0-sum(Uk*dz)/h0) )
 cmax = c0+c1
 if (imag(cmax)<0) cmax = conjg(cmax)
 fxa = imag(cmax) ! limit imaginary phase speed
 cmax = cmplx(real(cmax), min( 0.2/(kh_appr*86400.0) , fxa ) ) 
 omax = cmax*kh_appr 

 ! eigenfunction for that wavenumber and eigenvalue
 if (imag(omax)>0) then 
   H1(1)=0.0
   do k=2,nb
     H1(k)=H1(k-1)+dz*(cmax-uk(k))*(kh2*(cmax-uk(k))+betas) 
   enddo
   do k=2,nb
    H1(k) = H1(k-1)+dz*Nsqr(k)/f0**2/(cmax-uk(k))**2*H1(k)
   enddo
   H2(1)=0.0
   do k=2,nb
     H2(k)=H2(k-1)+dz*H1(k)*(cmax-uk(k))*(kh2*(cmax-uk(k))+betas) 
   enddo
   do k=2,nb
    H2(k) = H2(k-1)+dz*Nsqr(k)/f0**2/(cmax-uk(k))**2*H2(k)
   enddo
   H3(1)=0.0
   do k=2,nb
     H3(k)=H3(k-1)+dz*H2(k)*(cmax-uk(k))*(kh2*(cmax-uk(k))+betas) 
   enddo
   do k=2,nb
    H3(k) = H3(k-1)+dz*Nsqr(k)/f0**2/(cmax-uk(k))**2*H3(k)
   enddo
   psimax = (1+H1+H2+H3)*(uk-cmax)
   psimax=psimax/maxval(abs(real(psimax)) + abs(imag(psimax)) )
 else
   psimax = cmplx(0,0); 
 endif

end subroutine vertical_eigenvalues0b




subroutine vertical_eigenvalues0(nb,rho,u,v,f0,beta,dz,Lr,omax,psimax,kx_max,ky_max,Qx,Qy,ier) 
!------------------------------------------------------------------------
! approx. quasi geostrophic vertical eigenvalue problem after
!------------------------------------------------------------------------
 implicit none
 real*8, parameter :: rho_0 = 1024, grav = 9.81,N_min=0.001,drho_min = (N_min**2)*rho_0/grav
 integer :: nb,k,ier,info,lwork,n,nk=16
 real*8 :: rho(nb),u(nb),v(nb),beta,f0,dz,Lr,Qy(nb),Qx(nb)
 real*8 :: kx_max,ky_max
 real*8 :: kmag,uu(nb),Nsqr(nb) ,drho(nb)
 complex*8 :: cmax,omax,psimax(nb),H1(nb),H2(nb),H3(nb),betas

 real*8 :: Gij(nb,nb),Bij(nb,nb),Aij(nb,nb),Cij(nb,nb),eigv
 real*8 :: omr(nb),omi(nb),mu(nb),um,vm,umag,fxa,kmax,fxb
 real*8, allocatable :: work(:)


 ier=0
 ! check for unstable profile
 do k=1,nb-1
  drho(k)=max( rho(k+1)-rho(k), drho_min )
 enddo
 ! vertical differencing operator
 Gij=0.0
 Gij(1,1) = -1/drho(1)/dz;
 Gij(1,2) =  1/drho(1)/dz;
 do k = 2,nb-1
     Gij(k,k-1) = 1/drho(k-1)/dz;
     Gij(k,k)   = (-1/drho(k-1) -1/drho(k))/dz;
     Gij(k,k+1) =  1/drho(k) /dz;
 enddo
 Gij(nb,nb)   = -1/drho(nb-1) /dz;
 Gij(nb,nb-1) =  1/drho(nb-1) /dz;
 Gij          = Gij*f0**2*rho_0/grav
 ! background PV gradient
 do k=1,nb
   Qy(k) = beta
   Qx(k) = 0.0
   do n=1,nb
       Qy(k) = Qy(k)-Gij(k,n)*u(n)
       Qx(k) = Qx(k)+Gij(k,n)*v(n)
   enddo
 enddo
 ! stability frequency
 Nsqr(1)= -(rho(1)-rho(2))/(dz) /rho_0*grav
 do k=2,nb-1
  Nsqr(k)= -(rho(k-1)-rho(k+1))/(2*dz) /rho_0*grav
 enddo
 Nsqr(nb)=-(rho(nb-1)-rho(nb))/(dz) /rho_0*grav
 do k=1,nb
  Nsqr(k) = max(N_min**2,Nsqr(k))
 enddo


 ! choose wavenumber vector parallel to depth averaged flow
 um=0.0;vm=0.0
 do k=1,nb
  um=um+u(k)*dz
  vm=vm+v(k)*dz
 enddo
 um=um/(nb*dz)
 vm=vm/(nb*dz)
 umag = sqrt(um**2+vm**2)
 if (umag>0.0001) then
  um=um/umag; vm=vm/umag
 else 
  um=1; vm=0
 endif
 kmax=1.6
 kx_max=kmax/Lr*um; ky_max=kmax/Lr*vm
 kmag=sqrt(kx_max**2+ky_max**2)
 uu=(kx_max*u+ky_max*v)/kmag
 
 fxa=0;fxb=0
 do k=1,nb
  fxa=fxa+uu(k)*dz
  fxb=fxb+uu(k)**2*dz
 enddo
 fxa=fxa/(nb*dz); fxb=fxb/(nb*dz); 

 betas = beta*kx_max/kmag;
 cmax = (betas/(2*kmag**2) )**2 - fxb + fxa**2
 cmax= fxa -betas/(2*kmag**2) + sqrt(cmax)
 cmax=cmplx(real(cmax),abs(imag(cmax)))
 omax = cmax*kmag
 

 ! loop over wavenumbers to get maximal eigenvalue
 !omax=cmplx(0,0); kmax=1.6
 !do k=1,nk
 !  kx_max=4.2*(k-0.0)/(1.*nk)  /Lr*um; ky_max=4.2*(k-0.0)/(1.*nk)  /Lr*vm
 !  kmag=sqrt(kx_max**2+ky_max**2)
 !  uu=(kx_max*u+ky_max*v)/kmag
 !  ! construct matrixes and solve eigenvalue problem
 !  Cij=0.0; Aij=0.0
 !  Bij=Gij
 !  do n=1,nb
 !      Bij(n,n)=Bij(n,n)-kmag**2
 !      Cij(n,n)=kx_max*u(n)+ky_max*v(n)
 !  enddo
 !  Aij = matmul(Cij,Bij)
 !  do n=1,nb
 !      Aij(n,n)=Aij(n,n)+kx_max*Qy(n)-ky_max*Qx(n)
 !  enddo
 !  lwork=20*nb;allocate(work(lwork))
 !  call DGGEV('N','N',nb,Aij,nb,Bij,nb,omr,omi,mu,eigv,nb,eigv,nb,WORK,LWORK,INFO)
 !  deallocate(work)
 !  if (info/=0) then
 !            print*,' Error in DGGEV (1) info :',info
 !            omi=0.0;omr=0.0
 !  endif
 !  do n=1,nb
 !      if (mu(n)==0.0) then
 !         print*,' Error: mu=0 in DGGEV (1)'
 !          omr(n)=0.0;omi(n)=0.0
 !      else
 !          omr(n)=omr(n)/mu(n)
 !          omi(n)=omi(n)/mu(n)
 !      endif
 !  enddo
 !  n=maxloc(omi,1)
 !  if (omi(n)> imag(omax) ) then
 !     omax=cmplx(omr(n),omi(n))
 !     kmax= 4.2*(k-0.0)/(1.*nk)
 !  endif
 !enddo
 !kx_max=kmax/Lr*um; ky_max=kmax/Lr*vm

 if (imag(omax)>0) then 

   kmag=sqrt(kx_max**2+ky_max**2)
   uu=(kx_max*u+ky_max*v)/kmag
   cmax = omax/kmag
   betas = beta*kx_max/kmag;
   H1(1)=0.0
   do k=2,nb
     H1(k)=H1(k-1)+dz*(cmax-uu(k))*(kmag**2*(cmax-uu(k))+betas) 
   enddo
   do k=2,nb
    H1(k) = H1(k-1)+dz*Nsqr(k)/f0**2/(cmax-uu(k))**2*H1(k)
   enddo
   H2(1)=0.0
   do k=2,nb
     H2(k)=H2(k-1)+dz*H1(k)*(cmax-uu(k))*(kmag**2*(cmax-uu(k))+betas) 
   enddo
   do k=2,nb
    H2(k) = H2(k-1)+dz*Nsqr(k)/f0**2/(cmax-uu(k))**2*H2(k)
   enddo
   H3(1)=0.0
   do k=2,nb
     H3(k)=H3(k-1)+dz*H2(k)*(cmax-uu(k))*(kmag**2*(cmax-uu(k))+betas) 
   enddo
   do k=2,nb
    H3(k) = H3(k-1)+dz*Nsqr(k)/f0**2/(cmax-uu(k))**2*H3(k)
   enddo
   psimax = (1+H1+H2+H3)*(uu-cmax)
   psimax=psimax/maxval(abs(real(psimax)) + abs(imag(psimax)) )
 else
   psimax = cmplx(0,0)
 endif

end subroutine vertical_eigenvalues0







subroutine vertical_eigenvalues1(nb,rho,u,v,f0,beta,dz,Lr,omax,psimax,kx_max,ky_max,Qx,Qy,ier) 
!------------------------------------------------------------------------
! approx. quasi geostrophic vertical eigenvalue problem after
!------------------------------------------------------------------------
 implicit none
 real*8, parameter :: rho_0 = 1024, grav = 9.81, N_min = 0.001, kx_span=4.0, ky_span=4.0
 integer, parameter :: nk=16,nl=32
 integer :: nb,k,n,l,ier
 real*8 :: rho(nb),u(nb),v(nb),beta,f0,dz,Lr,Qy(nb),Qx(nb)
 real*8 :: kx_max,ky_max,kx,ky
 real*8 :: kmag,uu(nb),Nsqr(nb) ,betas,Nso(nb),ell(nb),zeta,uzu,uzl,u0,tzeta
 complex*8 :: cmax,omax,psimax(nb),c0,H1(nb),H2(nb),H3(nb),H4(nb)

 ier=0
 Nsqr(1)= -(rho(1)-rho(2))/(dz) /rho_0*grav
 do k=2,nb-1
  Nsqr(k)= -(rho(k-1)-rho(k+1))/(2*dz) /rho_0*grav
 enddo
 Nsqr(nb)= -(rho(nb-1)-rho(nb))/(dz) /rho_0*grav
 Nsqr = max(N_min**2,Nsqr)
 Nso  = sqrt(Nsqr)

 ! loop over wavenumbers and find maximum
 omax=cmplx(0,0); kx_max=0.2/Lr; ky_max=0.
 do k=1,nk
   kx=(0+ (k-1.0)/nk*kx_span) /Lr
   do l=1,nl
       ky=((l-1.0)/nl-0.5)*2 *ky_span/Lr

       ! solve eigenvalue problem approximately
       kmag=sqrt(kx**2+ky**2)
       uu=(kx*u+ky*v)/kmag
       ell=abs(f0)/kmag/Nso
       zeta=0.
       do n=1,nb
        zeta= zeta-1./ell(n)*dz
       enddo
       tzeta=0.0; if (zeta/=0.0) tzeta=1./tanh(zeta)
       uzu=ell(1 )*(uu(1   )-uu(2 ))/dz
       uzl=ell(nb)*(uu(nb-1)-uu(nb))/dz
       u0 = uu(1)+uu(nb)-(uzl-uzu)*tzeta
       c0=U0**2/4 +uzu*uzl-uu(1)*uu(nb)+(uu(1)*uzl-uu(nb)*uzu)*tzeta
       c0=u0/2+sqrt(c0)
       c0=cmplx(real(c0),abs(imag(c0)))

       if (imag(c0*kmag)>imag(omax)) then
          kx_max = kx; ky_max = ky; omax=c0*kmag; cmax=c0
       endif
      enddo
   enddo

 if (imag(omax)>0) then 
   kmag=sqrt(kx_max**2+ky_max**2)
   uu=(kx_max*u+ky_max*v)/kmag;
   betas = beta*kx_max/kmag;
   H1(1)=0.0
   do k=2,nb
     H1(k)=H1(k-1)+dz*(cmax-uu(k))*(kmag**2*(cmax-uu(k))+betas) 
   enddo
   do k=2,nb
    H1(k) = H1(k-1)+dz*Nsqr(k)/f0**2/(cmax-uu(k))**2*H1(k)
   enddo
   H2(1)=0.0
   do k=2,nb
     H2(k)=H2(k-1)+dz*H1(k)*(cmax-uu(k))*(kmag**2*(cmax-uu(k))+betas) 
   enddo
   do k=2,nb
    H2(k) = H2(k-1)+dz*Nsqr(k)/f0**2/(cmax-uu(k))**2*H2(k)
   enddo
   H3(1)=0.0
   do k=2,nb
     H3(k)=H3(k-1)+dz*H2(k)*(cmax-uu(k))*(kmag**2*(cmax-uu(k))+betas) 
   enddo
   do k=2,nb
    H3(k) = H3(k-1)+dz*Nsqr(k)/f0**2/(cmax-uu(k))**2*H3(k)
   enddo
   psimax = (1+H1+H2+H3)*(uu-cmax)
   psimax=psimax/maxval(abs(real(psimax)) + abs(imag(psimax)) )
 else
   psimax = cmplx(0,0)
 endif

end subroutine vertical_eigenvalues1






subroutine vertical_eigenvalues2(nb,rho,u,v,f0,beta,dz,Lr,omaxx,psimaxx,kx_max,ky_max,Qx,Qy,ier) 
!------------------------------------------------------------------------
! quasi geostrophic vertical eigenvalue problem after
! Smith (2007) "The Geography of Linear Baroclinic 
! Instability in Earth’s Oceans", J. Mar. Res., 65 (5), 655-683
!------------------------------------------------------------------------
 implicit none
 real*8, parameter :: rho_0 = 1024, grav = 9.81, drho_min = (0.002**2)*rho_0/grav
 integer, parameter :: nk=16,nl=16
 integer :: nb,k,n,l,lwork,info,ii(2),ier
 real*8 :: rho(nb),u(nb),v(nb),beta,f0,dz,Lr,Qy(nb),Qx(nb)
 complex*8 :: omaxx,psimaxx(nb),omax(nk,nl)!,psimax(nk,nl,nb)
 real*8 :: Gij(nb,nb),kx,ky,Bij(nb,nb),Aij(nb,nb),Cij(nb,nb)
 real*8 :: omr(nb),omi(nb),mu(nb),eigv(nb,nb),eps(nb)
 real*8, allocatable :: work(:)
 real*8    :: kx_max,ky_max,k_span=4.2,eins=1.0, null=0.0,drho(nb)
 real*8 :: kmag,uu(nb),Nsqr(nb) ,betas
 complex*8 :: H1(nb),c,H2(nb),H3(nb),H4(nb)

 ier=0
 ! check for unstable profile
 do k=1,nb-1
  drho(k)=max( rho(k+1)-rho(k), drho_min )
 enddo
 ! vertical differencing operator
 Gij=0.0
 Gij(1,1) = -1/drho(1)/dz;
 Gij(1,2) =  1/drho(1)/dz;
 do k = 2,nb-1
     Gij(k,k-1) = 1/drho(k-1)/dz;
     Gij(k,k)   = (-1/drho(k-1) -1/drho(k))/dz;
     Gij(k,k+1) =  1/drho(k) /dz;
 enddo
 Gij(nb,nb)   = -1/drho(nb-1) /dz;
 Gij(nb,nb-1) =  1/drho(nb-1) /dz;
 Gij          = Gij*f0**2*rho_0/grav
 ! background PV gradient
 do k=1,nb
   Qy(k) = beta
   Qx(k) = 0.0
   do n=1,nb
       Qy(k) = Qy(k)-Gij(k,n)*u(n)
       Qx(k) = Qx(k)+Gij(k,n)*v(n)
   enddo
 enddo
 lwork=20*nb;allocate(work(lwork))

 ! loop over wavenumbers on coarse large grid
 omax=cmplx(0,0); Cij=0.0; Aij=0.0
 do k=1,nk/2
   kx=((k-1.0)/nk)*2 *k_span/Lr
   do l=1,nl
       ky=((l-1.0)/nl-0.5)*2 *k_span/Lr
       ! construct matrixes and solve eigenvalue problem
       Bij=Gij
       do n=1,nb
        Bij(n,n)=Bij(n,n)-(kx**2+ky**2)
        Cij(n,n)=kx*u(n)+ky*v(n)
       enddo
       Aij = matmul(Cij,Bij)
       do n=1,nb
        Aij(n,n)=Aij(n,n)+kx*Qy(n)-ky*Qx(n)
       enddo
       call DGGEV('N','N',nb,Aij,nb,Bij,nb,omr,omi,mu,eigv,nb,eigv,nb,WORK,LWORK,INFO)
       if (info/=0) then
           print*,' Error in DGGEV (1) info :',info,'ky,kx=',ky*Lr,kx*Lr
           omi=0.0;omr=0.0
       endif
       do n=1,nb
     if (mu(n)==0.0) then
        print*,' Error: mu=0 in DGGEV (1)'
         omr(n)=0.0;omi(n)=0.0
     else
         omr(n)=omr(n)/mu(n)
         omi(n)=omi(n)/mu(n)
     endif
       enddo
       n=maxloc(omi,1)
       omax(k,l)=cmplx(omr(n),omi(n))
     enddo
 enddo
 ii=maxloc(imag(omax))
 omaxx=omax(ii(1),ii(2))
 kx_max = ((ii(1)-1.)/nk)*2 *k_span/Lr
 ky_max = ((ii(2)-1.)/nl-0.5)*2 *k_span/Lr


 ! loop over wavenumbers on finer grid around maximum
 omax=cmplx(0,0); Cij=0.0; Aij=0.0
 do k=1,nk
   kx=kx_max+ ((k-1.0)/nk-0.5)*2 *k_span/Lr /nk*2
   do l=1,nl
       ky=ky_max+ ((l-1.0)/nl-0.5)*2 *k_span/Lr /nl*2
       ! construct matrixes and solve eigenvalue problem
       Bij=Gij
       do n=1,nb
        Bij(n,n)=Bij(n,n)-(kx**2+ky**2)
        Cij(n,n)=kx*u(n)+ky*v(n)
       enddo
       Aij = matmul(Cij,Bij)
       do n=1,nb
        Aij(n,n)=Aij(n,n)+kx*Qy(n)-ky*Qx(n)
       enddo
       call DGGEV('N','N',nb,Aij,nb,Bij,nb,omr,omi,mu,eigv,nb,eigv,nb,WORK,LWORK,INFO)
       if (info/=0) then
           print*,' Error in DGGEV (2) info :',info,'ky,kx=',ky*Lr,kx*Lr
           omi=0.0;omr=0.0
       endif
       do n=1,nb
     if (mu(n)==0.0) then
        print*,' Error: mu=0 in DGGEV (2)'
         omr(n)=0.0;omi(n)=0.0
     else
         omr(n)=omr(n)/mu(n)
         omi(n)=omi(n)/mu(n)
     endif
       enddo
       n=maxloc(omi,1)
       omax(k,l)=cmplx(omr(n),omi(n))
     enddo
 enddo
 ii=maxloc(imag(omax))
 omaxx=omax(ii(1),ii(2))
 kx_max=kx_max+ ((ii(1)-1.0)/nk-0.5)*2 *k_span/Lr /nk*2
 ky_max=ky_max+ ((ii(2)-1.0)/nl-0.5)*2 *k_span/Lr /nl*2


 ! now get eigenvector of fastest growing mode
 !Bij=Gij; Cij=0.0; Aij=0.0
 !do n=1,nb
 !  Bij(n,n)=Bij(n,n)-(kx_max**2+ky_max**2)
 !  Cij(n,n)=kx_max*u(n)+ky_max*v(n)
 !enddo
 !Aij = matmul(Cij,Bij)
 !do n=1,nb
 !   Aij(n,n)=Aij(n,n)+kx_max*Qy(n)-ky_max*Qx(n)
 !enddo
 !eigv=0.0; 
 !call DGGEV('N','V',nb,Aij,nb,Bij,nb,omr,omi,mu,eigv,nb,eigv,nb,WORK,LWORK,INFO)
 !      if (info/=0) then
 !          print*,' Error in DGGEV (3) info :',info,'ky,kx=',ky*Lr,kx*Lr
 !          omi=0.0;omr=0.0;eigv=0.0
 !          ier=-1
 !      endif
 !n=maxloc(omi,1)
 !psimaxx=cmplx(eigv(:,n),eigv(:,n+1) )


 kmag=sqrt(kx_max**2+ky_max**2)
 uu=(kx_max*u+ky_max*v)/kmag;
 betas = beta*kx_max/kmag;
 c=omaxx/kmag;

 Nsqr(1)= -(rho(1)-rho(2))/(dz) /rho_0*grav
 do k=2,nb-1
  Nsqr(k)= -(rho(k-1)-rho(k+1))/(2*dz) /rho_0*grav
 enddo
 Nsqr(nb)= -(rho(nb-1)-rho(nb))/(dz) /rho_0*grav
 H1(1)=0.0
 do k=2,nb
   H1(k)=H1(k-1)+dz*(c-uu(k))*(kmag**2*(c-uu(k))+betas) 
 enddo
 do k=2,nb
  H1(k) = H1(k-1)+dz*Nsqr(k)/f0**2/(c-uu(k))**2*H1(k)
 enddo
 H2(1)=0.0
 do k=2,nb
   H2(k)=H2(k-1)+dz*H1(k)*(c-uu(k))*(kmag**2*(c-uu(k))+betas) 
 enddo
 do k=2,nb
  H2(k) = H2(k-1)+dz*Nsqr(k)/f0**2/(c-uu(k))**2*H2(k)
 enddo
 H3(1)=0.0
 do k=2,nb
   H3(k)=H3(k-1)+dz*H2(k)*(c-uu(k))*(kmag**2*(c-uu(k))+betas) 
 enddo
 do k=2,nb
  H3(k) = H3(k-1)+dz*Nsqr(k)/f0**2/(c-uu(k))**2*H3(k)
 enddo
 psimaxx = (1+H1+H2+H3)*(uu-c)
 psimaxx=psimaxx/maxval(abs(real(psimaxx)) + abs(imag(psimaxx)) )
 deallocate(work)
end subroutine vertical_eigenvalues2



subroutine vertical_eigenvalues3(nb,rho,u,v,f0,beta,dz,Lr,omax,psimax,kx_max,ky_max,Qx,Qy,ier) 
!------------------------------------------------------------------------
! quasi geostrophic vertical eigenvalue problem after
! Smith (2007) "The Geography of Linear Baroclinic 
! Instability in Earth’s Oceans", J. Mar. Res., 65 (5), 655-683
!------------------------------------------------------------------------
 implicit none
 real*8, parameter :: rho_0 = 1024, grav = 9.81, drho_min = (0.0005**2)*rho_0/grav
 real*8 :: A_h = 50.0,drho(nb)
 integer, parameter :: nk=16,nl=16
 integer :: nb,k,n,l,lwork,info,ier
 real*8 :: rho(nb),u(nb),v(nb),beta,f0,dz,Lr,Qy(nb),Qx(nb)
 complex*8 :: omax,psimax(nb)
 real*8 :: Gij(nb,nb),kx,ky
 real*8    :: kx_max,ky_max,k_span=4.2,eins=1.0, null=0.0
 ! variables for lapack
 complex :: Bij(nb,nb),Aij(nb,nb),Cij(nb,nb),om(nb),mu(nb),eigv(nb,nb)
 complex, allocatable :: work(:)
 real, allocatable :: rwork(:)

 ier=0
 ! check for unstable profile
 do k=1,nb-1
  drho(k)=max( rho(k+1)-rho(k), drho_min )
 enddo
 ! vertical differencing operator
 Gij=0.0
 Gij(1,1) = -1/drho(1)/dz;
 Gij(1,2) =  1/drho(1)/dz;
 do k = 2,nb-1
     Gij(k,k-1) = 1/drho(k-1)/dz;
     Gij(k,k)   = (-1/drho(k-1) -1/drho(k))/dz;
     Gij(k,k+1) =  1/drho(k) /dz;
 enddo
 Gij(nb,nb)   = -1/drho(nb-1) /dz;
 Gij(nb,nb-1) =  1/drho(nb-1) /dz;
 Gij          = Gij*f0**2*rho_0/grav
 ! background PV gradient
 do k=1,nb
   Qy(k) = beta
   Qx(k) = 0.0
   do n=1,nb
       Qy(k) = Qy(k)-Gij(k,n)*u(n)
       Qx(k) = Qx(k)+Gij(k,n)*v(n)
   enddo
 enddo
 lwork=20*nb;allocate(work(lwork),rwork(8*nb) )

 ! loop over wavenumbers on coarse grid
 omax=cmplx(0,0); kx_max=0.; ky_max=0.; Cij=0.0; Aij=0.0
 do k=1,nk/2
   kx=(k-1.0)/nk*2 *k_span/Lr
   do l=1,nl
       ky=((l-1.0)/nl-0.5)*2 *k_span/Lr
       ! construct matrixes and solve eigenvalue problem
       Bij=Gij
       do n=1,nb
        Bij(n,n)=Bij(n,n)-(kx**2+ky**2)
        Cij(n,n)=kx*u(n)+ky*v(n) - cmplx(0.,1.)*A_h*(kx**2+ky**2)
       enddo
       Aij = matmul(Cij,Bij)
       do n=1,nb
        Aij(n,n)=Aij(n,n)+kx*Qy(n)-ky*Qx(n)
       enddo
       call CGGEV('N','N',nb,Aij,nb,Bij,nb,om,mu,eigv,nb,eigv,nb,WORK,LWORK,RWORK,INFO)
       if (info/=0) then
           print*,' Error in CGGEV (1) info :',info,'ky,kx=',ky*Lr,kx*Lr
           om=cmplx(0.0,0.0)
       endif
       do n=1,nb
        if (abs(mu(n))==0.0) then
           print*,' Error: mu=0 in CGGEV (1) n=',n,'ky,kx=',ky*Lr,kx*Lr
           om(n)=cmplx(0.0,0.0)
        else
           om(n)=om(n)/mu(n)
        endif
       enddo
       n=maxloc(imag(om) ,1)
       if (imag(om(n))>imag(omax)) then
          kx_max = kx; ky_max = ky; omax=om(n)
       endif
     enddo
 enddo

 ! loop over wavenumbers on finer grid around maximum
 Cij=0.0; Aij=0.0
 do k=1,nk
   kx=kx_max+ ((k-1.0)/nk-0.5)*2 *k_span/Lr /nk*2
   do l=1,nl
       ky=ky_max+ ((l-1.0)/nl-0.5)*2 *k_span/Lr /nl*2
       ! construct matrixes and solve eigenvalue problem
       Bij=Gij
       do n=1,nb
        Bij(n,n)=Bij(n,n)-(kx**2+ky**2)
        Cij(n,n)=kx*u(n)+ky*v(n) - cmplx(0.,1.)*A_h*(kx**2+ky**2)
       enddo
       Aij = matmul(Cij,Bij)
       do n=1,nb
        Aij(n,n)=Aij(n,n)+kx*Qy(n)-ky*Qx(n)
       enddo
       call CGGEV('N','N',nb,Aij,nb,Bij,nb,om,mu,eigv,nb,eigv,nb,WORK,LWORK,RWORK,INFO)
       if (info/=0) then
           print*,' Error in CGGEV (1) info :',info,'ky,kx=',ky*Lr,kx*Lr
           om=cmplx(0.0,0.0)
       endif
       do n=1,nb
        if (abs(mu(n))==0.0) then
           print*,' Error: mu=0 in CGGEV (2) n=',n,'ky,kx=',ky*Lr,kx*Lr
           om(n)=cmplx(0.0,0.0)
        else
           om(n)=om(n)/mu(n)
        endif
       enddo
       n=maxloc(imag(om) ,1)
       if (imag(om(n))>imag(omax)) then
          kx_max = kx; ky_max = ky; omax=om(n)
       endif
     enddo
 enddo

 ! now get eigenvector of fastest growing mode
 Bij=Gij; Cij=0.0; Aij=0.0
 do n=1,nb
   Bij(n,n)=Bij(n,n)-(kx_max**2+ky_max**2)
   Cij(n,n)=kx_max*u(n)+ky_max*v(n) - cmplx(0.,1.)*A_h*(kx_max**2+ky_max**2)
 enddo
 Aij = matmul(Cij,Bij)
 do n=1,nb
    Aij(n,n)=Aij(n,n)+kx_max*Qy(n)-ky_max*Qx(n)
 enddo
 eigv=0.0; 
 call CGGEV('N','V',nb,Aij,nb,Bij,nb,om,mu,eigv,nb,eigv,nb,WORK,LWORK,RWORK,INFO)
 if (info/=0) then
           print*,' Error in CGGEV (3) info :',info,'ky,kx=',ky*Lr,kx*Lr
           om=cmplx(0.0,0.0)
           eigv=cmplx(0.0,0.0)
           ier=-1
 endif
 do n=1,nb
    if (abs(mu(n))==0.0) then
           print*,' Error: mu=0 in CGGEV (3) n=',n,'ky,kx=',ky*Lr,kx*Lr
           om(n)=cmplx(0.0,0.0)
    else
           om(n)=om(n)/mu(n)
    endif
 enddo
 n=maxloc(imag(om),1)
 omax=om(n)
 if (imag(omax)<0.0) then
    omax=cmplx(0.0,0.0)
    psimax=cmplx(0.0,0.0)
 else 
  psimax=eigv(:,n) 
 endif
 deallocate(work,rwork)
end subroutine vertical_eigenvalues3










subroutine fillgaps(nx, ny, dat, spval)
!------------------------------------------------------------------------
!------------------------------------------------------------------------
      implicit none
      integer :: nx, ny
      integer :: i,j,n,ngaps,is,js,ie,je
      integer,allocatable :: igap(:), jgap(:)    ! gap coordinates
      real*8 :: dat(nx,ny) , spval, sum   ,isum

!     Count gaps
!     ----------
   allocate( igap(nx*ny), jgap(nx*ny) )
   igap = 0
   jgap = 0
   ngaps = 0

   do j=1, ny
        do i=1, nx
          if (dat(i,j)==spval) then
            ngaps = ngaps + 1
            igap(ngaps) = i                     ! register position
            jgap(ngaps) = j
          end if
        end do
  end do
  !     Fill gaps with average of nearest neighbours
  !     use all 8 adjacent grid points, unless gap is at boundary
  if (ngaps /= 0) then
          do n=1, ngaps
              is = max(igap(n)-1, 1)
              ie = min(igap(n)+1, nx)
              js = max(jgap(n)-1, 1)
              je = min(jgap(n)+1, ny)
              sum  = 0.
              isum = 0

              do j=js, je
                do i=is, ie
                  if (dat(i,j)/=spval) then
                     sum = sum + dat(i,j)
                    isum = isum + 1.0
                  end if
                end do
              end do
              if (isum/=0.0) dat(igap(n),jgap(n)) = sum/isum
          end do          
  end if              
  deallocate(igap,jgap)
end subroutine fillgaps



subroutine rgrd2(nx,ny,x,y,p,mx,my,xx,yy,q,ier)
!------------------------------------------------------------------------
!     subroutine rgrd2 interpolates the values p(i,j) on the orthogonal
!     grid (x(i),y(j)) for i=1,...,nx and j=1,...,ny onto q(ii,jj) on the
!     orthogonal grid (xx(ii),yy(jj)) for ii=1,...,mx and jj=1,...,my.
!
!     each of the x,y grids must be strictly montonically increasing
!     and each of the xx,yy grids must be montonically increasing (see
!     ier = 4).  in addition the (X,Y) region
!
!          [xx(1),xx(mx)] X [yy(1),yy(my)]
!
!     must lie within the (X,Y) region
!
!          [x(1),x(nx)] X [y(1),y(ny)].
!
!     extrapolation is not allowed (see ier=3).  
!
!     an integer error flag set as follows:
!
!     ier = 0 if no errors in input arguments are detected
!
!     ier = 1 if  min0(mx,my) < 1
!
!     ier = 2 if nx < 2 when intpol(1)=1 or nx < 4 when intpol(1)=3 (or)
!                ny < 2 when intpol(2)=1 or ny < 4 when intpol(2)=3
!
!     ier = 3 if xx(1) < x(1) or x(nx) < xx(mx) (or)
!                yy(1) < y(1) or y(ny) < yy(my) (or)
!
!     John C. Adams (NCAR 1994), C.Eden (2011)
!------------------------------------------------------------------------
      implicit integer (i-n) 
      implicit real*8 (a-h,o-z)
      dimension x(nx),y(ny),p(nx,ny),xx(mx),yy(my),q(mx,my)
      ! automatic arrays
      dimension jy(my),dy(my),ix(mx),dx(mx)
      dimension pj(mx),pjp(mx)
      dimension dxm(mx),dxp(mx),dxpp(mx)
 
!     check input arguments
      ier = 1
!     check (xx,yy) grid resolution
      if (min0(mx,my) .lt. 1) return
!     check (x,y) grid resolution
      ier = 2
      if (nx.lt.2) return
      if (ny.lt.2) return
!     check (xx,yy) grid contained in (x,y) grid
      ier = 3
      if (xx(1).lt.x(1) .or. xx(mx).gt.x(nx)) return
      if (yy(1).lt.y(1) .or. yy(my).gt.y(ny)) return
!     check montonicity of grids
      ier = 4
      do  i=2,nx
        if (x(i-1).ge.x(i)) return
      enddo
      do j=2,ny
        if (y(j-1).ge.y(j)) return
      enddo
      do  ii=2,mx
        if (xx(ii-1).gt.xx(ii)) return
      enddo
      do jj=2,my
        if (yy(jj-1).gt.yy(jj)) return
      enddo
!     arguments o.k.
      ier = 0
!     linearly interpolate in y
!     set y interpolation indices and scales and linearly interpolate
      call linmx(ny,y,my,yy,jy,dy)
!     set indices which depend on x interpolation
      call linmx(nx,x,mx,xx,ix,dx)
      call lint2(nx,ny,p,mx,my,q,jy,dy,pj,pjp,ix,dxm,dx,dxp,dxpp)
end subroutine rgrd2


subroutine lint2(nx,ny,p,mx,my,q,jy,dy,pj,pjp,ix,dxm,dx,dxp,dxpp)
     implicit integer (i-n) 
     implicit real*8 (a-h,o-z)
     dimension p(nx,ny)
     dimension q(mx,my)
     dimension pj(mx),pjp(mx),jy(my),dy(my)
     dimension ix(mx),dxm(mx),dx(mx),dxp(mx),dxpp(mx)
!    linearly interpolate in y
     jsave = -1
     do jj=1,my
      j = jy(jj)
      if (j.eq.jsave) then
!     j pointer has not moved since last pass (no updates or interpolation)
      else if (j.eq.jsave+1) then
!     update j and interpolate j+1
       do  ii=1,mx
        pj(ii) = pjp(ii)
       enddo
       call lint1(nx,p(1,j+1),mx,pjp,ix,dx)
      else
!      interpolate j,j+1in pj,pjp on xx mesh
       call lint1(nx,p(1,j),mx,pj,ix,dx)
       call lint1(nx,p(1,j+1),mx,pjp,ix,dx)
      end if
!     save j pointer for next pass
      jsave = j
!     linearly interpolate q(ii,jj) from pjp,pj in y direction
      do ii=1,mx
        q(ii,jj) = pj(ii)+dy(jj)*(pjp(ii)-pj(ii))
      enddo
     enddo
end subroutine lint2

subroutine lint1(nx,p,mx,q,ix,dx)
     implicit integer (i-n) 
     implicit real*8 (a-h,o-z)
     dimension p(nx),q(mx),ix(mx),dx(mx)
!    linearly interpolate p on x onto q on xx
     do ii=1,mx
      i = ix(ii)
      q(ii) = p(i)+dx(ii)*(p(i+1)-p(i))
     enddo
end subroutine lint1

subroutine linmx(nx,x,mx,xx,ix,dx)
   implicit integer (i-n) 
   implicit real*8 (a-h,o-z)
!  set x grid pointers for xx grid and interpolation scale terms
   dimension x(nx),xx(mx)
   dimension ix(mx),dx(mx)
   isrt = 1
   do ii=1,mx
!    find x(i) s.t. x(i) < xx(ii) <= x(i+1)
     do i=isrt,nx-1
            if (x(i+1) .ge. xx(ii)) then
               isrt = i
               ix(ii) = i
               go to 3
            end if
     enddo
 3   continue
    enddo
!   set linear scale term
    do ii=1,mx
         i = ix(ii)
         dx(ii) = (xx(ii)-x(i))/(x(i+1)-x(i))
    enddo
end subroutine linmx

