{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE ConstraintKinds #-}
module Futhark.Representation.SOACS.SOAC
( SOAC(..)
, StreamForm(..)
, ScremaForm(..)
, GenReduceOp(..)
, Scan
, Reduce
, getStreamOrder
, getStreamAccums
, scremaType
, soacType
, typeCheckSOAC
, mkIdentityLambda
, isIdentityLambda
, composeLambda
, nilFn
, scanomapSOAC
, redomapSOAC
, scanSOAC
, reduceSOAC
, mapSOAC
, isScanomapSOAC
, isRedomapSOAC
, isScanSOAC
, isReduceSOAC
, isMapSOAC
, ppScrema
, ppGenReduce
, SOACMapper(..)
, identitySOACMapper
, mapSOACM
)
where
import Control.Monad.Writer
import Control.Monad.Identity
import qualified Data.Map.Strict as M
import qualified Data.Set as S
import Data.Maybe
import Data.List
import Futhark.Representation.AST
import qualified Futhark.Analysis.Alias as Alias
import qualified Futhark.Util.Pretty as PP
import Futhark.Util.Pretty (ppr, Doc, Pretty, parens, comma, (</>), commasep, text)
import Futhark.Representation.AST.Attributes.Aliases
import Futhark.Transform.Substitute
import Futhark.Transform.Rename
import Futhark.Optimise.Simplify.Lore
import Futhark.Representation.Ranges (Ranges, removeLambdaRanges)
import Futhark.Representation.AST.Attributes.Ranges
import Futhark.Representation.Aliases (Aliases, removeLambdaAliases)
import Futhark.Analysis.Usage
import qualified Futhark.Analysis.SymbolTable as ST
import Futhark.Analysis.PrimExp.Convert
import qualified Futhark.TypeCheck as TC
import Futhark.Analysis.Metrics
import qualified Futhark.Analysis.Range as Range
import Futhark.Construct
import Futhark.Util (maybeNth, chunks, splitAt3)
data SOAC lore =
Stream SubExp (StreamForm lore) (LambdaT lore) [VName]
| Scatter SubExp (LambdaT lore) [VName] [(SubExp, Int, VName)]
| GenReduce SubExp [GenReduceOp lore] (LambdaT lore) [VName]
| Screma SubExp (ScremaForm lore) [VName]
| CmpThreshold SubExp String
deriving (Eq, Ord, Show)
data GenReduceOp lore = GenReduceOp { genReduceWidth :: SubExp
, genReduceDest :: [VName]
, genReduceNeutral :: [SubExp]
, genReduceOp :: LambdaT lore
}
deriving (Eq, Ord, Show)
data StreamForm lore =
Parallel StreamOrd Commutativity (LambdaT lore) [SubExp]
| Sequential [SubExp]
deriving (Eq, Ord, Show)
data ScremaForm lore = ScremaForm
(Scan lore)
(Reduce lore)
(LambdaT lore)
deriving (Eq, Ord, Show)
type Scan lore = (LambdaT lore, [SubExp])
type Reduce lore = (Commutativity, LambdaT lore, [SubExp])
scremaType :: SubExp -> ScremaForm lore -> [Type]
scremaType w (ScremaForm (scan_lam, _scan_nes) (_, red_lam, _red_nes) map_lam) =
map (`arrayOfRow` w) scan_tps ++ red_tps ++ map (`arrayOfRow` w) map_tps
where scan_tps = lambdaReturnType scan_lam
red_tps = lambdaReturnType red_lam
map_tps = drop (length scan_tps + length red_tps) $ lambdaReturnType map_lam
mkIdentityLambda :: (Bindable lore, MonadFreshNames m) =>
[Type] -> m (Lambda lore)
mkIdentityLambda ts = do
params <- mapM (newParam "x") ts
return Lambda { lambdaParams = params
, lambdaBody = mkBody mempty $ map (Var . paramName) params
, lambdaReturnType = ts }
isIdentityLambda :: Lambda lore -> Bool
isIdentityLambda lam = bodyResult (lambdaBody lam) ==
map (Var . paramName) (lambdaParams lam)
composeLambda :: (Bindable lore, BinderOps lore, MonadFreshNames m,
HasScope somelore m, SameScope somelore lore) =>
Lambda lore
-> Lambda lore
-> Lambda lore
-> m (Lambda lore)
composeLambda scan_fun red_fun map_fun = do
body <- runBodyBinder $ inScopeOf scan_fun $ inScopeOf red_fun $ inScopeOf map_fun $ do
mapM_ addStm $ bodyStms $ lambdaBody map_fun
let (scan_res, red_res, map_res) = splitAt3 n m $ bodyResult $ lambdaBody map_fun
forM_ (zip scan_y_params scan_res) $ \(p,se) ->
letBindNames_ [paramName p] $ BasicOp $ SubExp se
forM_ (zip red_y_params red_res) $ \(p,se) ->
letBindNames_ [paramName p] $ BasicOp $ SubExp se
mapM_ addStm $ bodyStms $ lambdaBody scan_fun
mapM_ addStm $ bodyStms $ lambdaBody red_fun
resultBodyM $
bodyResult (lambdaBody scan_fun) ++
bodyResult (lambdaBody red_fun) ++
map_res
return Lambda { lambdaParams = scan_x_params ++ red_x_params ++ lambdaParams map_fun
, lambdaBody = body
, lambdaReturnType = lambdaReturnType map_fun }
where n = length $ lambdaReturnType scan_fun
m = length $ lambdaReturnType red_fun
(scan_x_params, scan_y_params) = splitAt n $ lambdaParams scan_fun
(red_x_params, red_y_params) = splitAt m $ lambdaParams red_fun
nilFn :: Bindable lore => LambdaT lore
nilFn = Lambda mempty (mkBody mempty mempty) mempty
isNilFn :: LambdaT lore -> Bool
isNilFn (Lambda ps body ts) =
null ps && null ts &&
null (bodyStms body) && null (bodyResult body)
scanomapSOAC :: Bindable lore =>
Lambda lore -> [SubExp] -> Lambda lore -> ScremaForm lore
scanomapSOAC lam nes = ScremaForm (lam, nes) (mempty, nilFn, mempty)
redomapSOAC :: Bindable lore =>
Commutativity -> Lambda lore -> [SubExp] -> Lambda lore -> ScremaForm lore
redomapSOAC comm lam nes = ScremaForm (nilFn, mempty) (comm, lam, nes)
scanSOAC :: (Bindable lore, MonadFreshNames m) =>
Lambda lore -> [SubExp] -> m (ScremaForm lore)
scanSOAC lam nes = scanomapSOAC lam nes <$> mkIdentityLambda (lambdaReturnType lam)
reduceSOAC :: (Bindable lore, MonadFreshNames m) =>
Commutativity -> Lambda lore -> [SubExp] -> m (ScremaForm lore)
reduceSOAC comm lam nes = redomapSOAC comm lam nes <$> mkIdentityLambda (lambdaReturnType lam)
mapSOAC :: Bindable lore => Lambda lore -> ScremaForm lore
mapSOAC = ScremaForm (nilFn, mempty) (mempty, nilFn, mempty)
isScanomapSOAC :: ScremaForm lore -> Maybe (Lambda lore, [SubExp], Lambda lore)
isScanomapSOAC (ScremaForm (scan_lam, scan_nes) (_, _, red_nes) map_lam) = do
guard $ null red_nes
guard $ not $ null scan_nes
return (scan_lam, scan_nes, map_lam)
isScanSOAC :: ScremaForm lore -> Maybe (Lambda lore, [SubExp])
isScanSOAC form = do (scan_lam, scan_nes, map_lam) <- isScanomapSOAC form
guard $ isIdentityLambda map_lam
return (scan_lam, scan_nes)
isRedomapSOAC :: ScremaForm lore -> Maybe (Commutativity, Lambda lore, [SubExp], Lambda lore)
isRedomapSOAC (ScremaForm (_, scan_nes) (comm, red_lam, red_nes) map_lam) = do
guard $ null scan_nes
guard $ not $ null red_nes
return (comm, red_lam, red_nes, map_lam)
isReduceSOAC :: ScremaForm lore -> Maybe (Commutativity, Lambda lore, [SubExp])
isReduceSOAC form = do (comm, red_lam, red_nes, map_lam) <- isRedomapSOAC form
guard $ isIdentityLambda map_lam
return (comm, red_lam, red_nes)
isMapSOAC :: ScremaForm lore -> Maybe (Lambda lore)
isMapSOAC (ScremaForm (_, scan_nes) (_, _, red_nes) map_lam) = do
guard $ null scan_nes
guard $ null red_nes
return map_lam
data SOACMapper flore tlore m = SOACMapper {
mapOnSOACSubExp :: SubExp -> m SubExp
, mapOnSOACLambda :: Lambda flore -> m (Lambda tlore)
, mapOnSOACVName :: VName -> m VName
}
identitySOACMapper :: Monad m => SOACMapper lore lore m
identitySOACMapper = SOACMapper { mapOnSOACSubExp = return
, mapOnSOACLambda = return
, mapOnSOACVName = return
}
mapSOACM :: (Applicative m, Monad m) =>
SOACMapper flore tlore m -> SOAC flore -> m (SOAC tlore)
mapSOACM tv (Stream size form lam arrs) =
Stream <$> mapOnSOACSubExp tv size <*>
mapOnStreamForm form <*> mapOnSOACLambda tv lam <*>
mapM (mapOnSOACVName tv) arrs
where mapOnStreamForm (Parallel o comm lam0 acc) =
Parallel <$> pure o <*> pure comm <*>
mapOnSOACLambda tv lam0 <*>
mapM (mapOnSOACSubExp tv) acc
mapOnStreamForm (Sequential acc) =
Sequential <$> mapM (mapOnSOACSubExp tv) acc
mapSOACM tv (Scatter len lam ivs as) =
Scatter
<$> mapOnSOACSubExp tv len
<*> mapOnSOACLambda tv lam
<*> mapM (mapOnSOACVName tv) ivs
<*> mapM (\(aw,an,a) -> (,,) <$> mapOnSOACSubExp tv aw <*>
pure an <*> mapOnSOACVName tv a) as
mapSOACM tv (GenReduce len ops bucket_fun imgs) =
GenReduce
<$> mapOnSOACSubExp tv len
<*> mapM (\(GenReduceOp e arrs nes op) ->
GenReduceOp <$> mapOnSOACSubExp tv e
<*> mapM (mapOnSOACVName tv) arrs
<*> mapM (mapOnSOACSubExp tv) nes
<*> mapOnSOACLambda tv op) ops
<*> mapOnSOACLambda tv bucket_fun
<*> mapM (mapOnSOACVName tv) imgs
mapSOACM tv (Screma w (ScremaForm (scan_lam, scan_nes) (comm, red_lam, red_nes) map_lam) arrs) =
Screma <$> mapOnSOACSubExp tv w <*>
(ScremaForm <$>
((,) <$> mapOnSOACLambda tv scan_lam <*> mapM (mapOnSOACSubExp tv) scan_nes) <*>
((,,) comm <$> mapOnSOACLambda tv red_lam <*> mapM (mapOnSOACSubExp tv) red_nes) <*>
mapOnSOACLambda tv map_lam)
<*> mapM (mapOnSOACVName tv) arrs
mapSOACM tv (CmpThreshold what s) = CmpThreshold <$> mapOnSOACSubExp tv what <*> pure s
instance Attributes lore => FreeIn (SOAC lore) where
freeIn = execWriter . mapSOACM free
where walk f x = tell (f x) >> return x
free = SOACMapper { mapOnSOACSubExp = walk freeIn
, mapOnSOACLambda = walk freeInLambda
, mapOnSOACVName = walk freeIn
}
instance Attributes lore => Substitute (SOAC lore) where
substituteNames subst =
runIdentity . mapSOACM substitute
where substitute =
SOACMapper { mapOnSOACSubExp = return . substituteNames subst
, mapOnSOACLambda = return . substituteNames subst
, mapOnSOACVName = return . substituteNames subst
}
instance Attributes lore => Rename (SOAC lore) where
rename = mapSOACM renamer
where renamer = SOACMapper rename rename rename
soacType :: SOAC lore -> [Type]
soacType (Stream outersize form lam _) =
map (substNamesInType substs) rtp
where nms = map paramName $ take (1 + length accs) params
substs = M.fromList $ zip nms (outersize:accs)
Lambda params _ rtp = lam
accs = case form of
Parallel _ _ _ acc -> acc
Sequential acc -> acc
soacType (Scatter _w lam _ivs as) =
zipWith arrayOfRow val_ts ws
where val_ts = concatMap (take 1) $ chunks ns $
drop (sum ns) $ lambdaReturnType lam
(ws, ns, _) = unzip3 as
soacType (GenReduce _len ops _bucket_fun _imgs) = do
op <- ops
map (`arrayOfRow` genReduceWidth op) (lambdaReturnType $ genReduceOp op)
soacType (Screma w form _arrs) =
scremaType w form
soacType CmpThreshold{} = [Prim Bool]
instance TypedOp (SOAC lore) where
opType = pure . staticShapes . soacType
instance (Attributes lore, Aliased lore) => AliasedOp (SOAC lore) where
opAliases = map (const mempty) . soacType
consumedInOp (Screma _ (ScremaForm _ _ map_lam) arrs) =
S.map consumedArray $ consumedByLambda map_lam
where consumedArray v = fromMaybe v $ lookup v params_to_arrs
params_to_arrs = zip (map paramName $ lambdaParams map_lam) arrs
consumedInOp (Stream _ form lam arrs) =
S.fromList $ subExpVars $
case form of Sequential accs ->
map (consumedArray accs) $ S.toList $ consumedByLambda lam
Parallel _ _ _ accs ->
map (consumedArray accs) $ S.toList $ consumedByLambda lam
where consumedArray accs v = fromMaybe (Var v) $ lookup v $ paramsToInput accs
paramsToInput accs = zip
(map paramName $ drop 1 $ lambdaParams lam)
(accs++map Var arrs)
consumedInOp (Scatter _ _ _ as) =
S.fromList $ map (\(_, _, a) -> a) as
consumedInOp (GenReduce _ ops _ _) =
S.fromList $ concatMap genReduceDest ops
consumedInOp CmpThreshold{} = mempty
mapGenReduceOp :: (LambdaT flore -> LambdaT tlore)
-> GenReduceOp flore -> GenReduceOp tlore
mapGenReduceOp f (GenReduceOp w dests nes lam) =
GenReduceOp w dests nes $ f lam
instance (Attributes lore,
Attributes (Aliases lore),
CanBeAliased (Op lore)) => CanBeAliased (SOAC lore) where
type OpWithAliases (SOAC lore) = SOAC (Aliases lore)
addOpAliases (Stream size form lam arr) =
Stream size (analyseStreamForm form)
(Alias.analyseLambda lam) arr
where analyseStreamForm (Parallel o comm lam0 acc) =
Parallel o comm (Alias.analyseLambda lam0) acc
analyseStreamForm (Sequential acc) = Sequential acc
addOpAliases (Scatter len lam ivs as) =
Scatter len (Alias.analyseLambda lam) ivs as
addOpAliases (GenReduce len ops bucket_fun imgs) =
GenReduce len (map (mapGenReduceOp Alias.analyseLambda) ops)
(Alias.analyseLambda bucket_fun) imgs
addOpAliases (Screma w (ScremaForm (scan_lam, scan_nes) (comm, red_lam, red_nes) map_lam) arrs) =
Screma w (ScremaForm
(Alias.analyseLambda scan_lam, scan_nes)
(comm, Alias.analyseLambda red_lam, red_nes)
(Alias.analyseLambda map_lam))
arrs
addOpAliases (CmpThreshold what s) = CmpThreshold what s
removeOpAliases = runIdentity . mapSOACM remove
where remove = SOACMapper return (return . removeLambdaAliases) return
instance Attributes lore => IsOp (SOAC lore) where
safeOp CmpThreshold{} = True
safeOp _ = False
cheapOp _ = True
substNamesInType :: M.Map VName SubExp -> Type -> Type
substNamesInType _ tp@(Prim _) = tp
substNamesInType subs (Mem se space) =
Mem (substNamesInSubExp subs se) space
substNamesInType subs (Array btp shp u) =
let shp' = Shape $ map (substNamesInSubExp subs) (shapeDims shp)
in Array btp shp' u
substNamesInSubExp :: M.Map VName SubExp -> SubExp -> SubExp
substNamesInSubExp _ e@(Constant _) = e
substNamesInSubExp subs (Var idd) =
M.findWithDefault (Var idd) idd subs
instance (Ranged inner) => RangedOp (SOAC inner) where
opRanges op = replicate (length $ soacType op) unknownRange
instance (Attributes lore, CanBeRanged (Op lore)) => CanBeRanged (SOAC lore) where
type OpWithRanges (SOAC lore) = SOAC (Ranges lore)
removeOpRanges = runIdentity . mapSOACM remove
where remove = SOACMapper return (return . removeLambdaRanges) return
addOpRanges (Stream w form lam arr) =
Stream w
(Range.runRangeM $ analyseStreamForm form)
(Range.runRangeM $ Range.analyseLambda lam)
arr
where analyseStreamForm (Sequential acc) =
return $ Sequential acc
analyseStreamForm (Parallel o comm lam0 acc) = do
lam0' <- Range.analyseLambda lam0
return $ Parallel o comm lam0' acc
addOpRanges (Scatter len lam ivs as) =
Scatter len (Range.runRangeM $ Range.analyseLambda lam) ivs as
addOpRanges (GenReduce len ops bucket_fun imgs) =
GenReduce len (map (mapGenReduceOp $ Range.runRangeM . Range.analyseLambda) ops)
(Range.runRangeM $ Range.analyseLambda bucket_fun) imgs
addOpRanges (Screma w (ScremaForm (scan_lam, scan_nes) (comm, red_lam, red_nes) map_lam) arrs) =
Screma w (ScremaForm
(Range.runRangeM $ Range.analyseLambda scan_lam, scan_nes)
(comm, Range.runRangeM $ Range.analyseLambda red_lam, red_nes)
(Range.runRangeM $ Range.analyseLambda map_lam))
arrs
addOpRanges (CmpThreshold what s) = CmpThreshold what s
instance (Attributes lore, CanBeWise (Op lore)) => CanBeWise (SOAC lore) where
type OpWithWisdom (SOAC lore) = SOAC (Wise lore)
removeOpWisdom = runIdentity . mapSOACM remove
where remove = SOACMapper return (return . removeLambdaWisdom) return
instance Annotations lore => ST.IndexOp (SOAC lore) where
indexOp vtable k soac [i] = do
(lam,se,arr_params,arrs) <- lambdaAndSubExp soac
let arr_indexes = M.fromList $ catMaybes $ zipWith arrIndex arr_params arrs
arr_indexes' = foldl expandPrimExpTable arr_indexes $ bodyStms $ lambdaBody lam
case se of
Var v -> M.lookup v arr_indexes'
_ -> Nothing
where lambdaAndSubExp (Screma _ (ScremaForm (_, scan_nes) (_, _, red_nes) map_lam) arrs) =
nthMapOut (length scan_nes + length red_nes) map_lam arrs
lambdaAndSubExp _ =
Nothing
nthMapOut num_accs lam arrs = do
se <- maybeNth (num_accs+k) $ bodyResult $ lambdaBody lam
return (lam, se, drop num_accs $ lambdaParams lam, arrs)
arrIndex p arr = do
(pe,cs) <- ST.index' arr [i] vtable
return (paramName p, (pe,cs))
expandPrimExpTable table stm
| [v] <- patternNames $ stmPattern stm,
Just (pe,cs) <-
runWriterT $ primExpFromExp (asPrimExp table) $ stmExp stm,
all (`ST.elem` vtable) (unCertificates $ stmCerts stm) =
M.insert v (pe, stmCerts stm <> cs) table
| otherwise =
table
asPrimExp table v
| Just (e,cs) <- M.lookup v table = tell cs >> return e
| Just (Prim pt) <- ST.lookupType v vtable =
return $ LeafExp v pt
| otherwise = lift Nothing
indexOp _ _ _ _ = Nothing
instance Aliased lore => UsageInOp (SOAC lore) where
usageInOp (Screma _ (ScremaForm _ _ f) arrs) = usageInLambda f arrs
usageInOp _ = mempty
typeCheckSOAC :: TC.Checkable lore => SOAC (Aliases lore) -> TC.TypeM lore ()
typeCheckSOAC (CmpThreshold what _) = TC.require [Prim int32] what
typeCheckSOAC (Stream size form lam arrexps) = do
let accexps = getStreamAccums form
TC.require [Prim int32] size
accargs <- mapM TC.checkArg accexps
arrargs <- mapM lookupType arrexps
_ <- TC.checkSOACArrayArgs size arrexps
let chunk = head $ lambdaParams lam
let asArg t = (t, mempty)
inttp = Prim int32
lamarrs'= map (`setOuterSize` Var (paramName chunk)) arrargs
let acc_len= length accexps
let lamrtp = take acc_len $ lambdaReturnType lam
unless (map TC.argType accargs == lamrtp) $
TC.bad $ TC.TypeError "Stream with inconsistent accumulator type in lambda."
_ <- case form of
Parallel _ _ lam0 _ -> do
let acct = map TC.argType accargs
outerRetType = lambdaReturnType lam0
TC.checkLambda lam0 $ map TC.noArgAliases $ accargs ++ accargs
unless (acct == outerRetType) $
TC.bad $ TC.TypeError $
"Initial value is of type " ++ prettyTuple acct ++
", but stream's reduce lambda returns type " ++ prettyTuple outerRetType ++ "."
_ -> return ()
let fake_lamarrs' = map asArg lamarrs'
TC.checkLambda lam $ asArg inttp : accargs ++ fake_lamarrs'
typeCheckSOAC (Scatter w lam ivs as) = do
TC.require [Prim int32] w
let (_as_ws, as_ns, _as_vs) = unzip3 as
rts = lambdaReturnType lam
rtsLen = length rts `div` 2
rtsI = take rtsLen rts
rtsV = drop rtsLen rts
unless (rtsLen == sum as_ns)
$ TC.bad $ TC.TypeError "Scatter: Uneven number of index types, value types, and arrays outputs."
forM_ rtsI $ \rtI -> unless (Prim int32 == rtI) $
TC.bad $ TC.TypeError "Scatter: Index return type must be i32."
forM_ (zip (chunks as_ns rtsV) as) $ \(rtVs, (aw, _, a)) -> do
TC.require [Prim int32] aw
forM_ rtVs $ \rtV -> TC.requireI [rtV `arrayOfRow` aw] a
TC.consume =<< TC.lookupAliases a
arrargs <- TC.checkSOACArrayArgs w ivs
TC.checkLambda lam arrargs
typeCheckSOAC (GenReduce len ops bucket_fun imgs) = do
TC.require [Prim int32] len
forM_ ops $ \(GenReduceOp dest_w dests nes op) -> do
nes' <- mapM TC.checkArg nes
TC.require [Prim int32] dest_w
TC.checkLambda op $ map TC.noArgAliases $ nes' ++ nes'
let nes_t = map TC.argType nes'
unless (nes_t == lambdaReturnType op) $
TC.bad $ TC.TypeError $ "Operator has return type " ++
prettyTuple (lambdaReturnType op) ++ " but neutral element has type " ++
prettyTuple nes_t
forM_ (zip nes_t dests) $ \(t, dest) -> do
TC.requireI [t `arrayOfRow` dest_w] dest
TC.consume =<< TC.lookupAliases dest
img' <- TC.checkSOACArrayArgs len imgs
TC.checkLambda bucket_fun img'
nes_ts <- concat <$> mapM (mapM subExpType . genReduceNeutral) ops
let bucket_ret_t = replicate (length ops) (Prim int32) ++ nes_ts
unless (bucket_ret_t == lambdaReturnType bucket_fun) $
TC.bad $ TC.TypeError $ "Bucket function has return type " ++
prettyTuple (lambdaReturnType bucket_fun) ++ " but should have type " ++
prettyTuple bucket_ret_t
typeCheckSOAC (Screma w (ScremaForm (scan_lam, scan_nes) (_, red_lam, red_nes) map_lam) arrs) = do
TC.require [Prim int32] w
arrs' <- TC.checkSOACArrayArgs w arrs
scan_nes' <- mapM TC.checkArg scan_nes
red_nes' <- mapM TC.checkArg red_nes
TC.checkLambda map_lam $ map TC.noArgAliases arrs'
TC.checkLambda scan_lam $ map TC.noArgAliases $ scan_nes' ++ scan_nes'
TC.checkLambda red_lam $ map TC.noArgAliases $ red_nes' ++ red_nes'
let scan_t = map TC.argType scan_nes'
red_t = map TC.argType red_nes'
map_lam_ts = lambdaReturnType map_lam
unless (scan_t == lambdaReturnType scan_lam) $
TC.bad $ TC.TypeError $ "Scan function returns type " ++
prettyTuple (lambdaReturnType scan_lam) ++ " but neutral element has type " ++
prettyTuple scan_t
unless (red_t == lambdaReturnType red_lam) $
TC.bad $ TC.TypeError $ "Reduce function returns type " ++
prettyTuple (lambdaReturnType red_lam) ++ " but neutral element has type " ++
prettyTuple red_t
unless (take (length scan_nes + length red_nes) map_lam_ts ==
map TC.argType (scan_nes'++ red_nes')) $
TC.bad $ TC.TypeError $ "Map function return type " ++ prettyTuple map_lam_ts ++
" wrong for given scan and reduction functions."
getStreamAccums :: StreamForm lore -> [SubExp]
getStreamAccums (Parallel _ _ _ accs) = accs
getStreamAccums (Sequential accs) = accs
getStreamOrder :: StreamForm lore -> StreamOrd
getStreamOrder (Parallel o _ _ _) = o
getStreamOrder (Sequential _) = InOrder
instance OpMetrics (Op lore) => OpMetrics (SOAC lore) where
opMetrics (Stream _ _ lam _) =
inside "Stream" $ lambdaMetrics lam
opMetrics (Scatter _len lam _ivs _as) =
inside "Scatter" $ lambdaMetrics lam
opMetrics (GenReduce _len ops bucket_fun _imgs) =
inside "GenReduce" $ mapM_ (lambdaMetrics . genReduceOp) ops >> lambdaMetrics bucket_fun
opMetrics (Screma _ (ScremaForm (scan_lam, _) (_, red_lam, _) map_lam) _) =
inside "Screma" $
lambdaMetrics scan_lam >> lambdaMetrics red_lam >> lambdaMetrics map_lam
opMetrics CmpThreshold{} = seen "CmpThreshold"
instance PrettyLore lore => PP.Pretty (SOAC lore) where
ppr (Stream size form lam arrs) =
case form of
Parallel o comm lam0 acc ->
let ord_str = if o == Disorder then "Per" else ""
comm_str = case comm of Commutative -> "Comm"
Noncommutative -> ""
in text ("streamPar"++ord_str++comm_str) <>
parens (ppr size <> comma </> ppr lam0 </> comma </> ppr lam </>
commasep ( PP.braces (commasep $ map ppr acc) : map ppr arrs ))
Sequential acc ->
text "streamSeq" <>
parens (ppr size <> comma </> ppr lam <> comma </>
commasep ( PP.braces (commasep $ map ppr acc) : map ppr arrs ))
ppr (Scatter len lam ivs as) =
ppSOAC "scatter" len [lam] (Just (map Var ivs)) (map (\(_,n,a) -> (n,a)) as)
ppr (GenReduce len ops bucket_fun imgs) =
ppGenReduce len ops bucket_fun imgs
ppr (Screma w (ScremaForm (scan_lam, scan_nes) (_, red_lam, red_nes) map_lam) arrs)
| isNilFn scan_lam, null scan_nes,
isNilFn red_lam, null red_nes =
text "map" <> parens (ppr w <> comma </>
ppr map_lam <> comma </>
commasep (map ppr arrs))
| isNilFn scan_lam, null scan_nes =
text "redomap" <> parens (ppr w <> comma </>
ppr red_lam <> comma </>
commasep (map ppr red_nes) <> comma </>
ppr map_lam <> comma </>
commasep (map ppr arrs))
| isNilFn red_lam, null red_nes =
text "scanomap" <> parens (ppr w <> comma </>
ppr scan_lam <> comma </>
commasep (map ppr scan_nes) <> comma </>
ppr map_lam <> comma </>
commasep (map ppr arrs))
ppr (Screma w form arrs) = ppScrema w form arrs
ppr (CmpThreshold what s) = text "cmpThreshold(" <> ppr what <> comma PP.<+> text (show s) <> text ")"
ppScrema :: (PrettyLore lore, Pretty inp) =>
SubExp -> ScremaForm lore -> [inp] -> Doc
ppScrema w (ScremaForm (scan_lam, scan_nes) (comm, red_lam, red_nes) map_lam) arrs =
text s <> parens (ppr w <> comma </>
ppr scan_lam <> comma </>
PP.braces (commasep $ map ppr scan_nes) </>
ppr red_lam <> comma </>
PP.braces (commasep $ map ppr red_nes) </>
ppr map_lam <> comma </>
commasep (map ppr arrs))
where s = case comm of Noncommutative -> "screma"
Commutative -> "scremaComm"
ppGenReduce :: (PrettyLore lore, Pretty inp) =>
SubExp -> [GenReduceOp lore] -> Lambda lore -> [inp] -> Doc
ppGenReduce len ops bucket_fun imgs =
text "gen_reduce" <>
parens (ppr len <> comma </>
PP.braces (mconcat $ intersperse (comma <> PP.line) $ map ppOp ops) <> comma </>
ppr bucket_fun <> comma </>
commasep (map ppr imgs))
where ppOp (GenReduceOp w dests nes op) =
ppr w <> comma <> PP.braces (commasep $ map ppr dests) <> comma </>
PP.braces (commasep $ map ppr nes) <> comma </> ppr op
ppSOAC :: (Pretty fn, Pretty v) =>
String -> SubExp -> [fn] -> Maybe [SubExp] -> [v] -> Doc
ppSOAC name size funs es as =
text name <> parens (ppr size <> comma </>
ppList funs </>
commasep (es' ++ map ppr as))
where es' = maybe [] ((:[]) . ppTuple') es
ppList :: Pretty a => [a] -> Doc
ppList as = case map ppr as of
[] -> mempty
a':as' -> foldl (</>) (a' <> comma) $ map (<> comma) as'