module Lang.LamIf.Semantics where

import FP
import MAAM
import Lang.LamIf.Syntax hiding (PreExp(..))
import Lang.LamIf.CPS
import Lang.LamIf.StateSpace

type Ψ = LocNum

-- These are the raw constraints that must hold for:
-- - time lτ and dτ
-- - values val
-- - the monad m

type TimeC τ =
  ( Time Ψ τ
  , Bot τ
  , Ord τ
  , Pretty τ
  )

type ValC   val =
  ( Val   val
  , Ord val
  , PartialOrder val
  , JoinLattice val
  , Difference val
  , Pretty val
  )

type MonadC val   m =
  ( Monad m, MonadBot m, MonadPlus m
  , MonadState (𝒮 val  ) m
  )

-- This type class aids type inference. The functional dependencies tell the
-- type checker that  choices for val, lτ, dτ and 𝓈 are unique for a given
-- m.
class 
  ( TimeC 
  , TimeC 
  , ValC   val
  , MonadC val   m
  ) => Analysis val   m | m -> val , m ->  , m ->  where

-- Some helper types
type GC m = Call -> m ()
type CreateClo   m = LocNum -> [Name] -> Call -> m (Clo  )
type TimeFilter = Call -> Bool

-- Generate a new address
new :: (Analysis val   m) => Name -> m (Addr  )
new x = do
   <- getL 𝓈lτL
   <- getL 𝓈dτL
  return $ Addr x  

-- bind a name to a value in an environment
bind :: (Analysis val   m) => Name -> val -> Map Name (Addr  ) -> m (Map Name (Addr  ))
bind x vD ρ = do
  l <- new x
  modifyL 𝓈σL $ mapInsertWith (\/) l vD
  return $ mapInsert x l ρ

-- bind a name to a value in _the_ environment
bindM :: (Analysis val   m) => Name -> val -> m ()
bindM x vD = do
  ρ <- getL 𝓈ρL
  ρ' <- bind x vD ρ
  putL 𝓈ρL ρ'

-- rebinds the value assigned to a name
rebind :: (Analysis val   m) => Name -> val -> m ()
rebind x vD = do
  ρ <- getL 𝓈ρL
  let l = ρ #! x
  modifyL 𝓈σL $ mapInsert l vD

-- rebinds the value assigned to a pico if it is a name
rebindPico :: (Analysis val   m) => PrePico Name -> val -> m ()
rebindPico (Lit _) _ = return ()
rebindPico (Var x) vD = rebind x vD

-- the denotation for addresses
addr :: (Analysis val   m) => Addr   -> m val
addr 𝓁 = do
  σ <- getL 𝓈σL
  maybeZero $ σ # 𝓁

-- the denotation for variables
var :: (Analysis val   m) => Name -> m val
var x = do
  ρ <- getL 𝓈ρL
  𝓁 <- maybeZero $ ρ # x
  addr 𝓁

-- the denotation for lambdas
lam :: (Analysis val   m) => CreateClo   m -> LocNum -> [Name] -> Call -> m val
lam createClo = clo ^..: createClo

-- the partial denotation for pico (for storing in values)
picoRef :: (Analysis val   m) => Pico -> m (PicoVal  )
picoRef (Lit l) = return $ LitPicoVal l
picoRef (Var x) = do
  ρ <- getL 𝓈ρL
  AddrPicoVal ^$ maybeZero $ ρ # x

picoVal :: (Analysis val   m) => PicoVal   -> m val
picoVal (LitPicoVal l) = return $ lit l
picoVal (AddrPicoVal 𝓁) = addr 𝓁

-- the denotation for the pico syntactic category
pico :: (Analysis val   m) => Pico -> m val
pico = picoVal *. picoRef

-- the denotation for the atom syntactic category
atom :: (Analysis val   m) => CreateClo   m -> Atom -> m val
atom createClo a = case stamped a of
  Pico p -> pico p
  Prim o a1 a2 -> return (binop $ lbinOpOp o) <@> pico a1 <@> pico a2
  LamF x kx c -> lam createClo (stampedID a) [x, kx] c
  LamK x c -> lam createClo (stampedID a) [x] c
  Tup p1 p2 -> pure (curry tup) <@> picoRef p1 <@> picoRef p2
  Pi1 p -> picoVal *. fst ^. mset . elimTup *$ pico p
  Pi2 p -> picoVal *. snd ^. mset . elimTup *$ pico p

apply :: (Analysis val   m) => TimeFilter -> Call -> PrePico Name -> val -> [val] -> m Call
apply timeFilter c fx fv avs = do
  fclo@(Clo cid' xs c' ρ ) <- mset $ elimClo fv
  rebindPico fx $ clo fclo
  xvs <- maybeZero $ zip xs avs
  putL 𝓈ρL ρ
  traverseOn xvs $ uncurry $ bindM 
  putL 𝓈lτL 
  when (timeFilter c) $
    modifyL 𝓈lτL $ tick cid'
  return c'

call :: (Analysis val   m) => GC m -> CreateClo   m -> TimeFilter -> TimeFilter -> Call -> m Call
call gc createClo ltimeFilter dtimeFilter c = do
  when (dtimeFilter c) $
    modifyL 𝓈dτL $ tick $ stampedFixID c
  c' <- case stampedFix c of
    Let x a c' -> do
      v <- atom createClo a
      bindM x v
      return c'
    If ax tc fc -> do
      b <- mset . elimBool *$ pico ax
      rebindPico ax $ lit $ B b
      return $ if b then tc else fc
    AppF fx ax ka -> do
      fv <- pico fx
      av <- pico ax
      kv <- pico ka
      apply ltimeFilter c fx fv [av, kv]
    AppK kx ax -> do
      kv <- pico kx
      av <- pico ax
      apply ltimeFilter c kx kv [av]
    Halt _ -> return c
  gc c'
  return c'

onlyStuck :: (MonadStep ς m,  Analysis val   m) => GC m -> CreateClo   m -> TimeFilter -> TimeFilter -> Call -> m Call
onlyStuck gc createClo ltimeFilter dtimeFilter e = do
  e' <- call gc createClo ltimeFilter dtimeFilter e
  if e == e' then return e else mbot

-- Execution {{{

type StateSpaceC ς' =
  ( PartialOrder (ς' Call)
  , JoinLattice (ς' Call)
  , Difference (ς' Call)
  , Pretty (ς' Call)
  )

class (MonadStep ς m, Inject ς, Isomorphism (ς Call) (ς' Call), StateSpaceC ς') => Execution ς ς' m | m -> ς, m -> ς'

liftς :: (Execution ς ς' m) => (Call -> m Call) -> (ς' Call -> ς' Call)
liftς f = isoto . mstepγ f . isofrom

injectς :: forall ς ς' a. (Inject ς, Isomorphism (ς a) (ς' a)) => P ς -> a -> ς' a
injectς P = isoto . (inj :: a -> ς a)

execς :: forall val   m ς ς'. (Analysis val   m, Execution ς ς' m) => 
  GC m -> CreateClo   m -> TimeFilter -> TimeFilter -> Call -> ς' Call
execς gc createClo ltimeFilter dtimeFilter = poiter (liftς $ call gc createClo ltimeFilter dtimeFilter) . injectς (P :: P ς)

execCollect :: forall val   m ς ς'. (Analysis val   m, Execution ς ς' m) => 
  GC m -> CreateClo   m -> TimeFilter -> TimeFilter -> Call -> ς' Call
execCollect gc createClo ltimeFilter dtimeFilter = collect (liftς $ call gc createClo ltimeFilter dtimeFilter) . injectς (P :: P ς)

execCollectHistory :: forall val   m ς ς'. (Analysis val   m, Execution ς ς' m) => 
  GC m -> CreateClo   m -> TimeFilter -> TimeFilter -> Call -> [ς' Call]
execCollectHistory gc createClo ltimeFilter dtimeFilter = collectHistory (liftς $ call gc createClo ltimeFilter dtimeFilter) . injectς (P :: P ς)

execCollectDiffs :: forall val   m ς ς'. (Analysis val   m, Execution ς ς' m) =>
  GC m -> CreateClo   m -> TimeFilter -> TimeFilter -> Call -> [ς' Call]
execCollectDiffs gc createClo ltimeFilter dtimeFilter = collectDiffs (liftς $ call gc createClo ltimeFilter dtimeFilter) . injectς (P :: P ς)

execOnlyStuck :: (Analysis val   m, Execution ς ς' m) => GC m -> CreateClo   m -> TimeFilter -> TimeFilter -> Call -> ς' Call
execOnlyStuck gc createClo ltimeFilter dtimeFilter = 
    liftς (onlyStuck gc createClo ltimeFilter dtimeFilter) 
  . execCollect gc createClo ltimeFilter dtimeFilter

-- }}}

-- GC {{{

nogc :: (Monad m) => Call -> m ()
nogc _ = return ()

yesgc :: (Analysis val   m) => Call -> m ()
yesgc c = do
  ρ <- getL 𝓈ρL
  σ <- getL 𝓈σL
  let live0 = callTouched ρ $ freeVarsLam empty [] c
  let live = collect (extend $ addrTouched σ) live0
  modifyL 𝓈σL $ onlyKeys live

callTouched :: (TimeC , TimeC ) => Env   -> Set Name -> Set (Addr  )
callTouched ρ xs = maybeSet . index ρ *$ xs

closureTouched :: (TimeC , TimeC ) => Clo   -> Set (Addr  )
closureTouched (Clo _ xs c ρ _) = callTouched ρ $ freeVarsLam empty xs c

picoValTouched :: (TimeC , TimeC ) => PicoVal   -> Set (Addr  )
picoValTouched (LitPicoVal _) = empty
picoValTouched (AddrPicoVal 𝓁) = single 𝓁

tupleTouched :: (TimeC , TimeC ) => (PicoVal  , PicoVal  ) -> Set (Addr  )
tupleTouched (pv1, pv2) = picoValTouched pv1 \/ picoValTouched pv2

addrTouched :: (TimeC , TimeC , ValC   val) => Map (Addr  ) val -> Addr   -> Set (Addr  )
addrTouched σ 𝓁 = do
  v <- maybeSet $ σ # 𝓁
  msum
    [ closureTouched *$ elimClo v 
    , tupleTouched *$ elimTup v
    ]

-- }}}

-- CreateClo {{{

linkClo :: (Analysis val   m) => LocNum -> [Name] -> Call -> m (Clo  )
linkClo cid xs c = do
  ρ <- getL 𝓈ρL
   <- getL 𝓈lτL
  return $ Clo cid xs c ρ 

copyClo :: (Analysis val   m) => LocNum -> [Name] -> Call -> m (Clo  )
copyClo cid xs c = do
  let ys = toList $ freeVarsLam empty xs c
  vs <- var ^*$ ys
  yvs <- maybeZero $ zip ys vs
  ρ <- runKleisliEndo mapEmpty *$ execWriterT $ do
    traverseOn yvs $ tell . KleisliEndo . uncurry bind
   <- getL 𝓈lτL
  return $ Clo cid xs c ρ 

-- }}}

-- Parametric Execution {{{

type UniTime τ = W (TimeC τ)
data ExTime where ExTime :: forall τ. UniTime τ -> ExTime

type UniVal val = forall  . (TimeC , TimeC ) => W (ValC   (val  ))
data ExVal where ExVal :: forall val. UniVal val -> ExVal

type UniMonad ς ς' m = 
  forall val  . (TimeC , TimeC , ValC   val) 
  => W (Analysis val   (m val  ), Execution (ς val  ) (ς' val  ) (m val  ))
data ExMonad where 
  ExMonad :: forall ς ς' m. 
       UniMonad ς ς' m 
    -> ExMonad

newtype AllGC = AllGC { runAllGC :: forall val   m. (Analysis val   m) => GC m }
newtype AllCreateClo  = AllCreateClo { runAllCreateClo :: forall val   m. (Analysis val   m) => CreateClo   m }

data Options = Options
  { ltimeOp :: ExTime
  , dtimeOp :: ExTime
  , valOp :: ExVal
  , monadOp :: ExMonad
  , gcOp :: AllGC
  , createCloOp :: AllCreateClo
  , ltimeFilterOp :: TimeFilter
  , dtimeFilterOp :: TimeFilter
  }

withOptions :: forall a. Options -> ((Analysis val   m, Execution ς ς' m) => GC m -> CreateClo   m -> TimeFilter -> TimeFilter -> a) -> a
withOptions o f = case o of
  Options (ExTime (W :: UniTime )) 
          (ExTime (W :: UniTime ))
          (ExVal (W :: W (ValC   (val  ))))
          (ExMonad (W :: W ( Analysis (val  )   (m (val  )  )
                           , Execution (ς (val  )  ) (ς' (val  )  ) (m (val  )  ))))
          (AllGC (gc :: GC (m (val  )  )))
          (AllCreateClo (createClo :: CreateClo   (m (val  )  )))
          (ltimeFilter :: TimeFilter)
          (dtimeFilter :: TimeFilter) -> f gc createClo ltimeFilter dtimeFilter

-- }}}