{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE TypeFamilies #-}
module Futhark.Optimise.InPlaceLowering.LowerIntoStm
( lowerUpdateKernels
, LowerUpdate
, DesiredUpdate (..)
) where
import Control.Monad
import Control.Monad.Writer
import Data.List (find)
import Data.Maybe (mapMaybe)
import Data.Either
import qualified Data.Map as M
import Futhark.Representation.AST.Attributes.Aliases
import Futhark.Representation.Aliases
import Futhark.Representation.Kernels
import Futhark.Construct
import Futhark.Optimise.InPlaceLowering.SubstituteIndices
data DesiredUpdate attr =
DesiredUpdate { DesiredUpdate attr -> VName
updateName :: VName
, DesiredUpdate attr -> attr
updateType :: attr
, DesiredUpdate attr -> Certificates
updateCertificates :: Certificates
, DesiredUpdate attr -> VName
updateSource :: VName
, DesiredUpdate attr -> Slice SubExp
updateIndices :: Slice SubExp
, DesiredUpdate attr -> VName
updateValue :: VName
}
deriving (Int -> DesiredUpdate attr -> ShowS
[DesiredUpdate attr] -> ShowS
DesiredUpdate attr -> String
(Int -> DesiredUpdate attr -> ShowS)
-> (DesiredUpdate attr -> String)
-> ([DesiredUpdate attr] -> ShowS)
-> Show (DesiredUpdate attr)
forall attr. Show attr => Int -> DesiredUpdate attr -> ShowS
forall attr. Show attr => [DesiredUpdate attr] -> ShowS
forall attr. Show attr => DesiredUpdate attr -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [DesiredUpdate attr] -> ShowS
$cshowList :: forall attr. Show attr => [DesiredUpdate attr] -> ShowS
show :: DesiredUpdate attr -> String
$cshow :: forall attr. Show attr => DesiredUpdate attr -> String
showsPrec :: Int -> DesiredUpdate attr -> ShowS
$cshowsPrec :: forall attr. Show attr => Int -> DesiredUpdate attr -> ShowS
Show)
instance Functor DesiredUpdate where
a -> b
f fmap :: (a -> b) -> DesiredUpdate a -> DesiredUpdate b
`fmap` DesiredUpdate a
u = DesiredUpdate a
u { updateType :: b
updateType = a -> b
f (a -> b) -> a -> b
forall a b. (a -> b) -> a -> b
$ DesiredUpdate a -> a
forall attr. DesiredUpdate attr -> attr
updateType DesiredUpdate a
u }
updateHasValue :: VName -> DesiredUpdate attr -> Bool
updateHasValue :: VName -> DesiredUpdate attr -> Bool
updateHasValue VName
name = (VName
nameVName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
==) (VName -> Bool)
-> (DesiredUpdate attr -> VName) -> DesiredUpdate attr -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. DesiredUpdate attr -> VName
forall attr. DesiredUpdate attr -> VName
updateValue
type LowerUpdate lore m = Scope (Aliases lore)
-> Stm (Aliases lore)
-> [DesiredUpdate (LetAttr (Aliases lore))]
-> Maybe (m [Stm (Aliases lore)])
lowerUpdate :: (MonadFreshNames m, Bindable lore,
LetAttr lore ~ Type, CanBeAliased (Op lore)) => LowerUpdate lore m
lowerUpdate :: LowerUpdate lore m
lowerUpdate Scope (Aliases lore)
scope (Let Pattern (Aliases lore)
pat StmAux (ExpAttr (Aliases lore))
aux (DoLoop [(FParam (Aliases lore), SubExp)]
ctx [(FParam (Aliases lore), SubExp)]
val LoopForm (Aliases lore)
form BodyT (Aliases lore)
body)) [DesiredUpdate (LetAttr (Aliases lore))]
updates = do
m ([Stm (Aliases lore)], [Stm (Aliases lore)], [Ident], [Ident],
[(Param DeclType, SubExp)], [(Param DeclType, SubExp)],
BodyT (Aliases lore))
canDo <- Scope (Aliases lore)
-> [DesiredUpdate (LetAttr (Aliases lore))]
-> Pattern (Aliases lore)
-> [(FParam (Aliases lore), SubExp)]
-> [(FParam (Aliases lore), SubExp)]
-> LoopForm (Aliases lore)
-> BodyT (Aliases lore)
-> Maybe
(m ([Stm (Aliases lore)], [Stm (Aliases lore)], [Ident], [Ident],
[(FParam (Aliases lore), SubExp)],
[(FParam (Aliases lore), SubExp)], BodyT (Aliases lore)))
forall lore als (m :: * -> *).
(Bindable lore, BinderOps lore, Aliased lore,
LetAttr lore ~ (als, Type), MonadFreshNames m) =>
Scope lore
-> [DesiredUpdate (LetAttr lore)]
-> Pattern lore
-> [(FParam lore, SubExp)]
-> [(FParam lore, SubExp)]
-> LoopForm lore
-> Body lore
-> Maybe
(m ([Stm lore], [Stm lore], [Ident], [Ident],
[(FParam lore, SubExp)], [(FParam lore, SubExp)], Body lore))
lowerUpdateIntoLoop Scope (Aliases lore)
scope [DesiredUpdate (LetAttr (Aliases lore))]
updates Pattern (Aliases lore)
pat [(FParam (Aliases lore), SubExp)]
ctx [(FParam (Aliases lore), SubExp)]
val LoopForm (Aliases lore)
form BodyT (Aliases lore)
body
m [Stm (Aliases lore)] -> Maybe (m [Stm (Aliases lore)])
forall a. a -> Maybe a
Just (m [Stm (Aliases lore)] -> Maybe (m [Stm (Aliases lore)]))
-> m [Stm (Aliases lore)] -> Maybe (m [Stm (Aliases lore)])
forall a b. (a -> b) -> a -> b
$ do
([Stm (Aliases lore)]
prebnds, [Stm (Aliases lore)]
postbnds, [Ident]
ctxpat, [Ident]
valpat, [(Param DeclType, SubExp)]
ctx', [(Param DeclType, SubExp)]
val', BodyT (Aliases lore)
body') <- m ([Stm (Aliases lore)], [Stm (Aliases lore)], [Ident], [Ident],
[(Param DeclType, SubExp)], [(Param DeclType, SubExp)],
BodyT (Aliases lore))
canDo
[Stm (Aliases lore)] -> m [Stm (Aliases lore)]
forall (m :: * -> *) a. Monad m => a -> m a
return ([Stm (Aliases lore)] -> m [Stm (Aliases lore)])
-> [Stm (Aliases lore)] -> m [Stm (Aliases lore)]
forall a b. (a -> b) -> a -> b
$
[Stm (Aliases lore)]
prebnds [Stm (Aliases lore)]
-> [Stm (Aliases lore)] -> [Stm (Aliases lore)]
forall a. [a] -> [a] -> [a]
++ [Certificates -> Stm (Aliases lore) -> Stm (Aliases lore)
forall lore. Certificates -> Stm lore -> Stm lore
certify (StmAux (ConsumedInExp, ExpAttr lore) -> Certificates
forall attr. StmAux attr -> Certificates
stmAuxCerts StmAux (ConsumedInExp, ExpAttr lore)
StmAux (ExpAttr (Aliases lore))
aux) (Stm (Aliases lore) -> Stm (Aliases lore))
-> Stm (Aliases lore) -> Stm (Aliases lore)
forall a b. (a -> b) -> a -> b
$
[Ident] -> [Ident] -> ExpT (Aliases lore) -> Stm (Aliases lore)
forall lore.
Bindable lore =>
[Ident] -> [Ident] -> Exp lore -> Stm lore
mkLet [Ident]
ctxpat [Ident]
valpat (ExpT (Aliases lore) -> Stm (Aliases lore))
-> ExpT (Aliases lore) -> Stm (Aliases lore)
forall a b. (a -> b) -> a -> b
$ [(FParam (Aliases lore), SubExp)]
-> [(FParam (Aliases lore), SubExp)]
-> LoopForm (Aliases lore)
-> BodyT (Aliases lore)
-> ExpT (Aliases lore)
forall lore.
[(FParam lore, SubExp)]
-> [(FParam lore, SubExp)]
-> LoopForm lore
-> BodyT lore
-> ExpT lore
DoLoop [(Param DeclType, SubExp)]
[(FParam (Aliases lore), SubExp)]
ctx' [(Param DeclType, SubExp)]
[(FParam (Aliases lore), SubExp)]
val' LoopForm (Aliases lore)
form BodyT (Aliases lore)
body'] [Stm (Aliases lore)]
-> [Stm (Aliases lore)] -> [Stm (Aliases lore)]
forall a. [a] -> [a] -> [a]
++ [Stm (Aliases lore)]
postbnds
lowerUpdate Scope (Aliases lore)
_
(Let Pattern (Aliases lore)
pat StmAux (ExpAttr (Aliases lore))
aux (BasicOp (SubExp (Var VName
v))))
[DesiredUpdate VName
bindee_nm LetAttr (Aliases lore)
bindee_attr Certificates
cs VName
src Slice SubExp
is VName
val]
| PatternT (ConsumedInExp, Type) -> [VName]
forall attr. PatternT attr -> [VName]
patternNames PatternT (ConsumedInExp, Type)
Pattern (Aliases lore)
pat [VName] -> [VName] -> Bool
forall a. Eq a => a -> a -> Bool
== [VName
src] =
let is' :: Slice SubExp
is' = Type -> Slice SubExp -> Slice SubExp
fullSlice ((ConsumedInExp, Type) -> Type
forall t. Typed t => t -> Type
typeOf (ConsumedInExp, Type)
LetAttr (Aliases lore)
bindee_attr) Slice SubExp
is
in m [Stm (Aliases lore)] -> Maybe (m [Stm (Aliases lore)])
forall a. a -> Maybe a
Just (m [Stm (Aliases lore)] -> Maybe (m [Stm (Aliases lore)]))
-> m [Stm (Aliases lore)] -> Maybe (m [Stm (Aliases lore)])
forall a b. (a -> b) -> a -> b
$
[Stm (Aliases lore)] -> m [Stm (Aliases lore)]
forall (m :: * -> *) a. Monad m => a -> m a
return [Certificates -> Stm (Aliases lore) -> Stm (Aliases lore)
forall lore. Certificates -> Stm lore -> Stm lore
certify (StmAux (ConsumedInExp, ExpAttr lore) -> Certificates
forall attr. StmAux attr -> Certificates
stmAuxCerts StmAux (ConsumedInExp, ExpAttr lore)
StmAux (ExpAttr (Aliases lore))
aux Certificates -> Certificates -> Certificates
forall a. Semigroup a => a -> a -> a
<> Certificates
cs) (Stm (Aliases lore) -> Stm (Aliases lore))
-> Stm (Aliases lore) -> Stm (Aliases lore)
forall a b. (a -> b) -> a -> b
$
[Ident] -> [Ident] -> ExpT (Aliases lore) -> Stm (Aliases lore)
forall lore.
Bindable lore =>
[Ident] -> [Ident] -> Exp lore -> Stm lore
mkLet [] [VName -> Type -> Ident
Ident VName
bindee_nm (Type -> Ident) -> Type -> Ident
forall a b. (a -> b) -> a -> b
$ (ConsumedInExp, Type) -> Type
forall t. Typed t => t -> Type
typeOf (ConsumedInExp, Type)
LetAttr (Aliases lore)
bindee_attr] (ExpT (Aliases lore) -> Stm (Aliases lore))
-> ExpT (Aliases lore) -> Stm (Aliases lore)
forall a b. (a -> b) -> a -> b
$
BasicOp (Aliases lore) -> ExpT (Aliases lore)
forall lore. BasicOp lore -> ExpT lore
BasicOp (BasicOp (Aliases lore) -> ExpT (Aliases lore))
-> BasicOp (Aliases lore) -> ExpT (Aliases lore)
forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> SubExp -> BasicOp (Aliases lore)
forall lore. VName -> Slice SubExp -> SubExp -> BasicOp lore
Update VName
v Slice SubExp
is' (SubExp -> BasicOp (Aliases lore))
-> SubExp -> BasicOp (Aliases lore)
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
val]
lowerUpdate Scope (Aliases lore)
_ Stm (Aliases lore)
_ [DesiredUpdate (LetAttr (Aliases lore))]
_ =
Maybe (m [Stm (Aliases lore)])
forall a. Maybe a
Nothing
lowerUpdateKernels :: MonadFreshNames m => LowerUpdate Kernels m
lowerUpdateKernels :: LowerUpdate Kernels m
lowerUpdateKernels Scope (Aliases Kernels)
_
(Let (Pattern [] [PatElem VName
v LetAttr (Aliases Kernels)
v_attr]) StmAux (ExpAttr (Aliases Kernels))
aux (Op (SegOp (SegMap lvl space ts kbody))))
[update :: DesiredUpdate (LetAttr (Aliases Kernels))
update@(DesiredUpdate VName
bindee_nm LetAttr (Aliases Kernels)
bindee_attr Certificates
cs VName
_src Slice SubExp
is VName
val)]
| VName
v VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== VName
val = do
KernelBody (Aliases Kernels)
kbody' <- DesiredUpdate (LetAttr (Aliases Kernels))
-> SegSpace
-> KernelBody (Aliases Kernels)
-> Maybe (KernelBody (Aliases Kernels))
lowerUpdateIntoKernel DesiredUpdate (LetAttr (Aliases Kernels))
update SegSpace
space KernelBody (Aliases Kernels)
kbody
let is' :: Slice SubExp
is' = Type -> Slice SubExp -> Slice SubExp
fullSlice ((ConsumedInExp, Type) -> Type
forall t. Typed t => t -> Type
typeOf (ConsumedInExp, Type)
LetAttr (Aliases Kernels)
bindee_attr) Slice SubExp
is
m [Stm (Aliases Kernels)] -> Maybe (m [Stm (Aliases Kernels)])
forall a. a -> Maybe a
Just (m [Stm (Aliases Kernels)] -> Maybe (m [Stm (Aliases Kernels)]))
-> m [Stm (Aliases Kernels)] -> Maybe (m [Stm (Aliases Kernels)])
forall a b. (a -> b) -> a -> b
$ [Stm (Aliases Kernels)] -> m [Stm (Aliases Kernels)]
forall (m :: * -> *) a. Monad m => a -> m a
return [Certificates -> Stm (Aliases Kernels) -> Stm (Aliases Kernels)
forall lore. Certificates -> Stm lore -> Stm lore
certify (StmAux (ConsumedInExp, ()) -> Certificates
forall attr. StmAux attr -> Certificates
stmAuxCerts StmAux (ConsumedInExp, ())
StmAux (ExpAttr (Aliases Kernels))
aux Certificates -> Certificates -> Certificates
forall a. Semigroup a => a -> a -> a
<> Certificates
cs) (Stm (Aliases Kernels) -> Stm (Aliases Kernels))
-> Stm (Aliases Kernels) -> Stm (Aliases Kernels)
forall a b. (a -> b) -> a -> b
$
[Ident]
-> [Ident] -> ExpT (Aliases Kernels) -> Stm (Aliases Kernels)
forall lore.
Bindable lore =>
[Ident] -> [Ident] -> Exp lore -> Stm lore
mkLet [] [VName -> Type -> Ident
Ident VName
bindee_nm (Type -> Ident) -> Type -> Ident
forall a b. (a -> b) -> a -> b
$ (ConsumedInExp, Type) -> Type
forall t. Typed t => t -> Type
typeOf (ConsumedInExp, Type)
LetAttr (Aliases Kernels)
bindee_attr] (ExpT (Aliases Kernels) -> Stm (Aliases Kernels))
-> ExpT (Aliases Kernels) -> Stm (Aliases Kernels)
forall a b. (a -> b) -> a -> b
$
Op (Aliases Kernels) -> ExpT (Aliases Kernels)
forall lore. Op lore -> ExpT lore
Op (Op (Aliases Kernels) -> ExpT (Aliases Kernels))
-> Op (Aliases Kernels) -> ExpT (Aliases Kernels)
forall a b. (a -> b) -> a -> b
$ SegOp (Aliases Kernels)
-> HostOp (Aliases Kernels) (SOAC (Aliases Kernels))
forall lore op. SegOp lore -> HostOp lore op
SegOp (SegOp (Aliases Kernels)
-> HostOp (Aliases Kernels) (SOAC (Aliases Kernels)))
-> SegOp (Aliases Kernels)
-> HostOp (Aliases Kernels) (SOAC (Aliases Kernels))
forall a b. (a -> b) -> a -> b
$ SegLevel
-> SegSpace
-> [Type]
-> KernelBody (Aliases Kernels)
-> SegOp (Aliases Kernels)
forall lore.
SegLevel -> SegSpace -> [Type] -> KernelBody lore -> SegOp lore
SegMap SegLevel
lvl SegSpace
space [Type]
ts KernelBody (Aliases Kernels)
kbody',
[Ident]
-> [Ident] -> ExpT (Aliases Kernels) -> Stm (Aliases Kernels)
forall lore.
Bindable lore =>
[Ident] -> [Ident] -> Exp lore -> Stm lore
mkLet [] [VName -> Type -> Ident
Ident VName
v (Type -> Ident) -> Type -> Ident
forall a b. (a -> b) -> a -> b
$ (ConsumedInExp, Type) -> Type
forall t. Typed t => t -> Type
typeOf (ConsumedInExp, Type)
LetAttr (Aliases Kernels)
v_attr] (ExpT (Aliases Kernels) -> Stm (Aliases Kernels))
-> ExpT (Aliases Kernels) -> Stm (Aliases Kernels)
forall a b. (a -> b) -> a -> b
$ BasicOp (Aliases Kernels) -> ExpT (Aliases Kernels)
forall lore. BasicOp lore -> ExpT lore
BasicOp (BasicOp (Aliases Kernels) -> ExpT (Aliases Kernels))
-> BasicOp (Aliases Kernels) -> ExpT (Aliases Kernels)
forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp (Aliases Kernels)
forall lore. VName -> Slice SubExp -> BasicOp lore
Index VName
bindee_nm Slice SubExp
is']
lowerUpdateKernels Scope (Aliases Kernels)
scope Stm (Aliases Kernels)
stm [DesiredUpdate (LetAttr (Aliases Kernels))]
updates = LowerUpdate Kernels m
forall (m :: * -> *) lore.
(MonadFreshNames m, Bindable lore, LetAttr lore ~ Type,
CanBeAliased (Op lore)) =>
LowerUpdate lore m
lowerUpdate Scope (Aliases Kernels)
scope Stm (Aliases Kernels)
stm [DesiredUpdate (LetAttr (Aliases Kernels))]
updates
lowerUpdateIntoKernel :: DesiredUpdate (LetAttr (Aliases Kernels))
-> SegSpace -> KernelBody (Aliases Kernels)
-> Maybe (KernelBody (Aliases Kernels))
lowerUpdateIntoKernel :: DesiredUpdate (LetAttr (Aliases Kernels))
-> SegSpace
-> KernelBody (Aliases Kernels)
-> Maybe (KernelBody (Aliases Kernels))
lowerUpdateIntoKernel DesiredUpdate (LetAttr (Aliases Kernels))
update SegSpace
kspace KernelBody (Aliases Kernels)
kbody = do
[Returns ResultManifest
_ SubExp
se] <- [KernelResult] -> Maybe [KernelResult]
forall a. a -> Maybe a
Just ([KernelResult] -> Maybe [KernelResult])
-> [KernelResult] -> Maybe [KernelResult]
forall a b. (a -> b) -> a -> b
$ KernelBody (Aliases Kernels) -> [KernelResult]
forall lore. KernelBody lore -> [KernelResult]
kernelBodyResult KernelBody (Aliases Kernels)
kbody
[SubExp]
is' <- (DimIndex SubExp -> Maybe SubExp) -> Slice SubExp -> Maybe [SubExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM DimIndex SubExp -> Maybe SubExp
forall d. DimIndex d -> Maybe d
dimFix Slice SubExp
is
let ret :: KernelResult
ret = [SubExp] -> VName -> [([SubExp], SubExp)] -> KernelResult
WriteReturns (Type -> [SubExp]
forall u. TypeBase Shape u -> [SubExp]
arrayDims (Type -> [SubExp]) -> Type -> [SubExp]
forall a b. (a -> b) -> a -> b
$ (ConsumedInExp, Type) -> Type
forall a b. (a, b) -> b
snd (ConsumedInExp, Type)
bindee_attr) VName
src [([SubExp]
is'[SubExp] -> [SubExp] -> [SubExp]
forall a. [a] -> [a] -> [a]
++(VName -> SubExp) -> [VName] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var [VName]
gtids, SubExp
se)]
KernelBody (Aliases Kernels)
-> Maybe (KernelBody (Aliases Kernels))
forall (m :: * -> *) a. Monad m => a -> m a
return KernelBody (Aliases Kernels)
kbody { kernelBodyResult :: [KernelResult]
kernelBodyResult = [KernelResult
ret] }
where DesiredUpdate VName
_bindee_nm (ConsumedInExp, Type)
bindee_attr Certificates
_cs VName
src Slice SubExp
is VName
_val = DesiredUpdate (ConsumedInExp, Type)
DesiredUpdate (LetAttr (Aliases Kernels))
update
gtids :: [VName]
gtids = ((VName, SubExp) -> VName) -> [(VName, SubExp)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map (VName, SubExp) -> VName
forall a b. (a, b) -> a
fst ([(VName, SubExp)] -> [VName]) -> [(VName, SubExp)] -> [VName]
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
kspace
lowerUpdateIntoLoop :: (Bindable lore, BinderOps lore,
Aliased lore, LetAttr lore ~ (als, Type),
MonadFreshNames m) =>
Scope lore
-> [DesiredUpdate (LetAttr lore)]
-> Pattern lore
-> [(FParam lore, SubExp)]
-> [(FParam lore, SubExp)]
-> LoopForm lore
-> Body lore
-> Maybe (m ([Stm lore],
[Stm lore],
[Ident],
[Ident],
[(FParam lore, SubExp)],
[(FParam lore, SubExp)],
Body lore))
lowerUpdateIntoLoop :: Scope lore
-> [DesiredUpdate (LetAttr lore)]
-> Pattern lore
-> [(FParam lore, SubExp)]
-> [(FParam lore, SubExp)]
-> LoopForm lore
-> Body lore
-> Maybe
(m ([Stm lore], [Stm lore], [Ident], [Ident],
[(FParam lore, SubExp)], [(FParam lore, SubExp)], Body lore))
lowerUpdateIntoLoop Scope lore
scope [DesiredUpdate (LetAttr lore)]
updates Pattern lore
pat [(FParam lore, SubExp)]
ctx [(FParam lore, SubExp)]
val LoopForm lore
form Body lore
body = do
[((Param DeclType, SubExp), Names)]
-> (((Param DeclType, SubExp), Names) -> Maybe ()) -> Maybe ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([(Param DeclType, SubExp)]
-> [Names] -> [((Param DeclType, SubExp), Names)]
forall a b. [a] -> [b] -> [(a, b)]
zip [(Param DeclType, SubExp)]
[(FParam lore, SubExp)]
val ([Names] -> [((Param DeclType, SubExp), Names)])
-> [Names] -> [((Param DeclType, SubExp), Names)]
forall a b. (a -> b) -> a -> b
$ Body lore -> [Names]
forall lore. Aliased lore => Body lore -> [Names]
bodyAliases Body lore
body) ((((Param DeclType, SubExp), Names) -> Maybe ()) -> Maybe ())
-> (((Param DeclType, SubExp), Names) -> Maybe ()) -> Maybe ()
forall a b. (a -> b) -> a -> b
$ \((Param DeclType
p, SubExp
_), Names
als) ->
Bool -> Maybe ()
forall (f :: * -> *). Alternative f => Bool -> f ()
guard (Bool -> Maybe ()) -> Bool -> Maybe ()
forall a b. (a -> b) -> a -> b
$ Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ Param DeclType -> VName
forall attr. Param attr -> VName
paramName Param DeclType
p VName -> Names -> Bool
`nameIn` Names
als
m [LoopResultSummary (als, Type)]
mk_in_place_map <- [DesiredUpdate (als, Type)]
-> Names
-> [(SubExp, Ident)]
-> [(Param DeclType, SubExp)]
-> Maybe (m [LoopResultSummary (als, Type)])
forall (m :: * -> *) als.
MonadFreshNames m =>
[DesiredUpdate (als, Type)]
-> Names
-> [(SubExp, Ident)]
-> [(Param DeclType, SubExp)]
-> Maybe (m [LoopResultSummary (als, Type)])
summariseLoop [DesiredUpdate (als, Type)]
[DesiredUpdate (LetAttr lore)]
updates Names
usedInBody [(SubExp, Ident)]
resmap [(Param DeclType, SubExp)]
[(FParam lore, SubExp)]
val
m ([Stm lore], [Stm lore], [Ident], [Ident],
[(Param DeclType, SubExp)], [(Param DeclType, SubExp)], Body lore)
-> Maybe
(m ([Stm lore], [Stm lore], [Ident], [Ident],
[(Param DeclType, SubExp)], [(Param DeclType, SubExp)], Body lore))
forall a. a -> Maybe a
Just (m ([Stm lore], [Stm lore], [Ident], [Ident],
[(Param DeclType, SubExp)], [(Param DeclType, SubExp)], Body lore)
-> Maybe
(m ([Stm lore], [Stm lore], [Ident], [Ident],
[(Param DeclType, SubExp)], [(Param DeclType, SubExp)],
Body lore)))
-> m ([Stm lore], [Stm lore], [Ident], [Ident],
[(Param DeclType, SubExp)], [(Param DeclType, SubExp)], Body lore)
-> Maybe
(m ([Stm lore], [Stm lore], [Ident], [Ident],
[(Param DeclType, SubExp)], [(Param DeclType, SubExp)], Body lore))
forall a b. (a -> b) -> a -> b
$ do
[LoopResultSummary (als, Type)]
in_place_map <- m [LoopResultSummary (als, Type)]
mk_in_place_map
([(Param DeclType, SubExp)]
val',[Stm lore]
prebnds,[Stm lore]
postbnds) <- [LoopResultSummary (als, Type)]
-> m ([(Param DeclType, SubExp)], [Stm lore], [Stm lore])
forall (m :: * -> *) lore als.
(MonadFreshNames m, Bindable lore) =>
[LoopResultSummary (als, Type)]
-> m ([(Param DeclType, SubExp)], [Stm lore], [Stm lore])
mkMerges [LoopResultSummary (als, Type)]
in_place_map
let ([Ident]
ctxpat,[Ident]
valpat) = [LoopResultSummary (als, Type)] -> ([Ident], [Ident])
mkResAndPat [LoopResultSummary (als, Type)]
in_place_map
idxsubsts :: IndexSubstitutions (als, Type)
idxsubsts = [LoopResultSummary (als, Type)] -> IndexSubstitutions (als, Type)
forall attr. [LoopResultSummary attr] -> IndexSubstitutions attr
indexSubstitutions [LoopResultSummary (als, Type)]
in_place_map
(IndexSubstitutions (als, Type)
idxsubsts', Stms lore
newbnds) <- IndexSubstitutions (als, Type)
-> Stms lore -> m (IndexSubstitutions (als, Type), Stms lore)
forall (m :: * -> *) lore attr.
(MonadFreshNames m, BinderOps lore, Bindable lore, Aliased lore,
LetAttr lore ~ attr) =>
IndexSubstitutions attr
-> Stms lore -> m (IndexSubstitutions attr, Stms lore)
substituteIndices IndexSubstitutions (als, Type)
idxsubsts (Stms lore -> m (IndexSubstitutions (als, Type), Stms lore))
-> Stms lore -> m (IndexSubstitutions (als, Type), Stms lore)
forall a b. (a -> b) -> a -> b
$ Body lore -> Stms lore
forall lore. BodyT lore -> Stms lore
bodyStms Body lore
body
([SubExp]
body_res, Stms lore
res_bnds) <- [LoopResultSummary (LetAttr lore)]
-> IndexSubstitutions (LetAttr lore) -> m ([SubExp], Stms lore)
forall lore (m :: * -> *).
(Bindable lore, MonadFreshNames m) =>
[LoopResultSummary (LetAttr lore)]
-> IndexSubstitutions (LetAttr lore) -> m ([SubExp], Stms lore)
manipulateResult [LoopResultSummary (als, Type)]
[LoopResultSummary (LetAttr lore)]
in_place_map IndexSubstitutions (als, Type)
IndexSubstitutions (LetAttr lore)
idxsubsts'
let body' :: Body lore
body' = Stms lore -> [SubExp] -> Body lore
forall lore. Bindable lore => Stms lore -> [SubExp] -> Body lore
mkBody (Stms lore
newbndsStms lore -> Stms lore -> Stms lore
forall a. Semigroup a => a -> a -> a
<>Stms lore
res_bnds) [SubExp]
body_res
([Stm lore], [Stm lore], [Ident], [Ident],
[(Param DeclType, SubExp)], [(Param DeclType, SubExp)], Body lore)
-> m ([Stm lore], [Stm lore], [Ident], [Ident],
[(Param DeclType, SubExp)], [(Param DeclType, SubExp)], Body lore)
forall (m :: * -> *) a. Monad m => a -> m a
return ([Stm lore]
prebnds, [Stm lore]
postbnds, [Ident]
ctxpat, [Ident]
valpat, [(Param DeclType, SubExp)]
[(FParam lore, SubExp)]
ctx, [(Param DeclType, SubExp)]
val', Body lore
body')
where usedInBody :: Names
usedInBody = [Names] -> Names
forall a. Monoid a => [a] -> a
mconcat ([Names] -> Names) -> [Names] -> Names
forall a b. (a -> b) -> a -> b
$ (VName -> Names) -> [VName] -> [Names]
forall a b. (a -> b) -> [a] -> [b]
map VName -> Names
expandAliases ([VName] -> [Names]) -> [VName] -> [Names]
forall a b. (a -> b) -> a -> b
$ Names -> [VName]
namesToList (Names -> [VName]) -> Names -> [VName]
forall a b. (a -> b) -> a -> b
$ Body lore -> Names
forall a. FreeIn a => a -> Names
freeIn Body lore
body Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> LoopForm lore -> Names
forall a. FreeIn a => a -> Names
freeIn LoopForm lore
form
expandAliases :: VName -> Names
expandAliases VName
v = case VName -> Scope lore -> Maybe (NameInfo lore)
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
v Scope lore
scope of
Just (LetInfo LetAttr lore
attr) -> VName -> Names
oneName VName
v Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> (als, Type) -> Names
forall a. AliasesOf a => a -> Names
aliasesOf (als, Type)
LetAttr lore
attr
Maybe (NameInfo lore)
_ -> VName -> Names
oneName VName
v
resmap :: [(SubExp, Ident)]
resmap = [SubExp] -> [Ident] -> [(SubExp, Ident)]
forall a b. [a] -> [b] -> [(a, b)]
zip (Body lore -> [SubExp]
forall lore. BodyT lore -> [SubExp]
bodyResult Body lore
body) ([Ident] -> [(SubExp, Ident)]) -> [Ident] -> [(SubExp, Ident)]
forall a b. (a -> b) -> a -> b
$ PatternT (als, Type) -> [Ident]
forall attr. Typed attr => PatternT attr -> [Ident]
patternValueIdents PatternT (als, Type)
Pattern lore
pat
mkMerges :: (MonadFreshNames m, Bindable lore) =>
[LoopResultSummary (als, Type)]
-> m ([(Param DeclType, SubExp)], [Stm lore], [Stm lore])
mkMerges :: [LoopResultSummary (als, Type)]
-> m ([(Param DeclType, SubExp)], [Stm lore], [Stm lore])
mkMerges [LoopResultSummary (als, Type)]
summaries = do
(([(Param DeclType, SubExp)]
origmerge, [(Param DeclType, SubExp)]
extramerge), ([Stm lore]
prebnds, [Stm lore]
postbnds)) <-
WriterT
([Stm lore], [Stm lore])
m
([(Param DeclType, SubExp)], [(Param DeclType, SubExp)])
-> m (([(Param DeclType, SubExp)], [(Param DeclType, SubExp)]),
([Stm lore], [Stm lore]))
forall w (m :: * -> *) a. WriterT w m a -> m (a, w)
runWriterT (WriterT
([Stm lore], [Stm lore])
m
([(Param DeclType, SubExp)], [(Param DeclType, SubExp)])
-> m (([(Param DeclType, SubExp)], [(Param DeclType, SubExp)]),
([Stm lore], [Stm lore])))
-> WriterT
([Stm lore], [Stm lore])
m
([(Param DeclType, SubExp)], [(Param DeclType, SubExp)])
-> m (([(Param DeclType, SubExp)], [(Param DeclType, SubExp)]),
([Stm lore], [Stm lore]))
forall a b. (a -> b) -> a -> b
$ [Either (Param DeclType, SubExp) (Param DeclType, SubExp)]
-> ([(Param DeclType, SubExp)], [(Param DeclType, SubExp)])
forall a b. [Either a b] -> ([a], [b])
partitionEithers ([Either (Param DeclType, SubExp) (Param DeclType, SubExp)]
-> ([(Param DeclType, SubExp)], [(Param DeclType, SubExp)]))
-> WriterT
([Stm lore], [Stm lore])
m
[Either (Param DeclType, SubExp) (Param DeclType, SubExp)]
-> WriterT
([Stm lore], [Stm lore])
m
([(Param DeclType, SubExp)], [(Param DeclType, SubExp)])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (LoopResultSummary (als, Type)
-> WriterT
([Stm lore], [Stm lore])
m
(Either (Param DeclType, SubExp) (Param DeclType, SubExp)))
-> [LoopResultSummary (als, Type)]
-> WriterT
([Stm lore], [Stm lore])
m
[Either (Param DeclType, SubExp) (Param DeclType, SubExp)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM LoopResultSummary (als, Type)
-> WriterT
([Stm lore], [Stm lore])
m
(Either (Param DeclType, SubExp) (Param DeclType, SubExp))
forall (m :: * -> *) lore lore a.
(MonadFreshNames m, MonadWriter ([Stm lore], [Stm lore]) m,
Bindable lore, Bindable lore) =>
LoopResultSummary (a, Type)
-> m (Either (Param DeclType, SubExp) (Param DeclType, SubExp))
mkMerge [LoopResultSummary (als, Type)]
summaries
([(Param DeclType, SubExp)], [Stm lore], [Stm lore])
-> m ([(Param DeclType, SubExp)], [Stm lore], [Stm lore])
forall (m :: * -> *) a. Monad m => a -> m a
return ([(Param DeclType, SubExp)]
origmerge [(Param DeclType, SubExp)]
-> [(Param DeclType, SubExp)] -> [(Param DeclType, SubExp)]
forall a. [a] -> [a] -> [a]
++ [(Param DeclType, SubExp)]
extramerge, [Stm lore]
prebnds, [Stm lore]
postbnds)
mkMerge :: LoopResultSummary (a, Type)
-> m (Either (Param DeclType, SubExp) (Param DeclType, SubExp))
mkMerge LoopResultSummary (a, Type)
summary
| Just (DesiredUpdate (a, Type)
update, VName
mergename, (a, Type)
mergeattr) <- LoopResultSummary (a, Type)
-> Maybe (DesiredUpdate (a, Type), VName, (a, Type))
forall attr.
LoopResultSummary attr -> Maybe (DesiredUpdate attr, VName, attr)
relatedUpdate LoopResultSummary (a, Type)
summary = do
VName
source <- String -> m VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"modified_source"
let source_t :: Type
source_t = (a, Type) -> Type
forall a b. (a, b) -> b
snd ((a, Type) -> Type) -> (a, Type) -> Type
forall a b. (a -> b) -> a -> b
$ DesiredUpdate (a, Type) -> (a, Type)
forall attr. DesiredUpdate attr -> attr
updateType DesiredUpdate (a, Type)
update
elmident :: Ident
elmident = VName -> Type -> Ident
Ident (DesiredUpdate (a, Type) -> VName
forall attr. DesiredUpdate attr -> VName
updateValue DesiredUpdate (a, Type)
update) (Type -> Ident) -> Type -> Ident
forall a b. (a -> b) -> a -> b
$ Type -> Type
forall shape u.
ArrayShape shape =>
TypeBase shape u -> TypeBase shape u
rowType Type
source_t
([Stm lore], [Stm lore]) -> m ()
forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell ([[Ident] -> [Ident] -> Exp lore -> Stm lore
forall lore.
Bindable lore =>
[Ident] -> [Ident] -> Exp lore -> Stm lore
mkLet [] [VName -> Type -> Ident
Ident VName
source Type
source_t] (Exp lore -> Stm lore) -> Exp lore -> Stm lore
forall a b. (a -> b) -> a -> b
$ BasicOp lore -> Exp lore
forall lore. BasicOp lore -> ExpT lore
BasicOp (BasicOp lore -> Exp lore) -> BasicOp lore -> Exp lore
forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> SubExp -> BasicOp lore
forall lore. VName -> Slice SubExp -> SubExp -> BasicOp lore
Update
(DesiredUpdate (a, Type) -> VName
forall attr. DesiredUpdate attr -> VName
updateSource DesiredUpdate (a, Type)
update)
(Type -> Slice SubExp -> Slice SubExp
fullSlice Type
source_t (Slice SubExp -> Slice SubExp) -> Slice SubExp -> Slice SubExp
forall a b. (a -> b) -> a -> b
$ DesiredUpdate (a, Type) -> Slice SubExp
forall attr. DesiredUpdate attr -> Slice SubExp
updateIndices DesiredUpdate (a, Type)
update) (SubExp -> BasicOp lore) -> SubExp -> BasicOp lore
forall a b. (a -> b) -> a -> b
$
(Param DeclType, SubExp) -> SubExp
forall a b. (a, b) -> b
snd ((Param DeclType, SubExp) -> SubExp)
-> (Param DeclType, SubExp) -> SubExp
forall a b. (a -> b) -> a -> b
$ LoopResultSummary (a, Type) -> (Param DeclType, SubExp)
forall attr. LoopResultSummary attr -> (Param DeclType, SubExp)
mergeParam LoopResultSummary (a, Type)
summary],
[[Ident] -> [Ident] -> Exp lore -> Stm lore
forall lore.
Bindable lore =>
[Ident] -> [Ident] -> Exp lore -> Stm lore
mkLet [] [Ident
elmident] (Exp lore -> Stm lore) -> Exp lore -> Stm lore
forall a b. (a -> b) -> a -> b
$ BasicOp lore -> Exp lore
forall lore. BasicOp lore -> ExpT lore
BasicOp (BasicOp lore -> Exp lore) -> BasicOp lore -> Exp lore
forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp lore
forall lore. VName -> Slice SubExp -> BasicOp lore
Index
(DesiredUpdate (a, Type) -> VName
forall attr. DesiredUpdate attr -> VName
updateName DesiredUpdate (a, Type)
update) (Type -> Slice SubExp -> Slice SubExp
fullSlice ((a, Type) -> Type
forall t. Typed t => t -> Type
typeOf ((a, Type) -> Type) -> (a, Type) -> Type
forall a b. (a -> b) -> a -> b
$ DesiredUpdate (a, Type) -> (a, Type)
forall attr. DesiredUpdate attr -> attr
updateType DesiredUpdate (a, Type)
update) (Slice SubExp -> Slice SubExp) -> Slice SubExp -> Slice SubExp
forall a b. (a -> b) -> a -> b
$ DesiredUpdate (a, Type) -> Slice SubExp
forall attr. DesiredUpdate attr -> Slice SubExp
updateIndices DesiredUpdate (a, Type)
update)])
Either (Param DeclType, SubExp) (Param DeclType, SubExp)
-> m (Either (Param DeclType, SubExp) (Param DeclType, SubExp))
forall (m :: * -> *) a. Monad m => a -> m a
return (Either (Param DeclType, SubExp) (Param DeclType, SubExp)
-> m (Either (Param DeclType, SubExp) (Param DeclType, SubExp)))
-> Either (Param DeclType, SubExp) (Param DeclType, SubExp)
-> m (Either (Param DeclType, SubExp) (Param DeclType, SubExp))
forall a b. (a -> b) -> a -> b
$ (Param DeclType, SubExp)
-> Either (Param DeclType, SubExp) (Param DeclType, SubExp)
forall a b. b -> Either a b
Right (VName -> DeclType -> Param DeclType
forall attr. VName -> attr -> Param attr
Param
VName
mergename
(Type -> Uniqueness -> DeclType
forall shape.
TypeBase shape NoUniqueness
-> Uniqueness -> TypeBase shape Uniqueness
toDecl ((a, Type) -> Type
forall t. Typed t => t -> Type
typeOf (a, Type)
mergeattr) Uniqueness
Unique),
VName -> SubExp
Var VName
source)
| Bool
otherwise = Either (Param DeclType, SubExp) (Param DeclType, SubExp)
-> m (Either (Param DeclType, SubExp) (Param DeclType, SubExp))
forall (m :: * -> *) a. Monad m => a -> m a
return (Either (Param DeclType, SubExp) (Param DeclType, SubExp)
-> m (Either (Param DeclType, SubExp) (Param DeclType, SubExp)))
-> Either (Param DeclType, SubExp) (Param DeclType, SubExp)
-> m (Either (Param DeclType, SubExp) (Param DeclType, SubExp))
forall a b. (a -> b) -> a -> b
$ (Param DeclType, SubExp)
-> Either (Param DeclType, SubExp) (Param DeclType, SubExp)
forall a b. a -> Either a b
Left ((Param DeclType, SubExp)
-> Either (Param DeclType, SubExp) (Param DeclType, SubExp))
-> (Param DeclType, SubExp)
-> Either (Param DeclType, SubExp) (Param DeclType, SubExp)
forall a b. (a -> b) -> a -> b
$ LoopResultSummary (a, Type) -> (Param DeclType, SubExp)
forall attr. LoopResultSummary attr -> (Param DeclType, SubExp)
mergeParam LoopResultSummary (a, Type)
summary
mkResAndPat :: [LoopResultSummary (als, Type)] -> ([Ident], [Ident])
mkResAndPat [LoopResultSummary (als, Type)]
summaries =
let ([Ident]
origpat,[Ident]
extrapat) = [Either Ident Ident] -> ([Ident], [Ident])
forall a b. [Either a b] -> ([a], [b])
partitionEithers ([Either Ident Ident] -> ([Ident], [Ident]))
-> [Either Ident Ident] -> ([Ident], [Ident])
forall a b. (a -> b) -> a -> b
$ (LoopResultSummary (als, Type) -> Either Ident Ident)
-> [LoopResultSummary (als, Type)] -> [Either Ident Ident]
forall a b. (a -> b) -> [a] -> [b]
map LoopResultSummary (als, Type) -> Either Ident Ident
forall a. LoopResultSummary (a, Type) -> Either Ident Ident
mkResAndPat' [LoopResultSummary (als, Type)]
summaries
in (PatternT (als, Type) -> [Ident]
forall attr. Typed attr => PatternT attr -> [Ident]
patternContextIdents PatternT (als, Type)
Pattern lore
pat,
[Ident]
origpat [Ident] -> [Ident] -> [Ident]
forall a. [a] -> [a] -> [a]
++ [Ident]
extrapat)
mkResAndPat' :: LoopResultSummary (a, Type) -> Either Ident Ident
mkResAndPat' LoopResultSummary (a, Type)
summary
| Just (DesiredUpdate (a, Type)
update, VName
_, (a, Type)
_) <- LoopResultSummary (a, Type)
-> Maybe (DesiredUpdate (a, Type), VName, (a, Type))
forall attr.
LoopResultSummary attr -> Maybe (DesiredUpdate attr, VName, attr)
relatedUpdate LoopResultSummary (a, Type)
summary =
Ident -> Either Ident Ident
forall a b. b -> Either a b
Right (VName -> Type -> Ident
Ident (DesiredUpdate (a, Type) -> VName
forall attr. DesiredUpdate attr -> VName
updateName DesiredUpdate (a, Type)
update) ((a, Type) -> Type
forall a b. (a, b) -> b
snd ((a, Type) -> Type) -> (a, Type) -> Type
forall a b. (a -> b) -> a -> b
$ DesiredUpdate (a, Type) -> (a, Type)
forall attr. DesiredUpdate attr -> attr
updateType DesiredUpdate (a, Type)
update))
| Bool
otherwise =
Ident -> Either Ident Ident
forall a b. a -> Either a b
Left (LoopResultSummary (a, Type) -> Ident
forall attr. LoopResultSummary attr -> Ident
inPatternAs LoopResultSummary (a, Type)
summary)
summariseLoop :: MonadFreshNames m =>
[DesiredUpdate (als, Type)]
-> Names
-> [(SubExp, Ident)]
-> [(Param DeclType, SubExp)]
-> Maybe (m [LoopResultSummary (als, Type)])
summariseLoop :: [DesiredUpdate (als, Type)]
-> Names
-> [(SubExp, Ident)]
-> [(Param DeclType, SubExp)]
-> Maybe (m [LoopResultSummary (als, Type)])
summariseLoop [DesiredUpdate (als, Type)]
updates Names
usedInBody [(SubExp, Ident)]
resmap [(Param DeclType, SubExp)]
merge =
[m (LoopResultSummary (als, Type))]
-> m [LoopResultSummary (als, Type)]
forall (t :: * -> *) (m :: * -> *) a.
(Traversable t, Monad m) =>
t (m a) -> m (t a)
sequence ([m (LoopResultSummary (als, Type))]
-> m [LoopResultSummary (als, Type)])
-> Maybe [m (LoopResultSummary (als, Type))]
-> Maybe (m [LoopResultSummary (als, Type)])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ((SubExp, Ident)
-> (Param DeclType, SubExp)
-> Maybe (m (LoopResultSummary (als, Type))))
-> [(SubExp, Ident)]
-> [(Param DeclType, SubExp)]
-> Maybe [m (LoopResultSummary (als, Type))]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM (SubExp, Ident)
-> (Param DeclType, SubExp)
-> Maybe (m (LoopResultSummary (als, Type)))
summariseLoopResult [(SubExp, Ident)]
resmap [(Param DeclType, SubExp)]
merge
where summariseLoopResult :: (SubExp, Ident)
-> (Param DeclType, SubExp)
-> Maybe (m (LoopResultSummary (als, Type)))
summariseLoopResult (SubExp
se, Ident
v) (Param DeclType
fparam, SubExp
mergeinit)
| Just DesiredUpdate (als, Type)
update <- (DesiredUpdate (als, Type) -> Bool)
-> [DesiredUpdate (als, Type)] -> Maybe (DesiredUpdate (als, Type))
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
find (VName -> DesiredUpdate (als, Type) -> Bool
forall attr. VName -> DesiredUpdate attr -> Bool
updateHasValue (VName -> DesiredUpdate (als, Type) -> Bool)
-> VName -> DesiredUpdate (als, Type) -> Bool
forall a b. (a -> b) -> a -> b
$ Ident -> VName
identName Ident
v) [DesiredUpdate (als, Type)]
updates =
if DesiredUpdate (als, Type) -> VName
forall attr. DesiredUpdate attr -> VName
updateSource DesiredUpdate (als, Type)
update VName -> Names -> Bool
`nameIn` Names
usedInBody
then Maybe (m (LoopResultSummary (als, Type)))
forall a. Maybe a
Nothing
else if Param DeclType -> Bool
hasLoopInvariantShape Param DeclType
fparam then m (LoopResultSummary (als, Type))
-> Maybe (m (LoopResultSummary (als, Type)))
forall a. a -> Maybe a
Just (m (LoopResultSummary (als, Type))
-> Maybe (m (LoopResultSummary (als, Type))))
-> m (LoopResultSummary (als, Type))
-> Maybe (m (LoopResultSummary (als, Type)))
forall a b. (a -> b) -> a -> b
$ do
VName
lowered_array <- String -> m VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"lowered_array"
LoopResultSummary (als, Type) -> m (LoopResultSummary (als, Type))
forall (m :: * -> *) a. Monad m => a -> m a
return LoopResultSummary :: forall attr.
SubExp
-> Ident
-> (Param DeclType, SubExp)
-> Maybe (DesiredUpdate attr, VName, attr)
-> LoopResultSummary attr
LoopResultSummary { resultSubExp :: SubExp
resultSubExp = SubExp
se
, inPatternAs :: Ident
inPatternAs = Ident
v
, mergeParam :: (Param DeclType, SubExp)
mergeParam = (Param DeclType
fparam, SubExp
mergeinit)
, relatedUpdate :: Maybe (DesiredUpdate (als, Type), VName, (als, Type))
relatedUpdate = (DesiredUpdate (als, Type), VName, (als, Type))
-> Maybe (DesiredUpdate (als, Type), VName, (als, Type))
forall a. a -> Maybe a
Just (DesiredUpdate (als, Type)
update,
VName
lowered_array,
DesiredUpdate (als, Type) -> (als, Type)
forall attr. DesiredUpdate attr -> attr
updateType DesiredUpdate (als, Type)
update)
}
else Maybe (m (LoopResultSummary (als, Type)))
forall a. Maybe a
Nothing
summariseLoopResult (SubExp, Ident)
_ (Param DeclType, SubExp)
_ =
Maybe (m (LoopResultSummary (als, Type)))
forall a. Maybe a
Nothing
hasLoopInvariantShape :: Param DeclType -> Bool
hasLoopInvariantShape = (SubExp -> Bool) -> [SubExp] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all SubExp -> Bool
loopInvariant ([SubExp] -> Bool)
-> (Param DeclType -> [SubExp]) -> Param DeclType -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Type -> [SubExp]
forall u. TypeBase Shape u -> [SubExp]
arrayDims (Type -> [SubExp])
-> (Param DeclType -> Type) -> Param DeclType -> [SubExp]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Param DeclType -> Type
forall attr. Typed attr => Param attr -> Type
paramType
merge_param_names :: [VName]
merge_param_names = ((Param DeclType, SubExp) -> VName)
-> [(Param DeclType, SubExp)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map (Param DeclType -> VName
forall attr. Param attr -> VName
paramName (Param DeclType -> VName)
-> ((Param DeclType, SubExp) -> Param DeclType)
-> (Param DeclType, SubExp)
-> VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Param DeclType, SubExp) -> Param DeclType
forall a b. (a, b) -> a
fst) [(Param DeclType, SubExp)]
merge
loopInvariant :: SubExp -> Bool
loopInvariant (Var VName
v) = VName
v VName -> [VName] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`notElem` [VName]
merge_param_names
loopInvariant Constant{} = Bool
True
data LoopResultSummary attr =
LoopResultSummary { LoopResultSummary attr -> SubExp
resultSubExp :: SubExp
, LoopResultSummary attr -> Ident
inPatternAs :: Ident
, LoopResultSummary attr -> (Param DeclType, SubExp)
mergeParam :: (Param DeclType, SubExp)
, LoopResultSummary attr -> Maybe (DesiredUpdate attr, VName, attr)
relatedUpdate :: Maybe (DesiredUpdate attr, VName, attr)
}
deriving (Int -> LoopResultSummary attr -> ShowS
[LoopResultSummary attr] -> ShowS
LoopResultSummary attr -> String
(Int -> LoopResultSummary attr -> ShowS)
-> (LoopResultSummary attr -> String)
-> ([LoopResultSummary attr] -> ShowS)
-> Show (LoopResultSummary attr)
forall attr. Show attr => Int -> LoopResultSummary attr -> ShowS
forall attr. Show attr => [LoopResultSummary attr] -> ShowS
forall attr. Show attr => LoopResultSummary attr -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [LoopResultSummary attr] -> ShowS
$cshowList :: forall attr. Show attr => [LoopResultSummary attr] -> ShowS
show :: LoopResultSummary attr -> String
$cshow :: forall attr. Show attr => LoopResultSummary attr -> String
showsPrec :: Int -> LoopResultSummary attr -> ShowS
$cshowsPrec :: forall attr. Show attr => Int -> LoopResultSummary attr -> ShowS
Show)
indexSubstitutions :: [LoopResultSummary attr]
-> IndexSubstitutions attr
indexSubstitutions :: [LoopResultSummary attr] -> IndexSubstitutions attr
indexSubstitutions = (LoopResultSummary attr
-> Maybe (VName, (Certificates, VName, attr, Slice SubExp)))
-> [LoopResultSummary attr] -> IndexSubstitutions attr
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe LoopResultSummary attr
-> Maybe (VName, (Certificates, VName, attr, Slice SubExp))
forall c.
LoopResultSummary c
-> Maybe (VName, (Certificates, VName, c, Slice SubExp))
getSubstitution
where getSubstitution :: LoopResultSummary c
-> Maybe (VName, (Certificates, VName, c, Slice SubExp))
getSubstitution LoopResultSummary c
res = do
(DesiredUpdate VName
_ c
_ Certificates
cs VName
_ Slice SubExp
is VName
_, VName
nm, c
attr) <- LoopResultSummary c -> Maybe (DesiredUpdate c, VName, c)
forall attr.
LoopResultSummary attr -> Maybe (DesiredUpdate attr, VName, attr)
relatedUpdate LoopResultSummary c
res
let name :: VName
name = Param DeclType -> VName
forall attr. Param attr -> VName
paramName (Param DeclType -> VName) -> Param DeclType -> VName
forall a b. (a -> b) -> a -> b
$ (Param DeclType, SubExp) -> Param DeclType
forall a b. (a, b) -> a
fst ((Param DeclType, SubExp) -> Param DeclType)
-> (Param DeclType, SubExp) -> Param DeclType
forall a b. (a -> b) -> a -> b
$ LoopResultSummary c -> (Param DeclType, SubExp)
forall attr. LoopResultSummary attr -> (Param DeclType, SubExp)
mergeParam LoopResultSummary c
res
(VName, (Certificates, VName, c, Slice SubExp))
-> Maybe (VName, (Certificates, VName, c, Slice SubExp))
forall (m :: * -> *) a. Monad m => a -> m a
return (VName
name, (Certificates
cs, VName
nm, c
attr, Slice SubExp
is))
manipulateResult :: (Bindable lore, MonadFreshNames m) =>
[LoopResultSummary (LetAttr lore)]
-> IndexSubstitutions (LetAttr lore)
-> m (Result, Stms lore)
manipulateResult :: [LoopResultSummary (LetAttr lore)]
-> IndexSubstitutions (LetAttr lore) -> m ([SubExp], Stms lore)
manipulateResult [LoopResultSummary (LetAttr lore)]
summaries IndexSubstitutions (LetAttr lore)
substs = do
let ([SubExp]
orig_ses,[SubExp]
updated_ses) = [Either SubExp SubExp] -> ([SubExp], [SubExp])
forall a b. [Either a b] -> ([a], [b])
partitionEithers ([Either SubExp SubExp] -> ([SubExp], [SubExp]))
-> [Either SubExp SubExp] -> ([SubExp], [SubExp])
forall a b. (a -> b) -> a -> b
$ (LoopResultSummary (LetAttr lore) -> Either SubExp SubExp)
-> [LoopResultSummary (LetAttr lore)] -> [Either SubExp SubExp]
forall a b. (a -> b) -> [a] -> [b]
map LoopResultSummary (LetAttr lore) -> Either SubExp SubExp
forall attr. LoopResultSummary attr -> Either SubExp SubExp
unchangedRes [LoopResultSummary (LetAttr lore)]
summaries
([SubExp]
subst_ses, [Stm lore]
res_bnds) <- WriterT [Stm lore] m [SubExp] -> m ([SubExp], [Stm lore])
forall w (m :: * -> *) a. WriterT w m a -> m (a, w)
runWriterT (WriterT [Stm lore] m [SubExp] -> m ([SubExp], [Stm lore]))
-> WriterT [Stm lore] m [SubExp] -> m ([SubExp], [Stm lore])
forall a b. (a -> b) -> a -> b
$ (SubExp
-> (VName, (Certificates, VName, LetAttr lore, Slice SubExp))
-> WriterT [Stm lore] m SubExp)
-> [SubExp]
-> IndexSubstitutions (LetAttr lore)
-> WriterT [Stm lore] m [SubExp]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM SubExp
-> (VName, (Certificates, VName, LetAttr lore, Slice SubExp))
-> WriterT [Stm lore] m SubExp
forall (m :: * -> *) lore t.
(MonadFreshNames m, MonadWriter [Stm lore] m, Bindable lore,
Typed t) =>
SubExp
-> (VName, (Certificates, VName, t, Slice SubExp)) -> m SubExp
substRes [SubExp]
updated_ses IndexSubstitutions (LetAttr lore)
substs
([SubExp], Stms lore) -> m ([SubExp], Stms lore)
forall (m :: * -> *) a. Monad m => a -> m a
return ([SubExp]
orig_ses [SubExp] -> [SubExp] -> [SubExp]
forall a. [a] -> [a] -> [a]
++ [SubExp]
subst_ses, [Stm lore] -> Stms lore
forall lore. [Stm lore] -> Stms lore
stmsFromList [Stm lore]
res_bnds)
where
unchangedRes :: LoopResultSummary attr -> Either SubExp SubExp
unchangedRes LoopResultSummary attr
summary =
case LoopResultSummary attr -> Maybe (DesiredUpdate attr, VName, attr)
forall attr.
LoopResultSummary attr -> Maybe (DesiredUpdate attr, VName, attr)
relatedUpdate LoopResultSummary attr
summary of
Maybe (DesiredUpdate attr, VName, attr)
Nothing -> SubExp -> Either SubExp SubExp
forall a b. a -> Either a b
Left (SubExp -> Either SubExp SubExp) -> SubExp -> Either SubExp SubExp
forall a b. (a -> b) -> a -> b
$ LoopResultSummary attr -> SubExp
forall attr. LoopResultSummary attr -> SubExp
resultSubExp LoopResultSummary attr
summary
Just (DesiredUpdate attr, VName, attr)
_ -> SubExp -> Either SubExp SubExp
forall a b. b -> Either a b
Right (SubExp -> Either SubExp SubExp) -> SubExp -> Either SubExp SubExp
forall a b. (a -> b) -> a -> b
$ LoopResultSummary attr -> SubExp
forall attr. LoopResultSummary attr -> SubExp
resultSubExp LoopResultSummary attr
summary
substRes :: SubExp
-> (VName, (Certificates, VName, t, Slice SubExp)) -> m SubExp
substRes (Var VName
res_v) (VName
subst_v, (Certificates
_, VName
nm, t
_, Slice SubExp
_))
| VName
res_v VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== VName
subst_v =
SubExp -> m SubExp
forall (m :: * -> *) a. Monad m => a -> m a
return (SubExp -> m SubExp) -> SubExp -> m SubExp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
nm
substRes SubExp
res_se (VName
_, (Certificates
cs, VName
nm, t
attr, Slice SubExp
is)) = do
Ident
v' <- ShowS -> Ident -> m Ident
forall (m :: * -> *).
MonadFreshNames m =>
ShowS -> Ident -> m Ident
newIdent' (String -> ShowS
forall a. [a] -> [a] -> [a]
++String
"_updated") (Ident -> m Ident) -> Ident -> m Ident
forall a b. (a -> b) -> a -> b
$ VName -> Type -> Ident
Ident VName
nm (Type -> Ident) -> Type -> Ident
forall a b. (a -> b) -> a -> b
$ t -> Type
forall t. Typed t => t -> Type
typeOf t
attr
[Stm lore] -> m ()
forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell [Certificates -> Stm lore -> Stm lore
forall lore. Certificates -> Stm lore -> Stm lore
certify Certificates
cs (Stm lore -> Stm lore) -> Stm lore -> Stm lore
forall a b. (a -> b) -> a -> b
$ [Ident] -> [Ident] -> Exp lore -> Stm lore
forall lore.
Bindable lore =>
[Ident] -> [Ident] -> Exp lore -> Stm lore
mkLet [] [Ident
v'] (Exp lore -> Stm lore) -> Exp lore -> Stm lore
forall a b. (a -> b) -> a -> b
$ BasicOp lore -> Exp lore
forall lore. BasicOp lore -> ExpT lore
BasicOp (BasicOp lore -> Exp lore) -> BasicOp lore -> Exp lore
forall a b. (a -> b) -> a -> b
$
VName -> Slice SubExp -> SubExp -> BasicOp lore
forall lore. VName -> Slice SubExp -> SubExp -> BasicOp lore
Update VName
nm (Type -> Slice SubExp -> Slice SubExp
fullSlice (t -> Type
forall t. Typed t => t -> Type
typeOf t
attr) Slice SubExp
is) SubExp
res_se]
SubExp -> m SubExp
forall (m :: * -> *) a. Monad m => a -> m a
return (SubExp -> m SubExp) -> SubExp -> m SubExp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ Ident -> VName
identName Ident
v'