# Author: Njaal Borch <Njaal.Borch@norut.no>

from pysqlite2 import dbapi2 as sqlite
import re
import os.path

def sql_escape(string):

    """Escape a string for use with SQL
    
    """

    # escape_string screws up if any fancy letters are present
    return string.replace("'", "\\'")
    # the mysqldb stuff is on drugs, just escape this by my own
    #for char in ["\'", "\"", "(", ")", "[", "]"]:
    #    string = string.replace(char, "\\"+char)
    return string

class CollectionDB:
    
    """A CollectionDB is a class to bundle resources together.  It is
    based on a sqlite backend, and can serialize to XML (RSS video
    feed).  """

    
    def __init__(self, db_name="Collections.data"):

        """Create a new collection DB.  db_name can be ':memory:' to
        create a in-memory db only, or a filename to persist it.
        
        """
        
        self.db = sqlite.connect(db_name, isolation_level=None)
        self.prepare()

    def cursor(self):
        return self.db.cursor()
    
    def prepare(self):

        """Prepare the database - check that all tables
        exist, or create them if htey dont
        """
        
        cursor = self.db.cursor()

        # Get list of old tables
        cursor.execute("select tbl_name from sqlite_master where type='table' order by tbl_name")
        
        tables = []
        for row in cursor.fetchall():
            tables.append(row[0])

        if not "collection" in tables:

            cursor.execute("""
            create table collection(
            name TEXT,
            id BIGINT UNSIGNED PRIMARY KEY, 
            description TEXT,
            type SMALLINT UNSIGNED)
            """)
            cursor.execute("CREATE INDEX c_name_idx ON collection(name)")
            

        if not "collection_entry" in tables:

            cursor.execute("""
            create table collection_entry(
            name TEXT,
            id TEXT, 
            description TEXT,
            url TEXT NOT NULL, 
            type SMALLINT UNSIGNED,
            major INT UNSIGNED DEFAULT 0,
            minor INT UNSIGNED DEFAULT 0,
            duplicate INT UNSIGNED DEFAULT 0,
            quality TEXT DEFAULT '',
            collection_id BIGINT UNSIGNED NOT NULL,
            time_stamp TEXT DEFAULT NULL,
            size BIGINT UNSIGNED DEFAULT 0,
            PRIMARY KEY (name, collection_id),
            FOREIGN KEY (collection_id) REFERENCES collection(id)
            )
            """)
            # Also create an index on the cursor id?
            # This might be automatic as it is a foreign key?
            #cursor.execute("CREATE INDEX ce_cid_idx ON collection_entry(cursor_id)")
                

    def clear(self):
        
        """Clear the entire database!
        
        """
        cursor = self.db.cursor()
        cursor.execute("DELETE FROM collection_entry")
        cursor.execute("DELETE FROM collection")
        
    def get_names(self):

        """Return a list of known collections
        
        """

        # TODO: Implement sorting on size/something
        cursor = self.db.cursor()
        SQL = "SELECT name FROM collection"
        cursor.execute(SQL)

        names = []
        for row in cursor.fetchall():
            names.append(row[0].encode("utf-8", "replace"))
        return names
    
    def get_collection(self, name=None, id=None):

        """Get a collection object for a given collection.  Can either
        find the collection based on the name or the id.  If both are
        ignored, an exception is thrown.  If no matching collection
        was found, a "NotFoundException" is thrown

        """

        if not name and not id:
            raise Exception("Missing parameter")
        
        SQL = "SELECT id FROM collection WHERE "
        if name:
            SQL += "name='%s'"%sql_escape(self.clean_up_string(name, fixCaps=True))
        else:
            SQL += "id=%d"%id

        cursor = self.db.cursor()
        cursor.execute(SQL)
        row = cursor.fetchone()
        if not row:
            raise NotFoundException()

        return Collection(self, row[0])

    def remove_collection(self, name=None, id=None):
        
        """Remove a collection object for a given collection.  Can either
        find the collection based on the name or the id.  If both are
        ignored, an exception is thrown.  If no matching collection
        was found, a "NotFoundException" is thrown

        """

        if not name and not id:
            raise Exception("Missing parameter")

        SQL = "DELETE FROM collection WHERE "
        if name:
            SQL += "name='%s'"%sql_escape(self.clean_up_string(name, fixCaps=True))
        else:
            SQL += "id=%s"%id

        cursor = self.db.cursor()
        cursor.execute(SQL)
        if cursor.rowcount < 1:
            raise NotFoundException()
        
    def new_collection(self, name, description, type=0, id=None):

        """Create a new collection.  After a collection has been
        created (or if it was fetched from the CollectionDB using
        'get_collection', all update functions on the collection
        object will propagate to the DB directly.  
        
        """

        assert name
        
        cursor = self.db.cursor()
        if not id:
            # As cursor.lastrowid is not the same as 'id', get it first
            cursor.execute("SELECT RANDOM()")
            id = cursor.fetchone()[0]

        SQL = "INSERT INTO collection(name,id,description,type) "\
              "VALUES ('%s', '%s', '%s', %d)"%\
              (sql_escape(self.clean_up_string(name, fixCaps=True).encode("utf8", "replace")), id,
               sql_escape(description.encode("utf8", "replace")), type)
        cursor.execute(SQL)

        # TODO: Some decent error handling here?
        if cursor.rowcount == 0:
            raise Exception("Create failed for unknown reason")

        # Return the collection object
        return self.get_collection(id=id)

    def clean_up_string(self, str, fixCaps=False):

        """Try to clean up names by removing extra whitespaces, change
        '.', '_' etc into spaces and so fort.  Returns a cleaned string
        
        """
        for c in ['.', '_', '-']:
            str = str.replace(c, " ")

        str = str.replace("'", "") # Just remove these, they are annoying

        while str.find("  ") > -1:
            str = str.replace("  ", " ")
            
        if fixCaps:
            str = str[0].upper() + str[1:].lower()

        return str.strip()

    def _smart_add_make_coll(self, title, major, minor, rest, ext, quality):
        
        """Internal function, do not use.  Resolve/create new
        collection and return the stuff smart_add is supposed to
        return
        
        """

        title = self.clean_up_string(title, fixCaps=True)
        rest = self.clean_up_string(rest)
        try:
            coll = self.get_collection(title)
        except NotFoundException:
            # Create a new collection
            coll = self.new_collection(title, rest)
            if not coll:
                raise Exception("Could not smart_add, don't know why")

        major = int(major)
        minor = int(minor)
        
        # Return the required information
        return (coll, (major, minor, 0, quality)) # TODO: Add type
        
    def smart_add(self, name, url):

        """Try to parse the name, and if it is recognized, the guessed
        correct collection named will be returned together with the
        necessary information to create the entry

        Returns a touple (collection, (major, minor, type, quality))
        or raises "NotFoundException" if unable to understand the
        name.
        
        """

        if not (name and url):
            raise Exception("Require both name or url")
        
        # Split the extension off
        (name, ext) = os.path.splitext(name)
        
        # First try to detect the quality
        quality = ""
        lname = name.lower()
        for q in ["720p", "1080i", "1080p", "720i"]:
            if lname.find(q) > -1:
                quality = q
                break
        
        # First check the typical "title s1e12 therest"
        found=False
        m = re.match("(.*)[sS](\d+)[eE](\d+)(.*)", name)
        if m:
            (title, major, minor, rest) = m.groups()
            return self._smart_add_make_coll(title, major, minor, 
                                             rest, ext, quality)

        # Check "title 1x21 therest"
        m = re.match("(.*) (\d+)[xX](\d+)(.*)", self.clean_up_string(name))
        if m:
            (title, major, minor, rest) = m.groups()
            return self._smart_add_make_coll(title, major, minor, 
                                             rest, ext, quality)

        # Look for title 0123 therest (season 1, episode 23).  Check
        # that the collection already exists, as this can give us too
        # many false positives
        m = re.match("(.*) (\d\d)(\d\d)(.*)", self.clean_up_string(name))
        if m:
            (title, major, minor, rest) = m.groups()
            title = self.clean_up_string(title, fixCaps=True)
            # Do we have this collection?
            try:
                self.get_collection(title)
                return self._smart_add_make_coll(title, major, minor, 
                                                 rest, ext, quality)
            except:
                # Ignore it, we don't know about it
                pass
            
        m = re.match("(.*) (\d)(\d\d)(.*)", self.clean_up_string(name))
        if m:
            (title, major, minor, rest) = m.groups()
            title = self.clean_up_string(title, fixCaps=True)
            # Do we have this collection?
            try:
                self.get_collection(title)
                return self._smart_add_make_coll(title, major, minor,
                                                 rest, ext, quality)
            except:
                # Ignore it, we don't know about it
                pass

        # Title season 3 (complete seasons)
        # TODO: Also require already existing collection?
        m = re.match("(.*)season\W*(\d+)(.*)", 
                     self.clean_up_string(name).lower())
        if m:
            (title, major, rest) = m.groups()
            minor=0
            return self._smart_add_make_coll(title, major, minor, 
                                             rest, ext, quality)
        
        raise NotFoundException()
    
    def remove_entry_from_all(self, id):
        """Erase an entry from all collections

        """

        SQL = "DELETE FROM collection_entry WHERE id='%s'"%id
        cursor = self.db.cursor()
        print SQL
        cursor.execute(SQL)

class Collection:
    """The Collection class is an object describing an actual
    collection.  It is used to manipulate the collection directly.

    """
    
    def __init__(self, db, id):

        """Create a new collection object.  This should never be done
        directly, but only done by the CollectionDB.  Create a
        collection by using CollectionDB.new_collectiion()

        Do not use the internal variables directly, but use the set/get
        functions in order to ensure database persistence.
        
        """
        
        self.__id = id
        self.__db = db

    def get_id(self):
        return self.__id
    
    def set_description(self, description):
        SQL = "UPDATE collection SET description='%s' WHERE id='%s'"%\
              (sql_escape(description), self.__id)
        cursor = self.__db.cursor()
        cursor.execute(SQL)
        if cursor.rowcount == 0:
            raise Exception("Update failed") # TODO: Get reason
        
    def get_description(self):
        SQL = "SELECT description FROM collection WHERE id='%s'"%self.__id
        cursor = self.__db.cursor()
        cursor.execute(SQL)
        if cursor.rowcount == 0:
            raise NotFoundException("Could not find collection %d"%self.__id)
        return cursor.fetchone()[0]
    
    def set_name(self, name):
        SQL = "UPDATE collection SET name='%s' WHERE id='%s'"%\
              (sql_escape(self.__db.clean_up_string(name)), self.__id)
        cursor = self.__db.cursor()
        cursor.execute(SQL)
        if cursor.rowcount == 0:
            raise Exception("Update failed") # TODO: Get reason

    def get_name(self):
        SQL = "SELECT name FROM collection WHERE id='%s'"%self.__id
        cursor = self.__db.cursor()
        cursor.execute(SQL)
        if cursor.rowcount == 0:
            raise NotFoundException("Could not find collection %d"%self.__id)
        return cursor.fetchone()[0]

    def get_entry(self, id=None, name=None):

        """Get an entry of this one (possibly allowing duplicates).
        Returns a touple
        (id, name, description, url, major, minor, type, quality, time_stamp, size)
        
        """
        
        assert (id or name)
        SQL = "SELECT id, name, description, url, major, minor, "\
              "type, quality, time_stamp, size FROM collection_entry WHERE "
        if id:
            SQL += "id='%s'"%id
        else:
            SQL += "name='%s'"%sql_escape(name)
        cursor = self.__db.cursor()
        cursor.execute(SQL)
        row = cursor.fetchone()
        if not row:
            raise NotFoundException()
        (id, name, description, url, major, minor, type, quality, ts, size) = row
        return (id, name, description, url, major, minor, type, quality, ts, size)
        
    def get_entries(self, major=None, minor=None, 
                    quality=None, reverse=False,
                    dupes=True):

        """Get a list of elements from this collection.  The elements
        will be returned as a list of touples on the form:
        (id, name, description, url, major, minor, type, quality, time_stamp, size)
        
        If 'reverse' is set to True, the entries will be returned with
        the highest value on top.

        If dupes is set to False and quality is None, the entry with
        the highest quality is selected. If all have the same quality,
        the resource added first will be returned.

        """

        # TODO: Allow other sorts?
        SQL = "SELECT id, name, description, url, major, minor, type, "\
              "quality, time_stamp, size FROM collection_entry WHERE "\
              "collection_id=%s "%self.__id
        
        if major:
            SQL += "AND major=%d "%major
        if minor:
            SQL += "AND minor=%d "%minor
        if quality != None:
            SQL += "AND quality='%s'"%quality.lower()
        if reverse:
            SQL += "ORDER BY major DESC,minor DESC,quality DESC,duplicate"
        else:
            SQL += "ORDER BY major,minor,quality DESC,duplicate"
        
        cursor = self.__db.cursor()
        cursor.execute(SQL)

        found = []
        list = []
        for (id, name, description, url, major, minor, type, quality, time_stamp, size) in cursor.fetchall():
            if not dupes:
                if (major,minor) in found:
                    continue
                found.append((major,minor))
            
            list.append((id, name, description, url,
                         major, minor, type, quality, time_stamp, size))
        return list
        
            
    
    def add_entry(self, id, name, description, url, major=0, minor=0, type=0, quality=None, time_stamp=None, size=0):

        """Add an element to the collection
        
        """
        
        if not time_stamp:
            time_stamp = "NULL"
        else:
            time_stamp = "'%s'"%time_stamp

        # First we need to find out if this is a duplicate
        # TODO: Should lock table here if parallel access can happen!
        SQL = "SELECT MAX(duplicate) FROM collection_entry "\
              "WHERE collection_id=%s AND major=%s AND minor=%d"%\
              (self.__id, major, minor)
        cursor = self.__db.cursor()
        cursor.execute(SQL)
        row = cursor.fetchone()
        if row[0] == None:
            dupe = 0 
        else:
            dupe = row[0] + 1

        if not quality:
            quality = ""
        else:
            quality = quality.lower() 

        SQL = "INSERT INTO collection_entry (id, name, description, url, "\
              "type, major, minor, duplicate, quality, collection_id, "\
              "time_stamp, size) VALUES "\
              "('%s', '%s', '%s', '%s', "\
              "%d, %d, %d, %d, '%s', %d, %s, %d)"%\
              (sql_escape(id), sql_escape(name),
               sql_escape(description),
               sql_escape(url),
               type, major, minor, dupe, quality, self.__id, time_stamp,
               size)
        cursor.execute(SQL)
        if cursor.rowcount == 0:
            raise Exception("Insert failed for uknown reason")

        return True

    def remove_entry_by_id(self, id):
        
        """Remove an entry from this collection.  

        """
        
        SQL = "DELETE FROM collection_entry WHERE id='%s' AND collection_id='%s'"%(id, self.__id)
        cursor = self.db.cursor()
        cursor.execute(SQL)
        
        
class NotFoundException(Exception):
    pass


