module DDC.Core.Transform.Rewrite.Rule
(
BindMode (..)
, isBMSpec
, isBMValue
, RewriteRule (..)
, NamedRewriteRule
, mkRewriteRule
, checkRewriteRule
, Error (..)
, Side (..))
where
import DDC.Core.Transform.Rewrite.Error
import DDC.Core.Transform.Reannotate
import DDC.Core.Transform.TransformX
import DDC.Core.Exp
import DDC.Core.Pretty ()
import DDC.Core.Collect
import DDC.Core.Compounds
import DDC.Type.Pretty ()
import DDC.Type.Env (KindEnv, TypeEnv)
import DDC.Base.Pretty
import Control.Monad
import qualified DDC.Core.Analysis.Usage as U
import qualified DDC.Core.Check as C
import qualified DDC.Core.Collect as C
import qualified DDC.Core.Transform.SpreadX as S
import qualified DDC.Type.Check as T
import qualified DDC.Type.Compounds as T
import qualified DDC.Type.Env as T
import qualified DDC.Type.Equiv as T
import qualified DDC.Type.Predicates as T
import qualified DDC.Type.Subsumes as T
import qualified DDC.Type.Transform.SpreadT as S
import qualified Data.Map as Map
import qualified Data.Maybe as Maybe
import qualified Data.Set as Set
import qualified DDC.Type.Env as Env
data RewriteRule a n
= RewriteRule
{
ruleBinds :: [(BindMode, Bind n)]
, ruleConstraints :: [Type n]
, ruleLeft :: Exp a n
, ruleLeftHole :: Maybe (Exp a n)
, ruleRight :: Exp a n
, ruleWeakEff :: Maybe (Effect n)
, ruleWeakClo :: [Exp a n]
, ruleFreeVars :: [Bound n]
} deriving (Eq, Show)
type NamedRewriteRule a n
= (String, RewriteRule a n)
instance (Pretty n, Eq n) => Pretty (RewriteRule a n) where
ppr (RewriteRule bs cs lhs hole rhs _ _ _)
= pprBinders bs <> pprConstrs cs <> ppr lhs <> pprHole <> text " = " <> ppr rhs
where pprBinders [] = text ""
pprBinders bs' = foldl1 (<>) (map pprBinder bs') <> text ". "
pprBinder (BMSpec, b) = text "[" <> ppr b <> text "] "
pprBinder (BMValue _, b) = text "(" <> ppr b <> text ") "
pprConstrs [] = text ""
pprConstrs (c:cs') = ppr c <> text " => " <> pprConstrs cs'
pprHole
| Just h <- hole
= text " {" <> ppr h <> text "}"
| otherwise
= text ""
data BindMode
= BMSpec
| BMValue Int
deriving (Eq, Show)
isBMSpec :: BindMode -> Bool
isBMSpec BMSpec = True
isBMSpec _ = False
isBMValue :: BindMode -> Bool
isBMValue (BMValue _) = True
isBMValue _ = False
mkRewriteRule
:: Ord n
=> [(BindMode,Bind n)]
-> [Type n]
-> Exp a n
-> Maybe (Exp a n)
-> Exp a n
-> RewriteRule a n
mkRewriteRule bs cs lhs hole rhs
= RewriteRule bs cs lhs hole rhs Nothing [] []
checkRewriteRule
:: (Ord n, Show n, Pretty n)
=> C.Config n
-> T.Env n
-> T.Env n
-> RewriteRule a n
-> Either (Error a n)
(RewriteRule (C.AnTEC a n) n)
checkRewriteRule config kenv tenv
(RewriteRule bs cs lhs hole rhs _ _ _)
= do
let (kenv', tenv', bs') = extendBinds bs kenv tenv
let csSpread = map (S.spreadT kenv') cs
mapM_ (checkConstraint config kenv') csSpread
(lhs', _, _, _)
<- checkExp config kenv' tenv' Lhs lhs
hole' <- case hole of
Just h
-> do (h',_,_,_) <- checkExp config kenv' tenv' Lhs h
return $ Just h'
Nothing -> return Nothing
let Just a = takeAnnotOfExp lhs
let lhs_full = maybe lhs (XApp a lhs) hole
(lhs_full', tLeft, effLeft, cloLeft)
<- checkExp config kenv' tenv' Lhs lhs_full
(rhs', tRight, effRight, cloRight)
<- checkExp config kenv' tenv' Rhs rhs
let err = ErrorTypeConflict
(tLeft, effLeft, cloLeft)
(tRight, effRight, cloRight)
checkEquiv tLeft tRight err
effWeak <- makeEffectWeakening T.kEffect effLeft effRight err
cloWeak <- makeClosureWeakening config kenv' tenv' lhs_full' rhs'
checkUnmentionedBinders bs' lhs_full'
checkAnonymousBinders bs'
checkValidPattern lhs_full
bs'' <- countBinderUsage bs' rhs
let binds = Set.fromList
$ Maybe.catMaybes
$ map (T.takeSubstBoundOfBind . snd) bs
let freeVars = Set.toList
$ (C.freeX T.empty lhs_full'
`Set.union` C.freeX T.empty rhs)
`Set.difference` binds
return $ RewriteRule
bs'' csSpread
lhs' hole' rhs'
effWeak cloWeak
freeVars
extendBinds
:: Ord n
=> [(BindMode, Bind n)]
-> KindEnv n -> TypeEnv n
-> (T.KindEnv n, T.TypeEnv n, [(BindMode, Bind n)])
extendBinds binds kenv tenv
= go binds kenv tenv []
where
go [] k t acc
= (k,t,acc)
go ((bm,b):bs) k t acc
= let b' = S.spreadX k t b
(k',t') = case bm of
BMSpec -> (T.extend b' k, t)
BMValue _ -> (k, T.extend b' t)
in go bs k' t' (acc ++ [(bm,b')])
checkExp
:: (Ord n, Show n, Pretty n)
=> C.Config n
-> KindEnv n
-> TypeEnv n
-> Side
-> Exp a n
-> Either (Error a n)
(Exp (C.AnTEC a n) n, Type n, Effect n, Closure n)
checkExp defs kenv tenv side xx
= let xx' = S.spreadX kenv tenv xx
in case C.checkExp defs kenv tenv xx' of
Left err -> Left $ ErrorTypeCheck side xx' err
Right rhs -> return rhs
checkConstraint
:: (Ord n, Show n, Pretty n)
=> C.Config n
-> KindEnv n
-> Type n
-> Either (Error a n) (Kind n)
checkConstraint defs kenv tt
= case T.checkType (C.configPrimDataDefs defs) kenv tt of
Left _err -> Left $ ErrorBadConstraint tt
Right k
| T.isWitnessType tt -> return k
| otherwise -> Left $ ErrorBadConstraint tt
checkEquiv
:: Ord n
=> Type n
-> Type n
-> Error a n
-> Either (Error a n) ()
checkEquiv tLeft tRight err
| T.equivT tLeft tRight = return ()
| otherwise = Left err
makeEffectWeakening
:: (Ord n, Show n)
=> Kind n
-> Effect n
-> Effect n
-> Error a n
-> Either (Error a n) (Maybe (Type n))
makeEffectWeakening k effLeft effRight onError
| T.equivT effLeft effRight
= return Nothing
| T.subsumesT k effLeft effRight
= return $ Just effLeft
| otherwise
= Left onError
makeClosureWeakening
:: (Ord n, Pretty n, Show n)
=> C.Config n
-> T.Env n
-> T.Env n
-> Exp (C.AnTEC a n) n
-> Exp (C.AnTEC a n) n
-> Either (Error a n)
[Exp (C.AnTEC a n) n]
makeClosureWeakening config kenv tenv lhs rhs
= let lhs' = removeEffects config kenv tenv lhs
supportLeft = support Env.empty Env.empty lhs'
daLeft = supportDaVar supportLeft
wiLeft = supportWiVar supportLeft
spLeft = supportSpVar supportLeft
rhs' = removeEffects config kenv tenv rhs
supportRight = support Env.empty Env.empty rhs'
daRight = supportDaVar supportRight
wiRight = supportWiVar supportRight
spRight = supportSpVar supportRight
Just a = takeAnnotOfExp lhs
in Right
$ [XVar a u
| u <- Set.toList $ daLeft `Set.difference` daRight ]
++ [XWitness (WVar u)
| u <- Set.toList $ wiLeft `Set.difference` wiRight ]
++ [XType (TVar u)
| u <- Set.toList $ spLeft `Set.difference` spRight ]
removeEffects
:: (Ord n, Pretty n, Show n)
=> C.Config n
-> T.Env n
-> T.Env n
-> Exp a n
-> Exp a n
removeEffects config = transformUpX remove
where
remove kenv _tenv x
| XType et <- x
, Right k <- T.checkType (C.configPrimDataDefs config)
kenv et
, T.isEffectKind k
= XType $ T.tBot T.kEffect
| otherwise
= x
checkUnmentionedBinders
:: (Ord n, Show n)
=> [(BindMode, Bind n)]
-> Exp (C.AnTEC a n) n
-> Either (Error a n) ()
checkUnmentionedBinders bs expr
= let used = C.freeX T.empty expr `Set.union` C.freeT T.empty expr
binds = Set.fromList
$ Maybe.catMaybes
$ map (T.takeSubstBoundOfBind . snd) bs
in if binds `Set.isSubsetOf` used
then return ()
else Left ErrorVarUnmentioned
checkAnonymousBinders
:: [(BindMode, Bind n)]
-> Either (Error a n) ()
checkAnonymousBinders bs
| (b:_) <- filter T.isBAnon $ map snd bs
= Left $ ErrorAnonymousBinder b
| otherwise
= return ()
checkValidPattern :: Exp a n -> Either (Error a n) ()
checkValidPattern expr
= go expr
where go (XVar _ _) = return ()
go (XCon _ _) = return ()
go x@(XLAM _ _ _) = Left $ ErrorNotFirstOrder x
go x@(XLam _ _ _) = Left $ ErrorNotFirstOrder x
go (XApp _ l r) = go l >> go r
go x@(XLet _ _ _) = Left $ ErrorNotFirstOrder x
go x@(XCase _ _ _) = Left $ ErrorNotFirstOrder x
go (XCast _ _ x) = go x
go (XType t) = go_t t
go (XWitness _) = return ()
go_t (TVar _) = return ()
go_t (TCon _) = return ()
go_t t@(TForall _ _) = Left $ ErrorNotFirstOrder (XType t)
go_t (TApp l r) = go_t l >> go_t r
go_t (TSum _) = return ()
countBinderUsage
:: Ord n
=> [(BindMode, Bind n)]
-> Exp a n
-> Either (Error a n) [(BindMode, Bind n)]
countBinderUsage bs x
= let Just (U.UsedMap um)
= liftM fst $ takeAnnotOfExp $ U.usageX x
get (BMValue _, BName n t)
= (BMValue
$ length
$ Maybe.fromMaybe []
$ Map.lookup n um
, BName n t)
get b
= b
in return $ map get bs
instance Reannotate RewriteRule where
reannotate f (RewriteRule bs cs lhs hole rhs eff clo fv)
= RewriteRule bs cs (re lhs) (fmap re hole) (re rhs) eff (map re clo) fv
where
re = reannotate f