/* 
 * Copyright Richard Tobin 1995.
 */

#include <stdio.h>
#include <string.h>
#include <assert.h>

#include "alloc.h"
#include "block_alloc.h"
#include "hash.h"

struct hash_table {
    int key_size;
    int entries;
    int buckets;
    struct hash_entry **bucket;
    Allocator key_allocator, entry_allocator;
};

static unsigned int hash(char *key, int len);
static void rehash(HashTable table);
static HashEntry hash_lookup(HashTable table, char *key, int *foundp, int add);
static int key_compare(HashTable table, char *key1, char *key2);
static char *key_copy(HashTable table, char *key);

/* 
 * Create a hash table.
 * init_size is the initial number of buckets, it doesn't have to be "right".
 * key_size is the size of each key in bytes; if zero they are taken to
 * be null-terminated strings.
 */

HashTable create_hash_table(int init_size, int key_size)
{
    int s, i;
    HashTable table;

    table = xalloc(1, struct hash_table);

    for(s = 256; s < init_size; s <<= 1)
	;

    table->key_size = key_size;
    table->entries = 0;
    table->buckets = s;
    table->bucket = xalloc(s, struct hash_entry *);

    table->entry_allocator =
	make_block_allocator(sizeof(struct hash_entry), 0);

    if(key_size != 0)
	table->key_allocator = make_block_allocator(key_size, 0);
    else
	table->key_allocator = 0;

    for(i=0; i<s; i++)
	table->bucket[i] = 0;

    return table;
}

/* 
 * Free a hash table.
 */

void free_hash_table(HashTable table)
{
    int i;
    HashEntry entry;

    if(!table->key_allocator)
	for(i=0; i<table->buckets; i++)
	    for(entry = table->bucket[i]; entry; entry = entry->next)
		xfree(entry->key);
    else
	destroy_block_allocator(table->key_allocator);

    destroy_block_allocator(table->entry_allocator);
    xfree(table->bucket);
    xfree(table);
}

int hash_count(HashTable table)
{
    return table->entries;
}

HashEntry hash_find(HashTable table, char *key)
{
    return hash_lookup(table, key, 0, 0);
}

HashEntry hash_find_or_add(HashTable table, char *key, int *foundp)
{
    return hash_lookup(table, key, foundp, 1);
}

static HashEntry hash_lookup(HashTable table, char *key, int *foundp, int add)
{
    HashEntry *entry, new;
    unsigned int h = hash(key, table->key_size);

    for(entry = &table->bucket[h % table->buckets];
	*entry; 
	entry = &(*entry)->next)
	if(key_compare(table, (*entry)->key, key) == 0)
	    break;

    if(foundp)
	*foundp = (*entry != 0);

    if(*entry == 0 && add == 0)
	return 0;

    if(*entry != 0)
	return *entry;

    if(table->entries > table->buckets)	/* XXX arbitrary! */
    {
	rehash(table);
	return hash_lookup(table, key, foundp, add);
    }

    new = block_alloc(table->entry_allocator);

    new->key = key_copy(table, key);
    new->value = 0;
    new->next = 0;
    
    table->entries++;

    *entry = new;

    return new;
}

void hash_remove(HashTable table, HashEntry entry)
{
    unsigned int h = hash(entry->key, table->key_size);
    HashEntry *e;

    for(e = &table->bucket[h % table->buckets]; *e; e = &(*e)->next)
	if(*e == entry)
	{
	    *e = entry->next;
	    block_free(table->entry_allocator, entry);
	    table->entries--;
	    return;
	}
    
    fprintf(stderr, "Attempt to remove non-existent entry from table\n");
    abort();
}

void hash_map(HashTable table, void (*function)(HashEntry))
{
    int i;
    HashEntry entry;

    for(i=0; i<table->buckets; i++)
	for(entry = table->bucket[i]; entry; entry = entry->next)
	    (*function)(entry);
}

static void rehash(HashTable table)
{
    HashTable new;
    unsigned h;
    int i;
    HashEntry entry, next, *chain;

    /* XXX Should collect some statistics here */

    new = create_hash_table(2 * table->buckets, table->key_size);

    for(i=0; i<table->buckets; i++)
    {
	for(entry = table->bucket[i]; entry; entry = next)
	{
	    next = entry->next;
	    h = hash(entry->key, table->key_size);
	    chain = &new->bucket[h % new->buckets];
	    entry->next = *chain;
	    *chain = entry;
	    new->entries++;
	}
    }

    assert(new->entries == table->entries);

    xfree(table->bucket);
    *table = *new;
    xfree(new);
}

/*
 * Chris Torek's hash function.  I don't know whether it's any good for
 * this...
 */

static unsigned int hash(char *key, int len)
{
    unsigned int h = 0;		/* should probably be 32 bits */
    int i;

    if(len == 0)
	len = strlen(key);

    for(i=0; i<len; i++)
	h = (h << 5) + h + key[i];

    return h;
}

static int key_compare(HashTable table, char *key1, char *key2)
{
    if(table->key_size == 0)
	return strcmp(key1, key2);
    else
	return memcmp(key1, key2, table->key_size);
}

static char *key_copy(HashTable table, char *key)
{
    char *copy;

    if(table->key_size > 0)
    {
	copy = block_alloc(table->key_allocator);
	memcpy(copy, key, table->key_size);
    }
    else
	copy = xstrdup(key);

    return copy;
}

