{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE RankNTypes #-}
module Clash.Rewrite.WorkFree
( isWorkFree
, isWorkFreeClockOrResetOrEnable
, isWorkFreeIsh
, isConstant
, isConstantNotClockReset
) where
import Control.Lens (Lens')
import Control.Monad.Extra (allM, andM, eitherM)
import Control.Monad.State.Class (MonadState)
import GHC.Stack (HasCallStack)
import Clash.Core.FreeVars
import Clash.Core.Pretty (showPpr)
import Clash.Core.Term
import Clash.Core.TermInfo
import Clash.Core.TyCon (TyConMap)
import Clash.Core.Type (isPolyFunTy)
import Clash.Core.Util
import Clash.Core.Var (Id, Var(..), isLocalId)
import Clash.Core.VarEnv (VarEnv, lookupVarEnv)
import Clash.Driver.Types (BindingMap, Binding(..))
import Clash.Util (makeCachedU)
isWorkFreeBinder
:: (HasCallStack, MonadState s m)
=> Lens' s (VarEnv Bool)
-> BindingMap
-> Id
-> m Bool
isWorkFreeBinder :: Lens' s (VarEnv Bool) -> BindingMap -> Id -> m Bool
isWorkFreeBinder Lens' s (VarEnv Bool)
cache BindingMap
bndrs Id
bndr =
Id -> Lens' s (VarEnv Bool) -> m Bool -> m Bool
forall s (m :: Type -> Type) k v.
(MonadState s m, Uniquable k) =>
k -> Lens' s (UniqMap v) -> m v -> m v
makeCachedU Id
bndr Lens' s (VarEnv Bool)
cache (m Bool -> m Bool) -> m Bool -> m Bool
forall a b. (a -> b) -> a -> b
$
case Id -> BindingMap -> Maybe (Binding Term)
forall b a. Var b -> VarEnv a -> Maybe a
lookupVarEnv Id
bndr BindingMap
bndrs of
Maybe (Binding Term)
Nothing -> [Char] -> m Bool
forall a. HasCallStack => [Char] -> a
error ([Char]
"isWorkFreeBinder: couldn't find binder: " [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ Id -> [Char]
forall p. PrettyPrec p => p -> [Char]
showPpr Id
bndr)
Just (Binding Term -> Term
forall a. Binding a -> a
bindingTerm -> Term
t) ->
if Id
bndr Id -> Term -> Bool
`globalIdOccursIn` Term
t
then Bool -> m Bool
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure Bool
False
else Lens' s (VarEnv Bool) -> BindingMap -> Term -> m Bool
forall s (m :: Type -> Type).
(HasCallStack, MonadState s m) =>
Lens' s (VarEnv Bool) -> BindingMap -> Term -> m Bool
isWorkFree Lens' s (VarEnv Bool)
cache BindingMap
bndrs Term
t
{-# INLINABLE isWorkFree #-}
isWorkFree
:: forall s m
. (HasCallStack, MonadState s m)
=> Lens' s (VarEnv Bool)
-> BindingMap
-> Term
-> m Bool
isWorkFree :: Lens' s (VarEnv Bool) -> BindingMap -> Term -> m Bool
isWorkFree Lens' s (VarEnv Bool)
cache BindingMap
bndrs = HasCallStack => Bool -> Term -> m Bool
Bool -> Term -> m Bool
go Bool
True
where
go :: HasCallStack => Bool -> Term -> m Bool
go :: Bool -> Term -> m Bool
go Bool
isOutermost (Term -> (Term, [Either Term Type])
collectArgs -> (Term
fun, [Either Term Type]
args)) =
case Term
fun of
Var Id
i
| Type -> Bool
isPolyFunTy (Id -> Type
forall a. Var a -> Type
varType Id
i) ->
Bool -> m Bool
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure (Id -> Bool
forall a. Var a -> Bool
isLocalId Id
i Bool -> Bool -> Bool
&& Bool
isOutermost Bool -> Bool -> Bool
&& [Either Term Type] -> Bool
forall (t :: Type -> Type) a. Foldable t => t a -> Bool
null [Either Term Type]
args)
| Id -> Bool
forall a. Var a -> Bool
isLocalId Id
i ->
Bool -> m Bool
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure Bool
True
| Bool
otherwise ->
[m Bool] -> m Bool
forall (m :: Type -> Type). Monad m => [m Bool] -> m Bool
andM [Lens' s (VarEnv Bool) -> BindingMap -> Id -> m Bool
forall s (m :: Type -> Type).
(HasCallStack, MonadState s m) =>
Lens' s (VarEnv Bool) -> BindingMap -> Id -> m Bool
isWorkFreeBinder Lens' s (VarEnv Bool)
cache BindingMap
bndrs Id
i, (Either Term Type -> m Bool) -> [Either Term Type] -> m Bool
forall (m :: Type -> Type) a.
Monad m =>
(a -> m Bool) -> [a] -> m Bool
allM Either Term Type -> m Bool
forall b. Either Term b -> m Bool
goArg [Either Term Type]
args]
Data DataCon
_ -> (Either Term Type -> m Bool) -> [Either Term Type] -> m Bool
forall (m :: Type -> Type) a.
Monad m =>
(a -> m Bool) -> [a] -> m Bool
allM Either Term Type -> m Bool
forall b. Either Term b -> m Bool
goArg [Either Term Type]
args
Literal Literal
_ -> Bool -> m Bool
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure Bool
True
Prim PrimInfo
pr ->
case PrimInfo -> WorkInfo
primWorkInfo PrimInfo
pr of
WorkInfo
WorkConstant -> Bool -> m Bool
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure Bool
True
WorkInfo
WorkNever -> (Either Term Type -> m Bool) -> [Either Term Type] -> m Bool
forall (m :: Type -> Type) a.
Monad m =>
(a -> m Bool) -> [a] -> m Bool
allM Either Term Type -> m Bool
forall b. Either Term b -> m Bool
goArg [Either Term Type]
args
WorkInfo
WorkVariable -> Bool -> m Bool
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure ((Either Term Type -> Bool) -> [Either Term Type] -> Bool
forall (t :: Type -> Type) a.
Foldable t =>
(a -> Bool) -> t a -> Bool
all Either Term Type -> Bool
forall b. Either Term b -> Bool
isConstantArg [Either Term Type]
args)
WorkInfo
WorkAlways -> Bool -> m Bool
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure Bool
False
Lam Id
_ Term
e -> [m Bool] -> m Bool
forall (m :: Type -> Type). Monad m => [m Bool] -> m Bool
andM [HasCallStack => Bool -> Term -> m Bool
Bool -> Term -> m Bool
go Bool
False Term
e, (Either Term Type -> m Bool) -> [Either Term Type] -> m Bool
forall (m :: Type -> Type) a.
Monad m =>
(a -> m Bool) -> [a] -> m Bool
allM Either Term Type -> m Bool
forall b. Either Term b -> m Bool
goArg [Either Term Type]
args]
TyLam TyVar
_ Term
e -> [m Bool] -> m Bool
forall (m :: Type -> Type). Monad m => [m Bool] -> m Bool
andM [HasCallStack => Bool -> Term -> m Bool
Bool -> Term -> m Bool
go Bool
False Term
e, (Either Term Type -> m Bool) -> [Either Term Type] -> m Bool
forall (m :: Type -> Type) a.
Monad m =>
(a -> m Bool) -> [a] -> m Bool
allM Either Term Type -> m Bool
forall b. Either Term b -> m Bool
goArg [Either Term Type]
args]
Letrec [LetBinding]
bs Term
e -> [m Bool] -> m Bool
forall (m :: Type -> Type). Monad m => [m Bool] -> m Bool
andM [HasCallStack => Bool -> Term -> m Bool
Bool -> Term -> m Bool
go Bool
False Term
e, (LetBinding -> m Bool) -> [LetBinding] -> m Bool
forall (m :: Type -> Type) a.
Monad m =>
(a -> m Bool) -> [a] -> m Bool
allM (HasCallStack => Bool -> Term -> m Bool
Bool -> Term -> m Bool
go Bool
False (Term -> m Bool) -> (LetBinding -> Term) -> LetBinding -> m Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. LetBinding -> Term
forall a b. (a, b) -> b
snd) [LetBinding]
bs, (Either Term Type -> m Bool) -> [Either Term Type] -> m Bool
forall (m :: Type -> Type) a.
Monad m =>
(a -> m Bool) -> [a] -> m Bool
allM Either Term Type -> m Bool
forall b. Either Term b -> m Bool
goArg [Either Term Type]
args]
Case Term
s Type
_ [(Pat
_, Term
a)] -> [m Bool] -> m Bool
forall (m :: Type -> Type). Monad m => [m Bool] -> m Bool
andM [HasCallStack => Bool -> Term -> m Bool
Bool -> Term -> m Bool
go Bool
False Term
s, HasCallStack => Bool -> Term -> m Bool
Bool -> Term -> m Bool
go Bool
False Term
a, (Either Term Type -> m Bool) -> [Either Term Type] -> m Bool
forall (m :: Type -> Type) a.
Monad m =>
(a -> m Bool) -> [a] -> m Bool
allM Either Term Type -> m Bool
forall b. Either Term b -> m Bool
goArg [Either Term Type]
args]
Case Term
e Type
_ [Alt]
_ -> [m Bool] -> m Bool
forall (m :: Type -> Type). Monad m => [m Bool] -> m Bool
andM [HasCallStack => Bool -> Term -> m Bool
Bool -> Term -> m Bool
go Bool
False Term
e, (Either Term Type -> m Bool) -> [Either Term Type] -> m Bool
forall (m :: Type -> Type) a.
Monad m =>
(a -> m Bool) -> [a] -> m Bool
allM Either Term Type -> m Bool
forall b. Either Term b -> m Bool
goArg [Either Term Type]
args]
Cast Term
e Type
_ Type
_ -> [m Bool] -> m Bool
forall (m :: Type -> Type). Monad m => [m Bool] -> m Bool
andM [HasCallStack => Bool -> Term -> m Bool
Bool -> Term -> m Bool
go Bool
False Term
e, (Either Term Type -> m Bool) -> [Either Term Type] -> m Bool
forall (m :: Type -> Type) a.
Monad m =>
(a -> m Bool) -> [a] -> m Bool
allM Either Term Type -> m Bool
forall b. Either Term b -> m Bool
goArg [Either Term Type]
args]
Tick TickInfo
_ Term
_ -> [Char] -> m Bool
forall a. HasCallStack => [Char] -> a
error [Char]
"isWorkFree: unexpected Tick"
App {} -> [Char] -> m Bool
forall a. HasCallStack => [Char] -> a
error [Char]
"isWorkFree: unexpected App"
TyApp {} -> [Char] -> m Bool
forall a. HasCallStack => [Char] -> a
error [Char]
"isWorkFree: unexpected TyApp"
goArg :: Either Term b -> m Bool
goArg Either Term b
e = (Term -> m Bool) -> (b -> m Bool) -> m (Either Term b) -> m Bool
forall (m :: Type -> Type) a c b.
Monad m =>
(a -> m c) -> (b -> m c) -> m (Either a b) -> m c
eitherM (HasCallStack => Bool -> Term -> m Bool
Bool -> Term -> m Bool
go Bool
False) (Bool -> m Bool
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure (Bool -> m Bool) -> (b -> Bool) -> b -> m Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Bool -> b -> Bool
forall a b. a -> b -> a
const Bool
True) (Either Term b -> m (Either Term b)
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure Either Term b
e)
isConstantArg :: Either Term b -> Bool
isConstantArg = (Term -> Bool) -> (b -> Bool) -> Either Term b -> Bool
forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either Term -> Bool
isConstant (Bool -> b -> Bool
forall a b. a -> b -> a
const Bool
True)
isConstant :: Term -> Bool
isConstant :: Term -> Bool
isConstant Term
e = case Term -> (Term, [Either Term Type])
collectArgs Term
e of
(Data DataCon
_, [Either Term Type]
args) -> (Either Term Type -> Bool) -> [Either Term Type] -> Bool
forall (t :: Type -> Type) a.
Foldable t =>
(a -> Bool) -> t a -> Bool
all ((Term -> Bool) -> (Type -> Bool) -> Either Term Type -> Bool
forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either Term -> Bool
isConstant (Bool -> Type -> Bool
forall a b. a -> b -> a
const Bool
True)) [Either Term Type]
args
(Prim PrimInfo
_, [Either Term Type]
args) -> (Either Term Type -> Bool) -> [Either Term Type] -> Bool
forall (t :: Type -> Type) a.
Foldable t =>
(a -> Bool) -> t a -> Bool
all ((Term -> Bool) -> (Type -> Bool) -> Either Term Type -> Bool
forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either Term -> Bool
isConstant (Bool -> Type -> Bool
forall a b. a -> b -> a
const Bool
True)) [Either Term Type]
args
(Lam Id
_ Term
_, [Either Term Type]
_) -> Bool -> Bool
not (Term -> Bool
hasLocalFreeVars Term
e)
(Literal Literal
_,[Either Term Type]
_) -> Bool
True
(Term, [Either Term Type])
_ -> Bool
False
isConstantNotClockReset :: TyConMap -> Term -> Bool
isConstantNotClockReset :: TyConMap -> Term -> Bool
isConstantNotClockReset TyConMap
tcm Term
e
| TyConMap -> Type -> Bool
isClockOrReset TyConMap
tcm Type
eTy =
case (Term, [Either Term Type]) -> Term
forall a b. (a, b) -> a
fst (Term -> (Term, [Either Term Type])
collectArgs Term
e) of
Prim PrimInfo
pr -> PrimInfo -> Text
primName PrimInfo
pr Text -> Text -> Bool
forall a. Eq a => a -> a -> Bool
== Text
"Clash.Transformations.removedArg"
Term
_ -> Bool
False
| Bool
otherwise = Term -> Bool
isConstant Term
e
where
eTy :: Type
eTy = TyConMap -> Term -> Type
termType TyConMap
tcm Term
e
isWorkFreeClockOrResetOrEnable
:: TyConMap
-> Term
-> Maybe Bool
isWorkFreeClockOrResetOrEnable :: TyConMap -> Term -> Maybe Bool
isWorkFreeClockOrResetOrEnable TyConMap
tcm Term
e =
let eTy :: Type
eTy = TyConMap -> Term -> Type
termType TyConMap
tcm Term
e in
if TyConMap -> Type -> Bool
isClockOrReset TyConMap
tcm Type
eTy Bool -> Bool -> Bool
|| TyConMap -> Type -> Bool
isEnable TyConMap
tcm Type
eTy then
case Term -> (Term, [Either Term Type])
collectArgs Term
e of
(Prim PrimInfo
p,[Either Term Type]
_) -> Bool -> Maybe Bool
forall a. a -> Maybe a
Just (PrimInfo -> Text
primName PrimInfo
p Text -> Text -> Bool
forall a. Eq a => a -> a -> Bool
== Text
"Clash.Transformations.removedArg")
(Var Id
_, []) -> Bool -> Maybe Bool
forall a. a -> Maybe a
Just Bool
True
(Data DataCon
_, []) -> Bool -> Maybe Bool
forall a. a -> Maybe a
Just Bool
True
(Literal Literal
_,[Either Term Type]
_) -> Bool -> Maybe Bool
forall a. a -> Maybe a
Just Bool
True
(Term, [Either Term Type])
_ -> Bool -> Maybe Bool
forall a. a -> Maybe a
Just Bool
False
else
Maybe Bool
forall a. Maybe a
Nothing
isWorkFreeIsh
:: TyConMap
-> Term
-> Bool
isWorkFreeIsh :: TyConMap -> Term -> Bool
isWorkFreeIsh TyConMap
tcm Term
e =
case TyConMap -> Term -> Maybe Bool
isWorkFreeClockOrResetOrEnable TyConMap
tcm Term
e of
Just Bool
b -> Bool
b
Maybe Bool
Nothing ->
case Term -> (Term, [Either Term Type])
collectArgs Term
e of
(Data DataCon
_, [Either Term Type]
args) -> (Either Term Type -> Bool) -> [Either Term Type] -> Bool
forall (t :: Type -> Type) a.
Foldable t =>
(a -> Bool) -> t a -> Bool
all Either Term Type -> Bool
forall b. Either Term b -> Bool
isWorkFreeIshArg [Either Term Type]
args
(Prim PrimInfo
pInfo, [Either Term Type]
args) -> case PrimInfo -> WorkInfo
primWorkInfo PrimInfo
pInfo of
WorkInfo
WorkAlways -> Bool
False
WorkInfo
WorkVariable -> (Either Term Type -> Bool) -> [Either Term Type] -> Bool
forall (t :: Type -> Type) a.
Foldable t =>
(a -> Bool) -> t a -> Bool
all Either Term Type -> Bool
forall b. Either Term b -> Bool
isConstantArg [Either Term Type]
args
WorkInfo
_ -> (Either Term Type -> Bool) -> [Either Term Type] -> Bool
forall (t :: Type -> Type) a.
Foldable t =>
(a -> Bool) -> t a -> Bool
all Either Term Type -> Bool
forall b. Either Term b -> Bool
isWorkFreeIshArg [Either Term Type]
args
(Lam Id
_ Term
_, [Either Term Type]
_) -> Bool -> Bool
not (Term -> Bool
hasLocalFreeVars Term
e)
(Literal Literal
_,[Either Term Type]
_) -> Bool
True
(Term, [Either Term Type])
_ -> Bool
False
where
isWorkFreeIshArg :: Either Term b -> Bool
isWorkFreeIshArg = (Term -> Bool) -> (b -> Bool) -> Either Term b -> Bool
forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either (TyConMap -> Term -> Bool
isWorkFreeIsh TyConMap
tcm) (Bool -> b -> Bool
forall a b. a -> b -> a
const Bool
True)
isConstantArg :: Either Term b -> Bool
isConstantArg = (Term -> Bool) -> (b -> Bool) -> Either Term b -> Bool
forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either Term -> Bool
isConstant (Bool -> b -> Bool
forall a b. a -> b -> a
const Bool
True)