import unittest
from pyorq import ptype
from pyorq.pprop import *

def fill_db():
    global A, B, C, D, a_oids, b_oids, c_oids, d1_oids, d2_oids
    class A(ptype.pobject):
        database = db
        a = pint()
        def __init__(self, i):
            self.a = i

    class B(A):
        pass

    class C(ptype.pobject):
        database = db
        ref = pref(B)
        def __init__(self, i):
            self.ref = i

    class D(ptype.pobject):
        database = db
        ref = pref(A)
        def __init__(self, i):
            self.ref = i

    db.empty_table('A')
    db.empty_table('B')
    db.empty_table('C')
    db.empty_table('D')
    
    a_oids = []
    b_oids = []
    c_oids = []
    d1_oids = []
    d2_oids = []
    
    for i in range(5):
        a = A(i)
        d = D(a)
        a.commit(); d.commit()
        a_oids.append((a.tid, a.oid))
        d1_oids.append((d.tid, d.oid))
    for i in range(5):
        b = B(i)
        c, d = C(b), D(b)
        b.commit(); c.commit(); d.commit()
        b_oids.append((b.tid, b.oid))
        c_oids.append((c.tid, c.oid))
        d2_oids.append((d.tid, d.oid))

class test_simple_queries(unittest.TestCase):
    def setUp(self):
        A.database.clear_cache()

    def match_lists(self, results, oid_list):
        self.failUnlessEqual(len(results), len(oid_list))
        for obj in results:
            for tid, oid in oid_list:
                if obj.tid == tid and obj.oid == oid:
                    break
            else:
                self.fail("%s not in %s" % ((obj.tid, obj.oid),
                                            oid_list))

    def test_EQ_query(self):
        q = B.a == 3
        self.match_lists(list(q), b_oids[3:4])
    
    def test_EQ_query_inh(self):
        q = A.a == 3
        self.match_lists(list(q), a_oids[3:4]+b_oids[3:4])

    def test_LE_query(self):
        q = B.a <= 3
        self.match_lists(list(q), b_oids[0:4])

    def test_LE_query_inh(self):
        q = A.a <= 3
        self.match_lists(list(q), a_oids[0:4]+b_oids[0:4])
        
    def test_LT_query(self):
        q = B.a < 3
        self.match_lists(list(q), b_oids[0:3])

    def test_LT_query_inh(self):
        q = A.a < 3
        self.match_lists(list(q), a_oids[0:3]+b_oids[0:3])

    def test_GT_query(self):
        q = B.a > 3
        self.match_lists(list(q), b_oids[4:5])
        
    def test_GT_query_inh(self):
        q = A.a > 3
        self.match_lists(list(q), a_oids[4:5]+b_oids[4:5])

    def test_GE_query(self):
        q = B.a >= 3
        self.match_lists(list(q), b_oids[3:5])

    def test_GE_query_inh(self):
        q = A.a >= 3
        self.match_lists(list(q), a_oids[3:5]+b_oids[3:5])

    def test_AND_query(self):
        q = (B.a > 1) & (B.a <= 3)
        self.match_lists(list(q), b_oids[2:4])

    def test_AND_query_inh(self):
        q = (A.a > 1) & (A.a <= 3)
        self.match_lists(list(q), a_oids[2:4]+b_oids[2:4])
        
    def test_OR_query(self):
        q = (B.a < 1) | (B.a >= 3)
        self.match_lists(list(q), b_oids[0:1]+b_oids[3:])
        
    def test_OR_query_inh(self):
        q = (A.a < 1) | (A.a >= 3)
        self.match_lists(list(q), a_oids[0:1]+a_oids[3:]+b_oids[0:1]+b_oids[3:])

    def test_ADD_query(self):
        q = B.a+1 == 3
        self.match_lists(list(q), b_oids[2:3])

    def test_ADD_query_inh(self):
        q = A.a+1 == 3
        self.match_lists(list(q), a_oids[2:3]+b_oids[2:3])

    def test_SUB_query(self):
        q = B.a-1 == 3
        self.match_lists(list(q), b_oids[4:5])

    def test_SUB_query_inh(self):
        q = A.a-1 == 3
        self.match_lists(list(q), a_oids[4:5]+b_oids[4:5])

class test_pref_queries(unittest.TestCase):            
    def setUp(self):
        A.database.clear_cache()

    def match_lists(self, results, oid_list):
        self.failUnlessEqual(len(results), len(oid_list))
        for obj in results:
            for tid, oid in oid_list:
                if obj.tid == tid and obj.oid == oid:
                    break
            else:
                self.fail("%s not in %s" % ((obj.tid, obj.oid),
                                            oid_list))

    def test_instance_comp(self):
        b = list(B.a == 3)[0]
        q = (C.ref == b)
        self.match_lists(list(q), c_oids[3:4])

    def test_EQ_query(self):
        q = C.ref.a == 3
        self.match_lists(list(q), c_oids[3:4])
    
    def test_EQ_query_inh(self):
        q = D.ref.a == 3
        self.match_lists(list(q), d1_oids[3:4]+d2_oids[3:4])
        
    def test_LE_query(self):
        q = C.ref.a <= 3
        self.match_lists(list(q), c_oids[0:4])

    def test_LE_query_inh(self):
        q = D.ref.a <= 3
        self.match_lists(list(q), d1_oids[0:4]+d2_oids[0:4])
        
    def test_LT_query(self):
        q = C.ref.a < 3
        self.match_lists(list(q), c_oids[0:3])

    def test_LT_query_inh(self):
        q = D.ref.a < 3
        self.match_lists(list(q), d1_oids[0:3]+d2_oids[0:3])

    def test_GT_query(self):
        q = C.ref.a > 3
        self.match_lists(list(q), c_oids[4:5])
        
    def test_GT_query_inh(self):
        q = D.ref.a > 3
        self.match_lists(list(q), d1_oids[4:5]+d2_oids[4:5])

    def test_GE_query(self):
        q = C.ref.a >= 3
        self.match_lists(list(q), c_oids[3:5])

    def test_GE_query_inh(self):
        q = D.ref.a >= 3
        self.match_lists(list(q), d1_oids[3:5]+d2_oids[3:5])

    def test_AND_query(self):
        q = (C.ref.a > 1) & (C.ref.a <= 3)
        self.match_lists(list(q), c_oids[2:4])

    def test_AND_query_inh(self):
        q = (D.ref.a > 1) & (D.ref.a <= 3)
        self.match_lists(list(q), d1_oids[2:4]+d2_oids[2:4])
        
    def test_OR_query(self):
        q = (C.ref.a < 1) | (C.ref.a >= 3)
        self.match_lists(list(q), c_oids[0:1]+c_oids[3:])
        
    def test_OR_query_inh(self):
        q = (D.ref.a < 1) | (D.ref.a >= 3)
        self.match_lists(list(q), d1_oids[0:1]+d1_oids[3:]+d2_oids[0:1]+d2_oids[3:])

    def test_ADD_query(self):
        q = C.ref.a+1 == 3
        self.match_lists(list(q), c_oids[2:3])

    def test_ADD_query_inh(self):
        q = D.ref.a+1 == 3
        self.match_lists(list(q), d1_oids[2:3]+d2_oids[2:3])

    def test_SUB_query(self):
        q = C.ref.a-1 == 3
        self.match_lists(list(q), c_oids[4:5])

    def test_SUB_query_inh(self):
        q = D.ref.a-1 == 3
        self.match_lists(list(q), d1_oids[4:5]+d2_oids[4:5])

def make_suite():
    suite = unittest.TestSuite()
    suite.addTest(unittest.makeSuite(test_simple_queries))
    suite.addTest(unittest.makeSuite(test_pref_queries))
    return suite

def test_db(external_lib_name, interface_name, **kwargs):
    global db 

    # Try to import the external libray
    try:
        mod = __import__(external_lib_name)
    except ImportError:
        print "%s not available" % external_lib_name
        return

    try:
        mod = __import__('pyorq.interface.'+interface_name)
        mod = getattr(mod.interface, interface_name)
        db = getattr(mod, interface_name)(**kwargs)
    except:
        print "Unable to instantiate interface %s" % interface_name
        return

    print "Running with %s" % interface_name
    fill_db()
    unittest.TextTestRunner(verbosity=1).run(make_suite())



if __name__ == '__main__':
    test_db("sys", "nodb")
    test_db("pyPgSQL.libpq", "postgresql_db", database="testdb")
    test_db("_sqlite", "sqlite_db", database="testdb")
    test_db("_mysql", "mysql_db", db="testdb")
