{-------------------------------------------------------------------------------

        Copyright:              Bernie Pope 2003

        Module:                 TransMonad

        Description:            A simple monad to support the transformation
                                code (threading state).

        Primary Authors:        Bernie Pope

-------------------------------------------------------------------------------}

{-
    This file is part of buddha.

    buddha is free software; you can redistribute it and/or modify
    it under the terms of the GNU General Public License as published by
    the Free Software Foundation; either version 2 of the License, or
    (at your option) any later version.

    buddha is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    GNU General Public License for more details.

    You should have received a copy of the GNU General Public License
    along with buddha; if not, write to the Free Software
    Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
-}


module TransMonad 
   ( Trans
   , getSigMap
   , freshVar
   , nFreshVars
   , runTrans
   , isFlagSet
   , getModName
   , context
   , getTransOpt
   , pushConstantStack
   , popConstantStack
   , recordConstant
   , getRecentSrcLoc
   , location
   , lookupTypeSigT
   , incPartialStatByT
   , incSaturatedStatByT
   , incOverSatStatByT
   , incLambdaStatByT
   , incConstantsStatByT
   , incMiscApsStatByT
   , incPatVarApStatByT
   ) where

import Rename     
   ( defaultVarPrefix )

import Language.Haskell.Syntax  
   ( Module
   , HsExp (..)
   , HsPat (..)
   , HsName (..)
   , HsQName (..) 
   , HsQualType
   , SrcLoc 
   ) 

import TransOpts  
   ( ContextMap
   , lookupContextMap
   , TransOpt
   , defaultTransOpt 
   , Context 
   ) 

import Data.Set  
   ( Set
   , emptySet
   , addToSet 
   , elementOf 
   )

import Stack     
   ( emptyStack
   , pushStack
   , popStack
   , Stack
   , peekStack
   , modifyTop 
   )

import SyntaxUtils
   ( madeUpSrcLoc )

import TypeSigMap 
   ( TypeSigMap
   , lookupTypeSig 
   )

import Statistics

import Opts      
   ( CmdLine (..) ) 

--------------------------------------------------------------------------------

data State 
   = State 
     { state_modName       :: Module        -- name of the current module 
     , state_cmdLine       :: CmdLine       -- the whole command line args 
     , state_varCount      :: Int           -- to make each new variable unique
     , state_context       :: Context       -- the context of code that we are in
     , state_contextMap    :: ContextMap    -- mapping from context to transform option
     , state_defTransOpt   :: TransOpt      -- the default trans option (how to make EDTs)
     , state_constantStack :: ConstantStack -- a stack of sets of constant identifiers
     , state_srcLocStack   :: Stack SrcLoc  -- stack of src locations (for error msgs etc)
     , state_typeSigMap    :: TypeSigMap    -- type sigs in this module
     , state_stats         :: Statistics    -- a record of stats about the module
     }

newtype Trans a = Trans (State -> Either String (a, State))

-- instantiate the Monad class so we can use do notation
instance Monad Trans where 
    return a
        = Trans (\state -> Right (a, state))

    Trans comp >>= fun
        = Trans (\state -> 
                     case comp state of
                        Left s -> Left s
                        Right (result, newState)
                           -> case fun result of
                                 Trans comp' -> comp' newState)

    fail s = Trans (\_ -> Left s)

-- run a transformation
runTrans :: Module -> CmdLine -> Int -> ContextMap -> TypeSigMap -> Trans a  
                                     -> Either String (a, Int, Statistics)
runTrans modName cmdLine count contextMap sigMap (Trans comp)
   = case comp initState of
        Left err -> Left err
        Right (result, newState) -> Right ( result 
                                          , state_varCount newState
                                          , state_stats newState
                                          )
   where
   initState = State { state_modName       = modName
                     , state_cmdLine       = cmdLine 
                     , state_varCount      = count
                     , state_contextMap    = contextMap 
                     , state_context       = []
                     , state_defTransOpt   = defTransOpt 
                     , state_constantStack = emptyStack 
                     , state_srcLocStack   = emptyStack 
                     , state_typeSigMap    = sigMap
                     , state_stats         = initStats
                     }
   -- set the default trans opt (how to make EDTs)
   defTransOpt :: TransOpt
   defTransOpt = case lookupContextMap contextMap ["_"] of
                    Nothing -> defaultTransOpt
                    Just opt -> opt
 
-- the set of all pattern-bound (constant) identifiers used in an expression
type ConstantSet = Set HsQName

-- a stack of ConstantSets
type ConstantStack = Stack ConstantSet

-- select a component of the state
select :: (State -> a) -> Trans a
select selector = Trans (\state -> Right (selector state, state))

getConstantStack :: Trans ConstantStack
getConstantStack = select state_constantStack

getVarCount :: Trans Int
getVarCount = select state_varCount

getModName :: Trans Module
getModName = select state_modName

getContext :: Trans Context
getContext = select state_context

getContextMap :: Trans ContextMap
getContextMap = select state_contextMap

getDefaultTransOpt :: Trans TransOpt
getDefaultTransOpt = select state_defTransOpt 

getSigMap :: Trans TypeSigMap
getSigMap = select state_typeSigMap

getStats :: Trans Statistics
getStats = select state_stats

getRecentSrcLoc :: Trans SrcLoc
getRecentSrcLoc 
   = do stack <- select state_srcLocStack
        case peekStack stack of
           Nothing -> return madeUpSrcLoc
           Just top -> return top

lookupTypeSigT :: HsName -> Trans (Maybe HsQualType)
lookupTypeSigT name
   = do sigMap <- getSigMap
        return $ lookupTypeSig sigMap name

incVarCount :: Trans ()
incVarCount = Trans (\state -> 
                          let oldVarCount = state_varCount state
                          in Right ((), state {state_varCount = oldVarCount + 1}))

-- compute a fresh variable
freshVar :: Trans (String, HsPat, HsExp)
freshVar 
   = do currentCount <- getVarCount
        incVarCount
        let varStr   = defaultVarPrefix ++ show currentCount 
            varIdent = HsIdent varStr 
        return (varStr, HsPVar varIdent, HsVar $ UnQual $ varIdent) 

nFreshVars :: Int -> Trans ([String], [HsPat], [HsExp])
nFreshVars n
   = do newVars <- sequence $ replicate n freshVar 
        return $ unzip3 newVars

-- check whether a flag is set
isFlagSet :: String -> Trans Bool
isFlagSet flag
   = do cmdLine <- select state_cmdLine
        return $ flag `elem` (transFlags cmdLine) 

context :: [String] -> Trans a -> Trans a 
context strings trans
   = do mapM_ pushContext strings
        t <- trans
        -- make sure we pop the same number of things that were pushed
        sequence_ $ replicate (length strings) popContext
        return t

location :: SrcLoc -> Trans a -> Trans a 
location sloc trans
   = do pushSrcLoc sloc 
        t <- trans
        popSrcLoc
        return t

pushContext :: String -> Trans ()
pushContext str 
   = Trans (\state -> 
              let oldContext = state_context state
              in Right ((), state { state_context = str : oldContext })
           )

popContext :: Trans ()
popContext 
   = Trans (\state -> 
              let oldContext = state_context state
              in case oldContext of
                    [] -> Right ((), state {state_context = oldContext })
                    (_:t) -> Right ((), state { state_context = t })
           )

pushSrcLoc :: SrcLoc -> Trans ()
pushSrcLoc sloc 
   = Trans (\state -> 
              let oldSrcLocStack = state_srcLocStack state
              in Right ((), state { state_srcLocStack = pushStack sloc oldSrcLocStack })
           )

popSrcLoc :: Trans ()
popSrcLoc 
   = Trans (\state -> 
              let oldStack = state_srcLocStack state
              in case popStack oldStack of
                    Nothing -> Right ((), state {state_srcLocStack = oldStack })
                    Just (_, bottom) -> Right ((), state { state_srcLocStack = bottom })
           )

getTransOpt :: Trans TransOpt
getTransOpt 
   = do contextMap <- getContextMap
        context <- getContext
        case lookupContextMap contextMap context of
           Nothing -> do let wildContext = "_" : tail context
                         case lookupContextMap contextMap wildContext of
                            Nothing -> getDefaultTransOpt 
                            Just opt -> return opt 
           Just opt -> return opt 

pushConstantStack :: Trans ()
pushConstantStack
   = Trans (\state -> Right ( () , let oldStack = state_constantStack state 
                                       newStack = pushStack emptySet oldStack 
                             in state { state_constantStack = newStack } )) 

popConstantStack :: Trans ()
popConstantStack
   = Trans (\state -> Right ( (), 
                             let oldStack = state_constantStack state 
                                 newStack = case popStack oldStack of 
                                               Nothing -> oldStack
                                               Just (_top, bottom) -> bottom 
                             in state { state_constantStack = newStack } )) 

-- check to see if a name is already recorded, if not add it to the set
recordConstant :: HsQName -> Trans Bool
recordConstant name
   = do constantSet <- getConstants
        if elementOf name constantSet 
           then return True
           else do addConstant name
                   return False

addConstant :: HsQName -> Trans ()
addConstant name
   = Trans (\state -> Right (() , 
                             let oldStack = state_constantStack state 
                                 newStack = case modifyTop (flip addToSet name) oldStack of
                                               Nothing -> oldStack
                                               Just modStack -> modStack
                             in state { state_constantStack = newStack } )) 

getConstants :: Trans (Set HsQName)
getConstants 
   = Trans (\state -> let stack = state_constantStack state 
                          constants = case peekStack stack of
                                         Nothing  -> emptySet 
                                         Just set -> set
                      in Right ( constants, state ) )

incStat :: (Int -> Statistics -> Statistics) -> Int -> Trans ()
incStat incFun increment
   = Trans (\state -> let oldStats = state_stats state
                          newStats = incFun increment oldStats
                      in Right ((), state { state_stats = newStats }))

incPartialStatByT :: Int -> Trans ()
incPartialStatByT = incStat incPartialStatBy

incSaturatedStatByT :: Int -> Trans ()
incSaturatedStatByT = incStat incSaturatedStatBy

incOverSatStatByT :: Int -> Trans ()
incOverSatStatByT = incStat incOverSatStatBy

incLambdaStatByT :: Int -> Trans ()
incLambdaStatByT = incStat incLambdaStatBy

incConstantsStatByT :: Int -> Trans ()
incConstantsStatByT = incStat incConstantsStatBy

incMiscApsStatByT :: Int -> Trans ()
incMiscApsStatByT = incStat incMiscApsStatBy

incPatVarApStatByT :: Int -> Trans ()
incPatVarApStatByT = incStat incPatVarApStatBy
