#
# This file is part of GNU Enterprise.
#
# GNU Enterprise is free software; you can redistribute it
# and/or modify it under the terms of the GNU General Public
# License as published by the Free Software Foundation; either
# version 2, or (at your option) any later version.
#
# GNU Enterprise is distributed in the hope that it will be
# useful, but WITHOUT ANY WARRANTY; without even the implied
# warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
# PURPOSE. See the GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public
# License along with program; see the file COPYING. If not,
# write to the Free Software Foundation, Inc., 59 Temple Place
# - Suite 330, Boston, MA 02111-1307, USA.
#
# Copyright 2000-2005 Free Software Foundation
#
# FILE:
# _dbsig/DBdriver.py
#
# DESCRIPTION:
"""
Generic implementation of dbdriver using Python DB-SIG v2
specification.
"""
#
# NOTES:
# The classes below are meant to be extended
#

__all__ = ['RecordSet']

from types import *

from gnue.common.datasources import GConditions, Exceptions
from gnue.common.datasources.drivers.Base import RecordSet as BaseRecordSet
from gnue.common.apps import errors
from string import join


######################################################################
#
#
#
class RecordSet(BaseRecordSet):

  def _postChanges(self, recordNumber=None):
    do = self._parent._dataObject
    if not self.isPending(): return

    if self._deleteFlag:
      s = self._buildDeleteStatement()
    elif self._insertFlag:
      #
      # Check for empty primary key and set with the sequence value if so
      #
      if hasattr(do,'primarykey') and hasattr(do,'primarykeyseq'):
        if do.primarykey and do.primarykeyseq and ',' not in do.primarykey and \
           hasattr(do._connection,'getsequence') and \
           self.getField(do.primarykey) is None:
          try:
            self.setField(do.primarykey,do._connection.getsequence(do.primarykeyseq))
          except do._connection._DatabaseError:
            raise exceptions.InvalidDatasourceDefintion, \
                errors.getException () [2]
      s = self._buildInsertStatement()
    elif self._updateFlag:
      s = self._buildUpdateStatement()

    else:
      # The record does not has a direct modification, so isPending () returns
      # True because a detail-record has pending changes
      return

    if isinstance (s, TupleType):
      # when useParameters is not set
      (statement, parameters) = s
    else:
      # when useParameters is set
      (statement, parameters) = (s, None)

    gDebug (8, "_postChanges: statement=%s" % statement)

    try:
      do._connection.sql (statement, parameters)

      # Set _initialData to be the just-now posted values
      if not self._deleteFlag:
        self._initialData = {}
        self._initialData.update(self._fields)

    except do._connection._DatabaseError, err:
      if recordNumber:
        raise Exceptions.ConnectionError, "\nERROR POSTING RECORD # %s\n\n%s" % (recordNumber,errors.getException () [2])
      else:
        raise Exceptions.ConnectionError, errors.getException () [2]

    self._updateFlag = False
    self._insertFlag = False
    self._deleteFlag = False

    return True


  # If a vendor can do any of these more efficiently (i.e., use a known
  # PRIMARY KEY or ROWID, then override these methods. Otherwise, leave
  # as default.  Note that these functions are specific to DB-SIG based
  # drivers (i.e., these functions are not in the base RecordSet class)

  # This function is only used with "useParameters" set in gnue.conf
  def _where (self):
    do = self._parent._dataObject
    if self._initialData.has_key(do._primaryIdField):
      where = [do._primaryIdFormat % \
               self._initialData [do._primaryIdField]]
      parameters = {}
    else:
      where = []
      parameters = {}
      for field in self._initialData.keys ():
        if self._parent.isFieldBound (field):
          if self._initialData [field] == None:
            where.append ("%s IS NULL" % field)
          else:
            key = 'old_' + field
            where.append ("%s=%%(%s)s" % (field, key))
            parameters [key] = self._initialData [field]

    return (join (where, ' AND '), parameters)

  def _buildDeleteStatement(self):
    do = self._parent._dataObject
    if gConfig ('useParameters'):
      (where, parameters) = self._where ()
      statement = 'DELETE FROM %s WHERE %s' % \
                  (do.table, where)
      return (statement, parameters)
    else:
      if self._initialData.has_key(do._primaryIdField):
        where = [do._primaryIdFormat % \
            self._initialData[do._primaryIdField]  ]
      else:
        where = []
        for field in self._initialData.keys():
          if self._parent.isFieldBound(field):
            if self._initialData[field] == None:
              where.append ("%s IS NULL" % field)
            else:
              where.append ("%s=%s" % (field,
                do._toSqlString(self._initialData[field])))

      statement = "DELETE FROM %s WHERE %s" % \
         (do.table, join(where,' AND ') )
      return statement

  def _buildInsertStatement(self):
    do = self._parent._dataObject
    if gConfig ('useParameters'):
      fields = []
      values = []
      parameters = {}

      for field in self._modifiedFlags.keys ():
        if self._parent.isFieldBound (field):
          key = 'new_' + field
          fields.append (field)
          values.append ('%%(%s)s' % key)
          parameters [key] = self._fields [field]

      statement = "INSERT INTO %s (%s) VALUES (%s)" % \
                  (do.table,
                   join (fields,', '),
                   join (values,', '))

      return (statement, parameters)
    else:
      vals = []
      fields = []

      for field in self._modifiedFlags.keys():
        if self._parent.isFieldBound(field):
          fields.append (field)
          if self._fields[field] == None or self._fields[field] == '':
            vals.append ("NULL") #  % (self._fields[field]))
          else:
            vals.append (do._toSqlString(self._fields[field]))

      return "INSERT INTO %s (%s) VALUES (%s)" % \
         (do.table, join(fields,','), \
          join(vals,',') )

  def _buildUpdateStatement(self):
    do = self._parent._dataObject
    if gConfig ('useParameters'):
      (where, parameters) = self._where ()
      updates = []
      for field in self._modifiedFlags.keys():
        key = 'new_' + field
        updates.append ("%s=%%(%s)s" % (field, key))
        parameters[key] = self._fields [field]

      statement = "UPDATE %s SET %s WHERE %s" % \
                  (do.table,
                   join(updates, ', '),
                   where)
      return (statement, parameters)
    else:
      updates = []
      for field in self._modifiedFlags.keys():
        updates.append ("%s=%s" % (field,
           do._toSqlString(self._fields[field])))

      if do._primaryIdField:
        where = [do._primaryIdFormat % \
            self._initialData[do._primaryIdField]  ]
      else:
        where = []
        for field in self._initialData.keys():
          if self._initialData[field] == None:
            where.append ("%s IS NULL" % field)
          else:
            where.append ("%s=%s" % (field, do._toSqlString(self._initialData[field])))

      return "UPDATE %s SET %s WHERE %s" % \
         (do.table, join(updates,','), \
          join(where,' AND ') )

  def _requery(self):
    """
    Requery a posted record to capture any changes made by the database
    """
    return  # TODO: Will test tomorrow

    do = self._parent._dataObject
    if not do.primarykey or not do._fieldReferences or \
       self._fields.get(key,None) is None :
      return False
    fields = self._fieldReferences.keys()
    where = []
    for key in do.primarykeys.split(','):
      where.append("%s=%s"%(key,do._toSqlString(self._fields[key])))
    select = "SELECT %s FROM %s WHERE %s" % (
            join(fields,','), do.table, join(where,' and '))
    print select
    try:
      cursor = self._connection.makecursor(select)
      f = cursor.fetchone()
      for i in range(len(f)):
        self.setField(fields[i], f[i], False)
    except do._connection._DatabaseError, err:
      raise Exceptions.ConnectionError, errors.getException () [2]
    return True
