module Clash.Rewrite.Combinators
( allR
, (!->)
, (>-!)
, (>-!->)
, (>->)
, bottomupR
, repeatR
, topdownR
, whenR
, bottomupWhenR
)
where
import Control.DeepSeq (deepseq)
import Control.Monad ((>=>))
import qualified Control.Monad.Writer as Writer
import qualified Data.Monoid as Monoid
import Clash.Core.Term (Term (..), CoreContext (..), primArg, patIds)
import Clash.Core.VarEnv
(extendInScopeSet, extendInScopeSetList)
import Clash.Rewrite.Types
allR
:: forall m
. Monad m
=> Transform m
-> Transform m
allR trans (TransformContext is c) (Lam v e) =
Lam v <$> trans (TransformContext (extendInScopeSet is v) (LamBody v:c)) e
allR trans (TransformContext is c) (TyLam tv e) =
TyLam tv <$> trans (TransformContext (extendInScopeSet is tv) (TyLamBody tv:c)) e
allR trans (TransformContext is c) (App e1 e2) = do
e1' <- trans (TransformContext is (AppFun:c)) e1
e2' <- trans (TransformContext is (AppArg (primArg e1') : c)) e2
pure (App e1' e2')
allR trans (TransformContext is c) (TyApp e ty) =
TyApp <$> trans (TransformContext is (TyAppC:c)) e <*> pure ty
allR trans (TransformContext is c) (Cast e ty1 ty2) =
Cast <$> trans (TransformContext is (CastBody:c)) e <*> pure ty1 <*> pure ty2
allR trans (TransformContext is c) (Letrec xes e) = do
xes' <- traverse rewriteBind xes
e' <- trans (TransformContext is' (LetBody bndrs:c)) e
return (Letrec xes' e')
where
bndrs = map fst xes
is' = extendInScopeSetList is (map fst xes)
rewriteBind (b,e') = (b,) <$> trans (TransformContext is' (LetBinding b bndrs:c)) e'
allR trans (TransformContext is c) (Case scrut ty alts) =
Case <$> trans (TransformContext is (CaseScrut:c)) scrut
<*> pure ty
<*> traverse rewriteAlt alts
where
rewriteAlt (p,e) =
let (tvs,ids) = patIds p
is' = extendInScopeSetList (extendInScopeSetList is tvs) ids
in (p,) <$> trans (TransformContext is' (CaseAlt p : c)) e
allR trans (TransformContext is c) (Tick sp e) =
Tick sp <$> trans (TransformContext is (TickC sp:c)) e
allR _ _ tm = pure tm
infixr 6 >->
(>->) :: Monad m => Transform m -> Transform m -> Transform m
(>->) = \r1 r2 c -> r1 c >=> r2 c
{-# INLINE (>->) #-}
infixr 6 >-!->
(>-!->) :: Monad m => Transform m -> Transform m -> Transform m
(>-!->) = \r1 r2 c e -> do
e' <- r1 c e
deepseq e' (r2 c e')
{-# INLINE (>-!->) #-}
topdownR :: Rewrite m -> Rewrite m
topdownR r = repeatR r >-> allR (topdownR r)
bottomupR :: Monad m => Transform m -> Transform m
bottomupR r = allR (bottomupR r) >-> r
infixr 5 !->
(!->) :: Rewrite m -> Rewrite m -> Rewrite m
(!->) = \r1 r2 c expr -> do
(expr',changed) <- Writer.listen $ r1 c expr
if Monoid.getAny changed
then r2 c expr'
else return expr'
{-# INLINE (!->) #-}
infixr 5 >-!
(>-!) :: Rewrite m -> Rewrite m -> Rewrite m
(>-!) = \r1 r2 c expr -> do
(expr',changed) <- Writer.listen $ r1 c expr
if Monoid.getAny changed
then return expr'
else r2 c expr'
{-# INLINE (>-!) #-}
repeatR :: Rewrite m -> Rewrite m
repeatR = let go r = r !-> repeatR r in go
{-# INLINE repeatR #-}
whenR :: Monad m
=> (TransformContext -> Term -> m Bool)
-> Transform m
-> Transform m
whenR f r1 ctx expr = do
b <- f ctx expr
if b
then r1 ctx expr
else return expr
bottomupWhenR
:: Monad m
=> (TransformContext -> Term -> m Bool)
-> Transform m
-> Transform m
bottomupWhenR f r ctx expr = do
b <- f ctx expr
if b
then (allR (bottomupWhenR f r) >-> r) ctx expr
else r ctx expr