-- | Symbolic STG Execution Engine
module SSTG.Core.Execution.Engine
    ( loadState
    , loadStateEntry
    , LoadResult (..)
    , RunFlags (..)
    , StepType (..)
    , execute
    , execute1
    ) where

import SSTG.Core.Syntax
import SSTG.Core.Execution.Models
import SSTG.Core.Execution.Naming
import SSTG.Core.Execution.Stepping

import qualified Data.Map as M

-- | Load Result
data LoadResult = LoadOkay  State
                | LoadGuess State [Binding]
                | LoadError String
                deriving (Show, Eq, Read)

-- | Guess the main function as @"main"@, which is consistent with a few
-- experimental results.
loadState :: Program -> LoadResult
loadState prog = loadStateEntry main_occ_name prog
  where main_occ_name = "main"  -- Based on a few experimental programs.

-- | Load from a specified entry point.
loadStateEntry :: String -> Program -> LoadResult
loadStateEntry entry (Program bnds) = if length matches == 0
    then LoadError ("No entry candidates found for: [" ++ entry ++ "]")
    else if length others == 0
        then LoadOkay  state
        else LoadGuess state (map fst others)
  where -- Status or something.
        status   = Status { steps = 0 }
        -- Stack initialized to empty.
        stack    = Stack []
        -- Globals and Heap are loaded together. They are still beta forms now.
        heap0    = Heap M.empty (MemAddr 0)
        (glist, heap1, bnd_addrss) = initGlobals bnds heap0
        globals0 = Globals (M.fromList glist)
        (heap2, localss) = liftBindings bnd_addrss globals0 heap1
        bnd_locs = zip bnds localss
        -- Code loading. Completes heap and globals with symbolic injection.
        matches  = entryMatches entry bnd_locs
        ((tgt_bnd, tgt_loc):others) = matches
        ((tgt_var, tgt_rhs):_)      = lhsMatches entry tgt_bnd
        (code, globals, heap) = loadCode tgt_var tgt_rhs tgt_loc globals0 heap2
        -- Ready to fill the state.
        state0   = State { state_status  = status
                         , state_stack   = stack
                         , state_heap    = heap
                         , state_globals = globals
                         , state_code    = code
                         , state_names   = []
                         , state_paths   = []
                         , state_links   = SymLinks M.empty }

        -- Gather information on all variables.
        state    = state0 { state_names = allNames state0 }

-- | Allocate Binding
allocBinding :: Binding -> Heap -> (Heap, [MemAddr])
allocBinding (Binding _ pairs) heap = (heap', addrs)
  where hfakes = map (const Blackhole) pairs
        (heap', addrs) = allocHeapList hfakes heap

-- | Allocate List of Bindings
allocBindingList :: [Binding] -> Heap -> (Heap, [[MemAddr]])
allocBindingList []     heap = (heap, [])
allocBindingList (b:bs) heap = (res_heap, addrs : as)
  where (heap', addrs) = allocBinding b heap
        (res_heap, as) = allocBindingList bs heap'

-- | Binding Address to Name Values
bndAddrsToNameVals :: (Binding, [MemAddr]) -> [(Name, Value)]
bndAddrsToNameVals (Binding _ rhss, addrs) = zip names pointers
  where names    = map (varName . fst) rhss
        pointers = map (\a -> MemVal a) addrs

-- | Initialize Globals
initGlobals :: [Binding] -> Heap ->
               ([(Name, Value)], Heap, [(Binding, [MemAddr])])
initGlobals bnds heap = (name_vals, heap', bnd_addrss)
  where (heap', addrss) = allocBindingList bnds heap
        bnd_addrss = zip bnds addrss
        name_vals  = concatMap bndAddrsToNameVals bnd_addrss

-- | Force Atom Lookup
forceLookupValue :: Atom -> Locals -> Globals -> Value
forceLookupValue (LitAtom lit) _      _       = LitVal lit
forceLookupValue (VarAtom var) locals globals =
    case lookupValue var locals globals of
        Nothing  -> LitVal BlankAddr  -- An error, but I want to not crash.
        Just val -> val

-- | Full Rhs Object
forceRhsObj :: BindRhs -> Locals -> Globals -> HeapObj
forceRhsObj (FunForm prms expr) locals _       = FunObj prms expr locals
forceRhsObj (ConForm dcon args) locals globals = ConObj dcon arg_vals
  where arg_vals = map (\a -> forceLookupValue a locals globals) args

-- | Lift Binding
liftBinding :: (Binding, [MemAddr]) -> Globals -> Heap -> (Heap, Locals)
liftBinding (Binding rec pairs, addrs) globals heap = (heap', locals)
  where (vars, rhss) = unzip pairs
        mem_vals = map (\a -> MemVal a) addrs
        e_locs   = Locals M.empty
        r_locs   = insertLocalsList (zip vars mem_vals) e_locs
        locals   = case rec of { Rec -> r_locs; NonRec -> e_locs }
        hobjs    = map (\r -> forceRhsObj r locals globals) rhss
        heap'    = insertHeapList (zip addrs hobjs) heap

-- | Lift Binding List
liftBindings :: [(Binding, [MemAddr])] -> Globals -> Heap -> (Heap, [Locals])
liftBindings []       _       heap = (heap, [])
liftBindings (bm:bms) globals heap = (res_heap, locals : ls)
  where (heap', locals) = liftBinding bm globals heap
        (res_heap, ls)  = liftBindings bms globals heap'

-- | Return a sub-list of bindings in which the entry candidate appears.
entryMatches :: String -> [(Binding, Locals)] -> [(Binding, Locals)]
entryMatches entry bnd_locs = filter (bindFilter entry) bnd_locs

-- | Bind Filtering
bindFilter :: String -> (Binding, Locals) -> Bool
bindFilter entry (bnd, _) = lhsMatches entry bnd /= []

-- | Sub-Bindings String Match
lhsMatches :: String -> Binding -> [(Var, BindRhs)]
lhsMatches st (Binding _ pairs) =
    filter (\(var, _) -> st == (nameOccStr . varName) var) pairs

-- | Load Code
loadCode :: Var -> BindRhs -> Locals -> Globals -> Heap -> (Code,Globals,Heap)
loadCode ent (ConForm _ _)         locals globals heap = (code, globals, heap)
  where code = Evaluate (Atom (VarAtom ent)) locals
loadCode ent (FunForm params expr) locals globals heap = (code, globals, heap')
  where actuals  = traceArgs params expr locals globals heap
        confs    = map varName actuals
        names'   = freshSeededNameList confs confs
        adjusted = map (\(n, t) -> Var n t) (zip names' (map varType actuals))
        -- Throw the parameters on heap as symbolic objects
        sym_objs = map (\p -> SymObj (Symbol p Nothing)) adjusted
        (heap', addrs) = allocHeapList sym_objs heap
        -- make Atom representations for arguments and shove into locals.
        mem_vals = map (\a -> MemVal a) addrs
        locals'  = insertLocalsList (zip adjusted mem_vals) locals
        args     = map (\p -> VarAtom p) adjusted
        -- Set up code
        code     = Evaluate (FunApp ent args) locals'

-- | We need to do stupid tracing if it's THUNK'D by default >:(
traceArgs :: [Var] -> Expr -> Locals -> Globals -> Heap -> [Var]
traceArgs base expr locals globals heap
  | FunApp var []     <- expr
  , Just (_, hobj)    <- vlookupHeap var locals globals heap
  , FunObj params _ _ <- hobj
  , length params > 0
  , length base == 0 = params

  | otherwise = base

-- | Run flags.
data RunFlags = RunFlags { step_count :: Int
                         , step_type  :: StepType
                         , dump_dir   :: Maybe FilePath
                         } deriving (Show, Eq, Read)

-- | Step execution type.
data StepType = BFS | DFS | BFSLogged | DFSLogged deriving (Show, Eq, Read)

-- | Perform execution on a `State` given the run flags.
execute :: RunFlags -> State -> [([LiveState], [DeadState])]
execute flags state = step (step_count flags) state
  where step :: Int -> State -> [([LiveState], [DeadState])]
        step = case step_type flags of
                   BFS       -> \k s -> [runBoundedBFS k s]
                   BFSLogged -> runBoundedBFSLogged
                   DFS       -> \k s -> [runBoundedDFS k s]
                   DFSLogged -> runBoundedDFSLogged

-- | Simple `BFS` based execution on a state.
execute1 :: Int -> State -> ([LiveState], [DeadState])
execute1 n state | n < 1     = ([([], state)], [])
                 | otherwise = runBoundedBFS n state