#!/usr/local/bin/ruby
#
# DNS Balance --- ưŪʬԤʤ DNS 
#
# By: YOKOTA Hiroshi <yokota@netlab.is.tsukuba.ac.jp>

# $Id: dns_balance.rb,v 1.17 2000/12/22 01:28:36 elca Exp $

require 'socket'
require 'thread'
require 'getopts'

$:.unshift "/usr/local/etc/dns_balance", "/usr/local/lib/dns_balance"

require 'datatype.rb'
require 'log_writer.rb'
require 'util.rb'

require 'namespace.rb'
require 'addrdb.rb'

#####################################################################
# 桼㳰
class NotImplementedError < StandardError ; end
class TruncatedError      < StandardError ; end
class NoQueryError        < StandardError ; end
class NoMoreResourceError < StandardError ; end


###################################################################
# ؿ

# DNS ѥåȤƤȼΥפȼΥ饹Ф
def parse_packet(packet)
  (number, flags, num_q, ans_rr, ort_rr, add_rr, str) =  packet.unpack("a2 a2 a2 a2 a2 a2 a*")

  if num_q != "\0\1"
    return ["\0", "\0\0", "\0\0"]
  end

  # ̯ʥѥåȤӽ
  #  1 ʤ餳ΩĤϤ
  if (str.split("\0")[0].length+1 +2+2) != str.length
    return ["\0", "\0\0", "\0\0"]
  end

  (q, q_type, q_class) = str.unpack("a#{(str.length-4).to_s} a2 a2")

  return [q, q_type, q_class]
end

# 饤ȤIPɥ쥹֤
def parse_client_addr(str)
  (family, port, ipaddr, pad) = str.unpack("a2 a2 a4 a*")

  return ipaddr
end

# 饤ȤIPɥ쥹ˤäƤѤ
# ֤̾ƤϤޤʪʤ "default" ˤʤ
def select_namespace(addrstr, name)
  p1 = sprintf("%d.%d.%d.%d", addrstr[0], addrstr[1], addrstr[2], addrstr[3])
  p2 = sprintf("%d.%d.%d",    addrstr[0], addrstr[1], addrstr[2])
  p3 = sprintf("%d.%d",       addrstr[0], addrstr[1])
  p4 = sprintf("%d",          addrstr[0])

  # custom namespace
  for i in [p1, p2, p3, p4]
    if $namespace_db[i] != nil && $addr_db[$namespace_db[i]] != nil && $addr_db[$namespace_db[i]][name] != nil
      return $namespace_db[i]
    end
  end

  # address number namespace
  for i in [p1, p2, p3, p4]
    if $addr_db[i] != nil && $addr_db[i][name] != nil
      return i
    end
  end

  return "default"
end

# ŤߤĤѿΤɽ
def make_rand_array(namespace, name)
  rnd_max = 0
  rnd_slesh = []

  for i in $addr_db[namespace][name]
    rnd_max += (10000 - min(10000, i[1])) # badness κͤ 10000
    rnd_slesh.push(rnd_max)
  end

  return [rnd_max, rnd_slesh]
end

# ŤߤĤ
def select_rand_array(namespace, name, size)
  (rnd_max, rnd_slesh) = make_rand_array(namespace, name)

  if rnd_max == 0  # ƤΥۥȤ Badness  10000 ä
    return []
  end

  arr = []
  for i in 0...size
    rnd = rand(rnd_max)
    for j in 0...rnd_slesh.size
      if rnd <= rnd_slesh[j]
	arr.push(j)
	break
      end
    end
  end

  return arr
end

# ѥåȤå
def check_packet(q, q_type, q_class)
  # ž̵
  if q_type == DnsType::AXFR
    raise NotImplementedError
  end

  # IP(UDP) Τ߼դ
  if q_class != DnsClass::INET && q_class != DnsClass::ANY
    raise NoQueryError
  end

  # A/ANY 쥳ɤΤ߼դ
  if q_type != DnsType::A && q_type != DnsType::ANY
    raise NoQueryError
  end

  # ԲĤʸ
  if (q =~ /[()<>@,;:\\\"\.\[\]]/) != nil
    raise NoQueryError
  end
end

def usage()
  print "Usage: dns_balance [-h] [-i ipaddr] [-l logfile] [-p pidfile]\n"
  print "       -l logfile print log to logfile\n"
  print "       -i ipaddr  listen IP address (default:0.0.0.0)\n"
  print "       -p pidfile record PID in pidfile\n"
  print "       -h         help message\n"
  exit(111)
end

######################################################################
# main

srand()
getopts("h", "i:0.0.0.0", 'l:', 'p:')

usage() if $OPT_h

exit! if fork
Process::setsid
exit! if fork
STDIN.close
STDOUT.close
STDERR.close

$pidfile = nil
if $OPT_p
  $pidfile = $OPT_p
  File::open($pidfile, 'w') { |f| f.puts $$ }
end

$logout = nil
if $OPT_l
  $logout = File::open($OPT_l, 'a+')
  $logout.sync = true
end

[0, 2, 3, 5, 10, 13, 15].each do |sig|
  trap(sig) {
    File::unlink($pidfile) if $pidfile
    $logout.close if $logout
    exit
  }
end

put_log("start\n") if $OPT_l

#
# ɥ쥹ǡ١ưŪ
#
Thread.start do
  while TRUE
    if test(?r, "addr")
      load("addr")
      put_log("reload\n") if $OPT_l
    end
    #p $addr_db
    sleep(5*60) # 5 ʬ˹
  end
end

gs = Socket.open(Socket::AF_INET, Socket::SOCK_DGRAM, 0)
gs.bind([Socket::AF_INET, 53, str_to_ipstr($OPT_i), ""].pack("n n a4 a8"))

#
# ᥤ롼
#
loop do
  (packet, client_addr) = gs.recvfrom(512)
  Thread.start do
    begin
      client = parse_client_addr(client_addr)
      (q, q_type, q_class) = parse_packet(packet)
      check_packet(q, q_type, q_class) # -> NoQuery, NotImpl

      name = dnsstr_to_str(q).downcase
      namespace = select_namespace(client, name)

      size = min($addr_db[namespace][name].size, 3)  # -> NameError -> NoQuery
      a_array = select_rand_array(namespace, name, size)

      if a_array.size == 0
	raise NoMoreResourceError
      end

      # 
      r = packet[0,12] + q + q_type + q_class
      r[2] |= 0x84   # answer & authenticated
      r[3] &= ~0x0f  # no error

      ans_addrs = []
      for i in a_array
	addr = $addr_db[namespace][name][i][0]
	ans_addrs.push(addr) if $OPT_l   # ǥХå

	# TTL 褬ĤʤΤʤ TTL Ĺ
	if ($addr_db[namespace][name].size == 1)
	  ttl = "\0\0\x0e\x10" # 1
	else
	  ttl = "\0\0\0\5"     # 5
	end

	#  եåȤ 0x000c
	r += "\xc0\x0c" + DnsType::A + DnsClass::INET + ttl + "\0\4" + addr.pack("CCCC")
      end

      # ο򥻥å
      r[6..7] = [a_array.size].pack("n")

      # Ĺ᤮
      if r.length > 512
	raise TruncatedError
      end

      status = "ok"

    rescue NotImplementedError
      r = packet[0,12] + q + q_type + q_class
      r[2] |= 0x80  # answer
      r[2] &= ~0x04 # not authenticated
      r[3] &= ~0x0f
      r[3] |= 0x04  # not implemented error
      status = "NotImpl"
    rescue TruncatedError
      # Ĺ᤮Ϻäƥե饰ΩƤ
      r = r[0,512]
      r[2] |= 0x02
      status = "Truncated"
    rescue NoMoreResourceError
      r = packet[0,12] + q + q_type + q_class
      r[2] |= 0x84  # answer & authenticated
      r[3] &= ~0x0f
      r[3] |= 0x03  # name error
      status = "NoRes"
    rescue NoQueryError,NameError,StandardError
      r = packet[0,12] + q + q_type + q_class
      r[2] |= 0x80  # answer
      r[2] &= ~0x04 # not authenticated
      r[3] &= ~0x0f
      r[3] |= 0x05  # query refused error
      status = "NoQuery"
    rescue
      # ˤʤϤ
      r = packet[0,12] + q + q_type + q_class
      r[2] |= 0x80  # answer
      r[2] &= ~0x04 # not authenticated
      r[3] &= ~0x0f
      r[3] |= 0x05  # query refused error
      status = "other"
    end

    #print packet.dump, "\n"
    #print r.dump, "\n"
    #p q

    gs.send(r, 0, client_addr)

    logger(client, status, name, namespace, ans_addrs) if $OPT_l # && status == "ok"

  end
end

# end
