{-# LANGUAGE NamedFieldPuns #-}
{-# LANGUAGE OverloadedStrings #-}
module Clash.Normalize.Transformations.MultiPrim
( setupMultiResultPrim
) where
import qualified Control.Lens as Lens
import qualified Data.Either as Either
import Data.Text.Extra (showt)
import GHC.Stack (HasCallStack)
import Clash.Annotations.Primitive (extractPrim)
import Clash.Core.Name (mkUnsafeInternalName)
import Clash.Core.Term
( IsMultiPrim(..), MultiPrimInfo(..), PrimInfo(..), Term(..), WorkInfo(..)
, mkAbstraction, mkApps, mkTmApps, mkTyApps, PrimUnfolding(..))
import Clash.Core.TermInfo (multiPrimInfo')
import Clash.Core.TyCon (TyConMap)
import Clash.Core.Type (Type(..), mkPolyFunTy, splitFunForallTy)
import Clash.Core.Util (listToLets)
import Clash.Core.Var (mkLocalId)
import Clash.Normalize.Types (NormRewrite)
import Clash.Primitives.Types (Primitive(..))
import Clash.Rewrite.Types (tcCache, primitives)
import Clash.Rewrite.Util (changed)
setupMultiResultPrim :: HasCallStack => NormRewrite
setupMultiResultPrim :: NormRewrite
setupMultiResultPrim TransformContext
_ctx e :: Term
e@(Prim pInfo :: PrimInfo
pInfo@PrimInfo{primMultiResult :: PrimInfo -> IsMultiPrim
primMultiResult=IsMultiPrim
SingleResult}) = do
TyConMap
tcm <- Getting TyConMap RewriteEnv TyConMap
-> RewriteMonad NormalizeState TyConMap
forall s (m :: Type -> Type) a.
MonadReader s m =>
Getting a s a -> m a
Lens.view Getting TyConMap RewriteEnv TyConMap
Getter RewriteEnv TyConMap
tcCache
Maybe
(PrimitiveGuard
(Primitive BlackBoxTemplate BlackBox () (Int, BlackBoxFunction)))
prim <- Getting
(Maybe
(PrimitiveGuard
(Primitive BlackBoxTemplate BlackBox () (Int, BlackBoxFunction))))
RewriteEnv
(Maybe
(PrimitiveGuard
(Primitive BlackBoxTemplate BlackBox () (Int, BlackBoxFunction))))
-> RewriteMonad
NormalizeState
(Maybe
(PrimitiveGuard
(Primitive BlackBoxTemplate BlackBox () (Int, BlackBoxFunction))))
forall s (m :: Type -> Type) a.
MonadReader s m =>
Getting a s a -> m a
Lens.view ((CompiledPrimMap
-> Const
(Maybe
(PrimitiveGuard
(Primitive BlackBoxTemplate BlackBox () (Int, BlackBoxFunction))))
CompiledPrimMap)
-> RewriteEnv
-> Const
(Maybe
(PrimitiveGuard
(Primitive BlackBoxTemplate BlackBox () (Int, BlackBoxFunction))))
RewriteEnv
Getter RewriteEnv CompiledPrimMap
primitives ((CompiledPrimMap
-> Const
(Maybe
(PrimitiveGuard
(Primitive BlackBoxTemplate BlackBox () (Int, BlackBoxFunction))))
CompiledPrimMap)
-> RewriteEnv
-> Const
(Maybe
(PrimitiveGuard
(Primitive BlackBoxTemplate BlackBox () (Int, BlackBoxFunction))))
RewriteEnv)
-> ((Maybe
(PrimitiveGuard
(Primitive BlackBoxTemplate BlackBox () (Int, BlackBoxFunction)))
-> Const
(Maybe
(PrimitiveGuard
(Primitive BlackBoxTemplate BlackBox () (Int, BlackBoxFunction))))
(Maybe
(PrimitiveGuard
(Primitive BlackBoxTemplate BlackBox () (Int, BlackBoxFunction)))))
-> CompiledPrimMap
-> Const
(Maybe
(PrimitiveGuard
(Primitive BlackBoxTemplate BlackBox () (Int, BlackBoxFunction))))
CompiledPrimMap)
-> Getting
(Maybe
(PrimitiveGuard
(Primitive BlackBoxTemplate BlackBox () (Int, BlackBoxFunction))))
RewriteEnv
(Maybe
(PrimitiveGuard
(Primitive BlackBoxTemplate BlackBox () (Int, BlackBoxFunction))))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Index CompiledPrimMap
-> Lens' CompiledPrimMap (Maybe (IxValue CompiledPrimMap))
forall m. At m => Index m -> Lens' m (Maybe (IxValue m))
Lens.at (PrimInfo -> Text
primName PrimInfo
pInfo))
case Maybe
(PrimitiveGuard
(Primitive BlackBoxTemplate BlackBox () (Int, BlackBoxFunction)))
prim Maybe
(PrimitiveGuard
(Primitive BlackBoxTemplate BlackBox () (Int, BlackBoxFunction)))
-> (PrimitiveGuard
(Primitive BlackBoxTemplate BlackBox () (Int, BlackBoxFunction))
-> Maybe
(Primitive BlackBoxTemplate BlackBox () (Int, BlackBoxFunction)))
-> Maybe
(Primitive BlackBoxTemplate BlackBox () (Int, BlackBoxFunction))
forall (m :: Type -> Type) a b. Monad m => m a -> (a -> m b) -> m b
>>= PrimitiveGuard
(Primitive BlackBoxTemplate BlackBox () (Int, BlackBoxFunction))
-> Maybe
(Primitive BlackBoxTemplate BlackBox () (Int, BlackBoxFunction))
forall a. PrimitiveGuard a -> Maybe a
extractPrim of
Just (BlackBoxHaskell{multiResult :: forall a b c d. Primitive a b c d -> Bool
multiResult=Bool
True}) ->
Term -> RewriteMonad NormalizeState Term
forall a extra. a -> RewriteMonad extra a
changed (HasCallStack => TyConMap -> PrimInfo -> Term
TyConMap -> PrimInfo -> Term
setupMultiResultPrim' TyConMap
tcm PrimInfo
pInfo)
Just (BlackBox{multiResult :: forall a b c d. Primitive a b c d -> Bool
multiResult=Bool
True}) ->
Term -> RewriteMonad NormalizeState Term
forall a extra. a -> RewriteMonad extra a
changed (HasCallStack => TyConMap -> PrimInfo -> Term
TyConMap -> PrimInfo -> Term
setupMultiResultPrim' TyConMap
tcm PrimInfo
pInfo)
Maybe
(Primitive BlackBoxTemplate BlackBox () (Int, BlackBoxFunction))
_ ->
Term -> RewriteMonad NormalizeState Term
forall (m :: Type -> Type) a. Monad m => a -> m a
return Term
e
setupMultiResultPrim TransformContext
_ Term
e = Term -> RewriteMonad NormalizeState Term
forall (m :: Type -> Type) a. Monad m => a -> m a
return Term
e
setupMultiResultPrim' :: HasCallStack => TyConMap -> PrimInfo -> Term
setupMultiResultPrim' :: TyConMap -> PrimInfo -> Term
setupMultiResultPrim' TyConMap
tcm primInfo :: PrimInfo
primInfo@PrimInfo{Type
primType :: PrimInfo -> Type
primType :: Type
primType} =
Term -> [Either Id TyVar] -> Term
mkAbstraction Term
letTerm ((TyVar -> Either Id TyVar) -> [TyVar] -> [Either Id TyVar]
forall a b. (a -> b) -> [a] -> [b]
map TyVar -> Either Id TyVar
forall a b. b -> Either a b
Right [TyVar]
typeVars [Either Id TyVar] -> [Either Id TyVar] -> [Either Id TyVar]
forall a. Semigroup a => a -> a -> a
<> (Id -> Either Id TyVar) -> [Id] -> [Either Id TyVar]
forall a b. (a -> b) -> [a] -> [b]
map Id -> Either Id TyVar
forall a b. a -> Either a b
Left [Id]
argIds)
where
typeVars :: [TyVar]
typeVars = [Either TyVar Type] -> [TyVar]
forall a b. [Either a b] -> [a]
Either.lefts [Either TyVar Type]
pArgs
internalNm :: Text -> Int -> Name a
internalNm Text
prefix Int
n = Text -> Int -> Name a
forall a. Text -> Int -> Name a
mkUnsafeInternalName (Text
prefix Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Int -> Text
forall a. Show a => a -> Text
showt Int
n) Int
n
internalId :: Text -> Type -> Int -> Id
internalId Text
prefix Type
typ Int
n = Type -> TmName -> Id
mkLocalId Type
typ (Text -> Int -> TmName
forall a. Text -> Int -> Name a
internalNm Text
prefix Int
n)
nTermArgs :: Int
nTermArgs = [Type] -> Int
forall (t :: Type -> Type) a. Foldable t => t a -> Int
length ([Either TyVar Type] -> [Type]
forall a b. [Either a b] -> [b]
Either.rights [Either TyVar Type]
pArgs)
argIds :: [Id]
argIds = (Type -> Int -> Id) -> [Type] -> [Int] -> [Id]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith (Text -> Type -> Int -> Id
internalId Text
"a") ([Either TyVar Type] -> [Type]
forall a b. [Either a b] -> [b]
Either.rights [Either TyVar Type]
pArgs) [Int
1..Int
nTermArgs]
resIds :: [Id]
resIds = (Type -> Int -> Id) -> [Type] -> [Int] -> [Id]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith (Text -> Type -> Int -> Id
internalId Text
"r") [Type]
resTypes [Int
nTermArgsInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1..Int
nTermArgsInt -> Int -> Int
forall a. Num a => a -> a -> a
+[Type] -> Int
forall (t :: Type -> Type) a. Foldable t => t a -> Int
length [Type]
resTypes]
resId :: Id
resId = Type -> TmName -> Id
mkLocalId Type
pResTy (Text -> Int -> TmName
forall a. Text -> Int -> Name a
mkUnsafeInternalName Text
"r" (Int
nTermArgsInt -> Int -> Int
forall a. Num a => a -> a -> a
+[Type] -> Int
forall (t :: Type -> Type) a. Foldable t => t a -> Int
length [Type]
resTypesInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1))
([Either TyVar Type]
pArgs, Type
pResTy) = Type -> ([Either TyVar Type], Type)
splitFunForallTy Type
primType
MultiPrimInfo{mpi_resultDc :: MultiPrimInfo -> DataCon
mpi_resultDc=DataCon
tupTc, mpi_resultTypes :: MultiPrimInfo -> [Type]
mpi_resultTypes=[Type]
resTypes} =
HasCallStack => TyConMap -> PrimInfo -> MultiPrimInfo
TyConMap -> PrimInfo -> MultiPrimInfo
multiPrimInfo' TyConMap
tcm PrimInfo
primInfo
multiPrimSelect :: Id -> Type -> (Id, Term)
multiPrimSelect Id
r Type
t = (Id
r, Term -> [Term] -> Term
mkTmApps (PrimInfo -> Term
Prim (Type -> PrimInfo
multiPrimSelectInfo Type
t)) [Id -> Term
Var Id
r, Id -> Term
Var Id
resId])
multiPrimSelectBinds :: [(Id, Term)]
multiPrimSelectBinds = (Id -> Type -> (Id, Term)) -> [Id] -> [Type] -> [(Id, Term)]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Id -> Type -> (Id, Term)
multiPrimSelect [Id]
resIds [Type]
resTypes
multiPrimTermArgs :: [Either Term b]
multiPrimTermArgs = (Id -> Either Term b) -> [Id] -> [Either Term b]
forall a b. (a -> b) -> [a] -> [b]
map (Term -> Either Term b
forall a b. a -> Either a b
Left (Term -> Either Term b) -> (Id -> Term) -> Id -> Either Term b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Id -> Term
Var) ([Id]
argIds [Id] -> [Id] -> [Id]
forall a. Semigroup a => a -> a -> a
<> [Id]
resIds)
multiPrimTypeArgs :: [Either a Type]
multiPrimTypeArgs = (TyVar -> Either a Type) -> [TyVar] -> [Either a Type]
forall a b. (a -> b) -> [a] -> [b]
map (Type -> Either a Type
forall a b. b -> Either a b
Right (Type -> Either a Type)
-> (TyVar -> Type) -> TyVar -> Either a Type
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TyVar -> Type
VarTy) [TyVar]
typeVars
multiPrimBind :: Term
multiPrimBind =
Term -> [Either Term Type] -> Term
mkApps
(PrimInfo -> Term
Prim PrimInfo
primInfo{primMultiResult :: IsMultiPrim
primMultiResult=IsMultiPrim
MultiResult})
([Either Term Type]
forall a. [Either a Type]
multiPrimTypeArgs [Either Term Type] -> [Either Term Type] -> [Either Term Type]
forall a. Semigroup a => a -> a -> a
<> [Either Term Type]
forall b. [Either Term b]
multiPrimTermArgs)
multiPrimSelectInfo :: Type -> PrimInfo
multiPrimSelectInfo Type
t = PrimInfo :: Text
-> Type -> WorkInfo -> IsMultiPrim -> PrimUnfolding -> PrimInfo
PrimInfo
{ primName :: Text
primName = Text
"c$multiPrimSelect"
, primType :: Type
primType = Type -> [Either TyVar Type] -> Type
mkPolyFunTy Type
pResTy [Type -> Either TyVar Type
forall a b. b -> Either a b
Right Type
pResTy, Type -> Either TyVar Type
forall a b. b -> Either a b
Right Type
t]
, primWorkInfo :: WorkInfo
primWorkInfo = WorkInfo
WorkAlways
, primMultiResult :: IsMultiPrim
primMultiResult = IsMultiPrim
SingleResult
, primUnfolding :: PrimUnfolding
primUnfolding = PrimUnfolding
NoUnfolding
}
letTerm :: Term
letTerm =
[(Id, Term)] -> Term -> Term
listToLets
((Id
resId,Term
multiPrimBind)(Id, Term) -> [(Id, Term)] -> [(Id, Term)]
forall a. a -> [a] -> [a]
:[(Id, Term)]
multiPrimSelectBinds)
(Term -> [Term] -> Term
mkTmApps (Term -> [Type] -> Term
mkTyApps (DataCon -> Term
Data DataCon
tupTc) [Type]
resTypes) ((Id -> Term) -> [Id] -> [Term]
forall a b. (a -> b) -> [a] -> [b]
map Id -> Term
Var [Id]
resIds))