{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE TypeFamilies #-}
module Futhark.Optimise.InPlaceLowering.LowerIntoStm
( lowerUpdateKernels,
lowerUpdate,
LowerUpdate,
DesiredUpdate (..),
)
where
import Control.Monad
import Control.Monad.Writer
import Data.Either
import Data.List (find, unzip4)
import Data.Maybe (isNothing, mapMaybe)
import Futhark.Analysis.PrimExp.Convert
import Futhark.Construct
import Futhark.IR.Aliases
import Futhark.IR.Kernels
import Futhark.Optimise.InPlaceLowering.SubstituteIndices
data DesiredUpdate dec = DesiredUpdate
{
forall dec. DesiredUpdate dec -> VName
updateName :: VName,
forall dec. DesiredUpdate dec -> dec
updateType :: dec,
forall dec. DesiredUpdate dec -> Certificates
updateCertificates :: Certificates,
forall dec. DesiredUpdate dec -> VName
updateSource :: VName,
forall dec. DesiredUpdate dec -> Slice SubExp
updateIndices :: Slice SubExp,
forall dec. DesiredUpdate dec -> VName
updateValue :: VName
}
deriving (Int -> DesiredUpdate dec -> ShowS
[DesiredUpdate dec] -> ShowS
DesiredUpdate dec -> String
(Int -> DesiredUpdate dec -> ShowS)
-> (DesiredUpdate dec -> String)
-> ([DesiredUpdate dec] -> ShowS)
-> Show (DesiredUpdate dec)
forall dec. Show dec => Int -> DesiredUpdate dec -> ShowS
forall dec. Show dec => [DesiredUpdate dec] -> ShowS
forall dec. Show dec => DesiredUpdate dec -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [DesiredUpdate dec] -> ShowS
$cshowList :: forall dec. Show dec => [DesiredUpdate dec] -> ShowS
show :: DesiredUpdate dec -> String
$cshow :: forall dec. Show dec => DesiredUpdate dec -> String
showsPrec :: Int -> DesiredUpdate dec -> ShowS
$cshowsPrec :: forall dec. Show dec => Int -> DesiredUpdate dec -> ShowS
Show)
instance Functor DesiredUpdate where
a -> b
f fmap :: forall a b. (a -> b) -> DesiredUpdate a -> DesiredUpdate b
`fmap` DesiredUpdate a
u = DesiredUpdate a
u {updateType :: b
updateType = a -> b
f (a -> b) -> a -> b
forall a b. (a -> b) -> a -> b
$ DesiredUpdate a -> a
forall dec. DesiredUpdate dec -> dec
updateType DesiredUpdate a
u}
updateHasValue :: VName -> DesiredUpdate dec -> Bool
updateHasValue :: forall dec. VName -> DesiredUpdate dec -> Bool
updateHasValue VName
name = (VName
name VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
==) (VName -> Bool)
-> (DesiredUpdate dec -> VName) -> DesiredUpdate dec -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. DesiredUpdate dec -> VName
forall dec. DesiredUpdate dec -> VName
updateValue
type LowerUpdate lore m =
Scope (Aliases lore) ->
Stm (Aliases lore) ->
[DesiredUpdate (LetDec (Aliases lore))] ->
Maybe (m [Stm (Aliases lore)])
lowerUpdate ::
( MonadFreshNames m,
Bindable lore,
LetDec lore ~ Type,
CanBeAliased (Op lore)
) =>
LowerUpdate lore m
lowerUpdate :: forall (m :: * -> *) lore.
(MonadFreshNames m, Bindable lore, LetDec lore ~ Type,
CanBeAliased (Op lore)) =>
LowerUpdate lore m
lowerUpdate Scope (Aliases lore)
scope (Let Pattern (Aliases lore)
pat StmAux (ExpDec (Aliases lore))
aux (DoLoop [(FParam (Aliases lore), SubExp)]
ctx [(FParam (Aliases lore), SubExp)]
val LoopForm (Aliases lore)
form BodyT (Aliases lore)
body)) [DesiredUpdate (LetDec (Aliases lore))]
updates = do
m ([Stm (Aliases lore)], [Stm (Aliases lore)], [Ident], [Ident],
[(Param DeclType, SubExp)], [(Param DeclType, SubExp)],
BodyT (Aliases lore))
canDo <- Scope (Aliases lore)
-> [DesiredUpdate (LetDec (Aliases lore))]
-> Pattern (Aliases lore)
-> [(FParam (Aliases lore), SubExp)]
-> [(FParam (Aliases lore), SubExp)]
-> LoopForm (Aliases lore)
-> BodyT (Aliases lore)
-> Maybe
(m ([Stm (Aliases lore)], [Stm (Aliases lore)], [Ident], [Ident],
[(FParam (Aliases lore), SubExp)],
[(FParam (Aliases lore), SubExp)], BodyT (Aliases lore)))
forall lore als (m :: * -> *).
(Bindable lore, BinderOps lore, Aliased lore,
LetDec lore ~ (als, Type), MonadFreshNames m) =>
Scope lore
-> [DesiredUpdate (LetDec lore)]
-> Pattern lore
-> [(FParam lore, SubExp)]
-> [(FParam lore, SubExp)]
-> LoopForm lore
-> Body lore
-> Maybe
(m ([Stm lore], [Stm lore], [Ident], [Ident],
[(FParam lore, SubExp)], [(FParam lore, SubExp)], Body lore))
lowerUpdateIntoLoop Scope (Aliases lore)
scope [DesiredUpdate (LetDec (Aliases lore))]
updates Pattern (Aliases lore)
pat [(FParam (Aliases lore), SubExp)]
ctx [(FParam (Aliases lore), SubExp)]
val LoopForm (Aliases lore)
form BodyT (Aliases lore)
body
m [Stm (Aliases lore)] -> Maybe (m [Stm (Aliases lore)])
forall a. a -> Maybe a
Just (m [Stm (Aliases lore)] -> Maybe (m [Stm (Aliases lore)]))
-> m [Stm (Aliases lore)] -> Maybe (m [Stm (Aliases lore)])
forall a b. (a -> b) -> a -> b
$ do
([Stm (Aliases lore)]
prebnds, [Stm (Aliases lore)]
postbnds, [Ident]
ctxpat, [Ident]
valpat, [(Param DeclType, SubExp)]
ctx', [(Param DeclType, SubExp)]
val', BodyT (Aliases lore)
body') <- m ([Stm (Aliases lore)], [Stm (Aliases lore)], [Ident], [Ident],
[(Param DeclType, SubExp)], [(Param DeclType, SubExp)],
BodyT (Aliases lore))
canDo
[Stm (Aliases lore)] -> m [Stm (Aliases lore)]
forall (m :: * -> *) a. Monad m => a -> m a
return ([Stm (Aliases lore)] -> m [Stm (Aliases lore)])
-> [Stm (Aliases lore)] -> m [Stm (Aliases lore)]
forall a b. (a -> b) -> a -> b
$
[Stm (Aliases lore)]
prebnds
[Stm (Aliases lore)]
-> [Stm (Aliases lore)] -> [Stm (Aliases lore)]
forall a. [a] -> [a] -> [a]
++ [ Certificates -> Stm (Aliases lore) -> Stm (Aliases lore)
forall lore. Certificates -> Stm lore -> Stm lore
certify (StmAux (VarAliases, ExpDec lore) -> Certificates
forall dec. StmAux dec -> Certificates
stmAuxCerts StmAux (VarAliases, ExpDec lore)
StmAux (ExpDec (Aliases lore))
aux) (Stm (Aliases lore) -> Stm (Aliases lore))
-> Stm (Aliases lore) -> Stm (Aliases lore)
forall a b. (a -> b) -> a -> b
$
[Ident] -> [Ident] -> ExpT (Aliases lore) -> Stm (Aliases lore)
forall lore.
Bindable lore =>
[Ident] -> [Ident] -> Exp lore -> Stm lore
mkLet [Ident]
ctxpat [Ident]
valpat (ExpT (Aliases lore) -> Stm (Aliases lore))
-> ExpT (Aliases lore) -> Stm (Aliases lore)
forall a b. (a -> b) -> a -> b
$ [(FParam (Aliases lore), SubExp)]
-> [(FParam (Aliases lore), SubExp)]
-> LoopForm (Aliases lore)
-> BodyT (Aliases lore)
-> ExpT (Aliases lore)
forall lore.
[(FParam lore, SubExp)]
-> [(FParam lore, SubExp)]
-> LoopForm lore
-> BodyT lore
-> ExpT lore
DoLoop [(Param DeclType, SubExp)]
[(FParam (Aliases lore), SubExp)]
ctx' [(Param DeclType, SubExp)]
[(FParam (Aliases lore), SubExp)]
val' LoopForm (Aliases lore)
form BodyT (Aliases lore)
body'
]
[Stm (Aliases lore)]
-> [Stm (Aliases lore)] -> [Stm (Aliases lore)]
forall a. [a] -> [a] -> [a]
++ [Stm (Aliases lore)]
postbnds
lowerUpdate
Scope (Aliases lore)
_
(Let Pattern (Aliases lore)
pat StmAux (ExpDec (Aliases lore))
aux (BasicOp (SubExp (Var VName
v))))
[DesiredUpdate VName
bindee_nm LetDec (Aliases lore)
bindee_dec Certificates
cs VName
src Slice SubExp
is VName
val]
| PatternT (VarAliases, Type) -> [VName]
forall dec. PatternT dec -> [VName]
patternNames PatternT (VarAliases, Type)
Pattern (Aliases lore)
pat [VName] -> [VName] -> Bool
forall a. Eq a => a -> a -> Bool
== [VName
src] =
let is' :: Slice SubExp
is' = Type -> Slice SubExp -> Slice SubExp
fullSlice ((VarAliases, Type) -> Type
forall t. Typed t => t -> Type
typeOf (VarAliases, Type)
LetDec (Aliases lore)
bindee_dec) Slice SubExp
is
in m [Stm (Aliases lore)] -> Maybe (m [Stm (Aliases lore)])
forall a. a -> Maybe a
Just (m [Stm (Aliases lore)] -> Maybe (m [Stm (Aliases lore)]))
-> m [Stm (Aliases lore)] -> Maybe (m [Stm (Aliases lore)])
forall a b. (a -> b) -> a -> b
$
[Stm (Aliases lore)] -> m [Stm (Aliases lore)]
forall (m :: * -> *) a. Monad m => a -> m a
return
[ Certificates -> Stm (Aliases lore) -> Stm (Aliases lore)
forall lore. Certificates -> Stm lore -> Stm lore
certify (StmAux (VarAliases, ExpDec lore) -> Certificates
forall dec. StmAux dec -> Certificates
stmAuxCerts StmAux (VarAliases, ExpDec lore)
StmAux (ExpDec (Aliases lore))
aux Certificates -> Certificates -> Certificates
forall a. Semigroup a => a -> a -> a
<> Certificates
cs) (Stm (Aliases lore) -> Stm (Aliases lore))
-> Stm (Aliases lore) -> Stm (Aliases lore)
forall a b. (a -> b) -> a -> b
$
[Ident] -> [Ident] -> ExpT (Aliases lore) -> Stm (Aliases lore)
forall lore.
Bindable lore =>
[Ident] -> [Ident] -> Exp lore -> Stm lore
mkLet [] [VName -> Type -> Ident
Ident VName
bindee_nm (Type -> Ident) -> Type -> Ident
forall a b. (a -> b) -> a -> b
$ (VarAliases, Type) -> Type
forall t. Typed t => t -> Type
typeOf (VarAliases, Type)
LetDec (Aliases lore)
bindee_dec] (ExpT (Aliases lore) -> Stm (Aliases lore))
-> ExpT (Aliases lore) -> Stm (Aliases lore)
forall a b. (a -> b) -> a -> b
$
BasicOp -> ExpT (Aliases lore)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT (Aliases lore)) -> BasicOp -> ExpT (Aliases lore)
forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> SubExp -> BasicOp
Update VName
v Slice SubExp
is' (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
val
]
lowerUpdate Scope (Aliases lore)
_ Stm (Aliases lore)
_ [DesiredUpdate (LetDec (Aliases lore))]
_ =
Maybe (m [Stm (Aliases lore)])
forall a. Maybe a
Nothing
lowerUpdateKernels :: MonadFreshNames m => LowerUpdate Kernels m
lowerUpdateKernels :: forall (m :: * -> *). MonadFreshNames m => LowerUpdate Kernels m
lowerUpdateKernels
Scope (Aliases Kernels)
scope
(Let Pattern (Aliases Kernels)
pat StmAux (ExpDec (Aliases Kernels))
aux (Op (SegOp (SegMap SegLevel
lvl SegSpace
space [Type]
ts KernelBody (Aliases Kernels)
kbody))))
[DesiredUpdate (LetDec (Aliases Kernels))]
updates
| (DesiredUpdate (VarAliases, Type) -> Bool)
-> [DesiredUpdate (VarAliases, Type)] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all ((VName -> [VName] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` PatternT (VarAliases, Type) -> [VName]
forall dec. PatternT dec -> [VName]
patternNames PatternT (VarAliases, Type)
Pattern (Aliases Kernels)
pat) (VName -> Bool)
-> (DesiredUpdate (VarAliases, Type) -> VName)
-> DesiredUpdate (VarAliases, Type)
-> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. DesiredUpdate (VarAliases, Type) -> VName
forall dec. DesiredUpdate dec -> VName
updateValue) [DesiredUpdate (VarAliases, Type)]
[DesiredUpdate (LetDec (Aliases Kernels))]
updates,
Bool -> Bool
not Bool
source_used_in_kbody = do
m (PatternT (VarAliases, Type), KernelBody (Aliases Kernels),
Stms (Aliases Kernels))
mk <- Scope (Aliases Kernels)
-> Pattern (Aliases Kernels)
-> [DesiredUpdate (LetDec (Aliases Kernels))]
-> SegSpace
-> KernelBody (Aliases Kernels)
-> Maybe
(m (Pattern (Aliases Kernels), KernelBody (Aliases Kernels),
Stms (Aliases Kernels)))
forall (m :: * -> *).
MonadFreshNames m =>
Scope (Aliases Kernels)
-> Pattern (Aliases Kernels)
-> [DesiredUpdate (LetDec (Aliases Kernels))]
-> SegSpace
-> KernelBody (Aliases Kernels)
-> Maybe
(m (Pattern (Aliases Kernels), KernelBody (Aliases Kernels),
Stms (Aliases Kernels)))
lowerUpdatesIntoSegMap Scope (Aliases Kernels)
scope Pattern (Aliases Kernels)
pat [DesiredUpdate (LetDec (Aliases Kernels))]
updates SegSpace
space KernelBody (Aliases Kernels)
kbody
m [Stm (Aliases Kernels)] -> Maybe (m [Stm (Aliases Kernels)])
forall a. a -> Maybe a
Just (m [Stm (Aliases Kernels)] -> Maybe (m [Stm (Aliases Kernels)]))
-> m [Stm (Aliases Kernels)] -> Maybe (m [Stm (Aliases Kernels)])
forall a b. (a -> b) -> a -> b
$ do
(PatternT (VarAliases, Type)
pat', KernelBody (Aliases Kernels)
kbody', Stms (Aliases Kernels)
poststms) <- m (PatternT (VarAliases, Type), KernelBody (Aliases Kernels),
Stms (Aliases Kernels))
mk
let cs :: Certificates
cs = StmAux (VarAliases, ()) -> Certificates
forall dec. StmAux dec -> Certificates
stmAuxCerts StmAux (VarAliases, ())
StmAux (ExpDec (Aliases Kernels))
aux Certificates -> Certificates -> Certificates
forall a. Semigroup a => a -> a -> a
<> (DesiredUpdate (VarAliases, Type) -> Certificates)
-> [DesiredUpdate (VarAliases, Type)] -> Certificates
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap DesiredUpdate (VarAliases, Type) -> Certificates
forall dec. DesiredUpdate dec -> Certificates
updateCertificates [DesiredUpdate (VarAliases, Type)]
[DesiredUpdate (LetDec (Aliases Kernels))]
updates
[Stm (Aliases Kernels)] -> m [Stm (Aliases Kernels)]
forall (m :: * -> *) a. Monad m => a -> m a
return ([Stm (Aliases Kernels)] -> m [Stm (Aliases Kernels)])
-> [Stm (Aliases Kernels)] -> m [Stm (Aliases Kernels)]
forall a b. (a -> b) -> a -> b
$
Certificates -> Stm (Aliases Kernels) -> Stm (Aliases Kernels)
forall lore. Certificates -> Stm lore -> Stm lore
certify Certificates
cs (Pattern (Aliases Kernels)
-> StmAux (ExpDec (Aliases Kernels))
-> ExpT (Aliases Kernels)
-> Stm (Aliases Kernels)
forall lore.
Pattern lore -> StmAux (ExpDec lore) -> Exp lore -> Stm lore
Let PatternT (VarAliases, Type)
Pattern (Aliases Kernels)
pat' StmAux (ExpDec (Aliases Kernels))
aux (ExpT (Aliases Kernels) -> Stm (Aliases Kernels))
-> ExpT (Aliases Kernels) -> Stm (Aliases Kernels)
forall a b. (a -> b) -> a -> b
$ Op (Aliases Kernels) -> ExpT (Aliases Kernels)
forall lore. Op lore -> ExpT lore
Op (Op (Aliases Kernels) -> ExpT (Aliases Kernels))
-> Op (Aliases Kernels) -> ExpT (Aliases Kernels)
forall a b. (a -> b) -> a -> b
$ SegOp SegLevel (Aliases Kernels)
-> HostOp (Aliases Kernels) (SOAC (Aliases Kernels))
forall lore op. SegOp SegLevel lore -> HostOp lore op
SegOp (SegOp SegLevel (Aliases Kernels)
-> HostOp (Aliases Kernels) (SOAC (Aliases Kernels)))
-> SegOp SegLevel (Aliases Kernels)
-> HostOp (Aliases Kernels) (SOAC (Aliases Kernels))
forall a b. (a -> b) -> a -> b
$ SegLevel
-> SegSpace
-> [Type]
-> KernelBody (Aliases Kernels)
-> SegOp SegLevel (Aliases Kernels)
forall lvl lore.
lvl -> SegSpace -> [Type] -> KernelBody lore -> SegOp lvl lore
SegMap SegLevel
lvl SegSpace
space [Type]
ts KernelBody (Aliases Kernels)
kbody') Stm (Aliases Kernels)
-> [Stm (Aliases Kernels)] -> [Stm (Aliases Kernels)]
forall a. a -> [a] -> [a]
:
Stms (Aliases Kernels) -> [Stm (Aliases Kernels)]
forall lore. Stms lore -> [Stm lore]
stmsToList Stms (Aliases Kernels)
poststms
where
source_used_in_kbody :: Bool
source_used_in_kbody =
[Names] -> Names
forall a. Monoid a => [a] -> a
mconcat ((VName -> Names) -> [VName] -> [Names]
forall a b. (a -> b) -> [a] -> [b]
map (VName -> Scope (Aliases Kernels) -> Names
forall lore.
AliasesOf (LetDec lore) =>
VName -> Scope lore -> Names
`lookupAliases` Scope (Aliases Kernels)
scope) (Names -> [VName]
namesToList (KernelBody (Aliases Kernels) -> Names
forall a. FreeIn a => a -> Names
freeIn KernelBody (Aliases Kernels)
kbody)))
Names -> Names -> Bool
`namesIntersect` [Names] -> Names
forall a. Monoid a => [a] -> a
mconcat ((DesiredUpdate (VarAliases, Type) -> Names)
-> [DesiredUpdate (VarAliases, Type)] -> [Names]
forall a b. (a -> b) -> [a] -> [b]
map ((VName -> Scope (Aliases Kernels) -> Names
forall lore.
AliasesOf (LetDec lore) =>
VName -> Scope lore -> Names
`lookupAliases` Scope (Aliases Kernels)
scope) (VName -> Names)
-> (DesiredUpdate (VarAliases, Type) -> VName)
-> DesiredUpdate (VarAliases, Type)
-> Names
forall b c a. (b -> c) -> (a -> b) -> a -> c
. DesiredUpdate (VarAliases, Type) -> VName
forall dec. DesiredUpdate dec -> VName
updateSource) [DesiredUpdate (VarAliases, Type)]
[DesiredUpdate (LetDec (Aliases Kernels))]
updates)
lowerUpdateKernels Scope (Aliases Kernels)
scope Stm (Aliases Kernels)
stm [DesiredUpdate (LetDec (Aliases Kernels))]
updates = LowerUpdate Kernels m
forall (m :: * -> *) lore.
(MonadFreshNames m, Bindable lore, LetDec lore ~ Type,
CanBeAliased (Op lore)) =>
LowerUpdate lore m
lowerUpdate Scope (Aliases Kernels)
scope Stm (Aliases Kernels)
stm [DesiredUpdate (LetDec (Aliases Kernels))]
updates
lowerUpdatesIntoSegMap ::
MonadFreshNames m =>
Scope (Aliases Kernels) ->
Pattern (Aliases Kernels) ->
[DesiredUpdate (LetDec (Aliases Kernels))] ->
SegSpace ->
KernelBody (Aliases Kernels) ->
Maybe
( m
( Pattern (Aliases Kernels),
KernelBody (Aliases Kernels),
Stms (Aliases Kernels)
)
)
lowerUpdatesIntoSegMap :: forall (m :: * -> *).
MonadFreshNames m =>
Scope (Aliases Kernels)
-> Pattern (Aliases Kernels)
-> [DesiredUpdate (LetDec (Aliases Kernels))]
-> SegSpace
-> KernelBody (Aliases Kernels)
-> Maybe
(m (Pattern (Aliases Kernels), KernelBody (Aliases Kernels),
Stms (Aliases Kernels)))
lowerUpdatesIntoSegMap Scope (Aliases Kernels)
scope Pattern (Aliases Kernels)
pat [DesiredUpdate (LetDec (Aliases Kernels))]
updates SegSpace
kspace KernelBody (Aliases Kernels)
kbody = do
[m (PatElemT (VarAliases, Type), Stms (Aliases Kernels),
KernelResult, Stms (Aliases Kernels))]
mk <- (PatElemT (VarAliases, Type)
-> KernelResult
-> Maybe
(m (PatElemT (VarAliases, Type), Stms (Aliases Kernels),
KernelResult, Stms (Aliases Kernels))))
-> [PatElemT (VarAliases, Type)]
-> [KernelResult]
-> Maybe
[m (PatElemT (VarAliases, Type), Stms (Aliases Kernels),
KernelResult, Stms (Aliases Kernels))]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM PatElemT (VarAliases, Type)
-> KernelResult
-> Maybe
(m (PatElemT (VarAliases, Type), Stms (Aliases Kernels),
KernelResult, Stms (Aliases Kernels)))
onRet (PatternT (VarAliases, Type) -> [PatElemT (VarAliases, Type)]
forall dec. PatternT dec -> [PatElemT dec]
patternElements PatternT (VarAliases, Type)
Pattern (Aliases Kernels)
pat) (KernelBody (Aliases Kernels) -> [KernelResult]
forall lore. KernelBody lore -> [KernelResult]
kernelBodyResult KernelBody (Aliases Kernels)
kbody)
m (PatternT (VarAliases, Type), KernelBody (Aliases Kernels),
Stms (Aliases Kernels))
-> Maybe
(m (PatternT (VarAliases, Type), KernelBody (Aliases Kernels),
Stms (Aliases Kernels)))
forall (m :: * -> *) a. Monad m => a -> m a
return (m (PatternT (VarAliases, Type), KernelBody (Aliases Kernels),
Stms (Aliases Kernels))
-> Maybe
(m (PatternT (VarAliases, Type), KernelBody (Aliases Kernels),
Stms (Aliases Kernels))))
-> m (PatternT (VarAliases, Type), KernelBody (Aliases Kernels),
Stms (Aliases Kernels))
-> Maybe
(m (PatternT (VarAliases, Type), KernelBody (Aliases Kernels),
Stms (Aliases Kernels)))
forall a b. (a -> b) -> a -> b
$ do
([PatElemT (VarAliases, Type)]
pes, [Stms (Aliases Kernels)]
bodystms, [KernelResult]
krets, [Stms (Aliases Kernels)]
poststms) <- [(PatElemT (VarAliases, Type), Stms (Aliases Kernels),
KernelResult, Stms (Aliases Kernels))]
-> ([PatElemT (VarAliases, Type)], [Stms (Aliases Kernels)],
[KernelResult], [Stms (Aliases Kernels)])
forall a b c d. [(a, b, c, d)] -> ([a], [b], [c], [d])
unzip4 ([(PatElemT (VarAliases, Type), Stms (Aliases Kernels),
KernelResult, Stms (Aliases Kernels))]
-> ([PatElemT (VarAliases, Type)], [Stms (Aliases Kernels)],
[KernelResult], [Stms (Aliases Kernels)]))
-> m [(PatElemT (VarAliases, Type), Stms (Aliases Kernels),
KernelResult, Stms (Aliases Kernels))]
-> m ([PatElemT (VarAliases, Type)], [Stms (Aliases Kernels)],
[KernelResult], [Stms (Aliases Kernels)])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [m (PatElemT (VarAliases, Type), Stms (Aliases Kernels),
KernelResult, Stms (Aliases Kernels))]
-> m [(PatElemT (VarAliases, Type), Stms (Aliases Kernels),
KernelResult, Stms (Aliases Kernels))]
forall (t :: * -> *) (m :: * -> *) a.
(Traversable t, Monad m) =>
t (m a) -> m (t a)
sequence [m (PatElemT (VarAliases, Type), Stms (Aliases Kernels),
KernelResult, Stms (Aliases Kernels))]
mk
(PatternT (VarAliases, Type), KernelBody (Aliases Kernels),
Stms (Aliases Kernels))
-> m (PatternT (VarAliases, Type), KernelBody (Aliases Kernels),
Stms (Aliases Kernels))
forall (m :: * -> *) a. Monad m => a -> m a
return
( [PatElemT (VarAliases, Type)]
-> [PatElemT (VarAliases, Type)] -> PatternT (VarAliases, Type)
forall dec. [PatElemT dec] -> [PatElemT dec] -> PatternT dec
Pattern [] [PatElemT (VarAliases, Type)]
pes,
KernelBody (Aliases Kernels)
kbody
{ kernelBodyStms :: Stms (Aliases Kernels)
kernelBodyStms = KernelBody (Aliases Kernels) -> Stms (Aliases Kernels)
forall lore. KernelBody lore -> Stms lore
kernelBodyStms KernelBody (Aliases Kernels)
kbody Stms (Aliases Kernels)
-> Stms (Aliases Kernels) -> Stms (Aliases Kernels)
forall a. Semigroup a => a -> a -> a
<> [Stms (Aliases Kernels)] -> Stms (Aliases Kernels)
forall a. Monoid a => [a] -> a
mconcat [Stms (Aliases Kernels)]
bodystms,
kernelBodyResult :: [KernelResult]
kernelBodyResult = [KernelResult]
krets
},
[Stms (Aliases Kernels)] -> Stms (Aliases Kernels)
forall a. Monoid a => [a] -> a
mconcat [Stms (Aliases Kernels)]
poststms
)
where
([VName]
gtids, [SubExp]
_dims) = [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(VName, SubExp)] -> ([VName], [SubExp]))
-> [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
kspace
onRet :: PatElemT (VarAliases, Type)
-> KernelResult
-> Maybe
(m (PatElemT (VarAliases, Type), Stms (Aliases Kernels),
KernelResult, Stms (Aliases Kernels)))
onRet (PatElem VName
v (VarAliases, Type)
v_dec) KernelResult
ret
| Just (DesiredUpdate VName
bindee_nm (VarAliases, Type)
bindee_dec Certificates
_cs VName
src Slice SubExp
slice VName
_val) <-
(DesiredUpdate (VarAliases, Type) -> Bool)
-> [DesiredUpdate (VarAliases, Type)]
-> Maybe (DesiredUpdate (VarAliases, Type))
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
find ((VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== VName
v) (VName -> Bool)
-> (DesiredUpdate (VarAliases, Type) -> VName)
-> DesiredUpdate (VarAliases, Type)
-> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. DesiredUpdate (VarAliases, Type) -> VName
forall dec. DesiredUpdate dec -> VName
updateValue) [DesiredUpdate (VarAliases, Type)]
[DesiredUpdate (LetDec (Aliases Kernels))]
updates = do
Returns ResultManifest
_ SubExp
se <- KernelResult -> Maybe KernelResult
forall a. a -> Maybe a
Just KernelResult
ret
Bool -> Maybe ()
forall (f :: * -> *). Alternative f => Bool -> f ()
guard (Bool -> Maybe ()) -> Bool -> Maybe ()
forall a b. (a -> b) -> a -> b
$
let ([SubExp]
dims', Slice SubExp
slice') =
[(SubExp, DimIndex SubExp)] -> ([SubExp], Slice SubExp)
forall a b. [(a, b)] -> ([a], [b])
unzip ([(SubExp, DimIndex SubExp)] -> ([SubExp], Slice SubExp))
-> ([(SubExp, DimIndex SubExp)] -> [(SubExp, DimIndex SubExp)])
-> [(SubExp, DimIndex SubExp)]
-> ([SubExp], Slice SubExp)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> [(SubExp, DimIndex SubExp)] -> [(SubExp, DimIndex SubExp)]
forall a. Int -> [a] -> [a]
drop ([VName] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [VName]
gtids) ([(SubExp, DimIndex SubExp)] -> [(SubExp, DimIndex SubExp)])
-> ([(SubExp, DimIndex SubExp)] -> [(SubExp, DimIndex SubExp)])
-> [(SubExp, DimIndex SubExp)]
-> [(SubExp, DimIndex SubExp)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((SubExp, DimIndex SubExp) -> Bool)
-> [(SubExp, DimIndex SubExp)] -> [(SubExp, DimIndex SubExp)]
forall a. (a -> Bool) -> [a] -> [a]
filter (Maybe SubExp -> Bool
forall a. Maybe a -> Bool
isNothing (Maybe SubExp -> Bool)
-> ((SubExp, DimIndex SubExp) -> Maybe SubExp)
-> (SubExp, DimIndex SubExp)
-> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. DimIndex SubExp -> Maybe SubExp
forall d. DimIndex d -> Maybe d
dimFix (DimIndex SubExp -> Maybe SubExp)
-> ((SubExp, DimIndex SubExp) -> DimIndex SubExp)
-> (SubExp, DimIndex SubExp)
-> Maybe SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (SubExp, DimIndex SubExp) -> DimIndex SubExp
forall a b. (a, b) -> b
snd) ([(SubExp, DimIndex SubExp)] -> ([SubExp], Slice SubExp))
-> [(SubExp, DimIndex SubExp)] -> ([SubExp], Slice SubExp)
forall a b. (a -> b) -> a -> b
$
[SubExp] -> Slice SubExp -> [(SubExp, DimIndex SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip (Type -> [SubExp]
forall u. TypeBase Shape u -> [SubExp]
arrayDims ((VarAliases, Type) -> Type
forall t. Typed t => t -> Type
typeOf (VarAliases, Type)
bindee_dec)) Slice SubExp
slice
in Shape -> Slice SubExp -> Bool
isFullSlice ([SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape [SubExp]
dims') Slice SubExp
slice'
m (PatElemT (VarAliases, Type), Stms (Aliases Kernels),
KernelResult, Stms (Aliases Kernels))
-> Maybe
(m (PatElemT (VarAliases, Type), Stms (Aliases Kernels),
KernelResult, Stms (Aliases Kernels)))
forall a. a -> Maybe a
Just (m (PatElemT (VarAliases, Type), Stms (Aliases Kernels),
KernelResult, Stms (Aliases Kernels))
-> Maybe
(m (PatElemT (VarAliases, Type), Stms (Aliases Kernels),
KernelResult, Stms (Aliases Kernels))))
-> m (PatElemT (VarAliases, Type), Stms (Aliases Kernels),
KernelResult, Stms (Aliases Kernels))
-> Maybe
(m (PatElemT (VarAliases, Type), Stms (Aliases Kernels),
KernelResult, Stms (Aliases Kernels)))
forall a b. (a -> b) -> a -> b
$ do
([SubExp]
slice', Stms (Aliases Kernels)
bodystms) <-
(BinderT (Aliases Kernels) m [SubExp]
-> Scope (Aliases Kernels) -> m ([SubExp], Stms (Aliases Kernels)))
-> Scope (Aliases Kernels)
-> BinderT (Aliases Kernels) m [SubExp]
-> m ([SubExp], Stms (Aliases Kernels))
forall a b c. (a -> b -> c) -> b -> a -> c
flip BinderT (Aliases Kernels) m [SubExp]
-> Scope (Aliases Kernels) -> m ([SubExp], Stms (Aliases Kernels))
forall (m :: * -> *) lore a.
MonadFreshNames m =>
BinderT lore m a -> Scope lore -> m (a, Stms lore)
runBinderT Scope (Aliases Kernels)
scope (BinderT (Aliases Kernels) m [SubExp]
-> m ([SubExp], Stms (Aliases Kernels)))
-> BinderT (Aliases Kernels) m [SubExp]
-> m ([SubExp], Stms (Aliases Kernels))
forall a b. (a -> b) -> a -> b
$
(TPrimExp Int64 VName -> BinderT (Aliases Kernels) m SubExp)
-> [TPrimExp Int64 VName] -> BinderT (Aliases Kernels) m [SubExp]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse (String
-> TPrimExp Int64 VName -> BinderT (Aliases Kernels) m SubExp
forall (m :: * -> *) a.
(MonadBinder m, ToExp a) =>
String -> a -> m SubExp
toSubExp String
"index") ([TPrimExp Int64 VName] -> BinderT (Aliases Kernels) m [SubExp])
-> [TPrimExp Int64 VName] -> BinderT (Aliases Kernels) m [SubExp]
forall a b. (a -> b) -> a -> b
$
Slice (TPrimExp Int64 VName)
-> [TPrimExp Int64 VName] -> [TPrimExp Int64 VName]
forall d. Num d => Slice d -> [d] -> [d]
fixSlice ((DimIndex SubExp -> DimIndex (TPrimExp Int64 VName))
-> Slice SubExp -> Slice (TPrimExp Int64 VName)
forall a b. (a -> b) -> [a] -> [b]
map ((SubExp -> TPrimExp Int64 VName)
-> DimIndex SubExp -> DimIndex (TPrimExp Int64 VName)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap SubExp -> TPrimExp Int64 VName
pe64) Slice SubExp
slice) ([TPrimExp Int64 VName] -> [TPrimExp Int64 VName])
-> [TPrimExp Int64 VName] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> a -> b
$
(VName -> TPrimExp Int64 VName)
-> [VName] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map (SubExp -> TPrimExp Int64 VName
pe64 (SubExp -> TPrimExp Int64 VName)
-> (VName -> SubExp) -> VName -> TPrimExp Int64 VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> SubExp
Var) [VName]
gtids
let res_dims :: [SubExp]
res_dims = Type -> [SubExp]
forall u. TypeBase Shape u -> [SubExp]
arrayDims (Type -> [SubExp]) -> Type -> [SubExp]
forall a b. (a -> b) -> a -> b
$ (VarAliases, Type) -> Type
forall a b. (a, b) -> b
snd (VarAliases, Type)
bindee_dec
ret' :: KernelResult
ret' = Shape -> VName -> [(Slice SubExp, SubExp)] -> KernelResult
WriteReturns ([SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape [SubExp]
res_dims) VName
src [((SubExp -> DimIndex SubExp) -> [SubExp] -> Slice SubExp
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix [SubExp]
slice', SubExp
se)]
(PatElemT (VarAliases, Type), Stms (Aliases Kernels), KernelResult,
Stms (Aliases Kernels))
-> m (PatElemT (VarAliases, Type), Stms (Aliases Kernels),
KernelResult, Stms (Aliases Kernels))
forall (m :: * -> *) a. Monad m => a -> m a
return
( VName -> (VarAliases, Type) -> PatElemT (VarAliases, Type)
forall dec. VName -> dec -> PatElemT dec
PatElem VName
bindee_nm (VarAliases, Type)
bindee_dec,
Stms (Aliases Kernels)
bodystms,
KernelResult
ret',
Stm (Aliases Kernels) -> Stms (Aliases Kernels)
forall lore. Stm lore -> Stms lore
oneStm (Stm (Aliases Kernels) -> Stms (Aliases Kernels))
-> Stm (Aliases Kernels) -> Stms (Aliases Kernels)
forall a b. (a -> b) -> a -> b
$
[Ident]
-> [Ident] -> ExpT (Aliases Kernels) -> Stm (Aliases Kernels)
forall lore.
Bindable lore =>
[Ident] -> [Ident] -> Exp lore -> Stm lore
mkLet [] [VName -> Type -> Ident
Ident VName
v (Type -> Ident) -> Type -> Ident
forall a b. (a -> b) -> a -> b
$ (VarAliases, Type) -> Type
forall t. Typed t => t -> Type
typeOf (VarAliases, Type)
v_dec] (ExpT (Aliases Kernels) -> Stm (Aliases Kernels))
-> ExpT (Aliases Kernels) -> Stm (Aliases Kernels)
forall a b. (a -> b) -> a -> b
$
BasicOp -> ExpT (Aliases Kernels)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT (Aliases Kernels))
-> BasicOp -> ExpT (Aliases Kernels)
forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
bindee_nm Slice SubExp
slice
)
onRet PatElemT (VarAliases, Type)
pe KernelResult
ret =
m (PatElemT (VarAliases, Type), Stms (Aliases Kernels),
KernelResult, Stms (Aliases Kernels))
-> Maybe
(m (PatElemT (VarAliases, Type), Stms (Aliases Kernels),
KernelResult, Stms (Aliases Kernels)))
forall a. a -> Maybe a
Just (m (PatElemT (VarAliases, Type), Stms (Aliases Kernels),
KernelResult, Stms (Aliases Kernels))
-> Maybe
(m (PatElemT (VarAliases, Type), Stms (Aliases Kernels),
KernelResult, Stms (Aliases Kernels))))
-> m (PatElemT (VarAliases, Type), Stms (Aliases Kernels),
KernelResult, Stms (Aliases Kernels))
-> Maybe
(m (PatElemT (VarAliases, Type), Stms (Aliases Kernels),
KernelResult, Stms (Aliases Kernels)))
forall a b. (a -> b) -> a -> b
$ (PatElemT (VarAliases, Type), Stms (Aliases Kernels), KernelResult,
Stms (Aliases Kernels))
-> m (PatElemT (VarAliases, Type), Stms (Aliases Kernels),
KernelResult, Stms (Aliases Kernels))
forall (m :: * -> *) a. Monad m => a -> m a
return (PatElemT (VarAliases, Type)
pe, Stms (Aliases Kernels)
forall a. Monoid a => a
mempty, KernelResult
ret, Stms (Aliases Kernels)
forall a. Monoid a => a
mempty)
lowerUpdateIntoLoop ::
( Bindable lore,
BinderOps lore,
Aliased lore,
LetDec lore ~ (als, Type),
MonadFreshNames m
) =>
Scope lore ->
[DesiredUpdate (LetDec lore)] ->
Pattern lore ->
[(FParam lore, SubExp)] ->
[(FParam lore, SubExp)] ->
LoopForm lore ->
Body lore ->
Maybe
( m
( [Stm lore],
[Stm lore],
[Ident],
[Ident],
[(FParam lore, SubExp)],
[(FParam lore, SubExp)],
Body lore
)
)
lowerUpdateIntoLoop :: forall lore als (m :: * -> *).
(Bindable lore, BinderOps lore, Aliased lore,
LetDec lore ~ (als, Type), MonadFreshNames m) =>
Scope lore
-> [DesiredUpdate (LetDec lore)]
-> Pattern lore
-> [(FParam lore, SubExp)]
-> [(FParam lore, SubExp)]
-> LoopForm lore
-> Body lore
-> Maybe
(m ([Stm lore], [Stm lore], [Ident], [Ident],
[(FParam lore, SubExp)], [(FParam lore, SubExp)], Body lore))
lowerUpdateIntoLoop Scope lore
scope [DesiredUpdate (LetDec lore)]
updates Pattern lore
pat [(FParam lore, SubExp)]
ctx [(FParam lore, SubExp)]
val LoopForm lore
form Body lore
body = do
[((Param DeclType, SubExp), Names)]
-> (((Param DeclType, SubExp), Names) -> Maybe ()) -> Maybe ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([(Param DeclType, SubExp)]
-> [Names] -> [((Param DeclType, SubExp), Names)]
forall a b. [a] -> [b] -> [(a, b)]
zip [(Param DeclType, SubExp)]
[(FParam lore, SubExp)]
val ([Names] -> [((Param DeclType, SubExp), Names)])
-> [Names] -> [((Param DeclType, SubExp), Names)]
forall a b. (a -> b) -> a -> b
$ Body lore -> [Names]
forall lore. Aliased lore => Body lore -> [Names]
bodyAliases Body lore
body) ((((Param DeclType, SubExp), Names) -> Maybe ()) -> Maybe ())
-> (((Param DeclType, SubExp), Names) -> Maybe ()) -> Maybe ()
forall a b. (a -> b) -> a -> b
$ \((Param DeclType
p, SubExp
_), Names
als) ->
Bool -> Maybe ()
forall (f :: * -> *). Alternative f => Bool -> f ()
guard (Bool -> Maybe ()) -> Bool -> Maybe ()
forall a b. (a -> b) -> a -> b
$ Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ Param DeclType -> VName
forall dec. Param dec -> VName
paramName Param DeclType
p VName -> Names -> Bool
`nameIn` Names
als
m [LoopResultSummary (als, Type)]
mk_in_place_map <- Scope lore
-> [DesiredUpdate (als, Type)]
-> Names
-> [(SubExp, Ident)]
-> [(Param DeclType, SubExp)]
-> Maybe (m [LoopResultSummary (als, Type)])
forall lore (m :: * -> *) als.
(Aliased lore, MonadFreshNames m) =>
Scope lore
-> [DesiredUpdate (als, Type)]
-> Names
-> [(SubExp, Ident)]
-> [(Param DeclType, SubExp)]
-> Maybe (m [LoopResultSummary (als, Type)])
summariseLoop Scope lore
scope [DesiredUpdate (als, Type)]
[DesiredUpdate (LetDec lore)]
updates Names
usedInBody [(SubExp, Ident)]
resmap [(Param DeclType, SubExp)]
[(FParam lore, SubExp)]
val
m ([Stm lore], [Stm lore], [Ident], [Ident],
[(Param DeclType, SubExp)], [(Param DeclType, SubExp)], Body lore)
-> Maybe
(m ([Stm lore], [Stm lore], [Ident], [Ident],
[(Param DeclType, SubExp)], [(Param DeclType, SubExp)], Body lore))
forall a. a -> Maybe a
Just (m ([Stm lore], [Stm lore], [Ident], [Ident],
[(Param DeclType, SubExp)], [(Param DeclType, SubExp)], Body lore)
-> Maybe
(m ([Stm lore], [Stm lore], [Ident], [Ident],
[(Param DeclType, SubExp)], [(Param DeclType, SubExp)],
Body lore)))
-> m ([Stm lore], [Stm lore], [Ident], [Ident],
[(Param DeclType, SubExp)], [(Param DeclType, SubExp)], Body lore)
-> Maybe
(m ([Stm lore], [Stm lore], [Ident], [Ident],
[(Param DeclType, SubExp)], [(Param DeclType, SubExp)], Body lore))
forall a b. (a -> b) -> a -> b
$ do
[LoopResultSummary (als, Type)]
in_place_map <- m [LoopResultSummary (als, Type)]
mk_in_place_map
([(Param DeclType, SubExp)]
val', [Stm lore]
prebnds, [Stm lore]
postbnds) <- [LoopResultSummary (als, Type)]
-> m ([(Param DeclType, SubExp)], [Stm lore], [Stm lore])
forall (m :: * -> *) lore als.
(MonadFreshNames m, Bindable lore) =>
[LoopResultSummary (als, Type)]
-> m ([(Param DeclType, SubExp)], [Stm lore], [Stm lore])
mkMerges [LoopResultSummary (als, Type)]
in_place_map
let ([Ident]
ctxpat, [Ident]
valpat) = [LoopResultSummary (als, Type)] -> ([Ident], [Ident])
mkResAndPat [LoopResultSummary (als, Type)]
in_place_map
idxsubsts :: IndexSubstitutions (als, Type)
idxsubsts = [LoopResultSummary (als, Type)] -> IndexSubstitutions (als, Type)
forall dec. [LoopResultSummary dec] -> IndexSubstitutions dec
indexSubstitutions [LoopResultSummary (als, Type)]
in_place_map
(IndexSubstitutions (als, Type)
idxsubsts', Stms lore
newbnds) <- IndexSubstitutions (als, Type)
-> Stms lore -> m (IndexSubstitutions (als, Type), Stms lore)
forall (m :: * -> *) lore dec.
(MonadFreshNames m, BinderOps lore, Bindable lore, Aliased lore,
LetDec lore ~ dec) =>
IndexSubstitutions dec
-> Stms lore -> m (IndexSubstitutions dec, Stms lore)
substituteIndices IndexSubstitutions (als, Type)
idxsubsts (Stms lore -> m (IndexSubstitutions (als, Type), Stms lore))
-> Stms lore -> m (IndexSubstitutions (als, Type), Stms lore)
forall a b. (a -> b) -> a -> b
$ Body lore -> Stms lore
forall lore. BodyT lore -> Stms lore
bodyStms Body lore
body
([SubExp]
body_res, Stms lore
res_bnds) <- [LoopResultSummary (LetDec lore)]
-> IndexSubstitutions (LetDec lore) -> m ([SubExp], Stms lore)
forall lore (m :: * -> *).
(Bindable lore, MonadFreshNames m) =>
[LoopResultSummary (LetDec lore)]
-> IndexSubstitutions (LetDec lore) -> m ([SubExp], Stms lore)
manipulateResult [LoopResultSummary (als, Type)]
[LoopResultSummary (LetDec lore)]
in_place_map IndexSubstitutions (als, Type)
IndexSubstitutions (LetDec lore)
idxsubsts'
let body' :: Body lore
body' = Stms lore -> [SubExp] -> Body lore
forall lore. Bindable lore => Stms lore -> [SubExp] -> Body lore
mkBody (Stms lore
newbnds Stms lore -> Stms lore -> Stms lore
forall a. Semigroup a => a -> a -> a
<> Stms lore
res_bnds) [SubExp]
body_res
([Stm lore], [Stm lore], [Ident], [Ident],
[(Param DeclType, SubExp)], [(Param DeclType, SubExp)], Body lore)
-> m ([Stm lore], [Stm lore], [Ident], [Ident],
[(Param DeclType, SubExp)], [(Param DeclType, SubExp)], Body lore)
forall (m :: * -> *) a. Monad m => a -> m a
return ([Stm lore]
prebnds, [Stm lore]
postbnds, [Ident]
ctxpat, [Ident]
valpat, [(Param DeclType, SubExp)]
[(FParam lore, SubExp)]
ctx, [(Param DeclType, SubExp)]
val', Body lore
body')
where
usedInBody :: Names
usedInBody =
[Names] -> Names
forall a. Monoid a => [a] -> a
mconcat ([Names] -> Names) -> [Names] -> Names
forall a b. (a -> b) -> a -> b
$ (VName -> Names) -> [VName] -> [Names]
forall a b. (a -> b) -> [a] -> [b]
map (VName -> Scope lore -> Names
forall lore.
AliasesOf (LetDec lore) =>
VName -> Scope lore -> Names
`lookupAliases` Scope lore
scope) ([VName] -> [Names]) -> [VName] -> [Names]
forall a b. (a -> b) -> a -> b
$ Names -> [VName]
namesToList (Names -> [VName]) -> Names -> [VName]
forall a b. (a -> b) -> a -> b
$ Body lore -> Names
forall a. FreeIn a => a -> Names
freeIn Body lore
body Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> LoopForm lore -> Names
forall a. FreeIn a => a -> Names
freeIn LoopForm lore
form
resmap :: [(SubExp, Ident)]
resmap = [SubExp] -> [Ident] -> [(SubExp, Ident)]
forall a b. [a] -> [b] -> [(a, b)]
zip (Body lore -> [SubExp]
forall lore. BodyT lore -> [SubExp]
bodyResult Body lore
body) ([Ident] -> [(SubExp, Ident)]) -> [Ident] -> [(SubExp, Ident)]
forall a b. (a -> b) -> a -> b
$ PatternT (als, Type) -> [Ident]
forall dec. Typed dec => PatternT dec -> [Ident]
patternValueIdents PatternT (als, Type)
Pattern lore
pat
mkMerges ::
(MonadFreshNames m, Bindable lore) =>
[LoopResultSummary (als, Type)] ->
m ([(Param DeclType, SubExp)], [Stm lore], [Stm lore])
mkMerges :: forall (m :: * -> *) lore als.
(MonadFreshNames m, Bindable lore) =>
[LoopResultSummary (als, Type)]
-> m ([(Param DeclType, SubExp)], [Stm lore], [Stm lore])
mkMerges [LoopResultSummary (als, Type)]
summaries = do
(([(Param DeclType, SubExp)]
origmerge, [(Param DeclType, SubExp)]
extramerge), ([Stm lore]
prebnds, [Stm lore]
postbnds)) <-
WriterT
([Stm lore], [Stm lore])
m
([(Param DeclType, SubExp)], [(Param DeclType, SubExp)])
-> m (([(Param DeclType, SubExp)], [(Param DeclType, SubExp)]),
([Stm lore], [Stm lore]))
forall w (m :: * -> *) a. WriterT w m a -> m (a, w)
runWriterT (WriterT
([Stm lore], [Stm lore])
m
([(Param DeclType, SubExp)], [(Param DeclType, SubExp)])
-> m (([(Param DeclType, SubExp)], [(Param DeclType, SubExp)]),
([Stm lore], [Stm lore])))
-> WriterT
([Stm lore], [Stm lore])
m
([(Param DeclType, SubExp)], [(Param DeclType, SubExp)])
-> m (([(Param DeclType, SubExp)], [(Param DeclType, SubExp)]),
([Stm lore], [Stm lore]))
forall a b. (a -> b) -> a -> b
$ [Either (Param DeclType, SubExp) (Param DeclType, SubExp)]
-> ([(Param DeclType, SubExp)], [(Param DeclType, SubExp)])
forall a b. [Either a b] -> ([a], [b])
partitionEithers ([Either (Param DeclType, SubExp) (Param DeclType, SubExp)]
-> ([(Param DeclType, SubExp)], [(Param DeclType, SubExp)]))
-> WriterT
([Stm lore], [Stm lore])
m
[Either (Param DeclType, SubExp) (Param DeclType, SubExp)]
-> WriterT
([Stm lore], [Stm lore])
m
([(Param DeclType, SubExp)], [(Param DeclType, SubExp)])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (LoopResultSummary (als, Type)
-> WriterT
([Stm lore], [Stm lore])
m
(Either (Param DeclType, SubExp) (Param DeclType, SubExp)))
-> [LoopResultSummary (als, Type)]
-> WriterT
([Stm lore], [Stm lore])
m
[Either (Param DeclType, SubExp) (Param DeclType, SubExp)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM LoopResultSummary (als, Type)
-> WriterT
([Stm lore], [Stm lore])
m
(Either (Param DeclType, SubExp) (Param DeclType, SubExp))
forall {m :: * -> *} {lore} {lore} {a}.
(MonadFreshNames m, MonadWriter ([Stm lore], [Stm lore]) m,
Bindable lore, Bindable lore) =>
LoopResultSummary (a, Type)
-> m (Either (Param DeclType, SubExp) (Param DeclType, SubExp))
mkMerge [LoopResultSummary (als, Type)]
summaries
([(Param DeclType, SubExp)], [Stm lore], [Stm lore])
-> m ([(Param DeclType, SubExp)], [Stm lore], [Stm lore])
forall (m :: * -> *) a. Monad m => a -> m a
return ([(Param DeclType, SubExp)]
origmerge [(Param DeclType, SubExp)]
-> [(Param DeclType, SubExp)] -> [(Param DeclType, SubExp)]
forall a. [a] -> [a] -> [a]
++ [(Param DeclType, SubExp)]
extramerge, [Stm lore]
prebnds, [Stm lore]
postbnds)
mkMerge :: LoopResultSummary (a, Type)
-> m (Either (Param DeclType, SubExp) (Param DeclType, SubExp))
mkMerge LoopResultSummary (a, Type)
summary
| Just (DesiredUpdate (a, Type)
update, VName
mergename, (a, Type)
mergedec) <- LoopResultSummary (a, Type)
-> Maybe (DesiredUpdate (a, Type), VName, (a, Type))
forall dec.
LoopResultSummary dec -> Maybe (DesiredUpdate dec, VName, dec)
relatedUpdate LoopResultSummary (a, Type)
summary = do
VName
source <- String -> m VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"modified_source"
let source_t :: Type
source_t = (a, Type) -> Type
forall a b. (a, b) -> b
snd ((a, Type) -> Type) -> (a, Type) -> Type
forall a b. (a -> b) -> a -> b
$ DesiredUpdate (a, Type) -> (a, Type)
forall dec. DesiredUpdate dec -> dec
updateType DesiredUpdate (a, Type)
update
elmident :: Ident
elmident =
VName -> Type -> Ident
Ident
(DesiredUpdate (a, Type) -> VName
forall dec. DesiredUpdate dec -> VName
updateValue DesiredUpdate (a, Type)
update)
(Type
source_t Type -> [SubExp] -> Type
forall oldshape u.
TypeBase oldshape u -> [SubExp] -> TypeBase Shape u
`setArrayDims` Slice SubExp -> [SubExp]
forall d. Slice d -> [d]
sliceDims (DesiredUpdate (a, Type) -> Slice SubExp
forall dec. DesiredUpdate dec -> Slice SubExp
updateIndices DesiredUpdate (a, Type)
update))
([Stm lore], [Stm lore]) -> m ()
forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell
( [ [Ident] -> [Ident] -> Exp lore -> Stm lore
forall lore.
Bindable lore =>
[Ident] -> [Ident] -> Exp lore -> Stm lore
mkLet [] [VName -> Type -> Ident
Ident VName
source Type
source_t] (Exp lore -> Stm lore) -> Exp lore -> Stm lore
forall a b. (a -> b) -> a -> b
$
BasicOp -> Exp lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp lore) -> BasicOp -> Exp lore
forall a b. (a -> b) -> a -> b
$
VName -> Slice SubExp -> SubExp -> BasicOp
Update
(DesiredUpdate (a, Type) -> VName
forall dec. DesiredUpdate dec -> VName
updateSource DesiredUpdate (a, Type)
update)
(Type -> Slice SubExp -> Slice SubExp
fullSlice Type
source_t (Slice SubExp -> Slice SubExp) -> Slice SubExp -> Slice SubExp
forall a b. (a -> b) -> a -> b
$ DesiredUpdate (a, Type) -> Slice SubExp
forall dec. DesiredUpdate dec -> Slice SubExp
updateIndices DesiredUpdate (a, Type)
update)
(SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ (Param DeclType, SubExp) -> SubExp
forall a b. (a, b) -> b
snd ((Param DeclType, SubExp) -> SubExp)
-> (Param DeclType, SubExp) -> SubExp
forall a b. (a -> b) -> a -> b
$ LoopResultSummary (a, Type) -> (Param DeclType, SubExp)
forall dec. LoopResultSummary dec -> (Param DeclType, SubExp)
mergeParam LoopResultSummary (a, Type)
summary
],
[ [Ident] -> [Ident] -> Exp lore -> Stm lore
forall lore.
Bindable lore =>
[Ident] -> [Ident] -> Exp lore -> Stm lore
mkLet [] [Ident
elmident] (Exp lore -> Stm lore) -> Exp lore -> Stm lore
forall a b. (a -> b) -> a -> b
$
BasicOp -> Exp lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp lore) -> BasicOp -> Exp lore
forall a b. (a -> b) -> a -> b
$
VName -> Slice SubExp -> BasicOp
Index
(DesiredUpdate (a, Type) -> VName
forall dec. DesiredUpdate dec -> VName
updateName DesiredUpdate (a, Type)
update)
(Type -> Slice SubExp -> Slice SubExp
fullSlice Type
source_t (Slice SubExp -> Slice SubExp) -> Slice SubExp -> Slice SubExp
forall a b. (a -> b) -> a -> b
$ DesiredUpdate (a, Type) -> Slice SubExp
forall dec. DesiredUpdate dec -> Slice SubExp
updateIndices DesiredUpdate (a, Type)
update)
]
)
Either (Param DeclType, SubExp) (Param DeclType, SubExp)
-> m (Either (Param DeclType, SubExp) (Param DeclType, SubExp))
forall (m :: * -> *) a. Monad m => a -> m a
return (Either (Param DeclType, SubExp) (Param DeclType, SubExp)
-> m (Either (Param DeclType, SubExp) (Param DeclType, SubExp)))
-> Either (Param DeclType, SubExp) (Param DeclType, SubExp)
-> m (Either (Param DeclType, SubExp) (Param DeclType, SubExp))
forall a b. (a -> b) -> a -> b
$
(Param DeclType, SubExp)
-> Either (Param DeclType, SubExp) (Param DeclType, SubExp)
forall a b. b -> Either a b
Right
( VName -> DeclType -> Param DeclType
forall dec. VName -> dec -> Param dec
Param
VName
mergename
(Type -> Uniqueness -> DeclType
forall shape.
TypeBase shape NoUniqueness
-> Uniqueness -> TypeBase shape Uniqueness
toDecl ((a, Type) -> Type
forall t. Typed t => t -> Type
typeOf (a, Type)
mergedec) Uniqueness
Unique),
VName -> SubExp
Var VName
source
)
| Bool
otherwise = Either (Param DeclType, SubExp) (Param DeclType, SubExp)
-> m (Either (Param DeclType, SubExp) (Param DeclType, SubExp))
forall (m :: * -> *) a. Monad m => a -> m a
return (Either (Param DeclType, SubExp) (Param DeclType, SubExp)
-> m (Either (Param DeclType, SubExp) (Param DeclType, SubExp)))
-> Either (Param DeclType, SubExp) (Param DeclType, SubExp)
-> m (Either (Param DeclType, SubExp) (Param DeclType, SubExp))
forall a b. (a -> b) -> a -> b
$ (Param DeclType, SubExp)
-> Either (Param DeclType, SubExp) (Param DeclType, SubExp)
forall a b. a -> Either a b
Left ((Param DeclType, SubExp)
-> Either (Param DeclType, SubExp) (Param DeclType, SubExp))
-> (Param DeclType, SubExp)
-> Either (Param DeclType, SubExp) (Param DeclType, SubExp)
forall a b. (a -> b) -> a -> b
$ LoopResultSummary (a, Type) -> (Param DeclType, SubExp)
forall dec. LoopResultSummary dec -> (Param DeclType, SubExp)
mergeParam LoopResultSummary (a, Type)
summary
mkResAndPat :: [LoopResultSummary (als, Type)] -> ([Ident], [Ident])
mkResAndPat [LoopResultSummary (als, Type)]
summaries =
let ([Ident]
origpat, [Ident]
extrapat) = [Either Ident Ident] -> ([Ident], [Ident])
forall a b. [Either a b] -> ([a], [b])
partitionEithers ([Either Ident Ident] -> ([Ident], [Ident]))
-> [Either Ident Ident] -> ([Ident], [Ident])
forall a b. (a -> b) -> a -> b
$ (LoopResultSummary (als, Type) -> Either Ident Ident)
-> [LoopResultSummary (als, Type)] -> [Either Ident Ident]
forall a b. (a -> b) -> [a] -> [b]
map LoopResultSummary (als, Type) -> Either Ident Ident
forall {a}. LoopResultSummary (a, Type) -> Either Ident Ident
mkResAndPat' [LoopResultSummary (als, Type)]
summaries
in ( PatternT (als, Type) -> [Ident]
forall dec. Typed dec => PatternT dec -> [Ident]
patternContextIdents PatternT (als, Type)
Pattern lore
pat,
[Ident]
origpat [Ident] -> [Ident] -> [Ident]
forall a. [a] -> [a] -> [a]
++ [Ident]
extrapat
)
mkResAndPat' :: LoopResultSummary (a, Type) -> Either Ident Ident
mkResAndPat' LoopResultSummary (a, Type)
summary
| Just (DesiredUpdate (a, Type)
update, VName
_, (a, Type)
_) <- LoopResultSummary (a, Type)
-> Maybe (DesiredUpdate (a, Type), VName, (a, Type))
forall dec.
LoopResultSummary dec -> Maybe (DesiredUpdate dec, VName, dec)
relatedUpdate LoopResultSummary (a, Type)
summary =
Ident -> Either Ident Ident
forall a b. b -> Either a b
Right (VName -> Type -> Ident
Ident (DesiredUpdate (a, Type) -> VName
forall dec. DesiredUpdate dec -> VName
updateName DesiredUpdate (a, Type)
update) ((a, Type) -> Type
forall a b. (a, b) -> b
snd ((a, Type) -> Type) -> (a, Type) -> Type
forall a b. (a -> b) -> a -> b
$ DesiredUpdate (a, Type) -> (a, Type)
forall dec. DesiredUpdate dec -> dec
updateType DesiredUpdate (a, Type)
update))
| Bool
otherwise =
Ident -> Either Ident Ident
forall a b. a -> Either a b
Left (LoopResultSummary (a, Type) -> Ident
forall dec. LoopResultSummary dec -> Ident
inPatternAs LoopResultSummary (a, Type)
summary)
summariseLoop ::
( Aliased lore,
MonadFreshNames m
) =>
Scope lore ->
[DesiredUpdate (als, Type)] ->
Names ->
[(SubExp, Ident)] ->
[(Param DeclType, SubExp)] ->
Maybe (m [LoopResultSummary (als, Type)])
summariseLoop :: forall lore (m :: * -> *) als.
(Aliased lore, MonadFreshNames m) =>
Scope lore
-> [DesiredUpdate (als, Type)]
-> Names
-> [(SubExp, Ident)]
-> [(Param DeclType, SubExp)]
-> Maybe (m [LoopResultSummary (als, Type)])
summariseLoop Scope lore
scope [DesiredUpdate (als, Type)]
updates Names
usedInBody [(SubExp, Ident)]
resmap [(Param DeclType, SubExp)]
merge =
[m (LoopResultSummary (als, Type))]
-> m [LoopResultSummary (als, Type)]
forall (t :: * -> *) (m :: * -> *) a.
(Traversable t, Monad m) =>
t (m a) -> m (t a)
sequence ([m (LoopResultSummary (als, Type))]
-> m [LoopResultSummary (als, Type)])
-> Maybe [m (LoopResultSummary (als, Type))]
-> Maybe (m [LoopResultSummary (als, Type)])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ((SubExp, Ident)
-> (Param DeclType, SubExp)
-> Maybe (m (LoopResultSummary (als, Type))))
-> [(SubExp, Ident)]
-> [(Param DeclType, SubExp)]
-> Maybe [m (LoopResultSummary (als, Type))]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM (SubExp, Ident)
-> (Param DeclType, SubExp)
-> Maybe (m (LoopResultSummary (als, Type)))
summariseLoopResult [(SubExp, Ident)]
resmap [(Param DeclType, SubExp)]
merge
where
summariseLoopResult :: (SubExp, Ident)
-> (Param DeclType, SubExp)
-> Maybe (m (LoopResultSummary (als, Type)))
summariseLoopResult (SubExp
se, Ident
v) (Param DeclType
fparam, SubExp
mergeinit)
| Just DesiredUpdate (als, Type)
update <- (DesiredUpdate (als, Type) -> Bool)
-> [DesiredUpdate (als, Type)] -> Maybe (DesiredUpdate (als, Type))
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
find (VName -> DesiredUpdate (als, Type) -> Bool
forall dec. VName -> DesiredUpdate dec -> Bool
updateHasValue (VName -> DesiredUpdate (als, Type) -> Bool)
-> VName -> DesiredUpdate (als, Type) -> Bool
forall a b. (a -> b) -> a -> b
$ Ident -> VName
identName Ident
v) [DesiredUpdate (als, Type)]
updates =
if Names
usedInBody Names -> Names -> Bool
`namesIntersect` VName -> Scope lore -> Names
forall lore.
AliasesOf (LetDec lore) =>
VName -> Scope lore -> Names
lookupAliases (DesiredUpdate (als, Type) -> VName
forall dec. DesiredUpdate dec -> VName
updateSource DesiredUpdate (als, Type)
update) Scope lore
scope
then Maybe (m (LoopResultSummary (als, Type)))
forall a. Maybe a
Nothing
else
if Param DeclType -> Bool
hasLoopInvariantShape Param DeclType
fparam
then m (LoopResultSummary (als, Type))
-> Maybe (m (LoopResultSummary (als, Type)))
forall a. a -> Maybe a
Just (m (LoopResultSummary (als, Type))
-> Maybe (m (LoopResultSummary (als, Type))))
-> m (LoopResultSummary (als, Type))
-> Maybe (m (LoopResultSummary (als, Type)))
forall a b. (a -> b) -> a -> b
$ do
VName
lowered_array <- String -> m VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"lowered_array"
LoopResultSummary (als, Type) -> m (LoopResultSummary (als, Type))
forall (m :: * -> *) a. Monad m => a -> m a
return
LoopResultSummary :: forall dec.
SubExp
-> Ident
-> (Param DeclType, SubExp)
-> Maybe (DesiredUpdate dec, VName, dec)
-> LoopResultSummary dec
LoopResultSummary
{ resultSubExp :: SubExp
resultSubExp = SubExp
se,
inPatternAs :: Ident
inPatternAs = Ident
v,
mergeParam :: (Param DeclType, SubExp)
mergeParam = (Param DeclType
fparam, SubExp
mergeinit),
relatedUpdate :: Maybe (DesiredUpdate (als, Type), VName, (als, Type))
relatedUpdate =
(DesiredUpdate (als, Type), VName, (als, Type))
-> Maybe (DesiredUpdate (als, Type), VName, (als, Type))
forall a. a -> Maybe a
Just
( DesiredUpdate (als, Type)
update,
VName
lowered_array,
DesiredUpdate (als, Type) -> (als, Type)
forall dec. DesiredUpdate dec -> dec
updateType DesiredUpdate (als, Type)
update
)
}
else Maybe (m (LoopResultSummary (als, Type)))
forall a. Maybe a
Nothing
summariseLoopResult (SubExp, Ident)
_ (Param DeclType, SubExp)
_ =
Maybe (m (LoopResultSummary (als, Type)))
forall a. Maybe a
Nothing
hasLoopInvariantShape :: Param DeclType -> Bool
hasLoopInvariantShape = (SubExp -> Bool) -> [SubExp] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all SubExp -> Bool
loopInvariant ([SubExp] -> Bool)
-> (Param DeclType -> [SubExp]) -> Param DeclType -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Type -> [SubExp]
forall u. TypeBase Shape u -> [SubExp]
arrayDims (Type -> [SubExp])
-> (Param DeclType -> Type) -> Param DeclType -> [SubExp]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Param DeclType -> Type
forall dec. Typed dec => Param dec -> Type
paramType
merge_param_names :: [VName]
merge_param_names = ((Param DeclType, SubExp) -> VName)
-> [(Param DeclType, SubExp)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map (Param DeclType -> VName
forall dec. Param dec -> VName
paramName (Param DeclType -> VName)
-> ((Param DeclType, SubExp) -> Param DeclType)
-> (Param DeclType, SubExp)
-> VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Param DeclType, SubExp) -> Param DeclType
forall a b. (a, b) -> a
fst) [(Param DeclType, SubExp)]
merge
loopInvariant :: SubExp -> Bool
loopInvariant (Var VName
v) = VName
v VName -> [VName] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`notElem` [VName]
merge_param_names
loopInvariant Constant {} = Bool
True
data LoopResultSummary dec = LoopResultSummary
{ forall dec. LoopResultSummary dec -> SubExp
resultSubExp :: SubExp,
forall dec. LoopResultSummary dec -> Ident
inPatternAs :: Ident,
forall dec. LoopResultSummary dec -> (Param DeclType, SubExp)
mergeParam :: (Param DeclType, SubExp),
forall dec.
LoopResultSummary dec -> Maybe (DesiredUpdate dec, VName, dec)
relatedUpdate :: Maybe (DesiredUpdate dec, VName, dec)
}
deriving (Int -> LoopResultSummary dec -> ShowS
[LoopResultSummary dec] -> ShowS
LoopResultSummary dec -> String
(Int -> LoopResultSummary dec -> ShowS)
-> (LoopResultSummary dec -> String)
-> ([LoopResultSummary dec] -> ShowS)
-> Show (LoopResultSummary dec)
forall dec. Show dec => Int -> LoopResultSummary dec -> ShowS
forall dec. Show dec => [LoopResultSummary dec] -> ShowS
forall dec. Show dec => LoopResultSummary dec -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [LoopResultSummary dec] -> ShowS
$cshowList :: forall dec. Show dec => [LoopResultSummary dec] -> ShowS
show :: LoopResultSummary dec -> String
$cshow :: forall dec. Show dec => LoopResultSummary dec -> String
showsPrec :: Int -> LoopResultSummary dec -> ShowS
$cshowsPrec :: forall dec. Show dec => Int -> LoopResultSummary dec -> ShowS
Show)
indexSubstitutions ::
[LoopResultSummary dec] ->
IndexSubstitutions dec
indexSubstitutions :: forall dec. [LoopResultSummary dec] -> IndexSubstitutions dec
indexSubstitutions = (LoopResultSummary dec
-> Maybe (VName, (Certificates, VName, dec, Slice SubExp)))
-> [LoopResultSummary dec]
-> [(VName, (Certificates, VName, dec, Slice SubExp))]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe LoopResultSummary dec
-> Maybe (VName, (Certificates, VName, dec, Slice SubExp))
forall {c}.
LoopResultSummary c
-> Maybe (VName, (Certificates, VName, c, Slice SubExp))
getSubstitution
where
getSubstitution :: LoopResultSummary c
-> Maybe (VName, (Certificates, VName, c, Slice SubExp))
getSubstitution LoopResultSummary c
res = do
(DesiredUpdate VName
_ c
_ Certificates
cs VName
_ Slice SubExp
is VName
_, VName
nm, c
dec) <- LoopResultSummary c -> Maybe (DesiredUpdate c, VName, c)
forall dec.
LoopResultSummary dec -> Maybe (DesiredUpdate dec, VName, dec)
relatedUpdate LoopResultSummary c
res
let name :: VName
name = Param DeclType -> VName
forall dec. Param dec -> VName
paramName (Param DeclType -> VName) -> Param DeclType -> VName
forall a b. (a -> b) -> a -> b
$ (Param DeclType, SubExp) -> Param DeclType
forall a b. (a, b) -> a
fst ((Param DeclType, SubExp) -> Param DeclType)
-> (Param DeclType, SubExp) -> Param DeclType
forall a b. (a -> b) -> a -> b
$ LoopResultSummary c -> (Param DeclType, SubExp)
forall dec. LoopResultSummary dec -> (Param DeclType, SubExp)
mergeParam LoopResultSummary c
res
(VName, (Certificates, VName, c, Slice SubExp))
-> Maybe (VName, (Certificates, VName, c, Slice SubExp))
forall (m :: * -> *) a. Monad m => a -> m a
return (VName
name, (Certificates
cs, VName
nm, c
dec, Slice SubExp
is))
manipulateResult ::
(Bindable lore, MonadFreshNames m) =>
[LoopResultSummary (LetDec lore)] ->
IndexSubstitutions (LetDec lore) ->
m (Result, Stms lore)
manipulateResult :: forall lore (m :: * -> *).
(Bindable lore, MonadFreshNames m) =>
[LoopResultSummary (LetDec lore)]
-> IndexSubstitutions (LetDec lore) -> m ([SubExp], Stms lore)
manipulateResult [LoopResultSummary (LetDec lore)]
summaries IndexSubstitutions (LetDec lore)
substs = do
let ([SubExp]
orig_ses, [SubExp]
updated_ses) = [Either SubExp SubExp] -> ([SubExp], [SubExp])
forall a b. [Either a b] -> ([a], [b])
partitionEithers ([Either SubExp SubExp] -> ([SubExp], [SubExp]))
-> [Either SubExp SubExp] -> ([SubExp], [SubExp])
forall a b. (a -> b) -> a -> b
$ (LoopResultSummary (LetDec lore) -> Either SubExp SubExp)
-> [LoopResultSummary (LetDec lore)] -> [Either SubExp SubExp]
forall a b. (a -> b) -> [a] -> [b]
map LoopResultSummary (LetDec lore) -> Either SubExp SubExp
forall {dec}. LoopResultSummary dec -> Either SubExp SubExp
unchangedRes [LoopResultSummary (LetDec lore)]
summaries
([SubExp]
subst_ses, [Stm lore]
res_bnds) <- WriterT [Stm lore] m [SubExp] -> m ([SubExp], [Stm lore])
forall w (m :: * -> *) a. WriterT w m a -> m (a, w)
runWriterT (WriterT [Stm lore] m [SubExp] -> m ([SubExp], [Stm lore]))
-> WriterT [Stm lore] m [SubExp] -> m ([SubExp], [Stm lore])
forall a b. (a -> b) -> a -> b
$ (SubExp
-> (VName, (Certificates, VName, LetDec lore, Slice SubExp))
-> WriterT [Stm lore] m SubExp)
-> [SubExp]
-> IndexSubstitutions (LetDec lore)
-> WriterT [Stm lore] m [SubExp]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM SubExp
-> (VName, (Certificates, VName, LetDec lore, Slice SubExp))
-> WriterT [Stm lore] m SubExp
forall {m :: * -> *} {lore} {t}.
(MonadFreshNames m, MonadWriter [Stm lore] m, Bindable lore,
Typed t) =>
SubExp
-> (VName, (Certificates, VName, t, Slice SubExp)) -> m SubExp
substRes [SubExp]
updated_ses IndexSubstitutions (LetDec lore)
substs
([SubExp], Stms lore) -> m ([SubExp], Stms lore)
forall (m :: * -> *) a. Monad m => a -> m a
return ([SubExp]
orig_ses [SubExp] -> [SubExp] -> [SubExp]
forall a. [a] -> [a] -> [a]
++ [SubExp]
subst_ses, [Stm lore] -> Stms lore
forall lore. [Stm lore] -> Stms lore
stmsFromList [Stm lore]
res_bnds)
where
unchangedRes :: LoopResultSummary dec -> Either SubExp SubExp
unchangedRes LoopResultSummary dec
summary =
case LoopResultSummary dec -> Maybe (DesiredUpdate dec, VName, dec)
forall dec.
LoopResultSummary dec -> Maybe (DesiredUpdate dec, VName, dec)
relatedUpdate LoopResultSummary dec
summary of
Maybe (DesiredUpdate dec, VName, dec)
Nothing -> SubExp -> Either SubExp SubExp
forall a b. a -> Either a b
Left (SubExp -> Either SubExp SubExp) -> SubExp -> Either SubExp SubExp
forall a b. (a -> b) -> a -> b
$ LoopResultSummary dec -> SubExp
forall dec. LoopResultSummary dec -> SubExp
resultSubExp LoopResultSummary dec
summary
Just (DesiredUpdate dec, VName, dec)
_ -> SubExp -> Either SubExp SubExp
forall a b. b -> Either a b
Right (SubExp -> Either SubExp SubExp) -> SubExp -> Either SubExp SubExp
forall a b. (a -> b) -> a -> b
$ LoopResultSummary dec -> SubExp
forall dec. LoopResultSummary dec -> SubExp
resultSubExp LoopResultSummary dec
summary
substRes :: SubExp
-> (VName, (Certificates, VName, t, Slice SubExp)) -> m SubExp
substRes (Var VName
res_v) (VName
subst_v, (Certificates
_, VName
nm, t
_, Slice SubExp
_))
| VName
res_v VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== VName
subst_v =
SubExp -> m SubExp
forall (m :: * -> *) a. Monad m => a -> m a
return (SubExp -> m SubExp) -> SubExp -> m SubExp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
nm
substRes SubExp
res_se (VName
_, (Certificates
cs, VName
nm, t
dec, Slice SubExp
is)) = do
Ident
v' <- ShowS -> Ident -> m Ident
forall (m :: * -> *).
MonadFreshNames m =>
ShowS -> Ident -> m Ident
newIdent' (String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
"_updated") (Ident -> m Ident) -> Ident -> m Ident
forall a b. (a -> b) -> a -> b
$ VName -> Type -> Ident
Ident VName
nm (Type -> Ident) -> Type -> Ident
forall a b. (a -> b) -> a -> b
$ t -> Type
forall t. Typed t => t -> Type
typeOf t
dec
[Stm lore] -> m ()
forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell
[ Certificates -> Stm lore -> Stm lore
forall lore. Certificates -> Stm lore -> Stm lore
certify Certificates
cs (Stm lore -> Stm lore) -> Stm lore -> Stm lore
forall a b. (a -> b) -> a -> b
$
[Ident] -> [Ident] -> Exp lore -> Stm lore
forall lore.
Bindable lore =>
[Ident] -> [Ident] -> Exp lore -> Stm lore
mkLet [] [Ident
v'] (Exp lore -> Stm lore) -> Exp lore -> Stm lore
forall a b. (a -> b) -> a -> b
$
BasicOp -> Exp lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp lore) -> BasicOp -> Exp lore
forall a b. (a -> b) -> a -> b
$
VName -> Slice SubExp -> SubExp -> BasicOp
Update VName
nm (Type -> Slice SubExp -> Slice SubExp
fullSlice (t -> Type
forall t. Typed t => t -> Type
typeOf t
dec) Slice SubExp
is) SubExp
res_se
]
SubExp -> m SubExp
forall (m :: * -> *) a. Monad m => a -> m a
return (SubExp -> m SubExp) -> SubExp -> m SubExp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ Ident -> VName
identName Ident
v'