-- | Symbolic STG Execution Models
module SSTG.Core.Execution.Models
    ( module SSTG.Core.Execution.Models
    ) where

import SSTG.Core.Syntax

import qualified Data.Map as M

-- | Symbolic Transformation
--   We supply some state(s), it gives back those state(s) and some result.
newtype SymbolicT s a = SymbolicT { run :: s -> (s, a) }

-- | Functor instance of Symbolic Transformation
--   Apply transformations on the result.
instance Functor (SymbolicT s) where
    fmap f st = SymbolicT (\s0 -> let (s1, a1) = (run st) s0 in (s1, f a1))

-- | Applicative instance of Symbolic Transformation
--   Can be used to chain together step-wise execution.
instance Applicative (SymbolicT s) where
    pure a    = SymbolicT (\s -> (s, a))
    sf <*> st = SymbolicT (\s0 -> let (s1, a1) = (run st) s0
                                      (s2, f2) = (run sf) s1 in (s2, f2 a1))

-- | Monad instance of Symbolic Transformation
--   Used for transitioning between different types of state manipulations.
instance Monad (SymbolicT s) where
    return a  = pure a
    st >>= fs = SymbolicT (\s0 -> let (s1, a1) = (run st) s0
                                      (s2, a2) = (run (fs a1)) s1 in (s2, a2))

-- | State
data State = State { state_status  :: Status
                   , state_stack   :: Stack
                   , state_heap    :: Heap
                   , state_globals :: Globals
                   , state_code    :: Code
                   , state_names   :: [Name]
                   , state_paths   :: PathCons
                   , state_links   :: SymLinks
                   } deriving (Show, Eq, Read)

-- | Symbolic
data Symbol = Symbol Var (Maybe (Expr, Locals)) deriving (Show, Eq, Read)

-- | Status
data Status = Status { steps :: Int
                     } deriving (Show, Eq, Read)

-- | Stack
newtype Stack = Stack [Frame] deriving (Show, Eq, Read)

-- | Stack Frame
data Frame = CaseFrame   Var [Alt] Locals
           | ApplyFrame  [Atom]    Locals
           | UpdateFrame MemAddr
           deriving (Show, Eq, Read)

-- | Memory Address
newtype MemAddr = MemAddr Int deriving (Show, Eq, Read, Ord)

-- | Value
data Value = LitVal Lit
           | MemVal MemAddr
           deriving (Show, Eq, Read)

-- | Locals
newtype Locals = Locals (M.Map Name Value) deriving (Show, Eq, Read)

-- | Heap
data Heap = Heap (M.Map MemAddr HeapObj) MemAddr deriving (Show, Eq, Read)

-- | Heap Object
data HeapObj = LitObj Lit
             | SymObj Symbol
             | ConObj DataCon [Value]
             | FunObj [Var] Expr Locals
             | Blackhole
             deriving (Show, Eq, Read)

-- | Globals
newtype Globals = Globals (M.Map Name Value) deriving (Show, Eq, Read)

-- | Evaluation State
data Code = Evaluate Expr Locals
          | Return   Value
          deriving (Show, Eq, Read)

-- | Path Constraints
type PathCons = [PathCond]

-- | Path Condition
data PathCond = PathCond (AltCon, [Var]) Expr Locals Bool
              deriving (Show, Eq, Read)

-- | Symbolic Link Table
newtype SymLinks = SymLinks (M.Map Name Name) deriving (Show, Eq, Read)

--   Simple functions that require only the immediate data structure.

-- | Name Occ String
nameOccStr :: Name -> String
nameOccStr (Name occ _ _ _) = occ

-- | Name Unique
nameUnique :: Name -> Int
nameUnique (Name _ _ _ unq) = unq

-- | Var Name
varName :: Var -> Name
varName (Var name _) = name

-- | Mem Addr Int
memAddrInt :: MemAddr -> Int
memAddrInt (MemAddr int) = int

-- | Lookup Locals
lookupLocals :: Var -> Locals -> Maybe Value
lookupLocals var (Locals lmap) = M.lookup (varName var) lmap

-- | Insert Locals
insertLocals :: Var -> Value -> Locals -> Locals
insertLocals var val (Locals lmap) = Locals lmap'
  where lmap' = M.insert (varName var) val lmap

-- | Insert Locals List
insertLocalsList :: [(Var, Value)] -> Locals -> Locals
insertLocalsList []               locals = locals
insertLocalsList ((var, val):vvs) locals = insertLocalsList vvs locals'
  where locals' = insertLocals var val locals

-- | Lookup Heap
lookupHeap :: MemAddr -> Heap -> Maybe HeapObj
lookupHeap addr (Heap hmap _) = M.lookup addr hmap

-- | Allocate Heap
allocHeap :: HeapObj -> Heap -> (Heap, MemAddr)
allocHeap hobj (Heap hmap prev) = (Heap hmap' addr, addr)
  where addr  = MemAddr ((memAddrInt prev) + 1)
        hmap' = M.insert addr hobj hmap

-- | Allocate Heap List
allocHeapList :: [HeapObj] -> Heap -> (Heap, [MemAddr])
allocHeapList []           heap = (heap, [])
allocHeapList (hobj:hobjs) heap = (heapf, addr : as)
  where (heap', addr) = allocHeap hobj heap
        (heapf, as)   = allocHeapList hobjs heap'

-- | Insert Heap
insertHeap :: MemAddr -> HeapObj -> Heap -> Heap
insertHeap addr hobj (Heap hmap prev) = Heap hmap' prev
  where hmap' = M.insert addr hobj hmap

-- | Insert Heap List
insertHeapList :: [(MemAddr, HeapObj)] -> Heap -> Heap
insertHeapList []                 heap = heap
insertHeapList ((addr, hobj):ahs) heap = insertHeapList ahs heap'
  where heap' = insertHeap addr hobj heap

-- | Lookup Globals
lookupGlobals :: Var -> Globals -> Maybe Value
lookupGlobals var (Globals gmap) = M.lookup (varName var) gmap

-- | Insert Globals
insertGlobals :: Var -> Value -> Globals -> Globals
insertGlobals var val (Globals gmap) = Globals gmap'
  where gmap' = M.insert (varName var) val gmap

-- | Insert Globals List
insertGlobalsList :: [(Var, Value)] -> Globals -> Globals
insertGlobalsList []               globals = globals
insertGlobalsList ((var, val):vvs) globals = insertGlobalsList vvs globals'
  where globals' = insertGlobals var val globals

--   Complex functions that involve multiple data structures.

-- | Lookup Value
lookupValue :: Var -> Locals -> Globals -> Maybe Value
lookupValue var locals globals = case lookupLocals var locals of
    Nothing -> lookupGlobals var globals
    mb_val  -> mb_val

-- | Lookup Heap by Variable
vlookupHeap :: Var -> Locals -> Globals -> Heap -> Maybe (MemAddr, HeapObj)
vlookupHeap var locals globals heap = do
    val <- lookupValue var locals globals
    case val of
        LitVal _    -> Nothing
        MemVal addr -> lookupHeap addr heap >>= \hobj -> Just (addr, hobj)

-- | MemAddr Type
memAddrType :: MemAddr -> Heap -> Maybe Type
memAddrType addr heap = do
    hobj <- lookupHeap addr heap
    Just $ case hobj of
        Blackhole           -> Bottom
        LitObj lit          -> litType lit
        SymObj (Symbol s _) -> varType s
        ConObj dcon _       -> dataConType dcon
        FunObj prms expr _  -> foldr FunTy (exprType expr) (map varType prms)