{-# LANGUAGE FlexibleContexts #-}
module Futhark.Optimise.Simplify.ClosedForm
( foldClosedForm
, loopClosedForm
)
where
import Control.Monad
import Data.Maybe
import qualified Data.Map.Strict as M
import qualified Data.Set as S
import Futhark.Construct
import Futhark.Representation.AST
import Futhark.Transform.Rename
import Futhark.Optimise.Simplify.Rule
type VarLookup lore = VName -> Maybe (Exp lore, Certificates)
foldClosedForm :: (Attributes lore, BinderOps lore) =>
VarLookup lore
-> Pattern lore
-> Lambda lore
-> [SubExp] -> [VName]
-> RuleM lore ()
foldClosedForm look pat lam accs arrs = do
inputsize <- arraysSize 0 <$> mapM lookupType arrs
t <- case patternTypes pat of [Prim t] -> return t
_ -> cannotSimplify
closedBody <- checkResults (patternNames pat) inputsize mempty knownBnds
(map paramName (lambdaParams lam))
(lambdaBody lam) accs
isEmpty <- newVName "fold_input_is_empty"
letBindNames_ [isEmpty] $
BasicOp $ CmpOp (CmpEq int32) inputsize (intConst Int32 0)
letBind_ pat =<< (If (Var isEmpty)
<$> resultBodyM accs
<*> renameBody closedBody
<*> pure (IfAttr [primBodyType t] IfNormal))
where knownBnds = determineKnownBindings look lam accs arrs
loopClosedForm :: (Attributes lore, BinderOps lore) =>
Pattern lore
-> [(FParam lore,SubExp)]
-> Names -> SubExp -> Body lore
-> RuleM lore ()
loopClosedForm pat merge i bound body = do
t <- case patternTypes pat of [Prim t] -> return t
_ -> cannotSimplify
closedBody <- checkResults mergenames bound i knownBnds
(map identName mergeidents) body mergeexp
isEmpty <- newVName "bound_is_zero"
letBindNames_ [isEmpty] $
BasicOp $ CmpOp (CmpSlt Int32) bound (intConst Int32 0)
letBind_ pat =<< (If (Var isEmpty)
<$> resultBodyM mergeexp
<*> renameBody closedBody
<*> pure (IfAttr [primBodyType t] IfNormal))
where (mergepat, mergeexp) = unzip merge
mergeidents = map paramIdent mergepat
mergenames = map paramName mergepat
knownBnds = M.fromList $ zip mergenames mergeexp
checkResults :: BinderOps lore =>
[VName]
-> SubExp
-> Names
-> M.Map VName SubExp
-> [VName]
-> Body lore
-> [SubExp]
-> RuleM lore (Body lore)
checkResults pat size untouchable knownBnds params body accs = do
((), bnds) <- collectStms $
zipWithM_ checkResult (zip pat res) (zip accparams accs)
mkBodyM bnds $ map Var pat
where bndMap = makeBindMap body
(accparams, _) = splitAt (length accs) params
res = bodyResult body
nonFree = boundInBody body <>
S.fromList params <>
untouchable
checkResult (p, Var v) (accparam, acc)
| Just (BasicOp (BinOp bop x y)) <- M.lookup v bndMap = do
let isThisAccum = (==Var accparam)
(this, el) <- liftMaybe $
case ((asFreeSubExp x, isThisAccum y),
(asFreeSubExp y, isThisAccum x)) of
((Just free, True), _) -> Just (acc, free)
(_, (Just free, True)) -> Just (acc, free)
_ -> Nothing
case bop of
LogAnd ->
letBindNames_ [p] $ BasicOp $ BinOp LogAnd this el
Add t | Just properly_typed_size <- properIntSize t -> do
size' <- properly_typed_size
letBindNames_ [p] =<<
eBinOp (Add t) (eSubExp this)
(pure $ BasicOp $ BinOp (Mul t) el size')
FAdd t | Just properly_typed_size <- properFloatSize t -> do
size' <- properly_typed_size
letBindNames_ [p] =<<
eBinOp (FAdd t) (eSubExp this)
(pure $ BasicOp $ BinOp (FMul t) el size')
_ -> cannotSimplify
checkResult _ _ = cannotSimplify
asFreeSubExp :: SubExp -> Maybe SubExp
asFreeSubExp (Var v)
| S.member v nonFree = M.lookup v knownBnds
asFreeSubExp se = Just se
properIntSize Int32 = Just $ return size
properIntSize t = Just $ letSubExp "converted_size" $
BasicOp $ ConvOp (SExt Int32 t) size
properFloatSize t =
Just $ letSubExp "converted_size" $
BasicOp $ ConvOp (SIToFP Int32 t) size
determineKnownBindings :: VarLookup lore -> Lambda lore -> [SubExp] -> [VName]
-> M.Map VName SubExp
determineKnownBindings look lam accs arrs =
accBnds <> arrBnds
where (accparams, arrparams) =
splitAt (length accs) $ lambdaParams lam
accBnds = M.fromList $
zip (map paramName accparams) accs
arrBnds = M.fromList $ mapMaybe isReplicate $
zip (map paramName arrparams) arrs
isReplicate (p, v)
| Just (BasicOp (Replicate _ ve), cs) <- look v,
cs == mempty = Just (p, ve)
isReplicate _ = Nothing
makeBindMap :: Body lore -> M.Map VName (Exp lore)
makeBindMap = M.fromList . mapMaybe isSingletonStm . stmsToList . bodyStms
where isSingletonStm (Let pat _ e) = case patternNames pat of
[v] -> Just (v,e)
_ -> Nothing