"""
ldapsession.py - higher-level class for handling LDAP connections
(c) by Michael Stroeder <michael@stroeder.com>

This module is distributed under the terms of the
GPL (GNU GENERAL PUBLIC LICENSE) Version 2
(see http://www.gnu.org/copyleft/gpl.html)

$Id: ldapsession.py,v 1.146 2002/02/14 08:17:47 michael Exp $
"""

__version__ = '0.4.2'

import sys,socket,time,types,ldap,ldaputil.base

from ldapurl import LDAPUrl

# Constants for options
OPT_OFF=0
OPT_ON=1
OPT_DICT={'referrals':ldap.OPT_REFERRALS,'restart':ldap.OPT_RESTART}

OPTIONS_ATTRS = ['referrals','restart','deref']

START_TLS_NO = 0
START_TLS_TRY = 1
START_TLS_REQUIRED = 2

# Used attributes from RootDSE
ROOTDSE_ATTRS = [
  'namingContexts',
  'subschemaSubentry',
  'defaultNamingContext',
  'defaultRnrDN',
  'altServer',
  'ogSupportedProfile',
  'supportedSASLMechanisms',
  'supportedLDAPVersion',
  'supportedControl',
  'subschemaSubentry',
  'supportedFeatures',
  'vendorName','vendorVersion',
]

READ_CACHE_EXPIRE = 20

class LDAPSession:
  """
  Class for handling LDAP connection objects
  """

  def __init__(
    self,on_behalf=None,useThreadLock=1,traceLevel=0,traceFile=sys.stdout
  ):
    """Initialize a LDAPSession object"""
    self.searchRoot = u''
    self.namingContexts = []
    self._dn = u''
    self._useThreadLock = useThreadLock
    self._traceLevel = traceLevel
    self._traceFile = traceFile
    # Authentication method used.
    self.authMethod = ldap.AUTH_SIMPLE
    # Character set/encoding of data stored on this particular host
    self.charset = 'utf-8'
    # This is a dictionary for storing arbitrary objects
    # tied to a LDAP session
    self.read_cache = {}
    # Default timeout 60 seconds
    self.timeout = 60
    # Initial state of Manage DSA IT mode
    self.manageDsaItenabled = 0
    # Capable of returning only attribute types with a search call
    self.onlyAttrTypes = 1
    # Supports feature described in draft-zeilenga-ldap-opattrs
    self.supportsAllOpAttr = 0
    if not on_behalf is None:
      # IP adress, host name or other free form information
      # of proxy client
      self.onBehalf = on_behalf

  def _retryConnect(
    self,host,port,trace_level,trace_file,max_conn_try=1,conn_try_delay=0.050
  ):
    """
    Try connecting for max_conn_try times with pausing
    conn_try_delay seconds between failed connection attempts.
    """
    retry_counter = 0
    while 1:
      retry_counter += 1
      try:
        l = ldap.open(host,port,trace_level,trace_file)
      except ldap.LDAPError:
        if retry_counter>=max_conn_try:
          raise
        time.sleep(conn_try_delay)
      else:
        break
    return l

  def _supportedLDAPVersion(self):
    """
    Try to determine the highest supported protocol version
    by trying to bind anonymously
    """
    # Set protocol version to LDAPv3
    self.l.set_option(ldap.OPT_PROTOCOL_VERSION,ldap.VERSION3)
    # first try LDAPv3 bind
    try:
      # Try to bind to provoke error reponse if protocol
      # version is not supported
      self.l.bind_s('','',ldap.AUTH_SIMPLE)
    except ldap.PROTOCOL_ERROR,e:
      # Make sure that error just happened because of wrong
      # protocol version
      if hasattr(e,'args') and \
         type(e.args)==type(()) and \
         type(e.args[0])==type({}) and \
         e.args[0].get('info','').lower()=='version not supported':
        # Drop connection completely
        self.l.unbind_s() ; del self.l
        # Reconnect to host
        self.l = self._retryConnect(
          self.hostaddr,self.port,self._traceLevel,self._traceFile
        )
        # Switch to new connection to LDAPv2
        self.l.set_option(ldap.OPT_PROTOCOL_VERSION,ldap.VERSION2)
      else:
        # Raise any other error exception
        raise
      # Set currently determined protocol version
      version = 2
    else:
      self.who = ''
      version = 3
    return version

  def _startTLS(self,startTLS=START_TLS_TRY):
    """StartTLS if possible and requested"""
    if startTLS and \
       self.currentLDAPVersion>=3 and \
       hasattr(self.l,'start_tls_s'):
      try:
        self.l.start_tls_s()
      except:
        startedTLS = 0
      else:
        startedTLS = startTLS
    else:
      startedTLS = 0
    self.startedTLS = startedTLS
    return startedTLS

  def open(
    self,host,timeout=60,startTLS=START_TLS_TRY
  ):
    """
    Open a LDAP connection with separate DNS lookup
    
    host
        Either a (Unicode) string or a list of strings
        containing host:port of host(s) to connect to.
        If host is a list connecting is tried until a
        connect to a host in the list was successful.
    """
    if not host:
      raise ValueError, "No host string or list specified for %s.open()." % (
        self.__class__.__name__
      )
    self.timeout = timeout
    if type(host) in [types.StringType,types.UnicodeType]:
      host_list = [host]
    elif type(host)==types.ListType:
      host_list = host
    else:
      raise TypeError,"Parameter host must be either list of strings or single string."
    while host_list:
      hostport = host_list[0].encode('ascii')
      try:
        host_name,port = hostport.split(':',1)
      except ValueError:
        host_name,port = hostport,389
      else:
        port = int(port)
      # Do DNS lookup here to provoke DNS-related socket.error
      # exceptions being raised
      try:
        hostaddr = socket.gethostbyname(host_name)
      except socket.error:
        # Remove current host from list
        host_list.pop(0)
        if host_list:
          # Try next host
          continue
        else:
          raise
      # Try connecting to LDAP host
      try:
        self.l = self._retryConnect(
          hostaddr,port,self._traceLevel,self._traceFile
        )
      except ldap.LDAPError:
        # Remove current host from list
        host_list.pop(0)
        if host_list:
          # Try next host
          continue
        else:
          raise
      else:
        break
    self.host = hostport
    self.hostaddr,self.port = hostaddr,port
    # Get and save currently used LDAP protocol version
    self.currentLDAPVersion = self._supportedLDAPVersion()
    # Start TLS
    self._startTLS(startTLS)
    

  def unbind(self):
    """Close LDAP connection object if necessary"""
    if hasattr(self,'l'):
      try:
        self.l.unbind_s()
      except ldap.LDAPError:
        pass
      except AttributeError:
        pass
      del self.l

  def getUmichConfig(self):
    """
    Try to read entry cn=config attribute database
    from UMich LDAPv2 server derivate
    """
    try:
      ldap_result = self.readEntry(
        'cn=config',['database'],
      )
    except ldap.NO_SUCH_OBJECT:
      result = []
    except ldap.PARTIAL_RESULTS:
      result = []
    except ldap.UNDEFINED_TYPE:
      result = []
    except ldap.INAPPROPRIATE_MATCHING:
      result = []
    except ldap.OPERATIONS_ERROR:
      result = []
    else:
      if ldap_result:
	result = []
	l = ldap_result[0][1].get('database',[])
	for d in l:
          try:
            dbtype,basedn = d.split(' : ',1)
	  except ValueError:
	    pass
	  else:
            result.extend(
              [
                dn.strip()
                for dn in basedn.split(' : ')
              ]
            )
      else:
	result = []
    return [
      unicode(i,self.charset)
      for i in result
    ]

  def _forgetRootDSEAttrs(self):
    """Forget all old RootDSE values"""
    for attrtype in ROOTDSE_ATTRS:
      if hasattr(self,attrtype):
        delattr(self,attrtype)
    self.supportsAllOpAttr = 0

  def getRootDSE(self):
    """Retrieve attributes from Root DSE"""
    self._forgetRootDSEAttrs()
    self.namingContexts = []
    self.hasRootDSE = 0
    if self.currentLDAPVersion<=2 and not hasattr(self,'who'):
      self.bind('','',ldap.AUTH_SIMPLE)
    try:
      ldap_result = self.readEntry(
        '',ROOTDSE_ATTRS+['objectClass']
      )
    except ldap.NO_SUCH_OBJECT:
      pass
    except ldap.PARTIAL_RESULTS:
      pass
    except ldap.UNDEFINED_TYPE:
      pass
    except ldap.INAPPROPRIATE_MATCHING:
      pass
    except ldap.OPERATIONS_ERROR:
      pass
    except ldap.UNWILLING_TO_PERFORM:
      pass
    else:
      self.hasRootDSE = 1
      if ldap_result:
        rootDSE=ldap_result[0][1]
      else:
        rootDSE={}
      for attrtype in ROOTDSE_ATTRS:
        attrvalue = ldaputil.base.unicode_list(
          rootDSE.get(attrtype,rootDSE.get(attrtype.lower(),[])),
          self.charset
        )
        if attrvalue:
          if attrtype=='vendorName' or attrtype=='vendorVersion':
            # These are single-valued attributes by definition
            setattr(self,attrtype,attrvalue[0])
          else:
            setattr(self,attrtype,attrvalue)
      self.supportsAllOpAttr = \
        ('1.3.6.1.4.1.4203.1.5.1' in self.__dict__.get('supportedFeatures',[])) or \
        ('OpenLDAProotDSE' in rootDSE.get('objectClass',[]))
    if not hasattr(self,'namingContexts') or not self.namingContexts:
      self.namingContexts=self.getUmichConfig()
    if self.namingContexts:
      # This is a work-around for Lotus Domino and other
      # misbehaving LDAP servers returning null-byte in namingContexts
      self.namingContexts = [
        {0:dn,1:''}[dn=='\000']
        for dn in self.namingContexts
      ]
      self.namingContexts = map(ldaputil.base.normalize_dn,self.namingContexts)
      self.searchRoot=ldaputil.base.match_dnlist(self._dn,self.namingContexts)
    else:
      self.namingContexts = []
      self.searchRoot = ''
    return # getRootDSE()

  def isLeafEntry(self,dn):
    """Returns 1 if the node is a leaf entry, 0 otherwise"""
    try:
      ldap_msgid = self.l.search(
        dn.encode(self.charset),
        ldap.SCOPE_ONELEVEL,'(objectClass=*)',
        ['objectClass'],self.onlyAttrTypes
      )
      result_ldap = (None,None)
      while result_ldap==(None,None):
        result_ldap = self.l.result(ldap_msgid,0,self.timeout)
      self.l.abandon(ldap_msgid)
    except ldap.NO_SUCH_OBJECT:
      return 0
    else:
      return result_ldap[1] is None

  def getObjectClasses(self,dn):
    """Get a list of object classes associated with an entry"""
    search_result = self.readEntry(dn,['objectClass'])
    if search_result:
      entry = search_result[0][1]
      return entry.get('objectClass',entry.get('objectclass',[]))
    else:
      raise ldap.NO_SUCH_OBJECT
    return [] # getObjectClasses()

  def getSubschemaEntryDN(self,dn):
    """Get DN of subschemaSubentry for dn"""
    # Search for DN of subschemaSubentry
    search_result = self.readEntry(dn,['subschemaSubentry'])
    if search_result:
      entry = search_result[0][1]
      return entry.get('subschemaSubentry',entry.get('subschemaSubentry',[None]))[0]
    else:
      return None

  def getSubschemaEntry(self,dn):
    """Read the whole subschemaSubentry for dn"""
    subschemaSubentryDN = self.getSubschemaEntryDN(dn)
    if subschemaSubentryDN:
      # Read the whole schema entry
      search_result = self.readEntry(
        subschemaSubentryDN,
        ['objectClasses','attributeTypes','matchingRules','ldapSyntaxes'],
        '(objectClass=subschema)'
      )
      if search_result:
        return search_result[0]
      else:
        raise ldap.NO_SUCH_OBJECT
    else:
      raise ldap.NO_SUCH_OBJECT

  def getAttributeTypes(self,dn):
    """Get a list of object classes associated with an entry"""
    if dn:
      attrs = None
    else:
      attrs = ['+']+ldaputil.base.ROOTDSE_ATTRS
    search_result = self.readEntry(
      dn,attrs,self.onlyAttrTypes
    )
    if search_result:
      return search_result[0][1].keys()
    else:
      raise ldap.NO_SUCH_OBJECT
    return [] # getAttributeTypes()

  def readEntry(
    self,dn,attrtype_list=None,only_attrtypes=0,
    search_filter='(objectClass=*)',no_cache=0
  ):
    """Read a single entry"""
    if attrtype_list is None:
      acid = '__'
    else:
      acid = ','.join(attrtype_list)
    if not no_cache and \
       self.read_cache.has_key(dn) and \
       self.read_cache[dn].has_key(acid):
      timestamp,read_cache_result = self.read_cache[dn][acid]
      if timestamp+READ_CACHE_EXPIRE>time.time():
        return read_cache_result
      else:
        del self.read_cache[dn][acid]
    # Read single entry from LDAP server
    search_result = self.l.search_st(
      dn.encode(self.charset),
      ldap.SCOPE_BASE,
      search_filter,
      attrtype_list,
      0, # FAKE!!! Always set to zero to get a more efficient caching
      self.timeout
    )
    if search_result:
      # Create DN-level cache dictionary
      if not self.read_cache.has_key(dn):
        self.read_cache[dn]={}
      # Store the read entry in the time-stamped read_cache
      self.read_cache[dn][acid] = (time.time(),search_result)
    return search_result

  def existingEntry(self,dn,suppress_referrals=0):
    """Returns 1 if entry exists, 0 if NO_SUCH_OBJECT was raised."""
    try:
      self.readEntry(dn,[])
    except ldap.NO_SUCH_OBJECT:
      return 0
    except ldap.PARTIAL_RESULTS:
      if suppress_referrals:
        return 0
      else:
        raise
    else:
      return 1
      
  def setCache(self,timeout=0,maxmem=0):
    """
    Enable or disable caching in python-ldap's LDAPObject
    
    If timeout is zero or maxmem is zero the caching is completely
    switched off by calling self.l.destroy_cache().
    """
    if not maxmem or not timeout and hasattr(self.l,'destroy_cache'):
      self.l.destroy_cache()
    else:
      try:
        self.l.enable_cache(timeout,maxmem)
      except AttributeError:
        pass

  def flushCache(self):
    """Flushes all LDAP cache data"""
    self.read_cache = {}
    try:
      self.l.flush_cache()
    except AttributeError:
      pass

  def uncacheEntry(self,dn):
    """Removes all cached items of entry from cache"""
    try:
      del self.read_cache[dn]
    except KeyError:
      pass
    try:
      self.l.uncache_entry(dn.encode(self.charset))
    except AttributeError:
      pass

  def addEntry(self,dn,modlist):
    """Add single entry"""
    self.l.add_s(dn.encode(self.charset),modlist)
    return

  def modifyEntry(self,dn,modlist):
    """Modify single entry"""
    self.uncacheEntry(dn)
    self.l.modify_s(dn.encode(self.charset),modlist)
    return # modifyEntry()

  def renameEntry(self,dn,new_rdn,new_superior=None,delold=1):
    """Rename an entry"""
    self.uncacheEntry(dn)
    old_superior = ldaputil.base.ParentDN(ldaputil.base.normalize_dn(dn))
    if new_superior:
      new_superior = ldaputil.base.normalize_dn(new_superior)
    if new_superior is None or new_superior==old_superior:
      self.l.modrdn_s(
        dn.encode(self.charset),new_rdn.encode(self.charset),delold
      )
      return ','.join([new_rdn,old_superior])
    elif self.currentLDAPVersion>=3:
      self.l.rename_s(
        dn.encode(self.charset),new_rdn.encode(self.charset),
        new_superior.encode(self.charset),delold
      )
      return ','.join([new_rdn,new_superior])
    else:
      raise ValueError,"LDAPv3 required for using parameter new_superior"
    return None # renameEntry()

  def deleteEntry(self,dn):
    """Delete single entry"""
    self.uncacheEntry(dn)
    self.l.delete_s(dn.encode(self.charset))
    return # deleteEntry()

  def setDN(self,dn):
    """
    Set currently used DN and matching searchRoot attribute accordingly.
    
    if dn is None nothing happens at all
    """
    if dn!=None:
      dn=ldaputil.base.normalize_dn(dn)
      if dn!=self._dn:
        self._dn = dn
        self.searchRoot=ldaputil.base.match_dnlist(dn,self.namingContexts)

  def manageDsaIt(self,enable):
    """
    Enable or disable manageDSAit mode (see draft-zeilenga-ldap-namedref)
    Mainly this is a wrapper which saves the state.
    
    If enable is None nothing is done
    """
    if self.currentLDAPVersion>=3 and enable!=None:
      self.l.manage_dsa_it(enable)
      self.getRootDSE()
      self.manageDsaItenabled = enable

  def bind(
    self,
    who='',
    cred='',
    method=ldap.AUTH_SIMPLE,
    filtertemplate='(uid=%s)',
    loginSearchRoot=None
  ):
    """
    Bind to host
    """
    # Smart login
    if who and not ldaputil.base.is_dn(who):
      if self.searchRoot is None:
        self.getRootDSE()
      if loginSearchRoot is None:
        loginSearchRoot = self.searchRoot
      who = ldaputil.base.SmartLogin(
        self.l,
        who.encode(self.charset),
        loginSearchRoot,
        filtertemplate,
        attrnamesonly=self.onlyAttrTypes,
        timeout=self.timeout
      )
    # Call the bind
    try:
      self.l.bind_s(who.encode(self.charset),cred.encode(self.charset),method)
    finally:
      # Make sure reference to credential is deleted in any case
      del cred
    # Store bind DN
    self.who = who
    # Flush old data from cache
    self.flushCache()
    # Access to root DSE might have changed after binding
    # as another entity
    self.getRootDSE()
    return

  def valid(self):
    """return 1 if connection is valid"""
    return hasattr(self,'l')

  def __repr__(self):
    return '<LDAPSession%s: %s>' % (
      ' connected'*self.valid(),
      ','.join(
        [
          '%s:%s' % (a,repr(getattr(self,a)))
          for a in ['host','who','dn','currentLDAPVersion','onBehalf','startedTLS','hasRootDSE']
          if hasattr(self,a)
        ]
      )
    )
