{-# LANGUAGE DeriveFunctor, GADTs, PatternGuards #-}
module DecisionTreeSimplify (
decisionTreeSimple,
decisionStepWithTime,
simplifyWait
) where
import Contract
import Observable (Steps(..))
import qualified Observable as Obs
import DecisionTree
import Display
import Prelude hiding (product, until, and)
import Data.List hiding (and)
import Data.Ord
decisionTreeSimple :: Time -> Contract -> DecisionTree
decisionTreeSimple t c = unfoldDecisionTree
decisionStepWithTime
(initialProcessState t c)
decisionStepWithTime :: ProcessState -> (DecisionStep ProcessState, Time)
decisionStepWithTime st@(PSt time _ _) = case decisionStep st of
Done -> (Done, time)
Trade d sf t st1 -> (Trade d sf t st1, time)
Choose p id st1 st2 -> (Choose p id st1 st2, time)
ObserveCond o st1 st2 -> case Obs.eval time o of
Result True -> decisionStepWithTime st1
Result False -> decisionStepWithTime st2
_ -> (ObserveCond o st1 st2, time)
ObserveValue o k -> case Obs.eval time o of
Result v -> decisionStepWithTime (k v)
_ -> (ObserveValue o k, time)
Wait conds opts -> case simplifyWait time conds (not (null opts)) of
Left st' -> decisionStepWithTime st'
Right [] -> (Done, time)
Right conds' -> (Wait conds' opts, time)
simplifyWait :: Time
-> [(Obs Bool, Time -> ProcessState)]
-> Bool
-> Either ProcessState
[(Obs Bool, Time -> ProcessState)]
simplifyWait time conds opts =
case checkCondTrue time conds of
Left k -> Left (k time)
Right [] | opts -> Right [(konst False, \time' -> PSt time' [] [])]
| otherwise -> Right []
Right otherConds ->
case Obs.earliestTimeHorizon time otherConds of
Nothing -> Right otherConds
Just (horizon, k) ->
let simplifiedConds = [ (obs', k')
| (obs, k') <- otherConds
, let obs' = Obs.simplifyWithinHorizon
time horizon obs
, not (Obs.isFalse time obs') ]
in if null simplifiedConds
then if opts then Right [(at horizon, k)]
else Left (k horizon)
else Right ((at horizon, k) : simplifiedConds)
where
checkCondTrue :: Time -> [(Obs Bool, a)] -> Either a [(Obs Bool, a)]
checkCondTrue time conds
| ((_,k) :_) <- trueConds = Left k
| otherwise = Right otherConds'
where
(trueConds, otherConds) = partition (Obs.isTrue time . fst) conds
otherConds' = filter (not . Obs.evermoreFalse time . fst) otherConds