module Lang.CPS.Semantics where

import FP
import MAAM
import Lang.CPS.Syntax
import Lang.Common
import Lang.CPS.StateSpace

type Ψ = LocNum

-- These are the raw constraints that must hold for:
-- - time lτ and dτ
-- - values val
-- - the monad m

type TimeC τ =
  ( Time τ
  , Initial (τ Ψ)
  , Ord (τ Ψ)
  , Pretty (τ Ψ)
  )

type ValC   val =
  ( Val   Ψ (val   Ψ)
  , Ord (val   Ψ)
  , PartialOrder (val   Ψ)
  , JoinLattice (val   Ψ)
  , Pretty (val   Ψ)
  )

type MonadC val   m =
  ( Monad m, MonadZero m, MonadPlus m
  , MonadState (𝒮 val   Ψ) m
  )

-- This type class aids type inference. The functional dependencies tell the
-- type checker that  choices for val, lτ, dτ and 𝓈 are unique for a given
-- m.
class 
  ( TimeC 
  , TimeC 
  , ValC   val
  , MonadC val   m
  ) => Analysis val   m | m -> val , m ->  , m ->  where

-- Some helper types
type GC m = SGCall -> m ()
type CreateClo   m = LocNum -> [SGName] -> SGCall -> m (Clo   Ψ)
type TimeFilter = SGCall -> Bool

-- Generate a new address
new :: (Analysis val   m) => SGName -> m (Addr   Ψ)
new x = do
   <- getL 𝓈lτL
   <- getL 𝓈dτL
  return $ Addr x  

-- bind a name to a value in an environment
bind :: (Analysis val   m) => SGName -> val   Ψ -> Map SGName (Addr   Ψ) -> m (Map SGName (Addr   Ψ))
bind x vD ρ = do
  l <- new x
  modifyL 𝓈σL $ mapInsertWith (\/) l vD
  return $ mapInsert x l ρ

-- bind a name to a value in _the_ environment
bindM :: (Analysis val   m) => SGName -> val   Ψ -> m ()
bindM x vD = do
  ρ <- getL 𝓈ρL
  ρ' <- bind x vD ρ
  putL 𝓈ρL ρ'

-- the denotation for variables
var :: (Analysis val   m) => SGName -> m (val   Ψ)
var x = do
  ρ <- getL 𝓈ρL
  σ <- getL 𝓈σL
  liftMaybeZero $ index σ *$ index ρ $ x

-- the denotation for lambdas
lam :: (Analysis val   m) => CreateClo   m -> LocNum -> [SGName] -> SGCall -> m (val   Ψ)
lam createClo = clo ^..: createClo

-- the denotation for the pico syntactic category
pico :: (Analysis val   m) => SGPico -> m (val   Ψ)
pico (Lit l) = return $ lit l
pico (Var x) = var x

-- the denotation for the atom syntactic category
atom :: (Analysis val   m) => CreateClo   m ->  SGAtom -> m (val   Ψ)
atom createClo a = case stamped a of
  Pico p -> pico p
  Prim o ax -> op o ^$ pico ax
  LamF x kx c -> lam createClo (stampedID a) [x, kx] c
  LamK x c -> lam createClo (stampedID a) [x] c

apply :: (Analysis val   m) => TimeFilter -> SGCall -> val   Ψ -> [val   Ψ] -> m SGCall
apply timeFilter c fv avs = do
  Clo cid' xs c' ρ  <- mset $ elimClo fv
  xvs <- liftMaybeZero $ zip xs avs
  putL 𝓈ρL ρ
  traverseOn xvs $ uncurry $ bindM 
  putL 𝓈lτL 
  when (timeFilter c) $
    modifyL 𝓈lτL $ tick cid'
  return c'

call :: (Analysis val   m) => GC m -> CreateClo   m -> TimeFilter -> TimeFilter -> SGCall -> m SGCall
call gc createClo ltimeFilter dtimeFilter c = do
  when (dtimeFilter c) $
    modifyL 𝓈dτL $ tick $ stampedFixID c
  c' <- case stampedFix c of
    Let x a c' -> do
      v <- atom createClo a
      bindM x v
      return c'
    If ax tc fc -> do
      b <- mset . elimBool *$ pico ax
      return $ if b then tc else fc
    AppF fx ax ka -> do
      fv <- pico fx
      av <- pico ax
      kv <- pico ka
      apply ltimeFilter c fv [av, kv]
    AppK kx ax -> do
      kv <- pico kx
      av <- pico ax
      apply ltimeFilter c kv [av]
    Halt _ -> return c
  gc c'
  return c'

-- GC {{{

nogc :: (Monad m) => SGCall -> m ()
nogc _ = return ()

closureTouched :: (TimeC , TimeC ) => Clo   Ψ -> Set (Addr   Ψ)
closureTouched (Clo _ xs c ρ _) = liftMaybeSet . index ρ *$ freeVarsLam xs $ stampedFix c

addrTouched :: (TimeC , TimeC , ValC   val) => Map (Addr   Ψ) (val   Ψ) -> Addr   Ψ -> Set (Addr   Ψ)
addrTouched σ = closureTouched *. elimClo *. liftMaybeSet . index σ

currClosure :: (Analysis val   m) => SGCall -> m (Clo   Ψ)
currClosure c = do
  ρ <- getL 𝓈ρL
   <- getL 𝓈lτL
  return $ Clo (LocNum (-1)) [] c ρ 

yesgc :: (Analysis val   m) => SGCall -> m ()
yesgc c = do
  σ <- getL 𝓈σL
  live0 <- closureTouched ^$ currClosure c
  let live = collect (extend $ addrTouched $ σ) live0
  modifyL 𝓈σL $ onlyKeys live

-- }}}

-- CreateClo {{{

linkClo :: (Analysis val   m) => LocNum -> [SGName] -> SGCall -> m (Clo   Ψ)
linkClo cid xs c = do
  ρ <- getL 𝓈ρL
   <- getL 𝓈lτL
  return $ Clo cid xs c ρ 

copyClo :: (Analysis val   m) => LocNum -> [SGName] -> SGCall -> m (Clo   Ψ)
copyClo cid xs c = do
  let ys = toList $ freeVarsLam xs $ stampedFix c
  vs <- var ^*$ ys
  yvs <- liftMaybeZero $ zip ys vs
  ρ <- runKleisliEndo mapEmpty *$ execWriterT $ do
    traverseOn yvs $ tell . KleisliEndo . uncurry bind
   <- getL 𝓈lτL
  return $ Clo cid xs c ρ 

-- }}}

-- Execution {{{

-- type StateSpaceC ς =
--   ( PartialOrder (ς SGCall)
--   , JoinLattice (ς SGCall)
--   , Pretty (ς SGCall)
--   , Inject ς
--   , MonadStep ς m
--   )

  -- , Isomorphism (ς SGCall) (ς' SGCall)
  -- , StateSpaceC ς'

type MonadStateSpaceC ς ς' m =
  ( MonadStep ς m
  , Inject ς
  , Isomorphism (ς SGCall) (ς' SGCall)
  )
type StateSpaceC ς' =
  ( PartialOrder (ς' SGCall)
  , JoinLattice (ς' SGCall)
  , Pretty (ς' SGCall)
  )

class (MonadStateSpaceC ς ς' m, StateSpaceC ς') => Execution ς ς' m | m -> ς, m -> ς'

exec :: 
  forall val   ς ς' m. (Analysis val   m, Execution ς ς' m) 
  => GC m -> CreateClo   m -> TimeFilter -> TimeFilter -> SGCall -> ς' SGCall
exec gc createClo ltimeFilter dtimeFilter = 
  poiter (isoto . mstepγ (call gc createClo ltimeFilter dtimeFilter) . isofrom) 
  . isoto 
  . (inj :: SGCall -> ς SGCall)

execCollect :: 
  forall val   ς ς' m. (Analysis val   m, Execution ς ς' m) 
  => GC m -> CreateClo   m -> TimeFilter -> TimeFilter -> SGCall -> ς' SGCall
execCollect gc createClo ltimeFilter dtimeFilter = 
  collect (isoto . mstepγ (call gc createClo ltimeFilter dtimeFilter) . isofrom) 
  . isoto 
  . (inj :: SGCall -> ς SGCall)

-- }}}

-- Parametric Execution {{{

type UniTime τ = W (TimeC τ)
data ExTime where ExTime :: forall τ. UniTime τ -> ExTime

type UniVal val = forall  . (TimeC , TimeC ) => W (ValC   val)
data ExVal where ExVal :: forall val. UniVal val -> ExVal

type UniMonad ς ς' m = 
  forall val  . (TimeC , TimeC , ValC   val) 
  => W (Analysis val   (m val   Ψ), Execution (ς val   Ψ) (ς' val   Ψ) (m val   Ψ))
data ExMonad where ExMonad :: forall ς ς' m. UniMonad ς ς' m -> ExMonad

newtype AllGC = AllGC { runAllGC :: forall val   m. (Analysis val   m) => GC m }
newtype AllCreateClo  = AllCreateClo { runAllCreateClo :: forall val   m. (Analysis val   m) => CreateClo   m }

data Options = Options
  { ltimeOp :: ExTime
  , dtimeOp :: ExTime
  , valOp :: ExVal
  , monadOp :: ExMonad
  , gcOp :: AllGC
  , createCloOp :: AllCreateClo
  , ltimeFilterOp :: TimeFilter
  , dtimeFilterOp :: TimeFilter
  }

data ExSigma where
  ExSigma :: (StateSpaceC ς) => ς SGCall -> ExSigma

runWithOptions :: Options -> SGCall -> ExSigma
runWithOptions o e = case o of
  Options (ExTime (W :: UniTime )) 
          (ExTime (W :: UniTime ))
          (ExVal (W :: W (ValC   val)))
          (ExMonad (W :: W (Analysis val   (m val   Ψ), Execution (ς val   Ψ) (ς' val   Ψ) (m val   Ψ))))
          (AllGC (gc :: GC (m val   Ψ)))
          (AllCreateClo (createClo  :: CreateClo   (m val   Ψ)))
          (ltimeFilter :: TimeFilter)
          (dtimeFilter :: TimeFilter) -> 
    ExSigma $ execCollect gc createClo ltimeFilter dtimeFilter e

-- }}}