#include "options.inc"
c
c-----------------------------------------------------------------------
c     diagnose merdional heat transport and various decompositions
c     and write to either NetCDF or binary file
c-----------------------------------------------------------------------
c
      subroutine init_heat_tr_cdf
      use spflame_module
      implicit none
#ifdef netcdf_diagnostics
#include "netcdf.inc"
      integer ncid,iret,i,j,k
      real :: spval=-9.9e12
      integer lat_udim,itimedim,lat_uid,itimeid
      integer n,id, advid, diffid, overid, gyreid, ekmid
      integer dims(4), corner(4), edges(4)
      character name*24, unit*16,lname*24
#ifdef netcdf_real4
      real (kind=4)  :: v2(jmt)
#else
      real           :: v2(jmt)
#endif

      if (my_pe == 0) then
      print*,' initializing NetCDF output file ',heat_tr_file
      ncid = nccre (heat_tr_file, NCCLOB, iret)
      iret=nf_set_fill(ncid, NF_NOFILL, iret)
      call store_info_cdf(ncid)
c     dimensions
      Lat_udim  = ncddef(ncid, 'Latitude_u',  jmt, iret)
      iTimedim  = ncddef(ncid, 'Time', nf_unlimited, iret)
c     grid variables
      dims(1)  = Lat_udim
      Lat_uid  = ncvdef (ncid,'Latitude_u', NCFLOAT,1,dims,iret)
      dims(1)  = iTimedim
      iTimeid   = ncvdef(ncid,'Time',       NCFLOAT,1,dims,iret)
      dims=(/lat_udim,itimedim,1,1/)
c     attributes of the grid
      name = 'Latitude on U grid      '; unit = 'degrees_N       '
      call ncaptc(ncid, Lat_uid, 'long_name', NCCHAR, 24, name, iret) 
      call ncaptc(ncid, Lat_uid, 'units',     NCCHAR, 16, unit, iret) 
      name = 'Time                    '; unit = 'days            '
      call ncaptc(ncid, iTimeid, 'long_name', NCCHAR, 24, name, iret) 
      call ncaptc(ncid, iTimeid, 'units',     NCCHAR, 16, unit, iret) 
      call ncaptc(ncid, iTimeid,'time_origin',NCCHAR, 20,
     &  '31-DEC-1899 00:00:00', iret)
c     attributes of variables
      do n=1,nt
      name='heat';unit='PW'
      if (n==2) name='salt'
      if (n==2) unit='10^10cm^2/s'
      if (n>2)  then
        write(name,'("tracer_",i2)') n
        call replace_space_zero(name)
        unit= ' '
      endif
      k=len_trim(name)
      id=ncvdef (ncid,name(1:k)//'_tr', NCFLOAT,2,dims,iret)
      advid =ncvdef (ncid,name(1:k)//'_tr_adv', NCFLOAT,2,dims,iret)
      diffid=ncvdef (ncid,name(1:k)//'_tr_diff', NCFLOAT,2,dims,iret)
      overid=ncvdef (ncid,name(1:k)//'_tr_over', NCFLOAT,2,dims,iret)
      gyreid=ncvdef (ncid,name(1:k)//'_tr_gyre', NCFLOAT,2,dims,iret)
      ekmid =ncvdef (ncid,name(1:k)//'_tr_ekm', NCFLOAT,2,dims,iret)
      lname = 'Northward '//name(1:k)//' transport'; 
      call dvcdf(ncid,id,lname,24,unit,16,spval)
      lname = 'Advective '//name(1:k)//' transport'; 
      call dvcdf(ncid,advid,lname,24,unit,16,spval)
      lname = 'Diffusive '//name(1:k)//' transport'; 
      call dvcdf(ncid,diffid,lname,24,unit,16,spval)
      lname = 'Overturning '//name(1:k)//' transport'; 
      call dvcdf(ncid,overid,lname,24,unit,16,spval)
      lname = 'Gyre '//name(1:k)//' transport';
      call dvcdf(ncid,gyreid,lname,24,unit,16,spval)
      lname = 'Ekman '//name(1:k)//' transport'; 
      call dvcdf(ncid,ekmid,lname,24,unit,16,spval)
      enddo ! n
      call ncendf(ncid, iret)
      corner(1) = 1; edges(1) = jmt
      v2=yu
      call ncvpt(ncid, Lat_uid, corner, edges,v2, iret)
      call ncclos (ncid, iret)
      print*,' done'
      endif! my_pe==0
#else
c     needs no intialisaztion
#endif
      end subroutine init_heat_tr_cdf


      subroutine diag_heat_tr(n,adv_fn,diff_fn)
      use spflame_module
      implicit none
      integer, intent(in) :: n
      real, intent(in) :: adv_fn(is_pe:ie_pe,km,js_pe-1:je_pe)
      real, intent(in) :: diff_fn(is_pe:ie_pe,km,js_pe-1:je_pe)

      integer i,j,k, is,ie,js,je
#ifdef netcdf_diagnostics
#include "netcdf.inc"
      integer ncid,iret,npe, htrid,timeid,timedim,len
      real :: spval=-9.9e12, tt
      integer corner(4), edges(4)

#ifdef netcdf_real4
      real (kind=4) :: var(js_pe:je_pe)
#else
      real          :: var(js_pe:je_pe)
#endif


#else
      integer :: io,m
      character (len=80) :: iotext
      character (len=60) :: expnam
      real :: reltim = 0.
#endif
      character (len=80) name
      real :: totdxn,totdxs,vbr,tbrs,tbrn,totz,vbrz,tbrz,mask
      real :: tempdiff_fn,tempadv_fn,factor
      real,dimension(:,:),save,allocatable :: over_tr,gyre_tr,adv_tr
      real,dimension(:,:),save,allocatable :: diff_tr,barotr_tr
      real,dimension(:,:),save,allocatable :: total_tr,ekm_tr,barocl_tr
      real,dimension(:,:,:),allocatable :: buf
      logical, save :: first=.true.
      real :: pwatts = 4.186e-15, csalt  = 1.e-10, small=1.e-10

      if (first) then
       first=.false.
       allocate(over_tr(jmt,nt), gyre_tr(jmt,nt), adv_tr(jmt,nt) )
       allocate(diff_tr(jmt,nt), barotr_tr(jmt,nt) )
       allocate(total_tr(jmt,nt), ekm_tr(jmt,nt), barocl_tr(jmt,nt))
      endif

      over_tr(:,n) = 0. ; gyre_tr(:,n) = 0.        
      adv_tr(:,n) = 0.   ; diff_tr(:,n) = 0.        
      total_tr(:,n) = 0. ; ekm_tr(:,n) = 0.        
      barocl_tr(:,n) = 0. ; barotr_tr(:,n) = 0.        

      is=max(is_pe,2); ie=min(ie_pe,imt-1)
      js=max(js_pe,2); je=min(je_pe,jmt-1)

      do j=js,je

        do k=1,km
         totdxn = small; totdxs = small
         vbr    = 0.; tbrs   = 0.; tbrn   = 0.
         do i=is,ie
           totdxn = totdxn + dxt(i)*tmask(i,k,j+1)
#ifdef partial_cell
     &                 *dht(i,k,j+1)
#endif
           totdxs = totdxs + dxt(i)*tmask(i,k,j)
#ifdef partial_cell
     &                 *dht(i,k,j)
#endif
           vbr    = vbr  + u(i,k,j,2,tau)*dxu(i)*csu(j)
#ifdef partial_cell
     &                 *dhu(i,k,j)
#endif
           tbrn   = tbrn + t(i,k,j+1,n,tau)*tmask(i,k,j+1)*dxt(i)
#ifdef partial_cell
     &                 *dht(i,k,j+1)
#endif
           tbrs   = tbrs + t(i,k,j,n,tau)*tmask(i,k,j)*dxt(i)
#ifdef partial_cell
     &                 *dht(i,k,j)
#endif
	 enddo
         tbrn = tbrn/totdxn; tbrs = tbrs/totdxs
         over_tr(j,n) = over_tr(j,n) + vbr*0.5*(tbrn+tbrs)
#ifndef partial_cell
     &                                                *dzt(k)
#endif
         do i=is,ie
	  tempdiff_fn =      diff_fn(i,k,j)*
     &                           tmask(i,k,j+1)*tmask(i,k,j)*
     &                           dxt(i)*csu(j)
#ifndef partial_cell
     &                                        *dzt(k)
#endif
c	  tempadv_fn       = 0.5*adv_vnt(i,k,j)*(t(i,k,j,n,tau) +
c     &                           t(i,k,j+1,n,tau))*dxt(i)
	  tempadv_fn       = 0.5*adv_fn(i,k,j)*dxt(i)
#ifndef partial_cell
     &                                                   *dzt(k)
#endif
          adv_tr(j,n)    = adv_tr(j,n)  + tempadv_fn
          diff_tr(j,n)   = diff_tr(j,n) - tempdiff_fn
c          fxb = adv_vntiso(i,k,j)*
c     &              (t(i,k,j,n,taum1)+t(i,k,j+1,n,taum1))
c     &              *0.5*dxt(i)*csu(j)
#ifndef partial_cell
c     &                                                       *dzt(k)
#endif
c              ttn2(9,j,n) = ttn2(9,j,n) + fxb

         enddo
        enddo
        call sum_along_jrow(adv_tr(j,n),1)
        call sum_along_jrow(diff_tr(j,n),1)
        call sum_along_jrow(over_tr(j,n),1)

        gyre_tr(j,n) = adv_tr(j,n)-over_tr(j,n)
        total_tr(j,n) = adv_tr(j,n)+diff_tr(j,n)

        do i=is,ie

         factor=4.*cori(i,j,1)
c        limit coriolis factor
         factor = sign(1.,factor)*max(1.e-5,abs(factor)) 

         totz = 0.; vbrz = 0.; tbrz = 0.
         do k=1,km
	  mask = tmask(i,k,j)*tmask(i,k,j+1)
          vbrz = vbrz + adv_vnt(i,k,j)*dxt(i)
#ifndef partial_cell
     &                                       *dzt(k)
#endif
          tbrz = tbrz +mask*(t(i,k,j,n,tau)+t(i,k,j+1,n,tau))
#ifdef partial_cell
     &                           *min(dht(i,k,j),dht(i,k,j+1))
#else
     &                           *dzt(k)
#endif
          totz = totz + mask
#ifdef partial_cell
     &                          *min(dht(i,k,j),dht(i,k,j+1))
#else
     &                          *dzt(k)
#endif
         enddo
         if (totz /=0.0) then
           tbrz = tbrz/totz
           barotr_tr(j,n) = barotr_tr(j,n) + vbrz*tbrz*0.5
           ekm_tr(j,n) = ekm_tr(j,n) - 
     &          (smf(i,j,1)*dxu(i)*umask(i,1,j) + 
     &           smf(i-1,j,1)*dxu(i-1)*umask(i-1,1,j))*(t(i,1,j,n,tau)
     &                        +t(i,1,j+1,n,tau)-tbrz)
     &         *csu(j)/factor
         endif
        enddo
        call sum_along_jrow(barotr_tr(j,n),1)
        call sum_along_jrow(ekm_tr(j,n),1)
        barocl_tr(j,n) = barocl_tr(j,n)-barotr_tr(j,n)-ekm_tr(j,n)
      enddo

      if (my_pe==0.and.n==1)
     &  print*,' --> writing heat_tr to file ',
     &        heat_tr_file(1:len_trim(heat_tr_file))

#ifdef netcdf_diagnostics
c     write to netcdf format here
      do npe=0,n_pes,n_pes_i
       call barrier
       if (my_pe==npe) then
        iret=nf_open(heat_tr_file,NF_WRITE,ncid)
        iret=nf_set_fill(ncid, NF_NOFILL, iret)
        iret=nf_inq_varid(ncid,'Time',timeid)
        iret=nf_inq_dimid(ncid,'Time',timedim)
        iret=nf_inq_dimlen(ncid, timedim,len)
        if (my_pe==0.and.n==1) then
         len=len+1
         corner(1)=len
         edges(1)=1
         call read_stamp(current_stamp,tt)
         print*,' at stamp=',current_stamp,
     &          ' (days since origin : ',tt,')',
     &          ' (time steps in file : ',len,')'
         iret= nf_put_vara_real (ncid,timeid, corner,edges,tt)
        endif

        if (n==1) then
         name='heat'
         factor=pwatts
        elseif (n==2) then
         name='salt'
         factor=csalt
        else
         write(name,'("tracer_",i2)') n
         call replace_space_zero(name)
         factor=1.
        endif
        iret=nf_inq_varid(ncid,name(1:len_trim(name))//'_tr',htrid)
        corner=(/js_pe,len,1,1/); edges=(/je_pe-js_pe+1,1,1,1/)
        var= total_tr(js_pe:je_pe,n)*factor
        iret= nf_put_vara_real (ncid,htrid ,corner, edges,var)

        iret=nf_inq_varid(ncid,name(1:len_trim(name))//'_tr_adv',htrid)
        corner=(/js_pe,len,1,1/); edges=(/je_pe-js_pe+1,1,1,1/)
        var =  adv_tr(js_pe:je_pe,n)*factor
        iret= nf_put_vara_real (ncid,htrid ,corner, edges, var)

        iret=nf_inq_varid(ncid,name(1:len_trim(name))//'_tr_diff',htrid)
        corner=(/js_pe,len,1,1/); edges=(/je_pe-js_pe+1,1,1,1/)
        var =  diff_tr(js_pe:je_pe,n)*factor
        iret= nf_put_vara_real (ncid,htrid ,corner, edges,var)

        iret=nf_inq_varid(ncid,name(1:len_trim(name))//'_tr_over',htrid)
        corner=(/js_pe,len,1,1/); edges=(/je_pe-js_pe+1,1,1,1/)
        var =   over_tr(js_pe:je_pe,n)*factor
        iret= nf_put_vara_real (ncid,htrid ,corner, edges,var)

        iret=nf_inq_varid(ncid,name(1:len_trim(name))//'_tr_gyre',htrid)
        corner=(/js_pe,len,1,1/); edges=(/je_pe-js_pe+1,1,1,1/)
        var =  gyre_tr(js_pe:je_pe,n)*factor
        iret= nf_put_vara_real (ncid,htrid ,corner, edges, var)

        iret=nf_inq_varid(ncid,name(1:len_trim(name))//'_tr_ekm',htrid)
        corner=(/js_pe,len,1,1/); edges=(/je_pe-js_pe+1,1,1,1/)
        var =  ekm_tr(js_pe:je_pe,n)*factor
        iret= nf_put_vara_real (ncid,htrid ,corner, edges,var)

        call ncclos (ncid, iret)
       endif
       call barrier
      enddo
#else
c     write the old flame format here

      if (n==1) then
        over_tr(:,n)=over_tr(:,n)*pwatts
        gyre_tr(:,n)=gyre_tr(:,n)*pwatts
        barotr_tr(:,n)=barotr_tr(:,n)*pwatts
        barocl_tr(:,n)=barocl_tr(:,n)*pwatts
        ekm_tr(:,n)=ekm_tr(:,n)*pwatts
        adv_tr(:,n)=adv_tr(:,n)*pwatts
        diff_tr(:,n)=diff_tr(:,n)*pwatts
        total_tr(:,n)=total_tr(:,n)*pwatts
      endif
      if (n==2) then
        over_tr(:,n)=over_tr(:,n)*csalt 
        gyre_tr(:,n)=gyre_tr(:,n)*csalt 
        barotr_tr(:,n)=barotr_tr(:,n)*csalt 
        barocl_tr(:,n)=barocl_tr(:,n)*csalt 
        ekm_tr(:,n)=ekm_tr(:,n)*csalt 
        adv_tr(:,n)=adv_tr(:,n)*csalt 
        diff_tr(:,n)=diff_tr(:,n)*csalt 
        total_tr(:,n)=total_tr(:,n)*csalt 
      endif

      if (n==nt) then
       if (my_pe == 0)  then
        call getunit (io, heat_tr_file,'u s a ieee')
        iotext = 'no iotext'
        expnam = 'SPFLAME, some experiment'
        write (io) current_stamp, iotext, expnam
        write (io) jmt, nt, reltim    
       endif
c       call pe0_recv_vec_along_jrow(over_tr,nt)
       call pe0_recv_merid_vec(over_tr,nt,1)
c       call pe0_recv_vec_along_jrow(gyre_tr,nt)
       call pe0_recv_merid_vec(gyre_tr,nt,1)
c       call pe0_recv_vec_along_jrow(barotr_tr,nt)
       call pe0_recv_merid_vec(barotr_tr,nt,1)
c       call pe0_recv_vec_along_jrow(barocl_tr,nt)
       call pe0_recv_merid_vec(barocl_tr,nt,1)
c       call pe0_recv_vec_along_jrow(ekm_tr,nt)
       call pe0_recv_merid_vec(ekm_tr,nt,1)
c       call pe0_recv_vec_along_jrow(adv_tr,nt)
       call pe0_recv_merid_vec(adv_tr,nt,1)
c       call pe0_recv_vec_along_jrow(diff_tr,nt)
       call pe0_recv_merid_vec(diff_tr,nt,1)
c       call pe0_recv_vec_along_jrow(total_tr,nt)
       call pe0_recv_merid_vec(total_tr,nt,1)
c
       if (my_pe == 0)  then
        allocate(buf(8,jmt,nt))
        buf(1,:,:)=over_tr; buf(2,:,:)=gyre_tr;
        buf(3,:,:)=barotr_tr; buf(4,:,:)=barocl_tr;
        buf(5,:,:)=ekm_tr; buf(6,:,:)=adv_tr;
        buf(7,:,:)=diff_tr; buf(8,:,:)=total_tr;
        write (io) current_stamp, iotext, expnam
c        write (io) (((over_tr(j,m),gyre_tr(j,m),barotr_tr(j,m),
c     &             barocl_tr(j,m),ekm_tr(j,m),adv_tr(j,m),
c     &             diff_tr(j,m),total_tr(j,m)),j=1,jmt),m=1,nt)
        write(io) buf
        write (io) current_stamp, iotext, expnam
c        write (io) (((over_tr(j,m),gyre_tr(j,m),barotr_tr(j,m),
c     &             barocl_tr(j,m)),j=1,jmt),m=1,nt)
        write (io) buf(1,1,1)
        deallocate(buf)
        close(io)
       endif
      endif

#endif
      end subroutine diag_heat_tr

