"""
ldaputil.base - basic LDAP functions
(c) by Michael Stroeder <michael@stroeder.com>

This module is distributed under the terms of the
GPL (GNU GENERAL PUBLIC LICENSE) Version 2
(see http://www.gnu.org/copyleft/gpl.html)

Python compability note:
This module only works with Python 1.6+ since all string parameters
are assumed to be Unicode objects and string methods are used instead
string module.

$Id: base.py,v 1.12 2002/02/26 21:00:55 michael Exp $
"""

__version__ = '0.7.0'

import re,urllib,ldap

from ldapurl import LDAPUrl

SEARCH_SCOPE_STR = ['base','one','sub']

SEARCH_SCOPE = {
  # default for empty search scope string
  '':ldap.SCOPE_BASE,
  # the search scope strings defined in RFC22xx(?)
  'base':ldap.SCOPE_BASE,
  'one':ldap.SCOPE_ONELEVEL,
  'sub':ldap.SCOPE_SUBTREE
}

attr_type_pattern = ur'[\w;.]+(;[\w_-]+)*'
attr_value_pattern = ur'(([^,]|\\,)+|".*?")'
rdn_pattern = attr_type_pattern + ur'[ ]*=[ ]*' + attr_value_pattern
dn_pattern   = rdn_pattern + r'([ ]*,[ ]*' + rdn_pattern + r')*[ ]*'

dc_rdn_pattern = ur'(dc|)[ ]*=[ ]*' + attr_value_pattern
dc_dn_pattern   = dc_rdn_pattern + r'([ ]*,[ ]*' + dc_rdn_pattern + r')*[ ]*'

#rdn_regex   = re.compile('^%s$' % rdn_pattern)
dn_regex      = re.compile(u'^%s$' % unicode(dn_pattern))

# Some widely used types
StringType = type('')
UnicodeType = type(u'')

ROOTDSE_ATTRS = [
  'defaultNamingContext',
  'defaultRnrDN',
  'altServer',
  'namingContexts',
  'subschemaSubentry',
  'supportedLDAPVersion',
  'subschemaSubentry',
  'supportedControl',
  'supportedSASLMechanisms',
  'supportedExtension',
  'supportedFeatures',
  'objectclass',
  'supportedSASLMechanisms',
  'dsServiceName',
  'ogSupportedProfile',
  'netscapemdsuffix',
  'dataversion',
  'dsaVersion',
]

def unicode_list(l,charset='utf-8'):
  """
  return list of Unicode objects
  """
  return [
    unicode(i,charset)
    for i in l
  ]

def unicode_entry(e,charset='utf-8'):
  """
  return dictionary of lists of Unicode objects
  """
  result = {}
  for attrtype,valuelist in e.items():
    result[attrtype]=unicode_list(valuelist,charset)
  return result

def encode_unicode_list(l,charset='utf-8'):
  """
  Encode the list of Unicode objects with given charset
  and return list of encoded strings
  """
  return [ i.encode(charset) for i in l ]

def is_dn(s):
  """returns 1 if s is a LDAP DN"""
  if s:
    rm = dn_regex.match(s)
    return rm!=None
  else:
    return 0

def explode_rdn_attr(attr_type_and_value):
  """
  explode_rdn_attr(attr_type_and_value) -> tuple

  This function takes a single attribute type and value pair
  describing a characteristic attribute forming part of a RDN
  (e.g. 'cn=Michael Stroeder') and returns a 2-tuple
  containing the attribute type and the attribute value unescaping
  the attribute value according to RFC 2253 if necessary.
  """
  attr_type,attr_value = attr_type_and_value.split('=')
  if attr_value:
    r = []
    start_pos=0
    i = 0
    attr_value_len=len(attr_value)
    while i<attr_value_len:
      if attr_value[i]=='\\':
        r.append(attr_value[start_pos:i])
        start_pos=i+1
      i=i+1
    r.append(attr_value[start_pos:i])
    attr_value = ''.join(r)
  return (attr_type,attr_value)

def rdn_dict(dn,charset='utf-8'):
  rdn,rest = SplitRDN(dn)
  if not rdn:
    return {}
  if type(rdn)==UnicodeType:
    rdn = rdn.encode(charset)
  result = {}
  for i in ldap.explode_rdn(rdn.strip()):
    attr_type,attr_value = explode_rdn_attr(i)
    attr_value = unicode(attr_value,charset)
    if result.has_key(attr_type):
      result[attr_type].append(attr_value)
    else:
      result[attr_type]=[attr_value]
  return result

def explode_dn(dn,charset='utf-8'):
  """
  Wrapper function for explode_dn() which returns [] for 
  a zero-length DN
  """
  if not dn:
    return []
  if type(dn)==UnicodeType:
    dn = dn.encode(charset)
  dn_list = ldap.explode_dn(dn.strip())
  if dn_list and dn_list!=['']:
    return [ unicode(dn.strip(),'utf-8') for dn in dn_list ]
  else:
    return []


def normalize_dn(dn):
  result = explode_dn(dn)
  return ','.join(result)


def match_dn(dn1,dn2):
  """Return how much levels two dn1 and dn2 are matching (integer)"""
  dn1_list,dn2_list = explode_dn(dn1),explode_dn(dn2)
  if not dn1_list or not dn2_list:
    return (0,u'')
  # Determine dn1_cmp that it's shorter than dn2_cmp
  if len(dn1_list)<=len(dn2_list):
    dn1_cmp,dn2_cmp = dn1_list,dn2_list
  else:
    dn1_cmp,dn2_cmp = dn2_list,dn1_list
  i = 1 ; dn1_len = len(dn1_cmp)
  while (dn1_cmp[-i].lower()==dn2_cmp[-i].lower()):
    i = i+1
    if i>dn1_len:
      break
  if i>1:
    return (i-1,','.join(dn2_cmp[-i+1:]))
  else:
    return (0,u'')

def match_dnlist(dn,dnlist):
  """find best matching parent DN of dn in dnlist"""
  max_match_level, max_match_name = 0, ''
  for dn_item in dnlist:
    match_level, match_name = match_dn(dn_item,dn)
    if match_level>max_match_level:
      max_match_level, max_match_name = match_level, match_name
  return max_match_name

def extract_referrals(e):
  """Extract the referral LDAP URL from a
     ldap.PARTIAL_RESULTS exception object"""
  if e.args[0].has_key('info'):
    info, ldap_url_info = [
      x.strip()
      for x in e.args[0]['info'].split('\n',1)
    ]
  else:
    raise ValueError, "Referral exception object does not have info field"
  ldap_urls = [
    LDAPUrl(l)
    for l in ldap_url_info.split('\n')
  ]
  matched = e.args[0].get('matched',None)
  return (matched,ldap_urls)

def ParentDN(dn):
  """returns parent-DN of dn"""
  dn_comp = explode_dn(dn)
  if len(dn_comp)>1:
    return ','.join(dn_comp[1:])
  elif len(dn_comp)==1:
    return ''
  else:
    return None

def SplitRDN(dn):
  """returns tuple (RDN,base DN) of dn"""
  dn_comp = explode_dn(dn)
  if len(dn_comp)>1:
    return dn_comp[0], ','.join(dn_comp[1:])
  elif len(dn_comp)==1:
    return dn_comp[0], ''
  else:
    return None,None

def ParentDNList(dn,rootdn=''):
  """returns a list of parent-DNs of dn"""
  result = []
  DNComponentList = explode_dn(dn)
  if rootdn:
    max_level = len(explode_dn(rootdn))
  else:
    max_level = len(DNComponentList)
  for i in range(1,max_level):
    result.append(','.join(DNComponentList[i:]))
  return result


def escape_filter_chars(s):
  """RFC2254: Convert special characters in search filter to escaped hex"""
  for ch in ['\\','*','(',')','\000']:
    s = s.replace(ch,'\%02X' % ord(ch))
  return s

def escape_dn_chars(s):
  """Convert special characters in s to escaped hex"""
  if s:
    for ch in ['\\',',','+','"','<', '>',';']:
      s = s.replace(ch,'\\'+ch)
    if s[0]=='#':
      s = ''.join(['\\',s])
    if s[-1]==' ':
      s = ''.join([s[:-1],'\\ '])
  return s


def escape_binary_dnchars(s):
  """Convert NON-ASCII characters in DN to escaped hex form"""
  new = ''
  for ch in s:
    c=ord(ch)
    if (c<32) or (c>=128):
      new = new+('\%02X' % c)
    else:
      new = new+ch
  return new

# Regex object used for finding hex-escaped characters
hex_dnchars_regex = re.compile(r'\\[0-9a-fA-F][0-9a-fA-F]')

def unescape_binary_dnchars(s,charset='utf-8'):
  """Convert hex-escaped characters in DN to binary NON-ASCII chars"""
  if type(s)==UnicodeType:
    s = s.encode(charset)
  asciiparts = hex_dnchars_regex.split(s)
  hexparts = hex_dnchars_regex.findall(s)
  new = asciiparts[0]
  for i in range(1,len(asciiparts)):
    new = chr(eval('0x%s' % hexparts[i-1][1:])).join([new,asciiparts[i]])
  return new

def sanitize_entry(entry):
  """
  Sanitize a dictionary holding a LDAP entry to be compliant
  to X.500 data model (e.g. delete multiple occurences of
  same attribute values)
  """
  result = {}
  for attr_type in entry.keys():
    attr_value_dict = {}
    for attr_value in entry[attr_type]:
      attr_value_dict[attr_value]=None
    result[attr_type] = attr_value_dict.keys()
  return result

def SearchTree(l=None,ldap_dn='',attrsonly=1,timeout=-1,charset='utf-8'):
  """
  Returns all DNs of the sub tree at ldap_dn as list of
  Unicode string objects.
  """
  if type(ldap_dn)==StringType:
    ldap_dn = unicode(ldap_dn,charset)
  stack, result = [ldap_dn],[]
  while stack:
    dn = stack.pop()
    result.append(dn)
    try:
      r = l.search_st(
        dn.encode(charset),
        ldap.SCOPE_ONELEVEL,
        '(objectclass=*)',
        ['objectclass'],
        attrsonly,
        timeout
      )
    except ldap.NO_SUCH_OBJECT:
      r = []
    stack.extend(
      [
        unicode(dn,charset)
        for dn,entry in r
      ]
    )
  return result


class USERNAME_NOT_FOUND(ldap.LDAPError):
  """
  Simple exception class raised when SmartLogin() does not
  find any entry matching search
  """
  def __str__(self):
    return self.args[0]['desc']

class USERNAME_NOT_UNIQUE(ldap.LDAPError):
  """
  Simple exception class raised when SmartLogin() does not
  find more than one entry matching search
  """
  def __str__(self):
    return self.args[0]['desc']

def SmartLogin(
  l,                        # LDAP connection object created with ldap.open()
  username='',              # User name or complete bind DN (UTF-8 encoded)
  searchroot='',            # search root for user entry search
  filtertemplate='(uid=%s)',# template string for LDAP filter
  scope=ldap.SCOPE_SUBTREE, # search scope
  attrnamesonly=1,          # retrieve only attribute names when searching
  timeout=-1
):
  """
  Do a smart login.
  
  If username is a valid DN it's used as bind-DN without further action.
  Otherwise filtertemplate is used to construct a LDAP search filter
  containing username instead of %s.
  The calling application has to handle all possible exceptions:
  ldap.NO_SUCH_OBJECT, ldap.FILTER_ERROR, ldapbase.USERNAME_NOT_UNIQUE
  ldap.INVALID_CREDENTIALS, ldap.INAPPROPRIATE_AUTH
  """
  if not username:
    return ''
  elif is_dn(username):
    return normalize_dn(username)
  else:
    searchfilter = filtertemplate.replace('%s',username)
    if searchroot is None:
      searchroot = ''
    # Try to find a unique entry with filtertemplate
    try:
      result = l.search_st(
        searchroot,
        scope,
        searchfilter,
        ['objectclass'],
        attrnamesonly,
        timeout
      )
    except ldap.NO_SUCH_OBJECT:
      raise USERNAME_NOT_FOUND({'desc':'Smart login did not find a matching user entry.'})
    else:
      if not result:
        raise USERNAME_NOT_FOUND({'desc':'Smart login did not find a matching user entry.'})
      elif len(result)!=1:
        raise USERNAME_NOT_UNIQUE({'desc':'More than one matching user entries.'})
      else:
        return normalize_dn(result[0][0])


def test():
  """Test functions"""

  print '\nTesting function is_dn():'
  ldap_dns = {
    u'o=Michaels':1,
    u'iiii':0,
    u'"cn="Mike"':0,
  }
  for ldap_dn in ldap_dns.keys():
    result_is_dn = is_dn(ldap_dn)
    if result_is_dn !=ldap_dns[ldap_dn]:
      print 'is_dn("%s") returns %d instead of %d.' % (
        ldap_dn,result_is_dn,ldap_dns[ldap_dn]
      )

  print '\nTesting function escape_dn_chars():'
  ldap_dns = {
    u'#\\,+"<>; ':u'\\#\\\\\\,\\+\\"\\<\\>\\;\\ ',
    '#\\,+"<>; ':'\\#\\\\\\,\\+\\"\\<\\>\\;\\ ',
    u'Str\xf6der':u'Str\xf6der',
    'Strder':'Strder',
    '':'',
  }
  for ldap_dn in ldap_dns.keys():
    result_escape_dn_chars = escape_dn_chars(ldap_dn)
    if result_escape_dn_chars !=ldap_dns[ldap_dn]:
      print 'escape_dn_chars(%s) returns %s instead of %s.' % (
        repr(ldap_dn),
        repr(result_escape_dn_chars),repr(ldap_dns[ldap_dn])
      )

  print '\nTesting function explode_rdn_attr():'
  ldap_dns = {
    'cn=Michael Stroeder':('cn','Michael Stroeder'),
    'cn=whois\+\+':('cn','whois++'),
    'cn=\#dummy\ ':('cn','#dummy '),
    'cn;lang-en-EN=Michael Stroeder':('cn;lang-en-EN','Michael Stroeder'),
    'cn=':('cn',''),
  }
  for ldap_dn in ldap_dns.keys():
    result_explode_rdn_attr = explode_rdn_attr(ldap_dn)
    if result_explode_rdn_attr !=ldap_dns[ldap_dn]:
      print 'explode_rdn_attr(%s) returns %s instead of %s.' % (
        repr(ldap_dn),
        repr(result_explode_rdn_attr),repr(ldap_dns[ldap_dn])
      )

  print '\nTesting functions escape_binary_dnchars() and unescape_binary_dnchars():'
  binary_strings = [
    '\000\001\002\003\004\005\006',
    'Michael Strder',
  ]
  for s in binary_strings:
    if unescape_binary_dnchars(escape_binary_dnchars(s))!=s:
      print 'Testing of unescape_binary_dnchars(escape_binary_dnchars()) failed for string %s.' % (
        repr(s)
      )

  print '\nTesting function match_dn():'
  match_dn_tests = {
    ('O=MICHAELS','o=michaels'):(1,u'O=MICHAELS'),
    ('CN=MICHAEL STROEDER,O=MICHAELS','o=michaels'):(1,u'O=MICHAELS'),
    ('CN=MICHAEL STROEDER,O=MICHAELS',''):(0,u''),
    ('CN=MICHAEL STROEDER,O=MICHAELS','     '):(0,u''),
    ('CN=MICHAEL STROEDER,O=MICHAELS','  cn=Michael Stroeder,o=Michaels  '):(2,u'cn=Michael Stroeder,o=Michaels'),
    ('CN=MICHAEL STROEDER,O=MICHAELS','mail=michael@stroeder.com,  cn=Michael Stroeder,o=Michaels  '):(2,u'cn=Michael Stroeder,o=Michaels'),
  }
  for dn1,dn2 in match_dn_tests.keys():
    result_match_dn = match_dn(dn1,dn2)
    if result_match_dn[0] !=match_dn_tests[(dn1,dn2)][0] or \
       result_match_dn[1].lower() !=match_dn_tests[(dn1,dn2)][1].lower():
      print 'match_dn(%s,%s) returns:\n%s\ninstead of:\n%s\n' % (
        repr(dn1),repr(dn2),
        repr(result_match_dn),
        repr(match_dn_tests[(dn1,dn2)])
      )


if __name__ == '__main__':
  test()
