/*

  Copyright 2000, 2001, 2002 Laurent Wacrenier

  This file is part of libhome

  libhome is free software; you can redistribute it and/or modify it
  under the terms of the GNU Lesser General Public License as
  published by the Free Software Foundation; either version 2 of the
  License, or (at your option) any later version.

  libhome is distributed in the hope that it will be useful,
  but WITHOUT ANY WARRANTY; without even the implied warranty of
  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
  GNU Lesser General Public License for more details.
  
  You should have received a copy of the GNU Lesser General Public
  License along with libhome; if not, write to the Free Software
  Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307
  USA

*/

#define LISTEN_BACKLOG 128

#include "config.h"

static char const rcsid[] UNUSED =
"$Id: home_proxy.c,v 1.14 2005/09/14 08:57:00 lwa Exp $";

#include <sys/types.h>
#include <sys/socket.h>
#include <sys/un.h>
#include <sys/wait.h>
#include <sys/stat.h>

/* redefine system struct passwd  */
#define passwd system_passwd
#include <pwd.h>
#undef passwd

#include <unistd.h>
#if HAVE_GETOPT_H
#include <getopt.h>
#endif
#include <syslog.h>

#include <string.h>
#include <errno.h>
#include <signal.h>
#include <stdio.h>
#include <stdlib.h>
#if HAVE_STDINT_H
#include <stdint.h>
#endif

#define HOME_DONT_SUBSTITUTE_SYSTEM

#include "hpwd.h"

#include "hparam.h"

struct child {
  pid_t pid;
  time_t end;
};

static struct child *childs = NULL;
static int nchild = 1;
static int need_checkup = 1;
static int exiting = 0;
static int restarting = 1;
static int alarmed;
static int verbose = 0;

static char *progname(char *argv0) {
  char *end = strrchr(argv0, '/');
  if (end)
    return end+1;
  else
    return argv0;
}

#define BUF_INC 1000

#define ERR_NOT_FOUND 0
#define ERR_MEM       0xFFFFFFFF

void senderr(int clt, const uint32_t err) {
  write(clt, &err, sizeof(uint32_t));
  close(clt);
}

#define ALLOW_UID       0x01
#define ALLOW_PASSWORD  0x02
#define ALLOW_NAME      0x04

#define ALLOW_ALL       0xFF

#define CMP_EQ 0
#define CMP_GT 1
#define CMP_GE 2
#define CMP_LT 3
#define CMP_LE 4
#define CMP_NE 5

unsigned int get_restrict(char **acl, uid_t uid, gid_t gid,
			  uid_t retr_uid, int retr_gid) {
  unsigned int ret = ALLOW_ALL;
  if (!acl)
    return ret;
  while(*acl) {
    char *rule = *acl;
    unsigned long comp1 = uid;
    char mode = 'u';
    int found  = 0;
    switch(*rule) {
    case 'u':
      comp1 = uid; rule++; mode = 'u'; break;
    case 'g':
      comp1 = gid; rule++; mode = 'g'; break;
    case 'U':
      comp1 = retr_uid; rule++; mode = 'u'; break;
    case 'G':
      comp1 = retr_gid; rule++; mode = 'g'; break;
    case '*':
      found = 1; break;
    }
    if (!found) {
      int cmp = CMP_EQ;
      unsigned long comp2;
      switch(*rule) {
      case '<': /* <, <= */
	if (rule[1] == '=') {
	  cmp = CMP_LE;
	  rule += 2;
	} else {
	  cmp = CMP_LT;
	  rule += 1;
	}
	break;
      case '>': /* >, >= */
	if (rule[1] == '=') {
	  cmp = CMP_GE;
	  rule += 2;
	} else {
	  cmp = CMP_GT;
	  rule += 1;
	}
	break;
      case '!': /* !, != */
	if (rule[1] == '=') {
	  cmp = CMP_NE;
	  rule += 2;
	} else {
	  cmp = CMP_NE;
	  rule += 1;
	}
	break;
      case '=': /* =, == */
	if (rule[1] == '=') {
	  cmp = CMP_EQ;
	  rule += 2;
	} else {
	  cmp = CMP_EQ;
	  rule += 1;
	}
	break;
      }
      if (*rule == '@') {
	comp2 = mode == 'u' ? retr_uid : retr_gid;
      } else {
	comp2 = strtoul(rule, &rule, 10);
      }
      switch(cmp) {
      case CMP_EQ: found = (comp1 == comp2); break;
      case CMP_NE: found = (comp1 != comp2); break;
      case CMP_LE: found = (comp1 <= comp2); break;
      case CMP_LT: found = (comp1 <  comp2); break;
      case CMP_GE: found = (comp1 >= comp2); break;
      case CMP_GT: found = (comp1 >  comp2); break;
      }
    }
    if (found) {
      while(*rule && *rule != ':') rule++;
      while(*rule == ':') rule++;

      while(*rule) {
	switch(*rule) {
	case 'u': ret &= ~ALLOW_UID; break;
	case 'n': ret &= ~ALLOW_NAME; break;
	case 'p': ret &= ~ALLOW_PASSWORD; break;
	case '-': ret = 0; break;
	}
	rule ++;
      }
      break;
    }
    acl++;
  }
  return ret;
}

/* see http://cr.yp.to/docs/secureipc.html for a abstract of methods */
ssize_t get_data_owner(int fd, void *buf, size_t len, 
		       uid_t *uid, gid_t *gid) {
  *uid = (uid_t)-1;
  *gid = (gid_t)-1;

#if defined(HAVE_GETPEEREID)
  { /* FreeBSD, OpenBSD, AIX, MacOSX */
    getpeereid(fd, uid, gid) ; 
  }

#elif defined(SO_PEERCRED)
  { /* Linux */
    struct ucred cr;
    int cr_len = sizeof(cr);
    if (getsockopt(fd, SOL_SOCKET, SO_PEERCRED, &cr, &cr_len) != -1) {
      *uid = cr.uid;
      *gid = cr.gid; 
    }
  }
#elif defined(LOCAL_CREDS) && defined(HAVE_STRUCT_SOCKCRED)
  { /* NetBSD */
    int on = 1;
    if (setsockopt(fd, 0, LOCAL_CREDS, &on, sizeof(on)) != -1) {
      ssize_t ret;
      struct msghdr msg = {0,};
      struct {
	struct cmsghdr hdr;
	struct sockcred cred;
      } cmsg;
      struct iovec iov[1];
      msg.msg_iov = iov;
      iov->iov_base = buf;
      iov->iov_len = len;
      msg.msg_iovlen = 1;
      msg.msg_control = &cmsg;
      msg.msg_controllen = sizeof(cmsg);
      if ((ret = recvmsg(fd, &msg, 0))>=0) {
	if (cmsg.hdr.cmsg_level == SOL_SOCKET && 
	    cmsg.hdr.cmsg_type == SCM_CREDS) {
	  *uid = cmsg.cred.sc_euid;
	  *gid = cmsg.cred.sc_egid;
	}
	return ret;
      } else {
	return -1;
      }
    } 
  }
#  else
  /* Solaris */
#warning socket privilege check are not supported
#endif
 /* fallckack */
 return recv(fd, buf, len, 0);
}

void child(int s, struct param *params) {
  int clt;
  char entry[1002];
  ssize_t entrylen;
  clt = accept(s, NULL, 0);
  if (clt >= 0) {
    uid_t uid;
    gid_t gid;
    unsigned int restrictions;
    //    entrylen = read(clt, entry, sizeof(entry));
    entrylen = get_data_owner(clt, entry, 1002, &uid, &gid);
    shutdown(clt, SHUT_RD);
    if (entrylen > sizeof(entry) - 2) {
      if (verbose > 0)
	syslog(LOG_DEBUG, "query too long %.*s", (int)(sizeof(entry)-2),
	       entry);
      senderr(clt, 0);
      return;
    }
    entry[entrylen]=0;

    if (*entry) {
      static char *out = NULL;
      static int outlen = 0;
      char *what;
      struct passwd *pwd;

#define ZERO_LENGTH 10
      int zero[ZERO_LENGTH];
      int len;

      if (IS_UID(entry)) {
	uid_t quid = (uid_t)strtoul(GET_UID(entry), NULL, 10);
	what = "UID";
	pwd = home_getpwuid(quid);
	if (pwd) {
	  restrictions = get_restrict(params->proxy_deny,
				      uid, gid, pwd->pw_uid, pwd->pw_gid);
	  if (!(restrictions & ALLOW_UID))
	    pwd = NULL;
	}
      } else {
	pwd = home_getpwnam(entry);
	what = "name";
	if (pwd) {
	  restrictions = get_restrict(params->proxy_deny,
				      uid, gid, pwd->pw_uid, pwd->pw_gid);
	  if (!(restrictions & ALLOW_NAME))
	    pwd = NULL;
	}
      }
      
      if (pwd) {
	if (out == NULL) {
	  outlen = 1000;
	  out = malloc(outlen + sizeof(uint32_t));
	  if (out == NULL) {
	    syslog(LOG_WARNING, "unable to allocate %lu bytes",
		   (unsigned long)(outlen + sizeof(uint32_t)));
	    senderr(clt, ERR_MEM);
	    exit(1);
	  }
	}

	do {
	  len = snprintf(out+sizeof(uint32_t), outlen,
			 "%s%n:%s%n:%lu%n:%lu%n:%s%n:%s%n:%s%n:%s%n:%lu%n:%lu%n",
			 pwd->pw_name, zero,
			 restrictions & ALLOW_PASSWORD ? pwd->pw_passwd : "*",
			 zero+1,
			 (unsigned long)pwd->pw_uid, zero+2,
			 (unsigned long)pwd->pw_gid, zero+3,
			 pwd->pw_class, zero+4,
			 pwd->pw_gecos, zero+5,
			 pwd->pw_dir, zero+6,
			 pwd->pw_shell, zero+7,
			 (unsigned long)pwd->pw_expire, zero+8,
			 (unsigned long)pwd->pw_quota, zero+9);
	  if (len >= outlen) {
	    free(out);
	    outlen = len + 10;
	    out = malloc(outlen + sizeof(uint32_t));
	    if (out == NULL) {
	      syslog(LOG_WARNING, "unable to allocate %lu bytes",
		     (unsigned long)(outlen + sizeof(uint32_t)));
	      senderr(clt, ERR_MEM);
	      exit(1);
	    }
	  } else {
	    break;
	  }
	} while (1);
	if (out == NULL) {
	  if (verbose > 1)
	    syslog(LOG_DEBUG, "\"%s\": transcient failure", entry);
	  senderr(clt, ERR_MEM);
	  return;
	} else {
	  int i;
	  for (i=0; i< ZERO_LENGTH; i++) {
	    out[zero[i] + sizeof(uint32_t)] = 0;
	  }
	  *(uint32_t *)out = (uint32_t)len;
	  write(clt, out, len + sizeof(uint32_t));
	  if (verbose > 1)
	    syslog(LOG_DEBUG, "%s \"%s\" found (name %s)", what,
		   entry, pwd->pw_name);
	}
      } else {
	if (verbose > 1)
	  syslog(LOG_DEBUG, "\"%s\" %s", entry,
		 errno == ENOMEM ? "transcient failure" : "not found");
	senderr(clt, errno == ENOMEM ? 0xFFFFFFFF : 0);
	return;
      }
      close(clt);
    } else {
      if (verbose > 1)
	syslog(LOG_DEBUG, "\"\" void request");
    }
  } else {
    syslog(LOG_WARNING, "accept(): %s", strerror(errno));
  }
}

void child_exit(void) {
  home_endpwent();
}

void child_exit_sig(int s) {
  exit(0);
}

pid_t run(int s, struct param *params) {
 
  pid_t pid; 
  sigset_t saved_sigmask;
  sigset_t block_sigmask;

  sigfillset(&block_sigmask);
  sigdelset(&block_sigmask, SIGTRAP); /* let's debug */
  sigprocmask(SIG_BLOCK, &block_sigmask, &saved_sigmask);

  pid = fork();
  switch(pid) {
  case -1:
    need_checkup = 1;
    syslog(LOG_WARNING, "unable to fork a child: %s", strerror(errno));
    break;
  case 0:  /* child */
    atexit(child_exit);
    signal(SIGHUP, SIG_IGN);  /* let the parent handle this */
    signal(SIGTERM, child_exit_sig); /* die */
    signal(SIGINT, child_exit_sig);  /* die */
    sigprocmask(SIG_SETMASK, &saved_sigmask, NULL);
    while(1)
      child(s, params);
    break;
  default: /* parent */
    break;
  }
  sigprocmask(SIG_SETMASK, &saved_sigmask, NULL);
  return pid;
}

void sigchld(int sig) {
  pid_t pid;
  int status;
  while((pid = wait3(&status, WNOHANG, NULL)) != 0) {
    if (pid == (pid_t)-1)
      break;
    if (WIFEXITED(status) || WIFSIGNALED(status)) {
      int i;
      need_checkup = 1;
      for (i=0; i<nchild; i++) {
	if (pid == childs[i].pid) {
	  childs[i].pid = (pid_t)-1;
	}
      }
    }
  }
}

void sigexit(int sig) {
  exiting = 1;
}

void sigrestart(int sig) {
  restarting = 1;
}

void sigalarm(int sig) {
  alarmed = 1;
}

void checkup(int s, struct param *params) {
  int i;
  need_checkup = 0;
  for (i = 0; i<nchild; i++) {
    if (childs[i].pid == (pid_t)-1) {
      childs[i].pid = run(s, params);
    }
  }
}

void cleanup(void) {
  int i;
  for (i = 0; i<nchild; i++) {
    pid_t pid = childs[i].pid;
    if (pid > 0) {
      kill(pid, SIGTERM);
    }
  }
}

void verbosity(int sig) {
  switch(sig) {
  case SIGUSR1:
    verbose++;
    break;
  case SIGUSR2:
    verbose = 0;
    break;
  }
}

int killone(int n) {
  pid_t pid;
  if (n >= nchild)
    n = 0;
  pid = childs[n].pid;
  if (pid > 0)
    kill(pid, SIGTERM);
  return n;
}

int main(int argc, char **argv) {
  char *path = NULL;
  char *tag;
  int s = -1;
  struct sockaddr_un sa_un;
  int ch;
  int unlink_path = 0;
  char *conffile = NULL;
  int i;
  char *pidfile = NULL;
  char *oldpath = NULL;
  char *optpath = NULL;
  char *me;
  int detach = 1;
  mode_t mode = (mode_t)-1;
  uid_t uid = getuid();
  uid_t gid = getgid();
  int killtime = 0;
  int lastkilled = 0;


  me = progname(argv[0]);
  tag = me;

  openlog(me, LOG_PID, LOG_AUTH);

  while((ch = getopt(argc, argv, "m:u:g:ds:t:n:xC:p:k:v")) != -1) {
    switch(ch) {
    case 'x':
      unlink_path = 1;
      break;
    case 's':
      optpath = optarg;
      break;
    case 'n':
      nchild = strtoul(optarg, NULL, 10);
      if (nchild < 1)
	nchild = 1;
      break;
    case 't':
      tag = optarg;
      break;
    case 'C':
      conffile = optarg;
      break;
    case 'p':
      pidfile = optarg;
      break;
    case 'd':
      detach = 0;
      break;
    case 'k':
      killtime = strtol(optarg, NULL, 10);
      break;
    case 'v':
      verbose ++;
      break;
    case 'm': {
      char *rest;
      mode = strtoul(optarg, &rest, 8);
      if (*rest || (mode & ~(S_IRWXU|S_IRWXG|S_IRWXO))) {
	fprintf(stderr, "%s: invalide mode '%s'\n", me, optarg);
	exit(1);
      }
      break;
    }
    case 'u': {
      char *rest;
      uid = strtoul(optarg, &rest, 10);
      if (*rest) {
	fprintf(stderr, "%s: non numerical UID '%s'\n", me, optarg);
	exit(1);
      }
      break;
    }
    case 'g': {
      char *rest;
      gid = strtoul(optarg, &rest, 10);
      if (*rest) {
	fprintf(stderr, "%s: non numerical GID '%s'\n", me, optarg);
	exit(1);
      }
      break;
    }
    default:
      fprintf(stderr,
	      "usage: %s [-xdv] [-s socket] [-t tag]" 
	      " [-n #] [-C conf] [-p pidfile]\n"
	      "          [-u uid] [-g gid] [-m mode] [-k seconds]\n", me);
      exit(1);
      break;
    }
  }
  childs = calloc(nchild, sizeof(struct child));
  
  for (i = 0; i<nchild; i++) {
    childs[i].pid = -1;
  } 

  signal(SIGCHLD, sigchld);
  
  signal(SIGHUP, sigrestart);
  signal(SIGINT, sigexit);
  signal(SIGTERM, sigexit);
  if (killtime >0) {
    signal(SIGALRM, sigalarm);
  }

  setpwtag(tag);

  if (detach) {
    pid_t pid = fork();
    if (pid == -1) {
      syslog(LOG_ERR, "fatal: unable to detach: %s", strerror(errno));
      exit(1);
    } else if (pid) {
      exit(0);
    }
  }

  if (pidfile) {
    FILE *f = fopen(pidfile, "w+");
    if (f) {
      fprintf(f, "%lu\n", (unsigned long)getpid());
      fclose(f);
    } else {
      syslog(LOG_WARNING, "unable to open pid file \"%s\": %s", 
	     pidfile, strerror(errno));
      pidfile = NULL;
    }
  }

  do {
    int bind_ok = 0;
    struct param *params;

    home_cleanup();
    params = home_init(conffile);
    home_setpassent(1); /* keep alive */
  
    path = optpath;

    if (path == NULL)
      path = params->proxy_socket;
    
    if (path == NULL) {
      path = "/var/run/home_proxy";
    }

    if (s == -1 || (oldpath && strcmp(oldpath, path))) {
      if (s != -1) {
	close(s);
	if (oldpath)
	  unlink(oldpath);
      }
      s = socket(AF_UNIX, SOCK_STREAM, 0);
      if (s == -1) {
	syslog(LOG_ERR, "fatal: unable to create a socket: %s",
	       strerror(errno));
	if (pidfile)
	  unlink(pidfile);
	exit(1);
      }
      memset(&sa_un, 0, sizeof(struct sockaddr_un));
      sa_un.sun_family = AF_UNIX;
      sa_un.sun_path[0] = 0;
      strncat(sa_un.sun_path, path, sizeof(sa_un.sun_path)-1);
      
      bind_ok = (bind(s, (struct sockaddr *)&sa_un, sizeof(sa_un)) != -1);
      
      if (!bind_ok && errno == EADDRINUSE && unlink_path) {
	unlink(path);
	bind_ok = (bind(s, (struct sockaddr *)&sa_un, sizeof(sa_un)) != -1);
      }
	
      if (!bind_ok) {
	syslog(LOG_ERR, "fatal: unable to bind to %s: %s",
	       path, strerror(errno));
	close(s);
	if (pidfile)
	  unlink(pidfile);
	exit(1);
      }

      if (mode != (mode_t)-1) {
	if ( chmod(path, mode) == -1 ) {
	  syslog(LOG_ERR, "unable to change mode of %s to 0%lo: %s",
		 path, (unsigned long)mode, strerror(errno));
	}
      }

      if ( chown(path, uid, gid) == -1 ) {
	syslog(LOG_ERR, "unable to change owner of %s to %lu:%lu: %s",
	       path, (unsigned long)uid, (unsigned long)gid, strerror(errno));
      }

      if (listen(s, LISTEN_BACKLOG) == -1) {
	close(s);
	unlink(path);
	syslog(LOG_ERR, "fatal: unable to listen to %s: %s",
	       path, strerror(errno));
	if (pidfile)
	  unlink(pidfile);
	exit(1);
      }
      oldpath = path;
    }
    if (restarting) {
      restarting = 0;
      cleanup(); /* kill childs */
    }
    syslog(LOG_INFO, "ready (%s%s%swaiting on %s)",
	   (tag ? "tag " : ""), tag, ", ",
	   path);

    if (killtime > 0)
      alarm(killtime);

    do {
      if (alarmed) {
	lastkilled = killone(lastkilled+1);
	alarm(killtime);
	alarmed=0;
      }
      if (need_checkup)
	checkup(s, params);
      sleep(1); /* pause */
    } while(! (exiting || restarting));
  } while(!exiting);
  
  cleanup();
  unlink(path);
  close(s);
  syslog(LOG_INFO, "exiting");
  if (pidfile)
    unlink(pidfile);
  exit(0);
}
