"""Relational algebra for PyORQ

This module implements the relational algebra for PyORQ. A more or less formal
definition is given by the following grammar:

relation   := not_rel | and_rel | or_rel | comparions
not_rel    := '~' relation
and_rel    := relation '&' relation
or_rel     := relation '|' relation
comparison := expr cmp_op expr
cmp_op     := '==' | '!=' | '<' | '<=' | '>' | '>>'
expression := min_expr | bin_expr | term
min_expr   := '-' expr
bin_expr   := expr bin_op expr
bin_op     := '+' | '-' | '*' | '/'
term       := atrr_ref | int | long | float | str ...

Relations are build from expression and relation objects

All objects have the following methods

py_repr()
   Build a representation that can be evaluated by python. Used by the nodb()
   interface to build a list comprehension. that should produce the same
   result as an sql query

sql_repr(db)
   Build a representation that can be used in a WHERE clause. The db-arghument
   is used to call back into the database interface to build representations
   for builtin values

free_variable():
   Returns the free variable of the relation

updated_bound_variables(d):
   Used to find the join clauses necessary to build the cross reference
   between tables. A query of the form A.b.c.d == 3, produces a bound-variable
   dict {('_x', 'b'): A.b.ptype, ('_x', 'b', 'c'): A.b.ptype.c.ptype} This
   dictionary is used to produce aliases and join clauses.

The relation objects also have an __iter__ method. This is where the magic
starts.
"""

##########################################################################
# Expressions

import datetime

class binary_repr_mixin(object):
    """Used by binary relations and binary expressions to make
    appropriate representations"""
    def py_repr(self):
        return '(%s %s %s)' % (self.lhs.py_repr(),
                               self.op_py_repr,
                               self.rhs.py_repr())
    def sql_repr(self, db):
        return '(%s %s %s)' % (self.lhs.sql_repr(db),
                               self.op_sql_repr,
                               self.rhs.sql_repr(db))

class expression(object):
    """Base class for expression objects

    Defines arithmetic operations that return new expression objects
    Defines comparison opertaors that return new relation objects
    """
    def __add__(self, other):
        """x+y --> expr_add(x, y)"""
        return expr_add(self, maketerm(other))
    def __radd__(self, other):
        """x+y --> expr_add(x, y)"""
        return expr_add(maketerm(other), self)
    def __sub__(self, other):
        """x-y --> expr_sub(x, y)"""
        return expr_sub(self, maketerm(other))
    def __rsub__(self, other):
        """x-y --> expr_sub(x, y)"""
        return expr_sub(maketerm(other), self)
    def __mul__(self, other):
        """x*y --> expr_mul(x, y)"""
        return expr_mul(self, maketerm(other))
    def __rmul__(self, other):
        """x*y --> expr_mul(x, y)"""
        return expr_mul(maketerm(other), self)
    def __div__(self, other):
        """x/y --> expr_div(x, y)"""
        return expr_div(self, maketerm(other))
    def __rdiv__(self, other):
        """x/y --> expr_div(x, y)"""
        return expr_div(maketerm(other), self)
    def __neg__(self):
        """-x --> expr_neg(x, y)"""
        return expr_neg(self)

    def __eq__(self, other):
        """x==y --> comp_eq(x, y)"""
        return comp_eq(self, maketerm(other))
    def __ne__(self, other):
        """x!=y --> comp_ne(x, y)"""
        return comp_ne(self, maketerm(other))
    def __ge__(self, other):
        """x>=y --> comp_ge(x, y)"""
        return comp_ge(self, maketerm(other))
    def __gt__(self, other):
        """x>y --> comp_gt(x, y)"""
        return comp_gt(self, maketerm(other))
    def __le__(self, other):
        """x<=y --> comp_le(x, y)"""
        return comp_le(self, maketerm(other))
    def __lt__(self, other):
        """x<y --> comp_lt(x, y)"""
        return comp_lt(self, maketerm(other))

class expr_neg(expression):
    """Created by a unary minus before an expression
    """
    def __init__(self, arg):
        self.arg = arg
    def free_variable(self):
        return self.arg.free_variable()
    def update_bound_variables(self, d):
        self.arg.update_bound_variables(d)
    def py_repr(self):
        return '(-%s)' % self.arg.py_repr()
    def sql_repr(self, db):
        return '(-%s)' % self.arg.sql_repr(db)

class binary_expression(expression, binary_repr_mixin):
    """Base class for expressions created by binary arithmetic operators
    """
    def __init__(self, lhs, rhs):
        lf = lhs.free_variable()
        rf = rhs.free_variable()
        if lf and rf and lf is not rf:
            raise ValueError, "The relation has more than one free variable"
        self.lhs = lhs
        self.rhs = rhs
    def free_variable(self):
        return self.lhs.free_variable() or self.rhs.free_variable()
    def update_bound_variables(self, d):
        self.lhs.update_bound_variables(d)
        self.rhs.update_bound_variables(d)
        
class expr_add(binary_expression):
    """Expression created by adding two expression"""
    op_py_repr = '+'
    op_sql_repr = '+'
class expr_sub(binary_expression):
    """Expression created by subtracting two expression"""
    op_py_repr = '-'
    op_sql_repr = '-'
class expr_mul(binary_expression):
    """Expression created by multiplying two expression"""
    op_py_repr = '*'
    op_sql_repr = '*'
class expr_div(binary_expression):
    """Expression created by dividing two expression"""
    op_py_repr = '/'
    op_sql_repr = '/'

###############################################################
# Terms

class term(expression):
    """The base class for terms in an expression"""
    pass

class value(term):
    """A term representing a Python value (e.g. int, float, str)

    This object is returned by the maketerm function.
    """
    def __init__(self, val):
        self.val = val
    def free_variable(self):
        return None
    def update_bound_variables(self, d):
        pass
    def py_repr(self):
        return repr(self.val)
    def sql_repr(self, db):
        val = self.val
        if isinstance(val, (int, long, float)):
            return str(val)
        if isinstance(val, str):
            return db.quote_str(val)
        if isinstance(val, datetime.date):
            return db.quote_date(val)
        if isinstance(val, datetime.time):
            return db.quote_time(val)
        if isinstance(val, datetime.datetime):
            return db.quote_datetime(val)
        

class reference(term):
    """The baseclass for references to persistent properties"""
    def __init__(self, parent, prop):
        self.parent = parent
        self.prop = prop
    def free_variable(self):
        if isinstance(self.parent, reference):
            return self.parent.free_variable()
        else:
            return self.parent
    def py_repr(self):
        if isinstance(self.parent, reference):
            head = self.parent.py_repr()
        else:
            head = '_x'
        return '%s.%s' % (head, self.prop.name)

    def make_alias(self):
        if isinstance(self.parent, reference):
            return '%s_%s' % (self.parent.make_alias(), self.prop.name)
        else:
            return '_x_'+self.prop.name
        
    def sql_repr(self, db):
        if isinstance(self.parent, reference):
            head = self.parent.make_alias()
        else:
            head = '_x'
        return '%s.%s' % (head, self.prop.name)
    

class value_attr(reference):
    """A reference to a persistent value (pval)"""
    def __init__(self, parent, prop):
        """A reference to a persistent value (pval)

        Arguments:
          parent -- either a persistent class or a ref_attr object
          prop   -- a pval object
        """
        self.parent = parent
        self.prop = prop
    def update_bound_variables(self, d):
        if isinstance(self.parent, reference):
            self.parent.update_bound_variables(d)
        
class ref_attr(reference):
    """A reference to a persistent ref (pref)"""
    def __init__(self, parent, prop):
        """A reference to a persistent ref (pref)

        Arguments:
          parent -- either a persistent class or a ref_attr object
          prop   -- a pref object
        """
        self.parent = parent
        self.prop = prop

    def update_bound_names_and_variables(self, l, d):
        if isinstance(self.parent, ref_attr):
            self.parent.update_bound_names_and_variables(l, d)
        name = self.prop.name
        tp   = self.prop.ptype
        l.append(name)
        d[tuple(l)] = tp
        
    def update_bound_variables(self, d):
        bound_name = ['_x']
        self.update_bound_names_and_variables(bound_name, d)

    def __getattr__(self, attr):
        """Attribute acess of reference attribute

        If attr is the name of a pval object p of self.prop.ptype, return
        value_attr(self, p)
        If attr is the name of a pref object p of self.prop.ptype, return
        ref_attr(self, p)
        Else, raise an AttributeError
        """
        for p in self.prop.ptype.persistent_values:
            if p.name == attr:
                return value_attr(self, p)
        for p in self.prop.ptype.persistent_refs:
            if p.name == attr:
                return ref_attr(self, p)
        raise AttributeError, ('%s not a property of %s' %
                               (attr, self.prop.ptype))

    def __eq__(self, other):
        """x==y --> comp_is(x, y)

        If x is a ref_attr and y is either a ref_attr or an instance,
        equality implies identity (i.e: 'is', rather than '==')
        """
        return comp_is(self, maketerm(other))
    def __ne__(self, other):
        """x!=y --> comp_not_is(x, y)

        If x is a ref_attr and y is either a ref_attr or an instance,
        inequality implies unequal identity (i.e: 'not is', rather than '!=')
        """
        return comp_not_is(self, maketerm(other))

    # Inequalities are not defined for attr_refs
    def __gt__(self, other):
        raise NotImplementedError, 'Comparison on object reference'
    __ge__ = __le__ = __lt__ = __gt__


def maketerm(val):
    """If the val is not an expression make it a val object"""
    if not isinstance(val, expression):
        return value(val)
    return val

######################################################################
# Relations
    
class relation(object):
    """The base class for relastions.

    Defines logical operations that return new relation objects
    Defines an __iter__ method that evaluates that query
    """
    def __and__(self, other):
        """x & y --> rel_and(x, y)"""
        return rel_and(self, other)

    def __or__(self, other):
        """x | y --> rel_not(x, y)"""
        return rel_or(self, other)

    def __invert__(self):
        """~x --> rel_not(x)"""
        return rel_not(self)

    def __iter__(self):
        """Yield all instances that satisfy the relation"""
        # We are going to find instances of this class that satisfy self
        cls = self.free_variable()
        db = cls.database

        # loop over the class and all its subclasses
        for i in cls.all_subclasses():
            # evaluate the query with i as a free variable
            for j in db.query_generator(i, self):
                # yield each instance satisfying self
                yield j

class rel_not(relation):
    """Relation created by applying a logical NOT to a relation"""
    def __init__(self, arg):
        """Represents NOT arg"""
        self.arg = arg
    def free_variable(self):
        return self.arg.free_variable()
    def update_bound_variables(self, d):
        self.arg.update_bound_variables(d)
    def py_repr(self):
        return '(not %s)' % self.arg.py_repr()
    def sql_repr(self, db):
        return '(NOT %s)' % self.arg.sql_repr(db)

class binary_relation(relation, binary_repr_mixin):
    """base class for relation created by applying a binary logical operation
    to two relations"""
    
    def __init__(self, lhs, rhs):
        if lhs.free_variable() is not rhs.free_variable():
            raise ValueError, "The relation has more than one free variable"
        self.lhs = lhs
        self.rhs = rhs

    def free_variable(self):
        # lhs and rhs have the same free variable
        return self.lhs.free_variable()
    def update_bound_variables(self, d):
        self.lhs.update_bound_variables(d)
        self.rhs.update_bound_variables(d)

class rel_and(binary_relation):
    """Relation obtained from rel1 & rel2"""
    op_py_repr = ' and '
    op_sql_repr = ' AND '
class rel_or(binary_relation):
    """Relation obtained from rel1 | rel2"""
    op_py_repr = ' or '
    op_sql_repr = ' OR '

class comparison_relation(relation, binary_repr_mixin):
    """Relations obtained by comparing two expression objects"""
    def __init__(self, lhs, rhs):
        lf = lhs.free_variable()
        rf = rhs.free_variable()
        if lf and rf and lf is not rf:
            raise ValueError, "The relation has more than one free variable"
        self.lhs = lhs
        self.rhs = rhs

    def free_variable(self):
        return self.lhs.free_variable() or self.rhs.free_variable()
    def update_bound_variables(self, d):
        self.lhs.update_bound_variables(d)
        self.rhs.update_bound_variables(d)

class comp_eq(comparison_relation):
    """Relation obtained from equality comparison"""
    op_py_repr = '=='
    op_sql_repr = '='
class comp_ne(comparison_relation):
    """Relation obtained from inequality comparison"""
    op_py_repr = '!='
    op_sql_repr = '!='
class comp_gt(comparison_relation):
    """Relation obtained from greather-than comparison"""
    op_py_repr = '>'
    op_sql_repr = '>'
class comp_ge(comparison_relation):
    """Relation obtained from greater-or-equal comparison"""
    op_py_repr = '>='
    op_sql_repr = '>='
class comp_lt(comparison_relation):
    """Relation obtained from less-than comparison"""
    op_py_repr = '<'
    op_sql_repr = '<'
class comp_le(comparison_relation):
    """Relation obtained from less-or-equal comparison"""
    op_py_repr = '<='
    op_sql_repr = '<='

class identity_comparison(comparison_relation):
    """Comparisons of identity, rather than value"""
    def update_bound_variables(self, d):
        if (isinstance(self.lhs, reference) and
            isinstance(self.lhs.parent, reference)):
            self.lhs.parent.update_bound_variables(d)
        if (isinstance(self.rhs, reference) and
            isinstance(self.rhs.parent, reference)):
            self.rhs.parent.update_bound_variables(d)
    def py_repr(self):
        if isinstance(self.lhs, reference):
            l = '%s.oid' % self.lhs.py_repr()
        else:
            l = str(self.lhs.val.oid)
        if isinstance(self.rhs, reference):
            r = '%s.oid' % self.rhs.py_repr()
        else:
            r = str(self.rhs.val.oid)
        return '(%s %s %s)' % (l, self.op_py_repr, r)
    def sql_repr(self, db):
        if isinstance(self.lhs, reference):
            head = self.lhs.sql_repr(db)+'_'
            lhs_oid = head+'oid'
            lhs_tid = head+'tid'
        else:
            lhs_oid = repr(self.lhs.val.oid)
            lhs_tid = repr(self.lhs.val.tid)
        if isinstance(self.rhs, reference):
            head = self.rhs.sql_repr(db)+'_'
            rhs_oid = head+'oid'
            rhs_tid = head+'tid'
        else:
            rhs_oid = repr(self.rhs.val.oid)
            rhs_tid = repr(self.rhs.val.tid)
        return '(%s%s%s AND %s%s%s)' % (lhs_tid, self.op_sql_repr, rhs_tid,
                                      lhs_oid, self.op_sql_repr, rhs_oid)
        
class comp_is(identity_comparison):
    """Identities are equal (x is y)"""
    op_py_repr = '=='
    op_sql_repr = '='
class comp_not_is(identity_comparison):
    """Identies are not equal (x is not y)"""
    op_py_repr = '!='
    op_sql_repr = '='
