"""
ldapbase.py - basic LDAP functions
(c) by Michael Stroeder <michael@stroeder.com>

This module is distributed under the terms of the
GPL (GNU GENERAL PUBLIC LICENSE) Version 2
(see http://www.gnu.org/copyleft/gpl.html)
"""

__version__ = '0.2.0'

import sys, string, re, UserDict, urllib, ldap

DNS_TYPE_SRV = 33

ldap_mod_str = {ldap.MOD_ADD:'add',ldap.MOD_REPLACE:'replace',ldap.MOD_DELETE:'delete'}
ldap_searchscope_str = ['base','one','sub']
ldap_searchscope = {
  'base':ldap.SCOPE_BASE,
  'one':ldap.SCOPE_ONELEVEL,
  'sub':ldap.SCOPE_SUBTREE
}

attr_pattern = r'[\w;.]+(;[\w_-]+)*'
data_pattern = '(([^,]|\\,)+|".*?")'
rdn_pattern = attr_pattern + r'[\s]*=[\s]*' + data_pattern
dn_pattern   = rdn_pattern + r'([\s]*,[\s]*' + rdn_pattern + r')*[ ]*'
host_pattern = r'([0-9]+\.[0-9]+\.[0-9]+\.[0-9]+|[a-zA-Z]+[a-zA-Z0-9-]*(\.[a-zA-Z]+[a-zA-Z0-9-]*)*)+(:[0-9]*)*'

#rdn_regex   = re.compile('^%s$' % rdn_pattern)
dn_regex      = re.compile(u'^%s$' % unicode(dn_pattern))
host_regex = re.compile(host_pattern)


def unicode_list(l,charset='utf-8'):
  """
  return list of Unicode objects
  """
  return map(
    lambda i,c=charset:unicode(i,c),
    l
  )

def encode_unicode_list(l,charset='utf-8'):
  """
  Encode the list of Unicode objects with given charset
  and return list of encoded strings
  """
  return map(
    lambda i,c=charset:i.encode(c),
    l
  )

def is_ldap_url(s):
  """Fail-safe wrapper function for ldap.is_ldap_url()"""
  if type(s)==type(u''):
    s=s.encode('utf-8')
  try:
    string_is_ldap_url = ldap.is_ldap_url(s)
  except TypeError:
    return 0
  else:
    return string_is_ldap_url

def is_dn(s):
  """returns 1 if s is a LDAP DN"""
  if s:
    rm = dn_regex.match(s)
    return rm!=None
  else:
    return 0

def explode_dn(dn):
  """Wrapper function for explode_dn() which returns [] for 
     zero-length dn
  """
  if dn:
    return map(lambda x:unicode(x,'utf-8'),ldap.explode_dn(dn.encode('utf-8')))
  else:
    return []

def normalize_dn(dn):
  result = explode_dn(dn)
  return string.join(result,',')

def match_dn(dn,dnlist):
  """find best matching parent DN of dn in dnlist"""
  # local symbol for function speeds things up
  match_levels = 0 ; result = None
  dn_e = map(string.lower,explode_dn(dn))
  for parentdn in dnlist:
    if is_dn(parentdn):
      parentdn_e = explode_dn(parentdn)
      if len(parentdn_e)>match_levels and\
         dn_e[-len(parentdn_e):] == map(string.lower,parentdn_e):
        match_levels = len(parentdn_e) ; result = parentdn_e
  if result:
    return string.join(result,',')
  else:
    return None

def extract_referrals(e):
  """Extract the referral LDAP URL from a
     ldap.PARTIAL_RESULTS exception object"""
  if e.args[0].has_key('info'):
    info, ldap_url_info = map(
      string.strip,
      string.split(
        e.args[0]['info'],
        '\n',
        1
      )
    )
  else:
    raise ValueError, "Referral exception object does not have info field"
  ldap_urls = map(string.strip,string.split(ldap_url_info,'\n'))
  matched = e.args[0].get('matched',None)
  return (matched,ldap_urls)

def ParentDN(dn):
  """returns parent-DN of dn"""
  dn_comp = explode_dn(dn)
  if len(dn_comp)>1:
    return string.join(dn_comp[1:],',')
  elif len(dn_comp)==1:
    return ''
  else:
    return None

def SplitRDN(dn):
  """returns tuple (RDN,base DN) of dn"""
  dn_comp = explode_dn(dn)
  if len(dn_comp)>1:
    return dn_comp[0], string.join(dn_comp[1:],',')
  elif len(dn_comp)==1:
    return dn_comp[0], ''
  else:
    return None

def ParentDNList(dn):
  """returns a list of parent-DNs of dn"""
  result = []
  DNComponentList = explode_dn(dn)
  for i in range(1,len(DNComponentList)):
    result.append(string.join(DNComponentList[i:],','))
  return result


def parse_ldap_url(ldap_url):
  """
  parse a LDAP URL and return (host,dn,attrs,scope,filter)

  host         LDAP host
  dn           distinguished name (search root)
  attrs        list of attribute types
  scope        integer search scope for ldap-module
  filter       LDAP search filter
  extensions   list of extensions
  """
  dn = '' ; attr_list = [] ; search_scope = ldap.SCOPE_BASE ; ldap_filter = '(objectclass=*)' ; extensions = [] 
  # local symbol for function speeds things up
  string_split = string.split
  dummy,rest = string_split(ldap_url,'://',1)
  try:
    host,rest = string_split(rest,'/',1)
  except ValueError:
    host=rest
  else:
    paramlist=string_split(rest,'?')
    paramlist_len = len(paramlist)
    if paramlist_len>=1:
      dn = urllib.unquote_plus(paramlist[0])
    if (paramlist_len>=2) and (paramlist[1]):
      attr_list = string_split(urllib.unquote_plus(paramlist[1]),',')
    if paramlist_len>=3:
      try:
        search_scope = ldap_searchscope[paramlist[2]]
      except KeyError:
        raise ValueError, "Search scope must be either one of base, one or sub"
    if paramlist_len>=4:
      ldap_filter = urllib.unquote_plus(paramlist[3])
    if paramlist_len>=5:
      extensions = string_split(paramlist[4],',')
  return (host,dn,attr_list,search_scope,ldap_filter,extensions)


def create_ldap_url(
  hostport='',		# host:port of LDAP server
  dn='',		# DN of search root
  attrs=[],		# list with attribute types
  scope=ldap.SCOPE_BASE,# Search scope
  			# (can be either string or constant from ldap-module)
  filter='(objectclass=*)',# LDAP filter string according to RFC2254
  extensions='',	# Extensions
  urlencode=0		# Apply URL-encoding
):
  """
  build LDAP URL of search query according to RFC2255

  hostport     host:port
  dn           distinguished name
  attributes   list with attributes
  scope        search scope string
  filter       LDAP search filter
  ldapurl    = scheme "://" [hostport] ["/"
                   [dn ["?" [attrs] ["?" [scope]
                   ["?" [filter] ["?" extensions]]]]]]
  """
  def urlencode(s):
    return string.replace(string.replace(urllib.quote(s),',','%2C'),'/','%2F')

  attrs_str = string.join(attrs,',')

  if urlencode:
    dn        = urlencode(dn)
    filter    = urlencode(filter)
    attrs_str = urlencode(attrs_str)

  if type(scope)==type(''):
    scope_str = scope
  elif type(scope)==type(ldap.SCOPE_SUBTREE):
    scope_str = ldap_searchscope_str[scope]
  else:
    raise TypeError, 'scope has to be of type either string or integer.'
  return 'ldap://%s/%s?%s?%s?%s' % (
    hostport,
    dn,
    attrs_str,
    scope_str,
    filter
  )

def add_modifylist(entry,ldap_charset='utf-8'):
  """Build modify list for call of method add()"""
  modlist = []
  for attr in entry.keys():
    modlist.append(
      (
        attr.encode(ldap_charset),
        map(lambda d,cs=ldap_charset:d.encode(cs),entry[attr])
      )
    )
  return modlist

def modify_modifylist(old_entry,new_entry,ldap_charset='utf-8'):
  """Build differential modify list for call of method modify()"""
  modlist = []
  old_entry_keys = map(string.lower,old_entry.keys())
  for attr in new_entry.keys():
    attr_lower = string.lower(attr)
    if attr_lower in old_entry_keys:
      old_data = unicode_list(
        old_entry.get(attr,old_entry.get(attr_lower,[])),
        ldap_charset
      )
      if new_entry[attr]!=old_data:
	if old_data and (new_entry[attr]==[]) or (new_entry[attr]==['']):
	  # delete an existing attribute
	  modlist.append(
            (
              ldap.MOD_DELETE,
              attr.encode(ldap_charset),
              map(lambda d,cs=ldap_charset:d.encode(cs),old_data)
            )
          )
	else:
#          print "new_entry[attr]",new_entry[attr]
	  # modify an existing attribute
	  modlist.append(
            (
              ldap.MOD_REPLACE,
              attr.encode(ldap_charset),
              map(lambda d,cs=ldap_charset:d.encode(cs),new_entry[attr])
            )
          )
    else:
      for data in new_entry[attr]:
	if data:
#          print "data",repr(data)
	  # append new attribute
	  modlist.append((ldap.MOD_ADD,attr.encode(ldap_charset),data.encode(ldap_charset)))
#  print repr(modlist)
  return modlist

def escape_filter_chars(s):
  """RFC2254: Convert special characters in search filter to escaped hex"""
  string_replace = string.replace
  for ch in ['\\','*','(',')','\000']:
    s=string_replace(s,ch,'\%02X' % ord(ch))
  return s

def escape_binary_dnchars(s):
  """Convert NON-ASCII characters in DN to escaped hex form"""
  new = ''
  for ch in s:
    c=ord(ch)
    if (c<32) or (c>=128):
      new = new+('\%02X' % c)
    else:
      new = new+ch
  return new


# Regex object used for finding hex-escaped characters
hex_dnchars_regex = re.compile(r'\\[0-9a-fA-F][0-9a-fA-F]')

def unescape_binary_dnchars(s):
  """Convert hex-escaped characters in DN to binary NON-ASCII chars"""
  # local symbol for function speeds things up
  string_join = string.join
  asciiparts = hex_dnchars_regex.split(s)
  hexparts = hex_dnchars_regex.findall(s)
  new = asciiparts[0]
  for i in range(1,len(asciiparts)):
    new = string_join([new,asciiparts[i]],chr(string.atoi(hexparts[i-1][1:],16)))
  return new

def SearchTree(l=None,ldap_dn='',attrsonly=1,timeout=-1):
  """Returns all DNs of the sub tree at ldap_dn."""
  stack, result = [ldap_dn],[]
  while stack:
    dn = stack.pop()
    result.append(dn)
    try:
      r = l.search_st(dn,ldap.SCOPE_ONELEVEL,'(objectclass=*)',['objectclass'],attrsonly,timeout)
    except ldap.NO_SUCH_OBJECT:
      r = []
    stack.extend(map(lambda x: x[0],r))
  return result


def GetNamingContexts(l,who='',cred='',auth_type=ldap.AUTH_SIMPLE):
  """
  This function trys to return a list of possible base-DNs
  for which a server holds the entries.

  The following steps are taken for querying the server configuration:
  1. reading the namingContexts attribute
  in the RootDSE of a LDAPv3 server or
  2. reading the cn=config entry
  of a UMich-derived LDAPv2 server
  l is a LDAP object returned by ldap.open(host)
  The calling application has to restore the l.options attribute!
  """
  if hasattr(l,'options'):
    options = 0
  if who is None:
    who=''
  if cred is None:
    cred=''
  l.bind_s(who,cred,auth_type)
  result = None
  # Try to read namingContexts attribute from RootDSE of LDAPv3 server
  try:
    r = l.search_s('',ldap.SCOPE_BASE,"(objectclass=*)",['namingContexts'],0)
  except ldap.NO_SUCH_OBJECT:
    pass
  except ldap.PARTIAL_RESULTS:
    pass
  except ldap.UNDEFINED_TYPE:
    pass
  except ldap.INAPPROPRIATE_MATCHING:
    pass
  except ldap.OPERATIONS_ERROR:
    pass
  else:
    if r:
      result = r[0][1].get('namingContexts',r[0][1].get('namingcontexts',[]))
  if result is None:
    # Try to read entry cn=config attribute database from UMich LDAPv2 server derivate
    try:
      r = l.search_s('cn=config',ldap.SCOPE_BASE,"(objectclass=*)",['database'],0)
    except ldap.NO_SUCH_OBJECT:
      result = []
    except ldap.PARTIAL_RESULTS:
      result = []
    except ldap.UNDEFINED_TYPE:
      result = []
    except ldap.INAPPROPRIATE_MATCHING:
      result = []
    except ldap.OPERATIONS_ERROR:
      result = []
    else:
      if r:
	result = []
	l = r[0][1].get('database',[])
	for d in l:
          try:
            dbtype,basedn = string.split(d,' : ',1)
	  except ValueError:
	    pass
	  else:
            result.extend(map(string.strip,string.split(basedn,' : ')))
	result = result
      else:
	result = []
  return map(lambda i:unicode(i,'utf-8'),result)


class USERNAME_NOT_FOUND(ldap.LDAPError):
  """
  Simple exception class raised when SmartLogin() does not
  find any entry matching search
  """
  pass

class USERNAME_NOT_UNIQUE(ldap.LDAPError):
  """
  Simple exception class raised when SmartLogin() does not
  finds more than one entry matching search
  """
  pass

def SmartLogin(
  l,                        # LDAP connection object created with ldap.open()
  username='',              # User name or complete bind DN (UTF-8 encoded)
  searchroot='',            # search root for user entry search
  filtertemplate='(uid=%s)',# template string for LDAP filter
  scope=ldap.SCOPE_SUBTREE, # search scope
  attrnamesonly=1           # retrieve only attribute names when searching
):
  """
  Do a smart login.
  
  If username is a valid DN it's used as bind-DN without further action.
  Otherwise filtertemplate is used to construct a LDAP search filter
  containing username instead of %s.
  The calling application has to handle all possible exceptions:
  ldap.NO_SUCH_OBJECT, ldap.FILTER_ERROR, ldapbase.USERNAME_NOT_UNIQUE
  ldap.INVALID_CREDENTIALS, ldap.INAPPROPRIATE_AUTH
  """
  if not username:
    return None
  elif is_dn(username):
    return normalize_dn(username)
  else:
    searchfilter = string.replace(filtertemplate,'%s',username)
    if searchroot is None:
      searchroot = ''
    # Try to find a unique entry with filtertemplate
    try:
      result = l.search_s(searchroot,scope,searchfilter,['objectclass'],attrnamesonly)
    except ldap.NO_SUCH_OBJECT:
      raise USERNAME_NOT_FOUND({'desc':'Smart login did not find a matching user entry.'})
    else:
      if not result:
        raise USERNAME_NOT_FOUND({'desc':'No matching user entry'})
      elif len(result)!=1:
        raise USERNAME_NOT_UNIQUE({'desc':'More than one matching user entries'})
      else:
        who = result[0][0]
    if who:
      return normalize_dn(who)
    else:
      return None


def test():
  """Test functions"""
  ldap_dns = {
    u'o=Michaels':1,
    u'iiii':0
  }
  print 'Testing function is_dn():'
  for ldap_dn in ldap_dns.keys():
    result_is_dn = is_dn(ldap_dn)
    if result_is_dn !=ldap_dns[ldap_dn]:
      print 'is_dn("%s") returns %d instead of %d.' % (
        ldap_dn,result_is_ldap_dn,ldap_dns[ldap_dn]
      )
  ldap_urls = {
    # Examples from RFC2255
    u'ldap:///o=University%20of%20Michigan,c=US':1,
    u'ldap://ldap.itd.umich.edu/o=University%20of%20Michigan,c=US':1,
    u'ldap://ldap.itd.umich.edu/o=University%20of%20Michigan,':1,
    u'ldap://host.com:6666/o=University%20of%20Michigan,':1,
    u'ldap://ldap.itd.umich.edu/c=GB?objectClass?one':1,
    u'ldap://ldap.question.com/o=Question%3f,c=US?mail':1,
    u'ldap://ldap.netscape.com/o=Babsco,c=US??(int=%5c00%5c00%5c00%5c04)':1,
    u'ldap:///??sub??bindname=cn=Manager%2co=Foo':1,
    u'ldap:///??sub??!bindname=cn=Manager%2co=Foo':1,
    # More examples from various sources
    u'ldap://ldap.nameflow.net:1389/c%3dDE':1,
  }
  print 'Testing function is_ldap_url():'
  for ldap_url in ldap_urls.keys():
    result_is_ldap_url = is_ldap_url(ldap_url)
    if result_is_ldap_url !=ldap_urls[ldap_url]:
      print 'is_ldap_url("%s") returns %d instead of %d.' % (
        ldap_url,result_is_ldap_url,ldap_urls[ldap_url]
      )
  print 'Testing functions escape_binary_dnchars() and unescape_binary_dnchars():'
  i=0
  binary_strings = [
    'Michael Strder',
    '\000\001\002\003\004\005\006'
  ]
  for s in binary_strings:
    print repr(s),escape_binary_dnchars(s)
    if unescape_binary_dnchars(escape_binary_dnchars(s))!=s:
      print 'Testing of unescape_binary_dnchars(escape_binary_dnchars()) failed for string.' % (
        repr(s)
      )
if __name__ == '__main__':
  test()
