
import numpy

class FortranError(Exception):
   """ exception of fortran subroutine
   """ 
   pass

class pyOMError(Exception):
   """ generic exception of class pyOM
   """ 
   def __init__(self, value):
       self.value = value
   def __str__(self):
       return repr(self.value)

class pyOM:
   """
   Main class describing the model
   """
   def __init__(self):
     """ initialize everyting
     """
     self.load_fortran_code()    # load fortran code with/out MPI bindings
     self.set_parameter()        # set all model parameter by overloaded method
     self.domain_decomposition() # set domain decomposition for parallel execution
     self.print_parameter()      # print all model parameters
     self.allocate_arrays()      # now allocate model variables in fortran module
     self.make_grid()            # setup model grid and other stuff
     self.initial_conditions()   # setup forcing configuration by overloaded method
     return

   def load_fortran_code(self):
     """ try to load module and fortran code with MPI bindings
         or try to live without MPI
     """
     try:     
        from mpi4py import MPI
        import pyOM_code_MPI 
        self.fortran = pyOM_code_MPI
        print ' checking MPI'
        self.mpi_comm = MPI.COMM_WORLD  
        if self.fortran.my_mpi_init( self.mpi_comm.py2f()) != 0: raise FortranError
     except ImportError: 
        print ' switching MPI off'
        import pyOM_code
        self.fortran = pyOM_code
        if self.fortran.my_mpi_init(0) != 0: raise FortranError
     self.pyOM_version = self.fortran.pyom_module.version
     return

   def run(self, runlen=0., snapint=0.):
     """
     enter a simple model time stepping loop
     """
     M=self.fortran.pyom_module         # fortran module with model variables
     if not hasattr(self,'startitt'): self.startitt=0
     self.enditt = self.startitt+int( runlen/M.dt )
     
     if M.my_pe==0:
       print ' model integration from time step ',self.startitt,' to ',self.enditt
       print ' run length   : ',runlen,'s, i.e. ', runlen/M.dt,' time steps'
       print ' snap interval: ',snapint,'s, i.e. any ', snapint/M.dt,' time steps'
       
     self.run_fortran('check_pyom_module') # check for consistent model variables

     for self.itt in range(self.startitt,self.enditt):
       self.time = self.itt*M.dt
       self.time_step()
       if numpy.mod(self.time,snapint) == 0 : self.diagnose()
       self.time_goes_by()
     if M.my_pe==0: print ' end of integration '  
     return
   
   
   def diagnose(self):
     """ diagnose the model variables, might be extended by further diagnostic
     """
     M=self.fortran.pyom_module         # fortran module with model variables
     if M.my_pe==0:
       #print " diagnosing at ",self.time, \
       #      's pressure solver itt = (',M.sor2d_itts,',',M.sor3d_itts,')'
       print 'diagnosing at %f s, itt=%i, pressure solver itt = (%i,%i)'  \
                %(self.time,self.itt,M.sor2d_itts,M.sor3d_itts)
     return

   def run_fortran(self,routine):
      """ execute a fortran routine with error handler
      """
      if hasattr(self.fortran,routine):
        ierr = eval('self.fortran.'+routine+'()')
        if ierr != 0:
          print 'errorcode ',ierr,' in fortran routine ',routine
          raise FortranError
        return
      else:
        raise pyOMError("routine "+routine+" not found")
      return
   
   def time_step(self):
     """ integrate one time step
     """
     M=self.fortran.pyom_module                       # fortran module with model variables
     self.boundary_conditions()               # set time dependent vertical boundary conditions
     if M.enable_hydrostatic:
        self.run_fortran('vertical_velocity')         # integrate divergence to get vertical velocity
     self.run_fortran('integrate_buoyancy')           # predict buoyancy 
     self.restoring_zones()                           # apply buoyancy restoring zones
     if M.enable_back_state:
        self.run_fortran('background_state_buoyancy') # effect of background state on buoyancy
     if M.enable_hydrostatic:
        self.run_fortran('convection')                # parameterize vertical convection 
     if M.enable_diag_tracer:
         self.run_fortran('integrate_tracer')         # integrate passive tracer
         if M.enable_isopycnal_diffusion:
             self.run_fortran('isopycnal_diffusion')      
         self.tracer_sources()
     self.run_fortran('momentum_tendency')            # predict momentum changes
     if M.enable_vert_friction_trm:
        self.run_fortran('vert_friction_trm')         # apply residual mean formulation
     self.momentum_restoring_zones()                  # apply restoring zones for momentum
     if M.enable_back_state:
        self.run_fortran('background_state_momentum') # effect of background state on buoyancy 
     self.run_fortran('solve_pressure')               # solve for the pressure
     self.run_fortran('integrate')                    # now update all new model variables
     return
   
   def time_goes_by(self):
     """ shift time pointers in fortran module
     """
     store                          = self.fortran.pyom_module.taum1*1 # to create new instance !!!
     self.fortran.pyom_module.taum1 = self.fortran.pyom_module.tau
     self.fortran.pyom_module.tau   = self.fortran.pyom_module.taup1
     self.fortran.pyom_module.taup1 = store
     return
    
   def allocate_arrays(self):
     """ allocate arrays
     """
     M=self.fortran.pyom_module                           # fortran module with model variables
     M.xt          = numpy.zeros((M.nx,),'d');
     M.xu          = numpy.zeros((M.nx,),'d')
     M.yt          = numpy.zeros((M.ny,),'d');
     M.yu          = numpy.zeros((M.ny,),'d')
     M.zt          = numpy.zeros((M.nz,),'d');
     M.zw          = numpy.zeros((M.nz,),'d')
     M.maskt       = numpy.zeros((M.nx,M.ny,M.nz),'d') 
     M.masku       = numpy.zeros((M.nx,M.ny,M.nz),'d') 
     M.maskv       = numpy.zeros((M.nx,M.ny,M.nz),'d') 
     M.maskw       = numpy.zeros((M.nx,M.ny,M.nz),'d') 
     M.hu          = numpy.zeros((M.nx,M.ny),'d') ;
     M.hv          = numpy.zeros((M.nx,M.ny),'d') ;
     M.ht          = numpy.zeros((M.nx,M.ny),'d')
     M.k_bottom_u  = numpy.zeros((M.nx,M.ny),'i') ;
     M.k_bottom_v  = numpy.zeros((M.nx,M.ny),'i') 
     M.coriolis_t  = numpy.zeros((M.ny,),'d');
     M.coriolis_hor= numpy.zeros((M.ny,),'d')

     if M.enable_diag_tracer: M.tr = numpy.zeros((M.nx,M.ny,M.nz,3,M.nt),'d')
     M.u           = numpy.zeros((M.nx,M.ny,M.nz,3),'d') 
     M.v           = numpy.zeros((M.nx,M.ny,M.nz,3),'d') 
     M.w           = numpy.zeros((M.nx,M.ny,M.nz,3),'d') 
     M.b           = numpy.zeros((M.nx,M.ny,M.nz,3),'d') 
     M.p_full      = numpy.zeros((M.nx,M.ny,M.nz,3),'d') 
     M.p_hydro     = numpy.zeros((M.nx,M.ny,M.nz),'d') 
     M.psi         = numpy.zeros((M.nx,M.ny,M.nz),'d') 
     M.p_surf      = numpy.zeros((M.nx,M.ny),'d') 
     M.eta         = numpy.zeros((M.nx,M.ny,3),'d')
  
     M.fu          = numpy.zeros((M.nx,M.ny,M.nz),'d') 
     M.fv          = numpy.zeros((M.nx,M.ny,M.nz),'d') 
     M.fw          = numpy.zeros((M.nx,M.ny,M.nz),'d')
  
     M.k_b         = numpy.zeros((M.nx,M.ny,M.nz),'d') 
     M.surface_flux= numpy.zeros((M.nx,M.ny),'d') 
     M.bottom_flux = numpy.zeros((M.nx,M.ny),'d') 
     M.surface_taux= numpy.zeros((M.nx,M.ny),'d') 
     M.surface_tauy= numpy.zeros((M.nx,M.ny),'d') 
     M.bottom_taux = numpy.zeros((M.nx,M.ny),'d')
     M.bottom_tauy = numpy.zeros((M.nx,M.ny),'d')

     M.cf2d        = numpy.zeros((M.nx,M.ny,3,3),'d') 
     if not M.enable_hydrostatic: 
        M.cf3d     = numpy.zeros((M.nx,M.ny,M.nz,3,3,3),'d')
     if M.enable_expl_free_surf: 
        M.bu       = numpy.zeros((M.nx,M.ny,3),'d')
        M.bv       = numpy.zeros((M.nx,M.ny,3),'d')
        
     if M.enable_vert_friction_trm:
        M.a_trm  = numpy.zeros((M.nx,M.ny,M.nz),'d')
        
     if M.enable_back_state:
        M.back = numpy.zeros((M.nx,M.ny,M.nz,3),'d')
        M.u0   = numpy.zeros((M.nx,M.ny,M.nz),'d')
     if M.enable_back_zonal_flow and M.enable_back_meridional_flow:
        raise pyOMError('zonal and meridional background flow is not allowed')
     return


        
   def domain_decomposition(self):
     """ establish domain decomposition for parallel code execution
     """
     M=self.fortran.pyom_module         # fortran module with model variables
     if M.n_pes>1:
       M.n_pes_j = M.n_pes
       M.j_blk = int( (M.ny-1)/M.n_pes_j + 1 )   # j-extent of each block
       M.my_blk_j = M.my_pe + 1                       # number of PE in j-dir.
       M.js_pe = (M.my_blk_j-1)*M.j_blk + 1
       M.je_pe = min(M.my_blk_j*M.j_blk,M.ny)
       self.barrier()
       #  check for incorrect domain decomposition
       if M.my_blk_j == M.n_pes_j and M.js_pe>=M.je_pe-2:
         print ' ERROR: on PE: ', M.my_pe
         print ' domain decompositon impossible in j-direction'
         print ' choose other number of PEs in j-direction'
         raise pyOMError
     else: 
       M.n_pes_j  = M.n_pes
       M.j_blk    = M.ny
       M.my_blk_j = 1 
       M.js_pe    = 1
       M.je_pe    = M.ny
     # print out domain decomposition
     self.barrier()
     if M.n_pes >1:
       if M.my_pe==0: print  ' Domain decomposition:'
       for n in range(max(1,M.n_pes-1)):
          if M.my_pe==n:
            print ''
            print ' sub domain for PE #',n
            print ' my_blk_j=',M.my_blk_j
            print ' js_pe=',M.js_pe,' je_pe=',M.je_pe
            print ''
     return
   
   def barrier(self):
     """
     MPI barrier function
     """
     if hasattr(self,'mpi'): self.mpi_comm.barrier()
     return
   
   def make_grid(self):
     """  generate grid, etc
     """
     M=self.fortran.pyom_module         # fortran module with model variables
     
     M.c2dt=2*M.dt;                 # two times the time step
     M.taum1=1; M.tau=2; M.taup1=3; # pointers for time levels of variables
     
     # time step splitting for explicit free surface
     h_0 = (M.nz-2)*M.dz
     dtex=M.dx/numpy.sqrt(9.8*h_0)/4.
     n=max(1,int(M.dt/dtex))
     M.dtex=M.dt/n
     if M.enable_expl_free_surf and M.my_pe==0: print ' free surface time step : ',M.dtex,' s'

     # the grid
     M.xt[:]=numpy.arange(M.nx)*M.dx
     M.yt[:]=numpy.arange(M.ny)*M.dx
     M.yu[:]=M.yt+M.dx/2.
     M.xu[:]=M.xt+M.dx/2.
     M.zw[:]=numpy.arange(M.nz)*M.dz
     M.zt[:]=M.zw-M.dz/2.
     M.zt[:]=M.zt-M.dz*(M.nz-2)
     M.zw[:]=M.zw-M.dz*(M.nz-2)

     # Coriolis parameter
     self.set_coriolis()
  
     # land mask
     M.maskt[:,:,:]=1;
     M.maskt[:,:,0]=0; M.maskt[:,:,-1]=0
     M.maskt[:,0,:]=0; M.maskt[:,-1,:]=0
     M.maskt[0,:,:]=0; M.maskt[-1,:,:]=0

     # call overloaded method to define topography
     self.topography()
     
#     if M.enable_cyclic_x: M.maskt[ 0,:,:]=0; M.maskt[-1,:,:]=0
#     if M.enable_cyclic_y: M.maskt[:, 0,:]=0; M.maskt[:,-1,:]=0
#     self.fortran.setcyclic3d(M.maskt)
     if M.enable_cyclic_x:
        M.maskt[ 0,:,:]=M.maskt[-2,:,:]
        M.maskt[-1,:,:]=M.maskt[ 1,:,:]
     if M.enable_cyclic_y:
        M.maskt[:, 0,:]=M.maskt[:,-2,:]
        M.maskt[:,-1,:]=M.maskt[:, 1,:]
     
     # masks for other grids
     M.masku[0:-1,:,:]=numpy.minimum(M.maskt[0:-1,:,:],M.maskt[1:,:,:])
     if M.enable_cyclic_x:
        M.masku[ 0,:,:]=M.masku[-2,:,:]
        M.masku[-1,:,:]=M.masku[ 1,:,:]
     if M.enable_cyclic_y:
        M.masku[:, 0,:]=M.masku[:,-2,:]
        M.masku[:,-1,:]=M.masku[:, 1,:]
#     self.fortran.setcyclic3d(M.masku)
     M.maskv[:,0:-1,:]=numpy.minimum(M.maskt[:,0:-1,:],M.maskt[:,1:,:])
     if M.enable_cyclic_x:
        M.maskv[ 0,:,:]=M.maskv[-2,:,:]
        M.maskv[-1,:,:]=M.maskv[ 1,:,:]
     if M.enable_cyclic_y:
        M.maskv[:, 0,:]=M.maskv[:,-2,:]
        M.maskv[:,-1,:]=M.maskv[:, 1,:]
#     self.fortran.setcyclic3d(M.maskv)
     M.maskw[:,:,0:-1]=numpy.minimum(M.maskt[:,:,0:-1],M.maskt[:,:,1:])
#     self.fortran.setcyclic3d(M.maskw)
     if M.enable_cyclic_x:
        M.maskw[ 0,:,:]=M.maskw[-2,:,:]
        M.maskw[-1,:,:]=M.maskw[ 1,:,:]
     if M.enable_cyclic_y:
        M.maskw[:, 0,:]=M.maskw[:,-2,:]
        M.maskw[:,-1,:]=M.maskw[:, 1,:]

     # water depth, needed for external mode
     M.hu[:,:]=M.masku.sum(axis=2)*M.dz
     M.hv[:,:]=M.maskv.sum(axis=2)*M.dz
     M.ht[:,:]=M.maskt.sum(axis=2)*M.dz

     # deepest index, needed for bottom friction
     M.k_bottom_u[:,:] = M.nz
     M.k_bottom_v[:,:] = M.nz
     eins=numpy.ones((M.nx,M.ny),'i')
     for k in range(M.nz,0,-1):
       M.k_bottom_u[:,:] = numpy.where( M.masku[:,:,k-1] == 1,k*eins,M.k_bottom_u[:,:])
       M.k_bottom_v[:,:] = numpy.where( M.maskv[:,:,k-1] == 1,k*eins,M.k_bottom_v[:,:])
       
     # coefficients for poisson equation
     ierr = self.fortran.make_coef2d()
     if ierr != 0:
        print 'errorcode ',ierr,' in fortran routine make_coef2d'
        raise FortranError
     if not M.enable_hydrostatic:
        ierr = self.fortran.make_coef3d()
        if ierr != 0:
           print 'errorcode ',ierr,' in fortran routine make_coef3d'
           raise FortranError
     return


   def print_parameter(self):
     """
     print relevant model information
     """
     M=self.fortran.pyom_module         # fortran module with model variables
     if M.my_pe == 0:
        print ' grid size    : nx=',M.nx,' ny=',M.ny,' nz=',M.nz
        print ' grid spacing : Delta x =',M.dx,'m  Delta z =',M.dz,'m'
        print ' domain size  : ',M.nx*M.dx,'m X ',M.ny*M.dx,'m X ',M.nz*M.dz,'m'
        print ' time step    : ',M.dt,'s'
        print ' lateral  diffusivity : K_h=',M.k_h,' m^2/s'
        print ' vertical diffusivity : K_v=',M.k_v,' m^2/s'
        print ' lateral  viscosity   : A_h=',M.a_h,' m^2/s'
        print ' vertical viscosity   : A_v=',M.a_v,' m^2/s'
        print ' epsilon for 2D solver : ',M.eps2d_sor 
     return

   def write_restart(self,filename = 'restart.dta',tracer_fname='tracer_restart.dta'):
     """ write a restart
     """
     M=self.fortran.pyom_module         # fortran module with model variables
     if hasattr(self.fortran,'write_restart'):
        self.fortran.write_restart(self.itt,filename) 
        if M.enable_diag_tracer:
           self.fortran.tracer_write_restart(tracer_fname)
     else:
        raise pyOMError("routine write_restart not found")
     return

   def read_restart(self,filename = 'restart.dta',tracer_fname='tracer_restart.dta'):
     """ write a restart
     """
     M=self.fortran.pyom_module         # fortran module with model variables
     if hasattr(self.fortran,'read_restart'):
        self.startitt = self.fortran.read_restart(filename)
        if M.enable_diag_tracer:
           self.fortran.tracer_read_restart(tracer_fname)
     else:
        raise pyOMError("routine read_restart not found")
     return

 
   ############################################################
   # the rest are template methods which should be overloaded #
   ############################################################
   
   def set_parameter(self):
     """
     set main parameter. This method should be overloaded
     """
     M=self.fortran.pyom_module         # fortran module with model variables
     M.nx=1; M.nz=2; M.ny=1
     M.dx=1.0; M.dz=1.0; M.dt=1.0
     M.eps2d_sor = 1e-6; M.eps3d_sor = 1e-6
     self.lat_ref=0; self.beta=0
     self.snapint = 5./86400.0
     self.runlen=0; self.snap_file = 'pyOM.cdf'
     return
   
   def initial_conditions(self):
     """ setup all initial conditions. This method should be overloaded
     """
     return

   def restoring_zones(self):
     """ add here restoring zones. This method should be overloaded
     """
     return

   def tracer_sources(self):
     """ add here tracer sources. This method should be overloaded
     """
     return
  
   def momentum_restoring_zones(self):
     """ Momentum restoring zones. This method should be overloaded
     """
     return

   def boundary_conditions(self):
     """ Time dependent surface boundary conditions. This method should be overloaded
     """
     return
  
   def topography(self):
     """ Definition of topography. This method should be overloaded
     """
     return

   def set_coriolis(self):
     """ vertical and horizontal Coriolis parameter on yt grid
         routine is called after initialization of grid
     """
     return  
   

if __name__ == "__main__": print 'I will do nothing'
