# -*- Mode: Python; tab-width: 4 -*-
#
#	Author: Sam Rushing <rushing@nightmare.com>
#	$Id: msql.py,v 1.1.1.1 1999/01/08 06:58:44 rushing Exp $
#
# ===========================================================================
#					[synchronous] standalone interface to mSQL.
# ===========================================================================
#
# packet protocol:
# length encoded as a 32-bit integer (always little-endian?) followed
# by the data.
#
# greeting:
# 0:6:1.0.16
# <== <status>:<protocol_version>:<greeting>\n
# ==> <username>\n
# <== -100:\012
#
# '-100' is a success code.
#
# query result pattern:
# <n>:xxx...
# if n is negative, it's a result code, and there's no more data.
# if it's positive, then it tells the length of the next field.
# <i>:<i chars><j>:<j chars> [...] \n
#
# ==================================================
# commands, from "msql_priv.h"

QUIT		= 1
INIT_DB		= 2
QUERY		= 3
DB_LIST		= 4
TABLE_LIST	= 5
FIELD_LIST	= 6
CREATE_DB	= 7
DROP_DB		= 8
RELOAD_ACL	= 9
SHUTDOWN	= 10

# ==================================================

import socket
import string
import struct

ProtocolError = "mSQL Protocol Error"
MsqlError = "mSQL Error"

def packet (s):
	return struct.pack ('l',len(s)) + s

class msql_session:
	def __init__ (self, addr, user):
		if type(addr) == type(()):
			self.socket = socket.socket (socket.AF_INET, socket.SOCK_STREAM)
		else:
			self.socket = socket.socket (socket.AF_UNIX, socket.SOCK_STREAM)
		self.socket.connect (addr)
		greeting = self.get_result()
		self.send ('%s\n' % user)
		self.get_result()

	def send (self, data):
		self.socket.send (packet (data))

	def get_result (self):
		rl = struct.unpack ('l',self.socket.recv(4))[0]
		return string.split (self.socket.recv(rl)[:-1],':')

	def get_query_result (self):
		rl = struct.unpack ('l',self.socket.recv(4))[0]
		data = self.socket.recv(rl)
		result = []
		pos = 0
		while pos < rl and data[pos] != '\n':
			index = string.find (data, ':', pos)
			if index == -1:
				raise ProtocolError, "error parsing msql result string, '%s'" % data
			else:
				try:
					field_len = string.atoi (data[pos:index])
				except ValueError:
					raise ProtocolError, "error parsing msql result string: '%s'" % data
				if field_len < 0:
					return -1
				pos = index + field_len + 1
				result.append (data[index+1:pos])
		return result

	def init_db (self, db):
		self.send ('%d:%s\n' % (INIT_DB, db))
		return self.get_result()

	def read_query_data (self):
		result = []
		while 1:
			r = self.get_query_result()
			if r == -1:
				break
			else:
				result.append (r)
		return result

	def query (self, query):
		self.send ('%d:%s\n' % (QUERY, query))
		r = self.get_result()
		if string.atoi(r[0]) < 0:
			raise MsqlError, r
		try:
			num_fields = string.atoi (r[1])
		except ValueError:
			return r
		result = self.read_query_data()
		fields = []
		for x in range(num_fields):
			fields.append (self.get_query_result())
		self.get_query_result()
		return result, fields

	__getitem__ = query

	def list_dbs (self):
		self.send ('%d:\n' % DB_LIST)
		return self.read_query_data()

	def list_tables (self):
		self.send ('%d:\n' % TABLE_LIST)
		return self.read_query_data()

	def list_fields (self, table):
		self.send ('%d:%s\n' % (FIELD_LIST, table))
		return self.read_query_data()

	def create_db (self, name):
		self.send ('%d:%s\n' % (CREATE_DB, name))
		return self.get_result()

	def drop_db (self, name):
		self.send ('%d:%s\n' % (DROP_DB, name))
		return self.get_result()		

	def shutdown (self):
		self.send ('%d:\n' % (SHUTDOWN))
		return self.get_result()

	def reload_acls (self):
		self.send ('%d:\n' % (RELOAD_ACL))		

	def close (self):
		self.socket.close()
		self.socket = None

if __name__ == '__main__':
	import os
	import pwd
	user = pwd.getpwuid (os.getuid())[0]
	print 'user:',user
	#s = msql_session (('', 1112), user)
	s = msql_session ('/dev/msql', user)
	dbs = s.list_dbs()
	for db in dbs:
		s.init_db (db[0])
		print 'db:',db[0]
		tables = s.list_tables()
		for table in tables:
			print '\ttable:',table[0]
			for field in s.list_fields (table[0]):
				print '\t\tfield:',string.join (field[1:],':')
	s.close()