{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE TemplateHaskell #-}
module Clash.Normalize.Transformations.EtaExpand
( etaExpandSyn
, etaExpansionTL
) where
import qualified Control.Lens as Lens
import qualified Data.Maybe as Maybe
import GHC.Stack (HasCallStack)
import Clash.Core.HasType
import Clash.Core.Term (Bind(..), CoreContext(..), Term(..), collectArgs, mkLams)
import Clash.Core.TermInfo (isFun)
import Clash.Core.Type (splitFunTy)
import Clash.Core.Util (mkInternalVar)
import Clash.Core.Var (Id)
import Clash.Core.VarEnv (elemVarSet, extendInScopeSet, extendInScopeSetList)
import Clash.Normalize.Types (NormRewrite)
import Clash.Rewrite.Types (TransformContext(..), tcCache, topEntities)
import Clash.Rewrite.Util (changed)
import Clash.Util (curLoc)
etaExpandSyn :: HasCallStack => NormRewrite
etaExpandSyn :: NormRewrite
etaExpandSyn (TransformContext InScopeSet
is0 Context
ctx) e :: Term
e@(Term -> (Term, [Either Term Type])
collectArgs -> (Var Id
f, [Either Term Type]
_)) = do
UniqSet (Var Any)
topEnts <- Getting (UniqSet (Var Any)) RewriteEnv (UniqSet (Var Any))
-> RewriteMonad NormalizeState (UniqSet (Var Any))
forall s (m :: Type -> Type) a.
MonadReader s m =>
Getting a s a -> m a
Lens.view Getting (UniqSet (Var Any)) RewriteEnv (UniqSet (Var Any))
Lens' RewriteEnv (UniqSet (Var Any))
topEntities
TyConMap
tcm <- Getting TyConMap RewriteEnv TyConMap
-> RewriteMonad NormalizeState TyConMap
forall s (m :: Type -> Type) a.
MonadReader s m =>
Getting a s a -> m a
Lens.view Getting TyConMap RewriteEnv TyConMap
Getter RewriteEnv TyConMap
tcCache
let isTopEnt :: Bool
isTopEnt = Id
f Id -> UniqSet (Var Any) -> Bool
forall a. Var a -> UniqSet (Var Any) -> Bool
`elemVarSet` UniqSet (Var Any)
topEnts
isAppFunCtx :: Context -> Bool
isAppFunCtx =
\case
CoreContext
AppFun:Context
_ -> Bool
True
TickC TickInfo
_:Context
c -> Context -> Bool
isAppFunCtx Context
c
Context
_ -> Bool
False
argTyM :: Maybe Type
argTyM = ((Type, Type) -> Type) -> Maybe (Type, Type) -> Maybe Type
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
fmap (Type, Type) -> Type
forall a b. (a, b) -> a
fst (TyConMap -> Type -> Maybe (Type, Type)
splitFunTy TyConMap
tcm (TyConMap -> Term -> Type
forall a. InferType a => TyConMap -> a -> Type
inferCoreTypeOf TyConMap
tcm Term
e))
case Maybe Type
argTyM of
Just Type
argTy | Bool
isTopEnt Bool -> Bool -> Bool
&& Bool -> Bool
not (Context -> Bool
isAppFunCtx Context
ctx) -> do
Id
newId <- InScopeSet -> OccName -> Type -> RewriteMonad NormalizeState Id
forall (m :: Type -> Type).
MonadUnique m =>
InScopeSet -> OccName -> Type -> m Id
mkInternalVar InScopeSet
is0 OccName
"arg" Type
argTy
Term -> RewriteMonad NormalizeState Term
forall a extra. a -> RewriteMonad extra a
changed (Id -> Term -> Term
Lam Id
newId (Term -> Term -> Term
App Term
e (Id -> Term
Var Id
newId)))
Maybe Type
_ -> Term -> RewriteMonad NormalizeState Term
forall (m :: Type -> Type) a. Monad m => a -> m a
return Term
e
etaExpandSyn TransformContext
_ Term
e = Term -> RewriteMonad NormalizeState Term
forall (m :: Type -> Type) a. Monad m => a -> m a
return Term
e
{-# SCC etaExpandSyn #-}
stripLambda :: Term -> ([Id], Term)
stripLambda :: Term -> ([Id], Term)
stripLambda (Lam Id
bndr Term
e) =
let ([Id]
bndrs, Term
e') = Term -> ([Id], Term)
stripLambda Term
e
in (Id
bndr Id -> [Id] -> [Id]
forall a. a -> [a] -> [a]
: [Id]
bndrs, Term
e')
stripLambda Term
e = ([], Term
e)
etaExpansionTL :: HasCallStack => NormRewrite
etaExpansionTL :: NormRewrite
etaExpansionTL (TransformContext InScopeSet
is0 Context
ctx) (Lam Id
bndr Term
e) = do
let ctx' :: TransformContext
ctx' = InScopeSet -> Context -> TransformContext
TransformContext (InScopeSet -> Id -> InScopeSet
forall a. InScopeSet -> Var a -> InScopeSet
extendInScopeSet InScopeSet
is0 Id
bndr) (Id -> CoreContext
LamBody Id
bndr CoreContext -> Context -> Context
forall a. a -> [a] -> [a]
: Context
ctx)
Term
e' <- HasCallStack => NormRewrite
NormRewrite
etaExpansionTL TransformContext
ctx' Term
e
Term -> RewriteMonad NormalizeState Term
forall (m :: Type -> Type) a. Monad m => a -> m a
return (Term -> RewriteMonad NormalizeState Term)
-> Term -> RewriteMonad NormalizeState Term
forall a b. (a -> b) -> a -> b
$ Id -> Term -> Term
Lam Id
bndr Term
e'
etaExpansionTL (TransformContext InScopeSet
is0 Context
ctx) (Let (NonRec Id
i Term
x) Term
e) = do
let ctx' :: TransformContext
ctx' = InScopeSet -> Context -> TransformContext
TransformContext (InScopeSet -> Id -> InScopeSet
forall a. InScopeSet -> Var a -> InScopeSet
extendInScopeSet InScopeSet
is0 Id
i) ([Id] -> CoreContext
LetBody [Id
i] CoreContext -> Context -> Context
forall a. a -> [a] -> [a]
: Context
ctx)
Term
e' <- HasCallStack => NormRewrite
NormRewrite
etaExpansionTL TransformContext
ctx' Term
e
case Term -> ([Id], Term)
stripLambda Term
e' of
(bs :: [Id]
bs@(Id
_:[Id]
_),Term
e2) -> do
let e3 :: Term
e3 = Bind Term -> Term -> Term
Let (Id -> Term -> Bind Term
forall a. Id -> a -> Bind a
NonRec Id
i Term
x) Term
e2
Term -> RewriteMonad NormalizeState Term
forall a extra. a -> RewriteMonad extra a
changed (Term -> [Id] -> Term
mkLams Term
e3 [Id]
bs)
([Id], Term)
_ -> Term -> RewriteMonad NormalizeState Term
forall (m :: Type -> Type) a. Monad m => a -> m a
return (Bind Term -> Term -> Term
Let (Id -> Term -> Bind Term
forall a. Id -> a -> Bind a
NonRec Id
i Term
x) Term
e')
etaExpansionTL (TransformContext InScopeSet
is0 Context
ctx) (Let (Rec [(Id, Term)]
xes) Term
e) = do
let bndrs :: [Id]
bndrs = ((Id, Term) -> Id) -> [(Id, Term)] -> [Id]
forall a b. (a -> b) -> [a] -> [b]
map (Id, Term) -> Id
forall a b. (a, b) -> a
fst [(Id, Term)]
xes
ctx' :: TransformContext
ctx' = InScopeSet -> Context -> TransformContext
TransformContext (InScopeSet -> [Id] -> InScopeSet
forall a. InScopeSet -> [Var a] -> InScopeSet
extendInScopeSetList InScopeSet
is0 [Id]
bndrs) ([Id] -> CoreContext
LetBody [Id]
bndrs CoreContext -> Context -> Context
forall a. a -> [a] -> [a]
: Context
ctx)
Term
e' <- HasCallStack => NormRewrite
NormRewrite
etaExpansionTL TransformContext
ctx' Term
e
case Term -> ([Id], Term)
stripLambda Term
e' of
(bs :: [Id]
bs@(Id
_:[Id]
_),Term
e2) -> do
let e3 :: Term
e3 = Bind Term -> Term -> Term
Let ([(Id, Term)] -> Bind Term
forall a. [(Id, a)] -> Bind a
Rec [(Id, Term)]
xes) Term
e2
Term -> RewriteMonad NormalizeState Term
forall a extra. a -> RewriteMonad extra a
changed (Term -> [Id] -> Term
mkLams Term
e3 [Id]
bs)
([Id], Term)
_ -> Term -> RewriteMonad NormalizeState Term
forall (m :: Type -> Type) a. Monad m => a -> m a
return (Bind Term -> Term -> Term
Let ([(Id, Term)] -> Bind Term
forall a. [(Id, a)] -> Bind a
Rec [(Id, Term)]
xes) Term
e')
etaExpansionTL (TransformContext InScopeSet
is0 Context
ctx) Term
e
= do
TyConMap
tcm <- Getting TyConMap RewriteEnv TyConMap
-> RewriteMonad NormalizeState TyConMap
forall s (m :: Type -> Type) a.
MonadReader s m =>
Getting a s a -> m a
Lens.view Getting TyConMap RewriteEnv TyConMap
Getter RewriteEnv TyConMap
tcCache
if TyConMap -> Term -> Bool
isFun TyConMap
tcm Term
e
then do
let argTy :: Type
argTy = ( (Type, Type) -> Type
forall a b. (a, b) -> a
fst
((Type, Type) -> Type) -> (Term -> (Type, Type)) -> Term -> Type
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Type, Type) -> Maybe (Type, Type) -> (Type, Type)
forall a. a -> Maybe a -> a
Maybe.fromMaybe ([Char] -> (Type, Type)
forall a. HasCallStack => [Char] -> a
error ([Char] -> (Type, Type)) -> [Char] -> (Type, Type)
forall a b. (a -> b) -> a -> b
$ $([Char]
curLoc) [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
"etaExpansion splitFunTy")
(Maybe (Type, Type) -> (Type, Type))
-> (Term -> Maybe (Type, Type)) -> Term -> (Type, Type)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TyConMap -> Type -> Maybe (Type, Type)
splitFunTy TyConMap
tcm
(Type -> Maybe (Type, Type))
-> (Term -> Type) -> Term -> Maybe (Type, Type)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TyConMap -> Term -> Type
forall a. InferType a => TyConMap -> a -> Type
inferCoreTypeOf TyConMap
tcm
) Term
e
Id
newId <- InScopeSet -> OccName -> Type -> RewriteMonad NormalizeState Id
forall (m :: Type -> Type).
MonadUnique m =>
InScopeSet -> OccName -> Type -> m Id
mkInternalVar InScopeSet
is0 OccName
"arg" Type
argTy
let ctx' :: TransformContext
ctx' = InScopeSet -> Context -> TransformContext
TransformContext (InScopeSet -> Id -> InScopeSet
forall a. InScopeSet -> Var a -> InScopeSet
extendInScopeSet InScopeSet
is0 Id
newId) (Id -> CoreContext
LamBody Id
newId CoreContext -> Context -> Context
forall a. a -> [a] -> [a]
: Context
ctx)
Term
e' <- HasCallStack => NormRewrite
NormRewrite
etaExpansionTL TransformContext
ctx' (Term -> Term -> Term
App Term
e (Id -> Term
Var Id
newId))
Term -> RewriteMonad NormalizeState Term
forall a extra. a -> RewriteMonad extra a
changed (Id -> Term -> Term
Lam Id
newId Term
e')
else Term -> RewriteMonad NormalizeState Term
forall (m :: Type -> Type) a. Monad m => a -> m a
return Term
e
{-# SCC etaExpansionTL #-}