/*
 * JBoss, the OpenSource J2EE webOS
 *
 * Distributable under LGPL license.
 * See terms of license at gnu.org.
 *
 * Created on Feb 4, 2004
 */
package org.jboss.security.auth.spi;

import java.security.KeyStore;
import java.security.KeyStoreException;
import java.security.Principal;
import java.security.acl.Group;
import java.security.cert.X509Certificate;
import java.util.Map;
import java.util.ArrayList;
import java.util.Enumeration;
import java.io.IOException;

import javax.naming.InitialContext;
import javax.naming.NamingException;
import javax.security.auth.Subject;
import javax.security.auth.callback.Callback;
import javax.security.auth.callback.CallbackHandler;
import javax.security.auth.callback.NameCallback;
import javax.security.auth.callback.UnsupportedCallbackException;
import javax.security.auth.login.FailedLoginException;
import javax.security.auth.login.LoginException;

import org.jboss.security.SecurityDomain;
import org.jboss.security.auth.callback.ObjectCallback;

/**
 * <dl>
 * <dt><b>Title: </b><dd>Base Login Module that uses X509Certificates as
 * credentials for authentication</dd>
 * <p>
 * <dt><b>Description: </b><dd>This login module uses X509Certificates as a
 * credential. It takes the cert as an object and checks to see if the alias in
 * the truststore/keystore contains the same certificate. Subclasses of this
 * module should implement the getRoleSets() method defined by
 * AbstractServerLoginModule. Much of this module was patterned after the
 * UserNamePasswordLoginModule.</dd>
 * <p>
 * </dl>
 * @author <a href="mailto:jasone@greenrivercomputing.com">Jason Essington</a>
 * @author Scott.Stark@jboss.org
 * @version $Revision: 1.1.2.3 $
 */
public class BaseCertLoginModule extends AbstractServerLoginModule
{
   /** */
   private Principal identity;
   /** */
   private X509Certificate credential;
   /** */
   private SecurityDomain domain = null;

   /** Override the super version to pickup the following options after first
    * calling the super method.
    * 
    * @option securityDomain: the name of the SecurityDomain to obtain the
    * trust and keystore from.
    * 
    * @see SecurityDomain
    * 
    * @param subject the Subject to update after a successful login.
    * @param callbackHandler the CallbackHandler that will be used to obtain the
    *    the user identity and credentials.
    * @param sharedState a Map shared between all configured login module instances
    * @param options the parameters passed to the login module.
    */
   public void initialize(Subject subject, CallbackHandler callbackHandler,
      Map sharedState, Map options)
   {
      super.initialize(subject, callbackHandler, sharedState, options);

      // Get the security domain and default to "other"
      String sd = (String) options.get("securityDomain");
      if (sd == null)
         sd = "java:/jaas/other";

      if (log.isDebugEnabled())
         log.debug("securityDomain=" + sd);

      try
      {
         Object tempDomain = new InitialContext().lookup(sd);
         if (tempDomain instanceof SecurityDomain)
         {
            domain = (SecurityDomain) tempDomain;
            if (log.isDebugEnabled())
            {
               if (domain != null)
                  log.debug("found domain: " + domain.getClass().getName());
               else
                  log.debug("the domain " + sd + " is null!");
            }
         }
         else
         {
            log.error("The domain " + sd + " is not a SecurityDomain. All authentication using this module will fail!");
         }
      }
      catch (NamingException e)
      {
         log.error("Unable to find the securityDomain named: " + sd, e);
      }

      if (log.isDebugEnabled())
         log.debug("exit: initialize(Subject, CallbackHandler, Map, Map)");
   }

   /** 
    * Perform the authentication of the username and password.
    */
   public boolean login() throws LoginException
   {
      if (log.isDebugEnabled())
         log.debug("enter: login()");
      // See if shared credentials exist
      if (super.login() == true)
      {
         // Setup our view of the user
         Object username = sharedState.get("javax.security.auth.login.name");
         if( username instanceof Principal )
            identity = (Principal) username;
         else
         {
            String name = username.toString();
            try
            {
               identity = createIdentity(name);
            }
            catch(Exception e)
            {
               log.debug("Failed to create principal", e);
               throw new LoginException("Failed to create principal: "+ e.getMessage());
            }
         }

         Object password = sharedState.get("javax.security.auth.login.password");
         if (password instanceof X509Certificate)
            credential = (X509Certificate) password;
         else if (password != null)
         {
            log.debug("javax.security.auth.login.password is not X509Certificate");
            super.loginOk = false;
            return false;
         }
         return true;
      }

      super.loginOk = false;
      Object[] info = getAliasAndCert();
      String alias = (String) info[0];
      credential = (X509Certificate) info[1];

      if (alias == null && credential == null)
      {
         identity = unauthenticatedIdentity;
         super.log.trace("Authenticating as unauthenticatedIdentity=" + identity);
      }

      if (identity == null)
      {
         try
         {
            identity = createIdentity(alias);
         }
         catch(Exception e)
         {
            log.debug("Failed to create identity for alias:"+alias, e);
         }

         if (!validateCredential(alias, credential))
         {
            super.log.debug("Bad credential for alias=" + alias);
            throw new FailedLoginException("Supplied Credential did not match existing credential for " + alias);
         }
      }

      if (getUseFirstPass() == true)
      {
         // Add authentication info to shared state map
         sharedState.put("javax.security.auth.login.name", alias);
         sharedState.put("javax.security.auth.login.password", credential);
      }
      super.loginOk = true;
      super.log.trace("User '" + identity + "' authenticated, loginOk=" + loginOk);
      
      if (log.isDebugEnabled())
         log.debug("exit: login()");
      return true;
   }

   /** Override to add the X509Certificate to the public credentials
    * @return
    * @throws LoginException
    */ 
   public boolean commit() throws LoginException
   {
      boolean ok = super.commit();
      if( ok == true )
      {
         // Add the cert to the public credentials
         subject.getPublicCredentials().add(credential);
      }
      return ok;
   }

   /** Subclasses need to override this to provide the roles for authorization
    * @return
    * @throws LoginException
    */ 
   protected Group[] getRoleSets() throws LoginException
   {
      return new Group[0];
   }

   protected Principal getIdentity()
   {
      return identity;
   }
   protected Object getCredentials()
   {
      return credential;
   }
   protected String getUsername()
   {
      String username = null;
      if (getIdentity() != null)
         username = getIdentity().getName();
      return username;
   }

   protected Object[] getAliasAndCert() throws LoginException
   {
      if (log.isDebugEnabled())
         log.debug("enter: getAliasAndCert()");
      Object[] info = { null, null };
      // prompt for a username and password
      if (callbackHandler == null)
      {
         throw new LoginException("Error: no CallbackHandler available to collect authentication information");
      }
      NameCallback nc = new NameCallback("Alias: ");
      ObjectCallback oc = new ObjectCallback("Certificate: ");
      Callback[] callbacks = { nc, oc };
      String alias = null;
      X509Certificate cert = null;
      X509Certificate[] certChain;
      try
      {
         callbackHandler.handle(callbacks);
         alias = nc.getName();
         Object tmpCert = oc.getCredential();
         if (tmpCert != null)
         {
            if (tmpCert instanceof X509Certificate)
            {
               cert = (X509Certificate) tmpCert;
               if (log.isDebugEnabled())
                  log.debug("found cert " + cert.getSerialNumber().toString(16) + ":" + cert.getSubjectDN().getName());
            }
            else if( tmpCert instanceof X509Certificate[] )
            {
               certChain = (X509Certificate[]) tmpCert;
               if( certChain.length > 0 )
                  cert = certChain[0];
            }
            else
            {
               String msg = "Don't know how to obtain X509Certificate from: "
                  +tmpCert.getClass();
               log.warn(msg);
               throw new LoginException(msg);
            }
         }
      }
      catch (IOException e)
      {
         log.debug("Failed to invoke callback", e);
         throw new LoginException("Failed to invoke callback: "+e.toString());
      }
      catch (UnsupportedCallbackException uce)
      {
         throw new LoginException("CallbackHandler does not support: "
            + uce.getCallback());
      }

      info[0] = alias;
      info[1] = cert;
      if (log.isDebugEnabled())
         log.debug("exit: getAliasAndCert()");
      return info;
   }

   protected boolean validateCredential(String alias, X509Certificate cert)
   {
      if (log.isDebugEnabled())
         log.debug("enter: validateCredentail(String, X509Certificate)");
      boolean isValid = false;

      if (domain != null && cert != null)
      {
         // if we don't have a trust store, we'll just use the key store.
         KeyStore store = domain.getTrustStore();
         if (store == null)
            store = domain.getKeyStore();
         if (store != null)
         {
            X509Certificate storeCert = null;
            try
            {
               storeCert = (X509Certificate) store.getCertificate(alias);
               if (log.isDebugEnabled())
               {
                  StringBuffer buf = new StringBuffer("\n\tSupplied Credential: ");
                  buf.append(cert.getSerialNumber().toString(16));
                  buf.append("\n\t\t");
                  buf.append(cert.getSubjectDN().getName());
                  buf.append("\n\n\tExisting Credential: ");
                  if( storeCert != null )
                  {
                     buf.append(storeCert.getSerialNumber().toString(16));
                     buf.append("\n\t\t");
                     buf.append(storeCert.getSubjectDN().getName());
                     buf.append("\n");
                  }
                  else
                  {
                     ArrayList aliases = new ArrayList();
                     Enumeration en = store.aliases();
                     while (en.hasMoreElements())
                     {
                        aliases.add(en.nextElement());
                     }
                     buf.append("No match for alias: "+alias+", we have aliases " + aliases);
                  }
                  log.debug(buf.toString());
               }
            }
            catch (KeyStoreException e)
            {
               log.warn("failed to find the certificate for " + alias, e);
            }
            
            if (cert.equals(storeCert))
               isValid = true;
         }
         else
         {
            log.warn("KeyStore is null!");
         }
      }
      else
      {
         log.warn("Domain or Credential is null. Unable to validate the certificate.");
      }

      if (log.isDebugEnabled())
      {
         log.debug("The supplied certificate "
               + (isValid ? "matched" : "DID NOT match")
               + " the certificate in the keystore.");

         log.debug("exit: validateCredentail(String, X509Certificate)");
      }
      return isValid;
   }

}
