# GNU Enterprise Common - DBSIG2 DB Driver - Schema Creation
#
# Copyright 2001-2005 Free Software Foundation
#
# This file is part of GNU Enterprise
#
# GNU Enterprise is free software; you can redistribute it
# and/or modify it under the terms of the GNU General Public
# License as published by the Free Software Foundation; either
# version 2, or (at your option) any later version.
#
# GNU Enterprise is distributed in the hope that it will be
# useful, but WITHOUT ANY WARRANTY; without even the implied
# warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
# PURPOSE. See the GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public
# License along with program; see the file COPYING. If not,
# write to the Free Software Foundation, Inc., 59 Temple Place
# - Suite 330, Boston, MA 02111-1307, USA.
#
# $Id: Creation.py 6851 2005-01-03 20:59:28Z jcater $

import string
from gnue.common.datasources.drivers.Base.Schema.Creation import \
    Creation as Base

# =============================================================================
# Exceptions
# =============================================================================

class NumericTransformationError (Base.Error):
  def __init__ (self, length, scale):
    msg = u_("No numeric transformation for number (%(length)d,%(scale)d)") \
          % {'length': length, 'scale': scale}
    Base.Error.__init__ (self, msg)

# =============================================================================
# This class implement the base creation for SQL like drivers
# =============================================================================

class Creation (Base.Creation):

  ALTER_MULTIPLE   = True
  EXTRA_PRIMARYKEY = False
  END_COMMAND      = ";"

  # ---------------------------------------------------------------------------
  # Create a table
  # ---------------------------------------------------------------------------

  def createTable (self, tableDefinition, codeOnly = False):
    """
    This function creates a table using the given definition and returns a
    code-tuple, which can be used to to this. 

    @param tableDefinition: a dictionary of the table to be created 
    @param codeOnly: if TRUE no operation takes place, but only the code will
        be returned.
    @return: a tuple of sequences (prologue, body, epliogue) containing the
        code to perform the action.
    """

    res = Base.Creation.createTable (self, tableDefinition, codeOnly)
    body = res [1]

    tableName = tableDefinition ['name']
    
    if tableDefinition.has_key ('fields'):
      fields = tableDefinition ['fields']

      fCode = self.createFields (tableName, fields, False)

      if tableDefinition.has_key ('primarykey'):
        if self.EXTRA_PRIMARYKEY:
          self.mergeTuple (res, self.createPrimaryKey (tableName,
                                        tableDefinition ['primarykey'], True))
        else:
          self.mergeTuple (fCode, self.createPrimaryKey (tableName,
                                        tableDefinition ['primarykey'], True))

      code = u"CREATE TABLE %s (%s)%s" % \
              (tableName, string.join (fCode [1], ", "), self.END_COMMAND)
      self.mergeTuple (res, (fCode [0], [code], fCode [2]))

    # Create all requested indices
    if tableDefinition.has_key ('indices'):
      for ixDef in tableDefinition ['indices']:
        self.mergeTuple (res, self.createIndex (tableName, ixDef, True))

    # Add all constraints
    if tableDefinition.has_key ('constraints'):
      for constraintDef in  tableDefinition ['constraints']:
        self.mergeTuple (res, \
           self.createConstraint (tableName, constraintDef, True))

    if not codeOnly:
      self._executeCodeTuple (res)

    self.lookup = {}
    return res


  # ---------------------------------------------------------------------------
  # Modify a table
  # ---------------------------------------------------------------------------

  def modifyTable (self, tableDefinition, codeOnly = False):
    """
    This function modifies a table according to the given definition.

    @param tableDefinition: a dictionary of the table to be modified
    @param codeOnly: if TRUE no operation takes place, but only the code will
        be returned.
    @return: a tuple of sequences (prologue, body, epliogue) containing the
        code to perform the action.
    """

    res  = Base.Creation.modifyTable (self, tableDefinition, codeOnly)
    body = res [1]

    table = tableDefinition ['name']

    if tableDefinition.has_key ('old_indices'):
      for ixName in tableDefinition ['old_indices']:
        self.mergeTuple (res, self.dropIndex (table, ixName, True))

    if tableDefinition.has_key ('fields') and len (tableDefinition ['fields']):
      if self.ALTER_MULTIPLE:
        item = tableDefinition ['fields']
        fCode = self.createFields (table, item, True)
        code = u"ALTER TABLE %s ADD (%s)%s" % \
                  (table, string.join (fCode [1], ", "), self.END_COMMAND)
        self.mergeTuple (res, (fCode [0], [code], fCode [2]))

      else:
        fields = [[fDef] for fDef in tableDefinition ['fields']]
        for item in fields:
          fCode = self.createFields (table, item, True)
          code = u"ALTER TABLE %s ADD %s%s" % \
              (table, string.join (fCode [1], ", "), self.END_COMMAND)
          self.mergeTuple (res, (fCode [0], [code], fCode [2]))

    # Create all requested indices
    if tableDefinition.has_key ('indices'):
      for ixDef in tableDefinition ['indices']:
        self.mergeTuple (res, self.createIndex (table, ixDef, True))

    # Add all constraints
    if tableDefinition.has_key ('constraints'):
      for constraintDef in  tableDefinition ['constraints']:
        self.mergeTuple (res, \
           self.createConstraint (table, constraintDef, True))

    if not codeOnly:
      self._executeCodeTuple (res)

    self.lookup = {}
    return res


  # ---------------------------------------------------------------------------
  # Create or modify fields for a table
  # ---------------------------------------------------------------------------

  def createFields (self, tableName, fields, forAlter = False):
    """
    This function calls _processField for each listed field and merges this
    result into its own code-tuple.

    @param tableName: name of the table for which fields should be created or
        modified.
    @param fields: a list of field definition dictionaries, describing the
        fields to be created or modified.
    @param forAlter: if TRUE the fields should be modified, otherwise created
    @return: a tuple of sequences (prologue, body, epliogue) containing the
        code to perform the action.
    """
    res = Base.Creation.createFields (self, tableName, fields, forAlter)
    for field in fields:
      self.mergeTuple (res, self._processField (tableName, field, forAlter))
    return res


  # ---------------------------------------------------------------------------
  # Create a primary key
  # ---------------------------------------------------------------------------

  def createPrimaryKey (self, tableName, keyDefinition, codeOnly = False):
    """
    This function creates a primary key for the given table using the primary
    key definition. If the constant EXTRA_PRIMARYKEY is true, an 'alter
    table'-statement will be created, otherwise the primary key constraint will
    be added to table creation/modification statement.

    @param tableName: name of the table for which a key should be created
    @param keyDefinition: a dictionary of the primary key to be created 
    @param codeOnly: if TRUE no operation takes place, but only the code will
        be returned.
    @return: a tuple of sequences (prologue, body, epliogue) containing the
        code to perform the action.
    """
    res = Base.Creation.createPrimaryKey (self, tableName, keyDefinition,
                                          codeOnly)
    fields  = string.join (keyDefinition ['fields'], ", ")
    keyName = self._shortenName (keyDefinition ['name'])
    code = u"CONSTRAINT %s PRIMARY KEY (%s)" % (keyName, fields)
    if self.EXTRA_PRIMARYKEY:
      res [2].append (u"ALTER TABLE %s ADD %s" % (tableName, code))
    else:
      res [1].append (code)

    return res


  # ---------------------------------------------------------------------------
  # Create a new index for a table
  # ---------------------------------------------------------------------------

  def createIndex (self, tableName, indexDefinition, codeOnly = False):
    """
    This function creates an index for the given table using the index
    definition.

    @param tableName: name of the table for which an index should be created
    @param indexDefinition: a dictionary of the index to be created 
    @param codeOnly: if TRUE no operation takes place, but only the code will
        be returned.
    @return: a tuple of sequences (prologue, body, epliogue) containing the
        code to perform the action.
    """
    res = Base.Creation.createIndex (self, tableName, indexDefinition,
                                     codeOnly)
    unique = indexDefinition.has_key ('unique') and indexDefinition ['unique']
    indexName = self._shortenName (indexDefinition ['name'])

    if self.lookup.has_key ("INDEX_%s" % indexName):
      return ([], [], [])

    body = res [1]
    body.append (u"CREATE %sINDEX %s ON %s (%s)%s" % \
        (unique and "UNIQUE " or "", indexName, tableName,
         string.join (indexDefinition ['fields'], ", "), self.END_COMMAND))

    if not codeOnly:
      self._executeCodeTuple (res)

    self.lookup ["INDEX_%s" % indexName] = True

    return res


  # ---------------------------------------------------------------------------
  # Drop an old index
  # ---------------------------------------------------------------------------

  def dropIndex (self, tableName, indexName, codeOnly = False):
    """
    This function drops an index from the given table

    @param tableName: name of the table to drop an index from
    @param indexName: name of the index to be dropped
    @param codeOnly: if TRUE no operation takes place, but only the code will
        be returned.
    @return: a tuple of sequences (prologue, body, epliogue) containing the
        code to perform the action.
    """

    res = Base.Creation.dropIndex (self, tableName, indexName, codeOnly)

    indexName = self._shortenName (indexName)
    res [0].append (u"DROP INDEX %s%s" % (indexName, self.END_COMMAND))

    if not codeOnly:
      self._executeCodeTuple (res)

    return res


  # ---------------------------------------------------------------------------
  # Create a constraint
  # ---------------------------------------------------------------------------

  def createConstraint (self, tableName, constraintDef, codeOnly = False):
    """
    This function creates a constraint for the given table using the constraint
    definition.

    @param tableName: name of the table for which an index should be created
    @param constraintDef: a dictionary of the constraint to be created 
    @param codeOnly: if TRUE no operation takes place, but only the code will
        be returned.
    @return: a tuple of sequences (prologue, body, epliogue) containing the
        code to perform the action.
    """
    res = Base.Creation.createConstraint (self, tableName, constraintDef,
                                          codeOnly)
    cName = self._shortenName (constraintDef ['name'])
    tries = 0
    while self.lookup.has_key ("CONSTRAINT_%s" % cName) and tries <= 10:
      cName = "%s%d" % (cName [:-1], tries)
      tries += 1

    fields    = constraintDef ['fields']
    reftable  = constraintDef ['reftable']
    reffields = constraintDef ['reffields']

    body = res [1]
    body.append (u"ALTER TABLE %s ADD CONSTRAINT %s FOREIGN KEY (%s) "
                  "REFERENCES %s (%s)%s" % \
         (tableName, cName, string.join (fields, ", "), reftable,
          string.join (reffields, ", "), self.END_COMMAND))

    self.lookup ["CONSTRAINT_%s" % cName] = True

    if not codeOnly:
      self._executeCodeTuple (res)

    return res


  # ---------------------------------------------------------------------------
  # Create code for a single field definition
  # ---------------------------------------------------------------------------

  def _processField (self, tableName, fieldDef, forAlter = False):
    """
    This function creates a portion of code which defines the given field in
    the table tableName. 
    
    @param tableName: the table this field belongs to.
    @param fieldDef: the dictionary describing the field.
    @param forAlter: If TRUE this function produces code for a table
        modification, otherwise for a table creation.
    @return: a tuple of sequences (prologue, body, epliogue) containing the
        code to perform the action.
    """

    res = Base.Creation._processField (self, tableName, fieldDef, forAlter)
    body = res [1]

    body.append (self._composeField (tableName, fieldDef, forAlter))

    if fieldDef.has_key ('defaultwith'):
      self._defaultwith (res, tableName, fieldDef, forAlter)

    if fieldDef.has_key ('default') and fieldDef ['default']:
      default = fieldDef ['default']
      if default [:8].upper () != 'DEFAULT ':
        default = "DEFAULT %s" % default
      self._setColumnDefault (res, tableName, fieldDef, forAlter, default)

    self._integrateNullable (res, tableName, fieldDef, forAlter)

    return res


  # ---------------------------------------------------------------------------
  # Handle the nullable flag of a field 
  # ---------------------------------------------------------------------------

  def _integrateNullable (self, code, tableName, fieldDef, forAlter):
    """
    This function handles the nullable flag of a field. If the field is not
    nullable the last line of the code's body sequence will be modified on a
    create-action, or an 'alter table'-statement is added to the code's
    epilogue. @see: _setColumnDefault ()

    @param code: code-tuple which get's the result. If forAlter is FALSE this
        function assumes the field's code is the last line in code.body
    @param tableName: name of the table the field belongs to
    @param fieldDef: dictionary describing the field
    @param forAlter: if TRUE, the field definition is used in a table
        modification, otherwise in a table creation.
    """
    if fieldDef.has_key ('nullable') and not fieldDef ['nullable']:
      self._setColumnDefault (code, tableName, fieldDef, forAlter, "NOT NULL")


  # ---------------------------------------------------------------------------
  # Process a defaultwith attribute
  # ---------------------------------------------------------------------------

  def _defaultwith (self, code, tableName, fieldDef, forAlter):
    """
    This function could be overriden by any descendants to create code for
    special defaults like 'serial' or 'timestamp'.

    @param code: code-tuple to merge the result in
    @param tableName: name of the table
    @param fieldDef: dictionary describing the field with the default
    @param forAlter: TRUE if the definition is used in a table modification
    """
    pass


  # ---------------------------------------------------------------------------
  # Set a default value for a given column
  # ---------------------------------------------------------------------------

  def _setColumnDefault (self, code, tableName, fieldDef, forAlter, default):
    """
    This function sets a default value for a given column. If it is called for
    a table modification the epilogue of the code-block will be modified.
    On a table creation, this function assumes the field's code is in the last
    line of the code-block's body sequence.

    @param code: code-tuple which get's the result. If forAlter is FALSE this
        function assumes the field's code is the last line in code.body
    @param tableName: name of the table the field belongs to
    @param fieldDef: dictionary describing the field
    @param forAlter: if TRUE, the field definition is used in a table
        modification, otherwise in a table creation.
    @param default: string with the default value for the column
    """
    if forAlter:
      code [2].append (u"ALTER TABLE %s ALTER COLUMN %s SET %s%s" % \
          (tableName, fieldDef ['name'], default, self.END_COMMAND))
    else:
      code [1][-1] += " %s" % default


  # ---------------------------------------------------------------------------
  # Compose a field from 'fieldname fieldtype'
  # ---------------------------------------------------------------------------

  def _composeField (self, tableName, fieldDefinition, forAlter):
    """
    This function composes a field definition of the form <fieldname>
    <fieldtype> where the latter one has been translated using the
    _translateType function.

    @param tableName: name of the table the field belongs to
    @param fieldDefinition: the dictionary describing the field
    @return: string containing fieldname and fieldtype
    """
    res = "%s %s" % (fieldDefinition ['name'],
                     self._translateType (fieldDefinition))
    return res
    


  # ---------------------------------------------------------------------------
  # Execute all parts of a code-tuple
  # ---------------------------------------------------------------------------

  def _executeCodeTuple (self, code):
    """
    This function executes the given code-tuple using the instances connection.
    @param code: tuple of n code-sequences. All elements of each sequence is
        treated as single statement which gets executed via
        conneciton.makecursor ()
    """
    if self.connection is not None:
      for block in range (len (code)):
        for statement in code [block]:
          if len (statement):
            cursor = self.connection.makecursor (statement)
            cursor.close ()

      self.connection.commit ()


  # ---------------------------------------------------------------------------
  # A string becomes either varchar or text 
  # ---------------------------------------------------------------------------

  def string (self, fieldDefinition):
    """
    This function creates a native datatype for a string field. If a length is
    defined it results in a 'varchar'- otherwise in a 'text'-field

    @param fieldDefinition: dictionary describing the field
    @return: varchar (length) or text
    """

    if fieldDefinition.has_key ('length'):
      return "varchar (%s)" % fieldDefinition ['length']
    else:
      return "text"


  # ---------------------------------------------------------------------------
  # Keep date as is
  # ---------------------------------------------------------------------------

  def date (self, fieldDefinition):
    """
    This function returns the native datatype for 'date'-fields

    @param fieldDefinition: dictionary describing the field
    @return: 'date'
    """
    return "date"


  # ---------------------------------------------------------------------------
  # Keep time as is
  # ---------------------------------------------------------------------------

  def time (self, fieldDefinition):
    """
    This function returns the native datatype for 'time'-fields

    @param fieldDefinition: dictionary describing the field
    @return: 'time'
    """
    return "time"


  # ---------------------------------------------------------------------------
  # Keep datetime as is
  # ---------------------------------------------------------------------------

  def datetime (self, fieldDefinition):
    """
    This function returns the native datatype for 'datetime'-fields

    @param fieldDefinition: dictionary describing the field
    @return: 'dateime'
    """
    return "datetime"



# =============================================================================
# Modules self test code
# =============================================================================

if __name__ == '__main__':
  def dumpTuple (aTuple):
    print "\nPrologue:"
    print "---------"
    for line in aTuple [0]:
      print ">>  %s" % line
    print "\nBody:"
    print "------"
    for line in aTuple [1]:
      print ">>  %s" % line
    print "\nEpilogue:"
    print "---------"
    for line in aTuple [2]:
      print ">>  %s" % line

  cr = Creation ()
  print "Hey!"
  fields = [{'name'    : 'gnue_id',
             'type'    : 'string',
             'length'  : 32,
             'nullable': False},
            {'name'    : 'address_code',
             'type'    : 'string',
             'length'  : 2,
             'nullable': True},
            {'name'    : 'fooserial',
             'type'    : 'string',
             'length'  : 6,
             'nullable': False,
             'defaultwith': 'serial',
             'default': ''}]

  tdef = {'name': 'address_country',
          'fields': fields,
          'primarykey': {
            'name': 'pk_gnue_id_address_country',
            'fields': ['gnue_id']},
          'indices': [
            {'name': 'code_index',
             'unique': True,
             'fields': ['address_code']},
            {'name': 'silly_index',
             'fields': ['address_code', 'gnue_id']}],
          'constraints': [
            {'name': 'fake_constraint',
             'type': 'foreignkey',
             'fields': ['address_code', 'fake'],
             'reftable': 'foobar',
             'reffields': ['gnue_id', 'trash']}
          ]
         }
          
  res = cr.createTable (tdef, True)
  dumpTuple (res)

  del tdef ['primarykey']
  res = cr.modifyTable (tdef, True)
  dumpTuple (res)
