#!/usr/bin/python

import numpy
from pyOM import pyOM

class pyOM_cdf(pyOM):
   """
   pyOM with snapshot output in netcdf format
   """
   def __init__(self):
     pyOM.__init__(self)
     try:     # try to load module with netcdf bindings
              from netCDF4 import Dataset as NF
              print " using module netCDF4 "
              self.NF=NF
              self.netcdf_bug=0
     except ImportError:
              from Scientific.IO.NetCDF import NetCDFFile as NF
              print " using module Scientific.IO.NetCDF "
              self.NF=NF
              self.netcdf_bug=1  # account for bug in this module
     # note that scipy.io.netcdf does not support appending data to file
     # therefore we cannot use that module
     self.snap_file = 'pyOM.cdf'
     self._spval = -1e33*numpy.ones( (1,),'f')  
     self.init_cdf()
     return
   
  
   def diagnose(self):
     """ diagnose the model variables
     """
     pyOM.diagnose(self)
     self.write_cdf()
     return

   def define_grid_cdf(self,fid):
       M=self.fortran.pyom_module         # fortran module with model variables
       
       fid.createDimension('xt',int(M.nx))
       fid.createDimension('xu',int(M.nx))
       fid.createDimension('yt',int(M.ny))
       fid.createDimension('yu',int(M.ny))
       fid.createDimension('zt',int(M.nz))
       fid.createDimension('zw',int(M.nz))
       fid.createDimension('Time',None)
       
       xt=fid.createVariable('xt','f',('xt',) )
       xt.long_name='Zonal coordinate on T grid'; xt.units='m'
       xt[:]=M.xt.astype('f')
  
       xu=fid.createVariable('xu','f',('xu',) )
       xu.long_name='Zonal coordinate on U grid'; xu.units='m'
       xu[:]=M.xu.astype('f')
  
       yt=fid.createVariable('yt','f',('yt',) )
       yt.long_name='Meridional coordinate on T grid'; yt.units='m'
       yt[:]=M.yt.astype('f')
  
       yu=fid.createVariable('yu','f',('yu',) )
       yu.long_name='Meridional coordinate on V grid'; yu.units='m'
       yu[:]=M.yu.astype('f')
  
       zt=fid.createVariable('zt','f',('zt',) )
       zt.long_name='Vertical coordinate on T grid'; zt.units='m'
       zt[:]=M.zt.astype('f')

       zw=fid.createVariable('zw','f',('zw',) )
       zw.long_name='Vertical coordinate on W grid'; zw.units='m'
       zw[:]=M.zw.astype('f')
       return
    
   
   def init_cdf(self):
     """ intitialize netcdf diagnostics
     """
     M=self.fortran.pyom_module         # fortran module with model variables
     
     if M.my_pe == 0 and hasattr(self,'NF'):
       fid = self.NF(self.snap_file,'w')
       
       self.define_grid_cdf(fid)
       Time=fid.createVariable('Time','f',('Time',) ) 
       Time.long_name = 'Time since start'; Time.units='Seconds'
   
       id=fid.createVariable('b','f',('Time','zt','yt','xt') )
       id.long_name = 'Buoyancy'; id.units='m/s^2'
       id.missing_value = self._spval
  
       id=fid.createVariable('p','f',('Time','zt','yt','xt') )
       id.long_name = 'Pressure'; id.units='m^2/s^2'
  
       id=fid.createVariable('u','f',('Time','zt','yt','xu') )
       id.long_name = 'Zonal velocity'; id.units='m/s'
  
       id=fid.createVariable('v','f',('Time','zt','yu','xt') )
       id.long_name = 'Meridional velocity'; id.units='m/s'
  
       id=fid.createVariable('w','f',('Time','zw','yt','xt') )
       id.long_name = 'Vertical velocity'; id.units='m/s'
       
       id=fid.createVariable('K_b','f',('Time','zw','yt','xt') )
       id.long_name = 'Vertical diffusivity'; id.units='m^2/s'
       
       id=fid.createVariable('taux','f',('Time','yt','xu') )
       id.long_name = 'Zonal wind stress'; id.units='m^2/s^2'
  
       id=fid.createVariable('tauy','f',('Time','yt','xu') )
       id.long_name = 'Meridional wind stress'; id.units='m^2/s^2'
       
       id=fid.createVariable('surface_flux','f',('Time','yt','xt') )
       id.long_name = 'Surface buoyancy flux'; id.units='m^2/s^4'
  
       if M.enable_vert_friction_trm:
          id=fid.createVariable('A_trm','f',('Time','zw','yt','xt') )
          id.long_name = 'Vertical viscosity'; id.units='m^2/s'
          
       fid.close()

       
       if M.enable_back_state:
         import os,glob
         if glob.glob('back.cdf'): os.remove('back.cdf') # another bug in netcdf module
         fid = self.NF('back.cdf','w')
         self.define_grid_cdf(fid)
          
         id=fid.createVariable('back','f',('zt','yt','xt') )
         id.long_name = 'Background Buoyancy'; id.units='m/s^2'
         id=fid.createVariable('u0','f',('zt','yt','xu') )
         id.long_name = 'Background zonal flow'; id.units='m/s^2'
         
         if self.netcdf_bug:  # account for bug in this module
           for j in range(M.ny):
               fid.variables['back'][:,j,:] = M.back[:,j,:,0].transpose().astype('f')
               fid.variables['u0'][:,j,:]   = M.u0[:,j,:].transpose().astype('f')
         else:  
            fid.variables['back'][:,:,:] = M.back[:,:,:,0].transpose().astype('f')
            fid.variables['u0'][:,:,:] = M.u0[:,:,:].transpose().astype('f')
         fid.close()
         
     return


   def write_cdf(self):
     M=self.fortran.pyom_module         # fortran module with model variables
     if hasattr(self,'NF'):
       self.fortran.pe0_recv_3d(M.b[:,:,:,M.tau-1])
       self.fortran.pe0_recv_3d(M.u[:,:,:,M.tau-1])
       self.fortran.pe0_recv_3d(M.v[:,:,:,M.tau-1])
       self.fortran.pe0_recv_3d(M.w[:,:,:,M.tau-1])
       self.fortran.pe0_recv_3d(M.p_full[:,:,:,M.tau-1])
       self.fortran.pe0_recv_2d(M.surface_taux)
       self.fortran.pe0_recv_2d(M.surface_tauy)
       self.fortran.pe0_recv_2d(M.surface_flux)
       self.fortran.pe0_recv_3d(M.k_b)
       if M.enable_vert_friction_trm:
           self.fortran.pe0_recv_3d(M.a_trm)
       if M.my_pe == 0:   
          fid= self.NF(self.snap_file,'a')
          tid=fid.variables['Time'];
          i=list(numpy.shape(tid))[0];
          tid[i]=self.time
          a=M.b[:,:,:,M.tau-1].astype('f')
          b=self._spval*numpy.ones(a.shape,'f')
          fid.variables['b'][i,:,:,:]  = numpy.where(M.maskt[:,:,:]==0.,b,a).transpose().astype('f')
          fid.variables['u'][i,:,:,:]  = M.u[:,:,:,M.tau-1].transpose().astype('f')
          fid.variables['v'][i,:,:,:]  = M.v[:,:,:,M.tau-1].transpose().astype('f')
          fid.variables['w'][i,:,:,:]  = M.w[:,:,:,M.tau-1].transpose().astype('f')
          fid.variables['p'][i,:,:,:]  = M.p_full[:,:,:,M.tau-1].transpose().astype('f')
          fid.variables['taux'][i,:,:] = M.surface_taux[:,:].transpose().astype('f')
          fid.variables['tauy'][i,:,:] = M.surface_tauy[:,:].transpose().astype('f')
          fid.variables['surface_flux'][i,:,:] = M.surface_flux[:,:].transpose().astype('f')
          fid.variables['K_b'][i,:,:,:]  = M.k_b[:,:,:].transpose().astype('f')
          if M.enable_vert_friction_trm:
               fid.variables['A_trm'][i,:,:,:] = M.a_trm[:,:,:].transpose().astype('f')
          fid.close()
     return   


   def write_cdf2(self):
     """ diagnose model 
     """
     M=self.fortran.pyom_module         # fortran module with model variables
     if hasattr(self,'NF'): 
      for pe in range(0,max(1,M.n_pes)):
        self.barrier()
        if pe == M.my_pe :
          fid= self.NF(self.snap_file,'a')
          if pe==0:
            tid=fid.variables['Time'];
            i=list(numpy.shape(tid))[0];
            tid[i]=self.time
          else:
            tid=fid.variables['Time'];
            i=list(numpy.shape(tid))[0]-1
          js=M.js_pe-1; je=M.je_pe
          if self.netcdf_bug:  # account for bug in this module
           for j in range(js,je):
            a=M.b[:,j,:,M.tau-1].astype('f')
            b=self._spval*numpy.ones(a.shape,'f')
            fid.variables['b'][i,:,j,:]  = numpy.where(M.maskt[:,j,:]==0.,b,a).transpose().astype('f')
            fid.variables['u'][i,:,j,:]  = M.u[:,j,:,M.tau-1].transpose().astype('f')
            fid.variables['v'][i,:,j,:]  = M.v[:,j,:,M.tau-1].transpose().astype('f')
            fid.variables['w'][i,:,j,:]  = M.w[:,j,:,M.tau-1].transpose().astype('f')
            fid.variables['p'][i,:,j,:]  = M.p_full[:,j,:,M.tau-1].transpose().astype('f')
            fid.variables['taux'][i,j,:] = M.surface_taux[:,j].transpose().astype('f')
            fid.variables['tauy'][i,j,:] = M.surface_tauy[:,j].transpose().astype('f')
            fid.variables['surface_flux'][i,j,:] = M.surface_flux[:,j].transpose().astype('f')
            fid.variables['K_b'][i,:,j,:]  = M.k_b[:,j,:].transpose().astype('f')
            if M.enable_vert_friction_trm:
               fid.variables['A_trm'][i,:,j,:] = M.a_trm[:,j,:].transpose().astype('f')
               
          else:
           fid.variables['b'][i,:,js:je,:]  = M.b[:,js:je,:,M.tau-1].transpose().astype('f')
           fid.variables['u'][i,:,js:je,:]  = M.u[:,js:je,:,M.tau-1].transpose().astype('f')
           fid.variables['v'][i,:,js:je,:]  = M.v[:,js:je,:,M.tau-1].transpose().astype('f')
           fid.variables['w'][i,:,js:je,:]  = M.w[:,js:je,:,M.tau-1].transpose().astype('f')
           fid.variables['p'][i,:,js:je,:]  = M.p_full[:,js:je,:,M.tau-1].transpose().astype('f')
           fid.variables['taux'][i,js:je,:] = M.surface_taux[:,js:je].transpose().astype('f')
           fid.variables['tauy'][i,js:je,:] = M.surface_tauy[:,js:je].transpose().astype('f')
           fid.variables['surface_flux'][i,js:je,:] = M.surface_flux[:,js:je].transpose().astype('f')
           fid.variables['K_b'][i,:,js:je,:]  = M.k_b[:,js:je,:].transpose().astype('f')
           if M.enable_vert_friction_trm:
               fid.variables['A_trm'][i,:,js:je,:] = M.a_trm[:,js:je,:].transpose().astype('f')
              
          fid.close()
        self.barrier()
     return


if __name__ == "__main__":
   print 'I will do nothing'
