package org.jboss.remoting;

import junit.framework.*;
import org.jgroups.*;
import org.jgroups.blocks.GroupRequest;
import org.jgroups.blocks.MethodCall;
import org.jgroups.blocks.RpcDispatcher;

import java.util.Vector;

/**
 * This class should be used as the base class for any distributed test classes
 * or other classes where you need to have a remote barrier so all of the
 * classes can start at the same time.
 * <br>
 * This class requires JGroups 2.0 to run
 *
 * @author <a href="mailto:telrod@vocalocity.net">Tom Elrod</a>
 * @version $Revision: 1.2 $
 */
public class DistributedTest
   extends TestCase
   implements MembershipListener
{
   private int parties = 2; //defaults to 2 since most common number of instances (since will always be > 1)

   private Channel channel;
   private RpcDispatcher disp;
   private Address loacalAddress;
   //TODO: May want to use TCP instead of UDP -TME
   private String props = "UDP(mcast_recv_buf_size=64000;mcast_send_buf_size=32000;" +
      "mcast_port=45566;use_packet_handler=false;ucast_recv_buf_size=64000;" +
      "mcast_addr=228.8.8.8;loopback=false;ucast_send_buf_size=32000;ip_ttl=32):" +
      "PING(timeout=2000;num_initial_members=3):" +
      "MERGE2(max_interval=10000;min_interval=5000):" +
      "FD(timeout=2000;max_tries=3;shun=true):" +
      "VERIFY_SUSPECT(timeout=1500):" +
      "pbcast.STABLE(desired_avg_gossip=20000):" +
      "pbcast.NAKACK(max_xmit_size=8192;gc_lag=50;retransmit_timeout=600,1200,2400,4800):" +
      "UNICAST(timeout=1200,2400,3600):" +
      "FRAG(frag_size=8192;down_thread=false;up_thread=false):" +
      "pbcast.GMS(print_local_addr=true;join_timeout=3000;join_retry_timeout=2000;shun=true)";


   //TODO: Probably want to allow the timeout period to be a parameter -TME
   // How long should wait for everyone to be ready to start. Default 10 seconds
   private long startupTimeout = 20000;
   // How long should wait for everyone to be ready to shutdown.
   // Default 2 minutes since could take a long time to run tests.
   private long shutdownTimeout = 120000;

   private int shutdownCount;

   private boolean startupWaitFlag = false;
   private boolean shutdownWaitFlag = false;

   private final Object waitObj = new Object();

   public DistributedTest(String name)
   {
      super(name);
      parties = Integer.getInteger("jboss.test.distributed.instancecount", 2).intValue();
   }

   public int getNumberOfInstances()
   {
      return parties;
   }

   protected void settUp() throws Exception
   {
      shutdownCount = parties;
      startupWaitFlag = true;
      sendStartupNotification();

      long startTime = System.currentTimeMillis();
      while(startupWaitFlag)
      {
         try
         {
            synchronized(waitObj)
            {
               waitObj.wait(1000);
            }

            if(timeoutExpired(startTime, startupTimeout))
            {
               break;
            }
         }
         catch(InterruptedException e)
         {
            break;
         }
      }

      if(startupWaitFlag)
      {
         // we timed out and still not everyone joined
         disp.stop();
         channel.disconnect();
         throw new Exception("Timed out waiting for other instances to start.");
      }
   }

   protected void shutDown() throws Exception
   {
      shutdownWaitFlag = true;
      sendShutdownNotification();

      long startTime = System.currentTimeMillis();
      while(shutdownWaitFlag)
      {
         try
         {
            //TODO: Need to same waitObj.wait(1000) as is done in startup()
            Thread.sleep(1000);
            if(timeoutExpired(startTime, shutdownTimeout))
            {
               break;
            }
         }
         catch(InterruptedException e)
         {
         }
      }

      if(shutdownWaitFlag)
      {
         // we timed out
         throw new Exception("Timed out waiting for other instances to stop.");
      }
      disp.stop();
      channel.disconnect();
   }

   private boolean timeoutExpired(long startTime, long timeout)
   {
      long duration = System.currentTimeMillis() - startTime;
      if(duration > timeout)
      {
         return true;
      }
      else
      {
         return false;
      }
   }

   private void sendStartupNotification() throws ChannelException
   {
      //JGroups code
      channel = new JChannel(props);
      disp = new RpcDispatcher(channel, null, this, this);
      channel.connect("DistributedTestCase");
      loacalAddress = channel.getLocalAddress();
   }

   private void sendShutdownNotification()
   {
      MethodCall call = new MethodCall();
      call.setName("receiveShutdownNotification");
      call.addArg(loacalAddress);
      disp.callRemoteMethods(null, call, GroupRequest.GET_NONE, 0);
   }

   public void viewAccepted(View view)
   {
      // has everyone joined
      Vector members = view.getMembers();
      int numOfMembers = members.size();
      if(numOfMembers >= parties && startupWaitFlag) // waiting for everyone to start
      {
         startupWaitFlag = false;
         synchronized(waitObj)
         {
            waitObj.notify();
         }
      }
   }

   public void receiveShutdownNotification(Address address)
   {
      if(--shutdownCount == 0 && shutdownWaitFlag) // waiting for everyone to stop
      {
         shutdownWaitFlag = false;
      }
   }

   public void suspect(Address address)
   {
   }

   public void block()
   {
   }
}
