#!/usr/bin/python -OO
# Copyright 2005 Gregor Kaufmann <tdian@users.sourceforge.net>
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
# as published by the Free Software Foundation; either version 2
# of the License, or (at your option) any later version.
# 
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
# 
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.

"""
sabnzbd.nzbqueue - nzb queue
"""

__NAME__ = "nzbqueue"

import re
import os
import binascii
import logging
import tempfile
import sabnzbd
import time
import datetime

from sabnzbd.nzbstuff import NzbObject

from sabnzbd.constants import *

try:
    import _yenc
    HAVE_YENC = True

except ImportError:
    HAVE_YENC = False

#-------------------------------------------------------------------------------

class HistoryItem:
    def __init__(self, nzo):
        self.nzo = nzo
        self.filename = nzo.get_filename()
        self.bytes_downloaded = nzo.get_bytes_downloaded()
        self.completed = time.time()
        self.unpackstrht = None
        
    def cleanup(self):
        if self.nzo:
            self.bytes_downloaded = self.nzo.get_bytes_downloaded()
            self.unpackstrht = self.nzo.get_unpackstrht()
            self.completed = time.time()
            self.nzo = None
            
class CrcError(Exception):
    def __init__(self, needcrc, gotcrc):
        Exception.__init__(self)
        self.needcrc = needcrc
        self.gotcrc = gotcrc
        
#-------------------------------------------------------------------------------
YDEC_TRANS = ''.join([chr((i + 256 - 42) % 256) for i in range(256)])
class NzbQueue:
    def __init__(self, f_mode = False, auto_sort = False, cache_limit = 0):
        self.__cache_limit = cache_limit
        self.__cache_size = 0
        
        self.__downloaded_items = []
        self.__nzo_list = []
        self.__article_list = []
        
        self.__nzo_table = {}
        self.__nzf_table = {}
        self.__article_table = {}
        
        self.__f_mode = f_mode
        self.__auto_sort = auto_sort
        
        self.try_list = []
        
        nzo_ids = []
        
        data = sabnzbd.load_data(sabnzbd.QUEUE_FILE_NAME, remove = False)
        
        if data:
            try:
                queue_vers, nzo_ids, self.__downloaded_items = data
                if not queue_vers == sabnzbd.__queueversion__:
                    logging.info("[%s] Outdated queuefile found, discarding", 
                                 __NAME__)
                    self.__downloaded_items = []
                    nzo_ids = []
            except ValueError:
                logging.exception("[%s] Error loading %s, corrupt file " + \
                                  "detected", __NAME__, sabnzbd.QUEUE_FILE_NAME)
                                  
            for nzo_id in nzo_ids:
                nzo = sabnzbd.load_data(nzo_id, remove = False)
                if nzo:
                    self.add(nzo, save = False)
                    
    def __init__stage2__(self):
        for hist_item in self.__downloaded_items:
            if hist_item.nzo:
                sabnzbd.postprocess_nzo(hist_item.nzo)
                
    def save(self):
        """ Save queue """
        logging.info("[%s] Saving queue", __NAME__)
        
        nzo_ids = []
        # Aggregate nzo_ids and save each nzo
        for nzo in self.__nzo_list:
            nzo_ids.append(nzo.nzo_id)
            sabnzbd.save_data(nzo, nzo.nzo_id)
            
        sabnzbd.save_data((sabnzbd.__queueversion__, nzo_ids, 
                           self.__downloaded_items), sabnzbd.QUEUE_FILE_NAME)
                              
    def generate_future(self, msg, repair, unpack, delete):
        """ Create and return a placeholder nzo object """
        future_nzo = NzbObject(msg, repair, unpack, delete, None, True)
        self.add(future_nzo)
        return future_nzo
        
    def insert_future(self, future, filename, data, cat_root = None, 
                      cat_tail = None):
        """ Refresh a placeholder nzo with an actual nzo """
        nzo_id = future.nzo_id
        if nzo_id in self.__nzo_table:
            try:
                logging.info("[%s] Regenerating item: %s", __NAME__, nzo_id)
                repair, unpack, delete = future.get_repair_opts()
                future.__init__(filename, repair, unpack, delete, 
                                nzb = data, futuretype = False, 
                                cat_root = cat_root, cat_tail = cat_tail)
                future.nzo_id = nzo_id
                self.save()
                sabnzbd.backup_nzb(filename, data)
                
                if self.__f_mode:
                    self.__make_filemode(future)
                    
                if self.__auto_sort:
                    self.sort_by_avg_age()
                    
                self.reset_try_list()
            except:
                logging.exception("[%s] Error while adding %s, removing", __NAME__, 
                                  nzo_id)
                self.remove(nzo_id, False)
        else:
            logging.info("[%s] Item %s no longer in queue, omitting", __NAME__, 
                         nzo_id)
                         
    def change_opts(self, nzo_id, pp):
        if nzo_id in self.__nzo_table:
            self.__nzo_table[nzo_id].set_opts(pp)
            
    def add(self, nzo, pos = -1, save=True):
        # Reset try_lists
        self.reset_try_list()
        nzo.reset_try_list()
        
        if not nzo.nzo_id:
            nzo.nzo_id = sabnzbd.get_new_id('nzo')
            
        if nzo.nzo_id:
            nzo.deleted = False
            self.__nzo_table[nzo.nzo_id] = nzo
            if pos > -1:
                self.__nzo_list.insert(pos, nzo)
            else:
                self.__nzo_list.append(nzo)
            if save:
                self.save()
            
            if self.__f_mode:
                self.__make_filemode(nzo)
                
        if self.__auto_sort:
            self.sort_by_avg_age()
            
    def remove(self, nzo_id, add_to_history = True):
        if nzo_id in self.__nzo_table:
            nzo = self.__nzo_table.pop(nzo_id)
            nzo.deleted = True
            self.__nzo_list.remove(nzo)
            
            for nzf in nzo.get_all_nzfs():
                self.remove_from_nzf_table(nzf)
                
            if add_to_history:
                # Make sure item is only represented once in history
                should_add = True
                for hist_item in self.__downloaded_items:
                    if hist_item.nzo and hist_item.nzo.nzo_id == nzo.nzo_id:
                        should_add = False
                        break
                if should_add:
                    self.__downloaded_items.append(HistoryItem(nzo))
            else:
                nzo.purge_data()
                self.__purge_articles(nzo.saved_articles)
                
            sabnzbd.remove_data(nzo_id)
            self.save()
            
    def remove_nzf(self, nzo_id, nzf_id):
        if nzo_id in self.__nzo_table:    
            nzo = self.__nzo_table[nzo_id]
            nzf = nzo.get_nzf_by_id(nzf_id)
            
            if nzf:
                self.remove_from_nzf_table(nzf)
                
                post_done = nzo.remove_nzf(nzf)
                if post_done:
                    self.remove(nzo_id, add_to_history = False)
                    
    def remove_from_nzf_table(self, nzf):
        date = nzf.get_date()
        if date in self.__nzf_table:
            nzf_list = self.__nzf_table[date]
            if nzf in nzf_list:
                logging.debug("[%s] Removing %s from nzf_table[%s]", 
                             __NAME__, nzf, date)
                nzf_list.remove(nzf)
            if not nzf_list:
                logging.debug("[%s] Removing %s from nzf_table", 
                             __NAME__, date)
                self.__nzf_table.pop(date)
                
    def switch(self, item_id_1, item_id_2):
        item_id_pos1 = -1
        item_id_pos2 = -1
        for i in xrange(len(self.__nzo_list)):
            if item_id_1 == self.__nzo_list[i].nzo_id:
                item_id_pos1 = i
            elif item_id_2 == self.__nzo_list[i].nzo_id:
                item_id_pos2 = i
            if (item_id_pos1 > -1) and (item_id_pos2 > -1):
                item = self.__nzo_list[item_id_pos1]
                del self.__nzo_list[item_id_pos1]
                self.__nzo_list.insert(item_id_pos2, item)
                break
                
    def move_up_bulk(self, nzo_id, nzf_ids):
        if nzo_id in self.__nzo_table:
            self.__nzo_table[nzo_id].move_up_bulk(nzf_ids)
            
    def move_top_bulk(self, nzo_id, nzf_ids):
        if nzo_id in self.__nzo_table:
            self.__nzo_table[nzo_id].move_top_bulk(nzf_ids)
            
    def move_down_bulk(self, nzo_id, nzf_ids):
        if nzo_id in self.__nzo_table:
            self.__nzo_table[nzo_id].move_down_bulk(nzf_ids)
            
    def move_bottom_bulk(self, nzo_id, nzf_ids):
        if nzo_id in self.__nzo_table:
            self.__nzo_table[nzo_id].move_bottom_bulk(nzf_ids)
            
    def sort_by_avg_age(self):
        logging.info("[%s] Sorting by average date...", __NAME__)
        self.__nzo_list.sort(cmp=nzo_date_cmp)
        
    def add_to_try_list(self, server):
        if server not in self.try_list:
            logging.debug("[%s] Appending %s to NZBQ.try_list",
                          __NAME__, server)
            self.try_list.append(server)
                
    def remove_from_try_list(self, server):
        if server in self.try_list:
            logging.debug("[%s] Removing %s from NZBQ.try_list", 
                          __NAME__, server)
            self.try_list.remove(server)
            
    def reset_try_list(self):
        if self.try_list:
            logging.debug("[%s] Reseting NZBQ.try_list", __NAME__)
            self.try_list = []
            
    def get_article(self, server):
        if self.__f_mode:
            keys = self.__nzf_table.keys()
            keys.sort()
            
            article = None
            nzf_remove_list = []
            
            do_break = False
            
            for key in keys:
                if do_break:
                    break
                    
                for nzf in self.__nzf_table[key]:
                    if (server not in nzf.try_list) and (nzf.is_active()):
                        if not nzf.import_finished:
                            nzf.finish_import()
                            if not nzf.import_finished:
                                logging.error("[%s] Error importing %s", __NAME__, nzf)
                                nzf_remove_list.append(nzf)
                                continue
                                
                        article = nzf.get_article(server)
                        if article:
                            do_break = True
                            break
            
            for nzf in nzf_remove_list:
                self.remove_from_nzf_table(nzf)
                
            return article
            
        else:
            for nzo in self.__nzo_list:               
                # Don't try to get an article if server is in nzo.try_list
                if server not in nzo.try_list:
                    article = nzo.get_article(server)
                    if article:
                        return article
                        
        # No articles for this server, block server (until reset issued)
        self.add_to_try_list(server)
        
    def register_article(self, article, lines = None):
        nzf = article.nzf
        nzo = nzf.nzo
        
        if nzo.deleted or nzf.deleted:
            logging.debug("[%s] Discarding article %s, no longer in queue", 
                          __NAME__, article.article)
            return
            
        if lines:
            try:
                self.decode(article, lines)
                nzf.increase_article_count()
            except IOError, e:
                raise e
            except CrcError, e:
                raise e
            except:
                logging.exception("[%s] Unknown exception in decode()",
                                  __NAME__)
                                  
        file_done, post_done, reset = nzo.remove_article(article)
        
        filename = nzf.get_filename()
        if filename:
            root, ext = os.path.splitext(filename)
            if ext in sabnzbd.CLEANUP_LIST:
                logging.info("[%s] Skipping %s", __NAME__, nzf)
                file_done, reset = (False, True)
                post_done = post_done or nzo.remove_nzf(nzf)
                self.remove_from_nzf_table(nzf)
                
        if reset:
            self.reset_try_list()
            
        if file_done:
            self.remove_from_nzf_table(nzf)
            
            if sabnzbd.DO_SAVE:
                sabnzbd.save_data(nzo, nzo.nzo_id)
                
            filename = nzf.get_filename()
            _type = nzf.get_type()
            
            # Only start decoding if we have a filename and type
            if filename and _type:
                sabnzbd.assemble_nzf((nzo, nzf))
                
            else:
                logging.warning('[%s] %s -> Unknown encoding', __NAME__,
                                filename)
                                
        if post_done:
            self.remove(nzo.nzo_id, True)
            
            # Notify assembler to call postprocessor
            sabnzbd.assemble_nzf((nzo, None))
            
            # sabnzbd.AUTOSHUTDOWN only True on os.name == 'nt'
            if sabnzbd.AUTOSHUTDOWN and not self.__nzo_list:
                sabnzbd.LOCK.release()
                sabnzbd.ASSEMBLER.stop()
                try:
                    sabnzbd.ASSEMBLER.join()
                except:
                    pass
                    
                sabnzbd.POSTPROCESSOR.stop()
                try:
                    sabnzbd.POSTPROCESSOR.join()
                except:
                    pass
                sabnzbd.LOCK.acquire()
                self.save()
                sabnzbd.system_shutdown()
                
    def purge(self):
        self.__downloaded_items = []
        
    def queue_info(self, for_cli = False):
        bytes_left = 0
        bytes = 0
        pnfo_list = []
        for nzo in self.__nzo_list:
            pnfo = nzo.gather_info(for_cli = for_cli)
            
            bytes += pnfo[PNFO_BYTES_FIELD] 
            bytes_left += pnfo[PNFO_BYTES_LEFT_FIELD]
            pnfo_list.append(pnfo)
                                     
        return (bytes, bytes_left, len(self.__article_list), 
                self.__cache_size, self.__cache_limit, pnfo_list)
                
    def history_info(self):
        history_info = {}
        bytes_downloaded = 0
        for hist_item in self.__downloaded_items:
            completed = hist_item.completed
            filename = hist_item.filename
            bytes_downloaded += hist_item.bytes_downloaded
            
            if completed not in history_info:
                history_info[completed] = []
                
            if hist_item.nzo:
                unpackstrht = hist_item.nzo.get_unpackstrht()
                loaded = True
            else:
                unpackstrht = hist_item.unpackstrht
                loaded = False
                
            history_info[completed].append((filename, unpackstrht, loaded))
        return (history_info, bytes_downloaded, sabnzbd.get_bytes())
        
    def is_empty(self):
        empty = True
        for nzo in self.__nzo_list:
            if not nzo.futuretype:
                empty = False
                break
        return empty
        
    def cleanup_nzo(self, nzo):
        nzo.purge_data()
        self.__purge_articles(nzo.saved_articles)
        
        for hist_item in self.__downloaded_items:
            # refresh fields & delete nzo reference
            if hist_item.nzo and hist_item.nzo == nzo:
                hist_item.cleanup()
                logging.debug('[%s] %s cleaned up', __NAME__, 
                              nzo.get_filename())
                
    def decode(self, article, data):
        data = strip(data)
        ## No point in continuing if we don't have any data left
        if data:
            nzf = article.nzf
            nzo = nzf.nzo
            yenc, data = yCheck(data)
            ybegin, ypart, yend = yenc
            decoded_data = None
            
            #Deal with non-yencoded posts
            if not ybegin:
                if data[0].startswith('begin '):
                    nzf.set_filename(data[0].split(None, 2)[2])
                    nzf.set_type('uu')
                    data.pop(0)
                if data[-1] == 'end':
                    data.pop()
                    if data[-1] == '`':
                        data.pop()
                        
                decoded_data = '\r\n'.join(data)
                
            #Deal with yenc encoded posts
            elif (ybegin and yend):
                if 'name' in ybegin:
                    nzf.set_filename(ybegin['name'])
                else:
                    logging.debug("[%s] Possible corrupt header detected " + \
                                  "=> ybegin: %s", __NAME__, ybegin)
                nzf.set_type('yenc')
                # Decode data
                if HAVE_YENC:
                    decoded_data, crc = _yenc.decode_string(''.join(data))[:2]
                    partcrc = '%08X' % ((crc ^ -1) & 2**32L - 1)
                else:
                    data = ''.join(data)
                    for i in (0, 9, 10, 13, 27, 32, 46, 61):
                        j = '=%c' % (i + 64)
                        data = data.replace(j, chr(i))
                        decoded_data = data.translate(YDEC_TRANS)
                        crc = binascii.crc32(decoded_data)
                        partcrc = '%08X' % (crc & 2**32L - 1)
                        
                if ypart:
                    crcname = 'pcrc32'
                else:
                    crcname = 'crc32'
                    
                if crcname in yend:
                    _partcrc = '0' * (8 - len(yend[crcname])) + yend[crcname].upper()
                else:
                    _partcrc = None
                    logging.debug("[%s] Corrupt header detected " + \
                                  "=> yend: %s", __NAME__, yend)
                                  
                if not (_partcrc == partcrc):
                    self.__save_article(article, decoded_data)
                    raise CrcError(_partcrc, partcrc)
                    
            if decoded_data:
                self.__save_article(article, decoded_data)
                
    def load_article(self, article):
        data = None
        
        if article in self.__article_list:
            data = self.__article_table.pop(article)
            self.__article_list.remove(article)
            self.__cache_size -= len(data)
            logging.info("[%s] Loaded %s from cache", __NAME__, article)
            logging.debug("[%s] cache_size -> %s", __NAME__, self.__cache_size)
        elif article.art_id:
            data = sabnzbd.load_data(article.art_id, remove = True, do_pickle = False)
            
        nzo = article.nzf.nzo
        if article in nzo.saved_articles:
            nzo.saved_articles.remove(article)
            
        return data
        
    def flush_articles(self):
        self.__cache_size = 0
        while self.__article_list:
            article = self.__article_list.pop(0)
            data = self.__article_table.pop(article)
            self.__flush_article(article, data)
            
    def __flush_article(self, article, data):
        if article.nzf.nzo in self.__nzo_list:
            art_id = article.get_art_id()
            if art_id:
                logging.info("[%s] Flushing %s to disk", __NAME__, article)
                logging.debug("[%s] cache_size -> %s", __NAME__, self.__cache_size)
                sabnzbd.save_data(data, art_id, do_pickle = False)
            else:
                logging.warning("[%s] Flushing %s failed -> no art_id", __NAME__, article)
        else:
            logging.debug("[%s] %s discarded", __NAME__, article)
        
    def __add_to_cache(self, article, data):
        if article in self.__article_table:
            self.__cache_size -= len(self.__article_table[article])
        else:
            self.__article_list.append(article)
            
        self.__article_table[article] = data
        self.__cache_size += len(data)
        logging.info("[%s] Added %s to cache", __NAME__, article)
        logging.debug("[%s] cache_size -> %s", __NAME__, self.__cache_size)
            
    def __save_article(self, article, data):
        saved_articles = article.nzf.nzo.saved_articles
        if article not in saved_articles:
            saved_articles.append(article)
        
        if self.__cache_limit:
            if self.__cache_limit < 0:
                self.__add_to_cache(article, data)
                
            else:
                data_size = len(data)
                
                while (self.__cache_size > (self.__cache_limit - data_size)) \
                and self.__article_list:
                    ## Flush oldest article in cache
                    old_article = self.__article_list.pop(0)
                    old_data = self.__article_table.pop(old_article)
                    self.__cache_size -= len(old_data)
                    ## No need to flush if this is a refreshment article
                    if old_article != article:
                        self.__flush_article(old_article, old_data)
                    
                ## Does our article fit into our limit now?
                if (self.__cache_size + data_size) <= self.__cache_limit:
                    self.__add_to_cache(article, data)
                else:
                    self.__flush_article(article, data)
                    
        else:
            self.__flush_article(article, data)
            
    def __purge_articles(self, articles):
        logging.debug("[%s] Purgable articles -> %s", __NAME__, articles)
        for article in articles:
            if article in self.__article_list:
                self.__article_list.remove(article)
                data = self.__article_table.pop(article)
                self.__cache_size -= len(data)
            if article.art_id:
                sabnzbd.remove_data(article.art_id)
                
    def __make_filemode(self, nzo):
        for nzf in nzo.get_all_nzfs():
            date = nzf.get_date()
            if date not in self.__nzf_table:
                self.__nzf_table[date] = []
            self.__nzf_table[date].append(nzf)
            
    def debug(self):
        return (self.__cache_limit, self.__cache_size, self.__downloaded_items[:], 
                self.__nzo_list[:], self.__article_list[:], self.__nzo_table.copy(),
                self.__nzf_table.copy(), self.__article_table.copy(), self.try_list[:])
                
#-------------------------------------------------------------------------------

def yCheck(data):
    ybegin = None
    ypart = None
    yend = None
    
    ## Check head
    for i in xrange(10):
        try:
            if data[i].startswith('=ybegin '):
                splits = 3
                if data[i].find(' part='):
                    splits += 1
                if data[i].find(' total='):
                    splits += 1
                    
                ybegin = ySplit(data[i], splits)
                
                if data[i+1].startswith('=ypart '):
                    ypart = ySplit(data[i+1])
                    data = data[i+2:]
                    break
                else:
                    data = data[i+1:]
                    break
        except IndexError:
            break
    
    ## Check tail
    for i in xrange(-1, -11, -1):
        try:
            if data[i].startswith('=yend '):
                yend = ySplit(data[i])
                data = data[:i]
                break
        except IndexError:
            break
            
    return ((ybegin, ypart, yend), data)

# Example: =ybegin part=1 line=128 size=123 name=-=DUMMY=- abc.par
#YSPLIT_RE = re.compile(r'(\S+)=')
YSPLIT_RE = re.compile(r'([a-zA-Z0-9]+)=') 
def ySplit(line, splits = None):
    fields = {}
    
    if splits:
        parts = YSPLIT_RE.split(line, splits)[1:]
    else:
        parts = YSPLIT_RE.split(line)[1:]
        
    if len(parts) % 2:
        return fields
        
    for i in range(0, len(parts), 2):
        key, value = parts[i], parts[i+1]
        fields[key] = value.strip()
        
    return fields
    
def strip(data):
    while data and not data[0]:
        data.pop(0)
    
    while data and not data[-1]:
        data.pop()
    
    for i in xrange(len(data)):
        if data[i][:2] == '..':
            data[i] = data[i][1:]
    return data
    
#-------------------------------------------------------------------------------

def nzo_date_cmp(nzo1, nzo2):
    avg_date1 = nzo1.get_avg_date()
    avg_date2 = nzo2.get_avg_date()
    
    if avg_date1 == None and avg_date2 == None:
        return 0
        
    if avg_date1 == None:
        avg_date1 = datetime.datetime.now()
    elif avg_date2 == None:
        avg_date2 = datetime.datetime.now()
        
    return cmp(avg_date1, avg_date2)
