{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TupleSections #-}
module Futhark.Analysis.HORepresentation.MapNest
( Nesting (..)
, MapNest (..)
, typeOf
, params
, inputs
, setInputs
, fromSOAC
, toSOAC
)
where
import Control.Monad
import Data.List
import Data.Maybe
import qualified Data.Map.Strict as M
import qualified Data.Set as S
import qualified Futhark.Analysis.HORepresentation.SOAC as SOAC
import Futhark.Analysis.HORepresentation.SOAC (SOAC)
import qualified Futhark.Representation.SOACS.SOAC as Futhark
import Futhark.Transform.Substitute
import Futhark.Representation.AST hiding (typeOf)
import Futhark.MonadFreshNames
import Futhark.Construct
data Nesting lore = Nesting {
nestingParamNames :: [VName]
, nestingResult :: [VName]
, nestingReturnType :: [Type]
, nestingWidth :: SubExp
} deriving (Eq, Ord, Show)
data MapNest lore = MapNest SubExp (Lambda lore) [Nesting lore] [SOAC.Input]
deriving (Show)
typeOf :: MapNest lore -> [Type]
typeOf (MapNest w lam [] _) =
map (`arrayOfRow` w) $ lambdaReturnType lam
typeOf (MapNest w _ (nest:_) _) =
map (`arrayOfRow` w) $ nestingReturnType nest
params :: MapNest lore -> [VName]
params (MapNest _ lam [] _) =
map paramName $ lambdaParams lam
params (MapNest _ _ (nest:_) _) =
nestingParamNames nest
inputs :: MapNest lore -> [SOAC.Input]
inputs (MapNest _ _ _ inps) = inps
setInputs :: [SOAC.Input] -> MapNest lore -> MapNest lore
setInputs [] (MapNest w body ns _) = MapNest w body ns []
setInputs (inp:inps) (MapNest _ body ns _) = MapNest w body ns' (inp:inps)
where w = arraySize 0 $ SOAC.inputType inp
ws = drop 1 $ arrayDims $ SOAC.inputType inp
ns' = zipWith setDepth ns ws
setDepth n nw = n { nestingWidth = nw }
fromSOAC :: (Bindable lore, MonadFreshNames m,
LocalScope lore m,
Op lore ~ Futhark.SOAC lore) =>
SOAC lore -> m (Maybe (MapNest lore))
fromSOAC = fromSOAC' mempty
fromSOAC' :: (Bindable lore, MonadFreshNames m,
LocalScope lore m,
Op lore ~ Futhark.SOAC lore) =>
[Ident]
-> SOAC lore
-> m (Maybe (MapNest lore))
fromSOAC' bound (SOAC.Screma w (SOAC.ScremaForm (_, []) (_, _, []) lam) inps) = do
maybenest <- case (stmsToList $ bodyStms $ lambdaBody lam,
bodyResult $ lambdaBody lam) of
([Let pat _ e], res) | res == map Var (patternNames pat) ->
localScope (scopeOfLParams $ lambdaParams lam) $
SOAC.fromExp e >>=
either (return . Left) (fmap (Right . fmap (pat,)) . fromSOAC' bound')
_ ->
return $ Right Nothing
case maybenest of
Right (Just (pat, mn@(MapNest inner_w body' ns' inps'))) -> do
(ps, inps'') <-
unzip <$>
fixInputs w (zip (map paramName $ lambdaParams lam) inps)
(zip (params mn) inps')
let n' = Nesting {
nestingParamNames = ps
, nestingResult = patternNames pat
, nestingReturnType = typeOf mn
, nestingWidth = inner_w
}
return $ Just $ MapNest w body' (n':ns') inps''
_ -> do
let isBound name
| Just param <- find ((name==) . identName) bound =
Just param
| otherwise =
Nothing
boundUsedInBody =
mapMaybe isBound $ S.toList $ freeInLambda lam
newParams <- mapM (newIdent' (++"_wasfree")) boundUsedInBody
let subst = M.fromList $
zip (map identName boundUsedInBody) (map identName newParams)
inps' = map (substituteNames subst) inps ++
map (SOAC.addTransform (SOAC.Replicate mempty $ Shape [w]) . SOAC.identInput)
boundUsedInBody
lam' =
lam { lambdaBody =
substituteNames subst $ lambdaBody lam
, lambdaParams =
lambdaParams lam ++ [ Param name t
| Ident name t <- newParams ]
}
return $ Just $ MapNest w lam' [] inps'
where bound' = bound <> map paramIdent (lambdaParams lam)
fromSOAC' _ _ = return Nothing
toSOAC :: (MonadFreshNames m, HasScope lore m,
Bindable lore, BinderOps lore, Op lore ~ Futhark.SOAC lore) =>
MapNest lore -> m (SOAC lore)
toSOAC (MapNest w lam [] inps) =
return $ SOAC.Screma w (Futhark.mapSOAC lam) inps
toSOAC (MapNest w lam (Nesting npnames nres nrettype nw:ns) inps) = do
let nparams = zipWith Param npnames $ map SOAC.inputRowType inps
(e,bnds) <- runBinder $ localScope (scopeOfLParams nparams) $ SOAC.toExp =<<
toSOAC (MapNest nw lam ns $ map (SOAC.identInput . paramIdent) nparams)
bnd <- mkLetNames nres e
let outerlam = Lambda { lambdaParams = nparams
, lambdaBody = mkBody (bnds<>oneStm bnd) $ map Var nres
, lambdaReturnType = nrettype
}
return $ SOAC.Screma w (Futhark.mapSOAC outerlam) inps
fixInputs :: MonadFreshNames m =>
SubExp -> [(VName, SOAC.Input)] -> [(VName, SOAC.Input)]
-> m [(VName, SOAC.Input)]
fixInputs w ourInps childInps =
reverse . snd <$> foldM inspect (ourInps, []) childInps
where
isParam x (y, _) = x == y
findParam :: [(VName, SOAC.Input)]
-> VName
-> Maybe ((VName, SOAC.Input), [(VName, SOAC.Input)])
findParam remPs v
| ([ourP], remPs') <- partition (isParam v) remPs = Just (ourP, remPs')
| otherwise = Nothing
inspect :: MonadFreshNames m =>
([(VName, SOAC.Input)], [(VName, SOAC.Input)])
-> (VName, SOAC.Input)
-> m ([(VName, SOAC.Input)], [(VName, SOAC.Input)])
inspect (remPs, newInps) (_, SOAC.Input ts v _)
| Just ((p,pInp), remPs') <- findParam remPs v =
let pInp' = SOAC.transformRows ts pInp
in return (remPs',
(p, pInp') : newInps)
| Just ((p,pInp), _) <- findParam newInps v = do
p' <- newNameFromString $ baseString p
return (remPs, (p', pInp) : newInps)
inspect (remPs, newInps) (param, SOAC.Input ts a t) = do
param' <- newNameFromString (baseString param ++ "_rep")
return (remPs, (param',
SOAC.Input (ts SOAC.|> SOAC.Replicate mempty (Shape [w])) a t) : newInps)