{-# LANGUAGE OverloadedStrings #-}
module Clash.Normalize.Transformations.SeparateArgs
( separateArguments
) where
import qualified Control.Lens as Lens
import Control.Monad.Writer (listen)
import qualified Data.List as List
import qualified Data.Monoid as Monoid
import GHC.Stack (HasCallStack)
import Clash.Core.HasType
import Clash.Core.Name (Name(..))
import Clash.Core.Subst (extendIdSubst, mkSubst, substTm)
import Clash.Core.Term (Term(..), collectArgsTicks, mkApps, mkLams, mkTicks)
import Clash.Core.Type (Type, mkPolyFunTy, splitFunForallTy)
import Clash.Core.TyCon (TyConMap)
import Clash.Core.Util (Projections (..), shouldSplit)
import Clash.Core.Var (Id, TyVar, Var (..), isGlobalId, mkLocalId)
import Clash.Core.VarEnv (extendInScopeSet, uniqAway)
import Clash.Normalize.Types (NormRewrite, NormalizeSession)
import Clash.Rewrite.Types (TransformContext(..), tcCache)
import Clash.Rewrite.Util (changed, mkDerivedName)
separateArguments :: HasCallStack => NormRewrite
separateArguments :: NormRewrite
separateArguments TransformContext
ctx e0 :: Term
e0@(Lam Id
b Term
eb) = do
TyConMap
tcm <- Getting TyConMap RewriteEnv TyConMap
-> RewriteMonad NormalizeState TyConMap
forall s (m :: Type -> Type) a.
MonadReader s m =>
Getting a s a -> m a
Lens.view Getting TyConMap RewriteEnv TyConMap
Getter RewriteEnv TyConMap
tcCache
case TyConMap -> TransformContext -> Id -> Term -> Maybe Term
separateLambda TyConMap
tcm TransformContext
ctx Id
b Term
eb of
Just Term
e1 -> Term -> RewriteMonad NormalizeState Term
forall a extra. a -> RewriteMonad extra a
changed Term
e1
Maybe Term
Nothing -> Term -> RewriteMonad NormalizeState Term
forall (m :: Type -> Type) a. Monad m => a -> m a
return Term
e0
separateArguments (TransformContext InScopeSet
is0 Context
_) e :: Term
e@(Term -> (Term, [Either Term Type], [TickInfo])
collectArgsTicks -> (Var Id
g, [Either Term Type]
args, [TickInfo]
ticks))
| Id -> Bool
forall a. Var a -> Bool
isGlobalId Id
g = do
let ([Either TyVar Type]
argTys0,Type
resTy) = Type -> ([Either TyVar Type], Type)
splitFunForallTy (Id -> Type
forall a. HasType a => a -> Type
coreTypeOf Id
g)
([[(Either TyVar Type, Either Term Type)]]
-> [(Either TyVar Type, Either Term Type)]
forall (t :: Type -> Type) a. Foldable t => t [a] -> [a]
concat -> [(Either TyVar Type, Either Term Type)]
args1, Any -> Bool
Monoid.getAny -> Bool
hasChanged)
<- RewriteMonad
NormalizeState [[(Either TyVar Type, Either Term Type)]]
-> RewriteMonad
NormalizeState ([[(Either TyVar Type, Either Term Type)]], Any)
forall w (m :: Type -> Type) a. MonadWriter w m => m a -> m (a, w)
listen (((Either TyVar Type, Either Term Type)
-> RewriteMonad
NormalizeState [(Either TyVar Type, Either Term Type)])
-> [(Either TyVar Type, Either Term Type)]
-> RewriteMonad
NormalizeState [[(Either TyVar Type, Either Term Type)]]
forall (t :: Type -> Type) (m :: Type -> Type) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM ((Either TyVar Type
-> Either Term Type
-> RewriteMonad
NormalizeState [(Either TyVar Type, Either Term Type)])
-> (Either TyVar Type, Either Term Type)
-> RewriteMonad
NormalizeState [(Either TyVar Type, Either Term Type)]
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry Either TyVar Type
-> Either Term Type
-> RewriteMonad
NormalizeState [(Either TyVar Type, Either Term Type)]
splitArg) ([Either TyVar Type]
-> [Either Term Type] -> [(Either TyVar Type, Either Term Type)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Either TyVar Type]
argTys0 [Either Term Type]
args))
if Bool
hasChanged then
let ([Either TyVar Type]
argTys1,[Either Term Type]
args2) = [(Either TyVar Type, Either Term Type)]
-> ([Either TyVar Type], [Either Term Type])
forall a b. [(a, b)] -> ([a], [b])
unzip [(Either TyVar Type, Either Term Type)]
args1
gTy :: Type
gTy = Type -> [Either TyVar Type] -> Type
mkPolyFunTy Type
resTy [Either TyVar Type]
argTys1
in Term -> RewriteMonad NormalizeState Term
forall (m :: Type -> Type) a. Monad m => a -> m a
return (Term -> [Either Term Type] -> Term
mkApps (Term -> [TickInfo] -> Term
mkTicks (Id -> Term
Var Id
g {varType :: Type
varType = Type
gTy}) [TickInfo]
ticks) [Either Term Type]
args2)
else
Term -> RewriteMonad NormalizeState Term
forall (m :: Type -> Type) a. Monad m => a -> m a
return Term
e
where
splitArg
:: Either TyVar Type
-> Either Term Type
-> NormalizeSession [(Either TyVar Type,Either Term Type)]
splitArg :: Either TyVar Type
-> Either Term Type
-> RewriteMonad
NormalizeState [(Either TyVar Type, Either Term Type)]
splitArg Either TyVar Type
tv arg :: Either Term Type
arg@(Right Type
_) = [(Either TyVar Type, Either Term Type)]
-> RewriteMonad
NormalizeState [(Either TyVar Type, Either Term Type)]
forall (m :: Type -> Type) a. Monad m => a -> m a
return [(Either TyVar Type
tv,Either Term Type
arg)]
splitArg Either TyVar Type
ty arg :: Either Term Type
arg@(Left Term
tmArg) = do
TyConMap
tcm <- Getting TyConMap RewriteEnv TyConMap
-> RewriteMonad NormalizeState TyConMap
forall s (m :: Type -> Type) a.
MonadReader s m =>
Getting a s a -> m a
Lens.view Getting TyConMap RewriteEnv TyConMap
Getter RewriteEnv TyConMap
tcCache
let argTy :: Type
argTy = TyConMap -> Term -> Type
forall a. InferType a => TyConMap -> a -> Type
inferCoreTypeOf TyConMap
tcm Term
tmArg
case TyConMap -> Type -> Maybe ([Term] -> Term, Projections, [Type])
shouldSplit TyConMap
tcm Type
argTy of
Just ([Term] -> Term
_,Projections forall (m :: Type -> Type).
MonadUnique m =>
InScopeSet -> Term -> m [Term]
projections,[Type]
_) -> do
[Term]
tmArgs <- InScopeSet -> Term -> RewriteMonad NormalizeState [Term]
forall (m :: Type -> Type).
MonadUnique m =>
InScopeSet -> Term -> m [Term]
projections InScopeSet
is0 Term
tmArg
[(Either TyVar Type, Either Term Type)]
-> RewriteMonad
NormalizeState [(Either TyVar Type, Either Term Type)]
forall a extra. a -> RewriteMonad extra a
changed ((Term -> (Either TyVar Type, Either Term Type))
-> [Term] -> [(Either TyVar Type, Either Term Type)]
forall a b. (a -> b) -> [a] -> [b]
map ((Either TyVar Type
ty,) (Either Term Type -> (Either TyVar Type, Either Term Type))
-> (Term -> Either Term Type)
-> Term
-> (Either TyVar Type, Either Term Type)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Term -> Either Term Type
forall a b. a -> Either a b
Left) [Term]
tmArgs)
Maybe ([Term] -> Term, Projections, [Type])
_ ->
[(Either TyVar Type, Either Term Type)]
-> RewriteMonad
NormalizeState [(Either TyVar Type, Either Term Type)]
forall (m :: Type -> Type) a. Monad m => a -> m a
return [(Either TyVar Type
ty,Either Term Type
arg)]
separateArguments TransformContext
_ Term
e = Term -> RewriteMonad NormalizeState Term
forall (m :: Type -> Type) a. Monad m => a -> m a
return Term
e
{-# SCC separateArguments #-}
separateLambda
:: TyConMap
-> TransformContext
-> Id
-> Term
-> Maybe Term
separateLambda :: TyConMap -> TransformContext -> Id -> Term -> Maybe Term
separateLambda TyConMap
tcm ctx :: TransformContext
ctx@(TransformContext InScopeSet
is0 Context
_) Id
b Term
eb0 =
case TyConMap -> Type -> Maybe ([Term] -> Term, Projections, [Type])
shouldSplit TyConMap
tcm (Id -> Type
forall a. HasType a => a -> Type
coreTypeOf Id
b) of
Just ([Term] -> Term
dc, Projections
_, [Type]
argTys) ->
let
nm :: TmName
nm = TransformContext -> OccName -> TmName
mkDerivedName TransformContext
ctx (TmName -> OccName
forall a. Name a -> OccName
nameOcc (Id -> TmName
forall a. Var a -> Name a
varName Id
b))
bs0 :: [Id]
bs0 = (Type -> Id) -> [Type] -> [Id]
forall a b. (a -> b) -> [a] -> [b]
map (Type -> TmName -> Id
`mkLocalId` TmName
nm) [Type]
argTys
(InScopeSet
is1, [Id]
bs1) = (InScopeSet -> Id -> (InScopeSet, Id))
-> InScopeSet -> [Id] -> (InScopeSet, [Id])
forall (t :: Type -> Type) a b c.
Traversable t =>
(a -> b -> (a, c)) -> a -> t b -> (a, t c)
List.mapAccumL InScopeSet -> Id -> (InScopeSet, Id)
forall a. InScopeSet -> Var a -> (InScopeSet, Var a)
newBinder InScopeSet
is0 [Id]
bs0
subst :: Subst
subst = Subst -> Id -> Term -> Subst
extendIdSubst (InScopeSet -> Subst
mkSubst InScopeSet
is1) Id
b ([Term] -> Term
dc ((Id -> Term) -> [Id] -> [Term]
forall a b. (a -> b) -> [a] -> [b]
map Id -> Term
Var [Id]
bs1))
eb1 :: Term
eb1 = HasCallStack => Doc () -> Subst -> Term -> Term
Doc () -> Subst -> Term -> Term
substTm Doc ()
"separateArguments" Subst
subst Term
eb0
in
Term -> Maybe Term
forall a. a -> Maybe a
Just (Term -> [Id] -> Term
mkLams Term
eb1 [Id]
bs1)
Maybe ([Term] -> Term, Projections, [Type])
_ ->
Maybe Term
forall a. Maybe a
Nothing
where
newBinder :: InScopeSet -> Var a -> (InScopeSet, Var a)
newBinder InScopeSet
isN0 Var a
x =
let
x' :: Var a
x' = InScopeSet -> Var a -> Var a
forall a. (Uniquable a, ClashPretty a) => InScopeSet -> a -> a
uniqAway InScopeSet
isN0 Var a
x
isN1 :: InScopeSet
isN1 = InScopeSet -> Var a -> InScopeSet
forall a. InScopeSet -> Var a -> InScopeSet
extendInScopeSet InScopeSet
isN0 Var a
x'
in
(InScopeSet
isN1, Var a
x')
{-# SCC separateLambda #-}