{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE UndecidableInstances #-}
module Futhark.IR.SOACS.SOAC
( SOAC (..),
StreamOrd (..),
StreamForm (..),
ScremaForm (..),
HistOp (..),
Scan (..),
scanResults,
singleScan,
Reduce (..),
redResults,
singleReduce,
scremaType,
soacType,
typeCheckSOAC,
mkIdentityLambda,
isIdentityLambda,
nilFn,
scanomapSOAC,
redomapSOAC,
scanSOAC,
reduceSOAC,
mapSOAC,
isScanomapSOAC,
isRedomapSOAC,
isScanSOAC,
isReduceSOAC,
isMapSOAC,
ppScrema,
ppHist,
groupScatterResults,
groupScatterResults',
splitScatterResults,
SOACMapper (..),
identitySOACMapper,
mapSOACM,
)
where
import Control.Category
import Control.Monad.Identity
import Control.Monad.State.Strict
import Control.Monad.Writer
import Data.Function ((&))
import Data.List (intersperse)
import qualified Data.Map.Strict as M
import Data.Maybe
import qualified Futhark.Analysis.Alias as Alias
import Futhark.Analysis.Metrics
import Futhark.Analysis.PrimExp.Convert
import qualified Futhark.Analysis.SymbolTable as ST
import Futhark.Construct
import Futhark.IR
import Futhark.IR.Aliases (Aliases, removeLambdaAliases)
import Futhark.IR.Prop.Aliases
import Futhark.Optimise.Simplify.Lore
import Futhark.Transform.Rename
import Futhark.Transform.Substitute
import qualified Futhark.TypeCheck as TC
import Futhark.Util (chunks, maybeNth)
import Futhark.Util.Pretty (Doc, Pretty, comma, commasep, parens, ppr, text, (<+>), (</>))
import qualified Futhark.Util.Pretty as PP
import Prelude hiding (id, (.))
data SOAC lore
= Stream SubExp [VName] (StreamForm lore) [SubExp] (Lambda lore)
|
Scatter SubExp (Lambda lore) [VName] [(Shape, Int, VName)]
|
Hist SubExp [HistOp lore] (Lambda lore) [VName]
|
Screma SubExp [VName] (ScremaForm lore)
deriving (SOAC lore -> SOAC lore -> Bool
(SOAC lore -> SOAC lore -> Bool)
-> (SOAC lore -> SOAC lore -> Bool) -> Eq (SOAC lore)
forall lore. Decorations lore => SOAC lore -> SOAC lore -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: SOAC lore -> SOAC lore -> Bool
$c/= :: forall lore. Decorations lore => SOAC lore -> SOAC lore -> Bool
== :: SOAC lore -> SOAC lore -> Bool
$c== :: forall lore. Decorations lore => SOAC lore -> SOAC lore -> Bool
Eq, Eq (SOAC lore)
Eq (SOAC lore)
-> (SOAC lore -> SOAC lore -> Ordering)
-> (SOAC lore -> SOAC lore -> Bool)
-> (SOAC lore -> SOAC lore -> Bool)
-> (SOAC lore -> SOAC lore -> Bool)
-> (SOAC lore -> SOAC lore -> Bool)
-> (SOAC lore -> SOAC lore -> SOAC lore)
-> (SOAC lore -> SOAC lore -> SOAC lore)
-> Ord (SOAC lore)
SOAC lore -> SOAC lore -> Bool
SOAC lore -> SOAC lore -> Ordering
SOAC lore -> SOAC lore -> SOAC lore
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
forall lore. Decorations lore => Eq (SOAC lore)
forall lore. Decorations lore => SOAC lore -> SOAC lore -> Bool
forall lore. Decorations lore => SOAC lore -> SOAC lore -> Ordering
forall lore.
Decorations lore =>
SOAC lore -> SOAC lore -> SOAC lore
min :: SOAC lore -> SOAC lore -> SOAC lore
$cmin :: forall lore.
Decorations lore =>
SOAC lore -> SOAC lore -> SOAC lore
max :: SOAC lore -> SOAC lore -> SOAC lore
$cmax :: forall lore.
Decorations lore =>
SOAC lore -> SOAC lore -> SOAC lore
>= :: SOAC lore -> SOAC lore -> Bool
$c>= :: forall lore. Decorations lore => SOAC lore -> SOAC lore -> Bool
> :: SOAC lore -> SOAC lore -> Bool
$c> :: forall lore. Decorations lore => SOAC lore -> SOAC lore -> Bool
<= :: SOAC lore -> SOAC lore -> Bool
$c<= :: forall lore. Decorations lore => SOAC lore -> SOAC lore -> Bool
< :: SOAC lore -> SOAC lore -> Bool
$c< :: forall lore. Decorations lore => SOAC lore -> SOAC lore -> Bool
compare :: SOAC lore -> SOAC lore -> Ordering
$ccompare :: forall lore. Decorations lore => SOAC lore -> SOAC lore -> Ordering
Ord, Int -> SOAC lore -> ShowS
[SOAC lore] -> ShowS
SOAC lore -> String
(Int -> SOAC lore -> ShowS)
-> (SOAC lore -> String)
-> ([SOAC lore] -> ShowS)
-> Show (SOAC lore)
forall lore. Decorations lore => Int -> SOAC lore -> ShowS
forall lore. Decorations lore => [SOAC lore] -> ShowS
forall lore. Decorations lore => SOAC lore -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [SOAC lore] -> ShowS
$cshowList :: forall lore. Decorations lore => [SOAC lore] -> ShowS
show :: SOAC lore -> String
$cshow :: forall lore. Decorations lore => SOAC lore -> String
showsPrec :: Int -> SOAC lore -> ShowS
$cshowsPrec :: forall lore. Decorations lore => Int -> SOAC lore -> ShowS
Show)
data HistOp lore = HistOp
{ forall lore. HistOp lore -> SubExp
histWidth :: SubExp,
forall lore. HistOp lore -> SubExp
histRaceFactor :: SubExp,
forall lore. HistOp lore -> [VName]
histDest :: [VName],
forall lore. HistOp lore -> [SubExp]
histNeutral :: [SubExp],
forall lore. HistOp lore -> Lambda lore
histOp :: Lambda lore
}
deriving (HistOp lore -> HistOp lore -> Bool
(HistOp lore -> HistOp lore -> Bool)
-> (HistOp lore -> HistOp lore -> Bool) -> Eq (HistOp lore)
forall lore. Decorations lore => HistOp lore -> HistOp lore -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: HistOp lore -> HistOp lore -> Bool
$c/= :: forall lore. Decorations lore => HistOp lore -> HistOp lore -> Bool
== :: HistOp lore -> HistOp lore -> Bool
$c== :: forall lore. Decorations lore => HistOp lore -> HistOp lore -> Bool
Eq, Eq (HistOp lore)
Eq (HistOp lore)
-> (HistOp lore -> HistOp lore -> Ordering)
-> (HistOp lore -> HistOp lore -> Bool)
-> (HistOp lore -> HistOp lore -> Bool)
-> (HistOp lore -> HistOp lore -> Bool)
-> (HistOp lore -> HistOp lore -> Bool)
-> (HistOp lore -> HistOp lore -> HistOp lore)
-> (HistOp lore -> HistOp lore -> HistOp lore)
-> Ord (HistOp lore)
HistOp lore -> HistOp lore -> Bool
HistOp lore -> HistOp lore -> Ordering
HistOp lore -> HistOp lore -> HistOp lore
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
forall lore. Decorations lore => Eq (HistOp lore)
forall lore. Decorations lore => HistOp lore -> HistOp lore -> Bool
forall lore.
Decorations lore =>
HistOp lore -> HistOp lore -> Ordering
forall lore.
Decorations lore =>
HistOp lore -> HistOp lore -> HistOp lore
min :: HistOp lore -> HistOp lore -> HistOp lore
$cmin :: forall lore.
Decorations lore =>
HistOp lore -> HistOp lore -> HistOp lore
max :: HistOp lore -> HistOp lore -> HistOp lore
$cmax :: forall lore.
Decorations lore =>
HistOp lore -> HistOp lore -> HistOp lore
>= :: HistOp lore -> HistOp lore -> Bool
$c>= :: forall lore. Decorations lore => HistOp lore -> HistOp lore -> Bool
> :: HistOp lore -> HistOp lore -> Bool
$c> :: forall lore. Decorations lore => HistOp lore -> HistOp lore -> Bool
<= :: HistOp lore -> HistOp lore -> Bool
$c<= :: forall lore. Decorations lore => HistOp lore -> HistOp lore -> Bool
< :: HistOp lore -> HistOp lore -> Bool
$c< :: forall lore. Decorations lore => HistOp lore -> HistOp lore -> Bool
compare :: HistOp lore -> HistOp lore -> Ordering
$ccompare :: forall lore.
Decorations lore =>
HistOp lore -> HistOp lore -> Ordering
Ord, Int -> HistOp lore -> ShowS
[HistOp lore] -> ShowS
HistOp lore -> String
(Int -> HistOp lore -> ShowS)
-> (HistOp lore -> String)
-> ([HistOp lore] -> ShowS)
-> Show (HistOp lore)
forall lore. Decorations lore => Int -> HistOp lore -> ShowS
forall lore. Decorations lore => [HistOp lore] -> ShowS
forall lore. Decorations lore => HistOp lore -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [HistOp lore] -> ShowS
$cshowList :: forall lore. Decorations lore => [HistOp lore] -> ShowS
show :: HistOp lore -> String
$cshow :: forall lore. Decorations lore => HistOp lore -> String
showsPrec :: Int -> HistOp lore -> ShowS
$cshowsPrec :: forall lore. Decorations lore => Int -> HistOp lore -> ShowS
Show)
data StreamOrd = InOrder | Disorder
deriving (StreamOrd -> StreamOrd -> Bool
(StreamOrd -> StreamOrd -> Bool)
-> (StreamOrd -> StreamOrd -> Bool) -> Eq StreamOrd
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: StreamOrd -> StreamOrd -> Bool
$c/= :: StreamOrd -> StreamOrd -> Bool
== :: StreamOrd -> StreamOrd -> Bool
$c== :: StreamOrd -> StreamOrd -> Bool
Eq, Eq StreamOrd
Eq StreamOrd
-> (StreamOrd -> StreamOrd -> Ordering)
-> (StreamOrd -> StreamOrd -> Bool)
-> (StreamOrd -> StreamOrd -> Bool)
-> (StreamOrd -> StreamOrd -> Bool)
-> (StreamOrd -> StreamOrd -> Bool)
-> (StreamOrd -> StreamOrd -> StreamOrd)
-> (StreamOrd -> StreamOrd -> StreamOrd)
-> Ord StreamOrd
StreamOrd -> StreamOrd -> Bool
StreamOrd -> StreamOrd -> Ordering
StreamOrd -> StreamOrd -> StreamOrd
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: StreamOrd -> StreamOrd -> StreamOrd
$cmin :: StreamOrd -> StreamOrd -> StreamOrd
max :: StreamOrd -> StreamOrd -> StreamOrd
$cmax :: StreamOrd -> StreamOrd -> StreamOrd
>= :: StreamOrd -> StreamOrd -> Bool
$c>= :: StreamOrd -> StreamOrd -> Bool
> :: StreamOrd -> StreamOrd -> Bool
$c> :: StreamOrd -> StreamOrd -> Bool
<= :: StreamOrd -> StreamOrd -> Bool
$c<= :: StreamOrd -> StreamOrd -> Bool
< :: StreamOrd -> StreamOrd -> Bool
$c< :: StreamOrd -> StreamOrd -> Bool
compare :: StreamOrd -> StreamOrd -> Ordering
$ccompare :: StreamOrd -> StreamOrd -> Ordering
Ord, Int -> StreamOrd -> ShowS
[StreamOrd] -> ShowS
StreamOrd -> String
(Int -> StreamOrd -> ShowS)
-> (StreamOrd -> String)
-> ([StreamOrd] -> ShowS)
-> Show StreamOrd
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [StreamOrd] -> ShowS
$cshowList :: [StreamOrd] -> ShowS
show :: StreamOrd -> String
$cshow :: StreamOrd -> String
showsPrec :: Int -> StreamOrd -> ShowS
$cshowsPrec :: Int -> StreamOrd -> ShowS
Show)
data StreamForm lore
= Parallel StreamOrd Commutativity (Lambda lore)
| Sequential
deriving (StreamForm lore -> StreamForm lore -> Bool
(StreamForm lore -> StreamForm lore -> Bool)
-> (StreamForm lore -> StreamForm lore -> Bool)
-> Eq (StreamForm lore)
forall lore.
Decorations lore =>
StreamForm lore -> StreamForm lore -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: StreamForm lore -> StreamForm lore -> Bool
$c/= :: forall lore.
Decorations lore =>
StreamForm lore -> StreamForm lore -> Bool
== :: StreamForm lore -> StreamForm lore -> Bool
$c== :: forall lore.
Decorations lore =>
StreamForm lore -> StreamForm lore -> Bool
Eq, Eq (StreamForm lore)
Eq (StreamForm lore)
-> (StreamForm lore -> StreamForm lore -> Ordering)
-> (StreamForm lore -> StreamForm lore -> Bool)
-> (StreamForm lore -> StreamForm lore -> Bool)
-> (StreamForm lore -> StreamForm lore -> Bool)
-> (StreamForm lore -> StreamForm lore -> Bool)
-> (StreamForm lore -> StreamForm lore -> StreamForm lore)
-> (StreamForm lore -> StreamForm lore -> StreamForm lore)
-> Ord (StreamForm lore)
StreamForm lore -> StreamForm lore -> Bool
StreamForm lore -> StreamForm lore -> Ordering
StreamForm lore -> StreamForm lore -> StreamForm lore
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
forall lore. Decorations lore => Eq (StreamForm lore)
forall lore.
Decorations lore =>
StreamForm lore -> StreamForm lore -> Bool
forall lore.
Decorations lore =>
StreamForm lore -> StreamForm lore -> Ordering
forall lore.
Decorations lore =>
StreamForm lore -> StreamForm lore -> StreamForm lore
min :: StreamForm lore -> StreamForm lore -> StreamForm lore
$cmin :: forall lore.
Decorations lore =>
StreamForm lore -> StreamForm lore -> StreamForm lore
max :: StreamForm lore -> StreamForm lore -> StreamForm lore
$cmax :: forall lore.
Decorations lore =>
StreamForm lore -> StreamForm lore -> StreamForm lore
>= :: StreamForm lore -> StreamForm lore -> Bool
$c>= :: forall lore.
Decorations lore =>
StreamForm lore -> StreamForm lore -> Bool
> :: StreamForm lore -> StreamForm lore -> Bool
$c> :: forall lore.
Decorations lore =>
StreamForm lore -> StreamForm lore -> Bool
<= :: StreamForm lore -> StreamForm lore -> Bool
$c<= :: forall lore.
Decorations lore =>
StreamForm lore -> StreamForm lore -> Bool
< :: StreamForm lore -> StreamForm lore -> Bool
$c< :: forall lore.
Decorations lore =>
StreamForm lore -> StreamForm lore -> Bool
compare :: StreamForm lore -> StreamForm lore -> Ordering
$ccompare :: forall lore.
Decorations lore =>
StreamForm lore -> StreamForm lore -> Ordering
Ord, Int -> StreamForm lore -> ShowS
[StreamForm lore] -> ShowS
StreamForm lore -> String
(Int -> StreamForm lore -> ShowS)
-> (StreamForm lore -> String)
-> ([StreamForm lore] -> ShowS)
-> Show (StreamForm lore)
forall lore. Decorations lore => Int -> StreamForm lore -> ShowS
forall lore. Decorations lore => [StreamForm lore] -> ShowS
forall lore. Decorations lore => StreamForm lore -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [StreamForm lore] -> ShowS
$cshowList :: forall lore. Decorations lore => [StreamForm lore] -> ShowS
show :: StreamForm lore -> String
$cshow :: forall lore. Decorations lore => StreamForm lore -> String
showsPrec :: Int -> StreamForm lore -> ShowS
$cshowsPrec :: forall lore. Decorations lore => Int -> StreamForm lore -> ShowS
Show)
data ScremaForm lore
= ScremaForm
[Scan lore]
[Reduce lore]
(Lambda lore)
deriving (ScremaForm lore -> ScremaForm lore -> Bool
(ScremaForm lore -> ScremaForm lore -> Bool)
-> (ScremaForm lore -> ScremaForm lore -> Bool)
-> Eq (ScremaForm lore)
forall lore.
Decorations lore =>
ScremaForm lore -> ScremaForm lore -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: ScremaForm lore -> ScremaForm lore -> Bool
$c/= :: forall lore.
Decorations lore =>
ScremaForm lore -> ScremaForm lore -> Bool
== :: ScremaForm lore -> ScremaForm lore -> Bool
$c== :: forall lore.
Decorations lore =>
ScremaForm lore -> ScremaForm lore -> Bool
Eq, Eq (ScremaForm lore)
Eq (ScremaForm lore)
-> (ScremaForm lore -> ScremaForm lore -> Ordering)
-> (ScremaForm lore -> ScremaForm lore -> Bool)
-> (ScremaForm lore -> ScremaForm lore -> Bool)
-> (ScremaForm lore -> ScremaForm lore -> Bool)
-> (ScremaForm lore -> ScremaForm lore -> Bool)
-> (ScremaForm lore -> ScremaForm lore -> ScremaForm lore)
-> (ScremaForm lore -> ScremaForm lore -> ScremaForm lore)
-> Ord (ScremaForm lore)
ScremaForm lore -> ScremaForm lore -> Bool
ScremaForm lore -> ScremaForm lore -> Ordering
ScremaForm lore -> ScremaForm lore -> ScremaForm lore
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
forall lore. Decorations lore => Eq (ScremaForm lore)
forall lore.
Decorations lore =>
ScremaForm lore -> ScremaForm lore -> Bool
forall lore.
Decorations lore =>
ScremaForm lore -> ScremaForm lore -> Ordering
forall lore.
Decorations lore =>
ScremaForm lore -> ScremaForm lore -> ScremaForm lore
min :: ScremaForm lore -> ScremaForm lore -> ScremaForm lore
$cmin :: forall lore.
Decorations lore =>
ScremaForm lore -> ScremaForm lore -> ScremaForm lore
max :: ScremaForm lore -> ScremaForm lore -> ScremaForm lore
$cmax :: forall lore.
Decorations lore =>
ScremaForm lore -> ScremaForm lore -> ScremaForm lore
>= :: ScremaForm lore -> ScremaForm lore -> Bool
$c>= :: forall lore.
Decorations lore =>
ScremaForm lore -> ScremaForm lore -> Bool
> :: ScremaForm lore -> ScremaForm lore -> Bool
$c> :: forall lore.
Decorations lore =>
ScremaForm lore -> ScremaForm lore -> Bool
<= :: ScremaForm lore -> ScremaForm lore -> Bool
$c<= :: forall lore.
Decorations lore =>
ScremaForm lore -> ScremaForm lore -> Bool
< :: ScremaForm lore -> ScremaForm lore -> Bool
$c< :: forall lore.
Decorations lore =>
ScremaForm lore -> ScremaForm lore -> Bool
compare :: ScremaForm lore -> ScremaForm lore -> Ordering
$ccompare :: forall lore.
Decorations lore =>
ScremaForm lore -> ScremaForm lore -> Ordering
Ord, Int -> ScremaForm lore -> ShowS
[ScremaForm lore] -> ShowS
ScremaForm lore -> String
(Int -> ScremaForm lore -> ShowS)
-> (ScremaForm lore -> String)
-> ([ScremaForm lore] -> ShowS)
-> Show (ScremaForm lore)
forall lore. Decorations lore => Int -> ScremaForm lore -> ShowS
forall lore. Decorations lore => [ScremaForm lore] -> ShowS
forall lore. Decorations lore => ScremaForm lore -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [ScremaForm lore] -> ShowS
$cshowList :: forall lore. Decorations lore => [ScremaForm lore] -> ShowS
show :: ScremaForm lore -> String
$cshow :: forall lore. Decorations lore => ScremaForm lore -> String
showsPrec :: Int -> ScremaForm lore -> ShowS
$cshowsPrec :: forall lore. Decorations lore => Int -> ScremaForm lore -> ShowS
Show)
singleBinOp :: Bindable lore => [Lambda lore] -> Lambda lore
singleBinOp :: forall lore. Bindable lore => [Lambda lore] -> Lambda lore
singleBinOp [Lambda lore]
lams =
Lambda :: forall lore. [LParam lore] -> BodyT lore -> [Type] -> LambdaT lore
Lambda
{ lambdaParams :: [LParam lore]
lambdaParams = (Lambda lore -> [Param Type]) -> [Lambda lore] -> [Param Type]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap Lambda lore -> [Param Type]
forall {lore}. LambdaT lore -> [Param (LParamInfo lore)]
xParams [Lambda lore]
lams [Param Type] -> [Param Type] -> [Param Type]
forall a. [a] -> [a] -> [a]
++ (Lambda lore -> [Param Type]) -> [Lambda lore] -> [Param Type]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap Lambda lore -> [Param Type]
forall {lore}. LambdaT lore -> [Param (LParamInfo lore)]
yParams [Lambda lore]
lams,
lambdaReturnType :: [Type]
lambdaReturnType = (Lambda lore -> [Type]) -> [Lambda lore] -> [Type]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap Lambda lore -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType [Lambda lore]
lams,
lambdaBody :: BodyT lore
lambdaBody =
Stms lore -> [SubExp] -> BodyT lore
forall lore. Bindable lore => Stms lore -> [SubExp] -> Body lore
mkBody
([Stms lore] -> Stms lore
forall a. Monoid a => [a] -> a
mconcat ((Lambda lore -> Stms lore) -> [Lambda lore] -> [Stms lore]
forall a b. (a -> b) -> [a] -> [b]
map (BodyT lore -> Stms lore
forall lore. BodyT lore -> Stms lore
bodyStms (BodyT lore -> Stms lore)
-> (Lambda lore -> BodyT lore) -> Lambda lore -> Stms lore
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Lambda lore -> BodyT lore
forall lore. LambdaT lore -> BodyT lore
lambdaBody) [Lambda lore]
lams))
((Lambda lore -> [SubExp]) -> [Lambda lore] -> [SubExp]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (BodyT lore -> [SubExp]
forall lore. BodyT lore -> [SubExp]
bodyResult (BodyT lore -> [SubExp])
-> (Lambda lore -> BodyT lore) -> Lambda lore -> [SubExp]
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Lambda lore -> BodyT lore
forall lore. LambdaT lore -> BodyT lore
lambdaBody) [Lambda lore]
lams)
}
where
xParams :: LambdaT lore -> [Param (LParamInfo lore)]
xParams LambdaT lore
lam = Int -> [Param (LParamInfo lore)] -> [Param (LParamInfo lore)]
forall a. Int -> [a] -> [a]
take ([Type] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (LambdaT lore -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType LambdaT lore
lam)) (LambdaT lore -> [Param (LParamInfo lore)]
forall {lore}. LambdaT lore -> [Param (LParamInfo lore)]
lambdaParams LambdaT lore
lam)
yParams :: LambdaT lore -> [Param (LParamInfo lore)]
yParams LambdaT lore
lam = Int -> [Param (LParamInfo lore)] -> [Param (LParamInfo lore)]
forall a. Int -> [a] -> [a]
drop ([Type] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (LambdaT lore -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType LambdaT lore
lam)) (LambdaT lore -> [Param (LParamInfo lore)]
forall {lore}. LambdaT lore -> [Param (LParamInfo lore)]
lambdaParams LambdaT lore
lam)
data Scan lore = Scan
{ forall lore. Scan lore -> Lambda lore
scanLambda :: Lambda lore,
forall lore. Scan lore -> [SubExp]
scanNeutral :: [SubExp]
}
deriving (Scan lore -> Scan lore -> Bool
(Scan lore -> Scan lore -> Bool)
-> (Scan lore -> Scan lore -> Bool) -> Eq (Scan lore)
forall lore. Decorations lore => Scan lore -> Scan lore -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Scan lore -> Scan lore -> Bool
$c/= :: forall lore. Decorations lore => Scan lore -> Scan lore -> Bool
== :: Scan lore -> Scan lore -> Bool
$c== :: forall lore. Decorations lore => Scan lore -> Scan lore -> Bool
Eq, Eq (Scan lore)
Eq (Scan lore)
-> (Scan lore -> Scan lore -> Ordering)
-> (Scan lore -> Scan lore -> Bool)
-> (Scan lore -> Scan lore -> Bool)
-> (Scan lore -> Scan lore -> Bool)
-> (Scan lore -> Scan lore -> Bool)
-> (Scan lore -> Scan lore -> Scan lore)
-> (Scan lore -> Scan lore -> Scan lore)
-> Ord (Scan lore)
Scan lore -> Scan lore -> Bool
Scan lore -> Scan lore -> Ordering
Scan lore -> Scan lore -> Scan lore
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
forall lore. Decorations lore => Eq (Scan lore)
forall lore. Decorations lore => Scan lore -> Scan lore -> Bool
forall lore. Decorations lore => Scan lore -> Scan lore -> Ordering
forall lore.
Decorations lore =>
Scan lore -> Scan lore -> Scan lore
min :: Scan lore -> Scan lore -> Scan lore
$cmin :: forall lore.
Decorations lore =>
Scan lore -> Scan lore -> Scan lore
max :: Scan lore -> Scan lore -> Scan lore
$cmax :: forall lore.
Decorations lore =>
Scan lore -> Scan lore -> Scan lore
>= :: Scan lore -> Scan lore -> Bool
$c>= :: forall lore. Decorations lore => Scan lore -> Scan lore -> Bool
> :: Scan lore -> Scan lore -> Bool
$c> :: forall lore. Decorations lore => Scan lore -> Scan lore -> Bool
<= :: Scan lore -> Scan lore -> Bool
$c<= :: forall lore. Decorations lore => Scan lore -> Scan lore -> Bool
< :: Scan lore -> Scan lore -> Bool
$c< :: forall lore. Decorations lore => Scan lore -> Scan lore -> Bool
compare :: Scan lore -> Scan lore -> Ordering
$ccompare :: forall lore. Decorations lore => Scan lore -> Scan lore -> Ordering
Ord, Int -> Scan lore -> ShowS
[Scan lore] -> ShowS
Scan lore -> String
(Int -> Scan lore -> ShowS)
-> (Scan lore -> String)
-> ([Scan lore] -> ShowS)
-> Show (Scan lore)
forall lore. Decorations lore => Int -> Scan lore -> ShowS
forall lore. Decorations lore => [Scan lore] -> ShowS
forall lore. Decorations lore => Scan lore -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Scan lore] -> ShowS
$cshowList :: forall lore. Decorations lore => [Scan lore] -> ShowS
show :: Scan lore -> String
$cshow :: forall lore. Decorations lore => Scan lore -> String
showsPrec :: Int -> Scan lore -> ShowS
$cshowsPrec :: forall lore. Decorations lore => Int -> Scan lore -> ShowS
Show)
scanResults :: [Scan lore] -> Int
scanResults :: forall lore. [Scan lore] -> Int
scanResults = [Int] -> Int
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum ([Int] -> Int) -> ([Scan lore] -> [Int]) -> [Scan lore] -> Int
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. (Scan lore -> Int) -> [Scan lore] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([SubExp] -> Int) -> (Scan lore -> [SubExp]) -> Scan lore -> Int
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Scan lore -> [SubExp]
forall lore. Scan lore -> [SubExp]
scanNeutral)
singleScan :: Bindable lore => [Scan lore] -> Scan lore
singleScan :: forall lore. Bindable lore => [Scan lore] -> Scan lore
singleScan [Scan lore]
scans =
let scan_nes :: [SubExp]
scan_nes = (Scan lore -> [SubExp]) -> [Scan lore] -> [SubExp]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap Scan lore -> [SubExp]
forall lore. Scan lore -> [SubExp]
scanNeutral [Scan lore]
scans
scan_lam :: Lambda lore
scan_lam = [Lambda lore] -> Lambda lore
forall lore. Bindable lore => [Lambda lore] -> Lambda lore
singleBinOp ([Lambda lore] -> Lambda lore) -> [Lambda lore] -> Lambda lore
forall a b. (a -> b) -> a -> b
$ (Scan lore -> Lambda lore) -> [Scan lore] -> [Lambda lore]
forall a b. (a -> b) -> [a] -> [b]
map Scan lore -> Lambda lore
forall lore. Scan lore -> Lambda lore
scanLambda [Scan lore]
scans
in Lambda lore -> [SubExp] -> Scan lore
forall lore. Lambda lore -> [SubExp] -> Scan lore
Scan Lambda lore
scan_lam [SubExp]
scan_nes
data Reduce lore = Reduce
{ forall lore. Reduce lore -> Commutativity
redComm :: Commutativity,
forall lore. Reduce lore -> Lambda lore
redLambda :: Lambda lore,
forall lore. Reduce lore -> [SubExp]
redNeutral :: [SubExp]
}
deriving (Reduce lore -> Reduce lore -> Bool
(Reduce lore -> Reduce lore -> Bool)
-> (Reduce lore -> Reduce lore -> Bool) -> Eq (Reduce lore)
forall lore. Decorations lore => Reduce lore -> Reduce lore -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Reduce lore -> Reduce lore -> Bool
$c/= :: forall lore. Decorations lore => Reduce lore -> Reduce lore -> Bool
== :: Reduce lore -> Reduce lore -> Bool
$c== :: forall lore. Decorations lore => Reduce lore -> Reduce lore -> Bool
Eq, Eq (Reduce lore)
Eq (Reduce lore)
-> (Reduce lore -> Reduce lore -> Ordering)
-> (Reduce lore -> Reduce lore -> Bool)
-> (Reduce lore -> Reduce lore -> Bool)
-> (Reduce lore -> Reduce lore -> Bool)
-> (Reduce lore -> Reduce lore -> Bool)
-> (Reduce lore -> Reduce lore -> Reduce lore)
-> (Reduce lore -> Reduce lore -> Reduce lore)
-> Ord (Reduce lore)
Reduce lore -> Reduce lore -> Bool
Reduce lore -> Reduce lore -> Ordering
Reduce lore -> Reduce lore -> Reduce lore
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
forall lore. Decorations lore => Eq (Reduce lore)
forall lore. Decorations lore => Reduce lore -> Reduce lore -> Bool
forall lore.
Decorations lore =>
Reduce lore -> Reduce lore -> Ordering
forall lore.
Decorations lore =>
Reduce lore -> Reduce lore -> Reduce lore
min :: Reduce lore -> Reduce lore -> Reduce lore
$cmin :: forall lore.
Decorations lore =>
Reduce lore -> Reduce lore -> Reduce lore
max :: Reduce lore -> Reduce lore -> Reduce lore
$cmax :: forall lore.
Decorations lore =>
Reduce lore -> Reduce lore -> Reduce lore
>= :: Reduce lore -> Reduce lore -> Bool
$c>= :: forall lore. Decorations lore => Reduce lore -> Reduce lore -> Bool
> :: Reduce lore -> Reduce lore -> Bool
$c> :: forall lore. Decorations lore => Reduce lore -> Reduce lore -> Bool
<= :: Reduce lore -> Reduce lore -> Bool
$c<= :: forall lore. Decorations lore => Reduce lore -> Reduce lore -> Bool
< :: Reduce lore -> Reduce lore -> Bool
$c< :: forall lore. Decorations lore => Reduce lore -> Reduce lore -> Bool
compare :: Reduce lore -> Reduce lore -> Ordering
$ccompare :: forall lore.
Decorations lore =>
Reduce lore -> Reduce lore -> Ordering
Ord, Int -> Reduce lore -> ShowS
[Reduce lore] -> ShowS
Reduce lore -> String
(Int -> Reduce lore -> ShowS)
-> (Reduce lore -> String)
-> ([Reduce lore] -> ShowS)
-> Show (Reduce lore)
forall lore. Decorations lore => Int -> Reduce lore -> ShowS
forall lore. Decorations lore => [Reduce lore] -> ShowS
forall lore. Decorations lore => Reduce lore -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Reduce lore] -> ShowS
$cshowList :: forall lore. Decorations lore => [Reduce lore] -> ShowS
show :: Reduce lore -> String
$cshow :: forall lore. Decorations lore => Reduce lore -> String
showsPrec :: Int -> Reduce lore -> ShowS
$cshowsPrec :: forall lore. Decorations lore => Int -> Reduce lore -> ShowS
Show)
redResults :: [Reduce lore] -> Int
redResults :: forall lore. [Reduce lore] -> Int
redResults = [Int] -> Int
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum ([Int] -> Int) -> ([Reduce lore] -> [Int]) -> [Reduce lore] -> Int
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. (Reduce lore -> Int) -> [Reduce lore] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([SubExp] -> Int)
-> (Reduce lore -> [SubExp]) -> Reduce lore -> Int
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Reduce lore -> [SubExp]
forall lore. Reduce lore -> [SubExp]
redNeutral)
singleReduce :: Bindable lore => [Reduce lore] -> Reduce lore
singleReduce :: forall lore. Bindable lore => [Reduce lore] -> Reduce lore
singleReduce [Reduce lore]
reds =
let red_nes :: [SubExp]
red_nes = (Reduce lore -> [SubExp]) -> [Reduce lore] -> [SubExp]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap Reduce lore -> [SubExp]
forall lore. Reduce lore -> [SubExp]
redNeutral [Reduce lore]
reds
red_lam :: Lambda lore
red_lam = [Lambda lore] -> Lambda lore
forall lore. Bindable lore => [Lambda lore] -> Lambda lore
singleBinOp ([Lambda lore] -> Lambda lore) -> [Lambda lore] -> Lambda lore
forall a b. (a -> b) -> a -> b
$ (Reduce lore -> Lambda lore) -> [Reduce lore] -> [Lambda lore]
forall a b. (a -> b) -> [a] -> [b]
map Reduce lore -> Lambda lore
forall lore. Reduce lore -> Lambda lore
redLambda [Reduce lore]
reds
in Commutativity -> Lambda lore -> [SubExp] -> Reduce lore
forall lore.
Commutativity -> Lambda lore -> [SubExp] -> Reduce lore
Reduce ([Commutativity] -> Commutativity
forall a. Monoid a => [a] -> a
mconcat ((Reduce lore -> Commutativity) -> [Reduce lore] -> [Commutativity]
forall a b. (a -> b) -> [a] -> [b]
map Reduce lore -> Commutativity
forall lore. Reduce lore -> Commutativity
redComm [Reduce lore]
reds)) Lambda lore
red_lam [SubExp]
red_nes
scremaType :: SubExp -> ScremaForm lore -> [Type]
scremaType :: forall lore. SubExp -> ScremaForm lore -> [Type]
scremaType SubExp
w (ScremaForm [Scan lore]
scans [Reduce lore]
reds Lambda lore
map_lam) =
[Type]
scan_tps [Type] -> [Type] -> [Type]
forall a. [a] -> [a] -> [a]
++ [Type]
red_tps [Type] -> [Type] -> [Type]
forall a. [a] -> [a] -> [a]
++ (Type -> Type) -> [Type] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map (Type -> SubExp -> Type
forall d.
ArrayShape (ShapeBase d) =>
TypeBase (ShapeBase d) NoUniqueness
-> d -> TypeBase (ShapeBase d) NoUniqueness
`arrayOfRow` SubExp
w) [Type]
map_tps
where
scan_tps :: [Type]
scan_tps =
(Type -> Type) -> [Type] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map (Type -> SubExp -> Type
forall d.
ArrayShape (ShapeBase d) =>
TypeBase (ShapeBase d) NoUniqueness
-> d -> TypeBase (ShapeBase d) NoUniqueness
`arrayOfRow` SubExp
w) ([Type] -> [Type]) -> [Type] -> [Type]
forall a b. (a -> b) -> a -> b
$
(Scan lore -> [Type]) -> [Scan lore] -> [Type]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (Lambda lore -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType (Lambda lore -> [Type])
-> (Scan lore -> Lambda lore) -> Scan lore -> [Type]
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Scan lore -> Lambda lore
forall lore. Scan lore -> Lambda lore
scanLambda) [Scan lore]
scans
red_tps :: [Type]
red_tps = (Reduce lore -> [Type]) -> [Reduce lore] -> [Type]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (Lambda lore -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType (Lambda lore -> [Type])
-> (Reduce lore -> Lambda lore) -> Reduce lore -> [Type]
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Reduce lore -> Lambda lore
forall lore. Reduce lore -> Lambda lore
redLambda) [Reduce lore]
reds
map_tps :: [Type]
map_tps = Int -> [Type] -> [Type]
forall a. Int -> [a] -> [a]
drop ([Type] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Type]
scan_tps Int -> Int -> Int
forall a. Num a => a -> a -> a
+ [Type] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Type]
red_tps) ([Type] -> [Type]) -> [Type] -> [Type]
forall a b. (a -> b) -> a -> b
$ Lambda lore -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType Lambda lore
map_lam
mkIdentityLambda ::
(Bindable lore, MonadFreshNames m) =>
[Type] ->
m (Lambda lore)
mkIdentityLambda :: forall lore (m :: * -> *).
(Bindable lore, MonadFreshNames m) =>
[Type] -> m (Lambda lore)
mkIdentityLambda [Type]
ts = do
[Param Type]
params <- (Type -> m (Param Type)) -> [Type] -> m [Param Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (String -> Type -> m (Param Type)
forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"x") [Type]
ts
Lambda lore -> m (Lambda lore)
forall (m :: * -> *) a. Monad m => a -> m a
return
Lambda :: forall lore. [LParam lore] -> BodyT lore -> [Type] -> LambdaT lore
Lambda
{ lambdaParams :: [LParam lore]
lambdaParams = [Param Type]
[LParam lore]
params,
lambdaBody :: BodyT lore
lambdaBody = Stms lore -> [SubExp] -> BodyT lore
forall lore. Bindable lore => Stms lore -> [SubExp] -> Body lore
mkBody Stms lore
forall a. Monoid a => a
mempty ([SubExp] -> BodyT lore) -> [SubExp] -> BodyT lore
forall a b. (a -> b) -> a -> b
$ (Param Type -> SubExp) -> [Param Type] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map (VName -> SubExp
Var (VName -> SubExp) -> (Param Type -> VName) -> Param Type -> SubExp
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Param Type -> VName
forall dec. Param dec -> VName
paramName) [Param Type]
params,
lambdaReturnType :: [Type]
lambdaReturnType = [Type]
ts
}
isIdentityLambda :: Lambda lore -> Bool
isIdentityLambda :: forall lore. Lambda lore -> Bool
isIdentityLambda Lambda lore
lam =
BodyT lore -> [SubExp]
forall lore. BodyT lore -> [SubExp]
bodyResult (Lambda lore -> BodyT lore
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda lore
lam)
[SubExp] -> [SubExp] -> Bool
forall a. Eq a => a -> a -> Bool
== (Param (LParamInfo lore) -> SubExp)
-> [Param (LParamInfo lore)] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map (VName -> SubExp
Var (VName -> SubExp)
-> (Param (LParamInfo lore) -> VName)
-> Param (LParamInfo lore)
-> SubExp
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Param (LParamInfo lore) -> VName
forall dec. Param dec -> VName
paramName) (Lambda lore -> [Param (LParamInfo lore)]
forall {lore}. LambdaT lore -> [Param (LParamInfo lore)]
lambdaParams Lambda lore
lam)
nilFn :: Bindable lore => Lambda lore
nilFn :: forall lore. Bindable lore => Lambda lore
nilFn = [LParam lore] -> BodyT lore -> [Type] -> LambdaT lore
forall lore. [LParam lore] -> BodyT lore -> [Type] -> LambdaT lore
Lambda [LParam lore]
forall a. Monoid a => a
mempty (Stms lore -> [SubExp] -> BodyT lore
forall lore. Bindable lore => Stms lore -> [SubExp] -> Body lore
mkBody Stms lore
forall a. Monoid a => a
mempty [SubExp]
forall a. Monoid a => a
mempty) [Type]
forall a. Monoid a => a
mempty
scanomapSOAC :: [Scan lore] -> Lambda lore -> ScremaForm lore
scanomapSOAC :: forall lore. [Scan lore] -> Lambda lore -> ScremaForm lore
scanomapSOAC [Scan lore]
scans = [Scan lore] -> [Reduce lore] -> Lambda lore -> ScremaForm lore
forall lore.
[Scan lore] -> [Reduce lore] -> Lambda lore -> ScremaForm lore
ScremaForm [Scan lore]
scans []
redomapSOAC :: [Reduce lore] -> Lambda lore -> ScremaForm lore
redomapSOAC :: forall lore. [Reduce lore] -> Lambda lore -> ScremaForm lore
redomapSOAC = [Scan lore] -> [Reduce lore] -> Lambda lore -> ScremaForm lore
forall lore.
[Scan lore] -> [Reduce lore] -> Lambda lore -> ScremaForm lore
ScremaForm []
scanSOAC ::
(Bindable lore, MonadFreshNames m) =>
[Scan lore] ->
m (ScremaForm lore)
scanSOAC :: forall lore (m :: * -> *).
(Bindable lore, MonadFreshNames m) =>
[Scan lore] -> m (ScremaForm lore)
scanSOAC [Scan lore]
scans = [Scan lore] -> Lambda lore -> ScremaForm lore
forall lore. [Scan lore] -> Lambda lore -> ScremaForm lore
scanomapSOAC [Scan lore]
scans (Lambda lore -> ScremaForm lore)
-> m (Lambda lore) -> m (ScremaForm lore)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Type] -> m (Lambda lore)
forall lore (m :: * -> *).
(Bindable lore, MonadFreshNames m) =>
[Type] -> m (Lambda lore)
mkIdentityLambda [Type]
ts
where
ts :: [Type]
ts = (Scan lore -> [Type]) -> [Scan lore] -> [Type]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (Lambda lore -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType (Lambda lore -> [Type])
-> (Scan lore -> Lambda lore) -> Scan lore -> [Type]
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Scan lore -> Lambda lore
forall lore. Scan lore -> Lambda lore
scanLambda) [Scan lore]
scans
reduceSOAC ::
(Bindable lore, MonadFreshNames m) =>
[Reduce lore] ->
m (ScremaForm lore)
reduceSOAC :: forall lore (m :: * -> *).
(Bindable lore, MonadFreshNames m) =>
[Reduce lore] -> m (ScremaForm lore)
reduceSOAC [Reduce lore]
reds = [Reduce lore] -> Lambda lore -> ScremaForm lore
forall lore. [Reduce lore] -> Lambda lore -> ScremaForm lore
redomapSOAC [Reduce lore]
reds (Lambda lore -> ScremaForm lore)
-> m (Lambda lore) -> m (ScremaForm lore)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Type] -> m (Lambda lore)
forall lore (m :: * -> *).
(Bindable lore, MonadFreshNames m) =>
[Type] -> m (Lambda lore)
mkIdentityLambda [Type]
ts
where
ts :: [Type]
ts = (Reduce lore -> [Type]) -> [Reduce lore] -> [Type]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (Lambda lore -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType (Lambda lore -> [Type])
-> (Reduce lore -> Lambda lore) -> Reduce lore -> [Type]
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Reduce lore -> Lambda lore
forall lore. Reduce lore -> Lambda lore
redLambda) [Reduce lore]
reds
mapSOAC :: Lambda lore -> ScremaForm lore
mapSOAC :: forall lore. Lambda lore -> ScremaForm lore
mapSOAC = [Scan lore] -> [Reduce lore] -> Lambda lore -> ScremaForm lore
forall lore.
[Scan lore] -> [Reduce lore] -> Lambda lore -> ScremaForm lore
ScremaForm [] []
isScanomapSOAC :: ScremaForm lore -> Maybe ([Scan lore], Lambda lore)
isScanomapSOAC :: forall lore. ScremaForm lore -> Maybe ([Scan lore], Lambda lore)
isScanomapSOAC (ScremaForm [Scan lore]
scans [Reduce lore]
reds Lambda lore
map_lam) = do
Bool -> Maybe ()
forall (f :: * -> *). Alternative f => Bool -> f ()
guard (Bool -> Maybe ()) -> Bool -> Maybe ()
forall a b. (a -> b) -> a -> b
$ [Reduce lore] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Reduce lore]
reds
Bool -> Maybe ()
forall (f :: * -> *). Alternative f => Bool -> f ()
guard (Bool -> Maybe ()) -> Bool -> Maybe ()
forall a b. (a -> b) -> a -> b
$ Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ [Scan lore] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Scan lore]
scans
([Scan lore], Lambda lore) -> Maybe ([Scan lore], Lambda lore)
forall (m :: * -> *) a. Monad m => a -> m a
return ([Scan lore]
scans, Lambda lore
map_lam)
isScanSOAC :: ScremaForm lore -> Maybe [Scan lore]
isScanSOAC :: forall lore. ScremaForm lore -> Maybe [Scan lore]
isScanSOAC ScremaForm lore
form = do
([Scan lore]
scans, Lambda lore
map_lam) <- ScremaForm lore -> Maybe ([Scan lore], Lambda lore)
forall lore. ScremaForm lore -> Maybe ([Scan lore], Lambda lore)
isScanomapSOAC ScremaForm lore
form
Bool -> Maybe ()
forall (f :: * -> *). Alternative f => Bool -> f ()
guard (Bool -> Maybe ()) -> Bool -> Maybe ()
forall a b. (a -> b) -> a -> b
$ Lambda lore -> Bool
forall lore. Lambda lore -> Bool
isIdentityLambda Lambda lore
map_lam
[Scan lore] -> Maybe [Scan lore]
forall (m :: * -> *) a. Monad m => a -> m a
return [Scan lore]
scans
isRedomapSOAC :: ScremaForm lore -> Maybe ([Reduce lore], Lambda lore)
isRedomapSOAC :: forall lore. ScremaForm lore -> Maybe ([Reduce lore], Lambda lore)
isRedomapSOAC (ScremaForm [Scan lore]
scans [Reduce lore]
reds Lambda lore
map_lam) = do
Bool -> Maybe ()
forall (f :: * -> *). Alternative f => Bool -> f ()
guard (Bool -> Maybe ()) -> Bool -> Maybe ()
forall a b. (a -> b) -> a -> b
$ [Scan lore] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Scan lore]
scans
Bool -> Maybe ()
forall (f :: * -> *). Alternative f => Bool -> f ()
guard (Bool -> Maybe ()) -> Bool -> Maybe ()
forall a b. (a -> b) -> a -> b
$ Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ [Reduce lore] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Reduce lore]
reds
([Reduce lore], Lambda lore) -> Maybe ([Reduce lore], Lambda lore)
forall (m :: * -> *) a. Monad m => a -> m a
return ([Reduce lore]
reds, Lambda lore
map_lam)
isReduceSOAC :: ScremaForm lore -> Maybe [Reduce lore]
isReduceSOAC :: forall lore. ScremaForm lore -> Maybe [Reduce lore]
isReduceSOAC ScremaForm lore
form = do
([Reduce lore]
reds, Lambda lore
map_lam) <- ScremaForm lore -> Maybe ([Reduce lore], Lambda lore)
forall lore. ScremaForm lore -> Maybe ([Reduce lore], Lambda lore)
isRedomapSOAC ScremaForm lore
form
Bool -> Maybe ()
forall (f :: * -> *). Alternative f => Bool -> f ()
guard (Bool -> Maybe ()) -> Bool -> Maybe ()
forall a b. (a -> b) -> a -> b
$ Lambda lore -> Bool
forall lore. Lambda lore -> Bool
isIdentityLambda Lambda lore
map_lam
[Reduce lore] -> Maybe [Reduce lore]
forall (m :: * -> *) a. Monad m => a -> m a
return [Reduce lore]
reds
isMapSOAC :: ScremaForm lore -> Maybe (Lambda lore)
isMapSOAC :: forall lore. ScremaForm lore -> Maybe (Lambda lore)
isMapSOAC (ScremaForm [Scan lore]
scans [Reduce lore]
reds Lambda lore
map_lam) = do
Bool -> Maybe ()
forall (f :: * -> *). Alternative f => Bool -> f ()
guard (Bool -> Maybe ()) -> Bool -> Maybe ()
forall a b. (a -> b) -> a -> b
$ [Scan lore] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Scan lore]
scans
Bool -> Maybe ()
forall (f :: * -> *). Alternative f => Bool -> f ()
guard (Bool -> Maybe ()) -> Bool -> Maybe ()
forall a b. (a -> b) -> a -> b
$ [Reduce lore] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Reduce lore]
reds
Lambda lore -> Maybe (Lambda lore)
forall (m :: * -> *) a. Monad m => a -> m a
return Lambda lore
map_lam
groupScatterResults :: [(Shape, Int, array)] -> [a] -> [(Shape, array, [([a], a)])]
groupScatterResults :: forall array a.
[(Shape, Int, array)] -> [a] -> [(Shape, array, [([a], a)])]
groupScatterResults [(Shape, Int, array)]
output_spec [a]
results =
let ([Shape]
shapes, [Int]
ns, [array]
arrays) = [(Shape, Int, array)] -> ([Shape], [Int], [array])
forall a b c. [(a, b, c)] -> ([a], [b], [c])
unzip3 [(Shape, Int, array)]
output_spec
in [(Shape, Int, array)] -> [a] -> [([a], a)]
forall array a. [(Shape, Int, array)] -> [a] -> [([a], a)]
groupScatterResults' [(Shape, Int, array)]
output_spec [a]
results
[([a], a)] -> ([([a], a)] -> [[([a], a)]]) -> [[([a], a)]]
forall a b. a -> (a -> b) -> b
& [Int] -> [([a], a)] -> [[([a], a)]]
forall a. [Int] -> [a] -> [[a]]
chunks [Int]
ns
[[([a], a)]]
-> ([[([a], a)]] -> [(Shape, array, [([a], a)])])
-> [(Shape, array, [([a], a)])]
forall a b. a -> (a -> b) -> b
& [Shape] -> [array] -> [[([a], a)]] -> [(Shape, array, [([a], a)])]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [Shape]
shapes [array]
arrays
groupScatterResults' :: [(Shape, Int, array)] -> [a] -> [([a], a)]
groupScatterResults' :: forall array a. [(Shape, Int, array)] -> [a] -> [([a], a)]
groupScatterResults' [(Shape, Int, array)]
output_spec [a]
results =
let ([a]
indices, [a]
values) = [(Shape, Int, array)] -> [a] -> ([a], [a])
forall array a. [(Shape, Int, array)] -> [a] -> ([a], [a])
splitScatterResults [(Shape, Int, array)]
output_spec [a]
results
([Shape]
shapes, [Int]
ns, [array]
_) = [(Shape, Int, array)] -> ([Shape], [Int], [array])
forall a b c. [(a, b, c)] -> ([a], [b], [c])
unzip3 [(Shape, Int, array)]
output_spec
chunk_sizes :: [Int]
chunk_sizes =
[[Int]] -> [Int]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat ([[Int]] -> [Int]) -> [[Int]] -> [Int]
forall a b. (a -> b) -> a -> b
$ (Shape -> Int -> [Int]) -> [Shape] -> [Int] -> [[Int]]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith (\Shape
shp Int
n -> Int -> Int -> [Int]
forall a. Int -> a -> [a]
replicate Int
n (Int -> [Int]) -> Int -> [Int]
forall a b. (a -> b) -> a -> b
$ Shape -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length Shape
shp) [Shape]
shapes [Int]
ns
in [[a]] -> [a] -> [([a], a)]
forall a b. [a] -> [b] -> [(a, b)]
zip ([Int] -> [a] -> [[a]]
forall a. [Int] -> [a] -> [[a]]
chunks [Int]
chunk_sizes [a]
indices) [a]
values
splitScatterResults :: [(Shape, Int, array)] -> [a] -> ([a], [a])
splitScatterResults :: forall array a. [(Shape, Int, array)] -> [a] -> ([a], [a])
splitScatterResults [(Shape, Int, array)]
output_spec [a]
results =
let ([Shape]
shapes, [Int]
ns, [array]
_) = [(Shape, Int, array)] -> ([Shape], [Int], [array])
forall a b c. [(a, b, c)] -> ([a], [b], [c])
unzip3 [(Shape, Int, array)]
output_spec
num_indices :: Int
num_indices = [Int] -> Int
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum ([Int] -> Int) -> [Int] -> Int
forall a b. (a -> b) -> a -> b
$ (Int -> Int -> Int) -> [Int] -> [Int] -> [Int]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Int -> Int -> Int
forall a. Num a => a -> a -> a
(*) [Int]
ns ([Int] -> [Int]) -> [Int] -> [Int]
forall a b. (a -> b) -> a -> b
$ (Shape -> Int) -> [Shape] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map Shape -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Shape]
shapes
in Int -> [a] -> ([a], [a])
forall a. Int -> [a] -> ([a], [a])
splitAt Int
num_indices [a]
results
data SOACMapper flore tlore m = SOACMapper
{ forall flore tlore (m :: * -> *).
SOACMapper flore tlore m -> SubExp -> m SubExp
mapOnSOACSubExp :: SubExp -> m SubExp,
forall flore tlore (m :: * -> *).
SOACMapper flore tlore m -> Lambda flore -> m (Lambda tlore)
mapOnSOACLambda :: Lambda flore -> m (Lambda tlore),
forall flore tlore (m :: * -> *).
SOACMapper flore tlore m -> VName -> m VName
mapOnSOACVName :: VName -> m VName
}
identitySOACMapper :: Monad m => SOACMapper lore lore m
identitySOACMapper :: forall (m :: * -> *) lore. Monad m => SOACMapper lore lore m
identitySOACMapper =
SOACMapper :: forall flore tlore (m :: * -> *).
(SubExp -> m SubExp)
-> (Lambda flore -> m (Lambda tlore))
-> (VName -> m VName)
-> SOACMapper flore tlore m
SOACMapper
{ mapOnSOACSubExp :: SubExp -> m SubExp
mapOnSOACSubExp = SubExp -> m SubExp
forall (m :: * -> *) a. Monad m => a -> m a
return,
mapOnSOACLambda :: Lambda lore -> m (Lambda lore)
mapOnSOACLambda = Lambda lore -> m (Lambda lore)
forall (m :: * -> *) a. Monad m => a -> m a
return,
mapOnSOACVName :: VName -> m VName
mapOnSOACVName = VName -> m VName
forall (m :: * -> *) a. Monad m => a -> m a
return
}
mapSOACM ::
(Applicative m, Monad m) =>
SOACMapper flore tlore m ->
SOAC flore ->
m (SOAC tlore)
mapSOACM :: forall (m :: * -> *) flore tlore.
(Applicative m, Monad m) =>
SOACMapper flore tlore m -> SOAC flore -> m (SOAC tlore)
mapSOACM SOACMapper flore tlore m
tv (Stream SubExp
size [VName]
arrs StreamForm flore
form [SubExp]
accs Lambda flore
lam) =
SubExp
-> [VName]
-> StreamForm tlore
-> [SubExp]
-> Lambda tlore
-> SOAC tlore
forall lore.
SubExp
-> [VName]
-> StreamForm lore
-> [SubExp]
-> Lambda lore
-> SOAC lore
Stream (SubExp
-> [VName]
-> StreamForm tlore
-> [SubExp]
-> Lambda tlore
-> SOAC tlore)
-> m SubExp
-> m ([VName]
-> StreamForm tlore -> [SubExp] -> Lambda tlore -> SOAC tlore)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SOACMapper flore tlore m -> SubExp -> m SubExp
forall flore tlore (m :: * -> *).
SOACMapper flore tlore m -> SubExp -> m SubExp
mapOnSOACSubExp SOACMapper flore tlore m
tv SubExp
size
m ([VName]
-> StreamForm tlore -> [SubExp] -> Lambda tlore -> SOAC tlore)
-> m [VName]
-> m (StreamForm tlore -> [SubExp] -> Lambda tlore -> SOAC tlore)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (VName -> m VName) -> [VName] -> m [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (SOACMapper flore tlore m -> VName -> m VName
forall flore tlore (m :: * -> *).
SOACMapper flore tlore m -> VName -> m VName
mapOnSOACVName SOACMapper flore tlore m
tv) [VName]
arrs
m (StreamForm tlore -> [SubExp] -> Lambda tlore -> SOAC tlore)
-> m (StreamForm tlore)
-> m ([SubExp] -> Lambda tlore -> SOAC tlore)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> StreamForm flore -> m (StreamForm tlore)
mapOnStreamForm StreamForm flore
form
m ([SubExp] -> Lambda tlore -> SOAC tlore)
-> m [SubExp] -> m (Lambda tlore -> SOAC tlore)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (SubExp -> m SubExp) -> [SubExp] -> m [SubExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (SOACMapper flore tlore m -> SubExp -> m SubExp
forall flore tlore (m :: * -> *).
SOACMapper flore tlore m -> SubExp -> m SubExp
mapOnSOACSubExp SOACMapper flore tlore m
tv) [SubExp]
accs
m (Lambda tlore -> SOAC tlore)
-> m (Lambda tlore) -> m (SOAC tlore)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SOACMapper flore tlore m -> Lambda flore -> m (Lambda tlore)
forall flore tlore (m :: * -> *).
SOACMapper flore tlore m -> Lambda flore -> m (Lambda tlore)
mapOnSOACLambda SOACMapper flore tlore m
tv Lambda flore
lam
where
mapOnStreamForm :: StreamForm flore -> m (StreamForm tlore)
mapOnStreamForm (Parallel StreamOrd
o Commutativity
comm Lambda flore
lam0) =
StreamOrd -> Commutativity -> Lambda tlore -> StreamForm tlore
forall lore.
StreamOrd -> Commutativity -> Lambda lore -> StreamForm lore
Parallel StreamOrd
o Commutativity
comm (Lambda tlore -> StreamForm tlore)
-> m (Lambda tlore) -> m (StreamForm tlore)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SOACMapper flore tlore m -> Lambda flore -> m (Lambda tlore)
forall flore tlore (m :: * -> *).
SOACMapper flore tlore m -> Lambda flore -> m (Lambda tlore)
mapOnSOACLambda SOACMapper flore tlore m
tv Lambda flore
lam0
mapOnStreamForm StreamForm flore
Sequential =
StreamForm tlore -> m (StreamForm tlore)
forall (f :: * -> *) a. Applicative f => a -> f a
pure StreamForm tlore
forall lore. StreamForm lore
Sequential
mapSOACM SOACMapper flore tlore m
tv (Scatter SubExp
len Lambda flore
lam [VName]
ivs [(Shape, Int, VName)]
as) =
SubExp
-> Lambda tlore -> [VName] -> [(Shape, Int, VName)] -> SOAC tlore
forall lore.
SubExp
-> Lambda lore -> [VName] -> [(Shape, Int, VName)] -> SOAC lore
Scatter
(SubExp
-> Lambda tlore -> [VName] -> [(Shape, Int, VName)] -> SOAC tlore)
-> m SubExp
-> m (Lambda tlore
-> [VName] -> [(Shape, Int, VName)] -> SOAC tlore)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SOACMapper flore tlore m -> SubExp -> m SubExp
forall flore tlore (m :: * -> *).
SOACMapper flore tlore m -> SubExp -> m SubExp
mapOnSOACSubExp SOACMapper flore tlore m
tv SubExp
len
m (Lambda tlore -> [VName] -> [(Shape, Int, VName)] -> SOAC tlore)
-> m (Lambda tlore)
-> m ([VName] -> [(Shape, Int, VName)] -> SOAC tlore)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SOACMapper flore tlore m -> Lambda flore -> m (Lambda tlore)
forall flore tlore (m :: * -> *).
SOACMapper flore tlore m -> Lambda flore -> m (Lambda tlore)
mapOnSOACLambda SOACMapper flore tlore m
tv Lambda flore
lam
m ([VName] -> [(Shape, Int, VName)] -> SOAC tlore)
-> m [VName] -> m ([(Shape, Int, VName)] -> SOAC tlore)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (VName -> m VName) -> [VName] -> m [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (SOACMapper flore tlore m -> VName -> m VName
forall flore tlore (m :: * -> *).
SOACMapper flore tlore m -> VName -> m VName
mapOnSOACVName SOACMapper flore tlore m
tv) [VName]
ivs
m ([(Shape, Int, VName)] -> SOAC tlore)
-> m [(Shape, Int, VName)] -> m (SOAC tlore)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> ((Shape, Int, VName) -> m (Shape, Int, VName))
-> [(Shape, Int, VName)] -> m [(Shape, Int, VName)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM
( \(Shape
aw, Int
an, VName
a) ->
(,,) (Shape -> Int -> VName -> (Shape, Int, VName))
-> m Shape -> m (Int -> VName -> (Shape, Int, VName))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (SubExp -> m SubExp) -> Shape -> m Shape
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (SOACMapper flore tlore m -> SubExp -> m SubExp
forall flore tlore (m :: * -> *).
SOACMapper flore tlore m -> SubExp -> m SubExp
mapOnSOACSubExp SOACMapper flore tlore m
tv) Shape
aw
m (Int -> VName -> (Shape, Int, VName))
-> m Int -> m (VName -> (Shape, Int, VName))
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Int -> m Int
forall (f :: * -> *) a. Applicative f => a -> f a
pure Int
an
m (VName -> (Shape, Int, VName))
-> m VName -> m (Shape, Int, VName)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SOACMapper flore tlore m -> VName -> m VName
forall flore tlore (m :: * -> *).
SOACMapper flore tlore m -> VName -> m VName
mapOnSOACVName SOACMapper flore tlore m
tv VName
a
)
[(Shape, Int, VName)]
as
mapSOACM SOACMapper flore tlore m
tv (Hist SubExp
len [HistOp flore]
ops Lambda flore
bucket_fun [VName]
imgs) =
SubExp -> [HistOp tlore] -> Lambda tlore -> [VName] -> SOAC tlore
forall lore.
SubExp -> [HistOp lore] -> Lambda lore -> [VName] -> SOAC lore
Hist
(SubExp -> [HistOp tlore] -> Lambda tlore -> [VName] -> SOAC tlore)
-> m SubExp
-> m ([HistOp tlore] -> Lambda tlore -> [VName] -> SOAC tlore)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SOACMapper flore tlore m -> SubExp -> m SubExp
forall flore tlore (m :: * -> *).
SOACMapper flore tlore m -> SubExp -> m SubExp
mapOnSOACSubExp SOACMapper flore tlore m
tv SubExp
len
m ([HistOp tlore] -> Lambda tlore -> [VName] -> SOAC tlore)
-> m [HistOp tlore] -> m (Lambda tlore -> [VName] -> SOAC tlore)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (HistOp flore -> m (HistOp tlore))
-> [HistOp flore] -> m [HistOp tlore]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM
( \(HistOp SubExp
e SubExp
rf [VName]
arrs [SubExp]
nes Lambda flore
op) ->
SubExp
-> SubExp -> [VName] -> [SubExp] -> Lambda tlore -> HistOp tlore
forall lore.
SubExp
-> SubExp -> [VName] -> [SubExp] -> Lambda lore -> HistOp lore
HistOp (SubExp
-> SubExp -> [VName] -> [SubExp] -> Lambda tlore -> HistOp tlore)
-> m SubExp
-> m (SubExp
-> [VName] -> [SubExp] -> Lambda tlore -> HistOp tlore)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SOACMapper flore tlore m -> SubExp -> m SubExp
forall flore tlore (m :: * -> *).
SOACMapper flore tlore m -> SubExp -> m SubExp
mapOnSOACSubExp SOACMapper flore tlore m
tv SubExp
e
m (SubExp -> [VName] -> [SubExp] -> Lambda tlore -> HistOp tlore)
-> m SubExp
-> m ([VName] -> [SubExp] -> Lambda tlore -> HistOp tlore)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SOACMapper flore tlore m -> SubExp -> m SubExp
forall flore tlore (m :: * -> *).
SOACMapper flore tlore m -> SubExp -> m SubExp
mapOnSOACSubExp SOACMapper flore tlore m
tv SubExp
rf
m ([VName] -> [SubExp] -> Lambda tlore -> HistOp tlore)
-> m [VName] -> m ([SubExp] -> Lambda tlore -> HistOp tlore)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (VName -> m VName) -> [VName] -> m [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (SOACMapper flore tlore m -> VName -> m VName
forall flore tlore (m :: * -> *).
SOACMapper flore tlore m -> VName -> m VName
mapOnSOACVName SOACMapper flore tlore m
tv) [VName]
arrs
m ([SubExp] -> Lambda tlore -> HistOp tlore)
-> m [SubExp] -> m (Lambda tlore -> HistOp tlore)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (SubExp -> m SubExp) -> [SubExp] -> m [SubExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (SOACMapper flore tlore m -> SubExp -> m SubExp
forall flore tlore (m :: * -> *).
SOACMapper flore tlore m -> SubExp -> m SubExp
mapOnSOACSubExp SOACMapper flore tlore m
tv) [SubExp]
nes
m (Lambda tlore -> HistOp tlore)
-> m (Lambda tlore) -> m (HistOp tlore)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SOACMapper flore tlore m -> Lambda flore -> m (Lambda tlore)
forall flore tlore (m :: * -> *).
SOACMapper flore tlore m -> Lambda flore -> m (Lambda tlore)
mapOnSOACLambda SOACMapper flore tlore m
tv Lambda flore
op
)
[HistOp flore]
ops
m (Lambda tlore -> [VName] -> SOAC tlore)
-> m (Lambda tlore) -> m ([VName] -> SOAC tlore)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SOACMapper flore tlore m -> Lambda flore -> m (Lambda tlore)
forall flore tlore (m :: * -> *).
SOACMapper flore tlore m -> Lambda flore -> m (Lambda tlore)
mapOnSOACLambda SOACMapper flore tlore m
tv Lambda flore
bucket_fun
m ([VName] -> SOAC tlore) -> m [VName] -> m (SOAC tlore)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (VName -> m VName) -> [VName] -> m [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (SOACMapper flore tlore m -> VName -> m VName
forall flore tlore (m :: * -> *).
SOACMapper flore tlore m -> VName -> m VName
mapOnSOACVName SOACMapper flore tlore m
tv) [VName]
imgs
mapSOACM SOACMapper flore tlore m
tv (Screma SubExp
w [VName]
arrs (ScremaForm [Scan flore]
scans [Reduce flore]
reds Lambda flore
map_lam)) =
SubExp -> [VName] -> ScremaForm tlore -> SOAC tlore
forall lore. SubExp -> [VName] -> ScremaForm lore -> SOAC lore
Screma (SubExp -> [VName] -> ScremaForm tlore -> SOAC tlore)
-> m SubExp -> m ([VName] -> ScremaForm tlore -> SOAC tlore)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SOACMapper flore tlore m -> SubExp -> m SubExp
forall flore tlore (m :: * -> *).
SOACMapper flore tlore m -> SubExp -> m SubExp
mapOnSOACSubExp SOACMapper flore tlore m
tv SubExp
w
m ([VName] -> ScremaForm tlore -> SOAC tlore)
-> m [VName] -> m (ScremaForm tlore -> SOAC tlore)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (VName -> m VName) -> [VName] -> m [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (SOACMapper flore tlore m -> VName -> m VName
forall flore tlore (m :: * -> *).
SOACMapper flore tlore m -> VName -> m VName
mapOnSOACVName SOACMapper flore tlore m
tv) [VName]
arrs
m (ScremaForm tlore -> SOAC tlore)
-> m (ScremaForm tlore) -> m (SOAC tlore)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> ( [Scan tlore] -> [Reduce tlore] -> Lambda tlore -> ScremaForm tlore
forall lore.
[Scan lore] -> [Reduce lore] -> Lambda lore -> ScremaForm lore
ScremaForm
([Scan tlore]
-> [Reduce tlore] -> Lambda tlore -> ScremaForm tlore)
-> m [Scan tlore]
-> m ([Reduce tlore] -> Lambda tlore -> ScremaForm tlore)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Scan flore] -> (Scan flore -> m (Scan tlore)) -> m [Scan tlore]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM
[Scan flore]
scans
( \(Scan Lambda flore
red_lam [SubExp]
red_nes) ->
Lambda tlore -> [SubExp] -> Scan tlore
forall lore. Lambda lore -> [SubExp] -> Scan lore
Scan (Lambda tlore -> [SubExp] -> Scan tlore)
-> m (Lambda tlore) -> m ([SubExp] -> Scan tlore)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SOACMapper flore tlore m -> Lambda flore -> m (Lambda tlore)
forall flore tlore (m :: * -> *).
SOACMapper flore tlore m -> Lambda flore -> m (Lambda tlore)
mapOnSOACLambda SOACMapper flore tlore m
tv Lambda flore
red_lam
m ([SubExp] -> Scan tlore) -> m [SubExp] -> m (Scan tlore)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (SubExp -> m SubExp) -> [SubExp] -> m [SubExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (SOACMapper flore tlore m -> SubExp -> m SubExp
forall flore tlore (m :: * -> *).
SOACMapper flore tlore m -> SubExp -> m SubExp
mapOnSOACSubExp SOACMapper flore tlore m
tv) [SubExp]
red_nes
)
m ([Reduce tlore] -> Lambda tlore -> ScremaForm tlore)
-> m [Reduce tlore] -> m (Lambda tlore -> ScremaForm tlore)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> [Reduce flore]
-> (Reduce flore -> m (Reduce tlore)) -> m [Reduce tlore]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM
[Reduce flore]
reds
( \(Reduce Commutativity
comm Lambda flore
red_lam [SubExp]
red_nes) ->
Commutativity -> Lambda tlore -> [SubExp] -> Reduce tlore
forall lore.
Commutativity -> Lambda lore -> [SubExp] -> Reduce lore
Reduce Commutativity
comm (Lambda tlore -> [SubExp] -> Reduce tlore)
-> m (Lambda tlore) -> m ([SubExp] -> Reduce tlore)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SOACMapper flore tlore m -> Lambda flore -> m (Lambda tlore)
forall flore tlore (m :: * -> *).
SOACMapper flore tlore m -> Lambda flore -> m (Lambda tlore)
mapOnSOACLambda SOACMapper flore tlore m
tv Lambda flore
red_lam
m ([SubExp] -> Reduce tlore) -> m [SubExp] -> m (Reduce tlore)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (SubExp -> m SubExp) -> [SubExp] -> m [SubExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (SOACMapper flore tlore m -> SubExp -> m SubExp
forall flore tlore (m :: * -> *).
SOACMapper flore tlore m -> SubExp -> m SubExp
mapOnSOACSubExp SOACMapper flore tlore m
tv) [SubExp]
red_nes
)
m (Lambda tlore -> ScremaForm tlore)
-> m (Lambda tlore) -> m (ScremaForm tlore)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SOACMapper flore tlore m -> Lambda flore -> m (Lambda tlore)
forall flore tlore (m :: * -> *).
SOACMapper flore tlore m -> Lambda flore -> m (Lambda tlore)
mapOnSOACLambda SOACMapper flore tlore m
tv Lambda flore
map_lam
)
instance ASTLore lore => FreeIn (SOAC lore) where
freeIn' :: SOAC lore -> FV
freeIn' = (State FV (SOAC lore) -> FV -> FV)
-> FV -> State FV (SOAC lore) -> FV
forall a b c. (a -> b -> c) -> b -> a -> c
flip State FV (SOAC lore) -> FV -> FV
forall s a. State s a -> s -> s
execState FV
forall a. Monoid a => a
mempty (State FV (SOAC lore) -> FV)
-> (SOAC lore -> State FV (SOAC lore)) -> SOAC lore -> FV
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. SOACMapper lore lore (StateT FV Identity)
-> SOAC lore -> State FV (SOAC lore)
forall (m :: * -> *) flore tlore.
(Applicative m, Monad m) =>
SOACMapper flore tlore m -> SOAC flore -> m (SOAC tlore)
mapSOACM SOACMapper lore lore (StateT FV Identity)
free
where
walk :: (b -> s) -> b -> m b
walk b -> s
f b
x = (s -> s) -> m ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify (s -> s -> s
forall a. Semigroup a => a -> a -> a
<> b -> s
f b
x) m () -> m b -> m b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> b -> m b
forall (m :: * -> *) a. Monad m => a -> m a
return b
x
free :: SOACMapper lore lore (StateT FV Identity)
free =
SOACMapper :: forall flore tlore (m :: * -> *).
(SubExp -> m SubExp)
-> (Lambda flore -> m (Lambda tlore))
-> (VName -> m VName)
-> SOACMapper flore tlore m
SOACMapper
{ mapOnSOACSubExp :: SubExp -> StateT FV Identity SubExp
mapOnSOACSubExp = (SubExp -> FV) -> SubExp -> StateT FV Identity SubExp
forall {m :: * -> *} {s} {b}.
(MonadState s m, Semigroup s) =>
(b -> s) -> b -> m b
walk SubExp -> FV
forall a. FreeIn a => a -> FV
freeIn',
mapOnSOACLambda :: Lambda lore -> StateT FV Identity (Lambda lore)
mapOnSOACLambda = (Lambda lore -> FV)
-> Lambda lore -> StateT FV Identity (Lambda lore)
forall {m :: * -> *} {s} {b}.
(MonadState s m, Semigroup s) =>
(b -> s) -> b -> m b
walk Lambda lore -> FV
forall a. FreeIn a => a -> FV
freeIn',
mapOnSOACVName :: VName -> StateT FV Identity VName
mapOnSOACVName = (VName -> FV) -> VName -> StateT FV Identity VName
forall {m :: * -> *} {s} {b}.
(MonadState s m, Semigroup s) =>
(b -> s) -> b -> m b
walk VName -> FV
forall a. FreeIn a => a -> FV
freeIn'
}
instance ASTLore lore => Substitute (SOAC lore) where
substituteNames :: Map VName VName -> SOAC lore -> SOAC lore
substituteNames Map VName VName
subst =
Identity (SOAC lore) -> SOAC lore
forall a. Identity a -> a
runIdentity (Identity (SOAC lore) -> SOAC lore)
-> (SOAC lore -> Identity (SOAC lore)) -> SOAC lore -> SOAC lore
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. SOACMapper lore lore Identity -> SOAC lore -> Identity (SOAC lore)
forall (m :: * -> *) flore tlore.
(Applicative m, Monad m) =>
SOACMapper flore tlore m -> SOAC flore -> m (SOAC tlore)
mapSOACM SOACMapper lore lore Identity
substitute
where
substitute :: SOACMapper lore lore Identity
substitute =
SOACMapper :: forall flore tlore (m :: * -> *).
(SubExp -> m SubExp)
-> (Lambda flore -> m (Lambda tlore))
-> (VName -> m VName)
-> SOACMapper flore tlore m
SOACMapper
{ mapOnSOACSubExp :: SubExp -> Identity SubExp
mapOnSOACSubExp = SubExp -> Identity SubExp
forall (m :: * -> *) a. Monad m => a -> m a
return (SubExp -> Identity SubExp)
-> (SubExp -> SubExp) -> SubExp -> Identity SubExp
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Map VName VName -> SubExp -> SubExp
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst,
mapOnSOACLambda :: Lambda lore -> Identity (Lambda lore)
mapOnSOACLambda = Lambda lore -> Identity (Lambda lore)
forall (m :: * -> *) a. Monad m => a -> m a
return (Lambda lore -> Identity (Lambda lore))
-> (Lambda lore -> Lambda lore)
-> Lambda lore
-> Identity (Lambda lore)
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Map VName VName -> Lambda lore -> Lambda lore
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst,
mapOnSOACVName :: VName -> Identity VName
mapOnSOACVName = VName -> Identity VName
forall (m :: * -> *) a. Monad m => a -> m a
return (VName -> Identity VName)
-> (VName -> VName) -> VName -> Identity VName
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Map VName VName -> VName -> VName
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst
}
instance ASTLore lore => Rename (SOAC lore) where
rename :: SOAC lore -> RenameM (SOAC lore)
rename = SOACMapper lore lore RenameM -> SOAC lore -> RenameM (SOAC lore)
forall (m :: * -> *) flore tlore.
(Applicative m, Monad m) =>
SOACMapper flore tlore m -> SOAC flore -> m (SOAC tlore)
mapSOACM SOACMapper lore lore RenameM
renamer
where
renamer :: SOACMapper lore lore RenameM
renamer = (SubExp -> RenameM SubExp)
-> (Lambda lore -> RenameM (Lambda lore))
-> (VName -> RenameM VName)
-> SOACMapper lore lore RenameM
forall flore tlore (m :: * -> *).
(SubExp -> m SubExp)
-> (Lambda flore -> m (Lambda tlore))
-> (VName -> m VName)
-> SOACMapper flore tlore m
SOACMapper SubExp -> RenameM SubExp
forall a. Rename a => a -> RenameM a
rename Lambda lore -> RenameM (Lambda lore)
forall a. Rename a => a -> RenameM a
rename VName -> RenameM VName
forall a. Rename a => a -> RenameM a
rename
soacType :: SOAC lore -> [Type]
soacType :: forall lore. SOAC lore -> [Type]
soacType (Stream SubExp
outersize [VName]
_ StreamForm lore
_ [SubExp]
accs Lambda lore
lam) =
(Type -> Type) -> [Type] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map (Map VName SubExp -> Type -> Type
substNamesInType Map VName SubExp
substs) [Type]
rtp
where
nms :: [VName]
nms = (Param (LParamInfo lore) -> VName)
-> [Param (LParamInfo lore)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param (LParamInfo lore) -> VName
forall dec. Param dec -> VName
paramName ([Param (LParamInfo lore)] -> [VName])
-> [Param (LParamInfo lore)] -> [VName]
forall a b. (a -> b) -> a -> b
$ Int -> [Param (LParamInfo lore)] -> [Param (LParamInfo lore)]
forall a. Int -> [a] -> [a]
take (Int
1 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ [SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
accs) [Param (LParamInfo lore)]
params
substs :: Map VName SubExp
substs = [(VName, SubExp)] -> Map VName SubExp
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([(VName, SubExp)] -> Map VName SubExp)
-> [(VName, SubExp)] -> Map VName SubExp
forall a b. (a -> b) -> a -> b
$ [VName] -> [SubExp] -> [(VName, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
nms (SubExp
outersize SubExp -> [SubExp] -> [SubExp]
forall a. a -> [a] -> [a]
: [SubExp]
accs)
Lambda [Param (LParamInfo lore)]
params BodyT lore
_ [Type]
rtp = Lambda lore
lam
soacType (Scatter SubExp
_w Lambda lore
lam [VName]
_ivs [(Shape, Int, VName)]
as) =
(Type -> Shape -> Type) -> [Type] -> [Shape] -> [Type]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Type -> Shape -> Type
arrayOfShape [Type]
val_ts [Shape]
ws
where
indexes :: Int
indexes = [Int] -> Int
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum ([Int] -> Int) -> [Int] -> Int
forall a b. (a -> b) -> a -> b
$ (Int -> Int -> Int) -> [Int] -> [Int] -> [Int]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Int -> Int -> Int
forall a. Num a => a -> a -> a
(*) [Int]
ns ([Int] -> [Int]) -> [Int] -> [Int]
forall a b. (a -> b) -> a -> b
$ (Shape -> Int) -> [Shape] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map Shape -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Shape]
ws
val_ts :: [Type]
val_ts = Int -> [Type] -> [Type]
forall a. Int -> [a] -> [a]
drop Int
indexes ([Type] -> [Type]) -> [Type] -> [Type]
forall a b. (a -> b) -> a -> b
$ Lambda lore -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType Lambda lore
lam
([Shape]
ws, [Int]
ns, [VName]
_) = [(Shape, Int, VName)] -> ([Shape], [Int], [VName])
forall a b c. [(a, b, c)] -> ([a], [b], [c])
unzip3 [(Shape, Int, VName)]
as
soacType (Hist SubExp
_len [HistOp lore]
ops Lambda lore
_bucket_fun [VName]
_imgs) = do
HistOp lore
op <- [HistOp lore]
ops
(Type -> Type) -> [Type] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map (Type -> SubExp -> Type
forall d.
ArrayShape (ShapeBase d) =>
TypeBase (ShapeBase d) NoUniqueness
-> d -> TypeBase (ShapeBase d) NoUniqueness
`arrayOfRow` HistOp lore -> SubExp
forall lore. HistOp lore -> SubExp
histWidth HistOp lore
op) (Lambda lore -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType (Lambda lore -> [Type]) -> Lambda lore -> [Type]
forall a b. (a -> b) -> a -> b
$ HistOp lore -> Lambda lore
forall lore. HistOp lore -> Lambda lore
histOp HistOp lore
op)
soacType (Screma SubExp
w [VName]
_arrs ScremaForm lore
form) =
SubExp -> ScremaForm lore -> [Type]
forall lore. SubExp -> ScremaForm lore -> [Type]
scremaType SubExp
w ScremaForm lore
form
instance TypedOp (SOAC lore) where
opType :: forall t (m :: * -> *). HasScope t m => SOAC lore -> m [ExtType]
opType = [ExtType] -> m [ExtType]
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([ExtType] -> m [ExtType])
-> (SOAC lore -> [ExtType]) -> SOAC lore -> m [ExtType]
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. [Type] -> [ExtType]
forall u. [TypeBase Shape u] -> [TypeBase ExtShape u]
staticShapes ([Type] -> [ExtType])
-> (SOAC lore -> [Type]) -> SOAC lore -> [ExtType]
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. SOAC lore -> [Type]
forall lore. SOAC lore -> [Type]
soacType
instance (ASTLore lore, Aliased lore) => AliasedOp (SOAC lore) where
opAliases :: SOAC lore -> [Names]
opAliases = (Type -> Names) -> [Type] -> [Names]
forall a b. (a -> b) -> [a] -> [b]
map (Names -> Type -> Names
forall a b. a -> b -> a
const Names
forall a. Monoid a => a
mempty) ([Type] -> [Names])
-> (SOAC lore -> [Type]) -> SOAC lore -> [Names]
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. SOAC lore -> [Type]
forall lore. SOAC lore -> [Type]
soacType
consumedInOp :: SOAC lore -> Names
consumedInOp (Screma SubExp
_ [VName]
arrs (ScremaForm [Scan lore]
_ [Reduce lore]
_ Lambda lore
map_lam)) =
(VName -> VName) -> Names -> Names
mapNames VName -> VName
consumedArray (Names -> Names) -> Names -> Names
forall a b. (a -> b) -> a -> b
$ Lambda lore -> Names
forall lore. Aliased lore => Lambda lore -> Names
consumedByLambda Lambda lore
map_lam
where
consumedArray :: VName -> VName
consumedArray VName
v = VName -> Maybe VName -> VName
forall a. a -> Maybe a -> a
fromMaybe VName
v (Maybe VName -> VName) -> Maybe VName -> VName
forall a b. (a -> b) -> a -> b
$ VName -> [(VName, VName)] -> Maybe VName
forall a b. Eq a => a -> [(a, b)] -> Maybe b
lookup VName
v [(VName, VName)]
params_to_arrs
params_to_arrs :: [(VName, VName)]
params_to_arrs = [VName] -> [VName] -> [(VName, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip ((Param (LParamInfo lore) -> VName)
-> [Param (LParamInfo lore)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param (LParamInfo lore) -> VName
forall dec. Param dec -> VName
paramName ([Param (LParamInfo lore)] -> [VName])
-> [Param (LParamInfo lore)] -> [VName]
forall a b. (a -> b) -> a -> b
$ Lambda lore -> [Param (LParamInfo lore)]
forall {lore}. LambdaT lore -> [Param (LParamInfo lore)]
lambdaParams Lambda lore
map_lam) [VName]
arrs
consumedInOp (Stream SubExp
_ [VName]
arrs StreamForm lore
form [SubExp]
accs Lambda lore
lam) =
[VName] -> Names
namesFromList ([VName] -> Names) -> [VName] -> Names
forall a b. (a -> b) -> a -> b
$
[SubExp] -> [VName]
subExpVars ([SubExp] -> [VName]) -> [SubExp] -> [VName]
forall a b. (a -> b) -> a -> b
$
case StreamForm lore
form of
StreamForm lore
Sequential ->
(VName -> SubExp) -> [VName] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
consumedArray ([VName] -> [SubExp]) -> [VName] -> [SubExp]
forall a b. (a -> b) -> a -> b
$ Names -> [VName]
namesToList (Names -> [VName]) -> Names -> [VName]
forall a b. (a -> b) -> a -> b
$ Lambda lore -> Names
forall lore. Aliased lore => Lambda lore -> Names
consumedByLambda Lambda lore
lam
Parallel {} ->
(VName -> SubExp) -> [VName] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
consumedArray ([VName] -> [SubExp]) -> [VName] -> [SubExp]
forall a b. (a -> b) -> a -> b
$ Names -> [VName]
namesToList (Names -> [VName]) -> Names -> [VName]
forall a b. (a -> b) -> a -> b
$ Lambda lore -> Names
forall lore. Aliased lore => Lambda lore -> Names
consumedByLambda Lambda lore
lam
where
consumedArray :: VName -> SubExp
consumedArray VName
v = SubExp -> Maybe SubExp -> SubExp
forall a. a -> Maybe a -> a
fromMaybe (VName -> SubExp
Var VName
v) (Maybe SubExp -> SubExp) -> Maybe SubExp -> SubExp
forall a b. (a -> b) -> a -> b
$ VName -> [(VName, SubExp)] -> Maybe SubExp
forall a b. Eq a => a -> [(a, b)] -> Maybe b
lookup VName
v [(VName, SubExp)]
paramsToInput
paramsToInput :: [(VName, SubExp)]
paramsToInput =
[VName] -> [SubExp] -> [(VName, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip ((Param (LParamInfo lore) -> VName)
-> [Param (LParamInfo lore)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param (LParamInfo lore) -> VName
forall dec. Param dec -> VName
paramName ([Param (LParamInfo lore)] -> [VName])
-> [Param (LParamInfo lore)] -> [VName]
forall a b. (a -> b) -> a -> b
$ Int -> [Param (LParamInfo lore)] -> [Param (LParamInfo lore)]
forall a. Int -> [a] -> [a]
drop Int
1 ([Param (LParamInfo lore)] -> [Param (LParamInfo lore)])
-> [Param (LParamInfo lore)] -> [Param (LParamInfo lore)]
forall a b. (a -> b) -> a -> b
$ Lambda lore -> [Param (LParamInfo lore)]
forall {lore}. LambdaT lore -> [Param (LParamInfo lore)]
lambdaParams Lambda lore
lam) ([SubExp]
accs [SubExp] -> [SubExp] -> [SubExp]
forall a. [a] -> [a] -> [a]
++ (VName -> SubExp) -> [VName] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var [VName]
arrs)
consumedInOp (Scatter SubExp
_ Lambda lore
_ [VName]
_ [(Shape, Int, VName)]
as) =
[VName] -> Names
namesFromList ([VName] -> Names) -> [VName] -> Names
forall a b. (a -> b) -> a -> b
$ ((Shape, Int, VName) -> VName) -> [(Shape, Int, VName)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map (\(Shape
_, Int
_, VName
a) -> VName
a) [(Shape, Int, VName)]
as
consumedInOp (Hist SubExp
_ [HistOp lore]
ops Lambda lore
_ [VName]
_) =
[VName] -> Names
namesFromList ([VName] -> Names) -> [VName] -> Names
forall a b. (a -> b) -> a -> b
$ (HistOp lore -> [VName]) -> [HistOp lore] -> [VName]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap HistOp lore -> [VName]
forall lore. HistOp lore -> [VName]
histDest [HistOp lore]
ops
mapHistOp ::
(Lambda flore -> Lambda tlore) ->
HistOp flore ->
HistOp tlore
mapHistOp :: forall flore tlore.
(Lambda flore -> Lambda tlore) -> HistOp flore -> HistOp tlore
mapHistOp Lambda flore -> Lambda tlore
f (HistOp SubExp
w SubExp
rf [VName]
dests [SubExp]
nes Lambda flore
lam) =
SubExp
-> SubExp -> [VName] -> [SubExp] -> Lambda tlore -> HistOp tlore
forall lore.
SubExp
-> SubExp -> [VName] -> [SubExp] -> Lambda lore -> HistOp lore
HistOp SubExp
w SubExp
rf [VName]
dests [SubExp]
nes (Lambda tlore -> HistOp tlore) -> Lambda tlore -> HistOp tlore
forall a b. (a -> b) -> a -> b
$ Lambda flore -> Lambda tlore
f Lambda flore
lam
instance
( ASTLore lore,
ASTLore (Aliases lore),
CanBeAliased (Op lore)
) =>
CanBeAliased (SOAC lore)
where
type OpWithAliases (SOAC lore) = SOAC (Aliases lore)
addOpAliases :: AliasTable -> SOAC lore -> OpWithAliases (SOAC lore)
addOpAliases AliasTable
aliases (Stream SubExp
size [VName]
arr StreamForm lore
form [SubExp]
accs Lambda lore
lam) =
SubExp
-> [VName]
-> StreamForm (Aliases lore)
-> [SubExp]
-> Lambda (Aliases lore)
-> SOAC (Aliases lore)
forall lore.
SubExp
-> [VName]
-> StreamForm lore
-> [SubExp]
-> Lambda lore
-> SOAC lore
Stream SubExp
size [VName]
arr (StreamForm lore -> StreamForm (Aliases lore)
analyseStreamForm StreamForm lore
form) [SubExp]
accs (Lambda (Aliases lore) -> SOAC (Aliases lore))
-> Lambda (Aliases lore) -> SOAC (Aliases lore)
forall a b. (a -> b) -> a -> b
$
AliasTable -> Lambda lore -> Lambda (Aliases lore)
forall lore.
(ASTLore lore, CanBeAliased (Op lore)) =>
AliasTable -> Lambda lore -> Lambda (Aliases lore)
Alias.analyseLambda AliasTable
aliases Lambda lore
lam
where
analyseStreamForm :: StreamForm lore -> StreamForm (Aliases lore)
analyseStreamForm (Parallel StreamOrd
o Commutativity
comm Lambda lore
lam0) =
StreamOrd
-> Commutativity
-> Lambda (Aliases lore)
-> StreamForm (Aliases lore)
forall lore.
StreamOrd -> Commutativity -> Lambda lore -> StreamForm lore
Parallel StreamOrd
o Commutativity
comm (AliasTable -> Lambda lore -> Lambda (Aliases lore)
forall lore.
(ASTLore lore, CanBeAliased (Op lore)) =>
AliasTable -> Lambda lore -> Lambda (Aliases lore)
Alias.analyseLambda AliasTable
aliases Lambda lore
lam0)
analyseStreamForm StreamForm lore
Sequential = StreamForm (Aliases lore)
forall lore. StreamForm lore
Sequential
addOpAliases AliasTable
aliases (Scatter SubExp
len Lambda lore
lam [VName]
ivs [(Shape, Int, VName)]
as) =
SubExp
-> Lambda (Aliases lore)
-> [VName]
-> [(Shape, Int, VName)]
-> SOAC (Aliases lore)
forall lore.
SubExp
-> Lambda lore -> [VName] -> [(Shape, Int, VName)] -> SOAC lore
Scatter SubExp
len (AliasTable -> Lambda lore -> Lambda (Aliases lore)
forall lore.
(ASTLore lore, CanBeAliased (Op lore)) =>
AliasTable -> Lambda lore -> Lambda (Aliases lore)
Alias.analyseLambda AliasTable
aliases Lambda lore
lam) [VName]
ivs [(Shape, Int, VName)]
as
addOpAliases AliasTable
aliases (Hist SubExp
len [HistOp lore]
ops Lambda lore
bucket_fun [VName]
imgs) =
SubExp
-> [HistOp (Aliases lore)]
-> Lambda (Aliases lore)
-> [VName]
-> SOAC (Aliases lore)
forall lore.
SubExp -> [HistOp lore] -> Lambda lore -> [VName] -> SOAC lore
Hist
SubExp
len
((HistOp lore -> HistOp (Aliases lore))
-> [HistOp lore] -> [HistOp (Aliases lore)]
forall a b. (a -> b) -> [a] -> [b]
map ((Lambda lore -> Lambda (Aliases lore))
-> HistOp lore -> HistOp (Aliases lore)
forall flore tlore.
(Lambda flore -> Lambda tlore) -> HistOp flore -> HistOp tlore
mapHistOp (AliasTable -> Lambda lore -> Lambda (Aliases lore)
forall lore.
(ASTLore lore, CanBeAliased (Op lore)) =>
AliasTable -> Lambda lore -> Lambda (Aliases lore)
Alias.analyseLambda AliasTable
aliases)) [HistOp lore]
ops)
(AliasTable -> Lambda lore -> Lambda (Aliases lore)
forall lore.
(ASTLore lore, CanBeAliased (Op lore)) =>
AliasTable -> Lambda lore -> Lambda (Aliases lore)
Alias.analyseLambda AliasTable
aliases Lambda lore
bucket_fun)
[VName]
imgs
addOpAliases AliasTable
aliases (Screma SubExp
w [VName]
arrs (ScremaForm [Scan lore]
scans [Reduce lore]
reds Lambda lore
map_lam)) =
SubExp
-> [VName] -> ScremaForm (Aliases lore) -> SOAC (Aliases lore)
forall lore. SubExp -> [VName] -> ScremaForm lore -> SOAC lore
Screma SubExp
w [VName]
arrs (ScremaForm (Aliases lore) -> SOAC (Aliases lore))
-> ScremaForm (Aliases lore) -> SOAC (Aliases lore)
forall a b. (a -> b) -> a -> b
$
[Scan (Aliases lore)]
-> [Reduce (Aliases lore)]
-> Lambda (Aliases lore)
-> ScremaForm (Aliases lore)
forall lore.
[Scan lore] -> [Reduce lore] -> Lambda lore -> ScremaForm lore
ScremaForm
((Scan lore -> Scan (Aliases lore))
-> [Scan lore] -> [Scan (Aliases lore)]
forall a b. (a -> b) -> [a] -> [b]
map Scan lore -> Scan (Aliases lore)
onScan [Scan lore]
scans)
((Reduce lore -> Reduce (Aliases lore))
-> [Reduce lore] -> [Reduce (Aliases lore)]
forall a b. (a -> b) -> [a] -> [b]
map Reduce lore -> Reduce (Aliases lore)
onRed [Reduce lore]
reds)
(AliasTable -> Lambda lore -> Lambda (Aliases lore)
forall lore.
(ASTLore lore, CanBeAliased (Op lore)) =>
AliasTable -> Lambda lore -> Lambda (Aliases lore)
Alias.analyseLambda AliasTable
aliases Lambda lore
map_lam)
where
onRed :: Reduce lore -> Reduce (Aliases lore)
onRed Reduce lore
red = Reduce lore
red {redLambda :: Lambda (Aliases lore)
redLambda = AliasTable -> Lambda lore -> Lambda (Aliases lore)
forall lore.
(ASTLore lore, CanBeAliased (Op lore)) =>
AliasTable -> Lambda lore -> Lambda (Aliases lore)
Alias.analyseLambda AliasTable
aliases (Lambda lore -> Lambda (Aliases lore))
-> Lambda lore -> Lambda (Aliases lore)
forall a b. (a -> b) -> a -> b
$ Reduce lore -> Lambda lore
forall lore. Reduce lore -> Lambda lore
redLambda Reduce lore
red}
onScan :: Scan lore -> Scan (Aliases lore)
onScan Scan lore
scan = Scan lore
scan {scanLambda :: Lambda (Aliases lore)
scanLambda = AliasTable -> Lambda lore -> Lambda (Aliases lore)
forall lore.
(ASTLore lore, CanBeAliased (Op lore)) =>
AliasTable -> Lambda lore -> Lambda (Aliases lore)
Alias.analyseLambda AliasTable
aliases (Lambda lore -> Lambda (Aliases lore))
-> Lambda lore -> Lambda (Aliases lore)
forall a b. (a -> b) -> a -> b
$ Scan lore -> Lambda lore
forall lore. Scan lore -> Lambda lore
scanLambda Scan lore
scan}
removeOpAliases :: OpWithAliases (SOAC lore) -> SOAC lore
removeOpAliases = Identity (SOAC lore) -> SOAC lore
forall a. Identity a -> a
runIdentity (Identity (SOAC lore) -> SOAC lore)
-> (SOAC (Aliases lore) -> Identity (SOAC lore))
-> SOAC (Aliases lore)
-> SOAC lore
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. SOACMapper (Aliases lore) lore Identity
-> SOAC (Aliases lore) -> Identity (SOAC lore)
forall (m :: * -> *) flore tlore.
(Applicative m, Monad m) =>
SOACMapper flore tlore m -> SOAC flore -> m (SOAC tlore)
mapSOACM SOACMapper (Aliases lore) lore Identity
remove
where
remove :: SOACMapper (Aliases lore) lore Identity
remove = (SubExp -> Identity SubExp)
-> (Lambda (Aliases lore) -> Identity (Lambda lore))
-> (VName -> Identity VName)
-> SOACMapper (Aliases lore) lore Identity
forall flore tlore (m :: * -> *).
(SubExp -> m SubExp)
-> (Lambda flore -> m (Lambda tlore))
-> (VName -> m VName)
-> SOACMapper flore tlore m
SOACMapper SubExp -> Identity SubExp
forall (m :: * -> *) a. Monad m => a -> m a
return (Lambda lore -> Identity (Lambda lore)
forall (m :: * -> *) a. Monad m => a -> m a
return (Lambda lore -> Identity (Lambda lore))
-> (Lambda (Aliases lore) -> Lambda lore)
-> Lambda (Aliases lore)
-> Identity (Lambda lore)
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Lambda (Aliases lore) -> Lambda lore
forall lore.
CanBeAliased (Op lore) =>
Lambda (Aliases lore) -> Lambda lore
removeLambdaAliases) VName -> Identity VName
forall (m :: * -> *) a. Monad m => a -> m a
return
instance ASTLore lore => IsOp (SOAC lore) where
safeOp :: SOAC lore -> Bool
safeOp SOAC lore
_ = Bool
False
cheapOp :: SOAC lore -> Bool
cheapOp SOAC lore
_ = Bool
True
substNamesInType :: M.Map VName SubExp -> Type -> Type
substNamesInType :: Map VName SubExp -> Type -> Type
substNamesInType Map VName SubExp
_ t :: Type
t@Prim {} = Type
t
substNamesInType Map VName SubExp
_ t :: Type
t@Acc {} = Type
t
substNamesInType Map VName SubExp
_ (Mem Space
space) = Space -> Type
forall shape u. Space -> TypeBase shape u
Mem Space
space
substNamesInType Map VName SubExp
subs (Array PrimType
btp Shape
shp NoUniqueness
u) =
let shp' :: Shape
shp' = [SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape ([SubExp] -> Shape) -> [SubExp] -> Shape
forall a b. (a -> b) -> a -> b
$ (SubExp -> SubExp) -> [SubExp] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map (Map VName SubExp -> SubExp -> SubExp
substNamesInSubExp Map VName SubExp
subs) (Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims Shape
shp)
in PrimType -> Shape -> NoUniqueness -> Type
forall shape u. PrimType -> shape -> u -> TypeBase shape u
Array PrimType
btp Shape
shp' NoUniqueness
u
substNamesInSubExp :: M.Map VName SubExp -> SubExp -> SubExp
substNamesInSubExp :: Map VName SubExp -> SubExp -> SubExp
substNamesInSubExp Map VName SubExp
_ e :: SubExp
e@(Constant PrimValue
_) = SubExp
e
substNamesInSubExp Map VName SubExp
subs (Var VName
idd) =
SubExp -> VName -> Map VName SubExp -> SubExp
forall k a. Ord k => a -> k -> Map k a -> a
M.findWithDefault (VName -> SubExp
Var VName
idd) VName
idd Map VName SubExp
subs
instance (ASTLore lore, CanBeWise (Op lore)) => CanBeWise (SOAC lore) where
type OpWithWisdom (SOAC lore) = SOAC (Wise lore)
removeOpWisdom :: OpWithWisdom (SOAC lore) -> SOAC lore
removeOpWisdom = Identity (SOAC lore) -> SOAC lore
forall a. Identity a -> a
runIdentity (Identity (SOAC lore) -> SOAC lore)
-> (SOAC (Wise lore) -> Identity (SOAC lore))
-> SOAC (Wise lore)
-> SOAC lore
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. SOACMapper (Wise lore) lore Identity
-> SOAC (Wise lore) -> Identity (SOAC lore)
forall (m :: * -> *) flore tlore.
(Applicative m, Monad m) =>
SOACMapper flore tlore m -> SOAC flore -> m (SOAC tlore)
mapSOACM SOACMapper (Wise lore) lore Identity
remove
where
remove :: SOACMapper (Wise lore) lore Identity
remove = (SubExp -> Identity SubExp)
-> (Lambda (Wise lore) -> Identity (Lambda lore))
-> (VName -> Identity VName)
-> SOACMapper (Wise lore) lore Identity
forall flore tlore (m :: * -> *).
(SubExp -> m SubExp)
-> (Lambda flore -> m (Lambda tlore))
-> (VName -> m VName)
-> SOACMapper flore tlore m
SOACMapper SubExp -> Identity SubExp
forall (m :: * -> *) a. Monad m => a -> m a
return (Lambda lore -> Identity (Lambda lore)
forall (m :: * -> *) a. Monad m => a -> m a
return (Lambda lore -> Identity (Lambda lore))
-> (Lambda (Wise lore) -> Lambda lore)
-> Lambda (Wise lore)
-> Identity (Lambda lore)
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Lambda (Wise lore) -> Lambda lore
forall lore.
CanBeWise (Op lore) =>
Lambda (Wise lore) -> Lambda lore
removeLambdaWisdom) VName -> Identity VName
forall (m :: * -> *) a. Monad m => a -> m a
return
instance Decorations lore => ST.IndexOp (SOAC lore) where
indexOp :: forall lore.
(ASTLore lore, IndexOp (Op lore)) =>
SymbolTable lore
-> Int -> SOAC lore -> [TPrimExp Int64 VName] -> Maybe Indexed
indexOp SymbolTable lore
vtable Int
k SOAC lore
soac [TPrimExp Int64 VName
i] = do
(LambdaT lore
lam, SubExp
se, [Param (LParamInfo lore)]
arr_params, [VName]
arrs) <- SOAC lore
-> Maybe (LambdaT lore, SubExp, [Param (LParamInfo lore)], [VName])
lambdaAndSubExp SOAC lore
soac
let arr_indexes :: Map VName (PrimExp VName, Certificates)
arr_indexes = [(VName, (PrimExp VName, Certificates))]
-> Map VName (PrimExp VName, Certificates)
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([(VName, (PrimExp VName, Certificates))]
-> Map VName (PrimExp VName, Certificates))
-> [(VName, (PrimExp VName, Certificates))]
-> Map VName (PrimExp VName, Certificates)
forall a b. (a -> b) -> a -> b
$ [Maybe (VName, (PrimExp VName, Certificates))]
-> [(VName, (PrimExp VName, Certificates))]
forall a. [Maybe a] -> [a]
catMaybes ([Maybe (VName, (PrimExp VName, Certificates))]
-> [(VName, (PrimExp VName, Certificates))])
-> [Maybe (VName, (PrimExp VName, Certificates))]
-> [(VName, (PrimExp VName, Certificates))]
forall a b. (a -> b) -> a -> b
$ (Param (LParamInfo lore)
-> VName -> Maybe (VName, (PrimExp VName, Certificates)))
-> [Param (LParamInfo lore)]
-> [VName]
-> [Maybe (VName, (PrimExp VName, Certificates))]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Param (LParamInfo lore)
-> VName -> Maybe (VName, (PrimExp VName, Certificates))
arrIndex [Param (LParamInfo lore)]
arr_params [VName]
arrs
arr_indexes' :: Map VName (PrimExp VName, Certificates)
arr_indexes' = (Map VName (PrimExp VName, Certificates)
-> Stm lore -> Map VName (PrimExp VName, Certificates))
-> Map VName (PrimExp VName, Certificates)
-> Seq (Stm lore)
-> Map VName (PrimExp VName, Certificates)
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl Map VName (PrimExp VName, Certificates)
-> Stm lore -> Map VName (PrimExp VName, Certificates)
expandPrimExpTable Map VName (PrimExp VName, Certificates)
arr_indexes (Seq (Stm lore) -> Map VName (PrimExp VName, Certificates))
-> Seq (Stm lore) -> Map VName (PrimExp VName, Certificates)
forall a b. (a -> b) -> a -> b
$ BodyT lore -> Seq (Stm lore)
forall lore. BodyT lore -> Stms lore
bodyStms (BodyT lore -> Seq (Stm lore)) -> BodyT lore -> Seq (Stm lore)
forall a b. (a -> b) -> a -> b
$ LambdaT lore -> BodyT lore
forall lore. LambdaT lore -> BodyT lore
lambdaBody LambdaT lore
lam
case SubExp
se of
Var VName
v -> (PrimExp VName -> Certificates -> Indexed)
-> (PrimExp VName, Certificates) -> Indexed
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry ((Certificates -> PrimExp VName -> Indexed)
-> PrimExp VName -> Certificates -> Indexed
forall a b c. (a -> b -> c) -> b -> a -> c
flip Certificates -> PrimExp VName -> Indexed
ST.Indexed) ((PrimExp VName, Certificates) -> Indexed)
-> Maybe (PrimExp VName, Certificates) -> Maybe Indexed
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName
-> Map VName (PrimExp VName, Certificates)
-> Maybe (PrimExp VName, Certificates)
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
v Map VName (PrimExp VName, Certificates)
arr_indexes'
SubExp
_ -> Maybe Indexed
forall a. Maybe a
Nothing
where
lambdaAndSubExp :: SOAC lore
-> Maybe (LambdaT lore, SubExp, [Param (LParamInfo lore)], [VName])
lambdaAndSubExp (Screma SubExp
_ [VName]
arrs (ScremaForm [Scan lore]
scans [Reduce lore]
reds LambdaT lore
map_lam)) =
Int
-> LambdaT lore
-> [VName]
-> Maybe (LambdaT lore, SubExp, [Param (LParamInfo lore)], [VName])
nthMapOut ([Scan lore] -> Int
forall lore. [Scan lore] -> Int
scanResults [Scan lore]
scans Int -> Int -> Int
forall a. Num a => a -> a -> a
+ [Reduce lore] -> Int
forall lore. [Reduce lore] -> Int
redResults [Reduce lore]
reds) LambdaT lore
map_lam [VName]
arrs
lambdaAndSubExp SOAC lore
_ =
Maybe (LambdaT lore, SubExp, [Param (LParamInfo lore)], [VName])
forall a. Maybe a
Nothing
nthMapOut :: Int
-> LambdaT lore
-> [VName]
-> Maybe (LambdaT lore, SubExp, [Param (LParamInfo lore)], [VName])
nthMapOut Int
num_accs LambdaT lore
lam [VName]
arrs = do
SubExp
se <- Int -> [SubExp] -> Maybe SubExp
forall int a. Integral int => int -> [a] -> Maybe a
maybeNth (Int
num_accs Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
k) ([SubExp] -> Maybe SubExp) -> [SubExp] -> Maybe SubExp
forall a b. (a -> b) -> a -> b
$ BodyT lore -> [SubExp]
forall lore. BodyT lore -> [SubExp]
bodyResult (BodyT lore -> [SubExp]) -> BodyT lore -> [SubExp]
forall a b. (a -> b) -> a -> b
$ LambdaT lore -> BodyT lore
forall lore. LambdaT lore -> BodyT lore
lambdaBody LambdaT lore
lam
(LambdaT lore, SubExp, [Param (LParamInfo lore)], [VName])
-> Maybe (LambdaT lore, SubExp, [Param (LParamInfo lore)], [VName])
forall (m :: * -> *) a. Monad m => a -> m a
return (LambdaT lore
lam, SubExp
se, Int -> [Param (LParamInfo lore)] -> [Param (LParamInfo lore)]
forall a. Int -> [a] -> [a]
drop Int
num_accs ([Param (LParamInfo lore)] -> [Param (LParamInfo lore)])
-> [Param (LParamInfo lore)] -> [Param (LParamInfo lore)]
forall a b. (a -> b) -> a -> b
$ LambdaT lore -> [Param (LParamInfo lore)]
forall {lore}. LambdaT lore -> [Param (LParamInfo lore)]
lambdaParams LambdaT lore
lam, [VName]
arrs)
arrIndex :: Param (LParamInfo lore)
-> VName -> Maybe (VName, (PrimExp VName, Certificates))
arrIndex Param (LParamInfo lore)
p VName
arr = do
ST.Indexed Certificates
cs PrimExp VName
pe <- VName
-> [TPrimExp Int64 VName] -> SymbolTable lore -> Maybe Indexed
forall lore.
VName
-> [TPrimExp Int64 VName] -> SymbolTable lore -> Maybe Indexed
ST.index' VName
arr [TPrimExp Int64 VName
i] SymbolTable lore
vtable
(VName, (PrimExp VName, Certificates))
-> Maybe (VName, (PrimExp VName, Certificates))
forall (m :: * -> *) a. Monad m => a -> m a
return (Param (LParamInfo lore) -> VName
forall dec. Param dec -> VName
paramName Param (LParamInfo lore)
p, (PrimExp VName
pe, Certificates
cs))
expandPrimExpTable :: Map VName (PrimExp VName, Certificates)
-> Stm lore -> Map VName (PrimExp VName, Certificates)
expandPrimExpTable Map VName (PrimExp VName, Certificates)
table Stm lore
stm
| [VName
v] <- PatternT (LetDec lore) -> [VName]
forall dec. PatternT dec -> [VName]
patternNames (PatternT (LetDec lore) -> [VName])
-> PatternT (LetDec lore) -> [VName]
forall a b. (a -> b) -> a -> b
$ Stm lore -> PatternT (LetDec lore)
forall lore. Stm lore -> Pattern lore
stmPattern Stm lore
stm,
Just (PrimExp VName
pe, Certificates
cs) <-
WriterT Certificates Maybe (PrimExp VName)
-> Maybe (PrimExp VName, Certificates)
forall w (m :: * -> *) a. WriterT w m a -> m (a, w)
runWriterT (WriterT Certificates Maybe (PrimExp VName)
-> Maybe (PrimExp VName, Certificates))
-> WriterT Certificates Maybe (PrimExp VName)
-> Maybe (PrimExp VName, Certificates)
forall a b. (a -> b) -> a -> b
$ (VName -> WriterT Certificates Maybe (PrimExp VName))
-> Exp lore -> WriterT Certificates Maybe (PrimExp VName)
forall (m :: * -> *) lore v.
(MonadFail m, Decorations lore) =>
(VName -> m (PrimExp v)) -> Exp lore -> m (PrimExp v)
primExpFromExp (Map VName (PrimExp VName, Certificates)
-> VName -> WriterT Certificates Maybe (PrimExp VName)
asPrimExp Map VName (PrimExp VName, Certificates)
table) (Exp lore -> WriterT Certificates Maybe (PrimExp VName))
-> Exp lore -> WriterT Certificates Maybe (PrimExp VName)
forall a b. (a -> b) -> a -> b
$ Stm lore -> Exp lore
forall lore. Stm lore -> Exp lore
stmExp Stm lore
stm,
(VName -> Bool) -> [VName] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (VName -> SymbolTable lore -> Bool
forall lore. VName -> SymbolTable lore -> Bool
`ST.elem` SymbolTable lore
vtable) (Certificates -> [VName]
unCertificates (Certificates -> [VName]) -> Certificates -> [VName]
forall a b. (a -> b) -> a -> b
$ Stm lore -> Certificates
forall lore. Stm lore -> Certificates
stmCerts Stm lore
stm) =
VName
-> (PrimExp VName, Certificates)
-> Map VName (PrimExp VName, Certificates)
-> Map VName (PrimExp VName, Certificates)
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
v (PrimExp VName
pe, Stm lore -> Certificates
forall lore. Stm lore -> Certificates
stmCerts Stm lore
stm Certificates -> Certificates -> Certificates
forall a. Semigroup a => a -> a -> a
<> Certificates
cs) Map VName (PrimExp VName, Certificates)
table
| Bool
otherwise =
Map VName (PrimExp VName, Certificates)
table
asPrimExp :: Map VName (PrimExp VName, Certificates)
-> VName -> WriterT Certificates Maybe (PrimExp VName)
asPrimExp Map VName (PrimExp VName, Certificates)
table VName
v
| Just (PrimExp VName
e, Certificates
cs) <- VName
-> Map VName (PrimExp VName, Certificates)
-> Maybe (PrimExp VName, Certificates)
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
v Map VName (PrimExp VName, Certificates)
table = Certificates -> WriterT Certificates Maybe ()
forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell Certificates
cs WriterT Certificates Maybe ()
-> WriterT Certificates Maybe (PrimExp VName)
-> WriterT Certificates Maybe (PrimExp VName)
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> PrimExp VName -> WriterT Certificates Maybe (PrimExp VName)
forall (m :: * -> *) a. Monad m => a -> m a
return PrimExp VName
e
| Just (Prim PrimType
pt) <- VName -> SymbolTable lore -> Maybe Type
forall lore.
ASTLore lore =>
VName -> SymbolTable lore -> Maybe Type
ST.lookupType VName
v SymbolTable lore
vtable =
PrimExp VName -> WriterT Certificates Maybe (PrimExp VName)
forall (m :: * -> *) a. Monad m => a -> m a
return (PrimExp VName -> WriterT Certificates Maybe (PrimExp VName))
-> PrimExp VName -> WriterT Certificates Maybe (PrimExp VName)
forall a b. (a -> b) -> a -> b
$ VName -> PrimType -> PrimExp VName
forall v. v -> PrimType -> PrimExp v
LeafExp VName
v PrimType
pt
| Bool
otherwise = Maybe (PrimExp VName) -> WriterT Certificates Maybe (PrimExp VName)
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift Maybe (PrimExp VName)
forall a. Maybe a
Nothing
indexOp SymbolTable lore
_ Int
_ SOAC lore
_ [TPrimExp Int64 VName]
_ = Maybe Indexed
forall a. Maybe a
Nothing
typeCheckSOAC :: TC.Checkable lore => SOAC (Aliases lore) -> TC.TypeM lore ()
typeCheckSOAC :: forall lore. Checkable lore => SOAC (Aliases lore) -> TypeM lore ()
typeCheckSOAC (Stream SubExp
size [VName]
arrexps StreamForm (Aliases lore)
form [SubExp]
accexps Lambda (Aliases lore)
lam) = do
[Type] -> SubExp -> TypeM lore ()
forall lore. Checkable lore => [Type] -> SubExp -> TypeM lore ()
TC.require [PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64] SubExp
size
[Arg]
accargs <- (SubExp -> TypeM lore Arg) -> [SubExp] -> TypeM lore [Arg]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM SubExp -> TypeM lore Arg
forall lore. Checkable lore => SubExp -> TypeM lore Arg
TC.checkArg [SubExp]
accexps
[Type]
arrargs <- (VName -> TypeM lore Type) -> [VName] -> TypeM lore [Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM VName -> TypeM lore Type
forall lore (m :: * -> *). HasScope lore m => VName -> m Type
lookupType [VName]
arrexps
[Arg]
_ <- SubExp -> [VName] -> TypeM lore [Arg]
forall lore.
Checkable lore =>
SubExp -> [VName] -> TypeM lore [Arg]
TC.checkSOACArrayArgs SubExp
size [VName]
arrexps
let chunk :: Param (LParamInfo lore)
chunk = [Param (LParamInfo lore)] -> Param (LParamInfo lore)
forall a. [a] -> a
head ([Param (LParamInfo lore)] -> Param (LParamInfo lore))
-> [Param (LParamInfo lore)] -> Param (LParamInfo lore)
forall a b. (a -> b) -> a -> b
$ Lambda (Aliases lore) -> [LParam (Aliases lore)]
forall {lore}. LambdaT lore -> [Param (LParamInfo lore)]
lambdaParams Lambda (Aliases lore)
lam
let asArg :: a -> (a, b)
asArg a
t = (a
t, b
forall a. Monoid a => a
mempty)
inttp :: TypeBase shape u
inttp = PrimType -> TypeBase shape u
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64
lamarrs' :: [Type]
lamarrs' = (Type -> Type) -> [Type] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map (Type -> SubExp -> Type
forall d u.
ArrayShape (ShapeBase d) =>
TypeBase (ShapeBase d) u -> d -> TypeBase (ShapeBase d) u
`setOuterSize` VName -> SubExp
Var (Param (LParamInfo lore) -> VName
forall dec. Param dec -> VName
paramName Param (LParamInfo lore)
chunk)) [Type]
arrargs
let acc_len :: Int
acc_len = [SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
accexps
let lamrtp :: [Type]
lamrtp = Int -> [Type] -> [Type]
forall a. Int -> [a] -> [a]
take Int
acc_len ([Type] -> [Type]) -> [Type] -> [Type]
forall a b. (a -> b) -> a -> b
$ Lambda (Aliases lore) -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType Lambda (Aliases lore)
lam
Bool -> TypeM lore () -> TypeM lore ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless ((Arg -> Type) -> [Arg] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map Arg -> Type
TC.argType [Arg]
accargs [Type] -> [Type] -> Bool
forall a. Eq a => a -> a -> Bool
== [Type]
lamrtp) (TypeM lore () -> TypeM lore ()) -> TypeM lore () -> TypeM lore ()
forall a b. (a -> b) -> a -> b
$
ErrorCase lore -> TypeM lore ()
forall lore a. ErrorCase lore -> TypeM lore a
TC.bad (ErrorCase lore -> TypeM lore ())
-> ErrorCase lore -> TypeM lore ()
forall a b. (a -> b) -> a -> b
$ String -> ErrorCase lore
forall lore. String -> ErrorCase lore
TC.TypeError String
"Stream with inconsistent accumulator type in lambda."
()
_ <- case StreamForm (Aliases lore)
form of
Parallel StreamOrd
_ Commutativity
_ Lambda (Aliases lore)
lam0 -> do
let acct :: [Type]
acct = (Arg -> Type) -> [Arg] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map Arg -> Type
TC.argType [Arg]
accargs
outerRetType :: [Type]
outerRetType = Lambda (Aliases lore) -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType Lambda (Aliases lore)
lam0
Lambda (Aliases lore) -> [Arg] -> TypeM lore ()
forall lore.
Checkable lore =>
Lambda (Aliases lore) -> [Arg] -> TypeM lore ()
TC.checkLambda Lambda (Aliases lore)
lam0 ([Arg] -> TypeM lore ()) -> [Arg] -> TypeM lore ()
forall a b. (a -> b) -> a -> b
$ (Arg -> Arg) -> [Arg] -> [Arg]
forall a b. (a -> b) -> [a] -> [b]
map Arg -> Arg
TC.noArgAliases ([Arg] -> [Arg]) -> [Arg] -> [Arg]
forall a b. (a -> b) -> a -> b
$ [Arg]
accargs [Arg] -> [Arg] -> [Arg]
forall a. [a] -> [a] -> [a]
++ [Arg]
accargs
Bool -> TypeM lore () -> TypeM lore ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless ([Type]
acct [Type] -> [Type] -> Bool
forall a. Eq a => a -> a -> Bool
== [Type]
outerRetType) (TypeM lore () -> TypeM lore ()) -> TypeM lore () -> TypeM lore ()
forall a b. (a -> b) -> a -> b
$
ErrorCase lore -> TypeM lore ()
forall lore a. ErrorCase lore -> TypeM lore a
TC.bad (ErrorCase lore -> TypeM lore ())
-> ErrorCase lore -> TypeM lore ()
forall a b. (a -> b) -> a -> b
$
String -> ErrorCase lore
forall lore. String -> ErrorCase lore
TC.TypeError (String -> ErrorCase lore) -> String -> ErrorCase lore
forall a b. (a -> b) -> a -> b
$
String
"Initial value is of type " String -> ShowS
forall a. [a] -> [a] -> [a]
++ [Type] -> String
forall a. Pretty a => [a] -> String
prettyTuple [Type]
acct
String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
", but stream's reduce lambda returns type "
String -> ShowS
forall a. [a] -> [a] -> [a]
++ [Type] -> String
forall a. Pretty a => [a] -> String
prettyTuple [Type]
outerRetType
String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
"."
StreamForm (Aliases lore)
Sequential -> () -> TypeM lore ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
let fake_lamarrs' :: [Arg]
fake_lamarrs' = (Type -> Arg) -> [Type] -> [Arg]
forall a b. (a -> b) -> [a] -> [b]
map Type -> Arg
forall {b} {a}. Monoid b => a -> (a, b)
asArg [Type]
lamarrs'
Lambda (Aliases lore) -> [Arg] -> TypeM lore ()
forall lore.
Checkable lore =>
Lambda (Aliases lore) -> [Arg] -> TypeM lore ()
TC.checkLambda Lambda (Aliases lore)
lam ([Arg] -> TypeM lore ()) -> [Arg] -> TypeM lore ()
forall a b. (a -> b) -> a -> b
$ Type -> Arg
forall {b} {a}. Monoid b => a -> (a, b)
asArg Type
forall {shape} {u}. TypeBase shape u
inttp Arg -> [Arg] -> [Arg]
forall a. a -> [a] -> [a]
: [Arg]
accargs [Arg] -> [Arg] -> [Arg]
forall a. [a] -> [a] -> [a]
++ [Arg]
fake_lamarrs'
typeCheckSOAC (Scatter SubExp
w Lambda (Aliases lore)
lam [VName]
ivs [(Shape, Int, VName)]
as) = do
[Type] -> SubExp -> TypeM lore ()
forall lore. Checkable lore => [Type] -> SubExp -> TypeM lore ()
TC.require [PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64] SubExp
w
let ([Shape]
as_ws, [Int]
as_ns, [VName]
_as_vs) = [(Shape, Int, VName)] -> ([Shape], [Int], [VName])
forall a b c. [(a, b, c)] -> ([a], [b], [c])
unzip3 [(Shape, Int, VName)]
as
indexes :: Int
indexes = [Int] -> Int
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum ([Int] -> Int) -> [Int] -> Int
forall a b. (a -> b) -> a -> b
$ (Int -> Int -> Int) -> [Int] -> [Int] -> [Int]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Int -> Int -> Int
forall a. Num a => a -> a -> a
(*) [Int]
as_ns ([Int] -> [Int]) -> [Int] -> [Int]
forall a b. (a -> b) -> a -> b
$ (Shape -> Int) -> [Shape] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map Shape -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Shape]
as_ws
rts :: [Type]
rts = Lambda (Aliases lore) -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType Lambda (Aliases lore)
lam
rtsI :: [Type]
rtsI = Int -> [Type] -> [Type]
forall a. Int -> [a] -> [a]
take Int
indexes [Type]
rts
rtsV :: [Type]
rtsV = Int -> [Type] -> [Type]
forall a. Int -> [a] -> [a]
drop Int
indexes [Type]
rts
Bool -> TypeM lore () -> TypeM lore ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless ([Type] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Type]
rts Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== [Int] -> Int
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum [Int]
as_ns Int -> Int -> Int
forall a. Num a => a -> a -> a
+ [Int] -> Int
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum ((Int -> Int -> Int) -> [Int] -> [Int] -> [Int]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Int -> Int -> Int
forall a. Num a => a -> a -> a
(*) [Int]
as_ns ([Int] -> [Int]) -> [Int] -> [Int]
forall a b. (a -> b) -> a -> b
$ (Shape -> Int) -> [Shape] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map Shape -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Shape]
as_ws)) (TypeM lore () -> TypeM lore ()) -> TypeM lore () -> TypeM lore ()
forall a b. (a -> b) -> a -> b
$
ErrorCase lore -> TypeM lore ()
forall lore a. ErrorCase lore -> TypeM lore a
TC.bad (ErrorCase lore -> TypeM lore ())
-> ErrorCase lore -> TypeM lore ()
forall a b. (a -> b) -> a -> b
$ String -> ErrorCase lore
forall lore. String -> ErrorCase lore
TC.TypeError String
"Scatter: number of index types, value types and array outputs do not match."
[Type] -> (Type -> TypeM lore ()) -> TypeM lore ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [Type]
rtsI ((Type -> TypeM lore ()) -> TypeM lore ())
-> (Type -> TypeM lore ()) -> TypeM lore ()
forall a b. (a -> b) -> a -> b
$ \Type
rtI ->
Bool -> TypeM lore () -> TypeM lore ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64 Type -> Type -> Bool
forall a. Eq a => a -> a -> Bool
== Type
rtI) (TypeM lore () -> TypeM lore ()) -> TypeM lore () -> TypeM lore ()
forall a b. (a -> b) -> a -> b
$
ErrorCase lore -> TypeM lore ()
forall lore a. ErrorCase lore -> TypeM lore a
TC.bad (ErrorCase lore -> TypeM lore ())
-> ErrorCase lore -> TypeM lore ()
forall a b. (a -> b) -> a -> b
$ String -> ErrorCase lore
forall lore. String -> ErrorCase lore
TC.TypeError String
"Scatter: Index return type must be i64."
[([Type], (Shape, Int, VName))]
-> (([Type], (Shape, Int, VName)) -> TypeM lore ())
-> TypeM lore ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([[Type]]
-> [(Shape, Int, VName)] -> [([Type], (Shape, Int, VName))]
forall a b. [a] -> [b] -> [(a, b)]
zip ([Int] -> [Type] -> [[Type]]
forall a. [Int] -> [a] -> [[a]]
chunks [Int]
as_ns [Type]
rtsV) [(Shape, Int, VName)]
as) ((([Type], (Shape, Int, VName)) -> TypeM lore ()) -> TypeM lore ())
-> (([Type], (Shape, Int, VName)) -> TypeM lore ())
-> TypeM lore ()
forall a b. (a -> b) -> a -> b
$ \([Type]
rtVs, (Shape
aw, Int
_, VName
a)) -> do
(SubExp -> TypeM lore ()) -> Shape -> TypeM lore ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ ([Type] -> SubExp -> TypeM lore ()
forall lore. Checkable lore => [Type] -> SubExp -> TypeM lore ()
TC.require [PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64]) Shape
aw
[Type] -> (Type -> TypeM lore ()) -> TypeM lore ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [Type]
rtVs ((Type -> TypeM lore ()) -> TypeM lore ())
-> (Type -> TypeM lore ()) -> TypeM lore ()
forall a b. (a -> b) -> a -> b
$ \Type
rtV -> [Type] -> VName -> TypeM lore ()
forall lore. Checkable lore => [Type] -> VName -> TypeM lore ()
TC.requireI [Type -> Shape -> Type
arrayOfShape Type
rtV Shape
aw] VName
a
Names -> TypeM lore ()
forall lore. Checkable lore => Names -> TypeM lore ()
TC.consume (Names -> TypeM lore ()) -> TypeM lore Names -> TypeM lore ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< VName -> TypeM lore Names
forall lore. Checkable lore => VName -> TypeM lore Names
TC.lookupAliases VName
a
[Arg]
arrargs <- SubExp -> [VName] -> TypeM lore [Arg]
forall lore.
Checkable lore =>
SubExp -> [VName] -> TypeM lore [Arg]
TC.checkSOACArrayArgs SubExp
w [VName]
ivs
Lambda (Aliases lore) -> [Arg] -> TypeM lore ()
forall lore.
Checkable lore =>
Lambda (Aliases lore) -> [Arg] -> TypeM lore ()
TC.checkLambda Lambda (Aliases lore)
lam [Arg]
arrargs
typeCheckSOAC (Hist SubExp
len [HistOp (Aliases lore)]
ops Lambda (Aliases lore)
bucket_fun [VName]
imgs) = do
[Type] -> SubExp -> TypeM lore ()
forall lore. Checkable lore => [Type] -> SubExp -> TypeM lore ()
TC.require [PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64] SubExp
len
[HistOp (Aliases lore)]
-> (HistOp (Aliases lore) -> TypeM lore ()) -> TypeM lore ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [HistOp (Aliases lore)]
ops ((HistOp (Aliases lore) -> TypeM lore ()) -> TypeM lore ())
-> (HistOp (Aliases lore) -> TypeM lore ()) -> TypeM lore ()
forall a b. (a -> b) -> a -> b
$ \(HistOp SubExp
dest_w SubExp
rf [VName]
dests [SubExp]
nes Lambda (Aliases lore)
op) -> do
[Arg]
nes' <- (SubExp -> TypeM lore Arg) -> [SubExp] -> TypeM lore [Arg]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM SubExp -> TypeM lore Arg
forall lore. Checkable lore => SubExp -> TypeM lore Arg
TC.checkArg [SubExp]
nes
[Type] -> SubExp -> TypeM lore ()
forall lore. Checkable lore => [Type] -> SubExp -> TypeM lore ()
TC.require [PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64] SubExp
dest_w
[Type] -> SubExp -> TypeM lore ()
forall lore. Checkable lore => [Type] -> SubExp -> TypeM lore ()
TC.require [PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64] SubExp
rf
Lambda (Aliases lore) -> [Arg] -> TypeM lore ()
forall lore.
Checkable lore =>
Lambda (Aliases lore) -> [Arg] -> TypeM lore ()
TC.checkLambda Lambda (Aliases lore)
op ([Arg] -> TypeM lore ()) -> [Arg] -> TypeM lore ()
forall a b. (a -> b) -> a -> b
$ (Arg -> Arg) -> [Arg] -> [Arg]
forall a b. (a -> b) -> [a] -> [b]
map Arg -> Arg
TC.noArgAliases ([Arg] -> [Arg]) -> [Arg] -> [Arg]
forall a b. (a -> b) -> a -> b
$ [Arg]
nes' [Arg] -> [Arg] -> [Arg]
forall a. [a] -> [a] -> [a]
++ [Arg]
nes'
let nes_t :: [Type]
nes_t = (Arg -> Type) -> [Arg] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map Arg -> Type
TC.argType [Arg]
nes'
Bool -> TypeM lore () -> TypeM lore ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless ([Type]
nes_t [Type] -> [Type] -> Bool
forall a. Eq a => a -> a -> Bool
== Lambda (Aliases lore) -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType Lambda (Aliases lore)
op) (TypeM lore () -> TypeM lore ()) -> TypeM lore () -> TypeM lore ()
forall a b. (a -> b) -> a -> b
$
ErrorCase lore -> TypeM lore ()
forall lore a. ErrorCase lore -> TypeM lore a
TC.bad (ErrorCase lore -> TypeM lore ())
-> ErrorCase lore -> TypeM lore ()
forall a b. (a -> b) -> a -> b
$
String -> ErrorCase lore
forall lore. String -> ErrorCase lore
TC.TypeError (String -> ErrorCase lore) -> String -> ErrorCase lore
forall a b. (a -> b) -> a -> b
$
String
"Operator has return type "
String -> ShowS
forall a. [a] -> [a] -> [a]
++ [Type] -> String
forall a. Pretty a => [a] -> String
prettyTuple (Lambda (Aliases lore) -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType Lambda (Aliases lore)
op)
String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
" but neutral element has type "
String -> ShowS
forall a. [a] -> [a] -> [a]
++ [Type] -> String
forall a. Pretty a => [a] -> String
prettyTuple [Type]
nes_t
[(Type, VName)]
-> ((Type, VName) -> TypeM lore ()) -> TypeM lore ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Type] -> [VName] -> [(Type, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Type]
nes_t [VName]
dests) (((Type, VName) -> TypeM lore ()) -> TypeM lore ())
-> ((Type, VName) -> TypeM lore ()) -> TypeM lore ()
forall a b. (a -> b) -> a -> b
$ \(Type
t, VName
dest) -> do
[Type] -> VName -> TypeM lore ()
forall lore. Checkable lore => [Type] -> VName -> TypeM lore ()
TC.requireI [Type
t Type -> SubExp -> Type
forall d.
ArrayShape (ShapeBase d) =>
TypeBase (ShapeBase d) NoUniqueness
-> d -> TypeBase (ShapeBase d) NoUniqueness
`arrayOfRow` SubExp
dest_w] VName
dest
Names -> TypeM lore ()
forall lore. Checkable lore => Names -> TypeM lore ()
TC.consume (Names -> TypeM lore ()) -> TypeM lore Names -> TypeM lore ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< VName -> TypeM lore Names
forall lore. Checkable lore => VName -> TypeM lore Names
TC.lookupAliases VName
dest
[Arg]
img' <- SubExp -> [VName] -> TypeM lore [Arg]
forall lore.
Checkable lore =>
SubExp -> [VName] -> TypeM lore [Arg]
TC.checkSOACArrayArgs SubExp
len [VName]
imgs
Lambda (Aliases lore) -> [Arg] -> TypeM lore ()
forall lore.
Checkable lore =>
Lambda (Aliases lore) -> [Arg] -> TypeM lore ()
TC.checkLambda Lambda (Aliases lore)
bucket_fun [Arg]
img'
[Type]
nes_ts <- [[Type]] -> [Type]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat ([[Type]] -> [Type]) -> TypeM lore [[Type]] -> TypeM lore [Type]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (HistOp (Aliases lore) -> TypeM lore [Type])
-> [HistOp (Aliases lore)] -> TypeM lore [[Type]]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM ((SubExp -> TypeM lore Type) -> [SubExp] -> TypeM lore [Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM SubExp -> TypeM lore Type
forall t (m :: * -> *). HasScope t m => SubExp -> m Type
subExpType ([SubExp] -> TypeM lore [Type])
-> (HistOp (Aliases lore) -> [SubExp])
-> HistOp (Aliases lore)
-> TypeM lore [Type]
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. HistOp (Aliases lore) -> [SubExp]
forall lore. HistOp lore -> [SubExp]
histNeutral) [HistOp (Aliases lore)]
ops
let bucket_ret_t :: [Type]
bucket_ret_t = Int -> Type -> [Type]
forall a. Int -> a -> [a]
replicate ([HistOp (Aliases lore)] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [HistOp (Aliases lore)]
ops) (PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64) [Type] -> [Type] -> [Type]
forall a. [a] -> [a] -> [a]
++ [Type]
nes_ts
Bool -> TypeM lore () -> TypeM lore ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless ([Type]
bucket_ret_t [Type] -> [Type] -> Bool
forall a. Eq a => a -> a -> Bool
== Lambda (Aliases lore) -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType Lambda (Aliases lore)
bucket_fun) (TypeM lore () -> TypeM lore ()) -> TypeM lore () -> TypeM lore ()
forall a b. (a -> b) -> a -> b
$
ErrorCase lore -> TypeM lore ()
forall lore a. ErrorCase lore -> TypeM lore a
TC.bad (ErrorCase lore -> TypeM lore ())
-> ErrorCase lore -> TypeM lore ()
forall a b. (a -> b) -> a -> b
$
String -> ErrorCase lore
forall lore. String -> ErrorCase lore
TC.TypeError (String -> ErrorCase lore) -> String -> ErrorCase lore
forall a b. (a -> b) -> a -> b
$
String
"Bucket function has return type "
String -> ShowS
forall a. [a] -> [a] -> [a]
++ [Type] -> String
forall a. Pretty a => [a] -> String
prettyTuple (Lambda (Aliases lore) -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType Lambda (Aliases lore)
bucket_fun)
String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
" but should have type "
String -> ShowS
forall a. [a] -> [a] -> [a]
++ [Type] -> String
forall a. Pretty a => [a] -> String
prettyTuple [Type]
bucket_ret_t
typeCheckSOAC (Screma SubExp
w [VName]
arrs (ScremaForm [Scan (Aliases lore)]
scans [Reduce (Aliases lore)]
reds Lambda (Aliases lore)
map_lam)) = do
[Type] -> SubExp -> TypeM lore ()
forall lore. Checkable lore => [Type] -> SubExp -> TypeM lore ()
TC.require [PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64] SubExp
w
[Arg]
arrs' <- SubExp -> [VName] -> TypeM lore [Arg]
forall lore.
Checkable lore =>
SubExp -> [VName] -> TypeM lore [Arg]
TC.checkSOACArrayArgs SubExp
w [VName]
arrs
Lambda (Aliases lore) -> [Arg] -> TypeM lore ()
forall lore.
Checkable lore =>
Lambda (Aliases lore) -> [Arg] -> TypeM lore ()
TC.checkLambda Lambda (Aliases lore)
map_lam [Arg]
arrs'
[Arg]
scan_nes' <- ([[Arg]] -> [Arg]) -> TypeM lore [[Arg]] -> TypeM lore [Arg]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [[Arg]] -> [Arg]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat (TypeM lore [[Arg]] -> TypeM lore [Arg])
-> TypeM lore [[Arg]] -> TypeM lore [Arg]
forall a b. (a -> b) -> a -> b
$
[Scan (Aliases lore)]
-> (Scan (Aliases lore) -> TypeM lore [Arg]) -> TypeM lore [[Arg]]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [Scan (Aliases lore)]
scans ((Scan (Aliases lore) -> TypeM lore [Arg]) -> TypeM lore [[Arg]])
-> (Scan (Aliases lore) -> TypeM lore [Arg]) -> TypeM lore [[Arg]]
forall a b. (a -> b) -> a -> b
$ \(Scan Lambda (Aliases lore)
scan_lam [SubExp]
scan_nes) -> do
[Arg]
scan_nes' <- (SubExp -> TypeM lore Arg) -> [SubExp] -> TypeM lore [Arg]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM SubExp -> TypeM lore Arg
forall lore. Checkable lore => SubExp -> TypeM lore Arg
TC.checkArg [SubExp]
scan_nes
let scan_t :: [Type]
scan_t = (Arg -> Type) -> [Arg] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map Arg -> Type
TC.argType [Arg]
scan_nes'
Lambda (Aliases lore) -> [Arg] -> TypeM lore ()
forall lore.
Checkable lore =>
Lambda (Aliases lore) -> [Arg] -> TypeM lore ()
TC.checkLambda Lambda (Aliases lore)
scan_lam ([Arg] -> TypeM lore ()) -> [Arg] -> TypeM lore ()
forall a b. (a -> b) -> a -> b
$ (Arg -> Arg) -> [Arg] -> [Arg]
forall a b. (a -> b) -> [a] -> [b]
map Arg -> Arg
TC.noArgAliases ([Arg] -> [Arg]) -> [Arg] -> [Arg]
forall a b. (a -> b) -> a -> b
$ [Arg]
scan_nes' [Arg] -> [Arg] -> [Arg]
forall a. [a] -> [a] -> [a]
++ [Arg]
scan_nes'
Bool -> TypeM lore () -> TypeM lore ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless ([Type]
scan_t [Type] -> [Type] -> Bool
forall a. Eq a => a -> a -> Bool
== Lambda (Aliases lore) -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType Lambda (Aliases lore)
scan_lam) (TypeM lore () -> TypeM lore ()) -> TypeM lore () -> TypeM lore ()
forall a b. (a -> b) -> a -> b
$
ErrorCase lore -> TypeM lore ()
forall lore a. ErrorCase lore -> TypeM lore a
TC.bad (ErrorCase lore -> TypeM lore ())
-> ErrorCase lore -> TypeM lore ()
forall a b. (a -> b) -> a -> b
$
String -> ErrorCase lore
forall lore. String -> ErrorCase lore
TC.TypeError (String -> ErrorCase lore) -> String -> ErrorCase lore
forall a b. (a -> b) -> a -> b
$
String
"Scan function returns type "
String -> ShowS
forall a. [a] -> [a] -> [a]
++ [Type] -> String
forall a. Pretty a => [a] -> String
prettyTuple (Lambda (Aliases lore) -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType Lambda (Aliases lore)
scan_lam)
String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
" but neutral element has type "
String -> ShowS
forall a. [a] -> [a] -> [a]
++ [Type] -> String
forall a. Pretty a => [a] -> String
prettyTuple [Type]
scan_t
[Arg] -> TypeM lore [Arg]
forall (m :: * -> *) a. Monad m => a -> m a
return [Arg]
scan_nes'
[Arg]
red_nes' <- ([[Arg]] -> [Arg]) -> TypeM lore [[Arg]] -> TypeM lore [Arg]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [[Arg]] -> [Arg]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat (TypeM lore [[Arg]] -> TypeM lore [Arg])
-> TypeM lore [[Arg]] -> TypeM lore [Arg]
forall a b. (a -> b) -> a -> b
$
[Reduce (Aliases lore)]
-> (Reduce (Aliases lore) -> TypeM lore [Arg])
-> TypeM lore [[Arg]]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [Reduce (Aliases lore)]
reds ((Reduce (Aliases lore) -> TypeM lore [Arg]) -> TypeM lore [[Arg]])
-> (Reduce (Aliases lore) -> TypeM lore [Arg])
-> TypeM lore [[Arg]]
forall a b. (a -> b) -> a -> b
$ \(Reduce Commutativity
_ Lambda (Aliases lore)
red_lam [SubExp]
red_nes) -> do
[Arg]
red_nes' <- (SubExp -> TypeM lore Arg) -> [SubExp] -> TypeM lore [Arg]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM SubExp -> TypeM lore Arg
forall lore. Checkable lore => SubExp -> TypeM lore Arg
TC.checkArg [SubExp]
red_nes
let red_t :: [Type]
red_t = (Arg -> Type) -> [Arg] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map Arg -> Type
TC.argType [Arg]
red_nes'
Lambda (Aliases lore) -> [Arg] -> TypeM lore ()
forall lore.
Checkable lore =>
Lambda (Aliases lore) -> [Arg] -> TypeM lore ()
TC.checkLambda Lambda (Aliases lore)
red_lam ([Arg] -> TypeM lore ()) -> [Arg] -> TypeM lore ()
forall a b. (a -> b) -> a -> b
$ (Arg -> Arg) -> [Arg] -> [Arg]
forall a b. (a -> b) -> [a] -> [b]
map Arg -> Arg
TC.noArgAliases ([Arg] -> [Arg]) -> [Arg] -> [Arg]
forall a b. (a -> b) -> a -> b
$ [Arg]
red_nes' [Arg] -> [Arg] -> [Arg]
forall a. [a] -> [a] -> [a]
++ [Arg]
red_nes'
Bool -> TypeM lore () -> TypeM lore ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless ([Type]
red_t [Type] -> [Type] -> Bool
forall a. Eq a => a -> a -> Bool
== Lambda (Aliases lore) -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType Lambda (Aliases lore)
red_lam) (TypeM lore () -> TypeM lore ()) -> TypeM lore () -> TypeM lore ()
forall a b. (a -> b) -> a -> b
$
ErrorCase lore -> TypeM lore ()
forall lore a. ErrorCase lore -> TypeM lore a
TC.bad (ErrorCase lore -> TypeM lore ())
-> ErrorCase lore -> TypeM lore ()
forall a b. (a -> b) -> a -> b
$
String -> ErrorCase lore
forall lore. String -> ErrorCase lore
TC.TypeError (String -> ErrorCase lore) -> String -> ErrorCase lore
forall a b. (a -> b) -> a -> b
$
String
"Reduce function returns type "
String -> ShowS
forall a. [a] -> [a] -> [a]
++ [Type] -> String
forall a. Pretty a => [a] -> String
prettyTuple (Lambda (Aliases lore) -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType Lambda (Aliases lore)
red_lam)
String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
" but neutral element has type "
String -> ShowS
forall a. [a] -> [a] -> [a]
++ [Type] -> String
forall a. Pretty a => [a] -> String
prettyTuple [Type]
red_t
[Arg] -> TypeM lore [Arg]
forall (m :: * -> *) a. Monad m => a -> m a
return [Arg]
red_nes'
let map_lam_ts :: [Type]
map_lam_ts = Lambda (Aliases lore) -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType Lambda (Aliases lore)
map_lam
Bool -> TypeM lore () -> TypeM lore ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless
( Int -> [Type] -> [Type]
forall a. Int -> [a] -> [a]
take ([Arg] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Arg]
scan_nes' Int -> Int -> Int
forall a. Num a => a -> a -> a
+ [Arg] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Arg]
red_nes') [Type]
map_lam_ts
[Type] -> [Type] -> Bool
forall a. Eq a => a -> a -> Bool
== (Arg -> Type) -> [Arg] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map Arg -> Type
TC.argType ([Arg]
scan_nes' [Arg] -> [Arg] -> [Arg]
forall a. [a] -> [a] -> [a]
++ [Arg]
red_nes')
)
(TypeM lore () -> TypeM lore ()) -> TypeM lore () -> TypeM lore ()
forall a b. (a -> b) -> a -> b
$ ErrorCase lore -> TypeM lore ()
forall lore a. ErrorCase lore -> TypeM lore a
TC.bad (ErrorCase lore -> TypeM lore ())
-> ErrorCase lore -> TypeM lore ()
forall a b. (a -> b) -> a -> b
$
String -> ErrorCase lore
forall lore. String -> ErrorCase lore
TC.TypeError (String -> ErrorCase lore) -> String -> ErrorCase lore
forall a b. (a -> b) -> a -> b
$
String
"Map function return type " String -> ShowS
forall a. [a] -> [a] -> [a]
++ [Type] -> String
forall a. Pretty a => [a] -> String
prettyTuple [Type]
map_lam_ts
String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
" wrong for given scan and reduction functions."
instance OpMetrics (Op lore) => OpMetrics (SOAC lore) where
opMetrics :: SOAC lore -> MetricsM ()
opMetrics (Stream SubExp
_ [VName]
_ StreamForm lore
_ [SubExp]
_ Lambda lore
lam) =
Text -> MetricsM () -> MetricsM ()
inside Text
"Stream" (MetricsM () -> MetricsM ()) -> MetricsM () -> MetricsM ()
forall a b. (a -> b) -> a -> b
$ Lambda lore -> MetricsM ()
forall lore. OpMetrics (Op lore) => Lambda lore -> MetricsM ()
lambdaMetrics Lambda lore
lam
opMetrics (Scatter SubExp
_len Lambda lore
lam [VName]
_ivs [(Shape, Int, VName)]
_as) =
Text -> MetricsM () -> MetricsM ()
inside Text
"Scatter" (MetricsM () -> MetricsM ()) -> MetricsM () -> MetricsM ()
forall a b. (a -> b) -> a -> b
$ Lambda lore -> MetricsM ()
forall lore. OpMetrics (Op lore) => Lambda lore -> MetricsM ()
lambdaMetrics Lambda lore
lam
opMetrics (Hist SubExp
_len [HistOp lore]
ops Lambda lore
bucket_fun [VName]
_imgs) =
Text -> MetricsM () -> MetricsM ()
inside Text
"Hist" (MetricsM () -> MetricsM ()) -> MetricsM () -> MetricsM ()
forall a b. (a -> b) -> a -> b
$ (HistOp lore -> MetricsM ()) -> [HistOp lore] -> MetricsM ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (Lambda lore -> MetricsM ()
forall lore. OpMetrics (Op lore) => Lambda lore -> MetricsM ()
lambdaMetrics (Lambda lore -> MetricsM ())
-> (HistOp lore -> Lambda lore) -> HistOp lore -> MetricsM ()
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. HistOp lore -> Lambda lore
forall lore. HistOp lore -> Lambda lore
histOp) [HistOp lore]
ops MetricsM () -> MetricsM () -> MetricsM ()
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Lambda lore -> MetricsM ()
forall lore. OpMetrics (Op lore) => Lambda lore -> MetricsM ()
lambdaMetrics Lambda lore
bucket_fun
opMetrics (Screma SubExp
_ [VName]
_ (ScremaForm [Scan lore]
scans [Reduce lore]
reds Lambda lore
map_lam)) =
Text -> MetricsM () -> MetricsM ()
inside Text
"Screma" (MetricsM () -> MetricsM ()) -> MetricsM () -> MetricsM ()
forall a b. (a -> b) -> a -> b
$ do
(Scan lore -> MetricsM ()) -> [Scan lore] -> MetricsM ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (Lambda lore -> MetricsM ()
forall lore. OpMetrics (Op lore) => Lambda lore -> MetricsM ()
lambdaMetrics (Lambda lore -> MetricsM ())
-> (Scan lore -> Lambda lore) -> Scan lore -> MetricsM ()
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Scan lore -> Lambda lore
forall lore. Scan lore -> Lambda lore
scanLambda) [Scan lore]
scans
(Reduce lore -> MetricsM ()) -> [Reduce lore] -> MetricsM ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (Lambda lore -> MetricsM ()
forall lore. OpMetrics (Op lore) => Lambda lore -> MetricsM ()
lambdaMetrics (Lambda lore -> MetricsM ())
-> (Reduce lore -> Lambda lore) -> Reduce lore -> MetricsM ()
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Reduce lore -> Lambda lore
forall lore. Reduce lore -> Lambda lore
redLambda) [Reduce lore]
reds
Lambda lore -> MetricsM ()
forall lore. OpMetrics (Op lore) => Lambda lore -> MetricsM ()
lambdaMetrics Lambda lore
map_lam
instance PrettyLore lore => PP.Pretty (SOAC lore) where
ppr :: SOAC lore -> Doc
ppr (Stream SubExp
size [VName]
arrs StreamForm lore
form [SubExp]
acc Lambda lore
lam) =
case StreamForm lore
form of
Parallel StreamOrd
o Commutativity
comm Lambda lore
lam0 ->
let ord_str :: String
ord_str = if StreamOrd
o StreamOrd -> StreamOrd -> Bool
forall a. Eq a => a -> a -> Bool
== StreamOrd
Disorder then String
"Per" else String
""
comm_str :: String
comm_str = case Commutativity
comm of
Commutativity
Commutative -> String
"Comm"
Commutativity
Noncommutative -> String
""
in String -> Doc
text (String
"streamPar" String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
ord_str String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
comm_str)
Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc -> Doc
parens
( SubExp -> Doc
forall a. Pretty a => a -> Doc
ppr SubExp
size Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
comma
Doc -> Doc -> Doc
</> [VName] -> Doc
forall a. Pretty a => [a] -> Doc
ppTuple' [VName]
arrs Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
comma
Doc -> Doc -> Doc
</> Lambda lore -> Doc
forall a. Pretty a => a -> Doc
ppr Lambda lore
lam0 Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
comma
Doc -> Doc -> Doc
</> [SubExp] -> Doc
forall a. Pretty a => [a] -> Doc
ppTuple' [SubExp]
acc Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
comma
Doc -> Doc -> Doc
</> Lambda lore -> Doc
forall a. Pretty a => a -> Doc
ppr Lambda lore
lam
)
StreamForm lore
Sequential ->
String -> Doc
text String
"streamSeq"
Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc -> Doc
parens
( SubExp -> Doc
forall a. Pretty a => a -> Doc
ppr SubExp
size Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
comma
Doc -> Doc -> Doc
</> [VName] -> Doc
forall a. Pretty a => [a] -> Doc
ppTuple' [VName]
arrs Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
comma
Doc -> Doc -> Doc
</> [SubExp] -> Doc
forall a. Pretty a => [a] -> Doc
ppTuple' [SubExp]
acc Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
comma
Doc -> Doc -> Doc
</> Lambda lore -> Doc
forall a. Pretty a => a -> Doc
ppr Lambda lore
lam
)
ppr (Scatter SubExp
w Lambda lore
lam [VName]
ivs [(Shape, Int, VName)]
as) =
Doc
"scatter"
Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc -> Doc
parens
( SubExp -> Doc
forall a. Pretty a => a -> Doc
ppr SubExp
w Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
comma
Doc -> Doc -> Doc
</> Lambda lore -> Doc
forall a. Pretty a => a -> Doc
ppr Lambda lore
lam Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
comma
Doc -> Doc -> Doc
</> [Doc] -> Doc
commasep ([VName] -> Doc
forall a. Pretty a => [a] -> Doc
ppTuple' [VName]
ivs Doc -> [Doc] -> [Doc]
forall a. a -> [a] -> [a]
: ((Shape, Int, VName) -> Doc) -> [(Shape, Int, VName)] -> [Doc]
forall a b. (a -> b) -> [a] -> [b]
map (Shape, Int, VName) -> Doc
forall a. Pretty a => a -> Doc
ppr [(Shape, Int, VName)]
as)
)
ppr (Hist SubExp
len [HistOp lore]
ops Lambda lore
bucket_fun [VName]
imgs) =
SubExp -> [HistOp lore] -> Lambda lore -> [VName] -> Doc
forall lore inp.
(PrettyLore lore, Pretty inp) =>
SubExp -> [HistOp lore] -> Lambda lore -> [inp] -> Doc
ppHist SubExp
len [HistOp lore]
ops Lambda lore
bucket_fun [VName]
imgs
ppr (Screma SubExp
w [VName]
arrs (ScremaForm [Scan lore]
scans [Reduce lore]
reds Lambda lore
map_lam))
| [Scan lore] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Scan lore]
scans,
[Reduce lore] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Reduce lore]
reds =
String -> Doc
text String
"map"
Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc -> Doc
parens
( SubExp -> Doc
forall a. Pretty a => a -> Doc
ppr SubExp
w Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
comma
Doc -> Doc -> Doc
</> [VName] -> Doc
forall a. Pretty a => [a] -> Doc
ppTuple' [VName]
arrs Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
comma
Doc -> Doc -> Doc
</> Lambda lore -> Doc
forall a. Pretty a => a -> Doc
ppr Lambda lore
map_lam
)
| [Scan lore] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Scan lore]
scans =
String -> Doc
text String
"redomap"
Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc -> Doc
parens
( SubExp -> Doc
forall a. Pretty a => a -> Doc
ppr SubExp
w Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
comma
Doc -> Doc -> Doc
</> [VName] -> Doc
forall a. Pretty a => [a] -> Doc
ppTuple' [VName]
arrs Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
comma
Doc -> Doc -> Doc
</> Doc -> Doc
PP.braces ([Doc] -> Doc
forall a. Monoid a => [a] -> a
mconcat ([Doc] -> Doc) -> [Doc] -> Doc
forall a b. (a -> b) -> a -> b
$ Doc -> [Doc] -> [Doc]
forall a. a -> [a] -> [a]
intersperse (Doc
comma Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
PP.line) ([Doc] -> [Doc]) -> [Doc] -> [Doc]
forall a b. (a -> b) -> a -> b
$ (Reduce lore -> Doc) -> [Reduce lore] -> [Doc]
forall a b. (a -> b) -> [a] -> [b]
map Reduce lore -> Doc
forall a. Pretty a => a -> Doc
ppr [Reduce lore]
reds) Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
comma
Doc -> Doc -> Doc
</> Lambda lore -> Doc
forall a. Pretty a => a -> Doc
ppr Lambda lore
map_lam
)
| [Reduce lore] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Reduce lore]
reds =
String -> Doc
text String
"scanomap"
Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc -> Doc
parens
( SubExp -> Doc
forall a. Pretty a => a -> Doc
ppr SubExp
w Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
comma
Doc -> Doc -> Doc
</> [VName] -> Doc
forall a. Pretty a => [a] -> Doc
ppTuple' [VName]
arrs Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
comma
Doc -> Doc -> Doc
</> Doc -> Doc
PP.braces ([Doc] -> Doc
forall a. Monoid a => [a] -> a
mconcat ([Doc] -> Doc) -> [Doc] -> Doc
forall a b. (a -> b) -> a -> b
$ Doc -> [Doc] -> [Doc]
forall a. a -> [a] -> [a]
intersperse (Doc
comma Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
PP.line) ([Doc] -> [Doc]) -> [Doc] -> [Doc]
forall a b. (a -> b) -> a -> b
$ (Scan lore -> Doc) -> [Scan lore] -> [Doc]
forall a b. (a -> b) -> [a] -> [b]
map Scan lore -> Doc
forall a. Pretty a => a -> Doc
ppr [Scan lore]
scans) Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
comma
Doc -> Doc -> Doc
</> Lambda lore -> Doc
forall a. Pretty a => a -> Doc
ppr Lambda lore
map_lam
)
ppr (Screma SubExp
w [VName]
arrs ScremaForm lore
form) = SubExp -> [VName] -> ScremaForm lore -> Doc
forall lore inp.
(PrettyLore lore, Pretty inp) =>
SubExp -> [inp] -> ScremaForm lore -> Doc
ppScrema SubExp
w [VName]
arrs ScremaForm lore
form
ppScrema ::
(PrettyLore lore, Pretty inp) => SubExp -> [inp] -> ScremaForm lore -> Doc
ppScrema :: forall lore inp.
(PrettyLore lore, Pretty inp) =>
SubExp -> [inp] -> ScremaForm lore -> Doc
ppScrema SubExp
w [inp]
arrs (ScremaForm [Scan lore]
scans [Reduce lore]
reds Lambda lore
map_lam) =
String -> Doc
text String
"screma"
Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc -> Doc
parens
( SubExp -> Doc
forall a. Pretty a => a -> Doc
ppr SubExp
w Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
comma
Doc -> Doc -> Doc
</> [inp] -> Doc
forall a. Pretty a => [a] -> Doc
ppTuple' [inp]
arrs Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
comma
Doc -> Doc -> Doc
</> Doc -> Doc
PP.braces ([Doc] -> Doc
forall a. Monoid a => [a] -> a
mconcat ([Doc] -> Doc) -> [Doc] -> Doc
forall a b. (a -> b) -> a -> b
$ Doc -> [Doc] -> [Doc]
forall a. a -> [a] -> [a]
intersperse (Doc
comma Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
PP.line) ([Doc] -> [Doc]) -> [Doc] -> [Doc]
forall a b. (a -> b) -> a -> b
$ (Scan lore -> Doc) -> [Scan lore] -> [Doc]
forall a b. (a -> b) -> [a] -> [b]
map Scan lore -> Doc
forall a. Pretty a => a -> Doc
ppr [Scan lore]
scans) Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
comma
Doc -> Doc -> Doc
</> Doc -> Doc
PP.braces ([Doc] -> Doc
forall a. Monoid a => [a] -> a
mconcat ([Doc] -> Doc) -> [Doc] -> Doc
forall a b. (a -> b) -> a -> b
$ Doc -> [Doc] -> [Doc]
forall a. a -> [a] -> [a]
intersperse (Doc
comma Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
PP.line) ([Doc] -> [Doc]) -> [Doc] -> [Doc]
forall a b. (a -> b) -> a -> b
$ (Reduce lore -> Doc) -> [Reduce lore] -> [Doc]
forall a b. (a -> b) -> [a] -> [b]
map Reduce lore -> Doc
forall a. Pretty a => a -> Doc
ppr [Reduce lore]
reds) Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
comma
Doc -> Doc -> Doc
</> Lambda lore -> Doc
forall a. Pretty a => a -> Doc
ppr Lambda lore
map_lam
)
instance PrettyLore lore => Pretty (Scan lore) where
ppr :: Scan lore -> Doc
ppr (Scan Lambda lore
scan_lam [SubExp]
scan_nes) =
Lambda lore -> Doc
forall a. Pretty a => a -> Doc
ppr Lambda lore
scan_lam Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
comma Doc -> Doc -> Doc
</> Doc -> Doc
PP.braces ([Doc] -> Doc
commasep ([Doc] -> Doc) -> [Doc] -> Doc
forall a b. (a -> b) -> a -> b
$ (SubExp -> Doc) -> [SubExp] -> [Doc]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> Doc
forall a. Pretty a => a -> Doc
ppr [SubExp]
scan_nes)
ppComm :: Commutativity -> Doc
ppComm :: Commutativity -> Doc
ppComm Commutativity
Noncommutative = Doc
forall a. Monoid a => a
mempty
ppComm Commutativity
Commutative = String -> Doc
text String
"commutative "
instance PrettyLore lore => Pretty (Reduce lore) where
ppr :: Reduce lore -> Doc
ppr (Reduce Commutativity
comm Lambda lore
red_lam [SubExp]
red_nes) =
Commutativity -> Doc
ppComm Commutativity
comm Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Lambda lore -> Doc
forall a. Pretty a => a -> Doc
ppr Lambda lore
red_lam Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
comma
Doc -> Doc -> Doc
</> Doc -> Doc
PP.braces ([Doc] -> Doc
commasep ([Doc] -> Doc) -> [Doc] -> Doc
forall a b. (a -> b) -> a -> b
$ (SubExp -> Doc) -> [SubExp] -> [Doc]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> Doc
forall a. Pretty a => a -> Doc
ppr [SubExp]
red_nes)
ppHist ::
(PrettyLore lore, Pretty inp) =>
SubExp ->
[HistOp lore] ->
Lambda lore ->
[inp] ->
Doc
ppHist :: forall lore inp.
(PrettyLore lore, Pretty inp) =>
SubExp -> [HistOp lore] -> Lambda lore -> [inp] -> Doc
ppHist SubExp
len [HistOp lore]
ops Lambda lore
bucket_fun [inp]
imgs =
String -> Doc
text String
"hist"
Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc -> Doc
parens
( SubExp -> Doc
forall a. Pretty a => a -> Doc
ppr SubExp
len Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
comma
Doc -> Doc -> Doc
</> Doc -> Doc
PP.braces ([Doc] -> Doc
forall a. Monoid a => [a] -> a
mconcat ([Doc] -> Doc) -> [Doc] -> Doc
forall a b. (a -> b) -> a -> b
$ Doc -> [Doc] -> [Doc]
forall a. a -> [a] -> [a]
intersperse (Doc
comma Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
PP.line) ([Doc] -> [Doc]) -> [Doc] -> [Doc]
forall a b. (a -> b) -> a -> b
$ (HistOp lore -> Doc) -> [HistOp lore] -> [Doc]
forall a b. (a -> b) -> [a] -> [b]
map HistOp lore -> Doc
forall {lore}. PrettyLore lore => HistOp lore -> Doc
ppOp [HistOp lore]
ops) Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
comma
Doc -> Doc -> Doc
</> Lambda lore -> Doc
forall a. Pretty a => a -> Doc
ppr Lambda lore
bucket_fun Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
comma
Doc -> Doc -> Doc
</> [Doc] -> Doc
commasep ((inp -> Doc) -> [inp] -> [Doc]
forall a b. (a -> b) -> [a] -> [b]
map inp -> Doc
forall a. Pretty a => a -> Doc
ppr [inp]
imgs)
)
where
ppOp :: HistOp lore -> Doc
ppOp (HistOp SubExp
w SubExp
rf [VName]
dests [SubExp]
nes Lambda lore
op) =
SubExp -> Doc
forall a. Pretty a => a -> Doc
ppr SubExp
w Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
comma Doc -> Doc -> Doc
<+> SubExp -> Doc
forall a. Pretty a => a -> Doc
ppr SubExp
rf Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
comma Doc -> Doc -> Doc
<+> Doc -> Doc
PP.braces ([Doc] -> Doc
commasep ([Doc] -> Doc) -> [Doc] -> Doc
forall a b. (a -> b) -> a -> b
$ (VName -> Doc) -> [VName] -> [Doc]
forall a b. (a -> b) -> [a] -> [b]
map VName -> Doc
forall a. Pretty a => a -> Doc
ppr [VName]
dests) Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
comma
Doc -> Doc -> Doc
</> Doc -> Doc
PP.braces ([Doc] -> Doc
commasep ([Doc] -> Doc) -> [Doc] -> Doc
forall a b. (a -> b) -> a -> b
$ (SubExp -> Doc) -> [SubExp] -> [Doc]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> Doc
forall a. Pretty a => a -> Doc
ppr [SubExp]
nes) Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
comma
Doc -> Doc -> Doc
</> Lambda lore -> Doc
forall a. Pretty a => a -> Doc
ppr Lambda lore
op