{-# LANGUAGE GeneralizedNewtypeDeriving, StandaloneDeriving,
FlexibleContexts, TypeFamilies, KindSignatures #-}
module Data.Singletons.Promote.Monad (
PrM, promoteM, promoteM_, promoteMDecs, VarPromotions,
allLocals, emitDecs, emitDecsM,
lambdaBind, LetBind, letBind, lookupVarE, forallBind, allBoundKindVars
) where
import Control.Monad.Reader
import Control.Monad.Writer
import qualified Data.Map.Strict as Map
import Data.Map.Strict ( Map )
import qualified Data.Set as Set
import Data.Set ( Set )
import Language.Haskell.TH.Syntax hiding ( lift )
import Language.Haskell.TH.Desugar
import Data.Singletons.Names
import Data.Singletons.Syntax
import Control.Monad.Fail ( MonadFail )
type LetExpansions = Map Name DType
data PrEnv =
PrEnv { pr_lambda_bound :: Map Name Name
, pr_let_bound :: LetExpansions
, pr_forall_bound :: Set Name
, pr_local_decls :: [Dec]
}
emptyPrEnv :: PrEnv
emptyPrEnv = PrEnv { pr_lambda_bound = Map.empty
, pr_let_bound = Map.empty
, pr_forall_bound = Set.empty
, pr_local_decls = [] }
newtype PrM a = PrM (ReaderT PrEnv (WriterT [DDec] Q) a)
deriving ( Functor, Applicative, Monad, Quasi
, MonadReader PrEnv, MonadWriter [DDec]
, MonadFail, MonadIO )
instance DsMonad PrM where
localDeclarations = asks pr_local_decls
allLocals :: MonadReader PrEnv m => m [Name]
allLocals = do
lambdas <- asks (Map.toList . pr_lambda_bound)
lets <- asks pr_let_bound
return [ typeName
| (termName, typeName) <- lambdas
, case Map.lookup termName lets of
Just (DVarT typeName') | typeName' == typeName -> True
_ -> False ]
emitDecs :: MonadWriter [DDec] m => [DDec] -> m ()
emitDecs = tell
emitDecsM :: MonadWriter [DDec] m => m [DDec] -> m ()
emitDecsM action = do
decs <- action
emitDecs decs
lambdaBind :: VarPromotions -> PrM a -> PrM a
lambdaBind binds = local add_binds
where add_binds env@(PrEnv { pr_lambda_bound = lambdas
, pr_let_bound = lets }) =
let new_lets = Map.fromList [ (tmN, DVarT tyN) | (tmN, tyN) <- binds ] in
env { pr_lambda_bound = Map.union (Map.fromList binds) lambdas
, pr_let_bound = Map.union new_lets lets }
type LetBind = (Name, DType)
letBind :: [LetBind] -> PrM a -> PrM a
letBind binds = local add_binds
where add_binds env@(PrEnv { pr_let_bound = lets }) =
env { pr_let_bound = Map.union (Map.fromList binds) lets }
lookupVarE :: Name -> PrM DType
lookupVarE n = do
lets <- asks pr_let_bound
case Map.lookup n lets of
Just ty -> return ty
Nothing -> return $ promoteValRhs n
forallBind :: Set Name -> PrM a -> PrM a
forallBind kvs1 =
local (\env@(PrEnv { pr_forall_bound = kvs2 }) ->
env { pr_forall_bound = kvs1 `Set.union` kvs2 })
allBoundKindVars :: PrM (Set Name)
allBoundKindVars = asks pr_forall_bound
promoteM :: DsMonad q => [Dec] -> PrM a -> q (a, [DDec])
promoteM locals (PrM rdr) = do
other_locals <- localDeclarations
let wr = runReaderT rdr (emptyPrEnv { pr_local_decls = other_locals ++ locals })
q = runWriterT wr
runQ q
promoteM_ :: DsMonad q => [Dec] -> PrM () -> q [DDec]
promoteM_ locals thing = do
((), decs) <- promoteM locals thing
return decs
promoteMDecs :: DsMonad q => [Dec] -> PrM [DDec] -> q [DDec]
promoteMDecs locals thing = do
(decs1, decs2) <- promoteM locals thing
return $ decs1 ++ decs2