{-------------------------------------------------------------------------------

        Copyright:              Bernie Pope 2004

        Module:                 EDT 

        Description:            Definition and construction of the EDT.

        Primary Authors:        Bernie Pope

-------------------------------------------------------------------------------}

{-
    This file is part of buddha.

    buddha is free software; you can redistribute it and/or modify
    it under the terms of the GNU General Public License as published by
    the Free Software Foundation; either version 2 of the License, or
    (at your option) any later version.

    buddha is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    GNU General Public License for more details.

    You should have received a copy of the GNU General Public License
    along with buddha; if not, write to the Free Software
    Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
-}

module EDT  
   ( children
   , parents
   , result
   , name
   , rootNodes
   , derivation
   , recordToDerivation
   , mkRecordArray
   , mkNodeMap
   , size
   , depth
   , validEDT
   ) where

import Meta 
   ( Val )

import Buddha
   ( Record (..) )

import Data.FiniteMap
   ( emptyFM
   , lookupFM
   , addListToFM_C
   , fmToList
   ) 

import Data.Set
   ( Set
   , elementOf
   , addToSet
   , emptySet
   )

import Data.Array
   ( listArray
   , (!)
   , bounds
   ) 

import Data
   ( EDT
   , Record (..)
   , Derivation (..)
   , DebugState (..)
   , RecordArray
   , NodeMap
   , readGlobalState
   )

import ReifyHs
   ( reifyVal )

import Data.PackedString
   ( PackedString )

import TablesUnsafe
   ( readCallCount )

import Monad
   ( liftM )

import List
   ( nub )

--------------------------------------------------------------------------------

rootNumber :: EDT
rootNumber = 0

mkRecordArray :: [Record] -> IO RecordArray
mkRecordArray recs 
   = do count <- readCallCount
        return $ listArray (1, count - 1) $ filter nonRef recs
   where
   nonRef :: Record -> Bool
   nonRef (Ref _ _) = False
   nonRef other = True

mkNodeMap :: [Record] -> NodeMap
mkNodeMap recs
   = mkNodeMapWorker emptyFM (parentChildList recs []) 
   where
   mkNodeMapWorker :: NodeMap -> [(Int,[Int])] -> NodeMap
   mkNodeMapWorker = addListToFM_C (\old [new] -> new:old) 
   parentChildList :: [Record] -> [(Int,[Int])] -> [(Int,[Int])]
   parentChildList [] acc = acc
   parentChildList (Rec parent index _name _args _result _line _modName : recs) acc
      = parentChildList recs ((parent, [index]):acc)
   parentChildList (Ref parent index: recs) acc
      = parentChildList recs ((parent, [index]):acc)
   parentChildList (other:recs) acc
      = parentChildList recs acc

rootNodes :: NodeMap -> [EDT]
rootNodes nodeMap
   = case lookupFM nodeMap rootNumber of
        Nothing -> []
        Just nodes -> nodes

children :: EDT -> IO [EDT]
children edt 
   = do nodeMap <- readGlobalState state_nodeMap
        case lookupFM nodeMap edt of
           Nothing -> return []
           Just kids -> return kids 

derivation :: EDT -> IO Derivation
derivation node 
   = do array <- readGlobalState state_recordArray
        recordToDerivation $ array ! node

recordToDerivation :: Record -> IO Derivation
recordToDerivation (Rec _parent _index name args result line modName)
      = do argGraphs   <- mapM reifyVal args
           resultGraph <- reifyVal result
           return $
              Derivation
              { deriv_name   = name 
              , deriv_args   = argGraphs 
              , deriv_result = resultGraph 
              , deriv_sloc   = line 
              , deriv_module = modName 
              } 

recordToDerivation (Constant _index name result line modName)
      = do resultGraph <- reifyVal result
           return $
              Derivation
              { deriv_name   = name
              , deriv_args   = [] 
              , deriv_result = resultGraph 
              , deriv_sloc   = line 
              , deriv_module = modName 
              }

size :: EDT -> IO Integer
size node 
   = sizeWorker emptySet node
   where
   sizeWorker :: Set EDT -> EDT -> IO Integer
   sizeWorker visited node
      | node `elementOf` visited = return 0
      | otherwise = do kids <- children node
                       kidsSize <- mapM (sizeWorker newVisited) (nub kids)
                       return $ 1 + sum kidsSize 
                    where
                    newVisited = addToSet visited node

depth :: EDT -> IO Integer
depth node 
   = depthWorker emptySet node
   where
   depthWorker :: Set EDT -> EDT -> IO Integer
   depthWorker visited node
      | node `elementOf` visited = return 0
      | otherwise
           = do kids <- children node  
                if null kids 
                    then  return 1 
                    else do kidsSize <- mapM (depthWorker newVisited) (nub kids)
                            return $ 1 + maximum kidsSize 
             where
             newVisited = addToSet visited node

result :: EDT -> IO Val
result node
   = do array <- readGlobalState state_recordArray
        let record = array ! node
        return $ recordResult record 
   where
   recordResult :: Record -> Val
   recordResult (Rec _parent _index _name _args recResult _line _modName) = recResult 
   recordResult (Constant _index _name value _line _modName) = value 

name :: EDT -> IO PackedString 
name node
   = do array <- readGlobalState state_recordArray 
        let record = array ! node
        return $ name record 
   where
   name :: Record -> PackedString 
   name (Rec _parent _index recName _args _result _line _modName) = recName 
   name (Constant _index constName _value _line _modName) = constName 

parents :: EDT -> IO [EDT]
parents node 
   = do nodeMap <- readGlobalState state_nodeMap
        return $ lookupParents node $ fmToList nodeMap
   where
   lookupParents :: EDT -> [(EDT, [EDT])] -> [EDT] 
   lookupParents seek [] = []
   lookupParents seek ((parent, children):rest) 
      | (seek `elem` children) && (parent /= rootNumber)
           = parent : lookupParents seek rest 
      | otherwise = lookupParents seek rest  

validEDT :: EDT -> IO Bool
validEDT node 
   = do (lo,hi) <- liftM bounds $ readGlobalState state_recordArray
        return $ node >= lo && node <= hi
