/*
 * JBoss, the OpenSource WebOS
 *
 * Distributable under LGPL license.
 * See terms of license at gnu.org.
 */
package org.jboss.web.tomcat.security;

import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStream;
import java.io.UnsupportedEncodingException;
import java.net.Socket;
import java.security.Principal;
import java.util.Enumeration;
import java.util.Iterator;
import java.util.Locale;
import java.util.Map;
import javax.servlet.FilterChain;
import javax.servlet.RequestDispatcher;
import javax.servlet.ServletException;
import javax.servlet.ServletInputStream;
import javax.servlet.ServletRequest;
import javax.servlet.http.Cookie;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpSession;

import org.apache.catalina.Connector;
import org.apache.catalina.Context;
import org.apache.catalina.Host;
import org.apache.catalina.HttpRequest;
import org.apache.catalina.Request;
import org.apache.catalina.Response;
import org.apache.catalina.ValveContext;
import org.apache.catalina.Wrapper;
import org.apache.catalina.valves.ValveBase;
import org.apache.coyote.tomcat5.CoyoteRequest;
import org.apache.coyote.tomcat5.CoyoteRequestFacade;
import org.apache.tomcat.util.buf.MessageBytes;

/** A valve that associates the Principal as obtained from the authentication
 * layer with the request. This allows any custom principal established by
 * the authentication layer to be seen by the web app.
 *  
 * @author Scott.Stark@jboss.org
 * @version $Revision: 1.2.2.4 $
 */
public class CustomPrincipalValve
   extends ValveBase
{
   public void invoke(Request request, Response response, ValveContext context)
      throws IOException, ServletException
   {
      HttpServletRequest httpRequest = (HttpServletRequest) request.getRequest();
      Principal user = httpRequest.getUserPrincipal();
      Request wrappedRequest = request;
      if( user instanceof JBossGenericPrincipal )
      {
         // Get the actual principal to use as the getUserPrincipal value
         CoyoteRequest coyteRequest = (CoyoteRequest) request;
         JBossGenericPrincipal guser = (JBossGenericPrincipal) user;
         Principal realUser = guser.getCallerPrincipal();
         wrappedRequest = new UserPrinicpalRequest(coyteRequest, realUser);
      }

      context.invokeNext(wrappedRequest, response);
   }

   /** A wrapper for the Request that returns the UserPrinicipalServletRequest
    * passed in as the servletRequest parameter.
    */ 
   static class UserPrinicpalRequest
      implements HttpRequest, HttpServletRequest
   {
      private CoyoteRequest delegate;
      private UserPrinicipalServletRequest httpRequest;
      private Principal userPrincipal;
      UserPrinicpalRequest(CoyoteRequest delegate, Principal userPrincipal)
      {
         this.delegate = delegate;
         this.userPrincipal = userPrincipal;
         this.httpRequest = new UserPrinicipalServletRequest(delegate, userPrincipal);
      }

      public String getAuthType()
      {
         return delegate.getAuthType();
      }

      public Cookie[] getCookies()
      {
         return delegate.getCookies();
      }

      public long getDateHeader(String name)
      {
         return delegate.getDateHeader(name);
      }

      public String getHeader(String name)
      {
         return delegate.getHeader(name);
      }

      public Enumeration getHeaders(String name)
      {
         return delegate.getHeaders(name);
      }

      public Enumeration getHeaderNames()
      {
         return delegate.getHeaderNames();
      }

      public int getIntHeader(String name)
      {
         return delegate.getIntHeader(name);
      }

      public String getMethod()
      {
         return delegate.getMethod();
      }

      public String getPathInfo()
      {
         return delegate.getPathInfo();
      }

      public String getPathTranslated()
      {
         return delegate.getPathTranslated();
      }

      public String getContextPath()
      {
         return delegate.getContextPath();
      }

      public String getQueryString()
      {
         return delegate.getQueryString();
      }

      public String getRemoteUser()
      {
         return delegate.getRemoteUser();
      }

      public boolean isUserInRole(String role)
      {
         return delegate.isUserInRole(role);
      }

      public Principal getUserPrincipal()
      {
         return userPrincipal;
      }

      public String getRequestedSessionId()
      {
         return delegate.getRequestedSessionId();
      }

      public String getRequestURI()
      {
         return delegate.getRequestURI();
      }

      public StringBuffer getRequestURL()
      {
         return delegate.getRequestURL();
      }

      public String getServletPath()
      {
         return delegate.getServletPath();
      }

      public HttpSession getSession(boolean create)
      {
         return delegate.getSession(create);
      }

      public HttpSession getSession()
      {
         return delegate.getSession();
      }

      public boolean isRequestedSessionIdValid()
      {
         return delegate.isRequestedSessionIdValid();
      }

      public boolean isRequestedSessionIdFromCookie()
      {
         return delegate.isRequestedSessionIdFromCookie();
      }

      public boolean isRequestedSessionIdFromURL()
      {
         return delegate.isRequestedSessionIdFromURL();
      }

      public boolean isRequestedSessionIdFromUrl()
      {
         return delegate.isRequestedSessionIdFromUrl();
      }

      public Object getAttribute(String name)
      {
         return delegate.getAttribute(name);
      }

      public Enumeration getAttributeNames()
      {
         return delegate.getAttributeNames();
      }

      public String getCharacterEncoding()
      {
         return delegate.getCharacterEncoding();
      }

      public void setCharacterEncoding(String env) throws UnsupportedEncodingException
      {
         delegate.setCharacterEncoding(env);
      }

      public int getContentLength()
      {
         return delegate.getContentLength();
      }

      public String getContentType()
      {
         return delegate.getContentType();
      }

      public ServletInputStream getInputStream() throws IOException
      {
         return delegate.getInputStream();
      }

      public String getParameter(String name)
      {
         return delegate.getParameter(name);
      }

      public Enumeration getParameterNames()
      {
         return delegate.getParameterNames();
      }

      public String[] getParameterValues(String name)
      {
         return delegate.getParameterValues(name);
      }

      public Map getParameterMap()
      {
         return delegate.getParameterMap();
      }

      public String getProtocol()
      {
         return delegate.getProtocol();
      }

      public String getScheme()
      {
         return delegate.getScheme();
      }

      public String getServerName()
      {
         return delegate.getServerName();
      }

      public int getServerPort()
      {
         return delegate.getServerPort();
      }

      public BufferedReader getReader() throws IOException
      {
         return delegate.getReader();
      }

      public String getRemoteAddr()
      {
         return delegate.getRemoteAddr();
      }

      public String getRemoteHost()
      {
         return delegate.getRemoteHost();
      }

      public void setAttribute(String name, Object o)
      {
         delegate.setAttribute(name, o);
      }

      public void removeAttribute(String name)
      {
         delegate.removeAttribute(name);
      }

      public Locale getLocale()
      {
         return delegate.getLocale();
      }

      public Enumeration getLocales()
      {
         return delegate.getLocales();
      }

      public boolean isSecure()
      {
         return delegate.isSecure();
      }

      public RequestDispatcher getRequestDispatcher(String path)
      {
         return delegate.getRequestDispatcher(path);
      }

      public String getRealPath(String path)
      {
         return delegate.getRealPath(path);
      }

      public int getRemotePort()
      {
         return delegate.getRemotePort();
      }

      public String getLocalName()
      {
         return delegate.getLocalName();
      }

      public String getLocalAddr()
      {
         return delegate.getLocalAddr();
      }

      public int getLocalPort()
      {
         return delegate.getLocalPort();
      }

      public void addCookie(Cookie cookie)
      {
         delegate.addCookie(cookie);
      }

      public void addHeader(String name, String value)
      {
         delegate.addHeader(name, value);
      }

      public void addLocale(Locale locale)
      {
         delegate.addLocale(locale);
      }

      public void addParameter(String name, String values[])
      {
         delegate.addParameter(name, values);
      }

      public void clearCookies()
      {
         delegate.clearCookies();
      }

      public void clearHeaders()
      {
         delegate.clearHeaders();
      }

      public void clearLocales()
      {
         delegate.clearLocales();
      }

      public void clearParameters()
      {
         delegate.clearParameters();
      }

      public void setAuthType(String type)
      {
         delegate.setAuthType(type);
      }

      public MessageBytes getContextPathMB()
      {
         return delegate.getContextPathMB();
      }

      public void setContextPath(String path)
      {
         delegate.setContextPath(path);
      }

      public void setMethod(String method)
      {
         delegate.setMethod(method);
      }

      public void setQueryString(String query)
      {
         delegate.setQueryString(query);
      }

      public MessageBytes getPathInfoMB()
      {
         return delegate.getPathInfoMB();
      }

      public void setPathInfo(String path)
      {
         delegate.setPathInfo(path);
      }

      public MessageBytes getRequestPathMB()
      {
         return delegate.getRequestPathMB();
      }

      public void setRequestedSessionCookie(boolean flag)
      {
         delegate.setRequestedSessionCookie(flag);
      }

      public void setRequestedSessionId(String id)
      {
         delegate.setRequestedSessionId(id);
      }

      public void setRequestedSessionURL(boolean flag)
      {
         delegate.setRequestedSessionURL(flag);
      }

      public void setRequestURI(String uri)
      {
         delegate.setRequestURI(uri);
      }

      public void setDecodedRequestURI(String uri)
      {
         delegate.setDecodedRequestURI(uri);
      }

      public String getDecodedRequestURI()
      {
         return delegate.getDecodedRequestURI();
      }

      public MessageBytes getDecodedRequestURIMB()
      {
         return delegate.getDecodedRequestURIMB();
      }

      public MessageBytes getServletPathMB()
      {
         return delegate.getServletPathMB();
      }

      public void setServletPath(String path)
      {
         delegate.setServletPath(path);
      }

      public void setUserPrincipal(Principal principal)
      {
         delegate.setUserPrincipal(principal);
      }

      public String getAuthorization()
      {
         return delegate.getAuthorization();
      }

      public void setAuthorization(String authorization)
      {
         delegate.setAuthorization(authorization);
      }

      public Connector getConnector()
      {
         return delegate.getConnector();
      }

      public void setConnector(Connector connector)
      {
         delegate.setConnector(connector);
      }

      public Context getContext()
      {
         return delegate.getContext();
      }

      public void setContext(Context context)
      {
         delegate.setContext(context);
      }

      public FilterChain getFilterChain()
      {
         return delegate.getFilterChain();
      }

      public void setFilterChain(FilterChain filterChain)
      {
         delegate.setFilterChain(filterChain);
      }

      public Host getHost()
      {
         return delegate.getHost();
      }

      public void setHost(Host host)
      {
         delegate.setHost(host);
      }

      public String getInfo()
      {
         return delegate.getInfo();
      }

      public ServletRequest getRequest()
      {
         return httpRequest;
      }

      public Response getResponse()
      {
         return delegate.getResponse();
      }

      public void setResponse(Response response)
      {
         delegate.setResponse(response);
      }

      public Socket getSocket()
      {
         return delegate.getSocket();
      }

      public void setSocket(Socket socket)
      {
         delegate.setSocket(socket);
      }

      public InputStream getStream()
      {
         return delegate.getStream();
      }

      public void setStream(InputStream stream)
      {
         delegate.setStream(stream);
      }

      public ValveContext getValveContext()
      {
         return delegate.getValveContext();
      }

      public void setValveContext(ValveContext valveContext)
      {
         delegate.setValveContext(valveContext);
      }

      public Wrapper getWrapper()
      {
         return delegate.getWrapper();
      }

      public void setWrapper(Wrapper wrapper)
      {
         delegate.setWrapper(wrapper);
      }

      public ServletInputStream createInputStream() throws IOException
      {
         return delegate.createInputStream();
      }

      public void finishRequest() throws IOException
      {
         delegate.finishRequest();
      }

      public Object getNote(String name)
      {
         return delegate.getNote(name);
      }

      public Iterator getNoteNames()
      {
         return delegate.getNoteNames();
      }

      public void recycle()
      {
         delegate.recycle();
      }

      public void removeNote(String name)
      {
         delegate.removeNote(name);
      }

      public void setContentLength(int length)
      {
         delegate.setContentLength(length);
      }

      public void setContentType(String type)
      {
         delegate.setContentType(type);
      }

      public void setNote(String name, Object value)
      {
         delegate.setNote(name, value);
      }

      public void setProtocol(String protocol)
      {
         delegate.setProtocol(protocol);
      }

      public void setRemoteAddr(String remote)
      {
         delegate.setRemoteAddr(remote);
      }

      public void setScheme(String scheme)
      {
         delegate.setScheme(scheme);
      }

      public void setSecure(boolean secure)
      {
         delegate.setSecure(secure);
      }

      public void setServerName(String name)
      {
         delegate.setServerName(name);
      }

      public void setServerPort(int port)
      {
         delegate.setServerPort(port);
      }

   }

   /** A wrapper for the HttpServletRequest implementation that overrides the
    * getUserPrincipal method to return the original caller principal set by
    * the Relam.
    */
   static class UserPrinicipalServletRequest extends CoyoteRequestFacade
   {
      private Principal userPrincipal;
      UserPrinicipalServletRequest(CoyoteRequest request, Principal userPrincipal)
      {
         super(request);
         this.userPrincipal = userPrincipal;
      }

      public Principal getUserPrincipal()
      {
         return userPrincipal;
      }
   }
}
