{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE TypeFamilies #-}
module Futhark.Transform.FirstOrderTransform
( transformFunDef,
transformConsts,
FirstOrderLore,
Transformer,
transformStmRecursively,
transformLambda,
transformSOAC,
)
where
import Control.Monad.Except
import Control.Monad.State
import Data.List (zip4)
import qualified Data.Map.Strict as M
import qualified Futhark.IR as AST
import Futhark.IR.SOACS
import Futhark.MonadFreshNames
import Futhark.Tools
import Futhark.Util (chunks, splitAt3)
type FirstOrderLore lore =
( Bindable lore,
BinderOps lore,
LetDec SOACS ~ LetDec lore,
LParamInfo SOACS ~ LParamInfo lore
)
transformFunDef ::
(MonadFreshNames m, FirstOrderLore tolore) =>
Scope tolore ->
FunDef SOACS ->
m (AST.FunDef tolore)
transformFunDef :: Scope tolore -> FunDef SOACS -> m (FunDef tolore)
transformFunDef Scope tolore
consts_scope (FunDef Maybe EntryPoint
entry Attrs
attrs Name
fname [RetType SOACS]
rettype [FParam SOACS]
params BodyT SOACS
body) = do
(BodyT tolore
body', Stms tolore
_) <- (VNameSource -> ((BodyT tolore, Stms tolore), VNameSource))
-> m (BodyT tolore, Stms tolore)
forall (m :: * -> *) a.
MonadFreshNames m =>
(VNameSource -> (a, VNameSource)) -> m a
modifyNameSource ((VNameSource -> ((BodyT tolore, Stms tolore), VNameSource))
-> m (BodyT tolore, Stms tolore))
-> (VNameSource -> ((BodyT tolore, Stms tolore), VNameSource))
-> m (BodyT tolore, Stms tolore)
forall a b. (a -> b) -> a -> b
$ State VNameSource (BodyT tolore, Stms tolore)
-> VNameSource -> ((BodyT tolore, Stms tolore), VNameSource)
forall s a. State s a -> s -> (a, s)
runState (State VNameSource (BodyT tolore, Stms tolore)
-> VNameSource -> ((BodyT tolore, Stms tolore), VNameSource))
-> State VNameSource (BodyT tolore, Stms tolore)
-> VNameSource
-> ((BodyT tolore, Stms tolore), VNameSource)
forall a b. (a -> b) -> a -> b
$ BinderT tolore (StateT VNameSource Identity) (BodyT tolore)
-> Scope tolore -> State VNameSource (BodyT tolore, Stms tolore)
forall (m :: * -> *) lore a.
MonadFreshNames m =>
BinderT lore m a -> Scope lore -> m (a, Stms lore)
runBinderT BinderT tolore (StateT VNameSource Identity) (BodyT tolore)
m Scope tolore
consts_scope
FunDef tolore -> m (FunDef tolore)
forall (m :: * -> *) a. Monad m => a -> m a
return (FunDef tolore -> m (FunDef tolore))
-> FunDef tolore -> m (FunDef tolore)
forall a b. (a -> b) -> a -> b
$ Maybe EntryPoint
-> Attrs
-> Name
-> [RetType tolore]
-> [FParam tolore]
-> BodyT tolore
-> FunDef tolore
forall lore.
Maybe EntryPoint
-> Attrs
-> Name
-> [RetType lore]
-> [FParam lore]
-> BodyT lore
-> FunDef lore
FunDef Maybe EntryPoint
entry Attrs
attrs Name
fname [RetType tolore]
[RetType SOACS]
rettype [FParam tolore]
[FParam SOACS]
params BodyT tolore
body'
where
m :: BinderT tolore (StateT VNameSource Identity) (BodyT tolore)
m = Scope tolore
-> BinderT tolore (StateT VNameSource Identity) (BodyT tolore)
-> BinderT tolore (StateT VNameSource Identity) (BodyT tolore)
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
localScope ([Param DeclType] -> Scope tolore
forall lore dec.
(FParamInfo lore ~ dec) =>
[Param dec] -> Scope lore
scopeOfFParams [Param DeclType]
[FParam SOACS]
params) (BinderT tolore (StateT VNameSource Identity) (BodyT tolore)
-> BinderT tolore (StateT VNameSource Identity) (BodyT tolore))
-> BinderT tolore (StateT VNameSource Identity) (BodyT tolore)
-> BinderT tolore (StateT VNameSource Identity) (BodyT tolore)
forall a b. (a -> b) -> a -> b
$ BinderT
tolore
(StateT VNameSource Identity)
(Body (Lore (BinderT tolore (StateT VNameSource Identity))))
-> BinderT
tolore
(StateT VNameSource Identity)
(Body (Lore (BinderT tolore (StateT VNameSource Identity))))
forall (m :: * -> *).
MonadBinder m =>
m (Body (Lore m)) -> m (Body (Lore m))
insertStmsM (BinderT
tolore
(StateT VNameSource Identity)
(Body (Lore (BinderT tolore (StateT VNameSource Identity))))
-> BinderT
tolore
(StateT VNameSource Identity)
(Body (Lore (BinderT tolore (StateT VNameSource Identity)))))
-> BinderT
tolore
(StateT VNameSource Identity)
(Body (Lore (BinderT tolore (StateT VNameSource Identity))))
-> BinderT
tolore
(StateT VNameSource Identity)
(Body (Lore (BinderT tolore (StateT VNameSource Identity))))
forall a b. (a -> b) -> a -> b
$ BodyT SOACS
-> BinderT
tolore
(StateT VNameSource Identity)
(Body (Lore (BinderT tolore (StateT VNameSource Identity))))
forall (m :: * -> *).
(Transformer m, LetDec (Lore m) ~ LetDec SOACS) =>
BodyT SOACS -> m (Body (Lore m))
transformBody BodyT SOACS
body
transformConsts ::
(MonadFreshNames m, FirstOrderLore tolore) =>
Stms SOACS ->
m (AST.Stms tolore)
transformConsts :: Stms SOACS -> m (Stms tolore)
transformConsts Stms SOACS
stms =
(((), Stms tolore) -> Stms tolore)
-> m ((), Stms tolore) -> m (Stms tolore)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ((), Stms tolore) -> Stms tolore
forall a b. (a, b) -> b
snd (m ((), Stms tolore) -> m (Stms tolore))
-> m ((), Stms tolore) -> m (Stms tolore)
forall a b. (a -> b) -> a -> b
$ (VNameSource -> (((), Stms tolore), VNameSource))
-> m ((), Stms tolore)
forall (m :: * -> *) a.
MonadFreshNames m =>
(VNameSource -> (a, VNameSource)) -> m a
modifyNameSource ((VNameSource -> (((), Stms tolore), VNameSource))
-> m ((), Stms tolore))
-> (VNameSource -> (((), Stms tolore), VNameSource))
-> m ((), Stms tolore)
forall a b. (a -> b) -> a -> b
$ State VNameSource ((), Stms tolore)
-> VNameSource -> (((), Stms tolore), VNameSource)
forall s a. State s a -> s -> (a, s)
runState (State VNameSource ((), Stms tolore)
-> VNameSource -> (((), Stms tolore), VNameSource))
-> State VNameSource ((), Stms tolore)
-> VNameSource
-> (((), Stms tolore), VNameSource)
forall a b. (a -> b) -> a -> b
$ BinderT tolore (StateT VNameSource Identity) ()
-> Scope tolore -> State VNameSource ((), Stms tolore)
forall (m :: * -> *) lore a.
MonadFreshNames m =>
BinderT lore m a -> Scope lore -> m (a, Stms lore)
runBinderT BinderT tolore (StateT VNameSource Identity) ()
m Scope tolore
forall a. Monoid a => a
mempty
where
m :: BinderT tolore (StateT VNameSource Identity) ()
m = (Stm -> BinderT tolore (StateT VNameSource Identity) ())
-> Stms SOACS -> BinderT tolore (StateT VNameSource Identity) ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ Stm -> BinderT tolore (StateT VNameSource Identity) ()
forall (m :: * -> *).
(Transformer m, LetDec (Lore m) ~ LetDec SOACS) =>
Stm -> m ()
transformStmRecursively Stms SOACS
stms
type Transformer m =
( MonadBinder m,
LocalScope (Lore m) m,
Bindable (Lore m),
BinderOps (Lore m),
LParamInfo SOACS ~ LParamInfo (Lore m)
)
transformBody ::
(Transformer m, LetDec (Lore m) ~ LetDec SOACS) =>
Body ->
m (AST.Body (Lore m))
transformBody :: BodyT SOACS -> m (Body (Lore m))
transformBody (Body () Stms SOACS
bnds Result
res) = m (Body (Lore m)) -> m (Body (Lore m))
forall (m :: * -> *).
MonadBinder m =>
m (Body (Lore m)) -> m (Body (Lore m))
insertStmsM (m (Body (Lore m)) -> m (Body (Lore m)))
-> m (Body (Lore m)) -> m (Body (Lore m))
forall a b. (a -> b) -> a -> b
$ do
(Stm -> m ()) -> Stms SOACS -> m ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ Stm -> m ()
forall (m :: * -> *).
(Transformer m, LetDec (Lore m) ~ LetDec SOACS) =>
Stm -> m ()
transformStmRecursively Stms SOACS
bnds
Body (Lore m) -> m (Body (Lore m))
forall (m :: * -> *) a. Monad m => a -> m a
return (Body (Lore m) -> m (Body (Lore m)))
-> Body (Lore m) -> m (Body (Lore m))
forall a b. (a -> b) -> a -> b
$ Result -> Body (Lore m)
forall lore. Bindable lore => Result -> Body lore
resultBody Result
res
transformStmRecursively ::
(Transformer m, LetDec (Lore m) ~ LetDec SOACS) =>
Stm ->
m ()
transformStmRecursively :: Stm -> m ()
transformStmRecursively (Let Pattern SOACS
pat StmAux (ExpDec SOACS)
aux (Op Op SOACS
soac)) =
StmAux () -> m () -> m ()
forall (m :: * -> *) anylore a.
MonadBinder m =>
StmAux anylore -> m a -> m a
auxing StmAux ()
StmAux (ExpDec SOACS)
aux (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$ Pattern (Lore m) -> SOAC (Lore m) -> m ()
forall (m :: * -> *).
Transformer m =>
Pattern (Lore m) -> SOAC (Lore m) -> m ()
transformSOAC Pattern (Lore m)
Pattern SOACS
pat (SOAC (Lore m) -> m ()) -> m (SOAC (Lore m)) -> m ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< SOACMapper SOACS (Lore m) m -> SOAC SOACS -> m (SOAC (Lore m))
forall (m :: * -> *) flore tlore.
(Applicative m, Monad m) =>
SOACMapper flore tlore m -> SOAC flore -> m (SOAC tlore)
mapSOACM SOACMapper SOACS (Lore m) m
soacTransform Op SOACS
SOAC SOACS
soac
where
soacTransform :: SOACMapper SOACS (Lore m) m
soacTransform = SOACMapper Any Any m
forall (m :: * -> *) lore. Monad m => SOACMapper lore lore m
identitySOACMapper {mapOnSOACLambda :: Lambda SOACS -> m (Lambda (Lore m))
mapOnSOACLambda = Lambda SOACS -> m (Lambda (Lore m))
forall (m :: * -> *) lore somelore.
(MonadFreshNames m, Bindable lore, BinderOps lore,
LocalScope somelore m, SameScope somelore lore,
LetDec lore ~ LetDec SOACS) =>
Lambda SOACS -> m (Lambda lore)
transformLambda}
transformStmRecursively (Let Pattern SOACS
pat StmAux (ExpDec SOACS)
aux ExpT SOACS
e) =
StmAux () -> m () -> m ()
forall (m :: * -> *) anylore a.
MonadBinder m =>
StmAux anylore -> m a -> m a
auxing StmAux ()
StmAux (ExpDec SOACS)
aux (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$ Pattern (Lore m) -> ExpT (Lore m) -> m ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind Pattern (Lore m)
Pattern SOACS
pat (ExpT (Lore m) -> m ()) -> m (ExpT (Lore m)) -> m ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Mapper SOACS (Lore m) m -> ExpT SOACS -> m (ExpT (Lore m))
forall (m :: * -> *) flore tlore.
(Applicative m, Monad m) =>
Mapper flore tlore m -> Exp flore -> m (Exp tlore)
mapExpM Mapper SOACS (Lore m) m
transform ExpT SOACS
e
where
transform :: Mapper SOACS (Lore m) m
transform =
Mapper Any Any m
forall (m :: * -> *) lore. Monad m => Mapper lore lore m
identityMapper
{ mapOnBody :: Scope (Lore m) -> BodyT SOACS -> m (Body (Lore m))
mapOnBody = \Scope (Lore m)
scope -> Scope (Lore m) -> m (Body (Lore m)) -> m (Body (Lore m))
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
localScope Scope (Lore m)
scope (m (Body (Lore m)) -> m (Body (Lore m)))
-> (BodyT SOACS -> m (Body (Lore m)))
-> BodyT SOACS
-> m (Body (Lore m))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BodyT SOACS -> m (Body (Lore m))
forall (m :: * -> *).
(Transformer m, LetDec (Lore m) ~ LetDec SOACS) =>
BodyT SOACS -> m (Body (Lore m))
transformBody,
mapOnRetType :: RetType SOACS -> m (RetType (Lore m))
mapOnRetType = RetType SOACS -> m (RetType (Lore m))
forall (m :: * -> *) a. Monad m => a -> m a
return,
mapOnBranchType :: BranchType SOACS -> m (BranchType (Lore m))
mapOnBranchType = BranchType SOACS -> m (BranchType (Lore m))
forall (m :: * -> *) a. Monad m => a -> m a
return,
mapOnFParam :: FParam SOACS -> m (FParam (Lore m))
mapOnFParam = FParam SOACS -> m (FParam (Lore m))
forall (m :: * -> *) a. Monad m => a -> m a
return,
mapOnLParam :: LParam SOACS -> m (LParam (Lore m))
mapOnLParam = LParam SOACS -> m (LParam (Lore m))
forall (m :: * -> *) a. Monad m => a -> m a
return,
mapOnOp :: Op SOACS -> m (Op (Lore m))
mapOnOp = [Char] -> SOAC SOACS -> m (Op (Lore m))
forall a. HasCallStack => [Char] -> a
error [Char]
"Unhandled Op in first order transform"
}
transformSOAC ::
Transformer m =>
AST.Pattern (Lore m) ->
SOAC (Lore m) ->
m ()
transformSOAC :: Pattern (Lore m) -> SOAC (Lore m) -> m ()
transformSOAC Pattern (Lore m)
pat (Screma SubExp
w [VName]
arrs form :: ScremaForm (Lore m)
form@(ScremaForm [Scan (Lore m)]
scans [Reduce (Lore m)]
reds Lambda (Lore m)
map_lam)) = do
let Reduce Commutativity
_ Lambda (Lore m)
red_lam Result
red_nes = [Reduce (Lore m)] -> Reduce (Lore m)
forall lore. Bindable lore => [Reduce lore] -> Reduce lore
singleReduce [Reduce (Lore m)]
reds
Scan Lambda (Lore m)
scan_lam Result
scan_nes = [Scan (Lore m)] -> Scan (Lore m)
forall lore. Bindable lore => [Scan lore] -> Scan lore
singleScan [Scan (Lore m)]
scans
([Type]
scan_arr_ts, [Type]
_red_ts, [Type]
map_arr_ts) =
Int -> Int -> [Type] -> ([Type], [Type], [Type])
forall a. Int -> Int -> [a] -> ([a], [a], [a])
splitAt3 (Result -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length Result
scan_nes) (Result -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length Result
red_nes) ([Type] -> ([Type], [Type], [Type]))
-> [Type] -> ([Type], [Type], [Type])
forall a b. (a -> b) -> a -> b
$ SubExp -> ScremaForm (Lore m) -> [Type]
forall lore. SubExp -> ScremaForm lore -> [Type]
scremaType SubExp
w ScremaForm (Lore m)
form
[VName]
scan_arrs <- [Type] -> m [VName]
forall (m :: * -> *). Transformer m => [Type] -> m [VName]
resultArray [Type]
scan_arr_ts
[VName]
map_arrs <- [Type] -> m [VName]
forall (m :: * -> *). Transformer m => [Type] -> m [VName]
resultArray [Type]
map_arr_ts
[Param DeclType]
scanacc_params <- (Type -> m (Param DeclType)) -> [Type] -> m [Param DeclType]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM ([Char] -> DeclType -> m (Param DeclType)
forall (m :: * -> *) dec.
MonadFreshNames m =>
[Char] -> dec -> m (Param dec)
newParam [Char]
"scanacc" (DeclType -> m (Param DeclType))
-> (Type -> DeclType) -> Type -> m (Param DeclType)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Type -> Uniqueness -> DeclType) -> Uniqueness -> Type -> DeclType
forall a b c. (a -> b -> c) -> b -> a -> c
flip Type -> Uniqueness -> DeclType
forall shape.
TypeBase shape NoUniqueness
-> Uniqueness -> TypeBase shape Uniqueness
toDecl Uniqueness
Nonunique) ([Type] -> m [Param DeclType]) -> [Type] -> m [Param DeclType]
forall a b. (a -> b) -> a -> b
$ Lambda (Lore m) -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType Lambda (Lore m)
scan_lam
[Param DeclType]
scanout_params <- (Type -> m (Param DeclType)) -> [Type] -> m [Param DeclType]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM ([Char] -> DeclType -> m (Param DeclType)
forall (m :: * -> *) dec.
MonadFreshNames m =>
[Char] -> dec -> m (Param dec)
newParam [Char]
"scanout" (DeclType -> m (Param DeclType))
-> (Type -> DeclType) -> Type -> m (Param DeclType)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Type -> Uniqueness -> DeclType) -> Uniqueness -> Type -> DeclType
forall a b c. (a -> b -> c) -> b -> a -> c
flip Type -> Uniqueness -> DeclType
forall shape.
TypeBase shape NoUniqueness
-> Uniqueness -> TypeBase shape Uniqueness
toDecl Uniqueness
Unique) [Type]
scan_arr_ts
[Param DeclType]
redout_params <- (Type -> m (Param DeclType)) -> [Type] -> m [Param DeclType]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM ([Char] -> DeclType -> m (Param DeclType)
forall (m :: * -> *) dec.
MonadFreshNames m =>
[Char] -> dec -> m (Param dec)
newParam [Char]
"redout" (DeclType -> m (Param DeclType))
-> (Type -> DeclType) -> Type -> m (Param DeclType)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Type -> Uniqueness -> DeclType) -> Uniqueness -> Type -> DeclType
forall a b c. (a -> b -> c) -> b -> a -> c
flip Type -> Uniqueness -> DeclType
forall shape.
TypeBase shape NoUniqueness
-> Uniqueness -> TypeBase shape Uniqueness
toDecl Uniqueness
Nonunique) ([Type] -> m [Param DeclType]) -> [Type] -> m [Param DeclType]
forall a b. (a -> b) -> a -> b
$ Lambda (Lore m) -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType Lambda (Lore m)
red_lam
[Param DeclType]
mapout_params <- (Type -> m (Param DeclType)) -> [Type] -> m [Param DeclType]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM ([Char] -> DeclType -> m (Param DeclType)
forall (m :: * -> *) dec.
MonadFreshNames m =>
[Char] -> dec -> m (Param dec)
newParam [Char]
"mapout" (DeclType -> m (Param DeclType))
-> (Type -> DeclType) -> Type -> m (Param DeclType)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Type -> Uniqueness -> DeclType) -> Uniqueness -> Type -> DeclType
forall a b c. (a -> b -> c) -> b -> a -> c
flip Type -> Uniqueness -> DeclType
forall shape.
TypeBase shape NoUniqueness
-> Uniqueness -> TypeBase shape Uniqueness
toDecl Uniqueness
Unique) [Type]
map_arr_ts
let merge :: [(Param DeclType, SubExp)]
merge =
[[(Param DeclType, SubExp)]] -> [(Param DeclType, SubExp)]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat
[ [Param DeclType] -> Result -> [(Param DeclType, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param DeclType]
scanacc_params Result
scan_nes,
[Param DeclType] -> Result -> [(Param DeclType, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param DeclType]
scanout_params (Result -> [(Param DeclType, SubExp)])
-> Result -> [(Param DeclType, SubExp)]
forall a b. (a -> b) -> a -> b
$ (VName -> SubExp) -> [VName] -> Result
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var [VName]
scan_arrs,
[Param DeclType] -> Result -> [(Param DeclType, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param DeclType]
redout_params Result
red_nes,
[Param DeclType] -> Result -> [(Param DeclType, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param DeclType]
mapout_params (Result -> [(Param DeclType, SubExp)])
-> Result -> [(Param DeclType, SubExp)]
forall a b. (a -> b) -> a -> b
$ (VName -> SubExp) -> [VName] -> Result
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var [VName]
map_arrs
]
VName
i <- [Char] -> m VName
forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
"i"
let loopform :: LoopForm (Lore m)
loopform = VName
-> IntType
-> SubExp
-> [(LParam (Lore m), VName)]
-> LoopForm (Lore m)
forall lore.
VName
-> IntType -> SubExp -> [(LParam lore, VName)] -> LoopForm lore
ForLoop VName
i IntType
Int64 SubExp
w []
Body (Lore m)
loop_body <- Binder (Lore m) (Body (Lore m)) -> m (Body (Lore m))
forall lore (m :: * -> *) somelore.
(Bindable lore, MonadFreshNames m, HasScope somelore m,
SameScope somelore lore) =>
Binder lore (Body lore) -> m (Body lore)
runBodyBinder (Binder (Lore m) (Body (Lore m)) -> m (Body (Lore m)))
-> Binder (Lore m) (Body (Lore m)) -> m (Body (Lore m))
forall a b. (a -> b) -> a -> b
$
Scope (Lore m)
-> Binder (Lore m) (Body (Lore m))
-> Binder (Lore m) (Body (Lore m))
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
localScope ([Param DeclType] -> Scope (Lore m)
forall lore dec.
(FParamInfo lore ~ dec) =>
[Param dec] -> Scope lore
scopeOfFParams ([Param DeclType] -> Scope (Lore m))
-> [Param DeclType] -> Scope (Lore m)
forall a b. (a -> b) -> a -> b
$ ((Param DeclType, SubExp) -> Param DeclType)
-> [(Param DeclType, SubExp)] -> [Param DeclType]
forall a b. (a -> b) -> [a] -> [b]
map (Param DeclType, SubExp) -> Param DeclType
forall a b. (a, b) -> a
fst [(Param DeclType, SubExp)]
merge) (Binder (Lore m) (Body (Lore m))
-> Binder (Lore m) (Body (Lore m)))
-> Binder (Lore m) (Body (Lore m))
-> Binder (Lore m) (Body (Lore m))
forall a b. (a -> b) -> a -> b
$
LoopForm (Lore m)
-> Binder (Lore m) (Body (Lore m))
-> Binder (Lore m) (Body (Lore m))
forall lore a (m :: * -> *) b.
(Scoped lore a, LocalScope lore m) =>
a -> m b -> m b
inScopeOf LoopForm (Lore m)
loopform (Binder (Lore m) (Body (Lore m))
-> Binder (Lore m) (Body (Lore m)))
-> Binder (Lore m) (Body (Lore m))
-> Binder (Lore m) (Body (Lore m))
forall a b. (a -> b) -> a -> b
$ do
[(Param Type, VName)]
-> ((Param Type, VName) -> BinderT (Lore m) (State VNameSource) ())
-> BinderT (Lore m) (State VNameSource) ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param Type] -> [VName] -> [(Param Type, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip (Lambda (Lore m) -> [LParam (Lore m)]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda (Lore m)
map_lam) [VName]
arrs) (((Param Type, VName) -> BinderT (Lore m) (State VNameSource) ())
-> BinderT (Lore m) (State VNameSource) ())
-> ((Param Type, VName) -> BinderT (Lore m) (State VNameSource) ())
-> BinderT (Lore m) (State VNameSource) ()
forall a b. (a -> b) -> a -> b
$ \(Param Type
p, VName
arr) -> do
Type
arr_t <- VName -> BinderT (Lore m) (State VNameSource) Type
forall lore (m :: * -> *). HasScope lore m => VName -> m Type
lookupType VName
arr
[VName]
-> Exp (Lore (BinderT (Lore m) (State VNameSource)))
-> BinderT (Lore m) (State VNameSource) ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m ()
letBindNames [Param Type -> VName
forall dec. Param dec -> VName
paramName Param Type
p] (Exp (Lore (BinderT (Lore m) (State VNameSource)))
-> BinderT (Lore m) (State VNameSource) ())
-> Exp (Lore (BinderT (Lore m) (State VNameSource)))
-> BinderT (Lore m) (State VNameSource) ()
forall a b. (a -> b) -> a -> b
$
BasicOp -> ExpT (Lore m)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT (Lore m)) -> BasicOp -> ExpT (Lore m)
forall a b. (a -> b) -> a -> b
$
VName -> Slice SubExp -> BasicOp
Index VName
arr (Slice SubExp -> BasicOp) -> Slice SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$
Type -> Slice SubExp -> Slice SubExp
fullSlice Type
arr_t [SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix (SubExp -> DimIndex SubExp) -> SubExp -> DimIndex SubExp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
i]
(Stm (Lore m) -> BinderT (Lore m) (State VNameSource) ())
-> Seq (Stm (Lore m)) -> BinderT (Lore m) (State VNameSource) ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ Stm (Lore m) -> BinderT (Lore m) (State VNameSource) ()
forall (m :: * -> *). MonadBinder m => Stm (Lore m) -> m ()
addStm (Seq (Stm (Lore m)) -> BinderT (Lore m) (State VNameSource) ())
-> Seq (Stm (Lore m)) -> BinderT (Lore m) (State VNameSource) ()
forall a b. (a -> b) -> a -> b
$ Body (Lore m) -> Seq (Stm (Lore m))
forall lore. BodyT lore -> Stms lore
bodyStms (Body (Lore m) -> Seq (Stm (Lore m)))
-> Body (Lore m) -> Seq (Stm (Lore m))
forall a b. (a -> b) -> a -> b
$ Lambda (Lore m) -> Body (Lore m)
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda (Lore m)
map_lam
let (Result
scan_res, Result
red_res, Result
map_res) =
Int -> Int -> Result -> (Result, Result, Result)
forall a. Int -> Int -> [a] -> ([a], [a], [a])
splitAt3 (Result -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length Result
scan_nes) (Result -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length Result
red_nes) (Result -> (Result, Result, Result))
-> Result -> (Result, Result, Result)
forall a b. (a -> b) -> a -> b
$
Body (Lore m) -> Result
forall lore. BodyT lore -> Result
bodyResult (Body (Lore m) -> Result) -> Body (Lore m) -> Result
forall a b. (a -> b) -> a -> b
$ Lambda (Lore m) -> Body (Lore m)
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda (Lore m)
map_lam
Result
scan_res' <-
Lambda (Lore (BinderT (Lore m) (State VNameSource)))
-> [BinderT
(Lore m)
(State VNameSource)
(Exp (Lore (BinderT (Lore m) (State VNameSource))))]
-> BinderT (Lore m) (State VNameSource) Result
forall (m :: * -> *).
MonadBinder m =>
Lambda (Lore m) -> [m (Exp (Lore m))] -> m Result
eLambda Lambda (Lore m)
Lambda (Lore (BinderT (Lore m) (State VNameSource)))
scan_lam ([BinderT
(Lore m)
(State VNameSource)
(Exp (Lore (BinderT (Lore m) (State VNameSource))))]
-> BinderT (Lore m) (State VNameSource) Result)
-> [BinderT
(Lore m)
(State VNameSource)
(Exp (Lore (BinderT (Lore m) (State VNameSource))))]
-> BinderT (Lore m) (State VNameSource) Result
forall a b. (a -> b) -> a -> b
$
(SubExp -> BinderT (Lore m) (State VNameSource) (ExpT (Lore m)))
-> Result -> [BinderT (Lore m) (State VNameSource) (ExpT (Lore m))]
forall a b. (a -> b) -> [a] -> [b]
map (ExpT (Lore m)
-> BinderT (Lore m) (State VNameSource) (ExpT (Lore m))
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ExpT (Lore m)
-> BinderT (Lore m) (State VNameSource) (ExpT (Lore m)))
-> (SubExp -> ExpT (Lore m))
-> SubExp
-> BinderT (Lore m) (State VNameSource) (ExpT (Lore m))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BasicOp -> ExpT (Lore m)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT (Lore m))
-> (SubExp -> BasicOp) -> SubExp -> ExpT (Lore m)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SubExp -> BasicOp
SubExp) (Result -> [BinderT (Lore m) (State VNameSource) (ExpT (Lore m))])
-> Result -> [BinderT (Lore m) (State VNameSource) (ExpT (Lore m))]
forall a b. (a -> b) -> a -> b
$
(Param DeclType -> SubExp) -> [Param DeclType] -> Result
forall a b. (a -> b) -> [a] -> [b]
map (VName -> SubExp
Var (VName -> SubExp)
-> (Param DeclType -> VName) -> Param DeclType -> SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Param DeclType -> VName
forall dec. Param dec -> VName
paramName) [Param DeclType]
scanacc_params Result -> Result -> Result
forall a. [a] -> [a] -> [a]
++ Result
scan_res
Result
red_res' <-
Lambda (Lore (BinderT (Lore m) (State VNameSource)))
-> [BinderT
(Lore m)
(State VNameSource)
(Exp (Lore (BinderT (Lore m) (State VNameSource))))]
-> BinderT (Lore m) (State VNameSource) Result
forall (m :: * -> *).
MonadBinder m =>
Lambda (Lore m) -> [m (Exp (Lore m))] -> m Result
eLambda Lambda (Lore m)
Lambda (Lore (BinderT (Lore m) (State VNameSource)))
red_lam ([BinderT
(Lore m)
(State VNameSource)
(Exp (Lore (BinderT (Lore m) (State VNameSource))))]
-> BinderT (Lore m) (State VNameSource) Result)
-> [BinderT
(Lore m)
(State VNameSource)
(Exp (Lore (BinderT (Lore m) (State VNameSource))))]
-> BinderT (Lore m) (State VNameSource) Result
forall a b. (a -> b) -> a -> b
$
(SubExp -> BinderT (Lore m) (State VNameSource) (ExpT (Lore m)))
-> Result -> [BinderT (Lore m) (State VNameSource) (ExpT (Lore m))]
forall a b. (a -> b) -> [a] -> [b]
map (ExpT (Lore m)
-> BinderT (Lore m) (State VNameSource) (ExpT (Lore m))
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ExpT (Lore m)
-> BinderT (Lore m) (State VNameSource) (ExpT (Lore m)))
-> (SubExp -> ExpT (Lore m))
-> SubExp
-> BinderT (Lore m) (State VNameSource) (ExpT (Lore m))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BasicOp -> ExpT (Lore m)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT (Lore m))
-> (SubExp -> BasicOp) -> SubExp -> ExpT (Lore m)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SubExp -> BasicOp
SubExp) (Result -> [BinderT (Lore m) (State VNameSource) (ExpT (Lore m))])
-> Result -> [BinderT (Lore m) (State VNameSource) (ExpT (Lore m))]
forall a b. (a -> b) -> a -> b
$
(Param DeclType -> SubExp) -> [Param DeclType] -> Result
forall a b. (a -> b) -> [a] -> [b]
map (VName -> SubExp
Var (VName -> SubExp)
-> (Param DeclType -> VName) -> Param DeclType -> SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Param DeclType -> VName
forall dec. Param dec -> VName
paramName) [Param DeclType]
redout_params Result -> Result -> Result
forall a. [a] -> [a] -> [a]
++ Result
red_res
[VName]
scan_outarrs <-
[VName]
-> BinderT
(Lore m)
(State VNameSource)
(Exp (Lore (BinderT (Lore m) (State VNameSource))))
-> [Exp (Lore (BinderT (Lore m) (State VNameSource)))]
-> BinderT (Lore m) (State VNameSource) [VName]
forall (m :: * -> *).
Transformer m =>
[VName] -> m (Exp (Lore m)) -> [Exp (Lore m)] -> m [VName]
letwith ((Param DeclType -> VName) -> [Param DeclType] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param DeclType -> VName
forall dec. Param dec -> VName
paramName [Param DeclType]
scanout_params) (SubExp -> BinderT (Lore m) (State VNameSource) (ExpT (Lore m))
forall (f :: * -> *) lore. Applicative f => SubExp -> f (Exp lore)
pexp (VName -> SubExp
Var VName
i)) ([Exp (Lore (BinderT (Lore m) (State VNameSource)))]
-> BinderT (Lore m) (State VNameSource) [VName])
-> [Exp (Lore (BinderT (Lore m) (State VNameSource)))]
-> BinderT (Lore m) (State VNameSource) [VName]
forall a b. (a -> b) -> a -> b
$
(SubExp -> ExpT (Lore m)) -> Result -> [ExpT (Lore m)]
forall a b. (a -> b) -> [a] -> [b]
map (BasicOp -> ExpT (Lore m)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT (Lore m))
-> (SubExp -> BasicOp) -> SubExp -> ExpT (Lore m)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SubExp -> BasicOp
SubExp) Result
scan_res'
[VName]
map_outarrs <-
[VName]
-> BinderT
(Lore m)
(State VNameSource)
(Exp (Lore (BinderT (Lore m) (State VNameSource))))
-> [Exp (Lore (BinderT (Lore m) (State VNameSource)))]
-> BinderT (Lore m) (State VNameSource) [VName]
forall (m :: * -> *).
Transformer m =>
[VName] -> m (Exp (Lore m)) -> [Exp (Lore m)] -> m [VName]
letwith ((Param DeclType -> VName) -> [Param DeclType] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param DeclType -> VName
forall dec. Param dec -> VName
paramName [Param DeclType]
mapout_params) (SubExp -> BinderT (Lore m) (State VNameSource) (ExpT (Lore m))
forall (f :: * -> *) lore. Applicative f => SubExp -> f (Exp lore)
pexp (VName -> SubExp
Var VName
i)) ([Exp (Lore (BinderT (Lore m) (State VNameSource)))]
-> BinderT (Lore m) (State VNameSource) [VName])
-> [Exp (Lore (BinderT (Lore m) (State VNameSource)))]
-> BinderT (Lore m) (State VNameSource) [VName]
forall a b. (a -> b) -> a -> b
$
(SubExp -> ExpT (Lore m)) -> Result -> [ExpT (Lore m)]
forall a b. (a -> b) -> [a] -> [b]
map (BasicOp -> ExpT (Lore m)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT (Lore m))
-> (SubExp -> BasicOp) -> SubExp -> ExpT (Lore m)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SubExp -> BasicOp
SubExp) Result
map_res
Body (Lore m) -> Binder (Lore m) (Body (Lore m))
forall (m :: * -> *) a. Monad m => a -> m a
return (Body (Lore m) -> Binder (Lore m) (Body (Lore m)))
-> Body (Lore m) -> Binder (Lore m) (Body (Lore m))
forall a b. (a -> b) -> a -> b
$
Result -> Body (Lore m)
forall lore. Bindable lore => Result -> Body lore
resultBody (Result -> Body (Lore m)) -> Result -> Body (Lore m)
forall a b. (a -> b) -> a -> b
$
[Result] -> Result
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat
[ Result
scan_res',
(VName -> SubExp) -> [VName] -> Result
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var [VName]
scan_outarrs,
Result
red_res',
(VName -> SubExp) -> [VName] -> Result
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var [VName]
map_outarrs
]
[VName]
names <-
([VName] -> [VName] -> [VName]
forall a. [a] -> [a] -> [a]
++ Pattern (Lore m) -> [VName]
forall dec. PatternT dec -> [VName]
patternNames Pattern (Lore m)
pat)
([VName] -> [VName]) -> m [VName] -> m [VName]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Int -> m VName -> m [VName]
forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM ([Param DeclType] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Param DeclType]
scanacc_params) ([Char] -> m VName
forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
"discard")
[VName] -> ExpT (Lore m) -> m ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m ()
letBindNames [VName]
names (ExpT (Lore m) -> m ()) -> ExpT (Lore m) -> m ()
forall a b. (a -> b) -> a -> b
$ [(FParam (Lore m), SubExp)]
-> [(FParam (Lore m), SubExp)]
-> LoopForm (Lore m)
-> Body (Lore m)
-> ExpT (Lore m)
forall lore.
[(FParam lore, SubExp)]
-> [(FParam lore, SubExp)]
-> LoopForm lore
-> BodyT lore
-> ExpT lore
DoLoop [] [(Param DeclType, SubExp)]
[(FParam (Lore m), SubExp)]
merge LoopForm (Lore m)
loopform Body (Lore m)
loop_body
transformSOAC Pattern (Lore m)
pat (Stream SubExp
w [VName]
arrs StreamForm (Lore m)
_ Result
nes Lambda (Lore m)
lam) = do
let (Param Type
chunk_size_param, [Param Type]
fold_params, [Param Type]
chunk_params) =
Int -> [Param Type] -> (Param Type, [Param Type], [Param Type])
forall dec.
Int -> [Param dec] -> (Param dec, [Param dec], [Param dec])
partitionChunkedFoldParameters (Result -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length Result
nes) ([Param Type] -> (Param Type, [Param Type], [Param Type]))
-> [Param Type] -> (Param Type, [Param Type], [Param Type])
forall a b. (a -> b) -> a -> b
$ Lambda (Lore m) -> [LParam (Lore m)]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda (Lore m)
lam
[(Param DeclType, SubExp)]
mapout_merge <- [Type]
-> (Type -> m (Param DeclType, SubExp))
-> m [(Param DeclType, SubExp)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM (Int -> [Type] -> [Type]
forall a. Int -> [a] -> [a]
drop (Result -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length Result
nes) ([Type] -> [Type]) -> [Type] -> [Type]
forall a b. (a -> b) -> a -> b
$ Lambda (Lore m) -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType Lambda (Lore m)
lam) ((Type -> m (Param DeclType, SubExp))
-> m [(Param DeclType, SubExp)])
-> (Type -> m (Param DeclType, SubExp))
-> m [(Param DeclType, SubExp)]
forall a b. (a -> b) -> a -> b
$ \Type
t ->
let t' :: Type
t' = Type
t Type -> SubExp -> Type
forall d u.
ArrayShape (ShapeBase d) =>
TypeBase (ShapeBase d) u -> d -> TypeBase (ShapeBase d) u
`setOuterSize` SubExp
w
scratch :: ExpT (Lore m)
scratch = BasicOp -> ExpT (Lore m)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT (Lore m)) -> BasicOp -> ExpT (Lore m)
forall a b. (a -> b) -> a -> b
$ PrimType -> Result -> BasicOp
Scratch (Type -> PrimType
forall shape u. TypeBase shape u -> PrimType
elemType Type
t') (Type -> Result
forall u. TypeBase Shape u -> Result
arrayDims Type
t')
in (,)
(Param DeclType -> SubExp -> (Param DeclType, SubExp))
-> m (Param DeclType) -> m (SubExp -> (Param DeclType, SubExp))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Char] -> DeclType -> m (Param DeclType)
forall (m :: * -> *) dec.
MonadFreshNames m =>
[Char] -> dec -> m (Param dec)
newParam [Char]
"stream_mapout" (Type -> Uniqueness -> DeclType
forall shape.
TypeBase shape NoUniqueness
-> Uniqueness -> TypeBase shape Uniqueness
toDecl Type
t' Uniqueness
Unique)
m (SubExp -> (Param DeclType, SubExp))
-> m SubExp -> m (Param DeclType, SubExp)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> [Char] -> ExpT (Lore m) -> m SubExp
forall (m :: * -> *).
MonadBinder m =>
[Char] -> Exp (Lore m) -> m SubExp
letSubExp [Char]
"stream_mapout_scratch" ExpT (Lore m)
scratch
let merge :: [(Param DeclType, SubExp)]
merge =
[Param DeclType] -> Result -> [(Param DeclType, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip ((Param Type -> Param DeclType) -> [Param Type] -> [Param DeclType]
forall a b. (a -> b) -> [a] -> [b]
map ((Type -> DeclType) -> Param Type -> Param DeclType
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Type -> Uniqueness -> DeclType
forall shape.
TypeBase shape NoUniqueness
-> Uniqueness -> TypeBase shape Uniqueness
`toDecl` Uniqueness
Nonunique)) [Param Type]
fold_params) Result
nes
[(Param DeclType, SubExp)]
-> [(Param DeclType, SubExp)] -> [(Param DeclType, SubExp)]
forall a. [a] -> [a] -> [a]
++ [(Param DeclType, SubExp)]
mapout_merge
merge_params :: [Param DeclType]
merge_params = ((Param DeclType, SubExp) -> Param DeclType)
-> [(Param DeclType, SubExp)] -> [Param DeclType]
forall a b. (a -> b) -> [a] -> [b]
map (Param DeclType, SubExp) -> Param DeclType
forall a b. (a, b) -> a
fst [(Param DeclType, SubExp)]
merge
mapout_params :: [Param DeclType]
mapout_params = ((Param DeclType, SubExp) -> Param DeclType)
-> [(Param DeclType, SubExp)] -> [Param DeclType]
forall a b. (a -> b) -> [a] -> [b]
map (Param DeclType, SubExp) -> Param DeclType
forall a b. (a, b) -> a
fst [(Param DeclType, SubExp)]
mapout_merge
VName
i <- [Char] -> m VName
forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
"i"
let loop_form :: LoopForm (Lore m)
loop_form = VName
-> IntType
-> SubExp
-> [(LParam (Lore m), VName)]
-> LoopForm (Lore m)
forall lore.
VName
-> IntType -> SubExp -> [(LParam lore, VName)] -> LoopForm lore
ForLoop VName
i IntType
Int64 SubExp
w []
[VName] -> ExpT (Lore m) -> m ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m ()
letBindNames [Param Type -> VName
forall dec. Param dec -> VName
paramName Param Type
chunk_size_param] (ExpT (Lore m) -> m ()) -> ExpT (Lore m) -> m ()
forall a b. (a -> b) -> a -> b
$
BasicOp -> ExpT (Lore m)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT (Lore m)) -> BasicOp -> ExpT (Lore m)
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1
Body (Lore m)
loop_body <- Binder (Lore m) (Body (Lore m)) -> m (Body (Lore m))
forall lore (m :: * -> *) somelore.
(Bindable lore, MonadFreshNames m, HasScope somelore m,
SameScope somelore lore) =>
Binder lore (Body lore) -> m (Body lore)
runBodyBinder (Binder (Lore m) (Body (Lore m)) -> m (Body (Lore m)))
-> Binder (Lore m) (Body (Lore m)) -> m (Body (Lore m))
forall a b. (a -> b) -> a -> b
$
Scope (Lore m)
-> Binder (Lore m) (Body (Lore m))
-> Binder (Lore m) (Body (Lore m))
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
localScope
( LoopForm (Lore m) -> Scope (Lore m)
forall lore a. Scoped lore a => a -> Scope lore
scopeOf LoopForm (Lore m)
loop_form
Scope (Lore m) -> Scope (Lore m) -> Scope (Lore m)
forall a. Semigroup a => a -> a -> a
<> [Param DeclType] -> Scope (Lore m)
forall lore dec.
(FParamInfo lore ~ dec) =>
[Param dec] -> Scope lore
scopeOfFParams [Param DeclType]
merge_params
)
(Binder (Lore m) (Body (Lore m))
-> Binder (Lore m) (Body (Lore m)))
-> Binder (Lore m) (Body (Lore m))
-> Binder (Lore m) (Body (Lore m))
forall a b. (a -> b) -> a -> b
$ do
let slice :: Slice SubExp
slice =
[SubExp -> SubExp -> SubExp -> DimIndex SubExp
forall d. d -> d -> d -> DimIndex d
DimSlice (VName -> SubExp
Var VName
i) (VName -> SubExp
Var (Param Type -> VName
forall dec. Param dec -> VName
paramName Param Type
chunk_size_param)) (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1)]
[(Param Type, VName)]
-> ((Param Type, VName) -> BinderT (Lore m) (State VNameSource) ())
-> BinderT (Lore m) (State VNameSource) ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param Type] -> [VName] -> [(Param Type, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param Type]
chunk_params [VName]
arrs) (((Param Type, VName) -> BinderT (Lore m) (State VNameSource) ())
-> BinderT (Lore m) (State VNameSource) ())
-> ((Param Type, VName) -> BinderT (Lore m) (State VNameSource) ())
-> BinderT (Lore m) (State VNameSource) ()
forall a b. (a -> b) -> a -> b
$ \(Param Type
p, VName
arr) ->
[VName]
-> Exp (Lore (BinderT (Lore m) (State VNameSource)))
-> BinderT (Lore m) (State VNameSource) ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m ()
letBindNames [Param Type -> VName
forall dec. Param dec -> VName
paramName Param Type
p] (Exp (Lore (BinderT (Lore m) (State VNameSource)))
-> BinderT (Lore m) (State VNameSource) ())
-> Exp (Lore (BinderT (Lore m) (State VNameSource)))
-> BinderT (Lore m) (State VNameSource) ()
forall a b. (a -> b) -> a -> b
$
BasicOp -> ExpT (Lore m)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT (Lore m)) -> BasicOp -> ExpT (Lore m)
forall a b. (a -> b) -> a -> b
$
VName -> Slice SubExp -> BasicOp
Index VName
arr (Slice SubExp -> BasicOp) -> Slice SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$
Type -> Slice SubExp -> Slice SubExp
fullSlice (Param Type -> Type
forall dec. Typed dec => Param dec -> Type
paramType Param Type
p) Slice SubExp
slice
(Result
res, Result
mapout_res) <- Int -> Result -> (Result, Result)
forall a. Int -> [a] -> ([a], [a])
splitAt (Result -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length Result
nes) (Result -> (Result, Result))
-> BinderT (Lore m) (State VNameSource) Result
-> BinderT (Lore m) (State VNameSource) (Result, Result)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Body (Lore (BinderT (Lore m) (State VNameSource)))
-> BinderT (Lore m) (State VNameSource) Result
forall (m :: * -> *). MonadBinder m => Body (Lore m) -> m Result
bodyBind (Lambda (Lore m) -> Body (Lore m)
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda (Lore m)
lam)
Result
mapout_res' <- [(Param DeclType, SubExp)]
-> ((Param DeclType, SubExp)
-> BinderT (Lore m) (State VNameSource) SubExp)
-> BinderT (Lore m) (State VNameSource) Result
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM ([Param DeclType] -> Result -> [(Param DeclType, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param DeclType]
mapout_params Result
mapout_res) (((Param DeclType, SubExp)
-> BinderT (Lore m) (State VNameSource) SubExp)
-> BinderT (Lore m) (State VNameSource) Result)
-> ((Param DeclType, SubExp)
-> BinderT (Lore m) (State VNameSource) SubExp)
-> BinderT (Lore m) (State VNameSource) Result
forall a b. (a -> b) -> a -> b
$ \(Param DeclType
p, SubExp
se) ->
[Char]
-> Exp (Lore (BinderT (Lore m) (State VNameSource)))
-> BinderT (Lore m) (State VNameSource) SubExp
forall (m :: * -> *).
MonadBinder m =>
[Char] -> Exp (Lore m) -> m SubExp
letSubExp [Char]
"mapout_res" (Exp (Lore (BinderT (Lore m) (State VNameSource)))
-> BinderT (Lore m) (State VNameSource) SubExp)
-> Exp (Lore (BinderT (Lore m) (State VNameSource)))
-> BinderT (Lore m) (State VNameSource) SubExp
forall a b. (a -> b) -> a -> b
$
BasicOp -> ExpT (Lore m)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT (Lore m)) -> BasicOp -> ExpT (Lore m)
forall a b. (a -> b) -> a -> b
$
VName -> Slice SubExp -> SubExp -> BasicOp
Update
(Param DeclType -> VName
forall dec. Param dec -> VName
paramName Param DeclType
p)
(Type -> Slice SubExp -> Slice SubExp
fullSlice (Param DeclType -> Type
forall dec. Typed dec => Param dec -> Type
paramType Param DeclType
p) Slice SubExp
slice)
SubExp
se
Result
-> BinderT
(Lore m)
(State VNameSource)
(Body (Lore (BinderT (Lore m) (State VNameSource))))
forall (m :: * -> *). MonadBinder m => Result -> m (Body (Lore m))
resultBodyM (Result
-> BinderT
(Lore m)
(State VNameSource)
(Body (Lore (BinderT (Lore m) (State VNameSource)))))
-> Result
-> BinderT
(Lore m)
(State VNameSource)
(Body (Lore (BinderT (Lore m) (State VNameSource))))
forall a b. (a -> b) -> a -> b
$ Result
res Result -> Result -> Result
forall a. [a] -> [a] -> [a]
++ Result
mapout_res'
Pattern (Lore m) -> ExpT (Lore m) -> m ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind Pattern (Lore m)
pat (ExpT (Lore m) -> m ()) -> ExpT (Lore m) -> m ()
forall a b. (a -> b) -> a -> b
$ [(FParam (Lore m), SubExp)]
-> [(FParam (Lore m), SubExp)]
-> LoopForm (Lore m)
-> Body (Lore m)
-> ExpT (Lore m)
forall lore.
[(FParam lore, SubExp)]
-> [(FParam lore, SubExp)]
-> LoopForm lore
-> BodyT lore
-> ExpT lore
DoLoop [] [(Param DeclType, SubExp)]
[(FParam (Lore m), SubExp)]
merge LoopForm (Lore m)
loop_form Body (Lore m)
loop_body
transformSOAC Pattern (Lore m)
pat (Scatter SubExp
len Lambda (Lore m)
lam [VName]
ivs [(Shape, Int, VName)]
as) = do
VName
iter <- [Char] -> m VName
forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
"write_iter"
let ([Shape]
as_ws, [Int]
as_ns, [VName]
as_vs) = [(Shape, Int, VName)] -> ([Shape], [Int], [VName])
forall a b c. [(a, b, c)] -> ([a], [b], [c])
unzip3 [(Shape, Int, VName)]
as
[Type]
ts <- (VName -> m Type) -> [VName] -> m [Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM VName -> m Type
forall lore (m :: * -> *). HasScope lore m => VName -> m Type
lookupType [VName]
as_vs
[Ident]
asOuts <- (Type -> m Ident) -> [Type] -> m [Ident]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM ([Char] -> Type -> m Ident
forall (m :: * -> *).
MonadFreshNames m =>
[Char] -> Type -> m Ident
newIdent [Char]
"write_out") [Type]
ts
let merge :: [(Param DeclType, SubExp)]
merge = [Ident] -> Result -> [(Param DeclType, SubExp)]
loopMerge [Ident]
asOuts (Result -> [(Param DeclType, SubExp)])
-> Result -> [(Param DeclType, SubExp)]
forall a b. (a -> b) -> a -> b
$ (VName -> SubExp) -> [VName] -> Result
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var [VName]
as_vs
Body (Lore m)
loopBody <- Binder (Lore m) (Body (Lore m)) -> m (Body (Lore m))
forall lore (m :: * -> *) somelore.
(Bindable lore, MonadFreshNames m, HasScope somelore m,
SameScope somelore lore) =>
Binder lore (Body lore) -> m (Body lore)
runBodyBinder (Binder (Lore m) (Body (Lore m)) -> m (Body (Lore m)))
-> Binder (Lore m) (Body (Lore m)) -> m (Body (Lore m))
forall a b. (a -> b) -> a -> b
$
Scope (Lore m)
-> Binder (Lore m) (Body (Lore m))
-> Binder (Lore m) (Body (Lore m))
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
localScope
( VName -> NameInfo (Lore m) -> Scope (Lore m) -> Scope (Lore m)
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
iter (IntType -> NameInfo (Lore m)
forall lore. IntType -> NameInfo lore
IndexName IntType
Int64) (Scope (Lore m) -> Scope (Lore m))
-> Scope (Lore m) -> Scope (Lore m)
forall a b. (a -> b) -> a -> b
$
[Param DeclType] -> Scope (Lore m)
forall lore dec.
(FParamInfo lore ~ dec) =>
[Param dec] -> Scope lore
scopeOfFParams ([Param DeclType] -> Scope (Lore m))
-> [Param DeclType] -> Scope (Lore m)
forall a b. (a -> b) -> a -> b
$ ((Param DeclType, SubExp) -> Param DeclType)
-> [(Param DeclType, SubExp)] -> [Param DeclType]
forall a b. (a -> b) -> [a] -> [b]
map (Param DeclType, SubExp) -> Param DeclType
forall a b. (a, b) -> a
fst [(Param DeclType, SubExp)]
merge
)
(Binder (Lore m) (Body (Lore m))
-> Binder (Lore m) (Body (Lore m)))
-> Binder (Lore m) (Body (Lore m))
-> Binder (Lore m) (Body (Lore m))
forall a b. (a -> b) -> a -> b
$ do
Result
ivs' <- [VName]
-> (VName -> BinderT (Lore m) (State VNameSource) SubExp)
-> BinderT (Lore m) (State VNameSource) Result
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [VName]
ivs ((VName -> BinderT (Lore m) (State VNameSource) SubExp)
-> BinderT (Lore m) (State VNameSource) Result)
-> (VName -> BinderT (Lore m) (State VNameSource) SubExp)
-> BinderT (Lore m) (State VNameSource) Result
forall a b. (a -> b) -> a -> b
$ \VName
iv -> do
Type
iv_t <- VName -> BinderT (Lore m) (State VNameSource) Type
forall lore (m :: * -> *). HasScope lore m => VName -> m Type
lookupType VName
iv
[Char]
-> Exp (Lore (BinderT (Lore m) (State VNameSource)))
-> BinderT (Lore m) (State VNameSource) SubExp
forall (m :: * -> *).
MonadBinder m =>
[Char] -> Exp (Lore m) -> m SubExp
letSubExp [Char]
"write_iv" (Exp (Lore (BinderT (Lore m) (State VNameSource)))
-> BinderT (Lore m) (State VNameSource) SubExp)
-> Exp (Lore (BinderT (Lore m) (State VNameSource)))
-> BinderT (Lore m) (State VNameSource) SubExp
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT (Lore m)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT (Lore m)) -> BasicOp -> ExpT (Lore m)
forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
iv (Slice SubExp -> BasicOp) -> Slice SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ Type -> Slice SubExp -> Slice SubExp
fullSlice Type
iv_t [SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix (SubExp -> DimIndex SubExp) -> SubExp -> DimIndex SubExp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
iter]
Result
ivs'' <- Lambda (Lore (BinderT (Lore m) (State VNameSource)))
-> [Exp (Lore (BinderT (Lore m) (State VNameSource)))]
-> BinderT (Lore m) (State VNameSource) Result
forall (m :: * -> *).
Transformer m =>
Lambda (Lore m) -> [Exp (Lore m)] -> m Result
bindLambda Lambda (Lore m)
Lambda (Lore (BinderT (Lore m) (State VNameSource)))
lam ((SubExp -> ExpT (Lore m)) -> Result -> [ExpT (Lore m)]
forall a b. (a -> b) -> [a] -> [b]
map (BasicOp -> ExpT (Lore m)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT (Lore m))
-> (SubExp -> BasicOp) -> SubExp -> ExpT (Lore m)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SubExp -> BasicOp
SubExp) Result
ivs')
let indexes :: [(Shape, VName, [(Result, SubExp)])]
indexes = [(Shape, Int, VName)]
-> Result -> [(Shape, VName, [(Result, SubExp)])]
forall array a.
[(Shape, Int, array)] -> [a] -> [(Shape, array, [([a], a)])]
groupScatterResults ([Shape] -> [Int] -> [VName] -> [(Shape, Int, VName)]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [Shape]
as_ws [Int]
as_ns ([VName] -> [(Shape, Int, VName)])
-> [VName] -> [(Shape, Int, VName)]
forall a b. (a -> b) -> a -> b
$ (Ident -> VName) -> [Ident] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Ident -> VName
identName [Ident]
asOuts) Result
ivs''
[VName]
ress <- [(Shape, VName, [(Result, SubExp)])]
-> ((Shape, VName, [(Result, SubExp)])
-> BinderT (Lore m) (State VNameSource) VName)
-> BinderT (Lore m) (State VNameSource) [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [(Shape, VName, [(Result, SubExp)])]
indexes (((Shape, VName, [(Result, SubExp)])
-> BinderT (Lore m) (State VNameSource) VName)
-> BinderT (Lore m) (State VNameSource) [VName])
-> ((Shape, VName, [(Result, SubExp)])
-> BinderT (Lore m) (State VNameSource) VName)
-> BinderT (Lore m) (State VNameSource) [VName]
forall a b. (a -> b) -> a -> b
$ \(Shape
_, VName
arr, [(Result, SubExp)]
indexes') -> do
let saveInArray :: VName -> (Result, SubExp) -> m VName
saveInArray VName
arr' (Result
indexCur, SubExp
valueCur) =
[Char] -> Exp (Lore m) -> m VName
forall (m :: * -> *).
MonadBinder m =>
[Char] -> Exp (Lore m) -> m VName
letExp [Char]
"write_out" (Exp (Lore m) -> m VName) -> m (Exp (Lore m)) -> m VName
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< VName -> [m (Exp (Lore m))] -> m (Exp (Lore m)) -> m (Exp (Lore m))
forall (m :: * -> *).
(MonadBinder m, BranchType (Lore m) ~ ExtType) =>
VName -> [m (Exp (Lore m))] -> m (Exp (Lore m)) -> m (Exp (Lore m))
eWriteArray VName
arr' ((SubExp -> m (Exp (Lore m))) -> Result -> [m (Exp (Lore m))]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> m (Exp (Lore m))
forall (m :: * -> *). MonadBinder m => SubExp -> m (Exp (Lore m))
eSubExp Result
indexCur) (SubExp -> m (Exp (Lore m))
forall (m :: * -> *). MonadBinder m => SubExp -> m (Exp (Lore m))
eSubExp SubExp
valueCur)
(VName
-> (Result, SubExp) -> BinderT (Lore m) (State VNameSource) VName)
-> VName
-> [(Result, SubExp)]
-> BinderT (Lore m) (State VNameSource) VName
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM VName
-> (Result, SubExp) -> BinderT (Lore m) (State VNameSource) VName
forall (m :: * -> *).
(MonadBinder m, BranchType (Lore m) ~ ExtType) =>
VName -> (Result, SubExp) -> m VName
saveInArray VName
arr [(Result, SubExp)]
indexes'
Body (Lore m) -> Binder (Lore m) (Body (Lore m))
forall (m :: * -> *) a. Monad m => a -> m a
return (Body (Lore m) -> Binder (Lore m) (Body (Lore m)))
-> Body (Lore m) -> Binder (Lore m) (Body (Lore m))
forall a b. (a -> b) -> a -> b
$ Result -> Body (Lore m)
forall lore. Bindable lore => Result -> Body lore
resultBody ((VName -> SubExp) -> [VName] -> Result
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var [VName]
ress)
Pattern (Lore m) -> ExpT (Lore m) -> m ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind Pattern (Lore m)
pat (ExpT (Lore m) -> m ()) -> ExpT (Lore m) -> m ()
forall a b. (a -> b) -> a -> b
$ [(FParam (Lore m), SubExp)]
-> [(FParam (Lore m), SubExp)]
-> LoopForm (Lore m)
-> Body (Lore m)
-> ExpT (Lore m)
forall lore.
[(FParam lore, SubExp)]
-> [(FParam lore, SubExp)]
-> LoopForm lore
-> BodyT lore
-> ExpT lore
DoLoop [] [(Param DeclType, SubExp)]
[(FParam (Lore m), SubExp)]
merge (VName
-> IntType
-> SubExp
-> [(LParam (Lore m), VName)]
-> LoopForm (Lore m)
forall lore.
VName
-> IntType -> SubExp -> [(LParam lore, VName)] -> LoopForm lore
ForLoop VName
iter IntType
Int64 SubExp
len []) Body (Lore m)
loopBody
transformSOAC Pattern (Lore m)
pat (Hist SubExp
len [HistOp (Lore m)]
ops Lambda (Lore m)
bucket_fun [VName]
imgs) = do
VName
iter <- [Char] -> m VName
forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
"iter"
[Type]
hists_ts <- (VName -> m Type) -> [VName] -> m [Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM VName -> m Type
forall lore (m :: * -> *). HasScope lore m => VName -> m Type
lookupType ([VName] -> m [Type]) -> [VName] -> m [Type]
forall a b. (a -> b) -> a -> b
$ (HistOp (Lore m) -> [VName]) -> [HistOp (Lore m)] -> [VName]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap HistOp (Lore m) -> [VName]
forall lore. HistOp lore -> [VName]
histDest [HistOp (Lore m)]
ops
[Ident]
hists_out <- (Type -> m Ident) -> [Type] -> m [Ident]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM ([Char] -> Type -> m Ident
forall (m :: * -> *).
MonadFreshNames m =>
[Char] -> Type -> m Ident
newIdent [Char]
"dests") [Type]
hists_ts
let merge :: [(Param DeclType, SubExp)]
merge = [Ident] -> Result -> [(Param DeclType, SubExp)]
loopMerge [Ident]
hists_out (Result -> [(Param DeclType, SubExp)])
-> Result -> [(Param DeclType, SubExp)]
forall a b. (a -> b) -> a -> b
$ (HistOp (Lore m) -> Result) -> [HistOp (Lore m)] -> Result
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap ((VName -> SubExp) -> [VName] -> Result
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var ([VName] -> Result)
-> (HistOp (Lore m) -> [VName]) -> HistOp (Lore m) -> Result
forall b c a. (b -> c) -> (a -> b) -> a -> c
. HistOp (Lore m) -> [VName]
forall lore. HistOp lore -> [VName]
histDest) [HistOp (Lore m)]
ops
Body (Lore m)
loopBody <- Binder (Lore m) (Body (Lore m)) -> m (Body (Lore m))
forall lore (m :: * -> *) somelore.
(Bindable lore, MonadFreshNames m, HasScope somelore m,
SameScope somelore lore) =>
Binder lore (Body lore) -> m (Body lore)
runBodyBinder (Binder (Lore m) (Body (Lore m)) -> m (Body (Lore m)))
-> Binder (Lore m) (Body (Lore m)) -> m (Body (Lore m))
forall a b. (a -> b) -> a -> b
$
Scope (Lore m)
-> Binder (Lore m) (Body (Lore m))
-> Binder (Lore m) (Body (Lore m))
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
localScope
( VName -> NameInfo (Lore m) -> Scope (Lore m) -> Scope (Lore m)
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
iter (IntType -> NameInfo (Lore m)
forall lore. IntType -> NameInfo lore
IndexName IntType
Int64) (Scope (Lore m) -> Scope (Lore m))
-> Scope (Lore m) -> Scope (Lore m)
forall a b. (a -> b) -> a -> b
$
[Param DeclType] -> Scope (Lore m)
forall lore dec.
(FParamInfo lore ~ dec) =>
[Param dec] -> Scope lore
scopeOfFParams ([Param DeclType] -> Scope (Lore m))
-> [Param DeclType] -> Scope (Lore m)
forall a b. (a -> b) -> a -> b
$ ((Param DeclType, SubExp) -> Param DeclType)
-> [(Param DeclType, SubExp)] -> [Param DeclType]
forall a b. (a -> b) -> [a] -> [b]
map (Param DeclType, SubExp) -> Param DeclType
forall a b. (a, b) -> a
fst [(Param DeclType, SubExp)]
merge
)
(Binder (Lore m) (Body (Lore m))
-> Binder (Lore m) (Body (Lore m)))
-> Binder (Lore m) (Body (Lore m))
-> Binder (Lore m) (Body (Lore m))
forall a b. (a -> b) -> a -> b
$ do
Result
imgs' <- [VName]
-> (VName -> BinderT (Lore m) (State VNameSource) SubExp)
-> BinderT (Lore m) (State VNameSource) Result
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [VName]
imgs ((VName -> BinderT (Lore m) (State VNameSource) SubExp)
-> BinderT (Lore m) (State VNameSource) Result)
-> (VName -> BinderT (Lore m) (State VNameSource) SubExp)
-> BinderT (Lore m) (State VNameSource) Result
forall a b. (a -> b) -> a -> b
$ \VName
img -> do
Type
img_t <- VName -> BinderT (Lore m) (State VNameSource) Type
forall lore (m :: * -> *). HasScope lore m => VName -> m Type
lookupType VName
img
[Char]
-> Exp (Lore (BinderT (Lore m) (State VNameSource)))
-> BinderT (Lore m) (State VNameSource) SubExp
forall (m :: * -> *).
MonadBinder m =>
[Char] -> Exp (Lore m) -> m SubExp
letSubExp [Char]
"pixel" (Exp (Lore (BinderT (Lore m) (State VNameSource)))
-> BinderT (Lore m) (State VNameSource) SubExp)
-> Exp (Lore (BinderT (Lore m) (State VNameSource)))
-> BinderT (Lore m) (State VNameSource) SubExp
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT (Lore m)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT (Lore m)) -> BasicOp -> ExpT (Lore m)
forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
img (Slice SubExp -> BasicOp) -> Slice SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ Type -> Slice SubExp -> Slice SubExp
fullSlice Type
img_t [SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix (SubExp -> DimIndex SubExp) -> SubExp -> DimIndex SubExp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
iter]
Result
imgs'' <- Lambda (Lore (BinderT (Lore m) (State VNameSource)))
-> [Exp (Lore (BinderT (Lore m) (State VNameSource)))]
-> BinderT (Lore m) (State VNameSource) Result
forall (m :: * -> *).
Transformer m =>
Lambda (Lore m) -> [Exp (Lore m)] -> m Result
bindLambda Lambda (Lore m)
Lambda (Lore (BinderT (Lore m) (State VNameSource)))
bucket_fun ([Exp (Lore (BinderT (Lore m) (State VNameSource)))]
-> BinderT (Lore m) (State VNameSource) Result)
-> [Exp (Lore (BinderT (Lore m) (State VNameSource)))]
-> BinderT (Lore m) (State VNameSource) Result
forall a b. (a -> b) -> a -> b
$ (SubExp -> ExpT (Lore m)) -> Result -> [ExpT (Lore m)]
forall a b. (a -> b) -> [a] -> [b]
map (BasicOp -> ExpT (Lore m)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT (Lore m))
-> (SubExp -> BasicOp) -> SubExp -> ExpT (Lore m)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SubExp -> BasicOp
SubExp) Result
imgs'
let lens :: Int
lens = [HistOp (Lore m)] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [HistOp (Lore m)]
ops
inds :: Result
inds = Int -> Result -> Result
forall a. Int -> [a] -> [a]
take Int
lens Result
imgs''
vals :: [Result]
vals = [Int] -> Result -> [Result]
forall a. [Int] -> [a] -> [[a]]
chunks ((HistOp (Lore m) -> Int) -> [HistOp (Lore m)] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map ([Type] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([Type] -> Int)
-> (HistOp (Lore m) -> [Type]) -> HistOp (Lore m) -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Lambda (Lore m) -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType (Lambda (Lore m) -> [Type])
-> (HistOp (Lore m) -> Lambda (Lore m))
-> HistOp (Lore m)
-> [Type]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. HistOp (Lore m) -> Lambda (Lore m)
forall lore. HistOp lore -> Lambda lore
histOp) [HistOp (Lore m)]
ops) (Result -> [Result]) -> Result -> [Result]
forall a b. (a -> b) -> a -> b
$ Int -> Result -> Result
forall a. Int -> [a] -> [a]
drop Int
lens Result
imgs''
hists_out' :: [[VName]]
hists_out' =
[Int] -> [VName] -> [[VName]]
forall a. [Int] -> [a] -> [[a]]
chunks ((HistOp (Lore m) -> Int) -> [HistOp (Lore m)] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map ([Type] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([Type] -> Int)
-> (HistOp (Lore m) -> [Type]) -> HistOp (Lore m) -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Lambda (Lore m) -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType (Lambda (Lore m) -> [Type])
-> (HistOp (Lore m) -> Lambda (Lore m))
-> HistOp (Lore m)
-> [Type]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. HistOp (Lore m) -> Lambda (Lore m)
forall lore. HistOp lore -> Lambda lore
histOp) [HistOp (Lore m)]
ops) ([VName] -> [[VName]]) -> [VName] -> [[VName]]
forall a b. (a -> b) -> a -> b
$
(Ident -> VName) -> [Ident] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Ident -> VName
identName [Ident]
hists_out
[[VName]]
hists_out'' <- [([VName], HistOp (Lore m), SubExp, Result)]
-> (([VName], HistOp (Lore m), SubExp, Result)
-> BinderT (Lore m) (State VNameSource) [VName])
-> BinderT (Lore m) (State VNameSource) [[VName]]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM ([[VName]]
-> [HistOp (Lore m)]
-> Result
-> [Result]
-> [([VName], HistOp (Lore m), SubExp, Result)]
forall a b c d. [a] -> [b] -> [c] -> [d] -> [(a, b, c, d)]
zip4 [[VName]]
hists_out' [HistOp (Lore m)]
ops Result
inds [Result]
vals) ((([VName], HistOp (Lore m), SubExp, Result)
-> BinderT (Lore m) (State VNameSource) [VName])
-> BinderT (Lore m) (State VNameSource) [[VName]])
-> (([VName], HistOp (Lore m), SubExp, Result)
-> BinderT (Lore m) (State VNameSource) [VName])
-> BinderT (Lore m) (State VNameSource) [[VName]]
forall a b. (a -> b) -> a -> b
$ \([VName]
hist, HistOp (Lore m)
op, SubExp
idx, Result
val) -> do
let outside_bounds_branch :: BinderT
(Lore m)
(State VNameSource)
(Body (Lore (BinderT (Lore m) (State VNameSource))))
outside_bounds_branch = BinderT
(Lore m)
(State VNameSource)
(Body (Lore (BinderT (Lore m) (State VNameSource))))
-> BinderT
(Lore m)
(State VNameSource)
(Body (Lore (BinderT (Lore m) (State VNameSource))))
forall (m :: * -> *).
MonadBinder m =>
m (Body (Lore m)) -> m (Body (Lore m))
insertStmsM (BinderT
(Lore m)
(State VNameSource)
(Body (Lore (BinderT (Lore m) (State VNameSource))))
-> BinderT
(Lore m)
(State VNameSource)
(Body (Lore (BinderT (Lore m) (State VNameSource)))))
-> BinderT
(Lore m)
(State VNameSource)
(Body (Lore (BinderT (Lore m) (State VNameSource))))
-> BinderT
(Lore m)
(State VNameSource)
(Body (Lore (BinderT (Lore m) (State VNameSource))))
forall a b. (a -> b) -> a -> b
$ Result
-> BinderT
(Lore m)
(State VNameSource)
(Body (Lore (BinderT (Lore m) (State VNameSource))))
forall (m :: * -> *). MonadBinder m => Result -> m (Body (Lore m))
resultBodyM (Result
-> BinderT
(Lore m)
(State VNameSource)
(Body (Lore (BinderT (Lore m) (State VNameSource)))))
-> Result
-> BinderT
(Lore m)
(State VNameSource)
(Body (Lore (BinderT (Lore m) (State VNameSource))))
forall a b. (a -> b) -> a -> b
$ (VName -> SubExp) -> [VName] -> Result
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var [VName]
hist
oob :: BinderT (Lore m) (State VNameSource) (ExpT (Lore m))
oob = case [VName]
hist of
[] -> SubExp
-> BinderT
(Lore m)
(State VNameSource)
(Exp (Lore (BinderT (Lore m) (State VNameSource))))
forall (m :: * -> *). MonadBinder m => SubExp -> m (Exp (Lore m))
eSubExp (SubExp
-> BinderT
(Lore m)
(State VNameSource)
(Exp (Lore (BinderT (Lore m) (State VNameSource)))))
-> SubExp
-> BinderT
(Lore m)
(State VNameSource)
(Exp (Lore (BinderT (Lore m) (State VNameSource))))
forall a b. (a -> b) -> a -> b
$ Bool -> SubExp
forall v. IsValue v => v -> SubExp
constant Bool
True
VName
arr : [VName]
_ -> VName
-> [BinderT
(Lore m)
(State VNameSource)
(Exp (Lore (BinderT (Lore m) (State VNameSource))))]
-> BinderT
(Lore m)
(State VNameSource)
(Exp (Lore (BinderT (Lore m) (State VNameSource))))
forall (m :: * -> *).
MonadBinder m =>
VName -> [m (Exp (Lore m))] -> m (Exp (Lore m))
eOutOfBounds VName
arr [SubExp
-> BinderT
(Lore m)
(State VNameSource)
(Exp (Lore (BinderT (Lore m) (State VNameSource))))
forall (m :: * -> *). MonadBinder m => SubExp -> m (Exp (Lore m))
eSubExp SubExp
idx]
[Char]
-> Exp (Lore (BinderT (Lore m) (State VNameSource)))
-> BinderT (Lore m) (State VNameSource) [VName]
forall (m :: * -> *).
MonadBinder m =>
[Char] -> Exp (Lore m) -> m [VName]
letTupExp [Char]
"new_histo"
(ExpT (Lore m) -> BinderT (Lore m) (State VNameSource) [VName])
-> (Binder (Lore m) (Body (Lore m))
-> BinderT (Lore m) (State VNameSource) (ExpT (Lore m)))
-> Binder (Lore m) (Body (Lore m))
-> BinderT (Lore m) (State VNameSource) [VName]
forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< BinderT
(Lore m)
(State VNameSource)
(Exp (Lore (BinderT (Lore m) (State VNameSource))))
-> BinderT
(Lore m)
(State VNameSource)
(Body (Lore (BinderT (Lore m) (State VNameSource))))
-> BinderT
(Lore m)
(State VNameSource)
(Body (Lore (BinderT (Lore m) (State VNameSource))))
-> BinderT
(Lore m)
(State VNameSource)
(Exp (Lore (BinderT (Lore m) (State VNameSource))))
forall (m :: * -> *).
(MonadBinder m, BranchType (Lore m) ~ ExtType) =>
m (Exp (Lore m))
-> m (Body (Lore m)) -> m (Body (Lore m)) -> m (Exp (Lore m))
eIf BinderT (Lore m) (State VNameSource) (ExpT (Lore m))
BinderT
(Lore m)
(State VNameSource)
(Exp (Lore (BinderT (Lore m) (State VNameSource))))
oob BinderT
(Lore m)
(State VNameSource)
(Body (Lore (BinderT (Lore m) (State VNameSource))))
outside_bounds_branch
(Binder (Lore m) (Body (Lore m))
-> BinderT (Lore m) (State VNameSource) [VName])
-> Binder (Lore m) (Body (Lore m))
-> BinderT (Lore m) (State VNameSource) [VName]
forall a b. (a -> b) -> a -> b
$ do
Result
h_val <- [VName]
-> (VName -> BinderT (Lore m) (State VNameSource) SubExp)
-> BinderT (Lore m) (State VNameSource) Result
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [VName]
hist ((VName -> BinderT (Lore m) (State VNameSource) SubExp)
-> BinderT (Lore m) (State VNameSource) Result)
-> (VName -> BinderT (Lore m) (State VNameSource) SubExp)
-> BinderT (Lore m) (State VNameSource) Result
forall a b. (a -> b) -> a -> b
$ \VName
arr -> do
Type
arr_t <- VName -> BinderT (Lore m) (State VNameSource) Type
forall lore (m :: * -> *). HasScope lore m => VName -> m Type
lookupType VName
arr
[Char]
-> Exp (Lore (BinderT (Lore m) (State VNameSource)))
-> BinderT (Lore m) (State VNameSource) SubExp
forall (m :: * -> *).
MonadBinder m =>
[Char] -> Exp (Lore m) -> m SubExp
letSubExp [Char]
"read_hist" (Exp (Lore (BinderT (Lore m) (State VNameSource)))
-> BinderT (Lore m) (State VNameSource) SubExp)
-> Exp (Lore (BinderT (Lore m) (State VNameSource)))
-> BinderT (Lore m) (State VNameSource) SubExp
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT (Lore m)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT (Lore m)) -> BasicOp -> ExpT (Lore m)
forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
arr (Slice SubExp -> BasicOp) -> Slice SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ Type -> Slice SubExp -> Slice SubExp
fullSlice Type
arr_t [SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix SubExp
idx]
Result
h_val' <-
Lambda (Lore (BinderT (Lore m) (State VNameSource)))
-> [Exp (Lore (BinderT (Lore m) (State VNameSource)))]
-> BinderT (Lore m) (State VNameSource) Result
forall (m :: * -> *).
Transformer m =>
Lambda (Lore m) -> [Exp (Lore m)] -> m Result
bindLambda (HistOp (Lore m) -> Lambda (Lore m)
forall lore. HistOp lore -> Lambda lore
histOp HistOp (Lore m)
op) ([Exp (Lore (BinderT (Lore m) (State VNameSource)))]
-> BinderT (Lore m) (State VNameSource) Result)
-> [Exp (Lore (BinderT (Lore m) (State VNameSource)))]
-> BinderT (Lore m) (State VNameSource) Result
forall a b. (a -> b) -> a -> b
$
(SubExp -> ExpT (Lore m)) -> Result -> [ExpT (Lore m)]
forall a b. (a -> b) -> [a] -> [b]
map (BasicOp -> ExpT (Lore m)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT (Lore m))
-> (SubExp -> BasicOp) -> SubExp -> ExpT (Lore m)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SubExp -> BasicOp
SubExp) (Result -> [ExpT (Lore m)]) -> Result -> [ExpT (Lore m)]
forall a b. (a -> b) -> a -> b
$ Result
h_val Result -> Result -> Result
forall a. [a] -> [a] -> [a]
++ Result
val
[VName]
hist' <- [(VName, SubExp)]
-> ((VName, SubExp) -> BinderT (Lore m) (State VNameSource) VName)
-> BinderT (Lore m) (State VNameSource) [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM ([VName] -> Result -> [(VName, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
hist Result
h_val') (((VName, SubExp) -> BinderT (Lore m) (State VNameSource) VName)
-> BinderT (Lore m) (State VNameSource) [VName])
-> ((VName, SubExp) -> BinderT (Lore m) (State VNameSource) VName)
-> BinderT (Lore m) (State VNameSource) [VName]
forall a b. (a -> b) -> a -> b
$ \(VName
arr, SubExp
v) -> do
Type
arr_t <- VName -> BinderT (Lore m) (State VNameSource) Type
forall lore (m :: * -> *). HasScope lore m => VName -> m Type
lookupType VName
arr
[Char]
-> VName
-> Slice SubExp
-> Exp (Lore (BinderT (Lore m) (State VNameSource)))
-> BinderT (Lore m) (State VNameSource) VName
forall (m :: * -> *).
MonadBinder m =>
[Char] -> VName -> Slice SubExp -> Exp (Lore m) -> m VName
letInPlace [Char]
"hist_out" VName
arr (Type -> Slice SubExp -> Slice SubExp
fullSlice Type
arr_t [SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix SubExp
idx]) (Exp (Lore (BinderT (Lore m) (State VNameSource)))
-> BinderT (Lore m) (State VNameSource) VName)
-> Exp (Lore (BinderT (Lore m) (State VNameSource)))
-> BinderT (Lore m) (State VNameSource) VName
forall a b. (a -> b) -> a -> b
$
BasicOp -> ExpT (Lore m)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT (Lore m)) -> BasicOp -> ExpT (Lore m)
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
v
Body (Lore m) -> Binder (Lore m) (Body (Lore m))
forall (m :: * -> *) a. Monad m => a -> m a
return (Body (Lore m) -> Binder (Lore m) (Body (Lore m)))
-> Body (Lore m) -> Binder (Lore m) (Body (Lore m))
forall a b. (a -> b) -> a -> b
$ Result -> Body (Lore m)
forall lore. Bindable lore => Result -> Body lore
resultBody (Result -> Body (Lore m)) -> Result -> Body (Lore m)
forall a b. (a -> b) -> a -> b
$ (VName -> SubExp) -> [VName] -> Result
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var [VName]
hist'
Body (Lore m) -> Binder (Lore m) (Body (Lore m))
forall (m :: * -> *) a. Monad m => a -> m a
return (Body (Lore m) -> Binder (Lore m) (Body (Lore m)))
-> Body (Lore m) -> Binder (Lore m) (Body (Lore m))
forall a b. (a -> b) -> a -> b
$ Result -> Body (Lore m)
forall lore. Bindable lore => Result -> Body lore
resultBody (Result -> Body (Lore m)) -> Result -> Body (Lore m)
forall a b. (a -> b) -> a -> b
$ (VName -> SubExp) -> [VName] -> Result
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var ([VName] -> Result) -> [VName] -> Result
forall a b. (a -> b) -> a -> b
$ [[VName]] -> [VName]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[VName]]
hists_out''
Pattern (Lore m) -> ExpT (Lore m) -> m ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind Pattern (Lore m)
pat (ExpT (Lore m) -> m ()) -> ExpT (Lore m) -> m ()
forall a b. (a -> b) -> a -> b
$ [(FParam (Lore m), SubExp)]
-> [(FParam (Lore m), SubExp)]
-> LoopForm (Lore m)
-> Body (Lore m)
-> ExpT (Lore m)
forall lore.
[(FParam lore, SubExp)]
-> [(FParam lore, SubExp)]
-> LoopForm lore
-> BodyT lore
-> ExpT lore
DoLoop [] [(Param DeclType, SubExp)]
[(FParam (Lore m), SubExp)]
merge (VName
-> IntType
-> SubExp
-> [(LParam (Lore m), VName)]
-> LoopForm (Lore m)
forall lore.
VName
-> IntType -> SubExp -> [(LParam lore, VName)] -> LoopForm lore
ForLoop VName
iter IntType
Int64 SubExp
len []) Body (Lore m)
loopBody
transformLambda ::
( MonadFreshNames m,
Bindable lore,
BinderOps lore,
LocalScope somelore m,
SameScope somelore lore,
LetDec lore ~ LetDec SOACS
) =>
Lambda ->
m (AST.Lambda lore)
transformLambda :: Lambda SOACS -> m (Lambda lore)
transformLambda (Lambda [LParam SOACS]
params BodyT SOACS
body [Type]
rettype) = do
Body lore
body' <-
Binder lore (Body lore) -> m (Body lore)
forall lore (m :: * -> *) somelore.
(Bindable lore, MonadFreshNames m, HasScope somelore m,
SameScope somelore lore) =>
Binder lore (Body lore) -> m (Body lore)
runBodyBinder (Binder lore (Body lore) -> m (Body lore))
-> Binder lore (Body lore) -> m (Body lore)
forall a b. (a -> b) -> a -> b
$
Scope lore -> Binder lore (Body lore) -> Binder lore (Body lore)
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
localScope ([Param Type] -> Scope lore
forall lore dec.
(LParamInfo lore ~ dec) =>
[Param dec] -> Scope lore
scopeOfLParams [Param Type]
[LParam SOACS]
params) (Binder lore (Body lore) -> Binder lore (Body lore))
-> Binder lore (Body lore) -> Binder lore (Body lore)
forall a b. (a -> b) -> a -> b
$
BodyT SOACS
-> BinderT
lore
(State VNameSource)
(Body (Lore (BinderT lore (State VNameSource))))
forall (m :: * -> *).
(Transformer m, LetDec (Lore m) ~ LetDec SOACS) =>
BodyT SOACS -> m (Body (Lore m))
transformBody BodyT SOACS
body
Lambda lore -> m (Lambda lore)
forall (m :: * -> *) a. Monad m => a -> m a
return (Lambda lore -> m (Lambda lore)) -> Lambda lore -> m (Lambda lore)
forall a b. (a -> b) -> a -> b
$ [LParam lore] -> Body lore -> [Type] -> Lambda lore
forall lore. [LParam lore] -> BodyT lore -> [Type] -> LambdaT lore
Lambda [LParam lore]
[LParam SOACS]
params Body lore
body' [Type]
rettype
resultArray :: Transformer m => [Type] -> m [VName]
resultArray :: [Type] -> m [VName]
resultArray = (Type -> m VName) -> [Type] -> m [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Type -> m VName
forall (m :: * -> *) u.
MonadBinder m =>
TypeBase Shape u -> m VName
oneArray
where
oneArray :: TypeBase Shape u -> m VName
oneArray TypeBase Shape u
t = [Char] -> Exp (Lore m) -> m VName
forall (m :: * -> *).
MonadBinder m =>
[Char] -> Exp (Lore m) -> m VName
letExp [Char]
"result" (Exp (Lore m) -> m VName) -> Exp (Lore m) -> m VName
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Lore m)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp (Lore m)) -> BasicOp -> Exp (Lore m)
forall a b. (a -> b) -> a -> b
$ PrimType -> Result -> BasicOp
Scratch (TypeBase Shape u -> PrimType
forall shape u. TypeBase shape u -> PrimType
elemType TypeBase Shape u
t) (TypeBase Shape u -> Result
forall u. TypeBase Shape u -> Result
arrayDims TypeBase Shape u
t)
letwith ::
Transformer m =>
[VName] ->
m (AST.Exp (Lore m)) ->
[AST.Exp (Lore m)] ->
m [VName]
letwith :: [VName] -> m (Exp (Lore m)) -> [Exp (Lore m)] -> m [VName]
letwith [VName]
ks m (Exp (Lore m))
i [Exp (Lore m)]
vs = do
Result
vs' <- [Char] -> [Exp (Lore m)] -> m Result
forall (m :: * -> *).
MonadBinder m =>
[Char] -> [Exp (Lore m)] -> m Result
letSubExps [Char]
"values" [Exp (Lore m)]
vs
SubExp
i' <- [Char] -> Exp (Lore m) -> m SubExp
forall (m :: * -> *).
MonadBinder m =>
[Char] -> Exp (Lore m) -> m SubExp
letSubExp [Char]
"i" (Exp (Lore m) -> m SubExp) -> m (Exp (Lore m)) -> m SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< m (Exp (Lore m))
i
let update :: VName -> SubExp -> m VName
update VName
k SubExp
v = do
Type
k_t <- VName -> m Type
forall lore (m :: * -> *). HasScope lore m => VName -> m Type
lookupType VName
k
[Char] -> VName -> Slice SubExp -> Exp (Lore m) -> m VName
forall (m :: * -> *).
MonadBinder m =>
[Char] -> VName -> Slice SubExp -> Exp (Lore m) -> m VName
letInPlace [Char]
"lw_dest" VName
k (Type -> Slice SubExp -> Slice SubExp
fullSlice Type
k_t [SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix SubExp
i']) (Exp (Lore m) -> m VName) -> Exp (Lore m) -> m VName
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Lore m)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp (Lore m)) -> BasicOp -> Exp (Lore m)
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
v
(VName -> SubExp -> m VName) -> [VName] -> Result -> m [VName]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM VName -> SubExp -> m VName
update [VName]
ks Result
vs'
pexp :: Applicative f => SubExp -> f (AST.Exp lore)
pexp :: SubExp -> f (Exp lore)
pexp = Exp lore -> f (Exp lore)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Exp lore -> f (Exp lore))
-> (SubExp -> Exp lore) -> SubExp -> f (Exp lore)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BasicOp -> Exp lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp lore) -> (SubExp -> BasicOp) -> SubExp -> Exp lore
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SubExp -> BasicOp
SubExp
bindLambda ::
Transformer m =>
AST.Lambda (Lore m) ->
[AST.Exp (Lore m)] ->
m [SubExp]
bindLambda :: Lambda (Lore m) -> [Exp (Lore m)] -> m Result
bindLambda (Lambda [LParam (Lore m)]
params BodyT (Lore m)
body [Type]
_) [Exp (Lore m)]
args = do
[(Param Type, Exp (Lore m))]
-> ((Param Type, Exp (Lore m)) -> m ()) -> m ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param Type] -> [Exp (Lore m)] -> [(Param Type, Exp (Lore m))]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param Type]
[LParam (Lore m)]
params [Exp (Lore m)]
args) (((Param Type, Exp (Lore m)) -> m ()) -> m ())
-> ((Param Type, Exp (Lore m)) -> m ()) -> m ()
forall a b. (a -> b) -> a -> b
$ \(Param Type
param, Exp (Lore m)
arg) ->
if Type -> Bool
forall shape u. TypeBase shape u -> Bool
primType (Type -> Bool) -> Type -> Bool
forall a b. (a -> b) -> a -> b
$ Param Type -> Type
forall dec. Typed dec => Param dec -> Type
paramType Param Type
param
then [VName] -> Exp (Lore m) -> m ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m ()
letBindNames [Param Type -> VName
forall dec. Param dec -> VName
paramName Param Type
param] Exp (Lore m)
arg
else [VName] -> Exp (Lore m) -> m ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m ()
letBindNames [Param Type -> VName
forall dec. Param dec -> VName
paramName Param Type
param] (Exp (Lore m) -> m ()) -> m (Exp (Lore m)) -> m ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< m (Exp (Lore m)) -> m (Exp (Lore m))
forall (m :: * -> *).
MonadBinder m =>
m (Exp (Lore m)) -> m (Exp (Lore m))
eCopy (Exp (Lore m) -> m (Exp (Lore m))
forall (f :: * -> *) a. Applicative f => a -> f a
pure Exp (Lore m)
arg)
BodyT (Lore m) -> m Result
forall (m :: * -> *). MonadBinder m => Body (Lore m) -> m Result
bodyBind BodyT (Lore m)
body
loopMerge :: [Ident] -> [SubExp] -> [(Param DeclType, SubExp)]
loopMerge :: [Ident] -> Result -> [(Param DeclType, SubExp)]
loopMerge [Ident]
vars = [(Ident, Uniqueness)] -> Result -> [(Param DeclType, SubExp)]
loopMerge' ([(Ident, Uniqueness)] -> Result -> [(Param DeclType, SubExp)])
-> [(Ident, Uniqueness)] -> Result -> [(Param DeclType, SubExp)]
forall a b. (a -> b) -> a -> b
$ [Ident] -> [Uniqueness] -> [(Ident, Uniqueness)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Ident]
vars ([Uniqueness] -> [(Ident, Uniqueness)])
-> [Uniqueness] -> [(Ident, Uniqueness)]
forall a b. (a -> b) -> a -> b
$ Uniqueness -> [Uniqueness]
forall a. a -> [a]
repeat Uniqueness
Unique
loopMerge' :: [(Ident, Uniqueness)] -> [SubExp] -> [(Param DeclType, SubExp)]
loopMerge' :: [(Ident, Uniqueness)] -> Result -> [(Param DeclType, SubExp)]
loopMerge' [(Ident, Uniqueness)]
vars Result
vals =
[ (VName -> DeclType -> Param DeclType
forall dec. VName -> dec -> Param dec
Param VName
pname (DeclType -> Param DeclType) -> DeclType -> Param DeclType
forall a b. (a -> b) -> a -> b
$ Type -> Uniqueness -> DeclType
forall shape.
TypeBase shape NoUniqueness
-> Uniqueness -> TypeBase shape Uniqueness
toDecl Type
ptype Uniqueness
u, SubExp
val)
| ((Ident VName
pname Type
ptype, Uniqueness
u), SubExp
val) <- [(Ident, Uniqueness)] -> Result -> [((Ident, Uniqueness), SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [(Ident, Uniqueness)]
vars Result
vals
]