# odbc test suite kindly contributed by Frank Millman.
import sys
import os
import unittest
import odbc
import tempfile

from pywin32_testutil import str2bytes, str2memory, TestSkipped

# We use the DAO ODBC driver
from win32com.client.gencache import EnsureDispatch
from win32com.client import constants
import pythoncom

class TestStuff(unittest.TestCase):
    def setUp(self):
        self.tablename = "pywin32test_users"
        self.db_filename = None
        self.conn = self.cur = None
        try:
            # Test any database if a connection string is supplied...
            conn_str = os.environ['TEST_ODBC_CONNECTION_STRING']
        except KeyError:
            # Create a local MSAccess DB for testing.
            self.db_filename = tempfile.NamedTemporaryFile().name + '.mdb'
    
            # Create a brand-new database - what is the story with these?
            for suffix in (".36", ".35", ".30"):
                try:
                    dbe = EnsureDispatch("DAO.DBEngine" + suffix)
                    break
                except pythoncom.com_error:
                    pass
            else:
                raise TestSkipped("Can't find a DB engine")
    
            workspace = dbe.Workspaces(0)
    
            newdb = workspace.CreateDatabase(self.db_filename, 
                                             constants.dbLangGeneral,
                                             constants.dbEncrypt)
    
            newdb.Close()
    
            conn_str = "Driver={Microsoft Access Driver (*.mdb)};dbq=%s;Uid=;Pwd=;" \
                       % (self.db_filename,)
        ## print 'Connection string:', conn_str
        self.conn = odbc.odbc(conn_str)
        # And we expect a 'users' table for these tests.
        self.cur = self.conn.cursor()
        ## self.cur.setoutputsize(1000)
        try:
            self.cur.execute("""drop table %s""" %self.tablename)
        except (odbc.error, odbc.progError):
            pass

        ## This needs to be adjusted for sql server syntax for unicode fields
        ##  - memo -> TEXT
        ##  - varchar -> nvarchar
        self.assertEqual(self.cur.execute(
            """create table %s (
                    userid varchar(25),
                    username varchar(25),
                    bitfield bit,
                    intfield integer,
                    floatfield float,
                    datefield datetime,
                    rawfield varbinary(100),
                    longtextfield memo,
                    longbinaryfield image
            )""" %self.tablename),-1)

    def tearDown(self):
        if self.cur is not None:
            try:
                self.cur.execute("""drop table %s""" %self.tablename)
            except (odbc.error, odbc.progError) as why:
                print("Failed to delete test table %s" %self.tablename, why)

            self.cur.close()
            self.cur = None
        if self.conn is not None:
            self.conn.close()
            self.conn = None
        if self.db_filename is not None:
            try:
                os.unlink(self.db_filename)
            except OSError:
                pass

    def test_insert_select(self, userid='Frank', username='Frank Millman'):
        self.assertEqual(self.cur.execute("insert into %s (userid, username) \
            values (?,?)" %self.tablename, [userid, username]),1)
        self.assertEqual(self.cur.execute("select * from %s \
            where userid = ?" %self.tablename, [userid.lower()]),0)
        self.assertEqual(self.cur.execute("select * from %s \
            where username = ?" %self.tablename, [username.lower()]),0)

    def test_insert_select_unicode(self, userid='Frank', username="Frank Millman"):
        self.assertEqual(self.cur.execute("insert into %s (userid, username)\
            values (?,?)" %self.tablename, [userid, username]),1)
        self.assertEqual(self.cur.execute("select * from %s \
            where userid = ?" %self.tablename, [userid.lower()]),0)
        self.assertEqual(self.cur.execute("select * from %s \
            where username = ?" %self.tablename, [username.lower()]),0)

    def test_insert_select_unicode_ext(self):
        userid = "t-\xe0\xf2"
        username = "test-\xe0\xf2 name"
        self.test_insert_select_unicode(userid, username)

    def _test_val(self, fieldName, value):
        for x in range(100):
            self.cur.execute("delete from %s where userid='Frank'" %self.tablename)
            self.assertEqual(self.cur.execute(
                "insert into %s (userid, %s) values (?,?)" % (self.tablename, fieldName),
                ["Frank", value]), 1)
            self.cur.execute("select %s from %s where userid = ?" % (fieldName, self.tablename),
                             ["Frank"])
            rows = self.cur.fetchmany()
            self.failUnlessEqual(1, len(rows))
            row = rows[0]
            self.failUnlessEqual(row[0], value)

    def testBit(self):
        self._test_val('bitfield', 1)
        self._test_val('bitfield', 0)

    def testInt(self):
        self._test_val('intfield', 1)
        self._test_val('intfield', 0)
        try:
            big = sys.maxsize
        except AttributeError:
            big = sys.maxint
        self._test_val('intfield', big)
        
    def testFloat(self):
        self._test_val('floatfield', 1.01)
        self._test_val('floatfield', 0)

    def testVarchar(self, ):
        self._test_val('username', 'foo')

    def testLongVarchar(self):
        """ Test a long text field in excess of internal cursor data size (65536)"""
        self._test_val('longtextfield', 'abc' * 70000)

    def testLongBinary(self):
        """ Test a long raw field in excess of internal cursor data size (65536)"""
        self._test_val('longbinaryfield', str2memory('\0\1\2' * 70000))

    def testRaw(self):
        ## Test binary data
        self._test_val('rawfield', str2memory('\1\2\3\4\0\5\6\7\8'))

    def test_widechar(self):
        """Test a unicode character that would be mangled if bound as plain character.
            For example, previously the below was returned as ascii 'a'
        """
        self._test_val('username', '\u0101')

    def testDates(self):
        import datetime
        for v in (
            (1900, 12, 25, 23, 39, 59),
            ):
            d = datetime.datetime(*v)
            self._test_val('datefield', d)

    def test_set_nonzero_length(self):
        self.assertEqual(self.cur.execute("insert into %s (userid,username) "
            "values (?,?)" %self.tablename, ['Frank', 'Frank Millman']),1)
        self.assertEqual(self.cur.execute("update %s set username = ?" %self.tablename,
            ['Frank']),1)
        self.assertEqual(self.cur.execute("select * from %s" %self.tablename), 0)
        self.assertEqual(len(self.cur.fetchone()[1]),5)

    def test_set_zero_length(self):
        self.assertEqual(self.cur.execute("insert into %s (userid,username) "
            "values (?,?)" %self.tablename, [str2bytes('Frank'), '']),1)
        self.assertEqual(self.cur.execute("select * from %s" %self.tablename), 0)
        self.assertEqual(len(self.cur.fetchone()[1]),0)

    def test_set_zero_length_unicode(self):
        self.assertEqual(self.cur.execute("insert into %s (userid,username) "
            "values (?,?)" %self.tablename, ['Frank', '']),1)
        self.assertEqual(self.cur.execute("select * from %s" %self.tablename), 0)
        self.assertEqual(len(self.cur.fetchone()[1]),0)

if __name__ == '__main__':
    unittest.main()