from __future__ import nested_scopes
import _sqlite

import copy, new, re, string, sys, warnings, weakref
from types import DictType, ListType, LongType, StringType, TupleType, UnicodeType
import threading
import threadutils

try:
    from mx import DateTime
    have_datetime = 1
except ImportError:
    have_datetime = 0

EXPECTED_TYPES_REGEX = re.compile(r"\s*pysqlite_pragma\s*expected_types\s*=(.*)")
RESET_EXPECTED_TYPES_REGEX = re.compile(r"\s*pysqlite_pragma\s*reset_expected_types")

if have_datetime:
    # Make the required Date/Time constructor visable in the PySQLite module.
    Date = DateTime.Date
    Time = DateTime.Time
    Timestamp = DateTime.Timestamp
    DateFromTicks = DateTime.DateFromTicks
    TimeFromTicks = DateTime.TimeFromTicks
    TimestampFromTicks = DateTime.TimestampFromTicks

    # And also the DateTime types
    DateTimeType = DateTime.DateTimeType
    DateTimeDeltaType = DateTime.DateTimeDeltaType

class DBAPITypeObject:
    def __init__(self,*values):
        self.values = values

    def __cmp__(self,other):
        if other in self.values:
            return 0
        if other < self.values:
            return 1
        else:
            return -1

class SubclassResponsibilityError(NotImplementedError):
    """This exception is raised by abstract methods in abstract classes.

    It is a special case of NotImplementedError, that indicates that the
    implementation won't be provided at that location in the future -- instead
    the subclass should provide it."""
    pass

def _quote(value):
    """_quote(value) -> string

    This function transforms the Python value into a string suitable to send to
    the SQLite database in a SQL statement.  This function is automatically
    applied to all parameters sent with an execute() call.  Because of this a
    SQL statement string in an execute() call should only use '%s' [or
    '%(name)s'] for variable substitution without any quoting."""

    if value is None:
        return 'NULL'
    elif isinstance(value, StringType):
        return "'%s'" % re.sub("'", "''", value)
    elif isinstance(value, LongType):
        return str(value)
    elif hasattr(value, '_quote'):
        return value._quote()
    elif have_datetime and type(value) in \
            [DateTime.DateTimeType, DateTime.DateTimeDeltaType]:
        return "'%s'" % value
    else:
        return repr(value)

def _quoteall(vdict):
    """_quoteall(vdict)->dict
    Quotes all elements in a list or dictionary to make them suitable for
    insertion in a SQL statement."""

    if type(vdict) is DictType or isinstance(vdict, PgResultSet):
        t = {}
        for k, v in vdict.items():
            t[k]=_quote(v)
    elif isinstance(vdict, StringType) or isinstance(vdict, UnicodeType):
        # Note: a string is a SequenceType, but is treated as a single
        #    entity, not a sequence of characters.
        t = (_quote(vdict), )
    elif type(vdict)in [ListType, TupleType]:
        t = tuple(map(_quote, vdict))
    else:
        raise TypeError, \
              "argument to _quoteall must be a sequence or dictionary!"

    return t

class PgResultSet:
    """A DB-API query result set for a single row.
    This class emulates a sequence with the added feature of being able to
    reference a column as attribute or with dictionary access in addition to a
    zero-based numeric index."""

    def __init__(self, value):
        self.__dict__['baseObj'] = value

    def __getattr__(self, key):
        if self.__class__._xlatkey.has_key(key):
            return self.baseObj[self.__class__._xlatkey[key]]
        raise AttributeError, key

    # We define a __setattr__ routine that will only allow the attributes that
    # are the column names to be updated.  All other attributes are read-only.
    def __setattr__(self, key, value):
        if key in ('baseObj'):
            raise AttributeError, "%s is read-only." % key

        if self.__class__._xlatkey.has_key(key):
            self.__dict__['baseObj'][self.__class__._xlatkey(key)] = value
        else:
            raise AttributeError, key

    def __len__(self):
        return len(self.baseObj)

    def __getitem__(self, key):
        if isinstance(key, StringType):
            key = self.__class__._xlatkey[key]
        return self.baseObj[key]

    def __setitem__(self, key, value):
        if isinstance(key, StringType):
            key = self.__class__._xlatkey[key]

        self.baseObj[key] = value

    def __getslice__(self, i, j):
        klass = make_PgResultSetClass(self.__class__._desc_[i:j])
        obj = klass(self.baseObj[i:j])
        return obj

    def __repr__(self):
        return repr(self.baseObj)

    def __str__(self):
        return str(self.baseObj)

    def __cmp__(self, other):
        return cmp(self.baseObj, other)

    def description(self):
        return _desc_

    def keys(self):
        _k = []
        for _i in _desc_:
            _k.append(_i[0])
        return _k

    def values(self):
        return self.baseObj[:]

    def items(self):
        _items = []
        for i in range(len(self.baseObj)):
            _items.append((_desc_[i][0], self.baseObj[i]))

        return _items

    def has_key(self, key):
        return self.__class__._xlatkey.has_key(key)

    def get(self, key):
        return self[key]

def make_PgResultSetClass(description):
    NewClass = new.classobj("PgResultSetConcreteClass", (PgResultSet,), {})
    NewClass.__dict__['_desc_'] = description

    NewClass.__dict__['_xlatkey'] = {}

    for _i in range(len(description)):
        NewClass.__dict__['_xlatkey'][description[_i][0]] = _i

    return NewClass

class BaseCursor:
    """Abstract cursor class implementing what all cursor classes have in
    common."""

    def __init__(self, conn):
        self.arraysize = 1

        # Add ourselves to the list of cursors for our owning connection.
        self.con = weakref.proxy(conn)
        self.con.cursors[id(self)] = self

        self._reset()
        if not self.con.autocommit:
            # Only the first created cursor begins the transaction.
            if not self.con.inTransaction:
                self.con._begin()

    def _reset(self):
        # closed is a trinary variable:
        #     == None => Cursor has not been opened.
        #     ==    0 => Cursor is open.
        #     ==    1 => Cursor is closed.
        self.closed = None
        self.rowcount = -1
        self.description = None
        self.expected_types = None

    def _checkNotClosed(self, methodname=None):
        if self.closed:
            raise _sqlite.ProgrammingError, \
                "%s failed - the cursor is closed." % (methodname or "")

    def _unicodeConvert(self, obj):
        """Encode all unicode strings that can be found in obj into
        byte-strings using the encoding specified in the connection's
        constructor, available here as self.con.client_encoding."""

        if isinstance(obj, StringType):
            return obj
        elif isinstance(obj, UnicodeType):
            return obj.encode(*self.con.client_encoding)
        elif isinstance(obj, ListType) or isinstance(obj, TupleType):
            converted_obj = []
            for item in obj:
                if type(item) is UnicodeType:
                    converted_obj.append(item.encode(*self.con.client_encoding))
                else:
                    converted_obj.append(item)
            return converted_obj
        elif isinstance(obj, DictType):
            converted_obj = {}
            for k, v in obj.items():
                if type(v) is UnicodeType:
                    converted_obj[k] = v.encode(*self.con.client_encoding)
                else:
                    converted_obj[k] = v
            return converted_obj
        elif isinstance(obj, PgResultSet):
            obj = copy.copy(obj)
            for k, v in obj.items():
                if type(v) is UnicodeType:
                    obj[k] = v.encode(*self.con.client_encoding)
            return obj
        else:
            return obj

    def _convert_types(self, l):
        def converter_wrapper(cvt, value, typeName):
            if value is None:
                return value
            else:
                try:
                    rv = cvt(value)
                except ValueError:
                    raise _sqlite.ProgrammingError, \
                        "Conversion failed for value '%s' of expected type '%s'"%(str(value),typeName)
                return rv

        if self.expected_types is None:
            return l

        if l is None:
            return None
        elif isinstance(l, TupleType):
            return [converter_wrapper(cvt, item, typeName) \
                    for (cvt, typeName), item in zip(self.expected_types, l)]
        else:
            return l

    def execute(self, SQL, *parms):
        # This method prepares the execution of an SQL statement by doing all
        # the necessary magic with the parameters. The actual execution of the
        # SQL statement is delegated to a self._execute_sql method. If
        # _execute_sql was successfully called, self._after_execute_sql is
        # called.
        #
        # Currently, you must read the source to understand what's going on in
        # the various methods. Of course, the reason for implementing it this
        # way was to make the IterCursor possible. That doesn't mean, however,
        # that it cannot be simplified or clarified.
        self._checkNotClosed("execute")

        if RESET_EXPECTED_TYPES_REGEX.match(SQL):
            self.expected_types = None
            return

        expected_types = EXPECTED_TYPES_REGEX.findall(SQL)
        if len(expected_types) == 1:
            self.expected_types = map(string.strip, expected_types[0].split(","))
            try:
                self.expected_types = [(self.con.converters[t],t) \
                                       for t in self.expected_types]
            except KeyError, key:
                raise _sqlite.ProgrammingError, \
                    "execute failed - undefined expected type: '%s'"%key
            return

        if self.con.autocommit:
            pass
        else:
            if not self.con.inTransaction:
                self.con._begin()
                self.con.inTransaction = 1

        SQL = self._unicodeConvert(SQL)

        if len(parms) == 0:
            # If there are no paramters, just execute the query.
            self._execute_sql(SQL)
        else:
            if len(parms) == 1 and \
               (type(parms[0]) in [DictType, ListType, TupleType] or \
                        isinstance(parms[0], PgResultSet)):
                parms = (self._unicodeConvert(parms[0]),)
                parms = _quoteall(parms[0])
            else:
                parms = self._unicodeConvert(parms)
                parms = tuple(map(_quote, parms))

            self._execute_sql(SQL % parms)

        self.closed = 0
        self._after_execute_sql()

    def _execute_sql(self, sql):
        raise SubclassResponsibilityError, self.__class__

    def _after_execute_sql(self):
        raise SubclassResponsibilityError, self.__class__

    def _invalidate(self):
        raise SubclassResponsibilityError, self.__class__

    def executemany(self, query, parm_sequence):
        self._checkNotClosed("executemany")

        if self.con is None:
            raise Error, "connection is closed."

        for _i in parm_sequence:
            self.execute(query, _i)

    def close(self):
        if self.con and self.con.closed:
            raise _sqlite.ProgrammingError, \
                  "This cursor's connection is already closed."
        if self.closed:
            raise _sqlite.ProgrammingError, \
                  "This cursor is already closed."
        self.closed = 1

        # Disassociate ourselves from our connection.
        try:
            cursors = self.con.cursors
            del cursors.data[id(self)]
        except:
            pass

        self._invalidate()

    def __del__(self):
        # Disassociate ourselves from our connection.
        try:
            cursors = self.con.cursors
            del cursors.data[id(self)]
        except:
            pass
        self._invalidate()

    def setinputsizes(self, sizes):
        """Does nothing, required by DB API."""
        self._checkNotClosed("setinputsize")

    def setoutputsize(self, size, column=None):
        """Does nothing, required by DB API."""
        self._checkNotClosed("setinputsize")

class StandardCursor(BaseCursor):
    def __init__(self, con):
        BaseCursor.__init__(self, con)
        self._reset()
        self.current_recnum = -1

    def _reset(self):
        BaseCursor._reset(self)

        self.rs = None
        self.current_recnum = 0

    def _execute_sql(self, sql):
        self.rs = self.con.db.execute(sql)

    def _after_execute_sql(self):
        self.current_recnum = 0
        self.rowcount = len(self.rs.row_list)
        self.description = self.rs.col_defs
        self.PgResultSetClass = make_PgResultSetClass(self.description[:])

    def _invalidate(self):
        pass

    #
    # DB-API methods:
    #

    def fetchone(self):
        self._checkNotClosed("fetchone")

        # If there are no records
        if self.rowcount == 0:
            return None

        # If we have reached the last record
        if(self.current_recnum == self.rowcount):
            return None

        retval = self.PgResultSetClass(
            self._convert_types(self.rs.row_list[self.current_recnum]))

        self.current_recnum += 1

        return retval

    def fetchmany(self, howmany=None):
        self._checkNotClosed("fetchmany")

        if howmany is None:
            howmany = self.arraysize

        # If there are no records
        if self.rowcount == 0:
            return []

        # If we have reached the last record
        if(self.current_recnum >= self.rowcount):
            return []

        retval = [self.PgResultSetClass(self._convert_types(item)) \
                  for item in self.rs.row_list[self.current_recnum: \
                                               self.current_recnum + howmany]]

        self.current_recnum += howmany
        if self.current_recnum > self.rowcount:
            self.current_recnum = self.rowcount

        return retval

    def fetchall(self):
        self._checkNotClosed("fetchall")

        # If there are no records
        if self.rowcount == 0:
            return []

        # If we have reached the last record
        if(self.current_recnum >= self.rowcount):
            return []

        retval = [self.PgResultSetClass(self._convert_types(item)) \
                  for item in self.rs.row_list[self.current_recnum:]]

        self.current_recnum = self.rowcount

        return retval

    #
    # Optional DB-API extensions from PEP 0249:
    #

    def __iter__(self):
        warnings.warn("DB-API extension cursor.__iter__() used")
        return self

    def next(self):
        warnings.warn("DB-API extension cursor.next() used")
        item = self.fetchone()
        if item is None:
            if sys.version_info[:2] >= (2,2):
                raise StopIteration
            else:
                raise IndexError
        else:
            return item

    def scroll(self, value, mode="relative"):
        warnings.warn("DB-API extension cursor.scroll() used")
        if mode == "relative":
            new_recnum = self.current_recnum + value
        elif mode == "absolute":
            new_recnum = value
        else:
            raise ValueError, "invalid mode parameter"
        if new_recnum >= 0 and new_recnum < self.rowcount:
            self.current_recnum = new_recnum
        else:
            raise IndexError

    def __getattr__(self, key):
        if self.__dict__.has_key(key):
            return self.__dict__[key]
        elif key == "sql":
            # The sql attribute is a PySQLite extension.
            return self.con.db.sql()
        elif key == "rownumber":
            warnings.warn("DB-API extension cursor.rownumber used")
            return self.current_recnum
        elif key == "lastrowid":
            warnings.warn("DB-API extension cursor.lastrowid used")
            return self.con.db.sqlite_last_insert_rowid()
        elif key == "connection":
            warnings.warn("DB-API extension cursor.connection used")
            return self.con
        else:
            raise AttributeError, key

if sys.version_info[:2] >= (2,2):
    from py22features import IterCursor
    have_itercursor = 1
else:
    have_itercursor = 0

class UnicodeConverter:
    def __init__(self, encoding):
        self.encoding = encoding

    def __call__(self, val):
        return unicode(val, *self.encoding)

class Connection:

    default_cursor = StandardCursor

    def __init__(self, db, mode=0755, converters={}, autocommit=0, client_encoding=None, *arg, **kwargs):
        if type(client_encoding) not in (TupleType, ListType):
            self.client_encoding = (client_encoding or sys.getdefaultencoding(),)
        else:
            self.client_encoding = client_encoding

        # These are the converters we provide by default ...
        self.converters = {"str": str, "int": int, "long": long, "float": float,
                           "unicode": UnicodeConverter(self.client_encoding)}

        # ... and DateTime/DateTimeDelta, if we have the mx.DateTime module.
        if have_datetime:
            self.converters.update({"DateTime": DateTime.DateTimeFrom,
                                    "DateTimeDelta": DateTime.DateTimeDeltaFrom})
        self.converters.update(converters)

        self.autocommit = autocommit

        # The addition of the IterCursor cursor implementation unfortunately
        # opened a whole can of worms because it uses multithreading
        # internally. It means we have to guard the low-level database
        # connection self.db against concurrent usage, which could otherwise
        # easily happen if an IterCursor is not yet exhausted and the SQLite
        # library is thus still executing sqlite_exec, but the Python callback
        # in IterCursor is blocking on insertion to its Queue.Queue instance.
        #
        # The solution was to add add a wrapper class, implemented in the
        # threadutils module, that automagically makes an entire object
        # 'synchronized', meaning only one method of the object can be active
        # at any one point in time. The SynchronizedObject constructor has two
        # parameters, the first one being the object to guard from concurrent
        # access, the second one an optional function to call before waiting on
        # the lock.
        #
        # The reason for this is to cope with the following scenario: A user
        # creates an IterCursor on a connection, but the IterCursor is not
        # exhausted yet. Now the user invokes a method on another cursor or on
        # the connection object that would use the low-level connection object,
        # and thus result in concurrent usage of the SQLite library. This would
        # crash the SQLite library.
        #
        # The only means of doing the library call is to exhaust all
        # IterCursors, so that none of them has their producer thread running
        # and so none of them is in the sqlite_exec function of SQLite.
        #
        # The _invalidateIterCursors method of the connection class does
        # exactly that, and we provide it as cleanup_handler to the
        # SynchronizedObject constructor.

        cleanup_handler = lambda x: self._invalidateIterCursors()
        self.db = threadutils.SynchronizedObject(_sqlite.connect(db, mode), cleanup_handler)

        self.closed = 0
        self.inTransaction = 0

        self.cursorclass = self.default_cursor
        self.cursors = weakref.WeakValueDictionary()

    def __del__(self):
        if not self.closed:
            self.close()

    def _checkNotClosed(self, methodname):
        if self.closed:
            raise _sqlite.ProgrammingError, \
                  "%s failed - Connection is closed." % methodname

    def __anyCursorsLeft(self):
        return len(self.cursors.data.keys()) > 0

    def __closeCursors(self, doclose=0):
        """__closeCursors() - closes all cursors associated with this connection"""
        if self.__anyCursorsLeft():
            cursors = map(lambda x: x(), self.cursors.data.values())

            for cursor in cursors:
                try:
                    if doclose:
                        cursor.close()
                    else:
                        cursor._reset()
                except weakref.ReferenceError:
                    pass

    def _begin(self):
        self.db.begin()
        self.inTransaction = 1

    def _invalidateIterCursors(self, exception=None):
        """Invalidate all IterCursors associated with this connection.
        If exception is given, this one cursor is not invalidated."""
        # Look into the comments in __init__ for the rationale.
        if not have_itercursor:
            return
        if self.__anyCursorsLeft():
            cursors = map(lambda x: x(), self.cursors.data.values())
            for cursor in cursors:
                if isinstance(cursor, IterCursor) and cursor is not exception:
                    cursor._invalidate()

    #
    # PySQLite extensions:
    #

    def create_function(self, name, nargs, func):
        self.db.create_function(name, nargs, func)

    def create_aggregate(self, name, nargs, agg_class):
        agg = agg_class()
        stepfunc = lambda *args: agg_class.step(agg, *args)
        finfunc = lambda *args:  agg_class.finalize(agg, *args)
        self.db.create_aggregate(name, nargs, stepfunc, finfunc)

    #
    # DB-API methods:
    #

    def commit(self):
        self._checkNotClosed("commit")
        if self.autocommit:
            raise _sqlite.ProgrammingError, "Commit failed - autocommit is on."

        self.db.commit()
        self.inTransaction = 0

    def rollback(self):
        self._checkNotClosed("rollback")
        if self.autocommit:
            raise _sqlite.ProgrammingError, "Rollback failed - autocommit is on."

        self.db.rollback()
        self.inTransaction = 0

    def close(self):
        self._checkNotClosed("close")

        self.__closeCursors(1)

        if self.inTransaction:
            self.rollback()

        self.db.close()
        self.closed = 1

    def cursor(self, cursorclass=None):
        self._checkNotClosed("cursor")
        return (cursorclass or self.cursorclass)(self)

    #
    # Optional DB-API extensions from PEP 0249:
    #

    def __getattr__(self, key):
        if key in self.__dict__.keys():
            return self.__dict__[key]
        elif key in ('IntegrityError', 'InterfaceError', 'InternalError',
                     'NotSupportedError', 'OperationalError',
                     'ProgrammingError', 'Warning'):
            warnings.warn("DB-API extension connection.%s used" % key)
            return getattr(_sqlite, key)
        else:
            raise AttributeError, key
