""" Remote classes """

# Copyright 2003-2005 Iustin Pop
#
# This file is part of cfvers.
#
# cfvers is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 2 of the License, or
# (at your option) any later version.
#
# cfvers is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
# General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with cfvers; if not, write to the Free Software Foundation,
# Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA

# $Id: gateway.py 207 2005-05-29 10:35:55Z iusty $

import Pyro.core
import Pyro.constants
from Pyro.protocol import DefaultConnValidator

import sha
import hmac
import random
import logging

import cfvers.repository
from cfvers.main import Result, Entry, ParsingException, OperationError

class User:
    def __init__(self, username, client_password,
                 server_password, addresses, areas, isadmin):
        self.username = username
        self.client_password = client_password
        self.server_password = server_password
        self.valid_addrs = addresses
        self.valid_areas = areas
        self.isadmin = isadmin
        return


class SecureObjBase(Pyro.core.ObjBase):
    """Secure Pyro object that allows only a certain set of operations"""
    
    __rmethods__ = ()
    
    def Pyro_dyncall(self, method, flags, args):
        """Overriden ObjBase method that checks incoming invocations"""
        
        if method not in self.__class__.__rmethods__:
            self.logger.critical("Pyro_dyncall refusing remote invocation "\
                                 "of method '%s' with flags '%s' and "\
                                 "arguments '%s'", method, flags, args)
            raise Pyro.errors.ConnectionDeniedError("security reasons")
        return Pyro.core.ObjBase.Pyro_dyncall(self, method, flags, args)


class Portal(SecureObjBase):
    __rmethods__ = (
        "checkID",
        "connect",
        "create",
        "disconnect",
        "commit",
        "rollback",
        "getAreas",
        "getArea",
        "getRevisions",
        "getRevisionItems",
        "getItemByName",
        "getItemByID",
        "getItems",
        "getItemsByDirname",
        "getRevNumbers",
        "getEntryList",
        "getEntry",
        "getEntries",
        "putRevision",
        "addItem",
        "bulkAddItem",
        "addEntry",
        "bulkAddEntry",
        "addArea",
        )
    
    def __init__(self, local=False, user=None, repo=None):
        SecureObjBase.__init__(self)
        if local:
            self.user = User("", "", "", "", "", True)
        else:
            self.user = user
        self.local = local
        self.repo = None
        self.repo_meth, self.repo_data = repo
        self.checks = 0
        self.logger = logging.getLogger("Portal")
        return

    def _gotReaped(self):
        if self.repo is not None:
            self.repo.rollback()
            self.repo.close()
        return

    def _validate_area(self, area):
        if not self.user.isadmin and area not in self.user.valid_areas:
            self.logger.warning("User %s tried to access area %s", self.user.username, area)
            raise Pyro.errors.ConnectionDeniedError("security reasons")
        return

    def _validate_itemid(self, itemid):
        try:
            newid = int(itemid)
        except ValueError:
            raise OperationError("Invalid item ID `%s'" % str(itemid))
        return newid

    def _validate_revno(self, revno):
        if revno is not None:
            try:
                tmp = int(revno)
            except ValueError:
                raise OperationError("Invalid revision number `%s'" % str(revno))
            revno = tmp
        return revno

    def checkID(self, token):
        if self.checks > 3:
            return None
        r = random.Random()
        l = []
        for i in range(20):
            l.append(chr(r.randint(32, 127)))
        l = "".join(l)
        pre, post = l[:10], l[10:]
        self.checks += 1
        hm = hmac.new(self.user.server_password, pre, sha)
        hm.update(token)
        hm.update(post)
        return pre, post, hm.hexdigest()
    
    def connect(self):
        self.repo = cfvers.repository.open(cnxmethod=self.repo_meth, cnxargs=self.repo_data)
        return

    def create(self, createopts):
        if not self.user.isadmin:
            self.logger.warning("Non-admin user %s tried to create an area", self.user.username)
            raise ValueError("Invalid operation attempted")
        self.repo = cfvers.repository.open(cnxmethod=self.repo_meth,
                                           cnxargs=self.repo_data,
                                           create=True,
                                           createopts=createopts)
        return

    def disconnect(self):
        if self.repo is not None:
            self.repo.close()
        if not self.local:
            self.getDaemon().disconnect(self)
        return

    def commit(self):
        self.repo.commit()
        return

    def rollback(self):
        self.repo.rollback()
        return

    def getAreas(self):
        arealist = self.repo.getAreas()
        if not self.user.isadmin:
            arealist = [area for area in arealist if area.name in self.user.valid_areas]
        return arealist

    def getArea(self, areaname):
        self._validate_area(areaname)
        return self.repo.getArea(areaname)

    def getRevisions(self, areaname):
        self._validate_area(areaname)
        return self.repo.getRevisions(areaname)

    def getRevisionItems(self, areaname, revno):
        self._validate_area(areaname)
        return self.repo.getRevisionItems(areaname, revno)

    def getItemByName(self, areaname, itemname):
        self._validate_area(areaname)
        return self.repo.getItemByName(areaname, itemname)

    def getItemByID(self, itemid):
        itemid = self._validate_itemid(itemid)
        item = self.repo.getItemByID(itemid)
        self._validate_area(item.area)
        return item

    def getItems(self, areaname):
        self._validate_area(areaname)
        itemlist = list(self.repo.getItems(areaname))
        return itemlist

    def getItemsByDirname(self, areaname, dirname):
        self._validate_area(areaname)
        return self.repo.getItemsByDirname(areaname, dirname)

    def getRevNumbers(self, itemid):
        itemid = self._validate_itemid(itemid)
        item = self.repo.getItemByID(itemid)
        self._validate_area(item.area)
        return self.repo.getRevNumbers(itemid)

    def getEntryList(self, itemid):
        itemid = self._validate_itemid(itemid)
        item = self.repo.getItemByID(itemid)
        self._validate_area(item.area)
        return self.repo.getEntryList(itemid)

    def getEntry(self, itemid, revno, do_payload=True):
        itemid = self._validate_itemid(itemid)
        revno = self._validate_revno(revno)
        item = self.repo.getItemByID(itemid)
        self._validate_area(item.area)
        return self.repo.getEntry(itemid, revno, do_payload=do_payload)

    def getEntries(self, options):
        try:
            # We use tuple(...) in order to convert generators
            # to normal tuple (generators don't work over Pyro)
            entries = tuple(self.repo.getEntries(options))
        except:
	    # FIXME no logger set up in local mode
            self.logger.exception("Exception while doing getEntries")
            raise ParsingException("Invalid argument to getEntries")
        if not self.user.isadmin:
            entries = [entry for entry in entries if entry.areaname in self.user.valid_areas]
        return entries

    def putRevision(self, ar):
        self._validate_area(ar.area)
        return self.repo.putRevision(ar)

    def addItem(self, newrev, item):
        self._validate_area(item.area)
        newrev = self._validate_revno(newrev)
        i = self.repo.addItem(item)
        e = Entry.newBorn(i, newrev)
        self.repo.addEntry(e)
        return (i, e)

    def bulkAddItem(self, newrev, bulk):
        newrev = self._validate_revno(newrev)
        areas = dict.fromkeys([item.area for item in bulk])
        for a in areas.keys():
            self._validate_area(a)
        resus = []
        for item in bulk:
            i = self.repo.addItem(item)
            e = Entry.newBorn(i, newrev)
            self.repo.addEntry(e)
            resus.append((i, e))
        return resus

    def addEntry(self, entry):
        item = self.repo.getItemByID(entry.item)
        self._validate_area(item.area)
        self.repo.addEntry(entry)
        return

    def bulkAddEntry(self, bulk):
        areas = dict.fromkeys([self.repo.getItemByID(entry.item).area for entry in bulk])
        for a in areas.keys():
            self._validate_area(a)
        results = []
        for entry in bulk:
            old = self.repo.getEntry(entry.item, None)
            if old is not None and entry == old:
                old.filecontents = None
                results.append(Result(Result.STORED_NOTCHANGED, entry=old))
            else:
                self.repo.addEntry(entry)
                if entry.status == Entry.STATUS_DELETED:
                    rcode = Result.STORED_DELETED
                else:
                    rcode = Result.STORED_OK
                entry.filecontents = None
                results.append(Result(rcode, entry=entry))
        return results
        
    def addArea(self, area):
        self._validate_area(area.name)
        self.repo.addArea(area)
        return

class PortalFactory(SecureObjBase):
    __rmethods__ = ("getPortal",)
    
    def __init__(self, users=None, repo_meth=None, repo_data=None):
        SecureObjBase.__init__(self)
        self.users = users
        self.repo = (repo_meth, repo_data)
        self.logger = logging.getLogger("PortalFactory")
        return

    def getPortal(self):
        username = self.getLocalStorage().caller.auth_name
        user = self.users.get(username)
        if user is None:
            self.logger.warning("Invalid user %s (%s) tried to login", user.username, self.getLocalStorage().caller.addr[0])
            raise Pyro.errors.ConnectionDeniedError("security reasons")
        p = Portal(local=False, user=user, repo=self.repo)
        self.getDaemon().connect(p)
        self.logger.info("User %s [%s] logged on", user.username, self.getLocalStorage().caller.addr[0])
        return p.getProxy()

class PortalValidator(DefaultConnValidator):
    # 1. client connects to server (to daemon)
    # 2. server creates challenge, sends to client and memorises it
    # 3. client receives the challenge and uses createAuthToken
    #    to create the token and sends it back
    # 4. server receives the token and passes it and other to
    #    acceptIdentification to see if allowed

    # The ident is supposed to be an (username, password) tuple
    # The token is "%s:%s" % (username, shadigest(password+challenge)) tuple
    # and can be created from ident by createAuthToken
    
    def __init__(self):
        self.userdict = {}
        self.logger = logging.getLogger("PortalValidator")
        return

    def _xform(self, s):
        return sha.new(s).hexdigest()

    def setUsers(self, userdict):
        self.userdict = userdict
        return

    def acceptIdentification(self, daemon, connection, token, challenge):
        clientip = connection.addr[0]
        try:
            login, challpas = token.split(":", 1)
        except:
            self.logger.warning("Invalid token received from %s" % clientip)
            return (0, Pyro.constants.DENIED_SECURITY)
        
        if login in self.userdict:
            if clientip not in self.userdict[login].valid_addrs:
                self.logger.warning("Failed login from %s for %s: address not allowed" % (clientip, login))
            else:
                password = self._xform(self.userdict[login].client_password)
                mytoken = self.createAuthToken((login, password), challenge, None, None, daemon)
                # Check if the username/password is valid.
                if mytoken == token:
                    connection.auth_name = login  # store for later reference by Pyro object
                    return (1, 0)
                else:
                    self.logger.warning("Failed login from %s for %s: authentication failed" % (clientip, login))
        else:
            self.logger.warning("Failed login from %s for %s: unknown user" % (clientip, login))
        return (0, Pyro.constants.DENIED_SECURITY)
		
    def createAuthToken(self, authid, challenge, peeraddr, URI, daemon):
        """Creates an authentication token"""
        # authid = (login, pass) and is what the client said _setIdentification(authid)
        # called by both client and server
        #print "Request to create token for %s, pass %s" % (authid[0], authid[1])
        return "%s:%s" % (authid[0], self._xform(authid[1]+challenge))
    
    def mungeIdent(self, ident):
        return (ident[0], self._xform(ident[1]))


syntax highlighted by Code2HTML, v. 0.9.1