{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE TypeFamilies #-}
module Futhark.Transform.FirstOrderTransform
( transformFunDef,
transformConsts,
FirstOrderLore,
Transformer,
transformStmRecursively,
transformLambda,
transformSOAC,
)
where
import Control.Monad.Except
import Control.Monad.State
import Data.List (find, zip4)
import qualified Data.Map.Strict as M
import qualified Futhark.Analysis.Alias as Alias
import qualified Futhark.IR as AST
import Futhark.IR.Prop.Aliases
import Futhark.IR.SOACS
import Futhark.MonadFreshNames
import Futhark.Tools
import Futhark.Util (chunks, splitAt3)
type FirstOrderLore lore =
( Bindable lore,
BinderOps lore,
LetDec SOACS ~ LetDec lore,
LParamInfo SOACS ~ LParamInfo lore,
CanBeAliased (Op lore)
)
transformFunDef ::
(MonadFreshNames m, FirstOrderLore tolore) =>
Scope tolore ->
FunDef SOACS ->
m (AST.FunDef tolore)
transformFunDef :: forall (m :: * -> *) tolore.
(MonadFreshNames m, FirstOrderLore tolore) =>
Scope tolore -> FunDef SOACS -> m (FunDef tolore)
transformFunDef Scope tolore
consts_scope (FunDef Maybe EntryPoint
entry Attrs
attrs Name
fname [RetType SOACS]
rettype [FParam SOACS]
params BodyT SOACS
body) = do
(BodyT tolore
body', Stms tolore
_) <- (VNameSource -> ((BodyT tolore, Stms tolore), VNameSource))
-> m (BodyT tolore, Stms tolore)
forall (m :: * -> *) a.
MonadFreshNames m =>
(VNameSource -> (a, VNameSource)) -> m a
modifyNameSource ((VNameSource -> ((BodyT tolore, Stms tolore), VNameSource))
-> m (BodyT tolore, Stms tolore))
-> (VNameSource -> ((BodyT tolore, Stms tolore), VNameSource))
-> m (BodyT tolore, Stms tolore)
forall a b. (a -> b) -> a -> b
$ State VNameSource (BodyT tolore, Stms tolore)
-> VNameSource -> ((BodyT tolore, Stms tolore), VNameSource)
forall s a. State s a -> s -> (a, s)
runState (State VNameSource (BodyT tolore, Stms tolore)
-> VNameSource -> ((BodyT tolore, Stms tolore), VNameSource))
-> State VNameSource (BodyT tolore, Stms tolore)
-> VNameSource
-> ((BodyT tolore, Stms tolore), VNameSource)
forall a b. (a -> b) -> a -> b
$ BinderT tolore (StateT VNameSource Identity) (BodyT tolore)
-> Scope tolore -> State VNameSource (BodyT tolore, Stms tolore)
forall (m :: * -> *) lore a.
MonadFreshNames m =>
BinderT lore m a -> Scope lore -> m (a, Stms lore)
runBinderT BinderT tolore (StateT VNameSource Identity) (BodyT tolore)
m Scope tolore
consts_scope
FunDef tolore -> m (FunDef tolore)
forall (m :: * -> *) a. Monad m => a -> m a
return (FunDef tolore -> m (FunDef tolore))
-> FunDef tolore -> m (FunDef tolore)
forall a b. (a -> b) -> a -> b
$ Maybe EntryPoint
-> Attrs
-> Name
-> [RetType tolore]
-> [FParam tolore]
-> BodyT tolore
-> FunDef tolore
forall lore.
Maybe EntryPoint
-> Attrs
-> Name
-> [RetType lore]
-> [FParam lore]
-> BodyT lore
-> FunDef lore
FunDef Maybe EntryPoint
entry Attrs
attrs Name
fname [RetType tolore]
[RetType SOACS]
rettype [FParam tolore]
[FParam SOACS]
params BodyT tolore
body'
where
m :: BinderT tolore (StateT VNameSource Identity) (BodyT tolore)
m = Scope tolore
-> BinderT tolore (StateT VNameSource Identity) (BodyT tolore)
-> BinderT tolore (StateT VNameSource Identity) (BodyT tolore)
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
localScope ([Param DeclType] -> Scope tolore
forall lore dec.
(FParamInfo lore ~ dec) =>
[Param dec] -> Scope lore
scopeOfFParams [Param DeclType]
[FParam SOACS]
params) (BinderT tolore (StateT VNameSource Identity) (BodyT tolore)
-> BinderT tolore (StateT VNameSource Identity) (BodyT tolore))
-> BinderT tolore (StateT VNameSource Identity) (BodyT tolore)
-> BinderT tolore (StateT VNameSource Identity) (BodyT tolore)
forall a b. (a -> b) -> a -> b
$ BodyT SOACS
-> BinderT
tolore
(StateT VNameSource Identity)
(Body (Lore (BinderT tolore (StateT VNameSource Identity))))
forall (m :: * -> *).
(Transformer m, LetDec (Lore m) ~ LetDec SOACS) =>
BodyT SOACS -> m (Body (Lore m))
transformBody BodyT SOACS
body
transformConsts ::
(MonadFreshNames m, FirstOrderLore tolore) =>
Stms SOACS ->
m (AST.Stms tolore)
transformConsts :: forall (m :: * -> *) tolore.
(MonadFreshNames m, FirstOrderLore tolore) =>
Stms SOACS -> m (Stms tolore)
transformConsts Stms SOACS
stms =
(((), Stms tolore) -> Stms tolore)
-> m ((), Stms tolore) -> m (Stms tolore)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ((), Stms tolore) -> Stms tolore
forall a b. (a, b) -> b
snd (m ((), Stms tolore) -> m (Stms tolore))
-> m ((), Stms tolore) -> m (Stms tolore)
forall a b. (a -> b) -> a -> b
$ (VNameSource -> (((), Stms tolore), VNameSource))
-> m ((), Stms tolore)
forall (m :: * -> *) a.
MonadFreshNames m =>
(VNameSource -> (a, VNameSource)) -> m a
modifyNameSource ((VNameSource -> (((), Stms tolore), VNameSource))
-> m ((), Stms tolore))
-> (VNameSource -> (((), Stms tolore), VNameSource))
-> m ((), Stms tolore)
forall a b. (a -> b) -> a -> b
$ State VNameSource ((), Stms tolore)
-> VNameSource -> (((), Stms tolore), VNameSource)
forall s a. State s a -> s -> (a, s)
runState (State VNameSource ((), Stms tolore)
-> VNameSource -> (((), Stms tolore), VNameSource))
-> State VNameSource ((), Stms tolore)
-> VNameSource
-> (((), Stms tolore), VNameSource)
forall a b. (a -> b) -> a -> b
$ BinderT tolore (StateT VNameSource Identity) ()
-> Scope tolore -> State VNameSource ((), Stms tolore)
forall (m :: * -> *) lore a.
MonadFreshNames m =>
BinderT lore m a -> Scope lore -> m (a, Stms lore)
runBinderT BinderT tolore (StateT VNameSource Identity) ()
m Scope tolore
forall a. Monoid a => a
mempty
where
m :: BinderT tolore (StateT VNameSource Identity) ()
m = (Stm -> BinderT tolore (StateT VNameSource Identity) ())
-> Stms SOACS -> BinderT tolore (StateT VNameSource Identity) ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ Stm -> BinderT tolore (StateT VNameSource Identity) ()
forall (m :: * -> *).
(Transformer m, LetDec (Lore m) ~ LetDec SOACS) =>
Stm -> m ()
transformStmRecursively Stms SOACS
stms
type Transformer m =
( MonadBinder m,
LocalScope (Lore m) m,
Bindable (Lore m),
BinderOps (Lore m),
LParamInfo SOACS ~ LParamInfo (Lore m),
CanBeAliased (Op (Lore m))
)
transformBody ::
(Transformer m, LetDec (Lore m) ~ LetDec SOACS) =>
Body ->
m (AST.Body (Lore m))
transformBody :: forall (m :: * -> *).
(Transformer m, LetDec (Lore m) ~ LetDec SOACS) =>
BodyT SOACS -> m (Body (Lore m))
transformBody (Body () Stms SOACS
stms Result
res) = m Result -> m (Body (Lore m))
forall (m :: * -> *).
MonadBinder m =>
m Result -> m (Body (Lore m))
buildBody_ (m Result -> m (Body (Lore m))) -> m Result -> m (Body (Lore m))
forall a b. (a -> b) -> a -> b
$ do
(Stm -> m ()) -> Stms SOACS -> m ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ Stm -> m ()
forall (m :: * -> *).
(Transformer m, LetDec (Lore m) ~ LetDec SOACS) =>
Stm -> m ()
transformStmRecursively Stms SOACS
stms
Result -> m Result
forall (f :: * -> *) a. Applicative f => a -> f a
pure Result
res
transformStmRecursively ::
(Transformer m, LetDec (Lore m) ~ LetDec SOACS) =>
Stm ->
m ()
transformStmRecursively :: forall (m :: * -> *).
(Transformer m, LetDec (Lore m) ~ LetDec SOACS) =>
Stm -> m ()
transformStmRecursively (Let Pattern SOACS
pat StmAux (ExpDec SOACS)
aux (Op Op SOACS
soac)) =
StmAux () -> m () -> m ()
forall (m :: * -> *) anylore a.
MonadBinder m =>
StmAux anylore -> m a -> m a
auxing StmAux ()
StmAux (ExpDec SOACS)
aux (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$ Pattern (Lore m) -> SOAC (Lore m) -> m ()
forall (m :: * -> *).
Transformer m =>
Pattern (Lore m) -> SOAC (Lore m) -> m ()
transformSOAC Pattern (Lore m)
Pattern SOACS
pat (SOAC (Lore m) -> m ()) -> m (SOAC (Lore m)) -> m ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< SOACMapper SOACS (Lore m) m -> SOAC SOACS -> m (SOAC (Lore m))
forall (m :: * -> *) flore tlore.
(Applicative m, Monad m) =>
SOACMapper flore tlore m -> SOAC flore -> m (SOAC tlore)
mapSOACM SOACMapper SOACS (Lore m) m
soacTransform Op SOACS
SOAC SOACS
soac
where
soacTransform :: SOACMapper SOACS (Lore m) m
soacTransform = SOACMapper Any Any m
forall (m :: * -> *) lore. Monad m => SOACMapper lore lore m
identitySOACMapper {mapOnSOACLambda :: Lambda SOACS -> m (Lambda (Lore m))
mapOnSOACLambda = Lambda SOACS -> m (Lambda (Lore m))
forall (m :: * -> *) lore somelore.
(MonadFreshNames m, Bindable lore, BinderOps lore,
LocalScope somelore m, SameScope somelore lore,
LetDec lore ~ LetDec SOACS, CanBeAliased (Op lore)) =>
Lambda SOACS -> m (Lambda lore)
transformLambda}
transformStmRecursively (Let Pattern SOACS
pat StmAux (ExpDec SOACS)
aux ExpT SOACS
e) =
StmAux () -> m () -> m ()
forall (m :: * -> *) anylore a.
MonadBinder m =>
StmAux anylore -> m a -> m a
auxing StmAux ()
StmAux (ExpDec SOACS)
aux (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$ Pattern (Lore m) -> ExpT (Lore m) -> m ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind Pattern (Lore m)
Pattern SOACS
pat (ExpT (Lore m) -> m ()) -> m (ExpT (Lore m)) -> m ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Mapper SOACS (Lore m) m -> ExpT SOACS -> m (ExpT (Lore m))
forall (m :: * -> *) flore tlore.
(Applicative m, Monad m) =>
Mapper flore tlore m -> Exp flore -> m (Exp tlore)
mapExpM Mapper SOACS (Lore m) m
transform ExpT SOACS
e
where
transform :: Mapper SOACS (Lore m) m
transform =
Mapper Any Any m
forall (m :: * -> *) lore. Monad m => Mapper lore lore m
identityMapper
{ mapOnBody :: Scope (Lore m) -> BodyT SOACS -> m (Body (Lore m))
mapOnBody = \Scope (Lore m)
scope -> Scope (Lore m) -> m (Body (Lore m)) -> m (Body (Lore m))
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
localScope Scope (Lore m)
scope (m (Body (Lore m)) -> m (Body (Lore m)))
-> (BodyT SOACS -> m (Body (Lore m)))
-> BodyT SOACS
-> m (Body (Lore m))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BodyT SOACS -> m (Body (Lore m))
forall (m :: * -> *).
(Transformer m, LetDec (Lore m) ~ LetDec SOACS) =>
BodyT SOACS -> m (Body (Lore m))
transformBody,
mapOnRetType :: RetType SOACS -> m (RetType (Lore m))
mapOnRetType = RetType SOACS -> m (RetType (Lore m))
forall (m :: * -> *) a. Monad m => a -> m a
return,
mapOnBranchType :: BranchType SOACS -> m (BranchType (Lore m))
mapOnBranchType = BranchType SOACS -> m (BranchType (Lore m))
forall (m :: * -> *) a. Monad m => a -> m a
return,
mapOnFParam :: FParam SOACS -> m (FParam (Lore m))
mapOnFParam = FParam SOACS -> m (FParam (Lore m))
forall (m :: * -> *) a. Monad m => a -> m a
return,
mapOnLParam :: LParam SOACS -> m (LParam (Lore m))
mapOnLParam = LParam SOACS -> m (LParam (Lore m))
forall (m :: * -> *) a. Monad m => a -> m a
return,
mapOnOp :: Op SOACS -> m (Op (Lore m))
mapOnOp = [Char] -> SOAC SOACS -> m (Op (Lore m))
forall a. HasCallStack => [Char] -> a
error [Char]
"Unhandled Op in first order transform"
}
resultArray :: Transformer m => [VName] -> [Type] -> m [VName]
resultArray :: forall (m :: * -> *).
Transformer m =>
[VName] -> [Type] -> m [VName]
resultArray [VName]
arrs [Type]
ts = do
[Type]
arrs_ts <- (VName -> m Type) -> [VName] -> m [Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM VName -> m Type
forall lore (m :: * -> *). HasScope lore m => VName -> m Type
lookupType [VName]
arrs
let oneArray :: Type -> m VName
oneArray t :: Type
t@Acc {}
| Just (VName
v, Type
_) <- ((VName, Type) -> Bool) -> [(VName, Type)] -> Maybe (VName, Type)
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
find ((Type -> Type -> Bool
forall a. Eq a => a -> a -> Bool
== Type
t) (Type -> Bool) -> ((VName, Type) -> Type) -> (VName, Type) -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (VName, Type) -> Type
forall a b. (a, b) -> b
snd) ([VName] -> [Type] -> [(VName, Type)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
arrs [Type]
arrs_ts) =
VName -> m VName
forall (f :: * -> *) a. Applicative f => a -> f a
pure VName
v
oneArray Type
t =
[Char] -> ExpT (Lore m) -> m VName
forall (m :: * -> *).
MonadBinder m =>
[Char] -> Exp (Lore m) -> m VName
letExp [Char]
"result" (ExpT (Lore m) -> m VName) -> m (ExpT (Lore m)) -> m VName
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Type -> m (ExpT (Lore m))
forall (m :: * -> *). MonadBinder m => Type -> m (Exp (Lore m))
eBlank Type
t
(Type -> m VName) -> [Type] -> m [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Type -> m VName
oneArray [Type]
ts
transformSOAC ::
Transformer m =>
AST.Pattern (Lore m) ->
SOAC (Lore m) ->
m ()
transformSOAC :: forall (m :: * -> *).
Transformer m =>
Pattern (Lore m) -> SOAC (Lore m) -> m ()
transformSOAC Pattern (Lore m)
pat (Screma SubExp
w [VName]
arrs form :: ScremaForm (Lore m)
form@(ScremaForm [Scan (Lore m)]
scans [Reduce (Lore m)]
reds Lambda (Lore m)
map_lam)) = do
let Reduce Commutativity
_ Lambda (Lore m)
red_lam Result
red_nes = [Reduce (Lore m)] -> Reduce (Lore m)
forall lore. Bindable lore => [Reduce lore] -> Reduce lore
singleReduce [Reduce (Lore m)]
reds
Scan Lambda (Lore m)
scan_lam Result
scan_nes = [Scan (Lore m)] -> Scan (Lore m)
forall lore. Bindable lore => [Scan lore] -> Scan lore
singleScan [Scan (Lore m)]
scans
([Type]
scan_arr_ts, [Type]
_red_ts, [Type]
map_arr_ts) =
Int -> Int -> [Type] -> ([Type], [Type], [Type])
forall a. Int -> Int -> [a] -> ([a], [a], [a])
splitAt3 (Result -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length Result
scan_nes) (Result -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length Result
red_nes) ([Type] -> ([Type], [Type], [Type]))
-> [Type] -> ([Type], [Type], [Type])
forall a b. (a -> b) -> a -> b
$ SubExp -> ScremaForm (Lore m) -> [Type]
forall lore. SubExp -> ScremaForm lore -> [Type]
scremaType SubExp
w ScremaForm (Lore m)
form
[VName]
scan_arrs <- [VName] -> [Type] -> m [VName]
forall (m :: * -> *).
Transformer m =>
[VName] -> [Type] -> m [VName]
resultArray [] [Type]
scan_arr_ts
[VName]
map_arrs <- [VName] -> [Type] -> m [VName]
forall (m :: * -> *).
Transformer m =>
[VName] -> [Type] -> m [VName]
resultArray [VName]
arrs [Type]
map_arr_ts
[Param DeclType]
scanacc_params <- (Type -> m (Param DeclType)) -> [Type] -> m [Param DeclType]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM ([Char] -> DeclType -> m (Param DeclType)
forall (m :: * -> *) dec.
MonadFreshNames m =>
[Char] -> dec -> m (Param dec)
newParam [Char]
"scanacc" (DeclType -> m (Param DeclType))
-> (Type -> DeclType) -> Type -> m (Param DeclType)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Type -> Uniqueness -> DeclType) -> Uniqueness -> Type -> DeclType
forall a b c. (a -> b -> c) -> b -> a -> c
flip Type -> Uniqueness -> DeclType
forall shape.
TypeBase shape NoUniqueness
-> Uniqueness -> TypeBase shape Uniqueness
toDecl Uniqueness
Nonunique) ([Type] -> m [Param DeclType]) -> [Type] -> m [Param DeclType]
forall a b. (a -> b) -> a -> b
$ Lambda (Lore m) -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType Lambda (Lore m)
scan_lam
[Param DeclType]
scanout_params <- (Type -> m (Param DeclType)) -> [Type] -> m [Param DeclType]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM ([Char] -> DeclType -> m (Param DeclType)
forall (m :: * -> *) dec.
MonadFreshNames m =>
[Char] -> dec -> m (Param dec)
newParam [Char]
"scanout" (DeclType -> m (Param DeclType))
-> (Type -> DeclType) -> Type -> m (Param DeclType)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Type -> Uniqueness -> DeclType) -> Uniqueness -> Type -> DeclType
forall a b c. (a -> b -> c) -> b -> a -> c
flip Type -> Uniqueness -> DeclType
forall shape.
TypeBase shape NoUniqueness
-> Uniqueness -> TypeBase shape Uniqueness
toDecl Uniqueness
Unique) [Type]
scan_arr_ts
[Param DeclType]
redout_params <- (Type -> m (Param DeclType)) -> [Type] -> m [Param DeclType]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM ([Char] -> DeclType -> m (Param DeclType)
forall (m :: * -> *) dec.
MonadFreshNames m =>
[Char] -> dec -> m (Param dec)
newParam [Char]
"redout" (DeclType -> m (Param DeclType))
-> (Type -> DeclType) -> Type -> m (Param DeclType)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Type -> Uniqueness -> DeclType) -> Uniqueness -> Type -> DeclType
forall a b c. (a -> b -> c) -> b -> a -> c
flip Type -> Uniqueness -> DeclType
forall shape.
TypeBase shape NoUniqueness
-> Uniqueness -> TypeBase shape Uniqueness
toDecl Uniqueness
Nonunique) ([Type] -> m [Param DeclType]) -> [Type] -> m [Param DeclType]
forall a b. (a -> b) -> a -> b
$ Lambda (Lore m) -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType Lambda (Lore m)
red_lam
[Param DeclType]
mapout_params <- (Type -> m (Param DeclType)) -> [Type] -> m [Param DeclType]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM ([Char] -> DeclType -> m (Param DeclType)
forall (m :: * -> *) dec.
MonadFreshNames m =>
[Char] -> dec -> m (Param dec)
newParam [Char]
"mapout" (DeclType -> m (Param DeclType))
-> (Type -> DeclType) -> Type -> m (Param DeclType)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Type -> Uniqueness -> DeclType) -> Uniqueness -> Type -> DeclType
forall a b c. (a -> b -> c) -> b -> a -> c
flip Type -> Uniqueness -> DeclType
forall shape.
TypeBase shape NoUniqueness
-> Uniqueness -> TypeBase shape Uniqueness
toDecl Uniqueness
Unique) [Type]
map_arr_ts
[Type]
arr_ts <- (VName -> m Type) -> [VName] -> m [Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM VName -> m Type
forall lore (m :: * -> *). HasScope lore m => VName -> m Type
lookupType [VName]
arrs
let paramForAcc :: Type -> Maybe (Param DeclType)
paramForAcc (Acc VName
c ShapeBase SubExp
_ [Type]
_ NoUniqueness
_) = (Param DeclType -> Bool)
-> [Param DeclType] -> Maybe (Param DeclType)
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
find (Type -> Bool
f (Type -> Bool)
-> (Param DeclType -> Type) -> Param DeclType -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Param DeclType -> Type
forall dec. Typed dec => Param dec -> Type
paramType) [Param DeclType]
mapout_params
where
f :: Type -> Bool
f (Acc VName
c2 ShapeBase SubExp
_ [Type]
_ NoUniqueness
_) = VName
c VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== VName
c2
f Type
_ = Bool
False
paramForAcc Type
_ = Maybe (Param DeclType)
forall a. Maybe a
Nothing
let merge :: [(Param DeclType, SubExp)]
merge =
[[(Param DeclType, SubExp)]] -> [(Param DeclType, SubExp)]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat
[ [Param DeclType] -> Result -> [(Param DeclType, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param DeclType]
scanacc_params Result
scan_nes,
[Param DeclType] -> Result -> [(Param DeclType, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param DeclType]
scanout_params (Result -> [(Param DeclType, SubExp)])
-> Result -> [(Param DeclType, SubExp)]
forall a b. (a -> b) -> a -> b
$ (VName -> SubExp) -> [VName] -> Result
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var [VName]
scan_arrs,
[Param DeclType] -> Result -> [(Param DeclType, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param DeclType]
redout_params Result
red_nes,
[Param DeclType] -> Result -> [(Param DeclType, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param DeclType]
mapout_params (Result -> [(Param DeclType, SubExp)])
-> Result -> [(Param DeclType, SubExp)]
forall a b. (a -> b) -> a -> b
$ (VName -> SubExp) -> [VName] -> Result
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var [VName]
map_arrs
]
VName
i <- [Char] -> m VName
forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
"i"
let loopform :: LoopForm (Lore m)
loopform = VName
-> IntType
-> SubExp
-> [(LParam (Lore m), VName)]
-> LoopForm (Lore m)
forall lore.
VName
-> IntType -> SubExp -> [(LParam lore, VName)] -> LoopForm lore
ForLoop VName
i IntType
Int64 SubExp
w []
lam_cons :: Names
lam_cons = Lambda (Aliases (Lore m)) -> Names
forall lore. Aliased lore => Lambda lore -> Names
consumedByLambda (Lambda (Aliases (Lore m)) -> Names)
-> Lambda (Aliases (Lore m)) -> Names
forall a b. (a -> b) -> a -> b
$ AliasTable -> Lambda (Lore m) -> Lambda (Aliases (Lore m))
forall lore.
(ASTLore lore, CanBeAliased (Op lore)) =>
AliasTable -> Lambda lore -> Lambda (Aliases lore)
Alias.analyseLambda AliasTable
forall a. Monoid a => a
mempty Lambda (Lore m)
map_lam
Body (Lore m)
loop_body <- Binder (Lore m) (Body (Lore m)) -> m (Body (Lore m))
forall lore (m :: * -> *) somelore.
(Bindable lore, MonadFreshNames m, HasScope somelore m,
SameScope somelore lore) =>
Binder lore (Body lore) -> m (Body lore)
runBodyBinder
(Binder (Lore m) (Body (Lore m)) -> m (Body (Lore m)))
-> (Binder (Lore m) (Body (Lore m))
-> Binder (Lore m) (Body (Lore m)))
-> Binder (Lore m) (Body (Lore m))
-> m (Body (Lore m))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Scope (Lore m)
-> Binder (Lore m) (Body (Lore m))
-> Binder (Lore m) (Body (Lore m))
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
localScope ([Param DeclType] -> Scope (Lore m)
forall lore dec.
(FParamInfo lore ~ dec) =>
[Param dec] -> Scope lore
scopeOfFParams (((Param DeclType, SubExp) -> Param DeclType)
-> [(Param DeclType, SubExp)] -> [Param DeclType]
forall a b. (a -> b) -> [a] -> [b]
map (Param DeclType, SubExp) -> Param DeclType
forall a b. (a, b) -> a
fst [(Param DeclType, SubExp)]
merge) Scope (Lore m) -> Scope (Lore m) -> Scope (Lore m)
forall a. Semigroup a => a -> a -> a
<> LoopForm (Lore m) -> Scope (Lore m)
forall lore a. Scoped lore a => a -> Scope lore
scopeOf LoopForm (Lore m)
loopform)
(Binder (Lore m) (Body (Lore m)) -> m (Body (Lore m)))
-> Binder (Lore m) (Body (Lore m)) -> m (Body (Lore m))
forall a b. (a -> b) -> a -> b
$ do
[(Param Type, VName, Type)]
-> ((Param Type, VName, Type)
-> BinderT (Lore m) (State VNameSource) ())
-> BinderT (Lore m) (State VNameSource) ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param Type] -> [VName] -> [Type] -> [(Param Type, VName, Type)]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 (Lambda (Lore m) -> [LParam (Lore m)]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda (Lore m)
map_lam) [VName]
arrs [Type]
arr_ts) (((Param Type, VName, Type)
-> BinderT (Lore m) (State VNameSource) ())
-> BinderT (Lore m) (State VNameSource) ())
-> ((Param Type, VName, Type)
-> BinderT (Lore m) (State VNameSource) ())
-> BinderT (Lore m) (State VNameSource) ()
forall a b. (a -> b) -> a -> b
$ \(Param Type
p, VName
arr, Type
arr_t) ->
case Type -> Maybe (Param DeclType)
paramForAcc Type
arr_t of
Just Param DeclType
acc_out_p ->
[VName]
-> Exp (Lore (BinderT (Lore m) (State VNameSource)))
-> BinderT (Lore m) (State VNameSource) ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m ()
letBindNames [Param Type -> VName
forall dec. Param dec -> VName
paramName Param Type
p] (ExpT (Lore m) -> BinderT (Lore m) (State VNameSource) ())
-> (BasicOp -> ExpT (Lore m))
-> BasicOp
-> BinderT (Lore m) (State VNameSource) ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BasicOp -> ExpT (Lore m)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> BinderT (Lore m) (State VNameSource) ())
-> BasicOp -> BinderT (Lore m) (State VNameSource) ()
forall a b. (a -> b) -> a -> b
$
SubExp -> BasicOp
SubExp (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ Param DeclType -> VName
forall dec. Param dec -> VName
paramName Param DeclType
acc_out_p
Maybe (Param DeclType)
Nothing
| Param Type -> VName
forall dec. Param dec -> VName
paramName Param Type
p VName -> Names -> Bool
`nameIn` Names
lam_cons -> do
VName
p' <-
[Char]
-> Exp (Lore (BinderT (Lore m) (State VNameSource)))
-> BinderT (Lore m) (State VNameSource) VName
forall (m :: * -> *).
MonadBinder m =>
[Char] -> Exp (Lore m) -> m VName
letExp (VName -> [Char]
baseString (Param Type -> VName
forall dec. Param dec -> VName
paramName Param Type
p)) (Exp (Lore (BinderT (Lore m) (State VNameSource)))
-> BinderT (Lore m) (State VNameSource) VName)
-> Exp (Lore (BinderT (Lore m) (State VNameSource)))
-> BinderT (Lore m) (State VNameSource) VName
forall a b. (a -> b) -> a -> b
$
BasicOp -> ExpT (Lore m)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT (Lore m)) -> BasicOp -> ExpT (Lore m)
forall a b. (a -> b) -> a -> b
$
VName -> Slice SubExp -> BasicOp
Index VName
arr (Slice SubExp -> BasicOp) -> Slice SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ Type -> Slice SubExp -> Slice SubExp
fullSlice Type
arr_t [SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix (SubExp -> DimIndex SubExp) -> SubExp -> DimIndex SubExp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
i]
[VName]
-> Exp (Lore (BinderT (Lore m) (State VNameSource)))
-> BinderT (Lore m) (State VNameSource) ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m ()
letBindNames [Param Type -> VName
forall dec. Param dec -> VName
paramName Param Type
p] (Exp (Lore (BinderT (Lore m) (State VNameSource)))
-> BinderT (Lore m) (State VNameSource) ())
-> Exp (Lore (BinderT (Lore m) (State VNameSource)))
-> BinderT (Lore m) (State VNameSource) ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT (Lore m)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT (Lore m)) -> BasicOp -> ExpT (Lore m)
forall a b. (a -> b) -> a -> b
$ VName -> BasicOp
Copy VName
p'
| Bool
otherwise ->
[VName]
-> Exp (Lore (BinderT (Lore m) (State VNameSource)))
-> BinderT (Lore m) (State VNameSource) ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m ()
letBindNames [Param Type -> VName
forall dec. Param dec -> VName
paramName Param Type
p] (Exp (Lore (BinderT (Lore m) (State VNameSource)))
-> BinderT (Lore m) (State VNameSource) ())
-> Exp (Lore (BinderT (Lore m) (State VNameSource)))
-> BinderT (Lore m) (State VNameSource) ()
forall a b. (a -> b) -> a -> b
$
BasicOp -> ExpT (Lore m)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT (Lore m)) -> BasicOp -> ExpT (Lore m)
forall a b. (a -> b) -> a -> b
$
VName -> Slice SubExp -> BasicOp
Index VName
arr (Slice SubExp -> BasicOp) -> Slice SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ Type -> Slice SubExp -> Slice SubExp
fullSlice Type
arr_t [SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix (SubExp -> DimIndex SubExp) -> SubExp -> DimIndex SubExp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
i]
(Stm (Lore m) -> BinderT (Lore m) (State VNameSource) ())
-> Seq (Stm (Lore m)) -> BinderT (Lore m) (State VNameSource) ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ Stm (Lore m) -> BinderT (Lore m) (State VNameSource) ()
forall (m :: * -> *). MonadBinder m => Stm (Lore m) -> m ()
addStm (Seq (Stm (Lore m)) -> BinderT (Lore m) (State VNameSource) ())
-> Seq (Stm (Lore m)) -> BinderT (Lore m) (State VNameSource) ()
forall a b. (a -> b) -> a -> b
$ Body (Lore m) -> Seq (Stm (Lore m))
forall lore. BodyT lore -> Stms lore
bodyStms (Body (Lore m) -> Seq (Stm (Lore m)))
-> Body (Lore m) -> Seq (Stm (Lore m))
forall a b. (a -> b) -> a -> b
$ Lambda (Lore m) -> Body (Lore m)
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda (Lore m)
map_lam
let (Result
scan_res, Result
red_res, Result
map_res) =
Int -> Int -> Result -> (Result, Result, Result)
forall a. Int -> Int -> [a] -> ([a], [a], [a])
splitAt3 (Result -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length Result
scan_nes) (Result -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length Result
red_nes) (Result -> (Result, Result, Result))
-> Result -> (Result, Result, Result)
forall a b. (a -> b) -> a -> b
$
Body (Lore m) -> Result
forall lore. BodyT lore -> Result
bodyResult (Body (Lore m) -> Result) -> Body (Lore m) -> Result
forall a b. (a -> b) -> a -> b
$ Lambda (Lore m) -> Body (Lore m)
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda (Lore m)
map_lam
Result
scan_res' <-
Lambda (Lore (BinderT (Lore m) (State VNameSource)))
-> [BinderT
(Lore m)
(State VNameSource)
(Exp (Lore (BinderT (Lore m) (State VNameSource))))]
-> BinderT (Lore m) (State VNameSource) Result
forall (m :: * -> *).
MonadBinder m =>
Lambda (Lore m) -> [m (Exp (Lore m))] -> m Result
eLambda Lambda (Lore m)
Lambda (Lore (BinderT (Lore m) (State VNameSource)))
scan_lam ([BinderT
(Lore m)
(State VNameSource)
(Exp (Lore (BinderT (Lore m) (State VNameSource))))]
-> BinderT (Lore m) (State VNameSource) Result)
-> [BinderT
(Lore m)
(State VNameSource)
(Exp (Lore (BinderT (Lore m) (State VNameSource))))]
-> BinderT (Lore m) (State VNameSource) Result
forall a b. (a -> b) -> a -> b
$
(SubExp -> BinderT (Lore m) (State VNameSource) (ExpT (Lore m)))
-> Result -> [BinderT (Lore m) (State VNameSource) (ExpT (Lore m))]
forall a b. (a -> b) -> [a] -> [b]
map (ExpT (Lore m)
-> BinderT (Lore m) (State VNameSource) (ExpT (Lore m))
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ExpT (Lore m)
-> BinderT (Lore m) (State VNameSource) (ExpT (Lore m)))
-> (SubExp -> ExpT (Lore m))
-> SubExp
-> BinderT (Lore m) (State VNameSource) (ExpT (Lore m))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BasicOp -> ExpT (Lore m)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT (Lore m))
-> (SubExp -> BasicOp) -> SubExp -> ExpT (Lore m)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SubExp -> BasicOp
SubExp) (Result -> [BinderT (Lore m) (State VNameSource) (ExpT (Lore m))])
-> Result -> [BinderT (Lore m) (State VNameSource) (ExpT (Lore m))]
forall a b. (a -> b) -> a -> b
$
(Param DeclType -> SubExp) -> [Param DeclType] -> Result
forall a b. (a -> b) -> [a] -> [b]
map (VName -> SubExp
Var (VName -> SubExp)
-> (Param DeclType -> VName) -> Param DeclType -> SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Param DeclType -> VName
forall dec. Param dec -> VName
paramName) [Param DeclType]
scanacc_params Result -> Result -> Result
forall a. [a] -> [a] -> [a]
++ Result
scan_res
Result
red_res' <-
Lambda (Lore (BinderT (Lore m) (State VNameSource)))
-> [BinderT
(Lore m)
(State VNameSource)
(Exp (Lore (BinderT (Lore m) (State VNameSource))))]
-> BinderT (Lore m) (State VNameSource) Result
forall (m :: * -> *).
MonadBinder m =>
Lambda (Lore m) -> [m (Exp (Lore m))] -> m Result
eLambda Lambda (Lore m)
Lambda (Lore (BinderT (Lore m) (State VNameSource)))
red_lam ([BinderT
(Lore m)
(State VNameSource)
(Exp (Lore (BinderT (Lore m) (State VNameSource))))]
-> BinderT (Lore m) (State VNameSource) Result)
-> [BinderT
(Lore m)
(State VNameSource)
(Exp (Lore (BinderT (Lore m) (State VNameSource))))]
-> BinderT (Lore m) (State VNameSource) Result
forall a b. (a -> b) -> a -> b
$
(SubExp -> BinderT (Lore m) (State VNameSource) (ExpT (Lore m)))
-> Result -> [BinderT (Lore m) (State VNameSource) (ExpT (Lore m))]
forall a b. (a -> b) -> [a] -> [b]
map (ExpT (Lore m)
-> BinderT (Lore m) (State VNameSource) (ExpT (Lore m))
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ExpT (Lore m)
-> BinderT (Lore m) (State VNameSource) (ExpT (Lore m)))
-> (SubExp -> ExpT (Lore m))
-> SubExp
-> BinderT (Lore m) (State VNameSource) (ExpT (Lore m))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BasicOp -> ExpT (Lore m)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT (Lore m))
-> (SubExp -> BasicOp) -> SubExp -> ExpT (Lore m)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SubExp -> BasicOp
SubExp) (Result -> [BinderT (Lore m) (State VNameSource) (ExpT (Lore m))])
-> Result -> [BinderT (Lore m) (State VNameSource) (ExpT (Lore m))]
forall a b. (a -> b) -> a -> b
$
(Param DeclType -> SubExp) -> [Param DeclType] -> Result
forall a b. (a -> b) -> [a] -> [b]
map (VName -> SubExp
Var (VName -> SubExp)
-> (Param DeclType -> VName) -> Param DeclType -> SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Param DeclType -> VName
forall dec. Param dec -> VName
paramName) [Param DeclType]
redout_params Result -> Result -> Result
forall a. [a] -> [a] -> [a]
++ Result
red_res
[VName]
scan_outarrs <-
[VName]
-> SubExp -> Result -> BinderT (Lore m) (State VNameSource) [VName]
forall (m :: * -> *).
Transformer m =>
[VName] -> SubExp -> Result -> m [VName]
letwith ((Param DeclType -> VName) -> [Param DeclType] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param DeclType -> VName
forall dec. Param dec -> VName
paramName [Param DeclType]
scanout_params) (VName -> SubExp
Var VName
i) Result
scan_res'
[VName]
map_outarrs <-
[VName]
-> SubExp -> Result -> BinderT (Lore m) (State VNameSource) [VName]
forall (m :: * -> *).
Transformer m =>
[VName] -> SubExp -> Result -> m [VName]
letwith ((Param DeclType -> VName) -> [Param DeclType] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param DeclType -> VName
forall dec. Param dec -> VName
paramName [Param DeclType]
mapout_params) (VName -> SubExp
Var VName
i) Result
map_res
Body (Lore m) -> Binder (Lore m) (Body (Lore m))
forall (m :: * -> *) a. Monad m => a -> m a
return (Body (Lore m) -> Binder (Lore m) (Body (Lore m)))
-> Body (Lore m) -> Binder (Lore m) (Body (Lore m))
forall a b. (a -> b) -> a -> b
$
Result -> Body (Lore m)
forall lore. Bindable lore => Result -> Body lore
resultBody (Result -> Body (Lore m)) -> Result -> Body (Lore m)
forall a b. (a -> b) -> a -> b
$
[Result] -> Result
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat
[ Result
scan_res',
(VName -> SubExp) -> [VName] -> Result
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var [VName]
scan_outarrs,
Result
red_res',
(VName -> SubExp) -> [VName] -> Result
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var [VName]
map_outarrs
]
[VName]
names <-
([VName] -> [VName] -> [VName]
forall a. [a] -> [a] -> [a]
++ Pattern (Lore m) -> [VName]
forall dec. PatternT dec -> [VName]
patternNames Pattern (Lore m)
pat)
([VName] -> [VName]) -> m [VName] -> m [VName]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Int -> m VName -> m [VName]
forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM ([Param DeclType] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Param DeclType]
scanacc_params) ([Char] -> m VName
forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
"discard")
[VName] -> ExpT (Lore m) -> m ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m ()
letBindNames [VName]
names (ExpT (Lore m) -> m ()) -> ExpT (Lore m) -> m ()
forall a b. (a -> b) -> a -> b
$ [(FParam (Lore m), SubExp)]
-> [(FParam (Lore m), SubExp)]
-> LoopForm (Lore m)
-> Body (Lore m)
-> ExpT (Lore m)
forall lore.
[(FParam lore, SubExp)]
-> [(FParam lore, SubExp)]
-> LoopForm lore
-> BodyT lore
-> ExpT lore
DoLoop [] [(Param DeclType, SubExp)]
[(FParam (Lore m), SubExp)]
merge LoopForm (Lore m)
loopform Body (Lore m)
loop_body
transformSOAC Pattern (Lore m)
pat (Stream SubExp
w [VName]
arrs StreamForm (Lore m)
_ Result
nes Lambda (Lore m)
lam) = do
let (Param Type
chunk_size_param, [Param Type]
fold_params, [Param Type]
chunk_params) =
Int -> [Param Type] -> (Param Type, [Param Type], [Param Type])
forall dec.
Int -> [Param dec] -> (Param dec, [Param dec], [Param dec])
partitionChunkedFoldParameters (Result -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length Result
nes) ([Param Type] -> (Param Type, [Param Type], [Param Type]))
-> [Param Type] -> (Param Type, [Param Type], [Param Type])
forall a b. (a -> b) -> a -> b
$ Lambda (Lore m) -> [LParam (Lore m)]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda (Lore m)
lam
[(Param DeclType, SubExp)]
mapout_merge <- [Type]
-> (Type -> m (Param DeclType, SubExp))
-> m [(Param DeclType, SubExp)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM (Int -> [Type] -> [Type]
forall a. Int -> [a] -> [a]
drop (Result -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length Result
nes) ([Type] -> [Type]) -> [Type] -> [Type]
forall a b. (a -> b) -> a -> b
$ Lambda (Lore m) -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType Lambda (Lore m)
lam) ((Type -> m (Param DeclType, SubExp))
-> m [(Param DeclType, SubExp)])
-> (Type -> m (Param DeclType, SubExp))
-> m [(Param DeclType, SubExp)]
forall a b. (a -> b) -> a -> b
$ \Type
t ->
let t' :: Type
t' = Type
t Type -> SubExp -> Type
forall d u.
ArrayShape (ShapeBase d) =>
TypeBase (ShapeBase d) u -> d -> TypeBase (ShapeBase d) u
`setOuterSize` SubExp
w
scratch :: ExpT (Lore m)
scratch = BasicOp -> ExpT (Lore m)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT (Lore m)) -> BasicOp -> ExpT (Lore m)
forall a b. (a -> b) -> a -> b
$ PrimType -> Result -> BasicOp
Scratch (Type -> PrimType
forall shape u. TypeBase shape u -> PrimType
elemType Type
t') (Type -> Result
forall u. TypeBase (ShapeBase SubExp) u -> Result
arrayDims Type
t')
in (,)
(Param DeclType -> SubExp -> (Param DeclType, SubExp))
-> m (Param DeclType) -> m (SubExp -> (Param DeclType, SubExp))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Char] -> DeclType -> m (Param DeclType)
forall (m :: * -> *) dec.
MonadFreshNames m =>
[Char] -> dec -> m (Param dec)
newParam [Char]
"stream_mapout" (Type -> Uniqueness -> DeclType
forall shape.
TypeBase shape NoUniqueness
-> Uniqueness -> TypeBase shape Uniqueness
toDecl Type
t' Uniqueness
Unique)
m (SubExp -> (Param DeclType, SubExp))
-> m SubExp -> m (Param DeclType, SubExp)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> [Char] -> ExpT (Lore m) -> m SubExp
forall (m :: * -> *).
MonadBinder m =>
[Char] -> Exp (Lore m) -> m SubExp
letSubExp [Char]
"stream_mapout_scratch" ExpT (Lore m)
scratch
let onType :: TypeBase shape NoUniqueness -> TypeBase shape Uniqueness
onType t :: TypeBase shape NoUniqueness
t@Acc {} = TypeBase shape NoUniqueness
t TypeBase shape NoUniqueness
-> Uniqueness -> TypeBase shape Uniqueness
forall shape.
TypeBase shape NoUniqueness
-> Uniqueness -> TypeBase shape Uniqueness
`toDecl` Uniqueness
Unique
onType TypeBase shape NoUniqueness
t = TypeBase shape NoUniqueness
t TypeBase shape NoUniqueness
-> Uniqueness -> TypeBase shape Uniqueness
forall shape.
TypeBase shape NoUniqueness
-> Uniqueness -> TypeBase shape Uniqueness
`toDecl` Uniqueness
Nonunique
merge :: [(Param DeclType, SubExp)]
merge = [Param DeclType] -> Result -> [(Param DeclType, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip ((Param Type -> Param DeclType) -> [Param Type] -> [Param DeclType]
forall a b. (a -> b) -> [a] -> [b]
map ((Type -> DeclType) -> Param Type -> Param DeclType
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Type -> DeclType
forall {shape}.
TypeBase shape NoUniqueness -> TypeBase shape Uniqueness
onType) [Param Type]
fold_params) Result
nes [(Param DeclType, SubExp)]
-> [(Param DeclType, SubExp)] -> [(Param DeclType, SubExp)]
forall a. [a] -> [a] -> [a]
++ [(Param DeclType, SubExp)]
mapout_merge
merge_params :: [Param DeclType]
merge_params = ((Param DeclType, SubExp) -> Param DeclType)
-> [(Param DeclType, SubExp)] -> [Param DeclType]
forall a b. (a -> b) -> [a] -> [b]
map (Param DeclType, SubExp) -> Param DeclType
forall a b. (a, b) -> a
fst [(Param DeclType, SubExp)]
merge
mapout_params :: [Param DeclType]
mapout_params = ((Param DeclType, SubExp) -> Param DeclType)
-> [(Param DeclType, SubExp)] -> [Param DeclType]
forall a b. (a -> b) -> [a] -> [b]
map (Param DeclType, SubExp) -> Param DeclType
forall a b. (a, b) -> a
fst [(Param DeclType, SubExp)]
mapout_merge
VName
i <- [Char] -> m VName
forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
"i"
let loop_form :: LoopForm (Lore m)
loop_form = VName
-> IntType
-> SubExp
-> [(LParam (Lore m), VName)]
-> LoopForm (Lore m)
forall lore.
VName
-> IntType -> SubExp -> [(LParam lore, VName)] -> LoopForm lore
ForLoop VName
i IntType
Int64 SubExp
w []
[VName] -> ExpT (Lore m) -> m ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m ()
letBindNames [Param Type -> VName
forall dec. Param dec -> VName
paramName Param Type
chunk_size_param] (ExpT (Lore m) -> m ()) -> ExpT (Lore m) -> m ()
forall a b. (a -> b) -> a -> b
$
BasicOp -> ExpT (Lore m)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT (Lore m)) -> BasicOp -> ExpT (Lore m)
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1
Body (Lore m)
loop_body <- Binder (Lore m) (Body (Lore m)) -> m (Body (Lore m))
forall lore (m :: * -> *) somelore.
(Bindable lore, MonadFreshNames m, HasScope somelore m,
SameScope somelore lore) =>
Binder lore (Body lore) -> m (Body lore)
runBodyBinder (Binder (Lore m) (Body (Lore m)) -> m (Body (Lore m)))
-> Binder (Lore m) (Body (Lore m)) -> m (Body (Lore m))
forall a b. (a -> b) -> a -> b
$
Scope (Lore m)
-> Binder (Lore m) (Body (Lore m))
-> Binder (Lore m) (Body (Lore m))
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
localScope
( LoopForm (Lore m) -> Scope (Lore m)
forall lore a. Scoped lore a => a -> Scope lore
scopeOf LoopForm (Lore m)
loop_form
Scope (Lore m) -> Scope (Lore m) -> Scope (Lore m)
forall a. Semigroup a => a -> a -> a
<> [Param DeclType] -> Scope (Lore m)
forall lore dec.
(FParamInfo lore ~ dec) =>
[Param dec] -> Scope lore
scopeOfFParams [Param DeclType]
merge_params
)
(Binder (Lore m) (Body (Lore m))
-> Binder (Lore m) (Body (Lore m)))
-> Binder (Lore m) (Body (Lore m))
-> Binder (Lore m) (Body (Lore m))
forall a b. (a -> b) -> a -> b
$ do
let slice :: Slice SubExp
slice =
[SubExp -> SubExp -> SubExp -> DimIndex SubExp
forall d. d -> d -> d -> DimIndex d
DimSlice (VName -> SubExp
Var VName
i) (VName -> SubExp
Var (Param Type -> VName
forall dec. Param dec -> VName
paramName Param Type
chunk_size_param)) (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1)]
[(Param Type, VName)]
-> ((Param Type, VName) -> BinderT (Lore m) (State VNameSource) ())
-> BinderT (Lore m) (State VNameSource) ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param Type] -> [VName] -> [(Param Type, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param Type]
chunk_params [VName]
arrs) (((Param Type, VName) -> BinderT (Lore m) (State VNameSource) ())
-> BinderT (Lore m) (State VNameSource) ())
-> ((Param Type, VName) -> BinderT (Lore m) (State VNameSource) ())
-> BinderT (Lore m) (State VNameSource) ()
forall a b. (a -> b) -> a -> b
$ \(Param Type
p, VName
arr) ->
[VName]
-> Exp (Lore (BinderT (Lore m) (State VNameSource)))
-> BinderT (Lore m) (State VNameSource) ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m ()
letBindNames [Param Type -> VName
forall dec. Param dec -> VName
paramName Param Type
p] (Exp (Lore (BinderT (Lore m) (State VNameSource)))
-> BinderT (Lore m) (State VNameSource) ())
-> Exp (Lore (BinderT (Lore m) (State VNameSource)))
-> BinderT (Lore m) (State VNameSource) ()
forall a b. (a -> b) -> a -> b
$
BasicOp -> ExpT (Lore m)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT (Lore m)) -> BasicOp -> ExpT (Lore m)
forall a b. (a -> b) -> a -> b
$
VName -> Slice SubExp -> BasicOp
Index VName
arr (Slice SubExp -> BasicOp) -> Slice SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$
Type -> Slice SubExp -> Slice SubExp
fullSlice (Param Type -> Type
forall dec. Typed dec => Param dec -> Type
paramType Param Type
p) Slice SubExp
slice
(Result
res, Result
mapout_res) <- Int -> Result -> (Result, Result)
forall a. Int -> [a] -> ([a], [a])
splitAt (Result -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length Result
nes) (Result -> (Result, Result))
-> BinderT (Lore m) (State VNameSource) Result
-> BinderT (Lore m) (State VNameSource) (Result, Result)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Body (Lore (BinderT (Lore m) (State VNameSource)))
-> BinderT (Lore m) (State VNameSource) Result
forall (m :: * -> *). MonadBinder m => Body (Lore m) -> m Result
bodyBind (Lambda (Lore m) -> Body (Lore m)
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda (Lore m)
lam)
Result
mapout_res' <- [(Param DeclType, SubExp)]
-> ((Param DeclType, SubExp)
-> BinderT (Lore m) (State VNameSource) SubExp)
-> BinderT (Lore m) (State VNameSource) Result
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM ([Param DeclType] -> Result -> [(Param DeclType, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param DeclType]
mapout_params Result
mapout_res) (((Param DeclType, SubExp)
-> BinderT (Lore m) (State VNameSource) SubExp)
-> BinderT (Lore m) (State VNameSource) Result)
-> ((Param DeclType, SubExp)
-> BinderT (Lore m) (State VNameSource) SubExp)
-> BinderT (Lore m) (State VNameSource) Result
forall a b. (a -> b) -> a -> b
$ \(Param DeclType
p, SubExp
se) ->
[Char]
-> Exp (Lore (BinderT (Lore m) (State VNameSource)))
-> BinderT (Lore m) (State VNameSource) SubExp
forall (m :: * -> *).
MonadBinder m =>
[Char] -> Exp (Lore m) -> m SubExp
letSubExp [Char]
"mapout_res" (Exp (Lore (BinderT (Lore m) (State VNameSource)))
-> BinderT (Lore m) (State VNameSource) SubExp)
-> Exp (Lore (BinderT (Lore m) (State VNameSource)))
-> BinderT (Lore m) (State VNameSource) SubExp
forall a b. (a -> b) -> a -> b
$
BasicOp -> ExpT (Lore m)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT (Lore m)) -> BasicOp -> ExpT (Lore m)
forall a b. (a -> b) -> a -> b
$
VName -> Slice SubExp -> SubExp -> BasicOp
Update
(Param DeclType -> VName
forall dec. Param dec -> VName
paramName Param DeclType
p)
(Type -> Slice SubExp -> Slice SubExp
fullSlice (Param DeclType -> Type
forall dec. Typed dec => Param dec -> Type
paramType Param DeclType
p) Slice SubExp
slice)
SubExp
se
Result
-> BinderT
(Lore m)
(State VNameSource)
(Body (Lore (BinderT (Lore m) (State VNameSource))))
forall (m :: * -> *). MonadBinder m => Result -> m (Body (Lore m))
resultBodyM (Result
-> BinderT
(Lore m)
(State VNameSource)
(Body (Lore (BinderT (Lore m) (State VNameSource)))))
-> Result
-> BinderT
(Lore m)
(State VNameSource)
(Body (Lore (BinderT (Lore m) (State VNameSource))))
forall a b. (a -> b) -> a -> b
$ Result
res Result -> Result -> Result
forall a. [a] -> [a] -> [a]
++ Result
mapout_res'
Pattern (Lore m) -> ExpT (Lore m) -> m ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind Pattern (Lore m)
pat (ExpT (Lore m) -> m ()) -> ExpT (Lore m) -> m ()
forall a b. (a -> b) -> a -> b
$ [(FParam (Lore m), SubExp)]
-> [(FParam (Lore m), SubExp)]
-> LoopForm (Lore m)
-> Body (Lore m)
-> ExpT (Lore m)
forall lore.
[(FParam lore, SubExp)]
-> [(FParam lore, SubExp)]
-> LoopForm lore
-> BodyT lore
-> ExpT lore
DoLoop [] [(Param DeclType, SubExp)]
[(FParam (Lore m), SubExp)]
merge LoopForm (Lore m)
loop_form Body (Lore m)
loop_body
transformSOAC Pattern (Lore m)
pat (Scatter SubExp
len Lambda (Lore m)
lam [VName]
ivs [(ShapeBase SubExp, Int, VName)]
as) = do
VName
iter <- [Char] -> m VName
forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
"write_iter"
let ([ShapeBase SubExp]
as_ws, [Int]
as_ns, [VName]
as_vs) = [(ShapeBase SubExp, Int, VName)]
-> ([ShapeBase SubExp], [Int], [VName])
forall a b c. [(a, b, c)] -> ([a], [b], [c])
unzip3 [(ShapeBase SubExp, Int, VName)]
as
[Type]
ts <- (VName -> m Type) -> [VName] -> m [Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM VName -> m Type
forall lore (m :: * -> *). HasScope lore m => VName -> m Type
lookupType [VName]
as_vs
[Ident]
asOuts <- (Type -> m Ident) -> [Type] -> m [Ident]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM ([Char] -> Type -> m Ident
forall (m :: * -> *).
MonadFreshNames m =>
[Char] -> Type -> m Ident
newIdent [Char]
"write_out") [Type]
ts
let merge :: [(Param DeclType, SubExp)]
merge = [Ident] -> Result -> [(Param DeclType, SubExp)]
loopMerge [Ident]
asOuts (Result -> [(Param DeclType, SubExp)])
-> Result -> [(Param DeclType, SubExp)]
forall a b. (a -> b) -> a -> b
$ (VName -> SubExp) -> [VName] -> Result
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var [VName]
as_vs
Body (Lore m)
loopBody <- Binder (Lore m) (Body (Lore m)) -> m (Body (Lore m))
forall lore (m :: * -> *) somelore.
(Bindable lore, MonadFreshNames m, HasScope somelore m,
SameScope somelore lore) =>
Binder lore (Body lore) -> m (Body lore)
runBodyBinder (Binder (Lore m) (Body (Lore m)) -> m (Body (Lore m)))
-> Binder (Lore m) (Body (Lore m)) -> m (Body (Lore m))
forall a b. (a -> b) -> a -> b
$
Scope (Lore m)
-> Binder (Lore m) (Body (Lore m))
-> Binder (Lore m) (Body (Lore m))
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
localScope
( VName -> NameInfo (Lore m) -> Scope (Lore m) -> Scope (Lore m)
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
iter (IntType -> NameInfo (Lore m)
forall lore. IntType -> NameInfo lore
IndexName IntType
Int64) (Scope (Lore m) -> Scope (Lore m))
-> Scope (Lore m) -> Scope (Lore m)
forall a b. (a -> b) -> a -> b
$
[Param DeclType] -> Scope (Lore m)
forall lore dec.
(FParamInfo lore ~ dec) =>
[Param dec] -> Scope lore
scopeOfFParams ([Param DeclType] -> Scope (Lore m))
-> [Param DeclType] -> Scope (Lore m)
forall a b. (a -> b) -> a -> b
$ ((Param DeclType, SubExp) -> Param DeclType)
-> [(Param DeclType, SubExp)] -> [Param DeclType]
forall a b. (a -> b) -> [a] -> [b]
map (Param DeclType, SubExp) -> Param DeclType
forall a b. (a, b) -> a
fst [(Param DeclType, SubExp)]
merge
)
(Binder (Lore m) (Body (Lore m))
-> Binder (Lore m) (Body (Lore m)))
-> Binder (Lore m) (Body (Lore m))
-> Binder (Lore m) (Body (Lore m))
forall a b. (a -> b) -> a -> b
$ do
Result
ivs' <- [VName]
-> (VName -> BinderT (Lore m) (State VNameSource) SubExp)
-> BinderT (Lore m) (State VNameSource) Result
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [VName]
ivs ((VName -> BinderT (Lore m) (State VNameSource) SubExp)
-> BinderT (Lore m) (State VNameSource) Result)
-> (VName -> BinderT (Lore m) (State VNameSource) SubExp)
-> BinderT (Lore m) (State VNameSource) Result
forall a b. (a -> b) -> a -> b
$ \VName
iv -> do
Type
iv_t <- VName -> BinderT (Lore m) (State VNameSource) Type
forall lore (m :: * -> *). HasScope lore m => VName -> m Type
lookupType VName
iv
[Char]
-> Exp (Lore (BinderT (Lore m) (State VNameSource)))
-> BinderT (Lore m) (State VNameSource) SubExp
forall (m :: * -> *).
MonadBinder m =>
[Char] -> Exp (Lore m) -> m SubExp
letSubExp [Char]
"write_iv" (Exp (Lore (BinderT (Lore m) (State VNameSource)))
-> BinderT (Lore m) (State VNameSource) SubExp)
-> Exp (Lore (BinderT (Lore m) (State VNameSource)))
-> BinderT (Lore m) (State VNameSource) SubExp
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT (Lore m)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT (Lore m)) -> BasicOp -> ExpT (Lore m)
forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
iv (Slice SubExp -> BasicOp) -> Slice SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ Type -> Slice SubExp -> Slice SubExp
fullSlice Type
iv_t [SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix (SubExp -> DimIndex SubExp) -> SubExp -> DimIndex SubExp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
iter]
Result
ivs'' <- Lambda (Lore (BinderT (Lore m) (State VNameSource)))
-> [Exp (Lore (BinderT (Lore m) (State VNameSource)))]
-> BinderT (Lore m) (State VNameSource) Result
forall (m :: * -> *).
Transformer m =>
Lambda (Lore m) -> [Exp (Lore m)] -> m Result
bindLambda Lambda (Lore m)
Lambda (Lore (BinderT (Lore m) (State VNameSource)))
lam ((SubExp -> ExpT (Lore m)) -> Result -> [ExpT (Lore m)]
forall a b. (a -> b) -> [a] -> [b]
map (BasicOp -> ExpT (Lore m)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT (Lore m))
-> (SubExp -> BasicOp) -> SubExp -> ExpT (Lore m)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SubExp -> BasicOp
SubExp) Result
ivs')
let indexes :: [(ShapeBase SubExp, VName, [(Result, SubExp)])]
indexes = [(ShapeBase SubExp, Int, VName)]
-> Result -> [(ShapeBase SubExp, VName, [(Result, SubExp)])]
forall array a.
[(ShapeBase SubExp, Int, array)]
-> [a] -> [(ShapeBase SubExp, array, [([a], a)])]
groupScatterResults ([ShapeBase SubExp]
-> [Int] -> [VName] -> [(ShapeBase SubExp, Int, VName)]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [ShapeBase SubExp]
as_ws [Int]
as_ns ([VName] -> [(ShapeBase SubExp, Int, VName)])
-> [VName] -> [(ShapeBase SubExp, Int, VName)]
forall a b. (a -> b) -> a -> b
$ (Ident -> VName) -> [Ident] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Ident -> VName
identName [Ident]
asOuts) Result
ivs''
[VName]
ress <- [(ShapeBase SubExp, VName, [(Result, SubExp)])]
-> ((ShapeBase SubExp, VName, [(Result, SubExp)])
-> BinderT (Lore m) (State VNameSource) VName)
-> BinderT (Lore m) (State VNameSource) [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [(ShapeBase SubExp, VName, [(Result, SubExp)])]
indexes (((ShapeBase SubExp, VName, [(Result, SubExp)])
-> BinderT (Lore m) (State VNameSource) VName)
-> BinderT (Lore m) (State VNameSource) [VName])
-> ((ShapeBase SubExp, VName, [(Result, SubExp)])
-> BinderT (Lore m) (State VNameSource) VName)
-> BinderT (Lore m) (State VNameSource) [VName]
forall a b. (a -> b) -> a -> b
$ \(ShapeBase SubExp
_, VName
arr, [(Result, SubExp)]
indexes') -> do
let saveInArray :: VName -> (Result, SubExp) -> m VName
saveInArray VName
arr' (Result
indexCur, SubExp
valueCur) =
[Char] -> Exp (Lore m) -> m VName
forall (m :: * -> *).
MonadBinder m =>
[Char] -> Exp (Lore m) -> m VName
letExp [Char]
"write_out" (Exp (Lore m) -> m VName) -> m (Exp (Lore m)) -> m VName
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< VName -> [m (Exp (Lore m))] -> m (Exp (Lore m)) -> m (Exp (Lore m))
forall (m :: * -> *).
(MonadBinder m, BranchType (Lore m) ~ ExtType) =>
VName -> [m (Exp (Lore m))] -> m (Exp (Lore m)) -> m (Exp (Lore m))
eWriteArray VName
arr' ((SubExp -> m (Exp (Lore m))) -> Result -> [m (Exp (Lore m))]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> m (Exp (Lore m))
forall (m :: * -> *). MonadBinder m => SubExp -> m (Exp (Lore m))
eSubExp Result
indexCur) (SubExp -> m (Exp (Lore m))
forall (m :: * -> *). MonadBinder m => SubExp -> m (Exp (Lore m))
eSubExp SubExp
valueCur)
(VName
-> (Result, SubExp) -> BinderT (Lore m) (State VNameSource) VName)
-> VName
-> [(Result, SubExp)]
-> BinderT (Lore m) (State VNameSource) VName
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM VName
-> (Result, SubExp) -> BinderT (Lore m) (State VNameSource) VName
forall {m :: * -> *}.
(MonadBinder m, BranchType (Lore m) ~ ExtType) =>
VName -> (Result, SubExp) -> m VName
saveInArray VName
arr [(Result, SubExp)]
indexes'
Body (Lore m) -> Binder (Lore m) (Body (Lore m))
forall (m :: * -> *) a. Monad m => a -> m a
return (Body (Lore m) -> Binder (Lore m) (Body (Lore m)))
-> Body (Lore m) -> Binder (Lore m) (Body (Lore m))
forall a b. (a -> b) -> a -> b
$ Result -> Body (Lore m)
forall lore. Bindable lore => Result -> Body lore
resultBody ((VName -> SubExp) -> [VName] -> Result
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var [VName]
ress)
Pattern (Lore m) -> ExpT (Lore m) -> m ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind Pattern (Lore m)
pat (ExpT (Lore m) -> m ()) -> ExpT (Lore m) -> m ()
forall a b. (a -> b) -> a -> b
$ [(FParam (Lore m), SubExp)]
-> [(FParam (Lore m), SubExp)]
-> LoopForm (Lore m)
-> Body (Lore m)
-> ExpT (Lore m)
forall lore.
[(FParam lore, SubExp)]
-> [(FParam lore, SubExp)]
-> LoopForm lore
-> BodyT lore
-> ExpT lore
DoLoop [] [(Param DeclType, SubExp)]
[(FParam (Lore m), SubExp)]
merge (VName
-> IntType
-> SubExp
-> [(LParam (Lore m), VName)]
-> LoopForm (Lore m)
forall lore.
VName
-> IntType -> SubExp -> [(LParam lore, VName)] -> LoopForm lore
ForLoop VName
iter IntType
Int64 SubExp
len []) Body (Lore m)
loopBody
transformSOAC Pattern (Lore m)
pat (Hist SubExp
len [HistOp (Lore m)]
ops Lambda (Lore m)
bucket_fun [VName]
imgs) = do
VName
iter <- [Char] -> m VName
forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
"iter"
[Type]
hists_ts <- (VName -> m Type) -> [VName] -> m [Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM VName -> m Type
forall lore (m :: * -> *). HasScope lore m => VName -> m Type
lookupType ([VName] -> m [Type]) -> [VName] -> m [Type]
forall a b. (a -> b) -> a -> b
$ (HistOp (Lore m) -> [VName]) -> [HistOp (Lore m)] -> [VName]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap HistOp (Lore m) -> [VName]
forall lore. HistOp lore -> [VName]
histDest [HistOp (Lore m)]
ops
[Ident]
hists_out <- (Type -> m Ident) -> [Type] -> m [Ident]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM ([Char] -> Type -> m Ident
forall (m :: * -> *).
MonadFreshNames m =>
[Char] -> Type -> m Ident
newIdent [Char]
"dests") [Type]
hists_ts
let merge :: [(Param DeclType, SubExp)]
merge = [Ident] -> Result -> [(Param DeclType, SubExp)]
loopMerge [Ident]
hists_out (Result -> [(Param DeclType, SubExp)])
-> Result -> [(Param DeclType, SubExp)]
forall a b. (a -> b) -> a -> b
$ (HistOp (Lore m) -> Result) -> [HistOp (Lore m)] -> Result
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap ((VName -> SubExp) -> [VName] -> Result
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var ([VName] -> Result)
-> (HistOp (Lore m) -> [VName]) -> HistOp (Lore m) -> Result
forall b c a. (b -> c) -> (a -> b) -> a -> c
. HistOp (Lore m) -> [VName]
forall lore. HistOp lore -> [VName]
histDest) [HistOp (Lore m)]
ops
let iter_scope :: Scope (Lore m)
iter_scope = VName -> NameInfo (Lore m) -> Scope (Lore m) -> Scope (Lore m)
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
iter (IntType -> NameInfo (Lore m)
forall lore. IntType -> NameInfo lore
IndexName IntType
Int64) (Scope (Lore m) -> Scope (Lore m))
-> Scope (Lore m) -> Scope (Lore m)
forall a b. (a -> b) -> a -> b
$ [Param DeclType] -> Scope (Lore m)
forall lore dec.
(FParamInfo lore ~ dec) =>
[Param dec] -> Scope lore
scopeOfFParams ([Param DeclType] -> Scope (Lore m))
-> [Param DeclType] -> Scope (Lore m)
forall a b. (a -> b) -> a -> b
$ ((Param DeclType, SubExp) -> Param DeclType)
-> [(Param DeclType, SubExp)] -> [Param DeclType]
forall a b. (a -> b) -> [a] -> [b]
map (Param DeclType, SubExp) -> Param DeclType
forall a b. (a, b) -> a
fst [(Param DeclType, SubExp)]
merge
Body (Lore m)
loopBody <- Binder (Lore m) (Body (Lore m)) -> m (Body (Lore m))
forall lore (m :: * -> *) somelore.
(Bindable lore, MonadFreshNames m, HasScope somelore m,
SameScope somelore lore) =>
Binder lore (Body lore) -> m (Body lore)
runBodyBinder (Binder (Lore m) (Body (Lore m)) -> m (Body (Lore m)))
-> (Binder (Lore m) (Body (Lore m))
-> Binder (Lore m) (Body (Lore m)))
-> Binder (Lore m) (Body (Lore m))
-> m (Body (Lore m))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Scope (Lore m)
-> Binder (Lore m) (Body (Lore m))
-> Binder (Lore m) (Body (Lore m))
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
localScope Scope (Lore m)
iter_scope (Binder (Lore m) (Body (Lore m)) -> m (Body (Lore m)))
-> Binder (Lore m) (Body (Lore m)) -> m (Body (Lore m))
forall a b. (a -> b) -> a -> b
$ do
Result
imgs' <- [VName]
-> (VName -> BinderT (Lore m) (State VNameSource) SubExp)
-> BinderT (Lore m) (State VNameSource) Result
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [VName]
imgs ((VName -> BinderT (Lore m) (State VNameSource) SubExp)
-> BinderT (Lore m) (State VNameSource) Result)
-> (VName -> BinderT (Lore m) (State VNameSource) SubExp)
-> BinderT (Lore m) (State VNameSource) Result
forall a b. (a -> b) -> a -> b
$ \VName
img -> do
Type
img_t <- VName -> BinderT (Lore m) (State VNameSource) Type
forall lore (m :: * -> *). HasScope lore m => VName -> m Type
lookupType VName
img
[Char]
-> Exp (Lore (BinderT (Lore m) (State VNameSource)))
-> BinderT (Lore m) (State VNameSource) SubExp
forall (m :: * -> *).
MonadBinder m =>
[Char] -> Exp (Lore m) -> m SubExp
letSubExp [Char]
"pixel" (Exp (Lore (BinderT (Lore m) (State VNameSource)))
-> BinderT (Lore m) (State VNameSource) SubExp)
-> Exp (Lore (BinderT (Lore m) (State VNameSource)))
-> BinderT (Lore m) (State VNameSource) SubExp
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT (Lore m)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT (Lore m)) -> BasicOp -> ExpT (Lore m)
forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
img (Slice SubExp -> BasicOp) -> Slice SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ Type -> Slice SubExp -> Slice SubExp
fullSlice Type
img_t [SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix (SubExp -> DimIndex SubExp) -> SubExp -> DimIndex SubExp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
iter]
Result
imgs'' <- Lambda (Lore (BinderT (Lore m) (State VNameSource)))
-> [Exp (Lore (BinderT (Lore m) (State VNameSource)))]
-> BinderT (Lore m) (State VNameSource) Result
forall (m :: * -> *).
Transformer m =>
Lambda (Lore m) -> [Exp (Lore m)] -> m Result
bindLambda Lambda (Lore m)
Lambda (Lore (BinderT (Lore m) (State VNameSource)))
bucket_fun ([Exp (Lore (BinderT (Lore m) (State VNameSource)))]
-> BinderT (Lore m) (State VNameSource) Result)
-> [Exp (Lore (BinderT (Lore m) (State VNameSource)))]
-> BinderT (Lore m) (State VNameSource) Result
forall a b. (a -> b) -> a -> b
$ (SubExp -> ExpT (Lore m)) -> Result -> [ExpT (Lore m)]
forall a b. (a -> b) -> [a] -> [b]
map (BasicOp -> ExpT (Lore m)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT (Lore m))
-> (SubExp -> BasicOp) -> SubExp -> ExpT (Lore m)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SubExp -> BasicOp
SubExp) Result
imgs'
let lens :: Int
lens = [HistOp (Lore m)] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [HistOp (Lore m)]
ops
inds :: Result
inds = Int -> Result -> Result
forall a. Int -> [a] -> [a]
take Int
lens Result
imgs''
vals :: [Result]
vals = [Int] -> Result -> [Result]
forall a. [Int] -> [a] -> [[a]]
chunks ((HistOp (Lore m) -> Int) -> [HistOp (Lore m)] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map ([Type] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([Type] -> Int)
-> (HistOp (Lore m) -> [Type]) -> HistOp (Lore m) -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Lambda (Lore m) -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType (Lambda (Lore m) -> [Type])
-> (HistOp (Lore m) -> Lambda (Lore m))
-> HistOp (Lore m)
-> [Type]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. HistOp (Lore m) -> Lambda (Lore m)
forall lore. HistOp lore -> Lambda lore
histOp) [HistOp (Lore m)]
ops) (Result -> [Result]) -> Result -> [Result]
forall a b. (a -> b) -> a -> b
$ Int -> Result -> Result
forall a. Int -> [a] -> [a]
drop Int
lens Result
imgs''
hists_out' :: [[VName]]
hists_out' =
[Int] -> [VName] -> [[VName]]
forall a. [Int] -> [a] -> [[a]]
chunks ((HistOp (Lore m) -> Int) -> [HistOp (Lore m)] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map ([Type] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([Type] -> Int)
-> (HistOp (Lore m) -> [Type]) -> HistOp (Lore m) -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Lambda (Lore m) -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType (Lambda (Lore m) -> [Type])
-> (HistOp (Lore m) -> Lambda (Lore m))
-> HistOp (Lore m)
-> [Type]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. HistOp (Lore m) -> Lambda (Lore m)
forall lore. HistOp lore -> Lambda lore
histOp) [HistOp (Lore m)]
ops) ([VName] -> [[VName]]) -> [VName] -> [[VName]]
forall a b. (a -> b) -> a -> b
$
(Ident -> VName) -> [Ident] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Ident -> VName
identName [Ident]
hists_out
[[VName]]
hists_out'' <- [([VName], HistOp (Lore m), SubExp, Result)]
-> (([VName], HistOp (Lore m), SubExp, Result)
-> BinderT (Lore m) (State VNameSource) [VName])
-> BinderT (Lore m) (State VNameSource) [[VName]]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM ([[VName]]
-> [HistOp (Lore m)]
-> Result
-> [Result]
-> [([VName], HistOp (Lore m), SubExp, Result)]
forall a b c d. [a] -> [b] -> [c] -> [d] -> [(a, b, c, d)]
zip4 [[VName]]
hists_out' [HistOp (Lore m)]
ops Result
inds [Result]
vals) ((([VName], HistOp (Lore m), SubExp, Result)
-> BinderT (Lore m) (State VNameSource) [VName])
-> BinderT (Lore m) (State VNameSource) [[VName]])
-> (([VName], HistOp (Lore m), SubExp, Result)
-> BinderT (Lore m) (State VNameSource) [VName])
-> BinderT (Lore m) (State VNameSource) [[VName]]
forall a b. (a -> b) -> a -> b
$ \([VName]
hist, HistOp (Lore m)
op, SubExp
idx, Result
val) -> do
let outside_bounds_branch :: BinderT
(Lore m)
(State VNameSource)
(Body (Lore (BinderT (Lore m) (State VNameSource))))
outside_bounds_branch = BinderT (Lore m) (State VNameSource) Result
-> BinderT
(Lore m)
(State VNameSource)
(Body (Lore (BinderT (Lore m) (State VNameSource))))
forall (m :: * -> *).
MonadBinder m =>
m Result -> m (Body (Lore m))
buildBody_ (BinderT (Lore m) (State VNameSource) Result
-> BinderT
(Lore m)
(State VNameSource)
(Body (Lore (BinderT (Lore m) (State VNameSource)))))
-> BinderT (Lore m) (State VNameSource) Result
-> BinderT
(Lore m)
(State VNameSource)
(Body (Lore (BinderT (Lore m) (State VNameSource))))
forall a b. (a -> b) -> a -> b
$ Result -> BinderT (Lore m) (State VNameSource) Result
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Result -> BinderT (Lore m) (State VNameSource) Result)
-> Result -> BinderT (Lore m) (State VNameSource) Result
forall a b. (a -> b) -> a -> b
$ (VName -> SubExp) -> [VName] -> Result
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var [VName]
hist
oob :: BinderT
(Lore m)
(State VNameSource)
(Exp (Lore (BinderT (Lore m) (State VNameSource))))
oob = case [VName]
hist of
[] -> SubExp
-> BinderT
(Lore m)
(State VNameSource)
(Exp (Lore (BinderT (Lore m) (State VNameSource))))
forall (m :: * -> *). MonadBinder m => SubExp -> m (Exp (Lore m))
eSubExp (SubExp
-> BinderT
(Lore m)
(State VNameSource)
(Exp (Lore (BinderT (Lore m) (State VNameSource)))))
-> SubExp
-> BinderT
(Lore m)
(State VNameSource)
(Exp (Lore (BinderT (Lore m) (State VNameSource))))
forall a b. (a -> b) -> a -> b
$ Bool -> SubExp
forall v. IsValue v => v -> SubExp
constant Bool
True
VName
arr : [VName]
_ -> VName
-> [BinderT
(Lore m)
(State VNameSource)
(Exp (Lore (BinderT (Lore m) (State VNameSource))))]
-> BinderT
(Lore m)
(State VNameSource)
(Exp (Lore (BinderT (Lore m) (State VNameSource))))
forall (m :: * -> *).
MonadBinder m =>
VName -> [m (Exp (Lore m))] -> m (Exp (Lore m))
eOutOfBounds VName
arr [SubExp
-> BinderT
(Lore m)
(State VNameSource)
(Exp (Lore (BinderT (Lore m) (State VNameSource))))
forall (m :: * -> *). MonadBinder m => SubExp -> m (Exp (Lore m))
eSubExp SubExp
idx]
[Char]
-> Exp (Lore (BinderT (Lore m) (State VNameSource)))
-> BinderT (Lore m) (State VNameSource) [VName]
forall (m :: * -> *).
MonadBinder m =>
[Char] -> Exp (Lore m) -> m [VName]
letTupExp [Char]
"new_histo" (ExpT (Lore m) -> BinderT (Lore m) (State VNameSource) [VName])
-> (Binder (Lore m) (Body (Lore m))
-> BinderT (Lore m) (State VNameSource) (ExpT (Lore m)))
-> Binder (Lore m) (Body (Lore m))
-> BinderT (Lore m) (State VNameSource) [VName]
forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< BinderT
(Lore m)
(State VNameSource)
(Exp (Lore (BinderT (Lore m) (State VNameSource))))
-> BinderT
(Lore m)
(State VNameSource)
(Body (Lore (BinderT (Lore m) (State VNameSource))))
-> BinderT
(Lore m)
(State VNameSource)
(Body (Lore (BinderT (Lore m) (State VNameSource))))
-> BinderT
(Lore m)
(State VNameSource)
(Exp (Lore (BinderT (Lore m) (State VNameSource))))
forall (m :: * -> *).
(MonadBinder m, BranchType (Lore m) ~ ExtType) =>
m (Exp (Lore m))
-> m (Body (Lore m)) -> m (Body (Lore m)) -> m (Exp (Lore m))
eIf BinderT
(Lore m)
(State VNameSource)
(Exp (Lore (BinderT (Lore m) (State VNameSource))))
oob BinderT
(Lore m)
(State VNameSource)
(Body (Lore (BinderT (Lore m) (State VNameSource))))
outside_bounds_branch (Binder (Lore m) (Body (Lore m))
-> BinderT (Lore m) (State VNameSource) [VName])
-> Binder (Lore m) (Body (Lore m))
-> BinderT (Lore m) (State VNameSource) [VName]
forall a b. (a -> b) -> a -> b
$
BinderT (Lore m) (State VNameSource) Result
-> BinderT
(Lore m)
(State VNameSource)
(Body (Lore (BinderT (Lore m) (State VNameSource))))
forall (m :: * -> *).
MonadBinder m =>
m Result -> m (Body (Lore m))
buildBody_ (BinderT (Lore m) (State VNameSource) Result
-> BinderT
(Lore m)
(State VNameSource)
(Body (Lore (BinderT (Lore m) (State VNameSource)))))
-> BinderT (Lore m) (State VNameSource) Result
-> BinderT
(Lore m)
(State VNameSource)
(Body (Lore (BinderT (Lore m) (State VNameSource))))
forall a b. (a -> b) -> a -> b
$ do
Result
h_val <- [VName]
-> (VName -> BinderT (Lore m) (State VNameSource) SubExp)
-> BinderT (Lore m) (State VNameSource) Result
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [VName]
hist ((VName -> BinderT (Lore m) (State VNameSource) SubExp)
-> BinderT (Lore m) (State VNameSource) Result)
-> (VName -> BinderT (Lore m) (State VNameSource) SubExp)
-> BinderT (Lore m) (State VNameSource) Result
forall a b. (a -> b) -> a -> b
$ \VName
arr -> do
Type
arr_t <- VName -> BinderT (Lore m) (State VNameSource) Type
forall lore (m :: * -> *). HasScope lore m => VName -> m Type
lookupType VName
arr
[Char]
-> Exp (Lore (BinderT (Lore m) (State VNameSource)))
-> BinderT (Lore m) (State VNameSource) SubExp
forall (m :: * -> *).
MonadBinder m =>
[Char] -> Exp (Lore m) -> m SubExp
letSubExp [Char]
"read_hist" (Exp (Lore (BinderT (Lore m) (State VNameSource)))
-> BinderT (Lore m) (State VNameSource) SubExp)
-> Exp (Lore (BinderT (Lore m) (State VNameSource)))
-> BinderT (Lore m) (State VNameSource) SubExp
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT (Lore m)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT (Lore m)) -> BasicOp -> ExpT (Lore m)
forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
arr (Slice SubExp -> BasicOp) -> Slice SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ Type -> Slice SubExp -> Slice SubExp
fullSlice Type
arr_t [SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix SubExp
idx]
Result
h_val' <-
Lambda (Lore (BinderT (Lore m) (State VNameSource)))
-> [Exp (Lore (BinderT (Lore m) (State VNameSource)))]
-> BinderT (Lore m) (State VNameSource) Result
forall (m :: * -> *).
Transformer m =>
Lambda (Lore m) -> [Exp (Lore m)] -> m Result
bindLambda (HistOp (Lore m) -> Lambda (Lore m)
forall lore. HistOp lore -> Lambda lore
histOp HistOp (Lore m)
op) ([Exp (Lore (BinderT (Lore m) (State VNameSource)))]
-> BinderT (Lore m) (State VNameSource) Result)
-> [Exp (Lore (BinderT (Lore m) (State VNameSource)))]
-> BinderT (Lore m) (State VNameSource) Result
forall a b. (a -> b) -> a -> b
$
(SubExp -> ExpT (Lore m)) -> Result -> [ExpT (Lore m)]
forall a b. (a -> b) -> [a] -> [b]
map (BasicOp -> ExpT (Lore m)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT (Lore m))
-> (SubExp -> BasicOp) -> SubExp -> ExpT (Lore m)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SubExp -> BasicOp
SubExp) (Result -> [ExpT (Lore m)]) -> Result -> [ExpT (Lore m)]
forall a b. (a -> b) -> a -> b
$ Result
h_val Result -> Result -> Result
forall a. [a] -> [a] -> [a]
++ Result
val
[VName]
hist' <- [(VName, SubExp)]
-> ((VName, SubExp) -> BinderT (Lore m) (State VNameSource) VName)
-> BinderT (Lore m) (State VNameSource) [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM ([VName] -> Result -> [(VName, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
hist Result
h_val') (((VName, SubExp) -> BinderT (Lore m) (State VNameSource) VName)
-> BinderT (Lore m) (State VNameSource) [VName])
-> ((VName, SubExp) -> BinderT (Lore m) (State VNameSource) VName)
-> BinderT (Lore m) (State VNameSource) [VName]
forall a b. (a -> b) -> a -> b
$ \(VName
arr, SubExp
v) -> do
Type
arr_t <- VName -> BinderT (Lore m) (State VNameSource) Type
forall lore (m :: * -> *). HasScope lore m => VName -> m Type
lookupType VName
arr
[Char]
-> VName
-> Slice SubExp
-> Exp (Lore (BinderT (Lore m) (State VNameSource)))
-> BinderT (Lore m) (State VNameSource) VName
forall (m :: * -> *).
MonadBinder m =>
[Char] -> VName -> Slice SubExp -> Exp (Lore m) -> m VName
letInPlace [Char]
"hist_out" VName
arr (Type -> Slice SubExp -> Slice SubExp
fullSlice Type
arr_t [SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix SubExp
idx]) (Exp (Lore (BinderT (Lore m) (State VNameSource)))
-> BinderT (Lore m) (State VNameSource) VName)
-> Exp (Lore (BinderT (Lore m) (State VNameSource)))
-> BinderT (Lore m) (State VNameSource) VName
forall a b. (a -> b) -> a -> b
$
BasicOp -> ExpT (Lore m)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT (Lore m)) -> BasicOp -> ExpT (Lore m)
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
v
Result -> BinderT (Lore m) (State VNameSource) Result
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Result -> BinderT (Lore m) (State VNameSource) Result)
-> Result -> BinderT (Lore m) (State VNameSource) Result
forall a b. (a -> b) -> a -> b
$ (VName -> SubExp) -> [VName] -> Result
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var [VName]
hist'
Body (Lore m) -> Binder (Lore m) (Body (Lore m))
forall (m :: * -> *) a. Monad m => a -> m a
return (Body (Lore m) -> Binder (Lore m) (Body (Lore m)))
-> Body (Lore m) -> Binder (Lore m) (Body (Lore m))
forall a b. (a -> b) -> a -> b
$ Result -> Body (Lore m)
forall lore. Bindable lore => Result -> Body lore
resultBody (Result -> Body (Lore m)) -> Result -> Body (Lore m)
forall a b. (a -> b) -> a -> b
$ (VName -> SubExp) -> [VName] -> Result
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var ([VName] -> Result) -> [VName] -> Result
forall a b. (a -> b) -> a -> b
$ [[VName]] -> [VName]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[VName]]
hists_out''
Pattern (Lore m) -> ExpT (Lore m) -> m ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind Pattern (Lore m)
pat (ExpT (Lore m) -> m ()) -> ExpT (Lore m) -> m ()
forall a b. (a -> b) -> a -> b
$ [(FParam (Lore m), SubExp)]
-> [(FParam (Lore m), SubExp)]
-> LoopForm (Lore m)
-> Body (Lore m)
-> ExpT (Lore m)
forall lore.
[(FParam lore, SubExp)]
-> [(FParam lore, SubExp)]
-> LoopForm lore
-> BodyT lore
-> ExpT lore
DoLoop [] [(Param DeclType, SubExp)]
[(FParam (Lore m), SubExp)]
merge (VName
-> IntType
-> SubExp
-> [(LParam (Lore m), VName)]
-> LoopForm (Lore m)
forall lore.
VName
-> IntType -> SubExp -> [(LParam lore, VName)] -> LoopForm lore
ForLoop VName
iter IntType
Int64 SubExp
len []) Body (Lore m)
loopBody
transformLambda ::
( MonadFreshNames m,
Bindable lore,
BinderOps lore,
LocalScope somelore m,
SameScope somelore lore,
LetDec lore ~ LetDec SOACS,
CanBeAliased (Op lore)
) =>
Lambda ->
m (AST.Lambda lore)
transformLambda :: forall (m :: * -> *) lore somelore.
(MonadFreshNames m, Bindable lore, BinderOps lore,
LocalScope somelore m, SameScope somelore lore,
LetDec lore ~ LetDec SOACS, CanBeAliased (Op lore)) =>
Lambda SOACS -> m (Lambda lore)
transformLambda (Lambda [LParam SOACS]
params BodyT SOACS
body [Type]
rettype) = do
Body lore
body' <-
Binder lore (Body lore) -> m (Body lore)
forall lore (m :: * -> *) somelore.
(Bindable lore, MonadFreshNames m, HasScope somelore m,
SameScope somelore lore) =>
Binder lore (Body lore) -> m (Body lore)
runBodyBinder (Binder lore (Body lore) -> m (Body lore))
-> Binder lore (Body lore) -> m (Body lore)
forall a b. (a -> b) -> a -> b
$
Scope lore -> Binder lore (Body lore) -> Binder lore (Body lore)
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
localScope ([Param Type] -> Scope lore
forall lore dec.
(LParamInfo lore ~ dec) =>
[Param dec] -> Scope lore
scopeOfLParams [Param Type]
[LParam SOACS]
params) (Binder lore (Body lore) -> Binder lore (Body lore))
-> Binder lore (Body lore) -> Binder lore (Body lore)
forall a b. (a -> b) -> a -> b
$
BodyT SOACS
-> BinderT
lore
(State VNameSource)
(Body (Lore (BinderT lore (State VNameSource))))
forall (m :: * -> *).
(Transformer m, LetDec (Lore m) ~ LetDec SOACS) =>
BodyT SOACS -> m (Body (Lore m))
transformBody BodyT SOACS
body
Lambda lore -> m (Lambda lore)
forall (m :: * -> *) a. Monad m => a -> m a
return (Lambda lore -> m (Lambda lore)) -> Lambda lore -> m (Lambda lore)
forall a b. (a -> b) -> a -> b
$ [LParam lore] -> Body lore -> [Type] -> Lambda lore
forall lore. [LParam lore] -> BodyT lore -> [Type] -> LambdaT lore
Lambda [LParam lore]
[LParam SOACS]
params Body lore
body' [Type]
rettype
letwith :: Transformer m => [VName] -> SubExp -> [SubExp] -> m [VName]
letwith :: forall (m :: * -> *).
Transformer m =>
[VName] -> SubExp -> Result -> m [VName]
letwith [VName]
ks SubExp
i Result
vs = do
let update :: VName -> SubExp -> m VName
update VName
k SubExp
v = do
Type
k_t <- VName -> m Type
forall lore (m :: * -> *). HasScope lore m => VName -> m Type
lookupType VName
k
case Type
k_t of
Acc {} ->
[Char] -> Exp (Lore m) -> m VName
forall (m :: * -> *).
MonadBinder m =>
[Char] -> Exp (Lore m) -> m VName
letExp [Char]
"lw_acc" (Exp (Lore m) -> m VName) -> Exp (Lore m) -> m VName
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Lore m)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp (Lore m)) -> BasicOp -> Exp (Lore m)
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
v
Type
_ ->
[Char] -> VName -> Slice SubExp -> Exp (Lore m) -> m VName
forall (m :: * -> *).
MonadBinder m =>
[Char] -> VName -> Slice SubExp -> Exp (Lore m) -> m VName
letInPlace [Char]
"lw_dest" VName
k (Type -> Slice SubExp -> Slice SubExp
fullSlice Type
k_t [SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix SubExp
i]) (Exp (Lore m) -> m VName) -> Exp (Lore m) -> m VName
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Lore m)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp (Lore m)) -> BasicOp -> Exp (Lore m)
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
v
(VName -> SubExp -> m VName) -> [VName] -> Result -> m [VName]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM VName -> SubExp -> m VName
update [VName]
ks Result
vs
bindLambda ::
Transformer m =>
AST.Lambda (Lore m) ->
[AST.Exp (Lore m)] ->
m [SubExp]
bindLambda :: forall (m :: * -> *).
Transformer m =>
Lambda (Lore m) -> [Exp (Lore m)] -> m Result
bindLambda (Lambda [LParam (Lore m)]
params BodyT (Lore m)
body [Type]
_) [Exp (Lore m)]
args = do
[(Param Type, Exp (Lore m))]
-> ((Param Type, Exp (Lore m)) -> m ()) -> m ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param Type] -> [Exp (Lore m)] -> [(Param Type, Exp (Lore m))]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param Type]
[LParam (Lore m)]
params [Exp (Lore m)]
args) (((Param Type, Exp (Lore m)) -> m ()) -> m ())
-> ((Param Type, Exp (Lore m)) -> m ()) -> m ()
forall a b. (a -> b) -> a -> b
$ \(Param Type
param, Exp (Lore m)
arg) ->
if Type -> Bool
forall shape u. TypeBase shape u -> Bool
primType (Type -> Bool) -> Type -> Bool
forall a b. (a -> b) -> a -> b
$ Param Type -> Type
forall dec. Typed dec => Param dec -> Type
paramType Param Type
param
then [VName] -> Exp (Lore m) -> m ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m ()
letBindNames [Param Type -> VName
forall dec. Param dec -> VName
paramName Param Type
param] Exp (Lore m)
arg
else [VName] -> Exp (Lore m) -> m ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m ()
letBindNames [Param Type -> VName
forall dec. Param dec -> VName
paramName Param Type
param] (Exp (Lore m) -> m ()) -> m (Exp (Lore m)) -> m ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< m (Exp (Lore m)) -> m (Exp (Lore m))
forall (m :: * -> *).
MonadBinder m =>
m (Exp (Lore m)) -> m (Exp (Lore m))
eCopy (Exp (Lore m) -> m (Exp (Lore m))
forall (f :: * -> *) a. Applicative f => a -> f a
pure Exp (Lore m)
arg)
BodyT (Lore m) -> m Result
forall (m :: * -> *). MonadBinder m => Body (Lore m) -> m Result
bodyBind BodyT (Lore m)
body
loopMerge :: [Ident] -> [SubExp] -> [(Param DeclType, SubExp)]
loopMerge :: [Ident] -> Result -> [(Param DeclType, SubExp)]
loopMerge [Ident]
vars = [(Ident, Uniqueness)] -> Result -> [(Param DeclType, SubExp)]
loopMerge' ([(Ident, Uniqueness)] -> Result -> [(Param DeclType, SubExp)])
-> [(Ident, Uniqueness)] -> Result -> [(Param DeclType, SubExp)]
forall a b. (a -> b) -> a -> b
$ [Ident] -> [Uniqueness] -> [(Ident, Uniqueness)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Ident]
vars ([Uniqueness] -> [(Ident, Uniqueness)])
-> [Uniqueness] -> [(Ident, Uniqueness)]
forall a b. (a -> b) -> a -> b
$ Uniqueness -> [Uniqueness]
forall a. a -> [a]
repeat Uniqueness
Unique
loopMerge' :: [(Ident, Uniqueness)] -> [SubExp] -> [(Param DeclType, SubExp)]
loopMerge' :: [(Ident, Uniqueness)] -> Result -> [(Param DeclType, SubExp)]
loopMerge' [(Ident, Uniqueness)]
vars Result
vals =
[ (VName -> DeclType -> Param DeclType
forall dec. VName -> dec -> Param dec
Param VName
pname (DeclType -> Param DeclType) -> DeclType -> Param DeclType
forall a b. (a -> b) -> a -> b
$ Type -> Uniqueness -> DeclType
forall shape.
TypeBase shape NoUniqueness
-> Uniqueness -> TypeBase shape Uniqueness
toDecl Type
ptype Uniqueness
u, SubExp
val)
| ((Ident VName
pname Type
ptype, Uniqueness
u), SubExp
val) <- [(Ident, Uniqueness)] -> Result -> [((Ident, Uniqueness), SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [(Ident, Uniqueness)]
vars Result
vals
]