/*
 * Copyright (c) 1998-2003 Caucho Technology -- all rights reserved
 *
 * Caucho Technology permits modification and use of this file in
 * source and binary form ("the Software") subject to the Caucho
 * Developer Source License 1.1 ("the License") which accompanies
 * this file.  The License is also available at
 *   http://www.caucho.com/download/cdsl1-1.xtp
 *
 * In addition to the terms of the License, the following conditions
 * must be met:
 *
 * 1. Each copy or derived work of the Software must preserve the copyright
 *    notice and this notice unmodified.
 *
 * 2. Each copy of the Software in source or binary form must include 
 *    an unmodified copy of the License in a plain ASCII text file named
 *    LICENSE.
 *
 * 3. Caucho reserves all rights to its names, trademarks and logos.
 *    In particular, the names "Resin" and "Caucho" are trademarks of
 *    Caucho and may not be used to endorse products derived from
 *    this software.  "Resin" and "Caucho" may not appear in the names
 *    of products derived from this software.
 *
 * This Software is provided "AS IS," without a warranty of any kind. 
 * ALL EXPRESS OR IMPLIED REPRESENTATIONS AND WARRANTIES, INCLUDING ANY
 * IMPLIED WARRANTY OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE
 * OR NON-INFRINGEMENT, ARE HEREBY EXCLUDED.
 *
 * CAUCHO TECHNOLOGY AND ITS LICENSORS SHALL NOT BE LIABLE FOR ANY DAMAGES
 * SUFFERED BY LICENSEE OR ANY THIRD PARTY AS A RESULT OF USING OR
 * DISTRIBUTING SOFTWARE. IN NO EVENT WILL CAUCHO OR ITS LICENSORS BE LIABLE
 * FOR ANY LOST REVENUE, PROFIT OR DATA, OR FOR DIRECT, INDIRECT, SPECIAL,
 * CONSEQUENTIAL, INCIDENTAL OR PUNITIVE DAMAGES, HOWEVER CAUSED AND
 * REGARDLESS OF THE THEORY OF LIABILITY, ARISING OUT OF THE USE OF OR
 * INABILITY TO USE SOFTWARE, EVEN IF HE HAS BEEN ADVISED OF THE POSSIBILITY
 * OF SUCH DAMAGES.      
 *
 * @author Scott Ferguson
 */

#include <stdio.h>
#ifdef WIN32
#include <windows.h>
#else
#include <unistd.h>
#include <sys/socket.h>
#include <netinet/in.h>
#include <arpa/inet.h>
#include <netdb.h>
#include <sys/time.h>
#endif
#include <stdlib.h>
#include <memory.h>
#include <errno.h>
#include <sys/types.h>

#include <fcntl.h>

/* SSLeay stuff */
#include <openssl/ssl.h>
#include <openssl/rsa.h>       
#include <openssl/err.h>

#ifdef SSL_ENGINE
#include <openssl/engine.h>
#endif
#include <jni.h>

#include "../common/cse.h"
#include "resin.h"
#include "resin_jni.h"

static int ssl_open(connection_t *conn, int fd);
static int ssl_read(connection_t *conn, char *buf, int len);
static int ssl_write(connection_t *conn, char *buf, int len);
static int ssl_close(connection_t *conn);
static void ssl_free(connection_t *conn);
static int ssl_read_client_certificate(connection_t *conn, char *buf, int len);

static connection_t *ssl_accept(server_socket_t *ss);

struct connection_ops_t ssl_ops = {
  ssl_open,
  ssl_read,
  ssl_write,
  ssl_close,
  ssl_free,
  ssl_read_client_certificate,
};

static RSA *g_rsa_512 = 0;
static RSA *g_rsa_1024 = 0;

static int
exception_status(connection_t *conn, int error)
{
  if (error == EINTR || error == EAGAIN)
    return INTERRUPT_EXN;
  else if (error == EPIPE || errno == ECONNRESET) {
    conn->ops->close(conn);
    return DISCONNECT_EXN;
  }
  else {
    conn->ops->close(conn);
    return -1;
  }
}

static int
password_callback(char *buf, int size, int rwflag, void *userdata)
{
  strcpy(buf, userdata);

  return strlen(buf);
}

/*
 * This OpenSSL callback function is called when OpenSSL
 * does client authentication and verifies the certificate chain.
 */
static int
ssl_verify_callback(int ok, X509_STORE_CTX *ctx)
{
  SSL *ssl;
  connection_t *conn;
  int error_code;

  /* If openssl's check was okay, then the verify is okay. */
  if (ok)
    return 1;

  ssl = (SSL *) X509_STORE_CTX_get_app_data(ctx);

  /* If the user data is missing, then it's a failure. */
  if (! ssl)
    return 0;
  
  conn = (connection_t *) SSL_get_app_data(ssl);

  /* If the user data is missing, then it's a failure. */
  if (! conn || ! conn->ss)
    return 0;
  
  error_code = X509_STORE_CTX_get_error(ctx);

  /* optional and required do require valid certificates */
  if (conn->ss->verify_client != Q_VERIFY_OPTIONAL_NO_CA)
    return 0;

  if (error_code == X509_V_ERR_DEPTH_ZERO_SELF_SIGNED_CERT ||
      error_code == X509_V_ERR_SELF_SIGNED_CERT_IN_CHAIN ||
      error_code == X509_V_ERR_UNABLE_TO_GET_ISSUER_CERT_LOCALLY ||
      error_code == X509_V_ERR_CERT_UNTRUSTED ||
      error_code == X509_V_ERR_UNABLE_TO_VERIFY_LEAF_SIGNATURE)
    return 1;

  return 0;
}

static SSL_CTX *
ssl_create_context(ssl_config_t *config)
{
  SSL_CTX *ctx;
  SSL_METHOD *meth;
  int skip_first;

  SSL_load_error_strings();
  SSL_library_init();
  SSLeay_add_ssl_algorithms();
  meth = SSLv23_server_method();

#ifdef SSL_ENGINE
  if (config->crypto_device && ! strcmp(config->crypto_device, "builtin")) {
    ENGINE *e = ENGINE_by_id(config->crypto_device);

    if (! e) {
      printf("unknown crypto device `%s'\n", config->crypto_device);
      return 0;
    }
    
    if (! ENGINE_set_default(e, ENGINE_METHOD_ALL)) {
      printf("Can't initialize crypto device `%s'\n", config->crypto_device);
      return 0;
    }

    printf("using crypto-device `%s'\n", config->crypto_device);

    ENGINE_free(e);
  }
#endif
  
  ctx = SSL_CTX_new(meth);
  
  if (! ctx) {
    /* ERR_print_errors_fp(stderr); */
    printf("can't allocate context\n");
    return 0;
  }

  SSL_CTX_set_options(ctx, SSL_OP_ALL);
  if (! (config->alg_flags & ALG_SSL2))
    SSL_CTX_set_options(ctx, SSL_OP_NO_SSLv2);
  if (! (config->alg_flags & ALG_SSL3))
    SSL_CTX_set_options(ctx, SSL_OP_NO_SSLv3);
  if (! (config->alg_flags & ALG_TLS1))
    SSL_CTX_set_options(ctx, SSL_OP_NO_TLSv1);

  if (! config->key_file) {
    fprintf(stderr, "Can't find certificate-key-file in SSL configuration\n");
    return 0;
  }
  
  if (! config->password) {
    fprintf(stderr, "Can't find key-store-password in SSL configuration\n");
    return 0;
  }

  SSL_CTX_set_default_passwd_cb(ctx, password_callback);
  SSL_CTX_set_default_passwd_cb_userdata(ctx, config->password);
  if (SSL_CTX_use_certificate_file(ctx, config->certificate_file,
                                   SSL_FILETYPE_PEM) != 1) {
    fprintf(stderr, "Can't open certificate file %s\n",
            config->certificate_file);
    ERR_print_errors_fp(stderr);
    return 0;
  }
  
  if (SSL_CTX_use_PrivateKey_file(ctx,
                                  config->key_file,
                                  SSL_FILETYPE_PEM) != 1) {
    ERR_print_errors_fp(stderr);
    return 0;
  }

  if (! SSL_CTX_check_private_key(ctx)) {
    fprintf(stderr, "Private key does not match the certificate public key\n");
    return 0;
  }

  if (config->certificate_chain_file &&
      SSL_CTX_use_certificate_chain_file(ctx, config->certificate_chain_file) != 1) {
    ERR_print_errors_fp(stderr);
    fprintf(stderr, "Can't open certificate chain file %s\n",
            config->certificate_chain_file);
    return 0;
  }

  if (config->verify_client != Q_VERIFY_NONE) {
    int nVerify = SSL_VERIFY_NONE|SSL_VERIFY_PEER;

    if (config->verify_client == Q_VERIFY_REQUIRE)
      nVerify |= SSL_VERIFY_FAIL_IF_NO_PEER_CERT;

    SSL_CTX_set_verify(ctx, nVerify, ssl_verify_callback);
  }

  if (! g_rsa_512) {
    g_rsa_512 = RSA_generate_key(512, RSA_F4, NULL, NULL);
    
    if (! g_rsa_512) {
      fprintf(stderr, "OpenSSL failed generating 512 bit RSA key.\n");

      return 0;
    }
  }

  if (! g_rsa_1024) {
    g_rsa_1024 = RSA_generate_key(1024, RSA_F4, NULL, NULL);
    
    if (! g_rsa_1024) {
      fprintf(stderr, "OpenSSL failed generating 1024 bit RSA key.\n");

      return 0;
    }
  }
  
  if (config->ca_certificate_file || config->ca_certificate_path) {
    if (! SSL_CTX_load_verify_locations(ctx,
					config->ca_certificate_file,
					config->ca_certificate_path)) {
      fprintf(stderr, "Can't find CA certificates for client authentication.\n");
      return 0;
    }
  }

  return ctx;
}

static connection_t *
ssl_accept(server_socket_t *ss)
{
  connection_t *conn;

  conn = std_accept(ss);

  if (! conn)
    return 0;

  if (! conn->context) {
    conn->context = ssl_create_context(ss->ssl_config);
  }

  conn->ops = &ssl_ops;
  conn->ssl_lock = &ss->ssl_lock;
  
  return conn;
}

static void
ssl_safe_free(int fd, SSL *ssl)
{
  if (ssl) {
    int count;

    /* clear non-blocking i/o */
#ifndef WIN32
    {
      int flags;
      flags = fcntl(fd, F_GETFL);
      fcntl(fd, F_SETFL, ~O_NONBLOCK&flags);
    }
#endif
  
    /* SSL_set_shutdown(ssl, SSL_SENT_SHUTDOWN|SSL_RECEIVED_SHUTDOWN); */
    SSL_set_shutdown(ssl, SSL_RECEIVED_SHUTDOWN);
    for (count = 4; count > 0; count--) {
      int result = SSL_shutdown(ssl);
    }
    
    SSL_free(ssl);
  }
}

static RSA *
ssl_get_temporary_RSA_key(SSL *ssl, int isExport, int keyLen)
{
  RSA *rsa = 0;

  if (isExport) {
    if (keyLen == 512)
      return g_rsa_512;
    else if (keyLen == 1024)
      return g_rsa_1024;
    else
      return g_rsa_1024;
  }
  else
    return g_rsa_1024;
}

static int
ssl_open(connection_t *conn, int fd)
{
  int result;
  SSL_CTX *ctx;
  SSL *ssl;
  SSL_CIPHER *cipher;
  int algbits;
  int retry = 10;

  ctx = conn->context;

  if (! ctx) {
    ERR(("missing SSL context\n"));
    return -1;
  }
  
  ssl = conn->sock;
  if (! ssl) {
    ssl = SSL_new(ctx);

    conn->sock = ssl;
  }

  if (! ssl) {
    closesocket(fd);
    conn->fd = -1;
    ERR(("can't allocate ssl\n"));
    return -1;
  }
  
  SSL_set_fd(ssl, fd);
  SSL_set_app_data(ssl, conn);

  SSL_set_tmp_rsa_callback(ssl, ssl_get_temporary_RSA_key);

  /* set non-blocking i/o */
#ifndef WIN32
  {
    int flags;
    flags = fcntl(fd, F_GETFL);
    fcntl(fd, F_SETFL, O_NONBLOCK|flags);
  }
#endif

  while (retry-- >= 0 && ! SSL_is_init_finished(ssl)) {
    fd_set read_mask;
    struct timeval timeout;
    
    timeout.tv_sec = 30;
    timeout.tv_usec = 0;
    
    FD_ZERO(&read_mask);
    FD_SET(fd, &read_mask);
    
    result = select(fd + 1, &read_mask, 0, 0, &timeout);
    
    if (result < 0 && (errno == EINTR || errno == EAGAIN))
      continue;
    
    if (result <= 0) {
      ERR(("select timeout %d err:%d\n", result, errno));
    }
    
    if (result > 0) {
#ifdef THREADS
      result = SSL_accept(ssl);
#else
      pthread_mutex_lock(conn->ssl_lock);
      result = SSL_accept(ssl);
      pthread_mutex_unlock(conn->ssl_lock);
#endif      
    }

    if (result > 0) {
    }
    else if (SSL_get_error(ssl, result) == SSL_ERROR_WANT_READ)
      continue;
    else if (result < 0) {
      ERR(("can't accept ssl %d response %d\n",
           result, SSL_get_error(ssl, result)));
      ERR_print_errors_fp(stderr);
      conn->sock = 0;
      conn->fd = -1;
      ssl_safe_free(fd, ssl);
      closesocket(fd);
      return -1;
    }
  }
  
  conn->fd = fd;

  cipher = SSL_get_current_cipher(ssl);

  if (cipher) {
    conn->ssl_cipher = (void *) SSL_CIPHER_get_name(cipher);
    conn->ssl_bits = SSL_CIPHER_get_bits(cipher, &algbits);
  }
  
  return 0;
}

static int
ssl_read(connection_t *conn, char *buf, int len)
{
  fd_set read_mask;
  struct timeval timeout;
  int fd;
  int ms;
  SSL *ssl;
  int result;
  int retry = 100;
  int ssl_error = 0;
  int timeout_chunk = 5;
  int timeout_count;

  if (conn->fd < 0)
    return -1;

  if (! conn)
    return -1;
  
  fd = conn->fd;
  ms = conn->timeout;

  if (ms <= 0)
    timeout_count = (30 + timeout_chunk - 1) / timeout_chunk;
  else
    timeout_count = (ms / 1000 + timeout_chunk - 1) / timeout_chunk;
  
  if (timeout_count <= 0)
    timeout_count = 1;
  
  if (fd < 0)
    return -1;
  
  if (! conn->is_init) {
    conn->is_init = 1;
    
    if (ssl_open(conn, conn->fd) < 0) {
      conn->ops->close(conn);
      return -1;
    }
  }

  ssl = conn->sock;
  if (! ssl)
    return -1;

  while (retry-- > 0) {
    ssl_error = 0;
    
    result = SSL_read(ssl, buf, len);

    if (result >= 0)
      return result;
    else if ((ssl_error = SSL_get_error(ssl, result)) == SSL_ERROR_WANT_READ) {
      /* wait for data */
    }
    else {
      /*
      fprintf(stderr, "disconnect from ssl error %d\n", ssl_error);
      */
      return DISCONNECT_EXN;
    }

    do {
      FD_ZERO(&read_mask);
      FD_SET(fd, &read_mask);

      timeout.tv_sec = timeout_chunk;
      timeout.tv_usec = 0;
      
      result = select(fd + 1, &read_mask, 0, 0, &timeout);
    } while (result < 0 && (errno == EINTR || errno == EAGAIN) && retry-- > 0);

    if (result == 0 && --timeout_count <= 0)
      return TIMEOUT_EXN;
    else if (result < 0)
      return exception_status(conn, errno);
  }
  
  return exception_status(conn, errno);
}

static int
ssl_write(connection_t *conn, char *buf, int len)
{
  SSL *ssl = conn->sock;
  fd_set write_mask;
  struct timeval timeout;
  int fd;
  int ms;
  int result;
  int retry = 100;

  if (! conn)
    return -1;

  if (conn->fd < 0)
    return -1;
  
  fd = conn->fd;
  
  if (! conn->is_init) {
    conn->is_init = 1;
    
    if (ssl_open(conn, conn->fd) < 0) {
      conn->ops->close(conn);
      return -1;
    }
  }
  
  ssl = conn->sock;
  
  result = SSL_write(ssl, buf, len);

  if (result > 0)
    return result;

  while (retry-- > 0) {
    ms = conn->timeout;
    FD_ZERO(&write_mask);
    FD_SET(fd, &write_mask);

    if (ms <= 0) {
      timeout.tv_sec = 30;
      timeout.tv_usec = 0;
    } else {
      timeout.tv_sec = ms / 1000;
      timeout.tv_usec = ms % 1000 * 1000;
    }
  
    result = select(fd + 1, 0, &write_mask, 0, &timeout);
      
    if (result > 0) {
    }
    else if (result == 0) {
      conn->ops->close(conn);
      return TIMEOUT_EXN;
    }
    else if (errno == EINTR || errno == EAGAIN)
      continue;
    else
      return exception_status(conn, errno);

    result = SSL_write(ssl, buf, len);

    if (result >= 0)
      return result;
    else
      continue;
  }
  
  return exception_status(conn, errno);
}

static int
ssl_close(connection_t *conn)
{
  int fd;
  SSL *ssl;

  if (! conn)
    return 0;

  fd = conn->fd;
  conn->fd = -1;

  ssl = conn->sock;
  conn->sock = 0;
  
  ssl_safe_free(fd, ssl);

  if (fd > 0)
    closesocket(fd);

  conn_close(conn);

  return 0;
}

static void
ssl_free(connection_t *conn)
{
  SSL *ssl = conn->sock;
  conn->sock = 0;

  if (ssl) {
    pthread_mutex_lock(conn->ssl_lock);
    ssl_safe_free(-1, ssl);
    pthread_mutex_unlock(conn->ssl_lock);
  }

  std_free(conn);
}

/**
 * Sets certificate chain stuff.
 */
static int
set_certificate_chain(SSL_CTX *ctx, ssl_config_t *config)
{
  /* Not sure how this is supposed to work. */
}

int
ssl_create(server_socket_t *ss, ssl_config_t *config)
{
  ss->ssl_config = config;
  ss->verify_client = config->verify_client;
  ss->context = ssl_create_context(config);
  ss->accept = ssl_accept;

  return 1;
}

static int
ssl_read_client_certificate(connection_t *conn, char *buffer, int length)
{
  BIO *bio;
  int n;
  X509 *cert;
  
  if (! conn)
    return -1;
  
  if (! conn->is_init) {
    conn->is_init = 1;
    
    if (ssl_open(conn, conn->fd) < 0) {
      conn->ops->close(conn);
      return -1;
    }
  }

  cert = SSL_get_peer_certificate(conn->sock);

  if (! cert)
    return -1;

  if ((bio = BIO_new(BIO_s_mem())) == NULL)
    return -1;
  
  PEM_write_bio_X509(bio, cert);
  n = BIO_pending(bio);

  if (n <= length)
    n = BIO_read(bio, buffer, n);
  
  BIO_free(bio);
    
  return n;
}
