#include <sys/stat.h>
#include <sys/socket.h>
#include <sys/types.h>
#include <sys/wait.h>
#include <netinet/in.h>
#include <netinet/tcp.h>
#include <arpa/inet.h>
#include <fcntl.h>
#include <netdb.h>
#include <stdio.h>
#include <signal.h>
#include <unistd.h>
#include <stdlib.h>
#include <string.h>
#include <errno.h>
#include <netdb.h>
#include <syslog.h>
#include <stdarg.h>
#ifdef WITH_TCPWRAPPERS
#include <tcpd.h>
#endif

/*
 * $Id: postgresql-relay.c,v 1.9 2004/03/21 04:29:04 edwin Exp $
 *
 */

#define MAXBUFSIZ	64*1024

#define PROCNAME	"PostgreSQL-relay"

/*
 * Copied from the include/libpq/pqcomm.h
 */
#define PG_PROTOCOL_MAJOR(v)       ((v) >> 16)
#define PG_PROTOCOL_MINOR(v)       ((v) & 0x0000ffff)


char *configurationfile="/usr/local/etc/postgresql-relay.conf";
int VERBOSE=1;
int VERYVERBOSE=0;
int BACKGROUND=1;

struct configtype {
    int port;
    char *dbname;
    char *remotehost;
    struct in_addr ipremotehost;
    int remoteport;
    char *remotedbname;
    struct configtype *next;
} configtype;

struct listenports {
    int	fd;
    int port;
    struct listenports *next;
} listenports;

struct session_statistics {
    long client_bytes_send;
    long client_bytes_received;
    long server_bytes_send;
    long server_bytes_received;
} session_statistics;

int	StartListening(struct listenports *port);
int	StartTalking(struct listenports *port);
void	StopTalking(int fd);
void	StopListening();
void	RunForestRun(int fd_client,struct sockaddr_in *sock);

struct listenports *listeners=NULL;
struct configtype *configs=NULL;

void	mylog(int debug,char *fmt,...);

/*
 * CONFIGURATION
 *
 */

void Usage() {
    printf(
"Usage: postgresql-relay [-bfqv] [-c configfile]\n"
"	-f	stay in the foreground\n"
"	-q	log quiet\n"
"	-c	specificy configuration file, default is /usr/local/etc/postgresql-relay.conf\n"
"	-v	log verbose\n"
    );
    exit(0);
}

void GetArgs(int argc,char **argv) {
    int ch;

    while ((ch=getopt(argc,argv,"c:fqv"))!=-1) {
	switch (ch) {
	case 'f':
	    BACKGROUND=0;
	    break;
	case 'c':
	    configurationfile=optarg;
	    break;
	case 'q':
	    VERBOSE=VERYVERBOSE=0;
	    break;
	case 'v':
	    VERBOSE=VERYVERBOSE=1;
	    break;
	default:
	    Usage();
	}
    }

}

void ReadConfiguration(void) {
    FILE *fin;
    char buffer[100];
    char *pbuffer=buffer;
    char **ap;
    char *parts[5];
    int  i;

    struct listenports *port;
    struct configtype *type;

    if (BACKGROUND)
	openlog(PROCNAME,LOG_NDELAY|LOG_PID,LOG_DAEMON);

    if ((fin=fopen(configurationfile,"r"))==NULL) {
	printf("Cannot open %s for reading - %s\n",configurationfile,strerror(errno));
	mylog(LOG_ERR,"Cannot open %s for reading - %s",configurationfile,strerror(errno));
	exit(0);
    }

    while (!feof(fin)) {
	if (fgets(buffer,sizeof(buffer),fin)==NULL) break;

	if (buffer[0]=='#') continue;
	buffer[strlen(buffer)-1]=0;	// newline
	pbuffer=buffer;
	i=0;
	for (ap=parts;(*ap=strsep(&pbuffer,":"))!=NULL;) {
	    parts[i++]=strdup(*ap);
	    ap++;
	}

	port=listeners;
	while (port!=NULL) {
	    if (port->port==atoi(parts[0]))
		break;
	    port=port->next;
	}
	if (port==NULL) {
	    port=(struct listenports *)calloc(1,sizeof(struct listenports));
	    port->port=atoi(parts[0]);
	    port->next=listeners;
	    listeners=port;
	}

	{
	    struct hostent *h;

	    type=(struct configtype *)calloc(1,sizeof(struct configtype));
	    type->port=atoi(parts[0]);
	    type->dbname=strdup(parts[1]);
	    type->remotehost=strdup(parts[2]);
	    type->remoteport=atoi(parts[3]);
	    type->remotedbname=strdup(parts[4]);
	    if ((h=gethostbyname(type->remotehost))==NULL) {
		herror("ReadConfiguration::gethostbyname");
		exit(0);
	    }
	    memcpy(&type->ipremotehost,h->h_addr,h->h_length);
	    type->next=configs;
	    configs=type;
	}

	for (i=0;i<5;i++) {
	    free(parts[i]);
	}
    }

    fclose(fin);
}

void mylog(int priority,char *fmt,...) {
    char buf[1024];
    va_list ap;

    va_start(ap,fmt);
    vsprintf(buf,fmt,ap);
    va_end(ap);

    if (BACKGROUND)
	syslog(priority,"%s",buf);
    else
	printf("%s\n",buf);

}

/*
 * SERVER
 *
 * - Spawn a child when a connection comes in.
 *
 */

#define PORT	4321

void killchild(int s) {
    int stat;
    int count=0;
    pid_t pid;
    while ((pid=waitpid(-1,&stat,WNOHANG))>0) {
	if (VERBOSE)
	    mylog(LOG_NOTICE,"Child #%d with pid %d ended",count++,pid);
    }
}

int main(int argc,char **argv) {
    fd_set in_set;
    int maxdesc=0;

    struct listenports *port;

    GetArgs(argc,argv);
    ReadConfiguration();

    if (VERYVERBOSE)
	mylog(LOG_NOTICE,"Started");

    if (BACKGROUND) {
	pid_t p;
	if ((p=fork()>0)) {
	    exit(0);
	}
	if (p==-1) {
	    mylog(LOG_ERR,"Cannot fork process: %s",strerror(errno));
	    exit(0);
	}
	mylog(LOG_INFO,"Forking");
    }

    FD_ZERO(&in_set);
    port=listeners;
    while (port!=NULL) {
	if (VERYVERBOSE)
	    mylog(LOG_INFO,"Starting server for port %d",port->port);
	if (!StartListening(port)) {
	    mylog(LOG_ERR,"Starting server for port %d failed",port->port);
	    return 1;
	}
	FD_SET(port->fd,&in_set);
	maxdesc=maxdesc<port->fd?port->fd:maxdesc;
	port=port->next;
    }

    signal(SIGCHLD,killchild);

    while (1) {
	port=listeners;
	while (port!=NULL) {
	    FD_SET(port->fd,&in_set);
	    port=port->next;
	}

	if (select(maxdesc+1,&in_set,NULL,NULL,NULL)<0) {
	    if (errno!=EINTR) {
		mylog(LOG_ERR,"select() in main failed - %s",strerror(errno));
		exit(0);
	    } else
		continue;
	}

	port=listeners;
	while (port!=NULL) {
	    if (FD_ISSET(port->fd,&in_set)) {
		if (VERYVERBOSE)
		    mylog(LOG_DEBUG,"accepting on port %d",port->port);
		StartTalking(port);
		FD_CLR(port->fd,&in_set);
	    }
	    port=port->next;
	}

    }

    if (VERYVERBOSE)
	mylog(LOG_NOTICE,"Ended");
    return 0;
}

/*
 * NETWORK FUNCTIONS
 *
 */

int StartListening(struct listenports *port) {
    int one=1;
    int len;
    struct sockaddr_in saddr;

    if ((port->fd=socket(AF_INET,SOCK_STREAM,0))<0) {
	mylog(LOG_ERR,"socket() in StartListening failed - %s",strerror(errno));
	return 0;
    }

    if ((setsockopt(port->fd,SOL_SOCKET,SO_REUSEADDR,(char *)&one,sizeof(one)))<0) {
	mylog(LOG_ERR,"setsockopt() in StartListening failed - %s",strerror(errno));
	return 0;
    }

    memset(&saddr,0,sizeof(saddr));
    saddr.sin_family=AF_INET;
    saddr.sin_addr.s_addr=htonl(INADDR_ANY);
    saddr.sin_port=htons(port->port);
    len=sizeof(struct sockaddr_in);

    if (bind(port->fd,(struct sockaddr *)&saddr,len)<0) {
	mylog(LOG_ERR,"bind() in StartListening failed - %s",strerror(errno));
	return 0;
    }

    if (listen(port->fd,1)<0) {
	mylog(LOG_ERR,"listen() in StartListening failed - %s",strerror(errno));
	return 0;
    }

    return 1;
}

int check_access(struct sockaddr_in *sock,char *dbname) {
#ifdef WITH_TCPWRAPPERS
    struct request_info req;
    int yes=0;
    char *daemon;

    if (dbname==NULL) {
	daemon=strdup("postgresql");
    } else {
	daemon=(char *)malloc(strlen(dbname)+strlen("postgresql")+2);
	sprintf(daemon,"postgresql-%s",dbname);
    }

    if (sock!=NULL)
	mylog(LOG_DEBUG,"Checking access of %s to %s",inet_ntoa(sock->sin_addr),daemon);

    request_init(&req,RQ_DAEMON,daemon,0);
    free(daemon);
    if (sock!=NULL) request_set(&req,RQ_CLIENT_SIN,sock,0);
    sock_methods(&req);
    yes=hosts_access(&req);
    return yes;
#else
    return 1;
#endif
}

int StartTalking(struct listenports *port) {
    struct sockaddr_in sock;
    int socksize=sizeof(sock);
    int one=1;
    int fd_client;
    int pid;

    if ((fd_client=accept(port->fd,(struct sockaddr *)&sock,&socksize))<0) {
	mylog(LOG_ERR,"accept() in StartTalking failed - %s",strerror(errno));
	return 0;
    }
    if (getpeername(fd_client,(struct sockaddr *)&sock,&socksize)<0) {
	mylog(LOG_ERR,"getpeername() in StartTalking failed - %s",strerror(errno));
	return 0;
    }
    if (VERBOSE)
	mylog(LOG_NOTICE,"Incoming from %s on port %d",inet_ntoa(sock.sin_addr),port->port);
    if (setsockopt(fd_client,IPPROTO_TCP,TCP_NODELAY,(char *)&one,sizeof(one))<0) {
	mylog(LOG_ERR,"TCP_NODELAY in StartTalking failed - %s",strerror(errno));
    }
    if (setsockopt(fd_client,SOL_SOCKET,SO_KEEPALIVE,(char *)&one,sizeof(one))<0) {
	mylog(LOG_ERR,"SO_KEEPALIVE in StartTalking failed - %s",strerror(errno));
    }

    if (!check_access(&sock,NULL)) {
	write(fd_client,"Access denied by hostname\n\r",27);
	mylog(LOG_ERR,"Access denied by rule 'postgresql'");
	StopTalking(fd_client);
	return 0;
    }

    if ((pid=fork())==0) { /* child */
	StopListening();
	RunForestRun(fd_client,&sock);
	exit(0);
    } 

    StopTalking(fd_client);

    return 1;
}

void StopTalking(int fd) {
    close(fd);
}

void StopListening() {
    struct listenports *ports=listeners;
    while (ports!=NULL) {
	close(ports->fd);
	ports=ports->next;
    }
}

/*
 * CHILD
 *
 * - Wait for a start packet
 * - Connect to the right server
 * - Pass everything through
 *
 */

int ConnectToServer(char *dbname,char *username) {
    int len;
    int fd_server;
    struct sockaddr_in local;
    struct sockaddr_in remote;
    struct configtype *config=configs;

    while (config!=NULL) {
	if (strcmp(config->dbname,dbname)==0) break;
	config=config->next;
    }
    if (config==NULL)
	return 0;

    if ((fd_server=socket(AF_INET,SOCK_STREAM,0))<0) {
	mylog(LOG_ERR,"socket() in ConnectToServer failed - %s",strerror(errno));
	return 0;
    }

    memset(&local,0,sizeof(local));
    memset(&remote,0,sizeof(remote));

    local.sin_family=AF_INET;
    local.sin_addr.s_addr=htonl(INADDR_ANY);
    local.sin_port=0;
    len=sizeof(struct sockaddr_in);

    remote.sin_family=AF_INET;
    memcpy(&remote.sin_addr.s_addr,&config->ipremotehost,4);
    remote.sin_port=htons(config->remoteport);
    len=sizeof(struct sockaddr_in);

    if (VERBOSE)
	mylog(LOG_INFO,"Connecting to server %s[%s] port %d",config->remotehost,inet_ntoa(config->ipremotehost),config->remoteport);

    if (bind(fd_server,(struct sockaddr *)&local,len)<0) {
	mylog(LOG_ERR,"bind() in ConnectToServer failed - %s",strerror(errno));
	return 0;
    }

    if (connect(fd_server,(struct sockaddr *)&remote,len)<0) {
	mylog(LOG_ERR,"connect() in ConnectToServer failed - %s",strerror(errno));
	return 0;
    }

#ifdef HAVE_SETPROCTITLE
    // poor linux users...
    setproctitle("to %s:%d (%s@%s)",config->remotehost,config->remoteport,username,dbname);
#endif

    return fd_server;
}

void RunForestRun(int fd_client,struct sockaddr_in *sock) {
    fd_set in_set;
    fd_set out_set;
    fd_set exc_set;

    int maxdesc;
    int fd_server=0;

    int gotheader=0;

    struct session_statistics stats={0,0,0,0};

    FD_ZERO(&in_set);
    FD_ZERO(&out_set);
    FD_ZERO(&exc_set);

    maxdesc=fd_client;

    while (1) {
	FD_SET(fd_client,&in_set);
	FD_SET(fd_client,&out_set);
	FD_SET(fd_client,&exc_set);
	if (gotheader) {
	    FD_SET(fd_server,&in_set);
	    FD_SET(fd_server,&out_set);
	    FD_SET(fd_server,&exc_set);
	}

	if (select(maxdesc+1,&in_set,NULL,&exc_set,NULL)<0) {
	    if (errno!=EINTR) {
		mylog(LOG_ERR,"select() in RunForestRun failed - %s",strerror(errno));
		exit(0);
	    } else
		continue;
	}

	if (FD_ISSET(fd_client,&exc_set)) {
	    if (VERBOSE)
		mylog(LOG_INFO,"%d: client quitting",fd_client);
	    exit(0);
	}
	if (gotheader && FD_ISSET(fd_server,&exc_set)) {
	    if (VERBOSE)
		mylog(LOG_INFO,"%d: server quitting",fd_client);
	    exit(0);
	}

	if (FD_ISSET(fd_client,&in_set)) {
	    if (gotheader==0) {
		int n;
		uint32_t m_size;
		uint32_t size;
		uint32_t m_type;
		char *dbname;
		char *username;
		char buffer[MAXBUFSIZ];

		if ((n=read(fd_client,&m_size,sizeof(m_size)))<=0) {
		    mylog(LOG_DEBUG,"%d: m_size EOF",fd_client);
		    exit(0);
		}
		m_size=ntohl(m_size);
		if (VERYVERBOSE)
		    mylog(LOG_DEBUG,"%d bytes",m_size);
		size=m_size-sizeof(m_size);

		if ((n=read(fd_client,buffer+sizeof(m_size),size))<=0) {
		    mylog(LOG_DEBUG,"%d: buffer EOF",fd_client);
		    exit(0);
		}
		stats.client_bytes_received+=m_size;
		memcpy(&m_type,buffer+sizeof(m_size),sizeof(m_type));
		m_type=ntohl(m_type);
		
		if (VERYVERBOSE)
		    mylog(LOG_DEBUG,"type 0x%08X",m_type);

		buffer[0]=((m_size&0xFF000000)>>24);
		buffer[1]=((m_size&0x00FF0000)>>16);
		buffer[2]=((m_size&0x0000FF00)>>8);
		buffer[3]=((m_size&0x000000FF)>>0);

		if (PG_PROTOCOL_MAJOR(m_type)==2) {
		    dbname=buffer+sizeof(m_size)+sizeof(m_type);
		    username=buffer+sizeof(m_size)+sizeof(m_type)+64;

		    if (!check_access(sock,dbname)) {
			write(fd_client,"Access denied to database by hostname\n\r",14);
			mylog(LOG_ERR,"Access denied by rule 'postgresql-%s'",dbname);
			exit(0);
		    }

		    if (VERYVERBOSE) {
			mylog(LOG_DEBUG,"%d. dbname=%s",fd_client,dbname);
			mylog(LOG_DEBUG,"%d. username=%s",fd_client,username);
		    }
		    if (VERBOSE)
			mylog(LOG_NOTICE,"Connecting to %s as %s",dbname,username);
		    if ((fd_server=ConnectToServer(dbname,username))==0) {
			mylog(LOG_WARNING,"RunForestRun::ConnectToServer failed");
			exit(0);
		    }
		    gotheader=1;
		    stats.server_bytes_send+=m_size;
		    write(fd_server,buffer,m_size);
		    maxdesc=fd_server>fd_client?fd_server:fd_client;

		} else if (PG_PROTOCOL_MAJOR(m_type)==3) {
		    char *b=buffer+sizeof(m_size)+sizeof(m_type);
		    char *key=NULL;

		    username=dbname=NULL;

		    while (*b) {
			if (key==NULL) {
			    key=b;
			} else {
			    if (VERYVERBOSE)
				mylog(LOG_DEBUG,"%d. key: key='%s' value='%s'",fd_client,key,b);
			    if (strcmp(key,"user")==0)
				username=b;
			    else if (strcmp(key,"database")==0)
				dbname=b;
			    else {
				if (VERYVERBOSE)
				    mylog(LOG_DEBUG,"%d. Unknown connect field: key='%s' value='%s'",fd_client,key,b);
			    }
			    key=NULL;
			}
			b+=strlen(b)+1;
		    }
		    if (VERYVERBOSE) {
			mylog(LOG_DEBUG,"%d. dbname=%s",fd_client,dbname);
			mylog(LOG_DEBUG,"%d. username=%s",fd_client,username);
		    }
		    if (!check_access(sock,dbname)) {
			write(fd_client,"Access denied to database by hostname\n\r",14);
			mylog(LOG_ERR,"Access denied by rule 'postgresql-%s'",dbname);
			exit(0);
		    }
		    if (VERBOSE)
			mylog(LOG_NOTICE,"Connecting to %s as %s",dbname,username);
		    if ((fd_server=ConnectToServer(dbname,username))==0) {
			mylog(LOG_WARNING,"RunForestRun::ConnectToServer failed");
			exit(0);
		    }
		    gotheader=1;
		    stats.server_bytes_send+=m_size;
		    write(fd_server,buffer,m_size);
		    maxdesc=fd_server>fd_client?fd_server:fd_client;

		} else if (PG_PROTOCOL_MAJOR(m_type)==0x04d2) {
		    stats.client_bytes_send++;
		    write(fd_client,"N",1);
		    mylog(LOG_NOTICE,"Requested SSL, not supporting it yet.");
		} else {
		    stats.client_bytes_send++;
		    write(fd_client,"EFATAL: unsupported relay protocol",34);
		    mylog(LOG_NOTICE,"Unknown protocol version: %0x",m_type);
		}

	    } else {
		char buf[1024];
		int n;
		if (VERYVERBOSE)
		    mylog(LOG_DEBUG,"Incoming client data",fd_client);
		if ((n=read(fd_client,&buf,sizeof(buf)))<=0) {
		    mylog(LOG_NOTICE,"Client EOF");
		    exit(0);
		}
		stats.client_bytes_received+=n;
		write(fd_server,buf,n);
		stats.server_bytes_send+=n;
	    }
	}
	if (gotheader && FD_ISSET(fd_server,&in_set)) {
	    char buf[1024];
	    int n;
	    if (VERYVERBOSE)
		mylog(LOG_DEBUG,"Incoming server data",fd_server);
	    if ((n=read(fd_server,&buf,sizeof(buf)))<=0) {
		mylog(LOG_NOTICE,"Server EOF");
		exit(0);
	    }
	    stats.server_bytes_received+=n;
	    write(fd_client,buf,n);
	    stats.client_bytes_send+=n;
	}
    }
}

