/* $Id: relaydb.c,v 1.15 2003/12/18 12:17:00 dhartmei Exp $ */

/*
 * Copyright (c) 2003 Daniel Hartmeier
 * All rights reserved.
 *
 * Redistribution and use in source and binary forms, with or without
 * modification, are permitted provided that the following conditions
 * are met:
 *
 *    - Redistributions of source code must retain the above copyright
 *      notice, this list of conditions and the following disclaimer.
 *    - Redistributions in binary form must reproduce the above
 *      copyright notice, this list of conditions and the following
 *      disclaimer in the documentation and/or other materials provided
 *      with the distribution.
 *
 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
 * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
 * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
 * FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
 * COPYRIGHT HOLDERS OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
 * INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
 * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
 * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
 * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
 * ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
 * POSSIBILITY OF SUCH DAMAGE.
 *
 */

static const char rcsid[] = "$Id: relaydb.c,v 1.15 2003/12/18 12:17:00 dhartmei Exp $";

#include <db.h>
#include <errno.h>
#include <fcntl.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <time.h>
#include <unistd.h>

struct data {
	int	 white;
	int	 black;
	time_t	 mtime;
};

struct data_old {
	int	 white;
	int	 black;
};

extern char	*__progname;
const int	 bufsiz = 1024;
const int	 factor = 3;
int		 debug = 0;
int		 action = 0;
int		 reverse = 0;
int		 traverse = 1;
int		 use_v4 = 1;
int		 use_v6 = 1;
BTREEINFO	 btreeinfo;
DB		*db;
DBT		 dbk, dbd;

int	 read_data(const DBT *, struct data *);
int	 address_valid_ipv4(const char *);
int	 address_valid_ipv6(const char *);
int	 address_private(const char *);
int	 check(const char *);
void	 read_headers(void);
void	 import_file(const char *);
void	 usage(void);

int
read_data(const DBT *dbd, struct data *d)
{
	if (dbd->size == sizeof(*d))
		memcpy(d, dbd->data, sizeof(*d));
	else if (dbd->size == sizeof(struct data_old)) {
		struct data_old o;

		memcpy(&o, dbd->data, sizeof(o));
		d->white = o.white;
		d->black = o.black;
		d->mtime = 0;
	} else
		return (1);
	return (0);
}

int
address_valid_v4(const char *a)
{
	if (!*a)
		return (0);
	while (*a)
		if ((*a >= '0' && *a <= '9') || *a == '.')
			a++;
		else
			return (0);
	return (1);
}

int
address_valid_v6(const char *a)
{
	if (!*a)
		return (0);
	while (*a)
		if ((*a >= '0' && *a <= '9') || 
		    (*a >= 'a' && *a <= 'f') ||
		    (*a >= 'A' && *a <= 'F') ||
		    *a == ':')
			a++;
		else
			return (0);
	return (1);
}

int
address_private(const char *a)
{
	if (!strcmp(a, "::1"))
		return (1);
	if (!strncmp(a, "127.", 4) ||
	    !strncmp(a, "10.", 3) ||
	    !strncmp(a, "172.16.", 7) ||
	    !strncmp(a, "192.168.", 8))
		return (1);
	return (0);
}

int
check(const char *address)
{
	int		 r;
	struct data	 d;

	if (!strncmp(address, "IPv6:", 5))
		address += 5;

	if (!((use_v4 && address_valid_v4(address)) ||
	    (use_v6 && address_valid_v6(address)))) {
		if (debug)
			printf("invalid address '%s'\n", address);
		return (1);
	} else if (address_private(address))
		return (0);

	if (debug)
		printf("checking %s\n", address);

	memset(&dbk, 0, sizeof(dbk));
	dbk.size = strlen(address);
	dbk.data = (char *)address;
	memset(&dbd, 0, sizeof(dbd));
	r = db->get(db, &dbk, &dbd, 0);
	if (r < 0) {
		fprintf(stderr, "db->get() %s\n", strerror(errno));
		return (1);
	}
	if (r) {
		if (debug)
			printf("  not found, inserting new host %s %s 1\n",
			    address, (action == 'b' ? "black" : "white"));
		memset(&d, 0, sizeof(d));
		if (action == 'b')
			d.black = 1;
		else
			d.white = 1;
		d.mtime = time(NULL);
		memset(&dbk, 0, sizeof(dbk));
		dbk.size = strlen(address);
		dbk.data = (char *)address;
		memset(&dbd, 0, sizeof(dbd));
		dbd.size = sizeof(d);
		dbd.data = &d;
		r = db->put(db, &dbk, &dbd, 0);
		if (r)
			fprintf(stderr, "db->put() %s\n", strerror(errno));
		return (1);
	} else {
		if (read_data(&dbd, &d)) {
			fprintf(stderr, "db->get() invalid data\n");
			return (1);
		}
		if (debug) {
			printf("  found, ");
			if (action == 'b')
				printf("white %d black %d -> %d\n",
				    d.white, d.black, d.black + 1);
			else
				printf("white %d -> %d black %d\n",
				    d.white, d.white + 1, d.black);
		}
		if (action == 'b') {
			if (reverse) {
				if (d.black > 0)
					d.black--;
			} else
				d.black++;
		} else {
			if (reverse) {
				if (d.white > 0)
					d.white--;
			} else
				d.white++;
		}
		d.mtime = time(NULL);
		memset(&dbk, 0, sizeof(dbk));
		dbk.size = strlen(address);
		dbk.data = (char *)address;
		memset(&dbd, 0, sizeof(dbd));
		dbd.size = sizeof(d);
		dbd.data = &d;
		r = db->put(db, &dbk, &dbd, 0);
		if (r) {
			fprintf(stderr, "db->put() %s\n", strerror(errno));
			return (1);
		}
		if (!traverse || d.black >= factor * d.white) {
			if (debug)
				printf("ignoring further headers\n");
			return (1);
		} else {
			if (debug)
				printf("checking next header\n");
			return (0);
		}
	}
}

void
read_headers()
{
	char	 buf[bufsiz], c;
	int	 pos = 0;
	int	 hdr = 1;

	while (hdr && fread(&c, 1, 1, stdin) > 0) {
		if (pos == bufsiz || c == '\n') {
			buf[pos] = 0;
			if (!pos)
				hdr = 0;
			pos = 0;
			if (hdr) {
				if (!strncmp(buf, "Received:", 9)) {
					char *b, *e;

					b = strchr(buf, '[');
					e = strchr(buf, ']');
					if (b != NULL && e != NULL && b < e) {
						*e = 0;
						if (check(b + 1))
							break;
					}
				}
			}
		} else
			buf[pos++] = c;
	}
}

time_t
parse_syslog_time(const char *s)
{
	const char *names[] = { "Jan", "Feb", "Mar", "Apr", "May", "Jun",
	    "Jul", "Aug", "Sep", "Oct", "Nov", "Dec", NULL };
	char mon[4];
	time_t t = time(NULL);
	struct tm tm;

	memcpy(&tm, localtime(&t), sizeof(tm));
	if (sscanf(s, "%3s %d %d:%d:%d", mon, &tm.tm_mday,
	    &tm.tm_hour, &tm.tm_min, &tm.tm_sec) != 5)
		return (0);
	for (tm.tm_mon = 0; names[tm.tm_mon]; ++tm.tm_mon)
		if (!strcmp(names[tm.tm_mon], mon))
			break;
	if (names[tm.tm_mon] == NULL)
		return (0);
	return (mktime(&tm));
}

void
parse_syslog(const char *filename)
{
	FILE		*f;
	char		 buf[bufsiz], address[128], c;
	int		 pos = 0, r;
	struct data	 d;
	unsigned	 count = 0;

	if (debug)
		printf("reading syslog %s\n", filename);
	f = fopen(filename, "r");
	if (f == NULL) {
		fprintf(stderr, "fopen: %s: %s\n", filename, strerror(errno));
		return;
	}

	while (fread(&c, 1, 1, f) > 0)
		if (pos == bufsiz || c == '\n') {
			char *p;
			time_t mtime;

			buf[pos] = 0;
			pos = 0;
			p = strstr(buf, ": connected (");
			if (p == NULL || strstr(buf, " spamd[") == NULL)
				continue;
			*p = 0;
			p = strrchr(buf, ':');
			if (p == NULL || p[1] != ' ')
				continue;
			strlcpy(address, p + 2, sizeof(address));
			if (!((use_v4 && address_valid_v4(address)) ||
			    (use_v6 && address_valid_v6(address))) ||
			    address_private(address))
				continue;
			if (!(mtime = parse_syslog_time(buf)))
				continue;
			memset(&dbk, 0, sizeof(dbk));
			dbk.size = strlen(address);
			dbk.data = address;
			memset(&dbd, 0, sizeof(dbd));
			r = db->get(db, &dbk, &dbd, 0);
			if (r < 0) {
				fprintf(stderr, "db->get() %s\n",
				    strerror(errno));
				goto done;
			}
			if (r)
				continue;
			if (read_data(&dbd, &d)) {
				fprintf(stderr, "db->get() invalid data\n");
				goto done;
			}
			if (d.mtime >= mtime)
				continue;
			if (debug)
				printf("touching %lu %s\n",
				    (unsigned long)mtime, address);
			d.mtime = mtime;
			memset(&dbk, 0, sizeof(dbk));
			dbk.size = strlen(address);
			dbk.data = address;
			memset(&dbd, 0, sizeof(dbd));
			dbd.size = sizeof(d);
			dbd.data = &d;
			r = db->put(db, &dbk, &dbd, 0);
			if (r) {
				fprintf(stderr, "db->put() %s\n",
				    strerror(errno));
				goto done;
			}
			count++;
		} else
			buf[pos++] = c;

done:
	fclose(f);
	printf("%u entries touched\n", count);
}

void
import_file(const char *filename)
{
	FILE		*f;
	char		 buf[bufsiz], address[128], c;
	int		 pos = 0, r;
	struct data	 d;
	unsigned	 count = 0;

	if (debug)
		printf("importing %s\n", filename);
	memset(&d, 0, sizeof(d));
	f = fopen(filename, "r");
	if (f == NULL) {
		fprintf(stderr, "fopen: %s: %s\n", filename, strerror(errno));
		return;
	}

	while (fread(&c, 1, 1, f) > 0)
		if (pos == bufsiz || c == '\n') {
			unsigned long mtime;

			buf[pos] = 0;
			pos = 0;
			r = sscanf(buf, "%127s %d %d %lu", address,
			    &d.white, &d.black, &mtime);
			if (r == 4)
				d.mtime = mtime;
			else if (r == 3)
				d.mtime = time(NULL);
			else {
				fprintf(stderr, "sscanf() invalid input '%s'\n",
				    buf);
				fclose(f);
				return;
			}
			if (!((use_v4 && address_valid_v4(address)) ||
			    (use_v6 && address_valid_v6(address))) ||
			    address_private(address))
				continue;
			if (debug)
				printf("adding %s %d %d %lu\n",
				    address, d.white, d.black,
				    (unsigned long)d.mtime);
			memset(&dbk, 0, sizeof(dbk));
			dbk.size = strlen(address);
			dbk.data = address;
			memset(&dbd, 0, sizeof(dbd));
			dbd.size = sizeof(d);
			dbd.data = &d;
			r = db->put(db, &dbk, &dbd, 0);
			if (r) {
				fprintf(stderr, "db->put() %s\n",
				    strerror(errno));
				fclose(f);
				return;
			}
			count++;
		} else
			buf[pos++] = c;

	fclose(f);
	printf("%u entries imported\n", count);
}

void
usage()
{
	fprintf(stderr, "usage: %s [-46bdlnrvw] "
	    "[-BW [+-]num] [-m [+-]days]\n\t[-f filename] "
	    "[-i filename] [-t filename]\n", __progname);
	exit(1);
}

int
main(int argc, char *argv[])
{
	int		 list = 0, delete = 0;
	const char	*filename = NULL, *import = NULL, *syslog = NULL;
	time_t		 mtime = 0;
	int		 mtime_op = 0;
	int		 black = -1, white = -1;
	int		 black_op = 0, white_op = 0;
	int		 ch;
	unsigned	 count = 0;

	while ((ch = getopt(argc, argv, "46bB:df:i:lm:nrt:vwW:")) != -1) {
		switch (ch) {
		case '4':
			use_v4 = 1;
			use_v6 = 0;
			break;
		case '6':
			use_v6 = 1;
			use_v4 = 0;
			break;
		case 'b':
		case 'w':
			action = ch;
			break;
		case 'B':
			if (*optarg == '+') {
				black_op = 1;
				optarg++;
			} else if (*optarg == '-') {
				black_op = -1;
				optarg++;
			}
			black = atol(optarg);
			break;
		case 'W':
			if (*optarg == '+') {
				white_op = 1;
				optarg++;
			} else if (*optarg == '-') {
				white_op = -1;
				optarg++;
			}
			white = atol(optarg);
			break;
		case 'd':
			delete = 1;
			break;
		case 'f':
			filename = optarg;
			break;
		case 'i':
			import = optarg;
			break;
		case 'l':
			list = 1;
			break;
		case 'm':
			if (*optarg == '+') {
				mtime_op = 1;
				optarg++;
			} else if (*optarg == '-') {
				mtime_op = -1;
				optarg++;
			}
			mtime = time(NULL) - atol(optarg) * 60 * 60 * 24;
			break;
		case 'n':
			traverse = 0;
			break;
		case 'r':
			reverse = 1;
			break;
		case 't':
			syslog = optarg;
			break;
		case 'v':
			debug++;
			break;
		default:
			usage();
		}
	}

	if (!list && !delete && !action && import == NULL && syslog == NULL)
		usage();

	if (delete && !action && !mtime && black == -1 && white == -1) {
		fprintf(stderr, "to delete all entries, delete the file\n");
		return (1);
	}

	if (filename == NULL) {
		const char	*home = getenv("HOME");
		const char	*file = "/.relaydb";
		int		 len;
		char		*fn;

		if (home == NULL) {
			fprintf(stderr, "-f not specified and $HOME undefined\n");
			return (1);
		}
		len = strlen(home)+strlen(file)+1;
		fn = (char *)malloc(len);
		if (fn == NULL) {
			fprintf(stderr, "malloc: %s\n", strerror(errno));
			return (1);
		}
		strlcpy(fn, home, len);
		strlcat(fn, file, len);
		filename = fn;
	}

	memset(&btreeinfo, 0, sizeof(btreeinfo));
	db = dbopen(filename, O_CREAT|O_EXLOCK|(list ? O_RDONLY : O_RDWR),
	    0600, DB_BTREE, &btreeinfo);
	if (db == NULL) {
		fprintf(stderr, "dbopen: %s: %s\n", filename, strerror(errno));
		return (1);
	}

	if (syslog != NULL)
		parse_syslog(syslog);
	else if (import != NULL)
		import_file(import);
	else if (list || delete) {
		int		 r;
		struct data	 d;
		char		 a[128];

		memset(&dbk, 0, sizeof(dbk));
		memset(&dbd, 0, sizeof(dbd));
		for (r = db->seq(db, &dbk, &dbd, R_FIRST); !r;
		    r = db->seq(db, &dbk, &dbd, R_NEXT)) {
			if (dbk.size < 1 || dbk.size >= sizeof(a) ||
			    read_data(&dbd, &d)) {
				fprintf(stderr, "db->seq() invalid data\n");
				if (db->close(db))
					fprintf(stderr, "db->close() %s\n",
					    strerror(errno));
				return (1);
			}
			if (black != -1)
				if ((black_op == 0 && d.black != black) ||
				    (black_op == 1 && d.black < black) ||
				    (black_op == -1 && d.black > black))
					continue;
			if (white != -1)
				if ((white_op == 0 && d.white != white) ||
				    (white_op == 1 && d.white < white) ||
				    (white_op == -1 && d.white > white))
					continue;
			if (action == 'b' && d.black <= factor * d.white)
				continue;
			if (action == 'w' && d.black  > factor * d.white)
				continue;
			if (mtime)
				if ((mtime_op == 0 && (d.mtime <
				    mtime - 60 * 60 * 24 || d.mtime > mtime)) ||
				    (mtime_op == 1 && d.mtime >
				    mtime - 60 * 60 * 24) ||
				    (mtime_op == -1 && d.mtime < mtime))
					continue;
			memcpy(a, dbk.data, dbk.size);
			a[dbk.size] = 0;
			if (!((use_v4 && address_valid_v4(a)) ||
			    (use_v6 && address_valid_v6(a))))
				continue;
			if (list) {
				if (debug)
					printf("%s %d %d %lu\n",
					    a, d.white, d.black,
					    (unsigned long)d.mtime);
				else
					printf("%s\n", a);
			} else {
				if (debug)
					printf("deleting %s\n", a);
				if (db->del(db, &dbk, 0)) {
					fprintf(stderr, "db->del() %s\n",
					    strerror(errno));
					db->sync(db, 0);
					db->close(db);
					return (1);
				}
				count++;
			}
		}
	} else {
		if (debug)
			printf("reading mail headers, considering the mail "
			    "%sspam\n", (action == 'b' ? "" : "not "));
		read_headers();
	}
	if (delete && !list)
		printf("%u entries deleted\n", count);

	if (!list && db->sync(db, 0))
		fprintf(stderr, "db->sync() %s\n", strerror(errno));
	if (db->close(db))
		fprintf(stderr, "db->close() %s\n", strerror(errno));
	return (0);
}


syntax highlighted by Code2HTML, v. 0.9.1