# GNU Enterprise Common - Base DB Driver - Schema Creation
#
# Copyright 2001-2005 Free Software Foundation
#
# This file is part of GNU Enterprise
#
# GNU Enterprise is free software; you can redistribute it
# and/or modify it under the terms of the GNU General Public
# License as published by the Free Software Foundation; either
# version 2, or (at your option) any later version.
#
# GNU Enterprise is distributed in the hope that it will be
# useful, but WITHOUT ANY WARRANTY; without even the implied
# warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
# PURPOSE. See the GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public
# License along with program; see the file COPYING. If not,
# write to the Free Software Foundation, Inc., 59 Temple Place
# - Suite 330, Boston, MA 02111-1307, USA.
#
# $Id: Creation.py 6851 2005-01-03 20:59:28Z jcater $

import gnue

# =============================================================================
# Exceptions
# =============================================================================

class Error (gException):
  pass

class NoCreationError (Error):
  def __init__ (self):
    msg = _("Database creation not implemented by this driver")
    Error.__init__ (self, msg)

class DefinitionError (Error):
  pass

class MissingKeyError (DefinitionError):
  MSG = u_("The definition has no attribute '%s'")
  def __init__ (self, attribute):
    DefinitionError.__init__ (self, self.MSG % attribute)

class TableDefinitionError (MissingKeyError):
  MSG = u_("The table definition has no attribute '%s'")

class FieldDefinitionError (MissingKeyError):
  MSG = u_("The field definition has no attribute '%s'")

class PrimaryKeyDefinitionError (MissingKeyError):
  MSG = u_("Primarykey definition has no attribute '%s'")

class PrimaryKeyFieldsError (Error):
  def __init__ (self, table, name):
    msg = u_("Primarykey '%(name)s' of table '%(table)s' has no fields") \
          % {'name' : name,
             'table': table}
    Error.__init__ (self, msg)

class PrimaryKeyError (DefinitionError):
  def __init__ (self, table):
    msg = u_("Table '%s' has a primary key which is not allowed on "
             "table modification") % table
    DefinitionError.__init__ (self, msg)

class IndexDefinitionError (MissingKeyError):
  MSG = u_("Index definition has no attribute '%s'")

class IndexFieldsError (Error):
  def __init__ (self, table, name):
    msg = u_("Index '%(name)s' of table '%(table)s' has no fields") \
          % {'name' : name,
             'table': table}
    Error.__init__ (self, msg)

class ConstraintDefinitionError (MissingKeyError):
  MSG = u_("Constraint definition has no attribute '%s'")

class ConstraintFieldsError (Error):
  def __init__ (self, table, name, fields):
    msg = u_("Constraint '%(name)s' of table '%(table)s' has no '%(fields)s'")\
          % {'name'  : name,
             'table' : table,
             'fields': fields}
    Error.__init__ (self, msg)

class ConstraintTypeError (Error):
  def __init__ (self, table, name, cType):
    msg = u_("Type '%(type)s' of constraint '%(name)s' in table '%(table)s' "
             "not supported") % \
            {'table': table,
             'name' : name,
             'type' : cType}
    Error.__init__ (self, msg)

class MissingTypeTransformationError (Error):
  def __init__ (self, typename):
    msg = u_("No type transformation for '%s' found") % typename
    Error.__init__ (self, msg)


class LengthError (Error):
  def __init__ (self, identifier, maxlen):
    msg = u_("The idendifier '%(identifier)s' exceeds the maximum length "
             "of %(maxlength)d characters") \
          % {'identifier': identifier,
             'maxlength': maxlen or 0}
    Error.__init__ (self, msg)


# =============================================================================
# Base class for drivers schema creation support
# =============================================================================
class Creation:

  MAX_NAME_LENGTH = None        # Max. length of an identifier
  END_COMMAND     = ""          # Character used for command termination
  
  # ---------------------------------------------------------------------------
  # Constructor
  # ---------------------------------------------------------------------------

  def __init__ (self, connection = None, introspector = None):
    self.connection   = connection
    self.introspector = introspector
    self.lookup       = {}

    if connection is not None and introspector is None:
      self.introspector = connection.introspector


  # ---------------------------------------------------------------------------
  # Create a database
  # ---------------------------------------------------------------------------

  def createDatabase (self):
    """
    Descendants can override this function to create a database. Usually all
    information needed could be gathered from the connection object.
    """
    raise NoCreationError


  # ---------------------------------------------------------------------------
  # Create a table from a table definition
  # ---------------------------------------------------------------------------

  def createTable (self, tableDefinition, codeOnly = False):
    """
    This function creates a table using the given definition and returns a
    code-tuple, which can be used to to this.

    @param tableDefinition: a dictionary of the table to be created
    @param codeOnly: if TRUE no operation takes place, but only the code will
        be returned.
    @return: a tuple of sequences (prologue, body, epliogue) containing the
        code to perform the action.
    """
    self._validateTable (tableDefinition)
    return ([], [], [])
  

  # ---------------------------------------------------------------------------
  # Create a primary key
  # ---------------------------------------------------------------------------

  def createPrimaryKey (self, tableName, keyDefinition, codeOnly = False):
    """
    This function creates a primary key for the given table using the primary
    key definition.

    @param tableName: name of the table for which a key should be created
    @param keyDefinition: a dictionary of the primary key to be created 
    @param codeOnly: if TRUE no operation takes place, but only the code will
        be returned.
    @return: a tuple of sequences (prologue, body, epliogue) containing the
        code to perform the action.
    """
    self._validatePrimaryKey (tableName, keyDefinition)
    return ([], [], [])


  # ---------------------------------------------------------------------------
  # Create an index
  # ---------------------------------------------------------------------------

  def createIndex (self, tableName, indexDefinition, codeOnly = False):
    """
    This function creates an index for the given table using the index
    definition.

    @param tableName: name of the table for which an index should be created
    @param indexDefinition: a dictionary of the index to be created 
    @param codeOnly: if TRUE no operation takes place, but only the code will
        be returned.
    @return: a tuple of sequences (prologue, body, epliogue) containing the
        code to perform the action.
    """
    self._validateIndex (tableName, indexDefinition)
    return ([], [], [])


  # ---------------------------------------------------------------------------
  # Drop an index
  # ---------------------------------------------------------------------------

  def dropIndex (self, tableName, indexName, codeOnly = False):
    """
    This function drops an index from the given table using the index
    definition.

    @param tableName: name of the table to drop an index from
    @param indexName: name of the index to be dropped 
    @param codeOnly: if TRUE no operation takes place, but only the code will
        be returned.
    @return: a tuple of sequences (prologue, body, epliogue) containing the
        code to perform the action.
    """
    return ([], [], [])


  # ---------------------------------------------------------------------------
  # Create a constraint
  # ---------------------------------------------------------------------------

  def createConstraint (self, tableName, constraintDef, codeOnly = False):
    """
    This function creates a constraint for the given table using the constraint
    definition.

    @param tableName: name of the table for which an index should be created
    @param constraintDef: a dictionary of the constraint to be created 
    @param codeOnly: if TRUE no operation takes place, but only the code will
        be returned.
    @return: a tuple of sequences (prologue, body, epliogue) containing the
        code to perform the action.
    """
    self._validateConstraint (tableName, constraintDef)
    return ([], [], [])


  # ---------------------------------------------------------------------------
  # Modify a table
  # ---------------------------------------------------------------------------

  def modifyTable (self, tableDefinition, codeOnly = False):
    """
    This function modifies a table according to the given definition.

    @param tableDefinition: a dictionary of the table to be modified
    @param codeOnly: if TRUE no operation takes place, but only the code will
        be returned.
    @return: a tuple of sequences (prologue, body, epliogue) containing the
        code to perform the action.
    """
    self._validateTable (tableDefinition, True)
    return ([], [], [])


  # ---------------------------------------------------------------------------
  # Create fields for a table
  # ---------------------------------------------------------------------------

  def createFields (self, tableName, fields, forAlter = False):
    """
    This function creates all listed fields in the given table. If forAlter is
    TRUE this function should create the fields for a table modification.

    @param tableName: name of the table for which fields should be created or
        modified.
    @param fields: a list of field definition dictionaries, describing the
        fields to be created or modified.
    @param forAlter: if TRUE the fields should be modified, otherwise created
    @return: a tuple of sequences (prologue, body, epliogue) containing the
        code to perform the action.
    """
    for field in fields:
      self._validateField (tableName, field)
    return ([], [], [])


  # ---------------------------------------------------------------------------
  # Check wether an element exists or not
  # ---------------------------------------------------------------------------

  def exists (self, elementName, elementType = None):
    """
    This function examines, wether an element exists in a datamodel or not.
    It's doing this using the given introspector. If no introspecor is
    available the result is FALSE.

    @param elementName: name of the element to be examined
    @param elementType: type of the element to be examined (optional)
    @return: TRUE if the element was found, otherwise FALSE
    """
    if self.introspector is not None:
      return self.introspector.find (name = elementName, type = elementType)
    else:
      return False


  # ---------------------------------------------------------------------------
  # Validate a given table definition
  # ---------------------------------------------------------------------------

  def validate (self, tableDef):
    """
    This function validates all parts of a table definition.
    @param tableDef: dictionary describing the table and it's parts.
    """
    self._validateTable (tableDef)
    tableName = tableDef['name']

    if tableDef.has_key ('primarykey'):
      self._validatePrimaryKey (tableName, tableDef ['primarykey'])

    if tableDef.has_key ('fields'):
      for field in tableDef ['fields']:
        self._validateField (tableName, field)

    if tableDef.has_key ('indices'):
      for index in tableDef ['indices']:
        self._validateIndex (tableName, index)

    if tableDef.has_key ('constraints'):
      for constraint in tableDef ['constraints']:
        self._validateConstraint (tableName, constraint)


  # ---------------------------------------------------------------------------
  # Make sure to release all references
  # ---------------------------------------------------------------------------

  def close (self):
    """
    This function releases all circular references held by the creator instance
    """

    self.introspector = None
    self.connection   = None
    self.lookup       = {}


  # ---------------------------------------------------------------------------
  # Call the appropriate method for a type-transformation
  # ---------------------------------------------------------------------------

  def _translateType (self, fieldDefinition):
    """
    This function calls the appropriate method for a type-conversion according
    to the field definition's datatype and returns this method's result.

    @param fieldDefinition: dictionary describing the field.
    @return: a string with the native data type for the field definition.
    """
    if not fieldDefinition.has_key ('type'):
      raise FieldDefinitionError, ('type')

    aMethod = self.__findMethod (self.__class__, fieldDefinition ['type'])
    if aMethod is None:
      raise MissingTypeTransformationError, (fieldDefinition ['type'])

    return aMethod (self, fieldDefinition)


  # ---------------------------------------------------------------------------
  # Create code for a single field
  # ---------------------------------------------------------------------------

  def _processField (self, tableName, fieldDef, forAlter = False):
    """
    This function creates a portion of code which defines the given field in
    the table tableName. 
    
    @param tableName: the table this field belongs to.
    @param fieldDef: the dictionary describing the field.
    @param forAlter: If TRUE this function produces code for a table
        modification, otherwise for a table creation.
    @return: a tuple of sequences (prologue, body, epliogue) containing the
        code to perform the action.
    """
    return ([], [], [])


  # ---------------------------------------------------------------------------
  # Create a usable name for a seuquence like object
  # ---------------------------------------------------------------------------

  def _getSequenceName (self, tableName, fieldDefinition):
    """
    This function creates a name for a sequence like object using the table-
    and fieldname. It respects a given restriction of identifier length.

    @param tableName: name of the table
    @param fieldDefinition: dictionary describing the field
    @return: string with a name for the given sequence
    """

    res = "%s_%s_seq" % (tableName, fieldDefinition ['name'])
    if self._nameTooLong (res):
      res = "%s_%s_seq" % (tableName, id (fieldDefinition))

    if self._nameTooLong (res):
      res = "%s_seq" % (id (fieldDefinition))

    return self._shortenName (res)


  # ---------------------------------------------------------------------------
  # Check if an identifier is too long
  # ---------------------------------------------------------------------------

  def _nameTooLong (self, aName):
    """
    This function returns TRUE if @aName exceeds MAX_NAME_LENGTH, otherwise
    FALSE. 
    """
    return (self.MAX_NAME_LENGTH is not None) and \
           (len (aName) > self.MAX_NAME_LENGTH)


  # ---------------------------------------------------------------------------
  # Make sure a given identifier doesn't exceed maximum length
  # ---------------------------------------------------------------------------

  def _shortenName (self, aName):
    """
    This function makes sure the given name doesn't exceed the maximum
    identifier length.
    @param aName: identifier to be checked
    @return: identifier with extra characters cut off
    """
    if self._nameTooLong (aName):
      return aName [:self.MAX_NAME_LENGTH]
    else:
      return aName


  # ---------------------------------------------------------------------------
  # Merge all sequences in the given tuples 
  # ---------------------------------------------------------------------------

  def mergeTuple (self, mergeInto, mergeFrom):
    """
    This function merges the sequences in the given tuples and returns the
    first one (which is changes as a side effect too).
    @param mergeInto: tuple with sequences which gets extended
    @param mergeFrom: tuple with sequences which mergeInto gets extended with
    @return: tuple of the same length as mergeInto with all sequences merged
        together.
    """
    for ix in range (len (mergeInto)):
      mergeInto [ix].extend (mergeFrom [ix])
    return mergeInto


  # ---------------------------------------------------------------------------
  # Validate a table definition
  # ---------------------------------------------------------------------------

  def _validateTable (self, tableDef, forAlter = False):
    """
    This function validates a table definition.
    @param tableDef: dictionary describing the table

    @raise TableDefinitionError: If tableDef has no key 'name'
    """
    self.__validateDefinition (tableDef, ['name'], TableDefinitionError)
    if self._nameTooLong (tableDef ['name']):
      raise LengthError, (tableDef ['name'], self.MAX_NAME_LENGTH)

    if forAlter and tableDef.has_key ('primarykey'):
      raise PrimaryKeyError, (tableDef ['name'])


  # ---------------------------------------------------------------------------
  # Validate a given primary key definition
  # ---------------------------------------------------------------------------

  def _validatePrimaryKey (self, tableName, keyDefinition):
    """
    This function validates a primarykey definition.
    @param tableName: name of the table the primary key belongs to
    @param keyDefinition: dictionary describing the primary key

    @raise PrimaryKeyDefinitionError: if 'name' or 'fields' are missing in the
        definition.
    @raise PrimaryKeyFieldsError: if 'fields' is an empty sequence
    """
    self.__validateDefinition (keyDefinition, ['name', 'fields'],
                               PrimaryKeyDefinitionError)

    if not len (keyDefinition ['fields']):
      raise PrimaryKeyFieldsError, (tableName, keyDefinition ['name'])

    for field in keyDefinition ['fields']:
      if self._nameTooLong (field):
        raise LengthError, (field, self.MAX_NAME_LENGTH)


  # ---------------------------------------------------------------------------
  # Validate a given index definition
  # ---------------------------------------------------------------------------

  def _validateIndex (self, tableName, indexDefinition):
    """
    This function validates an index definition.
    @param tableName: name of the table
    @param indexDefinition: dictionary describing the index

    @raise IndexDefinitionError: if 'name' or 'fields' are missing in the
        definition.
    @raise IndexFieldsError: if 'fields' is an empty sequence
    """
    self.__validateDefinition (indexDefinition, ['name', 'fields'],
                               IndexDefinitionError)
    if not len (indexDefinition ['fields']):
      raise IndexFieldsError, (tableName, indexDefinition ['name'])

    for field in indexDefinition ['fields']:
      if self._nameTooLong (field):
        raise LengthError, (field, self.MAX_NAME_LENGTH)


  # ---------------------------------------------------------------------------
  # Validate a given constraint definition
  # ---------------------------------------------------------------------------

  def _validateConstraint (self, tableName, constDef):
    """
    This function validates a constraint definition.
    @param tableName: name of the table the constraint belongs to
    @param constDef: the dictionary describing the constraint

    @raise ConstraintDefinitionError: if 'name' or 'fields' are missing in the
        definition.
    @raise ConstraintFieldsError: if 'fields' or 'reffields' is an empty
        sequence.
    """
    self.__validateDefinition (constDef,
        ['name', 'fields', 'reftable', 'reffields'], ConstraintDefinitionError)

    if not len (constDef ['fields']):
      raise ConstraintFieldsError, (tableName, constDef ['name'], 'fields')
    if not len (constDef ['reffields']):
      raise ConstraintFieldsError, (tableName, constDef ['name'], 'reffields')

    if constDef.has_key ('type') and constDef ['type'] != 'foreignkey':
      raise ConstraintTypeError, (tableName, constDef ['name'],
                                  constDef ['type'])

    if self._nameTooLong (constDef ['reftable']):
      raise LengthError, (constDef ['reftable'], self.MAX_NAME_LENGTH)

    for field in constDef ['fields'] + constDef ['reffields']:
      if self._nameTooLong (field):
        raise LengthError, (field, self.MAX_NAME_LENGTH)


  # ---------------------------------------------------------------------------
  # Validate a field definition
  # ---------------------------------------------------------------------------

  def _validateField (self, tableName, fieldDef):
    """
    This function validates a field definition.
    @param tableName: name of the table
    @param fieldDef: dictionary describing the field

    @raise FieldDefinitionError: If the dictionary has no 'name' and 'type'
        keys.
    """
    self.__validateDefinition (fieldDef, ['name', 'type'], FieldDefinitionError)
    if self._nameTooLong (fieldDef ['name']):
      raise LengthError, (fieldDef ['name'], self.MAX_NAME_LENGTH)


  # ---------------------------------------------------------------------------
  # Validate all keys in an arbitrary definition
  # ---------------------------------------------------------------------------

  def __validateDefinition (self, definition, keys, defError):
    """
    This function raises an exception if a key in the given sequence is missing
    in the definition.
    @param definition: dictionary to be checked
    @param keys: sequence of keys which must exist in definition
    @param defError: DefinitionError class raised on a missing key
    """
    for key in keys:
      if not definition.has_key (key):
        raise defError, (key)


  # ---------------------------------------------------------------------------
  # find a method in a class or its' superclasses
  # ---------------------------------------------------------------------------

  def __findMethod (self, aClass, aMethod):
    """
    This function looks for a method in a class and all its' superclasses.

    @param aClass: the class where the search starts
    @param aMethod: name of the method to be looked for
    @return: function pointer to the method found or None if search failed.
    """

    if aClass.__dict__.has_key (aMethod):
      return aClass.__dict__ [aMethod]
    else:
      for base in aClass.__bases__:
        result = self.__findMethod (base, aMethod)
        if result is not None:
          return result

    return None

