module Types where

import FiniteMap 
import List(partition, nub)
import Char(toLower)
import Set

data Object = Boolean Bool
            | Obj BasicType
            | SetOf (Set Object)
            | TupleOf [Object]
            | Function FuncInternals (Object -> Object) 

-- domain, range, and type
data FuncInternals 
    = FuncInternals 
       { getType   :: FunType, getDomain :: Set Object} 
       deriving (Show)

type FunType = (ObjectType, ObjectType)  -- (argType, resType)

-- if an ident has type=Nothing, then we have to look it up in the environment
type TypeSpecification = [(VarName, Maybe ObjectType)]

data ObjectType = Booleans
                | Domain VarName
                | Power ObjectType -- powerset of the type
                | Product [ObjectType]  -- all tuples on given types
                | FunctionType FunType
                deriving (Show, Eq)

type BasicType = Int -- every object is just a number
type VarName = String
data Env = Env { 
       identTypes :: FiniteMap VarName ObjectType,
       identVals  :: FiniteMap VarName Object
} deriving (Show, Eq)

data Expr = Basic Object                     |  -- constants
            Var VarName                      | 
            Quantifier Quan TypeSpecification Expr |
            Apply Expr Expr |  
            Builtin Builtin [Expr] |  -- this handles equals, connectives etc.
            PrintString String |
            Print Expr |
            MakeSetting Setting {- in -} Expr
     deriving (Show, Eq)


data Setting = Assign VarName Expr |
               DefaultType TypeSpecification
     deriving (Show, Eq)

data Builtin = MkTuple |  -- form a tuple out of args
               IfThenElse |
               Equal | Elem |  -- simple relations
               And | Or | Not | Iff | Imp |  -- connectives
               Compose
    deriving (Show, Eq)

data Quan = All | Exists | Find | FindAll | ForEach
    deriving (Show, Eq)

{- Classes and instances -}

instance NiceShow Integer where
    niceShow = reverse . commasInsert . reverse . show 
        where
        commasInsert (x:y:z:w:xs) = (x:y:z:',':commasInsert (w:xs))
        commasInsert xs = xs

instance NiceShow Quan where
    niceShow x = map toLower (show x)

instance NiceShow Setting where
    niceShow (Assign vname expr) 
        = vname ++ " = " ++ niceShow expr
    niceShow (DefaultType typespec)
        = niceShowTypeSpec typespec

niceShowTypeSpec :: TypeSpecification -> String
niceShowTypeSpec typespec
    = listWithBrackets "["  
          ([x ++ ":" ++ niceShow y | (x, Just y) <- typespec]
          ++ [x | (x, Nothing) <- typespec]) "]"

brack :: String -> String
brack x = "(" ++ x ++ ")"

instance NiceShow Expr where
    niceShow (Basic obj) = niceShow obj
    niceShow (Var vname) = vname
    niceShow (Quantifier q typespec expr)
        = niceShow q ++ " " ++ niceShowTypeSpec typespec ++ 
          (if q == All || q == Exists 
              then " (" ++ (niceShow expr)  ++ ") "
              else " {\n" ++ niceShow expr ++ "\n}\n")
    niceShow (Apply x1 x2) = niceShow x1 ++ brack (niceShow x2)
    niceShow builtin@(Builtin _ _) 
        = case builtin of
               Builtin IfThenElse [a,b] -> "if " ++ brack (niceShow a) ++ 
                                         " then " ++ brack (niceShow b)
               Builtin IfThenElse [a,b,c] -> "if " ++ brack (niceShow a) ++ 
                                           " then " ++ brack (niceShow b) ++ 
                                           " else " ++ brack (niceShow c)
               Builtin MkTuple xs -> listWithBrackets "(" (map niceShow xs) ")"
               Builtin Equal xs -> foldl1 (\x y -> x ++ "=" ++ y) (map niceShow xs)
               Builtin Elem [a,b] -> brack (niceShow a) ++ " elem-of " 
                               ++ brack (niceShow b)
               Builtin And [a,b] -> brack (niceShow a) ++ "&" ++ brack (niceShow b)
               Builtin Or [a,b] -> brack (niceShow a) ++ "|" ++ brack (niceShow b)
               Builtin Not [a] -> "~" ++ brack (niceShow a) 
               Builtin Iff [a,b] -> brack (niceShow a) ++ " <-> " ++ brack (niceShow b)
               Builtin Imp [a,b] -> brack (niceShow a) ++ " -> " ++ brack (niceShow b)
               Builtin Compose xs -> foldl1 (\x y -> x ++"."++ y) (map niceShow xs)
               Builtin b exprs -> 
                      let exprs' = map niceShow exprs
                          exprs'' = listWithBrackets "(" exprs' ")"
                      in map toLower (show b) ++ exprs''

    niceShow (PrintString str)
        = "printstr(" ++ show str ++ ")"

    niceShow (Print expr)
        = "print(" ++ niceShow expr ++ ")"

    niceShow (MakeSetting setting expr)
        = "let (" ++ niceShow setting ++ ") in "
           ++ niceShow expr

instance NiceShow ObjectType where
    niceShow Booleans = "Bool"
    niceShow (Domain x) = x
    niceShow (Power x) = "Pow(" ++ niceShow x ++ ")"
    niceShow (Product xs) = foldl1 (\x y -> x ++ "*" ++ y) (map niceShow xs)
    niceShow (FunctionType (a,b)) = "(" ++ niceShow a ++ "->" ++ niceShow b ++ ")"

instance Show Object where
    show (SetOf objs) = listWithBrackets "{" (map show $ setToList objs) "}"
    show (Boolean b) = if b then "T" else "F"
    show (Obj b)  = show b
    show (TupleOf xs) = listWithBrackets "(" (map show xs) ")"
    show f@(Function _ _)
      = let showPair (arg, ret) = show arg ++ " -> " ++ show ret
        in foldl1 (\x y -> x ++ ", " ++ y) $ map showPair $ functionToPairs f

class (Show a) => NiceShow a where
    niceShow :: a -> String
    niceShow x = show x  -- default is to just do a normal show

instance NiceShow Object where
    niceShow obj
       = case obj of Function internals fun -> showFunction internals fun
                     _ -> show obj

showFunction :: FuncInternals -> (Object -> Object) -> String
showFunction int f 
    = case (fromType, toType) of (Product [Domain _, Domain _],_) -> drawTable 
                                 (_, Booleans) -> drawExtension
                                 _ -> show function
    where
    function = Function int f
    (fromType, toType) = getType int
    dom = getDomain int
    pairs = functionToPairs function
    drawExtension 
        = listWithBrackets "extension = {" (map niceShow ext) "}"
        where ext = [x | (x,fx) <- pairs, fx == Boolean True]

    -- this is very cobbled together
    drawTable = init (unlines (firstRows ++ otherRows))
        where
        firstRows = ["  | " ++ unwords (map show (nub (map snd' (setToList dom)))),
                     "--+" ++ replicate (width - 3) '-']
        splitUpPairs = splitUp pairs
        otherRows = map drawRow splitUpPairs
        width = maximum (map length otherRows)
        drawRow xs = show (fst' (head (map fst xs))) ++ " | " ++ unwords (map show' (map snd xs))
        fst' (TupleOf [x,_]) = x
        snd' (TupleOf [_,x]) = x
        show' (Boolean True) = "*"
        show' (Boolean False) = " "
        show' x = show x
    
    -- split up the function pairs into different rows
    splitUp []   = []
    splitUp list = before:(splitUp after)
        where
        firstPart (TupleOf [x,_],_) = x
        firstPart _ = error "firstPart _"
        foo = firstPart (head list)
        (before, after) = List.partition (\x -> firstPart x == foo) list

-- puts commas in between, and brackets around the ends
listWithBrackets :: String -> [String] -> String -> String
listWithBrackets lbrack [] rbrack = lbrack ++ rbrack
listWithBrackets lbrack xs rbrack 
    = lbrack ++ foldl1 (\x y -> x ++ "," ++ y) xs ++ rbrack


instance Eq Object where
    (Boolean x) == (Boolean y) = x == y
    (SetOf x) == (SetOf y) = x == y
    (TupleOf x) == (TupleOf y) = x == y
    (Obj x) == (Obj y) = x == y
    f@(Function _ _) == g@(Function _ _)
        = functionToPairs f == functionToPairs g
    _ == _ = False

-- todo: check this is a partial order
instance Ord Object where
    (Boolean x) <= (Boolean y) = x <= y
    (Obj x) <= (Obj y) = x <= y
    (TupleOf x) <= (TupleOf y) = x <= y
    (SetOf x) <= (SetOf y) = x <= y  -- Ord (Set a) defined below
    f@(Function _ _) <= g@(Function _ _)
        = functionToPairs f <= functionToPairs g
    _ <= _ = False

instance (Show a, Show b) => Show (FiniteMap a b) where
    show = show . fmToList

instance (Ord a, Ord b) => Ord (FiniteMap a b) where
    fm1 <= fm2 = (fmToList fm1) <= (fmToList fm2)

instance (Show a) => Show (Set a) where
    show = show . setToList

instance (Ord a) => Ord (Set a) where
    set1 <= set2 = (setToList set1) <= (setToList set2)


functionToPairs :: Object -> [(Object, Object)]
functionToPairs (Function int fun) = [(x, fun x)|x <- setToList (getDomain int)]
functionToPairs _ = error "functionToPairs applied to non-Function"

pairsToFunction :: FunType -> [(Object, Object)] -> Object
pairsToFunction ftype vals = 
    Function (FuncInternals { getType = ftype, getDomain = dom}) fun
    where
    -- we build a finitemap underneath it all
    fun x = case lookupFM fm x of Just y -> y
                                  _ -> error "undefined application of function"
    fm = listToFM vals
    dom = mkSet (map fst vals)

