-- | Symbolic STG Execution Engine
module SSTG.Core.Execution.Engine
    ( loadState
    , loadStateEntry
    , LoadResult (..)
    , RunFlags (..)
    , StepType (..)
    , execute
    , execute1
    ) where

import SSTG.Core.Language
import SSTG.Core.Execution.Stepping

-- | Load Result
data LoadResult = LoadOkay State
                | LoadGuess State [Bind]
                | LoadError String
                deriving (Show, Eq, Read)

-- | Guess the main function as @"main"@, which is consistent with a few
-- experimental results.
loadState :: Program -> LoadResult
loadState prog = loadStateEntry main_occ_name prog
  where
    main_occ_name = "main"  -- Based on a few experimental programs.

-- | Load from a specified entry point.
loadStateEntry :: String -> Program -> LoadResult
loadStateEntry entry (Program bnds) = if length matches == 0
    then LoadError ("No entry candidates found for: [" ++ entry ++ "]")
    else if length others == 0
        then LoadOkay state
        else LoadGuess state (map fst others)
  where
    -- Status or something.
    status = init_status
    -- Stack initialized to empty.
    stack = empty_stack
    -- Globals and Heap are loaded together. They are still beta forms now.
    heap0 = empty_heap
    (glist, heap1, bnd_addrss) = initGlobals bnds heap0
    globals0 = insertGlobalsList glist empty_globals
    (heap2, localss) = liftBinds bnd_addrss globals0 heap1
    bnd_locs = zip bnds localss
    -- Code loading. Completes heap and globals with symbolic injection.
    matches = entryMatches entry bnd_locs
    ((tgt_bnd, tgt_loc):others) = matches
    ((tgt_var, tgt_rhs):_) = lhsMatches entry tgt_bnd
    (code, globals, heap) = loadCode tgt_var tgt_rhs tgt_loc globals0 heap2
    -- Ready to fill the state.
    state0 = State { state_status = status
                   , state_stack = stack
                   , state_heap = heap
                   , state_globals = globals
                   , state_code = code
                   , state_names = []
                   , state_paths = empty_pathcons }

    -- Gather information on all variables.
    state = state0 { state_names = allNames state0 }

-- | Allocate Bind
allocBind :: Bind -> Heap -> (Heap, [MemAddr])
allocBind (Bind _ pairs) heap = (heap', addrs)
  where
    hfakes = map (const Blackhole) pairs
    (heap', addrs) = allocHeapList hfakes heap

-- | Allocate List of `Bind`s
allocBindList :: [Bind] -> Heap -> (Heap, [[MemAddr]])
allocBindList [] heap = (heap, [])
allocBindList (b:bs) heap = (heapf, addrs : as)
  where
    (heap', addrs) = allocBind b heap
    (heapf, as) = allocBindList bs heap'

-- | Bind Address to Name Values
bndAddrsToVarVals :: (Bind, [MemAddr]) -> [(Var, Value)]
bndAddrsToVarVals (Bind _ rhss, addrs) = zip (map fst rhss) mem_vals
  where
    mem_vals = map (\a -> MemVal a) addrs

-- | Initialize Globals
initGlobals :: [Bind] -> Heap -> ([(Var, Value)], Heap, [(Bind, [MemAddr])])
initGlobals bnds heap = (var_vals, heap', bnd_addrss)
  where
    (heap', addrss) = allocBindList bnds heap
    bnd_addrss = zip bnds addrss
    var_vals = concatMap bndAddrsToVarVals bnd_addrss

-- | Force Atom Lookup
forceLookupValue :: Atom -> Locals -> Globals -> Value
forceLookupValue (LitAtom lit) _ _ = LitVal lit
forceLookupValue (VarAtom var) locals globals =
    case lookupValue var locals globals of
        Nothing -> LitVal BlankAddr  -- An error, but I want to not crash.
        Just val -> val

-- | Full Rhs Object
forceRhsObj :: BindRhs -> Locals -> Globals -> HeapObj
forceRhsObj (FunForm prms expr) locals _ = FunObj prms expr locals
forceRhsObj (ConForm dcon args) locals globals = ConObj dcon arg_vals
  where
    arg_vals = map (\a -> forceLookupValue a locals globals) args

-- | Lift `Bind`.
liftBind :: (Bind, [MemAddr]) -> Globals -> Heap -> (Heap, Locals)
liftBind (Bind rec pairs, addrs) globals heap = (heap', locals)
  where
    (vars, rhss) = unzip pairs
    mem_vals = map (\a -> MemVal a) addrs
    e_locs = empty_locals
    r_locs = insertLocalsList (zip vars mem_vals) e_locs
    locals = case rec of { Rec -> r_locs; NonRec -> e_locs }
    hobjs = map (\r -> forceRhsObj r locals globals) rhss
    heap' = insertHeapList (zip addrs hobjs) heap

-- | Lift Bind List
liftBinds :: [(Bind, [MemAddr])] -> Globals -> Heap -> (Heap, [Locals])
liftBinds [] _ heap = (heap, [])
liftBinds (bm:bms) globals heap = (heapf, locals : ls)
  where
    (heap', locals) = liftBind bm globals heap
    (heapf, ls) = liftBinds bms globals heap'

-- | Return a sub-list of binds in which the entry candidate appears.
entryMatches :: String -> [(Bind, Locals)] -> [(Bind, Locals)]
entryMatches entry bnd_locs = filter (bindFilter entry) bnd_locs

-- | Bind Filtering
bindFilter :: String -> (Bind, Locals) -> Bool
bindFilter entry (bnd, _) = lhsMatches entry bnd /= []

-- | Sub-Binds String Match
lhsMatches :: String -> Bind -> [(Var, BindRhs)]
lhsMatches st (Bind _ pairs) =
    filter (\(var, _) -> st == (nameOccStr . varName) var) pairs

-- | Load Code
loadCode :: Var -> BindRhs -> Locals -> Globals -> Heap -> (Code,Globals,Heap)
loadCode ent (ConForm _ _) locals globals heap = (code, globals, heap)
  where
    code = Evaluate (Atom (VarAtom ent)) locals
loadCode ent (FunForm params expr) locals globals heap = (code, globals, heap')
  where
    actuals = traceArgs params expr locals globals heap
    confs = map varName actuals
    names' = freshSeededNameList confs confs
    adjusted = map (\(n, t) -> Var n t) (zip names' (map varType actuals))
    -- Throw the parameters on heap as symbolic objects
    sym_objs = map (\p -> SymObj (Symbol p Nothing)) adjusted
    (heap', addrs) = allocHeapList sym_objs heap
    -- make Atom representations for arguments and shove into locals.
    mem_vals = map (\a -> MemVal a) addrs
    locals' = insertLocalsList (zip adjusted mem_vals) locals
    args = map (\p -> VarAtom p) adjusted
    -- Set up code
    code = Evaluate (FunApp ent args) locals'

-- | We need to do stupid tracing if it's THUNK'D by default >:(
traceArgs :: [Var] -> Expr -> Locals -> Globals -> Heap -> [Var]
traceArgs base expr locals globals heap
  | FunApp var [] <- expr
  , Just (_, hobj) <- vlookupHeap var locals globals heap
  , FunObj params _ _ <- hobj
  , length params > 0
  , length base == 0 = params

  | otherwise = base

-- | Run flags.
data RunFlags = RunFlags { flag_step_count :: Int
                         , flag_step_type :: StepType
                         , flag_dump_dir :: Maybe FilePath
                         } deriving (Show, Eq, Read)

-- | Step execution type.
data StepType = BFS | DFS | BFSLogged | DFSLogged deriving (Show, Eq, Read)

-- | Perform execution on a `State` given the run flags.
execute :: RunFlags -> State -> [([LiveState], [DeadState])]
execute flags state = step (flag_step_count flags) state
  where
    step :: Int -> State -> [([LiveState], [DeadState])]
    step = case flag_step_type flags of
               BFS -> \k s -> [runBoundedBFS k s]
               BFSLogged -> runBoundedBFSLogged
               DFS -> \k s -> [runBoundedDFS k s]
               DFSLogged -> runBoundedDFSLogged

-- | Simple `BFS` based execution on a state.
execute1 :: Int -> State -> ([LiveState], [DeadState])
execute1 n state | n < 1 = ([([], state)], [])
                 | otherwise = runBoundedBFS n state