#
# This file is part of GNU Enterprise.
#
# GNU Enterprise is free software; you can redistribute it 
# and/or modify it under the terms of the GNU General Public 
# License as published by the Free Software Foundation; either 
# version 2, or(at your option) any later version.
#
# GNU Enterprise is distributed in the hope that it will be 
# useful, but WITHOUT ANY WARRANTY; without even the implied 
# warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
# PURPOSE. See the GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public 
# License along with program; see the file COPYING. If not,
# write to the Free Software Foundation, Inc., 59 Temple Place
# - Suite 330, Boston, MA 02111-1307, USA.
#
# Copyright 2000, 2001 Free Software Foundation
#
# FILE:
# _pgsql/DBdriver.py
#
# DESCRIPTION:
# A core Postgresql implementation of dbdriver the other
# postgresql drivers can extend 
#
# NOTES:
#


from string import lower, join
import sys
from gnue.common import GDebug, GDataObjects
from gnue.common.dbdrivers._dbsig.DBdriver \
   import DBSIG_RecordSet, DBSIG_ResultSet, DBSIG_DataObject

from gnue.common.dbdrivers._dbsig.DBdriver import DBSIG_DataObject_Object as PGSQL_DataObject_Object
from gnue.common.dbdrivers._dbsig.DBdriver import DBSIG_DataObject_SQL as PGSQL_DataObject_SQL

class PGSQL_RecordSet(DBSIG_RecordSet):
  pass

class PGSQL_ResultSet(DBSIG_ResultSet):
  def __init__(self, dataObject, cursor=None, defaultValues={}, masterRecordSet=None):
    DBSIG_ResultSet.__init__(self, dataObject, \
                             cursor, defaultValues, masterRecordSet)
    self._recordSetClass = PGSQL_RecordSet

class PGSQL_DataObject(DBSIG_DataObject):
  def __init__(self, pgdriver=None, pgresultset=None):
    DBSIG_DataObject.__init__(self)
    if pgdriver:
      self._pgdriver = pgdriver
      self._DatabaseError = self._pgdriver.DatabaseError
    if pgresultset:
      self._resultSetClass = pgresultset


  # TODO: leaving here as placeholder for a common connect function for std postgresql drivers
  def connect(self, connectData={}):
    pass
    GDebug.printMesg(1,"Postgresql database driver initializing")
    try:
      #self._dataConnection = self._pgdriver.connect("%s::%s:%s:%s::" % \
      #       (connectData['host'],
      #        connectData['dbname'],
      #        connectData['_username'],
      #        connectData['_password']))
      self._dataConnection = self._pgdriver.connect('user=%s password=%s host=%s dbname=%s' %
                                          (connectData['_username'], connectData['_password'], connectData['host'], connectData['dbname']))
      self.triggerExtensions = PGSQL_TriggerExtensions(self._dataConnection)
    except self._DatabaseError, value:
      GDebug.printMesg(1,"%s::%s:%s:***::" % \
             (connectData['host'],
              connectData['dbname'],
              connectData['_username']))
      GDebug.printMesg(1,"Exception %s " % value)
      raise GDataObjects.LoginError, value

    try: 
      encoding = connectData['encoding']
      GDebug.printMesg(1,'Setting postgresql client_encoding to %s' % encoding)
      cursor = self._dataConnection.cursor()
      cursor.execute("SET CLIENT_ENCODING TO '%s'" % encoding)
      cursor.close()
    except KeyError: 
      pass
    except self._DatabaseError: 
      try: 
        cursor.close()
      except: 
        pass

    self._postConnect()

  def _postConnect(self):
    self.triggerExtensions = PGSQL_TriggerExtensions(self._dataConnection)

  #
  # Schema (metadata) functions
  #

  # Return a list of the types of Schema objects this driver provides
  def getSchemaTypes(self):
    return [('view','View',1), ('table','Table',1)]

  # Return a list of Schema objects
  def getSchemaList(self, type=None):
    includeTables = (type in ('table','sources', None))
    includeViews = (type in ('view','sources', None))

    inClause = []
    if includeTables:
      inClause.append ("'r'")
    if includeViews:
      inClause.append ("'v'")

    # TODO: This excludes any system tables and views. Should it?
    statement = "select relname, relkind from pg_class " + \
            "where relkind in (%s) " % (join(inClause,',')) + \
            "and relname not like 'pg_%' " + \
            "order by relname"

    cursor = self._dataConnection.cursor()
    cursor.execute(statement)

    list = []
    for rs in cursor.fetchall():
      list.append(GDataObjects.Schema(attrs={'id':lower(rs[0]), 'name':rs[0],
                         'type':rs[1] == 'v' and 'view' or 'table'},
                         getChildSchema=self.__getFieldSchema))

    cursor.close()
    return list


  # Find a schema object with specified name
  def getSchemaByName(self, name, type=None):
    statement = "select relname, relkind, oid from pg_class " + \
            "where relname = '%s'" % (name)

    cursor = self._dataConnection.cursor()
    cursor.execute(statement)

    rs = cursor.fetchone()
    if rs:
      schema = GDataObjects.Schema(attrs={'id':rs[2], 'name':rs[0],
                           'type':rs[1] == 'v' and 'view' or 'table'},
                           getChildSchema=self.__getFieldSchema)
    else:
      schema = None
      
    cursor.close()
    return schema

  # Get fields for a table
  def __getFieldSchema(self, parent):

    statement = "select attname, pg_attribute.oid, typname, " + \
            " attnotnull, atthasdef, atttypmod " + \
            "from pg_attribute, pg_type " + \
            "where attrelid = %d and " % (parent.id) + \
            "pg_type.oid = atttypid and attnum >= 0" + \
            "order by attnum"

    cursor = self._dataConnection.cursor()
    cursor.execute(statement)

    list = []
    for rs in cursor.fetchall():

      attrs={'id': rs[1], 'name': rs[0],
             'type':'field', 'nativetype': rs[2],
             'required': rs[3] and not rs[4]}

      if rs[2] in ('int8','int2','int4','numeric',
                   'float4','float8','money','bool'):
        attrs['datatype']='number'
      elif rs[2] in ('date','time','timestamp','abstime','reltime'):
        attrs['datatype']='date'
      else:
        attrs['datatype']='text'

      if rs[5] != -1:
        attrs['length'] = rs[5]

      list.append(GDataObjects.Schema(attrs=attrs))

    cursor.close()
    return list

#
#  Extensions to Trigger Namespaces
#
class PGSQL_TriggerExtensions:

  def __init__(self, connection):
    self.__connection = connection

  # Return the current date, according to database
  def getTimeStamp(self):
    return self.__singleQuery("select current_timestamp")

  # Return a sequence number from sequence 'name'
  def getSequence(self, name):
    return self.__singleQuery("select nextval('%s')" % name)

  # Run the SQL statement 'statement'
  def sql(self, statement):
    cursor = self.__connection.cursor()
    try:
      cursor.execute(statement)
      cursor.close()
    except:
      cursor.close()
      raise

  # Used internally
  def __singleQuery(self, statement):
    cursor = self.__connection.cursor()
    try:
      cursor.execute(statement)
      rv = cursor.fetchone()
      cursor.close()
    except:
      print "DBdriver.py", "You've got your bar in my foo! And you've got your foo on my bar!  Two great reams that ream well together!"
      GDebug.printMesg(1,"**** Unable to execute extension query")
      GDebug.printMesg(1,"**** %s" % sys.exc_info()[1])
      cursor.close()
      return None

    try:
      return rv[0]
    except:
      return None



