package test;

import junit.framework.AssertionFailedError;
import junit.framework.Test;
import junit.framework.TestListener;
import junit.framework.TestResult;
import org.jgroups.*;
import org.jgroups.blocks.GroupRequest;
import org.jgroups.blocks.MethodCall;
import org.jgroups.blocks.RpcDispatcher;
import org.jboss.logging.Logger;

import java.util.Vector;

/**
 * This class should be used as the base class for any distributed test classes
 * or other classes where you need to have a remote barrier so all of the
 * classes can start at the same time.  Note: If want to continually run,
 * then set the number of instances to 1, which will by-pass the shutdown.
 * <br>
 * This class requires JGroups 2.0 to run
 * <br>
 * Since this uses the JUnit framework, can create test just as you would
 * a normal JUnit test, but make sure that if you use setUp() and tearDown()
 * to call the super as well since these are used to control the remote
 * synchronization.
 *
 * @author <a href="mailto:telrod@vocalocity.net">Tom Elrod</a>
 * @version $Revision: 1.6.2.1 $
 */
public class DistributedTestCase extends MultipleTestCase implements MembershipListener
{
    private static final Logger log = Logger.getLogger(DistributedTestCase.class);

    private int parties = 2; //defaults to 2 since most common number of instances (since will always be > 1)

    private Channel channel;
    private RpcDispatcher disp;
    private Address localAddress;
    //TODO: May want to use TCP instead of UDP -TME
    private String props = "UDP(mcast_recv_buf_size=64000;mcast_send_buf_size=32000;" +
            "mcast_port=45566;use_packet_handler=false;ucast_recv_buf_size=64000;" +
            "mcast_addr=228.8.8.8;loopback=false;ucast_send_buf_size=32000;ip_ttl=32):" +
            "PING(timeout=2000;num_initial_members=3):" +
            "MERGE2(max_interval=10000;min_interval=5000):" +
            "FD(timeout=2000;max_tries=3;shun=true):" +
            "VERIFY_SUSPECT(timeout=1500):" +
            "pbcast.NAKACK(max_xmit_size=8192;gc_lag=50;retransmit_timeout=600,1200,2400,4800):" +
            "pbcast.STABLE(desired_avg_gossip=20000):" +
            "UNICAST(timeout=1200,2400,3600):" +
            "FRAG(frag_size=8192;down_thread=false;up_thread=false):" +
            "pbcast.GMS(print_local_addr=true;join_timeout=3000;join_retry_timeout=2000;shun=true)";


    //TODO: Probably want to allow the timeout period to be a parameter -TME
    // How long should wait for everyone to be ready to start. Default 10 seconds
    private long startupTimeout = 60000;
    // How long should wait for everyone to be ready to shutdown.
    // Default 2 minutes since could take a long time to run tests.
    private long shutdownTimeout = 120000;

    private boolean shouldShutdown = true;

    private int shutdownCount;

    private boolean startupWaitFlag = false;
    private boolean shutdownWaitFlag = false;

    private boolean startupCalledFlag = false;
    private boolean shutdownCalledFlag = false;

    //JUnit related variables
    // Used to indicate the number of runs since have to be sure to not call
    // startup() for each test method run.
    private int testRunCount = 0;
    // flag to indicate if should disconnect from JG using shutdown() or endTest()
    private boolean runningAsUnitTest = false;

    private final Object waitObj = new Object();
    private DistributedTestListener testListener = new DistributedTestListener();

    public DistributedTestCase(String name)
    {
        super(name);
    }

    /**
     * Sets the number of total number of remote instances (including this instance).
     * @param numOfInstances
     */
    protected void init(int numOfInstances)
    {
        parties = numOfInstances;
    }

    /**
     * Get the total number of instances running in test case.
     * @return
     */
    public int getNumberOfInstances()
    {
        return parties;
    }

    /**
     * Sends JG message to let other remote test instances know this instance is
     * ready to run.  Will block until all instances are ready.
     * @param numOfInstances - indicates total number of instnaces for remote test.
     * @throws Exception
     */
    public void startup(int numOfInstances) throws Exception
    {
        init(numOfInstances);
        startup();
    }

    /**
     * Sends JG message to let other remote test instances know this instance is
     * ready to run.  Will block until all instances are ready.
     * @throws Exception
     */
    public void startup() throws Exception
    {
        shutdownCount = parties;
        // if more than 1 party, then should shutdown
        shouldShutdown = parties > 1;

        startupWaitFlag = true;
        startupCalledFlag = true;
        sendStartupNotification();

        long startTime = System.currentTimeMillis();
        while(startupWaitFlag)
        {
            try
            {
                synchronized(waitObj)
                {
                    waitObj.wait(1000);
                }

                if(timeoutExpired(startTime, startupTimeout))
                {
                    break;
                }
            }
            catch(InterruptedException e)
            {
                break;
            }
        }

        if(startupWaitFlag)
        {
            // we timed out and still not everyone joined
            disp.stop();
            channel.disconnect();
            throw new Exception("Timed out waiting for other instances to start.");
        }
    }

    /**
     * Should be called when ready to shutdown.  Will notify all other remote test
     * instances and will then block until all other instances have made the same call.
     * @throws Exception
     */
    public void shutdown() throws Exception
    {
        try
        {
            shutdownWaitFlag = true;
            shutdownCalledFlag = true;
            Thread.sleep(1000);
            sendShutdownNotification();

            long startTime = System.currentTimeMillis();
            while(shutdownWaitFlag)
            {
                try
                {
                    //TODO: Need to same waitObj.wait(1000) as is done in startup()
                    Thread.sleep(1000);
                    if(timeoutExpired(startTime, shutdownTimeout))
                    {
                        if(shouldShutdown)
                        {
                            break;
                        }
                    }
                }
                catch(InterruptedException e)
                {
                }
            }

            if(shutdownWaitFlag)
            {
                // we timed out
                throw new Exception("Timed out waiting for other instances to stop.");
            }
        }
        finally
        {
            // if not running as unit test, can disconnect now.
            // otherwise, need to wait till test has ended.
            if(!runningAsUnitTest)
            {
                log.debug("calling disconnect. runningAsUnitTest = " + runningAsUnitTest);
                disconnect();
            }
        }
    }

    /**
     * Disconnects from JGroups
     */
    protected void disconnect()
    {
        //need to give JG a few seconds to send test report
        try
        {
            Thread.sleep(5000);
        }
        catch(InterruptedException e)
        {
            e.printStackTrace();
        }
        try
        {
            log.debug("Disconnecting from JGroups.  Will not be able to send any more messages.");
            disp.stop();
            channel.disconnect();
            /**
             * Can not call close since it will prevent any of the other
             * instances from receiving or sending shutdown notifications.
             */
            //channel.close();
        }
        catch(Exception e)
        {
            log.warn("Exception in disconnect() when stopping and closing channel.", e);
        }
    }

    private boolean timeoutExpired(long startTime, long timeout)
    {
        long duration = System.currentTimeMillis() - startTime;
        if(duration > timeout)
        {
            return true;
        }
        else
        {
            return false;
        }
    }

    private void sendStartupNotification() throws ChannelException
    {
        //JGroups code
        channel = new JChannel(props);
        disp = new RpcDispatcher(channel, null, this, this);
        channel.connect("DistributedTestCase");
        localAddress = channel.getLocalAddress();
    }

    private void sendShutdownNotification()
    {
        MethodCall call = new MethodCall("receiveShutdownNotification",
              new Object[]{localAddress}, new Class[]{Address.class});
        disp.callRemoteMethods(null, call, GroupRequest.GET_NONE, 0);
        log.debug("sent shutdown notification " + call);
    }

    /**
     * Used to indicate when members have joined the JGroups channel for this
     * test case run.
     * @param view
     */
    public void viewAccepted(View view)
    {
        // has everyone joined
        Vector members = view.getMembers();
        int numOfMembers = members.size();
        if(numOfMembers >= parties && startupWaitFlag) // waiting for everyone to start
        {
            startupWaitFlag = false;
            synchronized(waitObj)
            {
                waitObj.notify();
            }
        }
    }

    /**
     * Called using JGroups by other instances when they are ready to shutdown.
     * @param address
     */
    public void receiveShutdownNotification(Address address)
    {
        log.debug("receiveShutdownNotification() from " + address);
        log.debug("shutdownCount = " + (shutdownCount - 1) +
                  " and shutdownWaitFlag = " + shutdownWaitFlag);
        if(--shutdownCount == 0 && shutdownWaitFlag) // waiting for everyone to stop
        {
            if(shouldShutdown)
            {
                shutdownWaitFlag = false;
            }
        }
    }

    private void callRemoteAssert(String methodName, Object[] params)
    {
       int len=params != null? params.length : 0;
       Object[] new_args=new Object[len +1];
       new_args[0]=localAddress;
       for(int i=0; i < params.length; i++)
       {
          new_args[i+1]=params[i];
       }
        MethodCall call = new MethodCall(methodName, new_args);
        disp.callRemoteMethods(null, call, GroupRequest.GET_NONE, 0);
    }

    /*************************************
     * Driver callback for JUnit asserts *
     *************************************/
    /**
     * JGroups callback when a test fails.
     * @param source
     * @param message
     */
    public void receiveAssert(Address source, String message)
    {
        log.warn("Assert source: " + source + "\tmessage = " + message);
    }


    /**************************
     * JUnit methods          *
     **************************/
    public void run(TestResult testResult)
    {
        log.debug("DistributedTestCase::run(TestResult testResult) called.");
        log.debug("countTestCases() = " + countTestCases());
        testResult.addListener(testListener);
        super.run(testResult);
    }

    /**
     * Will check to see if this is the first test method to be run, if it is, will
     * then call startup() to let other instances know ready to run.
     * @throws Exception
     */
    protected void setUp() throws Exception
    {

        log.debug("setUp() - testRunCount = " + testRunCount);
        if(testRunCount == 0)
        {
            // have to make sure startup not already explicitly called
            if(!startupCalledFlag)
            {
                log.debug("calling startup()");
                startup(getNumberOfInstances());
            }
        }
        testRunCount++;
    }

    /**
     * Will call shutdown if this is the last test method to be run.
     * @throws Exception
     */
    protected void tearDown() throws Exception
    {
        log.debug("tearDown() - testRunCount = " + testRunCount);
        log.debug("tearDown() - countTestCases() = " + countTestCases());
        if(testRunCount == countTestCases())
        {
            // need to make sure shutdown not already explicitly called
            if(!shutdownCalledFlag)
            {
                log.debug("calling shutdown()");
                shutdown();
            }
        }
    }

    /************************
     * JGroups methods   *
     ************************/
    public void suspect(Address address)
    {
    }

    public void block()
    {
    }

    /**
     * Listener of the test results which then forward failures or errors
     * on to the other remote instances via JG so they can report it.
     */
    public class DistributedTestListener implements TestListener
    {
        public void addError(Test test, Throwable throwable)
        {
            String message = throwable.getMessage();
            String methodName = "receiveAssert";
            log.debug("addError() called with " + message);
            callRemoteAssert(methodName, new Object[]{message});
        }

        public void addFailure(Test test, AssertionFailedError assertionFailedError)
        {
            String message = assertionFailedError.getMessage();
            String methodName = "receiveAssert";
            log.debug("addFailure() called with " + message);
            callRemoteAssert(methodName, new Object[]{message});
        }

        public void endTest(Test test)
        {
            log.debug("endTest() called.  Calling disconnect().");
            disconnect();
        }

        public void startTest(Test test)
        {
            runningAsUnitTest = true;
            log.debug("startTest() called");
        }

    }


    public static void main(String[] args)
    {
        DistributedTestCase testCase = new DistributedTestCase(DistributedTestCase.class.getName());
        try
        {
            if(args.length > 0)
            {
                int num = Integer.parseInt(args[0]);
                testCase.startup(num);
            }
            else
            {
                testCase.startup();
            }

            testCase.assertTrue(true);
            testCase.assertTrue("test message", false);
            testCase.shutdown();
        }
        catch(Exception e)
        {
            e.printStackTrace();
            System.exit(1);
        }
        System.exit(0);
    }

}
