#!/usr/bin/env python
#
# Copyright (C) 2001,2002 Jason R. Mastaler <jason@mastaler.com>
#
# This file is part of TMDA.
#
# TMDA is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 2 of the License, or
# (at your option) any later version.  A copy of this license should
# be included in the file COPYING.
#
# TMDA is distributed in the hope that it will be useful, but WITHOUT
# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
# FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
# for more details.
#
# You should have received a copy of the GNU General Public License
# along with TMDA; if not, write to the Free Software Foundation, Inc.,
# 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA

# Based on code from Python's (undocumented) smtpd module
# Copyright (C) 2001,2002 Python Software Foundation.

"""An authenticated ofmip proxy for TMDA.  Tag your outgoing mail through SMTP.

See <URL:http://tmda.net/tmda-ofmipd.html> for complete setup and
usage information.

Usage: %(program)s [OPTIONS]

OPTIONS:
    -h
    --help
        Print this message and exit.

    -V
    --version
        Print TMDA version information and exit.

    -d
    --debug
        Turn on debugging prints.

    -f
    --foreground
        Run in foreground.

    -b
    --background
        Run in background (default).

    -u <username>
    --username <username>
        The username that this program should run under.  The default
        is to run as the user who starts the program unless that is
        root, in which case an attempt to seteuid user `tofmipd' will be
        made.  Use this option to override these defaults.

    -p <host:port>
    --proxyport <host:port>
        The host:port to listen for incoming connections on.  The
        default is FQDN:8025 (i.e, port 8025 on the fully qualified
        domain name for the local host).

    -R proto[://host[:port]]
    --remoteauth proto[://host[:port]][/dn]
        Host to connect to to check username and password.
        - proto can be one of the following:
          `imap' (IMAP4 server)
          'imaps' (IMAP4 server over SSL)
          `pop3' (POP3 server)
          `apop' (POP3 server with APOP authentication)
          `ldap' (LDAP server)
        - host defaults to localhost
        - port defaults to 143 (imap), 993 (imaps), 110 (pop3/apop), 389 (ldap)
        - dn is mandatory for ldap and should contain a `%%s' identifying
          the username
        Examples: -R imaps://myimapserver.net
                  -R pop3://mypopserver.net:2110
                  -R ldap://host.com/cn=%%s,dc=host,dc=com

    -A <program>
    --authprog <program>
        checkpassword compatible command used to check username/password. e.g,
              `-A /usr/sbin/checkpassword-pam -s id --stdin -- /bin/true'
        The program must be able to receive the username/password pair
        on its stdin, and in the following format:
              `username\\0password\\0'

    -a <file>
    --authfile <file>
        Path to the file holding authentication information for this
        proxy.  Default location is /etc/tofmipd if running as
        root/tofmipd, otherwise ~user/.tmda/tofmipd.  Use this option
        to override these defaults.

    -C <n>
    --connections <n>
        Do not handle more than n simultaneous connections. If there
        are n active connections, defer acceptance of a new connection
        until one finishes. n must be a positive integer. Default: 20
    
    -c <directory>
    --configdir <directory>
        Base directory to search for the authenticated user's TMDA
        configuration file in.  This might be useful if you wish to
        maintain TMDA files outside the user's home directory.
        
        'username/config' will be appended to form the path.  e.g,
        `-c /var/tmda' will have tmda-ofmipd search for
        `/var/tmda/bobby/config'.  If this option is not used,
        `~user/.tmda/config' will be assumed.
"""

import getopt
import os
import socket
import sys

try:
    import paths
except ImportError:
    # Prepend /usr/lib/python2.x/site-packages/TMDA/pythonlib
    sitedir = os.path.join(sys.prefix, 'lib', 'python'+sys.version[:3],
                           'site-packages', 'TMDA', 'pythonlib')
    sys.path.insert(0, sitedir)

from TMDA import Util
from TMDA import Version

class Devnull:
    def write(self, msg): pass
    def flush(self): pass

# Some defaults
FQDN = socket.getfqdn()
DEBUGSTREAM = Devnull()
proxyport = '%s:%s' % (FQDN, 8025)
program = sys.argv[0]
configdir = None
authprog = None
foreground = None
remoteauth = { 'proto': None,
               'host':  'localhost',
               'port':  None,
               'dn':  '',
               'enable': 0,
               }
defaultauthports = { 'imap':  143,
                     'imaps': 993,
                     'apop': 110,
                     'pop3':  110,
                     'ldap':  389,
                     #                     'pop3s': 995,
                     }
connections = 20

if os.getuid() == 0:
    running_as_root = 1
else:
    running_as_root = 0

if running_as_root:
    username = 'tofmipd'
    authfile = '/etc/tofmipd'
else:
    username = None
    authfile = os.path.join(os.path.expanduser('~'), '.tmda', 'tofmipd')


def warning(msg='', exit=1):
    delimiter = '*' * 70
    if msg:
        msg = Util.wraptext(msg)
        print >> sys.stderr, '\n', delimiter, '\n', msg, '\n', delimiter, '\n'
    if exit:
        sys.exit()

# check whether we are running a recent enough Python
if not Version.PYTHON >= '2.2':
    msg = 'Python 2.2 or greater is required to run ' + program + \
          ' -- Visit http://python.org/download/ to upgrade.'
    warning(msg)

# provide disclaimer if running as root
if running_as_root:
    msg = 'WARNING: The security implications and risks of running ' + \
          program + ' in "seteuid" mode have not been fully evaluated.  ' + \
          'If you are uncomfortable with this, quit now and instead run ' + \
          program + ' under your non-privileged TMDA user account.'
    warning(msg, exit=0)
    
    
def usage(code, msg=''):
    print __doc__ % globals()
    if msg:
        print msg
    sys.exit(code)

try:
    opts, args = getopt.getopt(sys.argv[1:],
                             'p:u:R:A:a:c:C:dVhfb', ['proxyport=',
                                                     'username=',
                                                     'authfile=',
                                                     'remoteauth=',
                                                     'authprog=',
                                                     'configdir=',
                                                     'connections=',
                                                     'debug',
                                                     'version',
                                                     'help',
                                                     'foreground',
                                                     'background'])
except getopt.error, msg:
    usage(1, msg)

for opt, arg in opts:
    if opt in ('-h', '--help'):
        usage(0)
    if opt == '-V':
        print Version.ALL
        sys.exit()
    if opt == '--version':
        print Version.TMDA
        sys.exit()
    elif opt in ('-d', '--debug'):
        DEBUGSTREAM = sys.stderr
    elif opt in ('-f', '--foreground'):
        foreground = 1
    elif opt in ('-b', '--background'):
        foreground = 0
    elif opt in ('-p', '--proxyport'):
	proxyport = arg
    elif opt in ('-u', '--username'):
        username = arg
    elif opt in ('-R', '--remoteauth'):
        # arg is like: imap://host:port
        try:
            authproto, arg = arg.split('://', 1)
        except ValueError:
            authproto, arg = arg, None
        remoteauth['proto'] = authproto
        remoteauth['port'] = defaultauthports[authproto]
        if authproto not in defaultauthports.keys():
            raise ValueError, 'Protocol not supported: ' + authproto + \
                    '\nPlease pick one of ' + repr(defaultauthports.keys())
        if arg:
            try:
                arg, dn = arg.split('/', 1)
                remoteauth['dn'] = dn
            except ValueError:
                dn = ''
            try:
                authhost, authport = arg.split(':', 1)
            except ValueError:
                authhost = arg
                authport = defaultauthports[authproto]
            if authhost:
                remoteauth['host'] = authhost
            if authport:
                remoteauth['port'] = authport
        print >> DEBUGSTREAM, "auth method: %s://%s:%s/%s" % \
              (remoteauth['proto'], remoteauth['host'],
               remoteauth['port'], remoteauth['dn'])
        remoteauth['enable'] = 1
    elif opt in ('-A', '--authprog'):
        authprog = arg
    elif opt in ('-a', '--authfile'):
        authfile = arg
    elif opt in ('-c', '--configdir'):
        configdir = arg
    elif opt in ('-C', '--connections'):
        connections = arg
        

import asynchat
import asyncore
import base64
import hmac
import md5
import popen2
import random
import time


__version__ = Version.TMDA
NEWLINE = '\n'
EMPTYSTRING = ''
COMMASPACE = ', '


if remoteauth['proto'] == 'imaps':
    vmaj, vmin = sys.version_info[:2]
    # Python version 2.2 and before don't have IMAP4_SSL
    import imaplib
    if vmaj <= 2 or (vmaj == 2 and vmin <= 2):
        class IMAP4_SSL(imaplib.IMAP4):
            # extends IMAP4 class to talk SSL cause it's not yet
            # implemented in python 2.2
            def open(self, host, port):
                """Setup connection to remote server on "host:port".
                This connection will be used by the routines:
                read, readline, send, shutdown.
                """
                self.sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
                self.sock.connect((self.host, self.port))
                self.sslsock = socket.ssl(self.sock)
                self.file = self.sock.makefile('rb')
        
            def read(self, size):
                """Read 'size' bytes from remote."""
                buf = self.sslsock.read(size)
                return buf
        
            def readline(self):
                """Read line from remote."""
                line = [ ]
                c = self.sslsock.read(1)
                while c:
                    line.append(c)
                    if c == '\n':
                        break
                    c = self.sslsock.read(1)
                buf = ''.join(line)
                return buf
        
            def send(self, data):
                """Send data to remote."""
                bytes = len(data)
                while bytes > 0:
                    sent = self.sslsock.write(data)
                    if sent == bytes:
                        break   # avoid copy
                    data = data[sent:]
                    bytes = bytes - sent
    else:
        IMAP4_SSL = imaplib.IMAP4_SSL

if remoteauth['proto'] == 'ldap':
    try:
        import ldap
    except ImportError:
        raise ImportError, \
              'python-ldap (http://python-ldap.sf.net/) required.'
    if remoteauth['dn'] == '':
        print >> DEBUGSTREAM, "Error: Missing ldap dn\n"
        raise ValueError
    try:
        remoteauth['dn'].index('%s')
    except:
        print >> DEBUGSTREAM, "Error: Invalid ldap dn\n"
        raise ValueError


# Utility functions
def pipecmd(command, *strings):
    popen2._cleanup()
    cmd = popen2.Popen3(command, 1, bufsize=-1)
    cmdout, cmdin, cmderr = cmd.fromchild, cmd.tochild, cmd.childerr
    if strings:
        # Write to the tochild file object.
        for s in strings:
            cmdin.write(s)
        cmdin.flush()
        cmdin.close()
    # Read from the childerr object; command will block until exit.
    err = cmderr.read().strip()
    cmderr.close()
    # Read from the fromchild object.
    out = cmdout.read().strip()
    cmdout.close()
    # Get exit status from the wait() member function.
    return cmd.wait()


def run_authprog(username, password):
    """authprog should return 0 for auth ok, and a positive integer in
    case of a problem."""
    print >> DEBUGSTREAM, "Trying authprog method"
    return pipecmd('%s' % authprog, '%s\0%s\0' % (username, password))


def quote_rcpts(rcpttos):
    """Each address should be properly quoted to prevent malicious
    users from executing code by passing args to tmda-inject."""
    rcpttos_quoted = []
    for rcptto in rcpttos:
        rcpttos_quoted.append("'%s'" %
                              rcptto.replace
                              ("\\", "\\\\").replace("'", "'\\\\\\''"))
    return rcpttos_quoted


def run_remoteauth(username, password):
    """Authenticate username/password combination against a remote
    resource.  Return 1 upon successful authentication, and 0
    otherwise."""
    print >> DEBUGSTREAM, "trying %s authentication for %s@%s:%s" % \
          (remoteauth['proto'], username, remoteauth['host'],
           remoteauth['port'])
    port = defaultauthports[remoteauth['proto']]
    if remoteauth['proto'] == 'imap':
        import imaplib
        if remoteauth['port']:
            port = int(remoteauth['port'])
        M = imaplib.IMAP4(remoteauth['host'], port)
        try:
            M.login(username, password)
            M.logout()
            return 1
        except:
            print >> DEBUGSTREAM, "imap authentication for %s@%s failed" % \
                  (username, remoteauth['host'])
            return 0
    elif remoteauth['proto'] == 'imaps':
        import imaplib
        if remoteauth['port']:
            port = int(remoteauth['port'])
        M = IMAP4_SSL(remoteauth['host'], port)
        try:
            M.login(username, password)
            M.logout()
            return 1
        except:
            print >> DEBUGSTREAM, "imaps authentication for %s@%s failed" % \
                  (username, remoteauth['host'])
            return 0
    elif remoteauth['proto'] in ('pop3', 'apop'):
        import poplib
        if remoteauth['port']:
            port = int(remoteauth['port'])
        M = poplib.POP3(remoteauth['host'], port)
        try:
            if remoteauth['proto'] == 'pop3':
                M.user(username)
                M.pass_(password)
                M.quit()
                return 1
            else:
                M.apop(username, password)
                M.quit()
                return 1
        except:
            print >> DEBUGSTREAM, "%s authentication for %s@%s failed" % \
                  (remoteauth['proto'], username, remoteauth['host'])
            return 0
    elif remoteauth['proto'] == 'ldap':
        import ldap
        if remoteauth['port']:
            port = int(remoteauth['port'])
        try:
            M = ldap.initialize("ldap://%s:%s" % (remoteauth['host'],
                                                  remoteauth['port']))
            M.simple_bind_s(remoteauth['dn'] % username, password)
            M.unbind_s()
            return 1
        except:
            print >> DEBUGSTREAM, "ldap authentication for %s@%s failed" % \
                  (username, remoteauth['host'])
            return 0
    # proto not implemented
    print >> DEBUGSTREAM, "Error: protocol %s not implemented" % \
            remoteauth['proto']
    return 0


def authfile2dict(authfile):
    """Iterate over a tmda-ofmipd authentication file, and return a
    dictionary containing username:password pairs.  Username is
    returned in lowercase."""
    authdict = {}
    fp = file(authfile, 'r')
    for line in fp:
        line = line.strip()
        if line == '':
            continue
        else:
            fields = line.split(':', 1)
            authdict[fields[0].lower().strip()] = fields[1].strip()
    fp.close()
    return authdict


def b64_encode(s):
    """base64 encoding without the trailing newline."""
    return base64.encodestring(s)[:-1]


def b64_decode(s):
    """base64 decoding."""
    return base64.decodestring(s)



# Classes

class SMTPChannel(asynchat.async_chat):
    COMMAND = 0
    DATA = 1
    AUTH = 2
    
    def __init__(self, server, conn, addr):
        asynchat.async_chat.__init__(self, conn)
        # SMTP AUTH
        self.__smtpauth = 0
        self.__auth_resp1 = None
        self.__auth_resp2 = None
        self.__auth_username = None
        self.__auth_password = None
        self.__auth_sasl = None
        self.__sasl_types = ['login', 'cram-md5', 'plain']
        self.__auth_cram_md5_ticket = '<%s.%s@%s>' % (random.randrange(10000),
                                                      int(time.time()), FQDN)
        self.__server = server
        self.__conn = conn
        self.__addr = addr
        self.__line = []
        self.__state = self.COMMAND
        #self.__greeting = 0
        self.__mailfrom = None
        self.__rcpttos = []
        self.__data = ''
        self.__fqdn = FQDN
        self.__peer = conn.getpeername()
        print >> DEBUGSTREAM, 'Peer:', repr(self.__peer)
        self.push('220 %s ESMTP tmda-ofmipd' % (self.__fqdn))
        self.set_terminator('\r\n')

    # Overrides base class for convenience
    def push(self, msg):
        asynchat.async_chat.push(self, msg + '\r\n')

    # Implementation of base class abstract method
    def collect_incoming_data(self, data):
        self.__line.append(data)

    def verify_login(self, b64username, b64password):
        """The LOGIN SMTP authentication method is an undocumented,
        unstandardized Microsoft invention.  Needed to support MS
        Outlook clients."""
        try:
            username = b64_decode(b64username)
            password = b64_decode(b64password)
        except:
            return 501
        self.__auth_username = username.lower()
        self.__auth_password = password
        if remoteauth['enable']:
            # Try first with the remote auth
            if run_remoteauth(username, password):
                return 1
        if authprog:
            # Then with the authprog
            if run_authprog(username, password) == 0:
                return 1
        # Now we can fall back on the authfile
        authdict = authfile2dict(authfile)
        if authdict.get(username.lower(), 0) != password:
            return 0
        else:
            return 1

    def verify_plain(self, response):
        """PLAIN is described in RFC 2595."""
        try:
            response = b64_decode(response)
        except:
            return 501
        try:
            username, username, password = response.split('\0')
        except ValueError:
            return 0
        self.__auth_username = username.lower()
        self.__auth_password = password
        if remoteauth['enable']:
            # Try first with the remote auth
            if run_remoteauth(username, password):
                return 1
        if authprog:
            # Then with the authprog
            if run_authprog(username, password) == 0:
                return 1
        # Now we can fall back on the authfile
        authdict = authfile2dict(authfile)
        if authdict.get(username.lower(), 0) != password:
            return 0
        else:
            return 1

    def verify_cram_md5(self, response, ticket):
        """CRAM-MD5 is described in RFC 2195."""
        try:
            response = b64_decode(response)
        except:
            return 501
        try:
            username, hexdigest = response.split()
        except ValueError:
            return 0
        authdict = authfile2dict(authfile)
        password = authdict.get(username.lower(), 0)
        self.__auth_username = username.lower()
        self.__auth_password = password
        if password == 0:
            return 0
        newhexdigest = hmac.HMAC(password, ticket, digestmod=md5).hexdigest()
        if newhexdigest != hexdigest:
            return 0
        else:
            return 1

    def auth_reset_state(self):
        """As per RFC 2554, the SMTP state is reset if the AUTH fails,
        and once it succeeds."""
        self.__auth_sasl = None
        self.__auth_resp1 = None
        self.__auth_resp2 = None
        self.__state = self.COMMAND

    def auth_notify_required(self):
        """Send a 530 reply.  RFC 2554 says this response may be
        returned by any command other than AUTH, EHLO, HELO, NOOP,
        RSET, or QUIT. It indicates that server policy requires
        authentication in order to perform the requested action."""
        self.push('530 Error: Authentication required')
        
    def auth_notify_fail(self, failcode=0):
        if failcode == 501:
            # base64 decoding failed
            self.push('501 malformed AUTH input')
        else:
            self.push('535 AUTH failed')
        print >> DEBUGSTREAM, 'Auth: ', 'failed for user', \
              "'%s'" % self.__auth_username
        self.__smtpauth = 0

    def auth_notify_succeed(self):
        self.push('235 AUTH successful')
        print >> DEBUGSTREAM, 'Auth: ', 'succeeded for user', \
              "'%s'" % self.__auth_username
        self.__smtpauth = 1

    def auth_verify(self, sasl=None):
        if sasl is None:
            sasl = self.__auth_sasl
        verify = 0
        if sasl == 'plain':
            verify = self.verify_plain(self.__auth_resp1)
        elif sasl == 'cram-md5':
            verify = self.verify_cram_md5(self.__auth_resp1,
                                          self.__auth_cram_md5_ticket)
        elif sasl == 'login':
            verify =  self.verify_login(self.__auth_resp1,
                                        self.__auth_resp2)
        if verify == 1:
            self.auth_notify_succeed()
        else:
            self.auth_notify_fail(verify)
        self.auth_reset_state()
            
    def auth_challenge(self):
        line = EMPTYSTRING.join(self.__line)
        if not self.__auth_resp1:
            # No initial response, issue first server challenge
            if self.__auth_sasl == 'plain':
                self.push('334 ')
            elif self.__auth_sasl == 'cram-md5':
                self.push('334 ' + b64_encode(self.__auth_cram_md5_ticket))
            elif self.__auth_sasl == 'login':
                self.push('334 VXNlcm5hbWU6')
            return
        if self.__auth_resp1 and not self.__auth_resp2:
            # Client sent an initial response
            if self.__auth_sasl == 'plain':
                # Perform authentication
                self.auth_verify()
            elif self.__auth_sasl == 'cram-md5':
                # Perform authentication
                self.auth_verify()
            elif self.__auth_sasl == 'login':
                # Issue second server challenge
                self.push('334 UGFzc3dvcmQ6')
            return
        if self.__auth_resp1 and self.__auth_resp2:
            # Client sent a second response (only if AUTH=LOGIN),
            # perform authentication
            self.auth_verify()
            return

    # Implementation of base class abstract method
    def found_terminator(self):
        line = EMPTYSTRING.join(self.__line)
        print >> DEBUGSTREAM, 'Data:', repr(line)
        self.__line = []
        if self.__state == self.COMMAND:
            if not line:
                self.push('500 Error: bad syntax')
                return
            method = None
            i = line.find(' ')
            if i < 0:
                command = line.upper()
                arg = None
            else:
                command = line[:i].upper()
                arg = line[i+1:].strip()
            method = getattr(self, 'smtp_' + command, None)
            if not method:
                self.push('502 Error: command "%s" not implemented' % command)
                return
            method(arg)
            return
        elif self.__state == self.DATA:
            # Remove extraneous carriage returns and de-transparency according
            # to RFC 2821, Section 4.5.2.
            data = []
            for text in line.split('\r\n'):
                if text and text[0] == '.':
                    data.append(text[1:])
                else:
                    data.append(text)
            self.__data = NEWLINE.join(data)
            status = self.__server.process_message(self.__peer,
                                                   self.__mailfrom,
                                                   self.__rcpttos,
                                                   self.__data,
                                                   self.__auth_username)
            self.__rcpttos = []
            self.__mailfrom = None
            self.__state = self.COMMAND
            self.set_terminator('\r\n')
            if not status:
                self.push('250 Ok')
            else:
                self.push(status)
        elif self.__state == self.AUTH:
            if line == '*':
                # client canceled the authentication attempt
                self.push('501 AUTH exchange cancelled')
                self.auth_reset_state()
                return
            if not self.__auth_resp1:
                self.__auth_resp1 = line
            else:
                self.__auth_resp2 = line
            self.auth_challenge()
        else:
            self.push('451 Internal confusion')
            return

    # ESMTP/SMTP commands

    def smtp_EHLO(self, arg):
        if not arg:
            self.push('501 Syntax: EHLO hostname')
            return
        #self.__greeting = arg
        self.push('250-%s' % self.__fqdn)
        self.push('250 AUTH %s' % (' '.join(map(lambda s: s.upper(),
                                                self.__sasl_types))))

    def smtp_NOOP(self, arg):
        if arg:
            self.push('501 Syntax: NOOP')
        else:
            self.push('250 Ok')

    def smtp_QUIT(self, arg):
        # args is ignored
        self.push('221 Bye')
        self.close_when_done()

    # factored
    def __getaddr(self, keyword, arg):
        address = None
        keylen = len(keyword)
        if arg[:keylen].upper() == keyword:
            address = arg[keylen:].strip()
            if not address:
                pass
            elif address[0] == '<' and address[-1] == '>' and address != '<>':
                # Addresses can be in the form <person@dom.com> but watch out
                # for null address, e.g. <>
                address = address[1:-1]
        return address

    def smtp_MAIL(self, arg):
        # Authentication required first
        if not self.__smtpauth:
            self.auth_notify_required()
            return
        print >> DEBUGSTREAM, '===> MAIL', arg
        address = self.__getaddr('FROM:', arg)
        if not address:
            self.push('501 Syntax: MAIL FROM:<address>')
            return
        if self.__mailfrom:
            self.push('503 Error: nested MAIL command')
            return
        self.__mailfrom = address
        print >> DEBUGSTREAM, 'sender:', self.__mailfrom
        self.push('250 Ok')

    def smtp_RCPT(self, arg):
        print >> DEBUGSTREAM, '===> RCPT', arg
        if not self.__mailfrom:
            self.push('503 Error: need MAIL command')
            return
        address = self.__getaddr('TO:', arg)
        if not address:
            self.push('501 Syntax: RCPT TO: <address>')
            return
        self.__rcpttos.append(address)
        print >> DEBUGSTREAM, 'recips:', self.__rcpttos
        self.push('250 Ok')

    def smtp_RSET(self, arg):
        if arg:
            self.push('501 Syntax: RSET')
            return
        # Resets the sender, recipients, and data, but not the greeting
        self.__mailfrom = None
        self.__rcpttos = []
        self.__data = ''
        self.__state = self.COMMAND
        self.push('250 Ok')

    def smtp_DATA(self, arg):
        if not self.__rcpttos:
            self.push('503 Error: need RCPT command')
            return
        if arg:
            self.push('501 Syntax: DATA')
            return
        self.__state = self.DATA
        self.set_terminator('\r\n.\r\n')
        self.push('354 End data with <CR><LF>.<CR><LF>')

    def smtp_AUTH(self, arg):
        """RFC 2554 - SMTP Service Extension for Authentication"""
        if self.__smtpauth:
            # After an successful AUTH, no more AUTH commands may be
            # issued in the same session.
            self.push('503 Duplicate AUTH')
            return
        if arg:
            args = arg.split()
            if len(args) == 2:
                self.__auth_sasl = args[0]
                self.__auth_resp1 = args[1]
            else:
                self.__auth_sasl = args[0]
        if self.__auth_sasl:
            self.__auth_sasl = self.__auth_sasl.lower()
        if not arg or self.__auth_sasl not in self.__sasl_types:
            self.push('504 AUTH type unimplemented')
            return
        self.__state = self.AUTH
        self.auth_challenge()


class SMTPServer(asyncore.dispatcher):
    """The base class for the backend.  Raises NotImplementedError if
    you try to use it."""
    def __init__(self, localaddr, remoteaddr):
        self._localaddr = localaddr
        self._remoteaddr = remoteaddr
        asyncore.dispatcher.__init__(self)
        self.create_socket(socket.AF_INET, socket.SOCK_STREAM)
        # try to re-use a server port if possible
        self.set_reuse_addr()
        self.bind(localaddr)
        self.listen(5)
        print >> DEBUGSTREAM, \
              'tmda-ofmipd started at %s\n\tListening on %s' % \
              (Util.unixdate(), proxyport)

    def readable(self):
        if len(asyncore.socket_map) > int(connections):
            # too many simultaneous connections
            return 0
        else:
            return 1
        
    def handle_accept(self):
        conn, addr = self.accept()
        print >> DEBUGSTREAM, 'Incoming connection from %s' % repr(addr)
        channel = SMTPChannel(self, conn, addr)

    # API for "doing something useful with the message"
    def process_message(self, peer, mailfrom, rcpttos, data):
        """Override this abstract method to handle messages from the client.

        peer is a tuple containing (ipaddr, port) of the client that made the
        socket connection to our smtp port.

        mailfrom is the raw address the client claims the message is coming
        from.

        rcpttos is a list of raw addresses the client wishes to deliver the
        message to.

        data is a string containing the entire full text of the message,
        headers (if supplied) and all.  It has been `de-transparencied'
        according to RFC 821, Section 4.5.2.  In other words, a line
        containing a `.' followed by other text has had the leading dot
        removed.

        This function should return None, for a normal `250 Ok' response;
        otherwise it returns the desired response string in RFC 821 format.

        """
        raise NotImplementedError


class DebuggingServer(SMTPServer):
    """Simply prints each message it receives on stdout."""
    # Do something with the gathered message
    def process_message(self, peer, mailfrom, rcpttos, data):
        inheaders = 1
        lines = data.split('\n')
        print '---------- MESSAGE FOLLOWS ----------'
        for line in lines:
            # headers first
            if inheaders and not line:
                print 'X-Peer:', peer[0]
                inheaders = 0
            print line
        print '------------ END MESSAGE ------------'


class PureProxy(SMTPServer):
    """Proxies all messages to a real smtpd which does final
    delivery."""
    def process_message(self, peer, mailfrom, rcpttos, data):
        lines = data.split('\n')
        # Look for the last header
        i = 0
        for line in lines:
            if not line:
                break
            i += 1
        lines.insert(i, 'X-Peer: %s' % peer[0])
        data = NEWLINE.join(lines)
        refused = self._deliver(mailfrom, rcpttos, data)
        # TBD: what to do with refused addresses?
        print >> DEBUGSTREAM, 'we got some refusals:', refused

    def _deliver(self, mailfrom, rcpttos, data):
        import smtplib
        refused = {}
        try:
            s = smtplib.SMTP()
            s.connect(self._remoteaddr[0], self._remoteaddr[1])
            try:
                refused = s.sendmail(mailfrom, rcpttos, data)
            finally:
                s.quit()
        except smtplib.SMTPRecipientsRefused, e:
            print >> DEBUGSTREAM, 'got SMTPRecipientsRefused'
            refused = e.recipients
        except (socket.error, smtplib.SMTPException), e:
            print >> DEBUGSTREAM, 'got', e.__class__
            # All recipients were refused.  If the exception had an associated
            # error code, use it.  Otherwise,fake it with a non-triggering
            # exception code.
            errcode = getattr(e, 'smtp_code', -1)
            errmsg = getattr(e, 'smtp_error', 'ignore')
            for r in rcpttos:
                refused[r] = (errcode, errmsg)
        return refused


class TMDAProxy(PureProxy):
    """Using this server for outgoing smtpd, the authenticated user
    will have his mail tagged using his TMDA config file."""
    def process_message(self, peer, mailfrom, rcpttos, data, auth_username):
        if configdir is None:
            # ~user/.tmda/
            tmda_configdir = os.path.join(os.path.expanduser
                                          ('~' + auth_username), '.tmda')
        else:
            tmda_configdir = os.path.join(os.path.expanduser
                                          (configdir), auth_username)
        tmda_configfile = os.path.join(tmda_configdir, 'config')
        execdir = os.path.dirname(os.path.abspath(program))
        inject_path = os.path.join(execdir, 'tmda-inject')
        inject_cmd = '%s --config-file %s' % (inject_path, tmda_configfile)
        # If running as uid 0, fork the tmda-inject process, and
        # then change UID and GID to the authenticated user.
        if running_as_root:
            pid = os.fork()
            if pid == 0:
                os.seteuid(0)
                os.setgid(Util.getgid(auth_username))
                os.setgroups(Util.getgrouplist(auth_username))
                os.setuid(Util.getuid(auth_username))
                # This is so "~" will work in the .tmda/* files.
                os.environ['HOME'] = Util.gethomedir(auth_username)
                try:
                    Util.pipecmd('%s %s' %
                                 (inject_cmd, ' '.join
                                  (quote_rcpts(rcpttos))), data)
                except Exception, err:
                    print >> DEBUGSTREAM, 'Error:', err
                    os._exit(-1)
                os._exit(0)
            else:
                rpid, status = os.wait()
                # Did tmda-inject succeed?
                if status != 0:
                    raise IOError, 'tmda-inject failed!'
        else:
            # no need to fork
            Util.pipecmd('%s %s' % (inject_cmd, ' '.join
                                    (quote_rcpts(rcpttos))), data)


def main():
    # check permissions of authfile
    authfile_mode = Util.getfilemode(authfile)
    if authfile_mode not in (400, 600):
        raise IOError, \
              authfile + ' must be chmod 400 or 600!'
    # try binding to the specified host:port
    host, port = proxyport.split(':', 1)
    proxy = TMDAProxy((host, int(port)),
                      ('localhost', 25))
    if running_as_root:
        pw_uid = Util.getuid(username)
        # check ownership of authfile
        if Util.getfileuid(authfile) != pw_uid:
            raise IOError, \
                  authfile + ' must be owned by UID ' + str(pw_uid)
        # try setegid()
        os.setegid(Util.getgid(username))
        # try setting the supplemental group ids
        os.setgroups(Util.getgrouplist(username))
        # try seteuid()
        os.seteuid(pw_uid)

    # Issue a warning if neither -f nor -b options specified
    #if foreground is None:
    #    print "WARNING: you should specify -b",
    #    print "(background) or -f (foreground) option."
    #    print "The default (background) behavior",
    #    print "could be changed in a future version."
    # Try to fork to go to daemon unless foreground mode
    if not foreground:
        if os.fork() != 0:
            sys.exit()

    # Start the event loop
    try:
        asyncore.loop()
    except KeyboardInterrupt:
        pass


# This is the end my friend.
if __name__ == '__main__':
    main()
