#include <stdio.h>
#include "Internals.h"

void compare (GraphNode *old, HaskellObj new);
void eval_whnf (HaskellObj in, HaskellObj *out);
void applyFunToArgs (HaskellObj obj, StgStablePtr fun);

/* global reference to the runApp function */
StgStablePtr runAppFun;

void reEvalC (StgStablePtr runApp, StgStablePtr old, StgStablePtr new)
{
   HaskellObj newObj;
   GraphNode *oldGraph;
   newObj = (HaskellObj) deRefStablePtr (new);
   oldGraph = makeHeapGraph (old);
   runAppFun = runApp;
   compare (oldGraph, newObj);
   freeGraph (oldGraph);
   return;
}

void compare (GraphNode *old, HaskellObj new)
{
   HaskellObj new_whnf;
   HaskellObj this_obj;
   StgStablePtr new_stable_ptr;
   int i;

   if (old->tag == Node)
   {
      eval_whnf (new, &new_whnf);
      new_stable_ptr = (StgStablePtr) (getStablePtr((StgPtr) new_whnf));

      for (i=0; i < old->numChildren; i++) 
      {
         this_obj = (HaskellObj) deRefStablePtr (new_stable_ptr);
         compare (old->children[i], this_obj->payload[i]);
      }
      freeStablePtr (new_stable_ptr);
      return;
   }

   /* if old is a primitive/basic type then force new */
   else if (    old->tag == Int  
             || old->tag == Char 
             || old->tag == SmallInteger
             || old->tag == LargeInteger
             || old->tag == Float
             || old->tag == Double )
   {
      eval_whnf (new, &new_whnf);
      return;
   }

   /* if old is a thunk then don't touch new */
   else if (old->tag == Thunk)
   {
      return;
   }

   /* if old is a cycle then don't touch new */

   else if (old->tag == Cycle)
   {
      return;
   }

   /* we have an encoded function - we should evaluate the closure to whnf 
      and we should evaluate the representation fully
      -- note that the representation is the second argument of the
      encoded function */
   /*
   else if (old->tag == EncodedFunctionAsTerm)
   {
      eval_whnf (new, &new_whnf);
      compare (old->children[0], new_whnf->payload[1]);
      return; 
   }
   */

   /* re-evaluated function encoded by map */
   /* find the map of the old function (its representation)
    * and apply the new function to each of the arguments in the map 
    * and demand that the result be the same as before */

   /*
   else if (old->tag == EncodedFunctionAsMap)
   {
      int i = 0;
      StgInfoTable *info;
      StgStablePtr funPtr;

      funPtr = (StgStablePtr) getStablePtr((StgPtr) new); 

      applyFunToArgs ((HaskellObj) deRefStablePtr(old->val.funMap), funPtr);

      freeStablePtr (funPtr);
      
      return; 
   }
   */

   /* eeps what is this? okay bomb */
   else if (old->tag == Unknown)
   {
      fprintf (stderr, "buddha: compare() found an unknown heap object, exiting\n");
      fprintf (stderr, "buddha: please report this as a bug to the developers\n");
      exit (-1);
   }
    
   return;
}

/* force a value to WHNF */
void eval_whnf (HaskellObj in, HaskellObj *out)
{
   rts_eval (in, out);
}

void applyFunToArgs (HaskellObj obj, StgStablePtr fun)
{
   StgInfoTable *listinfo;
   HaskellObj list;
   HaskellObj funarg;
   HaskellObj funresult;
   HaskellObj stref;
   HaskellObj mutvar;
   HaskellObj v;
   HaskellObj pair;
   GraphNode *resultGraph;
   HaskellObj realRunAppFun;
   StgStablePtr resultStablePtr;
   StgStablePtr listStablePointer;
   HaskellObj realFun;

   /* this should be an STRef, whose first arg is a MutVar */
   stref = removeIndirections (obj); 

   obj = stref->payload[0];
   /* this should be a MutVar, which points to a list of pairs */
   mutvar = removeIndirections (obj); 
   /* follow the mutVar */
   obj = followMutVar (mutvar);

   /* this should be a list of pairs */
   list = removeIndirections (obj); 
   listinfo = get_itbl (list);

   /* while the list is not empty */
   while (listinfo->layout.payload.ptrs == 2)
   {
      /* this should be a V constructor applied to a pair */
      obj  = list->payload[0];
      v = removeIndirections (obj);

      /* this should be a pair of things */
      obj  = v->payload[0];
      pair = removeIndirections (obj);

      /* this should be the argument of the function */
      obj  = pair->payload[0];
      funarg = removeIndirections (obj);

      /* this should be the result of the function */
      obj  = pair->payload[1];
      funresult = removeIndirections (obj);

      /* locate the runApp function */
      realRunAppFun = (HaskellObj) deRefStablePtr (runAppFun);

      /* make a graph out of the result */
      resultStablePtr = getStablePtr ((StgPtr) funresult);
      resultGraph = makeHeapGraph (resultStablePtr);
      freeStablePtr (resultStablePtr);

      // save the list pointer before doing some work
      listStablePointer = (StgStablePtr) (getStablePtr((StgPtr) list));

      // get the real (encoded) (F1 ... ... ...) function to apply
      realFun = (HaskellObj) deRefStablePtr (fun);

      /* apply the function to its arg and then demand the result */
      compare (resultGraph, app3 (realRunAppFun, realFun, funarg));

      freeGraph (resultGraph);

      /* go to the next list item */
      list = (HaskellObj) deRefStablePtr (listStablePointer);
      obj  = list->payload[1];
      freeStablePtr (listStablePointer);

      list = removeIndirections (obj);
      listinfo = get_itbl (list);
   }

   return;
}      
