module Lang.Hask.Semantics where

import FP

import Lang.Hask.Time
import Lang.Hask.CPS hiding (atom)
import Name
import Literal
import DataCon
import CoreSyn (AltCon(..))

-- Values

data Moment   = Moment
  { timeLex :: 
  , timeDyn :: 
  } deriving (Eq, Ord)
makeLenses ''Moment
instance (Time ψ , Time ψ ) => Bot (Moment  ) where bot = Moment tzero tzero

data Addr   = Addr
  { addrName :: Name
  , addrTime :: Moment  
  } deriving (Eq, Ord)

type Env   = Map Name (Addr  )
type Store ν   = Map (Addr  ) (ν  )

data ArgVal   =
    AddrVal (Addr  )
  | LitVal Literal
  | TypeVal
  deriving (Eq, Ord)

data Data   = Data
  { dataCon :: DataCon
  , dataArgs :: [ArgVal  ]
  } deriving (Eq, Ord)

data FunClo   = FunClo
  { funCloLamArg :: Name
  , funCloKonArg :: Name
  , funCloBody :: Call
  , funCloEnv :: Env  
  , funCloTime :: 
  } deriving (Eq, Ord)

data Ref   = Ref
  { refAddr :: Addr  
  } deriving (Eq, Ord)

data KonClo   = KonClo
  { konCloArg :: Name
  , konCloBody :: Call
  , konCloEnv :: Env  
  } deriving (Eq, Ord)

data ThunkClo   = ThunkClo
  { thunkCloKonXLoc :: Int
  , thunkCloKonXArg :: Name
  , thunkCloKonKArg :: Name
  , thunkCloFun :: Pico
  , thunkCloArg :: Pico
  , thunkCloEnv :: Env  
  , thunkCloTime :: 
  } deriving (Eq, Ord)

data KonMemoClo   = KonMemoClo
  { konMemoCloLoc :: Addr  
  , konMemoCloThunk :: ThunkClo  
  , konMemoCloArg :: Name
  , konMemoCloBody :: Call
  , konMemoCloEnv :: Env  
  } deriving (Eq, Ord)

data Forced   = Forced
  { forcedVal :: ArgVal  
  } deriving (Eq, Ord)

class Val   γν αν | αν -> γν where
  botI :: αν  
  litI :: Literal -> αν  
  litTestE :: Literal -> αν   -> γν Bool
  dataI :: Data   -> αν  
  dataAnyI :: DataCon -> αν  
  dataE :: αν   -> γν (Data  )
  funCloI :: FunClo   -> αν  
  funCloE :: αν   -> γν (FunClo  )
  refI :: Ref   -> αν  
  refAnyI :: αν  
  refE :: αν   -> γν (Ref  )
  konCloI :: KonClo   -> αν  
  konCloE :: αν   -> γν (KonClo  )
  konMemoCloI :: KonMemoClo   -> αν  
  konMemoCloE :: αν   -> γν (KonMemoClo  )
  thunkCloI :: ThunkClo   -> αν  
  thunkCloE :: αν   -> γν (ThunkClo  )
  forcedI :: Forced   -> αν  
  forcedE :: αν   -> γν (Forced  )

-- State Space

data 𝒮 ν   = 𝒮
  { 𝓈Env :: Env  
  , 𝓈Store :: Store ν  
  , 𝓈Time :: Moment  
  } deriving (Eq, Ord)
instance (Time ψ , Time ψ ) => Bot (𝒮 ν  ) where bot = 𝒮 bot bot bot
makeLenses ''𝒮

-- Analysis effects and constraints

type TimeC   = (Ord , Ord , Time Int , Time Int )
type ValC ν   = (JoinLattice (ν  ), Meet (ν  ), Neg (ν  ), Val   SetWithTop ν)
type MonadC ν   m = (Monad m, MonadBot m, MonadTop m, MonadPlus m, MonadState (𝒮 ν  ) m)

class ( MonadC ν   m , ValC ν   , TimeC  ) => Analysis ν   m | m -> ν , m ->  , m -> 

-- Moment management

tickLex :: (Analysis ν   m) => Call -> m ()
tickLex = modifyL (timeLexL <.> 𝓈TimeL) . tick . stampedFixID

tickDyn :: (Analysis ν   m) => Call -> m ()
tickDyn = modifyL (timeDynL <.> 𝓈TimeL) . tick . stampedFixID

alloc :: (Analysis ν   m) => Name -> m (Addr  )
alloc x = do
  τ <- getL 𝓈TimeL
  return $ Addr x τ

-- Updating values in the store

bindJoin :: (Analysis ν   m) => Name -> ν   -> m ()
bindJoin x v = do
  𝓁 <- alloc x
  modifyL 𝓈EnvL $ mapInsert x 𝓁
  modifyL 𝓈StoreL $ mapInsertWith (\/) 𝓁 v

updateRef :: (Analysis ν   m) => Addr   -> ν   -> ν   -> m ()
updateRef 𝓁 vOld vNew = modifyL 𝓈StoreL $ \ σ -> 
  mapModify (\ v -> v /\ neg vOld) 𝓁 σ \/ mapSingleton 𝓁 vNew

-- Refinement and extraction

refine :: (Analysis ν   m) => ArgVal   -> ν   -> m ()
refine (AddrVal 𝓁) v = modifyL 𝓈StoreL $ mapInsertWith (/\) 𝓁 v
refine (LitVal _) _ = return ()
refine TypeVal _ = return ()

extract :: (Analysis ν   m) => (a -> ν  ) -> (ν   -> SetWithTop a) -> ArgVal   -> m a
extract intro elim av = do
  v <- argVal av
  a <- setWithTopElim mtop mset $ elim v
  refine av $ intro a
  return a

extractIsLit :: (Analysis ν   m) => Literal -> ArgVal   -> m ()
extractIsLit l av = do
  v <- argVal av
  b <- setWithTopElim mtop mset $ litTestE l v
  guard b
  refine av $ litI l

-- Denotations

addr :: (Analysis ν   m) => Addr   -> m (ν  )
addr 𝓁 = do
  σ <- getL 𝓈StoreL
  maybeZero $ σ # 𝓁

argVal :: (Analysis ν   m) => ArgVal   -> m (ν  )
argVal (AddrVal 𝓁) = addr 𝓁
argVal (LitVal l) = return $ litI l
argVal TypeVal = return botI

varAddr :: (Analysis ν   m) => Name -> m (Addr  )
varAddr x = do
  ρ <- getL 𝓈EnvL
  maybeZero $ ρ # x

var :: (Analysis ν   m) => Name -> m (ν  )
var = addr *. varAddr

pico :: (Analysis ν   m) => Pico -> m (ν  )
pico = \ case
  Var n -> var n
  Lit l -> return $ litI l
  Type -> return botI

picoArg :: (Analysis ν   m) => Pico -> m (ArgVal  )
picoArg (Var x) = AddrVal ^$ varAddr x
picoArg (Lit l) = return $ LitVal l
picoArg Type = return TypeVal

atom :: (Analysis ν   m) => Atom -> m (ν  )
atom = \ case
  Pico p -> pico p
  LamF x k c -> do
    ρ <- getL 𝓈EnvL
     <- getL $ timeLexL <.> 𝓈TimeL
    return $ funCloI $ FunClo x k c ρ 
  LamK x c -> do
    ρ <- getL 𝓈EnvL
    return $ konCloI $ KonClo x c ρ
  Thunk r xi x k p₁ p₂ -> do
    ρ <- getL 𝓈EnvL
     <- getL $ timeLexL <.> 𝓈TimeL
    𝓁 <- alloc r
    updateRef 𝓁 botI $ thunkCloI $ ThunkClo xi x k p₁ p₂ ρ 
    return $ refI $ Ref 𝓁

forceThunk :: forall ν   m. (Analysis ν   m) => Name -> ArgVal   -> Call -> m Call
forceThunk x av c = do
  Ref 𝓁 <- extract refI refE av
  msum
    [ do
        Forced av' <- extract forcedI forcedE $ AddrVal 𝓁
        v' <- argVal av'
        bindJoin x v'
        return c
    , do
        t@(ThunkClo xi' x' k p₁' p₂' ρ' lτ') <- extract thunkCloI thunkCloE $ AddrVal 𝓁
        ρ <- getL 𝓈EnvL
        let kv = konMemoCloI $ KonMemoClo 𝓁 t x c ρ
        putL 𝓈EnvL ρ'
        putL (timeLexL <.> 𝓈TimeL) lτ'
        bindJoin k kv
        return $ StampedFix xi' $ AppF xi' x' p₁' p₂' $ Var k
    ]

call :: (Analysis ν   m) => Call -> m Call
call c = do
  tickDyn c
  case stampedFix c of
    Let x a c' -> do
      v <- atom a  
      bindJoin x v
      return c'
    Rec rxs c' -> do
      traverseOn rxs $ \ (r,x) -> do
        𝓁 <- alloc r
        bindJoin x $ refI $ Ref 𝓁
      return c'
    Letrec xas c' -> do
      traverseOn xas $ \ (x, a) -> do
        av <- picoArg $ Var x
        Ref 𝓁 <- extract refI refE av
        updateRef 𝓁 botI *$ atom a
      return c'
    AppK p₁ p₂ -> do
      av₁ <- picoArg p₁
      v₂ <- pico p₂
      msum
        [ do
            KonClo x c' ρ <- extract konCloI konCloE av₁
            putL 𝓈EnvL ρ
            bindJoin x v₂
            return c'
        , do
            KonMemoClo 𝓁 th x c' ρ <- extract konMemoCloI konMemoCloE av₁
            updateRef 𝓁 (thunkCloI th) . forcedI . Forced *$ picoArg p₂
            putL 𝓈EnvL ρ
            bindJoin x v₂
            return c'
        ]
    AppF xi' x' p₁ p₂ p₃ -> do
      av₁ <- picoArg p₁
      v₂ <- pico p₂
      v₃ <- pico p₃
      msum
        [ do
            FunClo x k c' ρ  <- extract funCloI funCloE av₁
            putL 𝓈EnvL ρ
            putL (timeLexL <.> 𝓈TimeL) 
            bindJoin x v₂
            bindJoin k v₃
            return c'
        , forceThunk x' av₁ $ StampedFix xi' $ AppF xi' x' (Var x') p₂ p₃
        ]
    Case xi' x' p bs0 -> do
      av <- picoArg p
      msum
        [ do
            -- loop through the alternatives
            let loop bs = do
                  (CaseBranch acon xs c', bs') <- maybeZero $ view consL bs
                  case acon of
                    DataAlt con -> msum
                      -- The alt is a Data and the value is a Data with the same
                      -- tag; jump to the alt body.
                      [ do
                          Data dcon 𝓁s <- extract dataI dataE av
                          guard $ con == dcon
                          x𝓁s <- maybeZero $ zip xs 𝓁s
                          traverseOn x𝓁s $ \ (x, av') -> do
                            v' <- argVal av'
                            bindJoin x v'
                          return c'
                      -- The alt is a Data and the value is not a Data with the
                      -- same tag; try the next branch.
                      , do
                          refine av $ neg $ dataAnyI con
                          loop bs'
                      ]
                    LitAlt l -> msum
                      -- The alt is a Lit and the value is the same lit; jump to
                      -- the alt body.
                      [ do
                          extractIsLit l av
                          return c'
                      -- The alt is a Lit and and the value is not the same lit;
                      -- try the next branch.
                      , do
                          refine av $ neg $ litI l
                          loop bs'
                      ]
                    -- The alt is the default branch; jump to the body _only if
                    -- the value is not a ref_.
                    DEFAULT -> do
                      refine av $ neg $ refAnyI
                      return c
            loop bs0
        , forceThunk x' av $ StampedFix xi' $ Case xi' x' (Var x') bs0
        ]
    Halt _ -> return c