{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE FlexibleInstances #-}
module Futhark.Analysis.HORepresentation.SOAC
(
SOAC (..)
, Futhark.ScremaForm(..)
, inputs
, setInputs
, lambda
, setLambda
, typeOf
, width
, NotSOAC (..)
, fromExp
, toExp
, toSOAC
, Input (..)
, varInput
, identInput
, isVarInput
, isVarishInput
, addTransform
, addTransforms
, addInitialTransforms
, inputArray
, inputRank
, inputType
, inputRowType
, transformRows
, transposeInput
, ArrayTransforms
, noTransforms
, singleTransform
, nullTransforms
, (|>)
, (<|)
, viewf
, ViewF(..)
, viewl
, ViewL(..)
, ArrayTransform(..)
, transformFromExp
, soacToStream
)
where
import Data.Foldable as Foldable
import Data.Maybe
import Data.Monoid ((<>))
import qualified Data.Sequence as Seq
import qualified Futhark.Representation.AST as Futhark
import Futhark.Representation.SOACS.SOAC
(StreamForm(..), ScremaForm(..), scremaType, getStreamAccums, GenReduceOp(..))
import qualified Futhark.Representation.SOACS.SOAC as Futhark
import Futhark.Representation.AST
hiding (Var, Iota, Rearrange, Reshape, Replicate, typeOf)
import Futhark.Transform.Substitute
import Futhark.Construct hiding (toExp)
import Futhark.Transform.Rename (renameLambda)
import qualified Futhark.Util.Pretty as PP
import Futhark.Util.Pretty (ppr, text)
data ArrayTransform = Rearrange Certificates [Int]
| Reshape Certificates (ShapeChange SubExp)
| ReshapeOuter Certificates (ShapeChange SubExp)
| ReshapeInner Certificates (ShapeChange SubExp)
| Replicate Certificates Shape
deriving (Show, Eq, Ord)
instance Substitute ArrayTransform where
substituteNames substs (Rearrange cs xs) =
Rearrange (substituteNames substs cs) xs
substituteNames substs (Reshape cs ses) =
Reshape (substituteNames substs cs) (substituteNames substs ses)
substituteNames substs (ReshapeOuter cs ses) =
ReshapeOuter (substituteNames substs cs) (substituteNames substs ses)
substituteNames substs (ReshapeInner cs ses) =
ReshapeInner (substituteNames substs cs) (substituteNames substs ses)
substituteNames substs (Replicate cs se) =
Replicate (substituteNames substs cs) (substituteNames substs se)
newtype ArrayTransforms = ArrayTransforms (Seq.Seq ArrayTransform)
deriving (Eq, Ord, Show)
instance Semigroup ArrayTransforms where
ts1 <> ts2 = case viewf ts2 of
t :< ts2' -> (ts1 |> t) <> ts2'
EmptyF -> ts1
instance Monoid ArrayTransforms where
mempty = noTransforms
instance Substitute ArrayTransforms where
substituteNames substs (ArrayTransforms ts) =
ArrayTransforms $ substituteNames substs <$> ts
noTransforms :: ArrayTransforms
noTransforms = ArrayTransforms Seq.empty
nullTransforms :: ArrayTransforms -> Bool
nullTransforms (ArrayTransforms s) = Seq.null s
singleTransform :: ArrayTransform -> ArrayTransforms
singleTransform = ArrayTransforms . Seq.singleton
viewf :: ArrayTransforms -> ViewF
viewf (ArrayTransforms s) = case Seq.viewl s of
t Seq.:< s' -> t :< ArrayTransforms s'
Seq.EmptyL -> EmptyF
data ViewF = EmptyF
| ArrayTransform :< ArrayTransforms
viewl :: ArrayTransforms -> ViewL
viewl (ArrayTransforms s) = case Seq.viewr s of
s' Seq.:> t -> ArrayTransforms s' :> t
Seq.EmptyR -> EmptyL
data ViewL = EmptyL
| ArrayTransforms :> ArrayTransform
(|>) :: ArrayTransforms -> ArrayTransform -> ArrayTransforms
(|>) = flip $ addTransform' extract add $ uncurry (flip (,))
where extract ts' = case viewl ts' of
EmptyL -> Nothing
ts'' :> t' -> Just (t', ts'')
add t' (ArrayTransforms ts') = ArrayTransforms $ ts' Seq.|> t'
(<|) :: ArrayTransform -> ArrayTransforms -> ArrayTransforms
(<|) = addTransform' extract add id
where extract ts' = case viewf ts' of
EmptyF -> Nothing
t' :< ts'' -> Just (t', ts'')
add t' (ArrayTransforms ts') = ArrayTransforms $ t' Seq.<| ts'
addTransform' :: (ArrayTransforms -> Maybe (ArrayTransform, ArrayTransforms))
-> (ArrayTransform -> ArrayTransforms -> ArrayTransforms)
-> ((ArrayTransform,ArrayTransform) -> (ArrayTransform,ArrayTransform))
-> ArrayTransform -> ArrayTransforms
-> ArrayTransforms
addTransform' extract add swap t ts =
fromMaybe (t `add` ts) $ do
(t', ts') <- extract ts
combined <- uncurry combineTransforms $ swap (t', t)
Just $ if identityTransform combined then ts'
else addTransform' extract add swap combined ts'
identityTransform :: ArrayTransform -> Bool
identityTransform (Rearrange _ perm) =
Foldable.and $ zipWith (==) perm [0..]
identityTransform _ = False
combineTransforms :: ArrayTransform -> ArrayTransform -> Maybe ArrayTransform
combineTransforms (Rearrange cs2 perm2) (Rearrange cs1 perm1) =
Just $ Rearrange (cs1<>cs2) $ perm2 `rearrangeCompose` perm1
combineTransforms _ _ = Nothing
transformFromExp :: Certificates -> Exp lore -> Maybe (VName, ArrayTransform)
transformFromExp cs (BasicOp (Futhark.Rearrange perm v)) =
Just (v, Rearrange cs perm)
transformFromExp cs (BasicOp (Futhark.Reshape shape v)) =
Just (v, Reshape cs shape)
transformFromExp cs (BasicOp (Futhark.Replicate shape (Futhark.Var v))) =
Just (v, Replicate cs shape)
transformFromExp _ _ = Nothing
data Input = Input ArrayTransforms VName Type
deriving (Show, Eq, Ord)
instance Substitute Input where
substituteNames substs (Input ts v t) =
Input (substituteNames substs ts)
(substituteNames substs v) (substituteNames substs t)
varInput :: HasScope t f => VName -> f Input
varInput v = withType <$> lookupType v
where withType = Input (ArrayTransforms Seq.empty) v
identInput :: Ident -> Input
identInput v = Input (ArrayTransforms Seq.empty) (identName v) (identType v)
isVarInput :: Input -> Maybe VName
isVarInput (Input ts v _) | nullTransforms ts = Just v
isVarInput _ = Nothing
isVarishInput :: Input -> Maybe VName
isVarishInput (Input ts v t)
| nullTransforms ts = Just v
| Reshape cs [DimCoercion _] :< ts' <- viewf ts, cs == mempty =
isVarishInput $ Input ts' v t
isVarishInput _ = Nothing
addTransform :: ArrayTransform -> Input -> Input
addTransform tr (Input trs a t) =
Input (trs |> tr) a t
addTransforms :: ArrayTransforms -> Input -> Input
addTransforms ts (Input ots a t) = Input (ots <> ts) a t
addInitialTransforms :: ArrayTransforms -> Input -> Input
addInitialTransforms ts (Input ots a t) = Input (ts <> ots) a t
inputsToSubExps :: (MonadBinder m) =>
[Input] -> m [VName]
inputsToSubExps = mapM inputToExp'
where inputToExp' (Input (ArrayTransforms ts) a _) =
foldlM transform a ts
transform ia (Replicate cs n) =
certifying cs $
letExp "repeat" $ BasicOp $ Futhark.Replicate n (Futhark.Var ia)
transform ia (Rearrange cs perm) =
certifying cs $
letExp "rearrange" $ BasicOp $ Futhark.Rearrange perm ia
transform ia (Reshape cs shape) =
certifying cs $
letExp "reshape" $ BasicOp $ Futhark.Reshape shape ia
transform ia (ReshapeOuter cs shape) = do
shape' <- reshapeOuter shape 1 . arrayShape <$> lookupType ia
certifying cs $
letExp "reshape_outer" $ BasicOp $ Futhark.Reshape shape' ia
transform ia (ReshapeInner cs shape) = do
shape' <- reshapeInner shape 1 . arrayShape <$> lookupType ia
certifying cs $
letExp "reshape_inner" $ BasicOp $ Futhark.Reshape shape' ia
inputArray :: Input -> VName
inputArray (Input _ v _) = v
inputType :: Input -> Type
inputType (Input (ArrayTransforms ts) _ at) =
Foldable.foldl transformType at ts
where transformType t (Replicate _ shape) =
arrayOfShape t shape
transformType t (Rearrange _ perm) =
rearrangeType perm t
transformType t (Reshape _ shape) =
t `setArrayShape` newShape shape
transformType t (ReshapeOuter _ shape) =
let Shape oldshape = arrayShape t
in t `setArrayShape` Shape (newDims shape ++ drop 1 oldshape)
transformType t (ReshapeInner _ shape) =
let Shape oldshape = arrayShape t
in t `setArrayShape` Shape (take 1 oldshape ++ newDims shape)
inputRowType :: Input -> Type
inputRowType = rowType . inputType
inputRank :: Input -> Int
inputRank = arrayRank . inputType
transformRows :: ArrayTransforms -> Input -> Input
transformRows (ArrayTransforms ts) =
flip (Foldable.foldl transformRows') ts
where transformRows' inp (Rearrange cs perm) =
addTransform (Rearrange cs (0:map (+1) perm)) inp
transformRows' inp (Reshape cs shape) =
addTransform (ReshapeInner cs shape) inp
transformRows' inp (Replicate cs n)
| inputRank inp == 1 =
Rearrange mempty [1,0] `addTransform`
(Replicate cs n `addTransform` inp)
| otherwise =
Rearrange mempty (2:0:1:[3..inputRank inp]) `addTransform`
(Replicate cs n `addTransform`
(Rearrange mempty (1:0:[2..inputRank inp-1]) `addTransform` inp))
transformRows' inp nts =
error $ "transformRows: Cannot transform this yet:\n" ++ show nts ++ "\n" ++ show inp
transposeInput :: Int -> Int -> Input -> Input
transposeInput k n inp =
addTransform (Rearrange mempty $ transposeIndex k n [0..inputRank inp-1]) inp
data SOAC lore = Stream SubExp (StreamForm lore) (Lambda lore) [Input]
| Scatter SubExp (Lambda lore) [Input] [(SubExp, Int, VName)]
| Screma SubExp (ScremaForm lore) [Input]
| GenReduce SubExp [GenReduceOp lore] (Lambda lore) [Input]
deriving (Eq, Show)
instance PP.Pretty Input where
ppr (Input (ArrayTransforms ts) arr _) = foldl f (ppr arr) ts
where f e (Rearrange cs perm) =
text "rearrange" <> ppr cs <> PP.apply [PP.apply (map ppr perm), e]
f e (Reshape cs shape) =
text "reshape" <> ppr cs <> PP.apply [PP.apply (map ppr shape), e]
f e (ReshapeOuter cs shape) =
text "reshape_outer" <> ppr cs <> PP.apply [PP.apply (map ppr shape), e]
f e (ReshapeInner cs shape) =
text "reshape_inner" <> ppr cs <> PP.apply [PP.apply (map ppr shape), e]
f e (Replicate cs ne) =
text "replicate" <> ppr cs <> PP.apply [ppr ne, e]
instance PrettyLore lore => PP.Pretty (SOAC lore) where
ppr (Screma w form arrs) = Futhark.ppScrema w form arrs
ppr (GenReduce len ops bucket_fun imgs) =
Futhark.ppGenReduce len ops bucket_fun imgs
ppr soac = text $ show soac
inputs :: SOAC lore -> [Input]
inputs (Stream _ _ _ arrs) = arrs
inputs (Scatter _len _lam ivs _as) = ivs
inputs (Screma _ _ arrs) = arrs
inputs (GenReduce _ _ _ inps) = inps
setInputs :: [Input] -> SOAC lore -> SOAC lore
setInputs arrs (Stream w form lam _) =
Stream (newWidth arrs w) form lam arrs
setInputs arrs (Scatter w lam _ivs as) =
Scatter (newWidth arrs w) lam arrs as
setInputs arrs (Screma w form _) =
Screma w form arrs
setInputs inps (GenReduce w ops lam _) =
GenReduce w ops lam inps
newWidth :: [Input] -> SubExp -> SubExp
newWidth [] w = w
newWidth (inp:_) _ = arraySize 0 $ inputType inp
lambda :: SOAC lore -> Lambda lore
lambda (Stream _ _ lam _) = lam
lambda (Scatter _len lam _ivs _as) = lam
lambda (Screma _ (ScremaForm _ _ lam) _) = lam
lambda (GenReduce _ _ lam _) = lam
setLambda :: Lambda lore -> SOAC lore -> SOAC lore
setLambda lam (Stream w form _ arrs) =
Stream w form lam arrs
setLambda lam (Scatter len _lam ivs as) =
Scatter len lam ivs as
setLambda lam (Screma w (ScremaForm scan red _) arrs) =
Screma w (ScremaForm scan red lam) arrs
setLambda lam (GenReduce w ops _ inps) =
GenReduce w ops lam inps
typeOf :: SOAC lore -> [Type]
typeOf (Stream w form lam _) =
let nes = getStreamAccums form
accrtps = take (length nes) $ lambdaReturnType lam
arrtps = [ arrayOf (stripArray 1 t) (Shape [w]) NoUniqueness
| t <- drop (length nes) (lambdaReturnType lam) ]
in accrtps ++ arrtps
typeOf (Scatter _w lam _ivs dests) =
zipWith arrayOfRow (snd $ splitAt (n `div` 2) lam_ts) aws
where lam_ts = lambdaReturnType lam
n = length lam_ts
(aws, _, _) = unzip3 dests
typeOf (Screma w form _) =
scremaType w form
typeOf (GenReduce _ ops _ _) = do
op <- ops
map (`arrayOfRow` genReduceWidth op) (lambdaReturnType $ genReduceOp op)
width :: SOAC lore -> SubExp
width (Stream w _ _ _) = w
width (Scatter len _lam _ivs _as) = len
width (Screma w _ _) = w
width (GenReduce w _ _ _) = w
toExp :: (MonadBinder m, Op (Lore m) ~ Futhark.SOAC (Lore m)) =>
SOAC (Lore m) -> m (Exp (Lore m))
toExp soac = Op <$> toSOAC soac
toSOAC :: MonadBinder m =>
SOAC (Lore m) -> m (Futhark.SOAC (Lore m))
toSOAC (Stream w form lam inps) =
Futhark.Stream w form lam <$> inputsToSubExps inps
toSOAC (Scatter len lam ivs dests) = do
ivs' <- inputsToSubExps ivs
return $ Futhark.Scatter len lam ivs' dests
toSOAC (Screma w form arrs) =
Futhark.Screma w form <$> inputsToSubExps arrs
toSOAC (GenReduce w ops lam inps) =
Futhark.GenReduce w ops lam <$> inputsToSubExps inps
data NotSOAC = NotSOAC
deriving (Show)
fromExp :: (Op lore ~ Futhark.SOAC lore, Bindable lore,
HasScope lore m, MonadFreshNames m) =>
Exp lore -> m (Either NotSOAC (SOAC lore))
fromExp (BasicOp (Copy arr)) = do
arr_t <- lookupType arr
p <- Param <$> newVName "copy_p" <*> pure (rowType arr_t)
let lam = Lambda [p] (mkBody mempty [Futhark.Var $ paramName p]) [rowType arr_t]
Right . Screma (arraySize 0 arr_t) (Futhark.mapSOAC lam) . pure <$> varInput arr
fromExp (Op (Futhark.Stream w form lam as)) =
Right . Stream w form lam <$> traverse varInput as
fromExp (Op (Futhark.Scatter len lam ivs as)) = do
ivs' <- traverse varInput ivs
return $ Right $ Scatter len lam ivs' as
fromExp (Op (Futhark.Screma w form arrs)) =
Right . Screma w form <$> traverse varInput arrs
fromExp (Op (Futhark.GenReduce w ops lam arrs)) =
Right . GenReduce w ops lam <$> traverse varInput arrs
fromExp _ = pure $ Left NotSOAC
soacToStream :: (MonadFreshNames m, Bindable lore, Op lore ~ Futhark.SOAC lore) =>
SOAC lore -> m (SOAC lore,[Ident])
soacToStream soac = do
chunk_param <- newParam "chunk" $ Prim int32
let chvar= Futhark.Var $ paramName chunk_param
(lam, inps) = (lambda soac, inputs soac)
w = width soac
lam' <- renameLambda lam
let arrrtps= mapType w lam
loutps = [ arrayOfRow t chvar | t <- map rowType arrrtps ]
lintps = [ arrayOfRow t chvar | t <- map inputRowType inps ]
strm_inpids <- mapM (newParam "inp") lintps
case soac of
Screma _ form _
| Just _ <- Futhark.isMapSOAC form -> do
strm_resids <- mapM (newIdent "res") loutps
let insoac = Futhark.Screma chvar (Futhark.mapSOAC lam') $ map paramName strm_inpids
insbnd = mkLet [] strm_resids $ Op insoac
strmbdy= mkBody (oneStm insbnd) $ map (Futhark.Var . identName) strm_resids
strmpar= chunk_param:strm_inpids
strmlam= Lambda strmpar strmbdy loutps
empty_lam = Lambda [] (mkBody mempty []) []
return (Stream w (Parallel Disorder Commutative empty_lam []) strmlam inps, [])
| Just (scan_lam, nes, _) <- Futhark.isScanomapSOAC form -> do
let scan_arr_ts = map (`arrayOfRow` chvar) $ lambdaReturnType scan_lam
map_arr_ts = drop (length nes) loutps
accrtps = lambdaReturnType scan_lam
strm_resids <- mapM (newIdent "res") scan_arr_ts
scan0_ids <- mapM (newIdent "resarr0") scan_arr_ts
map_resids <- mapM (newIdent "map_res") map_arr_ts
lastel_ids <- mapM (newIdent "lstel") accrtps
lastel_tmp_ids <- mapM (newIdent "lstel_tmp") accrtps
empty_arr <- newIdent "empty_arr" $ Prim Bool
inpacc_ids <- mapM (newParam "inpacc") accrtps
outszm1id <- newIdent "szm1" $ Prim int32
let insbnd = mkLet [] (scan0_ids++map_resids) $ Op $
Futhark.Screma chvar (Futhark.scanomapSOAC scan_lam nes lam') $
map paramName strm_inpids
outszm1bnd = mkLet [] [outszm1id] $ BasicOp $
BinOp (Sub Int32)
(Futhark.Var $ paramName chunk_param)
(constant (1::Int32))
empty_arr_bnd = mkLet [] [empty_arr] $ BasicOp $ CmpOp (CmpSlt Int32)
(Futhark.Var $ identName outszm1id)
(constant (0::Int32))
leltmpbnds= zipWith (\ lid arrid -> mkLet [] [lid] $ BasicOp $
Index (identName arrid) $
fullSlice (identType arrid)
[DimFix $ Futhark.Var $ identName outszm1id]
) lastel_tmp_ids scan0_ids
lelbnd = mkLet [] lastel_ids $
If (Futhark.Var $ identName empty_arr)
(mkBody mempty nes)
(mkBody (stmsFromList leltmpbnds) $
map (Futhark.Var . identName) lastel_tmp_ids) $
ifCommon $ map identType lastel_tmp_ids
maplam <- mkMapPlusAccLam (map (Futhark.Var . paramName) inpacc_ids) scan_lam
let mapbnd = mkLet [] strm_resids $ Op $
Futhark.Screma chvar (Futhark.mapSOAC maplam) $
map identName scan0_ids
addlelbdy <- mkPlusBnds scan_lam $ map Futhark.Var $
map paramName inpacc_ids++map identName lastel_ids
let (addlelbnd,addlelres) = (bodyStms addlelbdy, bodyResult addlelbdy)
strmbdy= mkBody (stmsFromList [insbnd,outszm1bnd,empty_arr_bnd,lelbnd,mapbnd]<>addlelbnd) $
addlelres ++ map (Futhark.Var . identName) (strm_resids ++ map_resids)
strmpar= chunk_param:inpacc_ids++strm_inpids
strmlam= Lambda strmpar strmbdy (accrtps++loutps)
return (Stream w (Sequential nes) strmlam inps,
map paramIdent inpacc_ids)
| Just (comm, lamin, nes, _) <- Futhark.isRedomapSOAC form -> do
let accrtps= take (length nes) $ lambdaReturnType lam
loutps' = drop (length nes) loutps
foldlam = lam'
strm_resids <- mapM (newIdent "res") loutps'
inpacc_ids <- mapM (newParam "inpacc") accrtps
acc0_ids <- mapM (newIdent "acc0" ) accrtps
let insoac = Futhark.Screma chvar (Futhark.redomapSOAC comm lamin nes foldlam) $
map paramName strm_inpids
insbnd = mkLet [] (acc0_ids++strm_resids) $ Op insoac
addaccbdy <- mkPlusBnds lamin $ map Futhark.Var $
map paramName inpacc_ids++map identName acc0_ids
let (addaccbnd,addaccres) = (bodyStms addaccbdy, bodyResult addaccbdy)
strmbdy= mkBody (oneStm insbnd <> addaccbnd) $
addaccres ++ map (Futhark.Var . identName) strm_resids
strmpar= chunk_param:inpacc_ids++strm_inpids
strmlam= Lambda strmpar strmbdy (accrtps++loutps')
lam0 <- renameLambda lamin
return (Stream w (Parallel InOrder comm lam0 nes) strmlam inps, [])
_ -> return (soac,[])
where mkMapPlusAccLam :: (MonadFreshNames m, Bindable lore)
=> [SubExp] -> Lambda lore -> m (Lambda lore)
mkMapPlusAccLam accs plus = do
let lampars = lambdaParams plus
(accpars, rempars) = ( take (length accs) lampars,
drop (length accs) lampars )
parbnds = zipWith (\ par se -> mkLet [] [paramIdent par]
(BasicOp $ SubExp se)
) accpars accs
plus_bdy = lambdaBody plus
newlambdy = Body (bodyAttr plus_bdy)
(stmsFromList parbnds <> bodyStms plus_bdy)
(bodyResult plus_bdy)
renameLambda $ Lambda rempars newlambdy $ lambdaReturnType plus
mkPlusBnds :: (MonadFreshNames m, Bindable lore)
=> Lambda lore -> [SubExp] -> m (Body lore)
mkPlusBnds plus accels = do
plus' <- renameLambda plus
let parbnds = zipWith (\ par se -> mkLet [] [paramIdent par]
(BasicOp $ SubExp se)
) (lambdaParams plus') accels
body = lambdaBody plus'
return $ body { bodyStms = stmsFromList parbnds <> bodyStms body }