{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE TypeSynonymInstances #-}
{-# LANGUAGE FlexibleInstances #-}
module Futhark.Representation.AST.Attributes.Ranges
( Bound
, KnownBound (..)
, boundToScalExp
, minimumBound
, maximumBound
, Range
, unknownRange
, ScalExpRange
, Ranged
, RangeOf (..)
, RangesOf (..)
, expRanges
, RangedOp (..)
, CanBeRanged (..)
)
where
import Data.Monoid ((<>))
import qualified Data.Set as S
import qualified Data.Map.Strict as M
import Futhark.Representation.AST.Attributes
import Futhark.Representation.AST.Syntax
import qualified Futhark.Analysis.ScalExp as SE
import qualified Futhark.Analysis.AlgSimplify as AS
import Futhark.Transform.Substitute
import Futhark.Transform.Rename
import qualified Futhark.Util.Pretty as PP
data KnownBound = VarBound VName
| MinimumBound KnownBound KnownBound
| MaximumBound KnownBound KnownBound
| ScalarBound SE.ScalExp
deriving (Eq, Ord, Show)
instance Substitute KnownBound where
substituteNames substs (VarBound name) =
VarBound $ substituteNames substs name
substituteNames substs (MinimumBound b1 b2) =
MinimumBound (substituteNames substs b1) (substituteNames substs b2)
substituteNames substs (MaximumBound b1 b2) =
MaximumBound (substituteNames substs b1) (substituteNames substs b2)
substituteNames substs (ScalarBound se) =
ScalarBound $ substituteNames substs se
instance Rename KnownBound where
rename = substituteRename
instance FreeIn KnownBound where
freeIn (VarBound v) = freeIn v
freeIn (MinimumBound b1 b2) = freeIn b1 <> freeIn b2
freeIn (MaximumBound b1 b2) = freeIn b1 <> freeIn b2
freeIn (ScalarBound e) = freeIn e
instance FreeAttr KnownBound where
precomputed _ = id
instance PP.Pretty KnownBound where
ppr (VarBound v) =
PP.text "variable " <> PP.ppr v
ppr (MinimumBound b1 b2) =
PP.text "min" <> PP.parens (PP.ppr b1 <> PP.comma PP.<+> PP.ppr b2)
ppr (MaximumBound b1 b2) =
PP.text "max" <> PP.parens (PP.ppr b1 <> PP.comma PP.<+> PP.ppr b2)
ppr (ScalarBound e) =
PP.ppr e
boundToScalExp :: KnownBound -> Maybe SE.ScalExp
boundToScalExp (VarBound _) = Nothing
boundToScalExp (ScalarBound se) = Just se
boundToScalExp (MinimumBound b1 b2) = do
b1' <- boundToScalExp b1
b2' <- boundToScalExp b2
return $ SE.MaxMin True [b1', b2']
boundToScalExp (MaximumBound b1 b2) = do
b1' <- boundToScalExp b1
b2' <- boundToScalExp b2
return $ SE.MaxMin False [b1', b2']
type Bound = Maybe KnownBound
minimumBound :: Bound -> Bound -> Bound
minimumBound (Just x) (Just y) = Just $ MinimumBound x y
minimumBound _ _ = Nothing
maximumBound :: Bound -> Bound -> Bound
maximumBound (Just x) (Just y) = Just $ MaximumBound x y
maximumBound _ _ = Nothing
type Range = (Bound, Bound)
unknownRange :: Range
unknownRange = (Nothing, Nothing)
type ScalExpRange = (Maybe SE.ScalExp, Maybe SE.ScalExp)
type Ranged lore = (Attributes lore,
RangedOp (Op lore),
RangeOf (LetAttr lore),
RangesOf (BodyAttr lore))
class RangeOf a where
rangeOf :: a -> Range
instance RangeOf Range where
rangeOf = id
instance RangeOf attr => RangeOf (PatElemT attr) where
rangeOf = rangeOf . patElemAttr
instance RangeOf SubExp where
rangeOf se = (Just lower, Just upper)
where (lower, upper) = subExpKnownRange se
class RangesOf a where
rangesOf :: a -> [Range]
instance RangeOf a => RangesOf [a] where
rangesOf = map rangeOf
instance RangeOf attr => RangesOf (PatternT attr) where
rangesOf = map rangeOf . patternElements
instance Ranged lore => RangesOf (Body lore) where
rangesOf = rangesOf . bodyAttr
subExpKnownRange :: SubExp -> (KnownBound, KnownBound)
subExpKnownRange (Var v) =
(VarBound v,
VarBound v)
subExpKnownRange (Constant val) =
(ScalarBound $ SE.Val val,
ScalarBound $ SE.Val val)
scalExpRange :: SE.ScalExp -> Range
scalExpRange se =
(Just $ ScalarBound se, Just $ ScalarBound se)
primOpRanges :: BasicOp lore -> [Range]
primOpRanges (SubExp se) =
[rangeOf se]
primOpRanges (BinOp (Add t) x y) =
[scalExpRange $ SE.SPlus (SE.subExpToScalExp x $ IntType t) (SE.subExpToScalExp y $ IntType t)]
primOpRanges (BinOp (Sub t) x y) =
[scalExpRange $ SE.SMinus (SE.subExpToScalExp x $ IntType t) (SE.subExpToScalExp y $ IntType t)]
primOpRanges (BinOp (Mul t) x y) =
[scalExpRange $ SE.STimes (SE.subExpToScalExp x $ IntType t) (SE.subExpToScalExp y $ IntType t)]
primOpRanges (BinOp (SDiv t) x y) =
[scalExpRange $ SE.SDiv (SE.subExpToScalExp x $ IntType t) (SE.subExpToScalExp y $ IntType t)]
primOpRanges (ConvOp (SExt from to) x)
| from < to = [rangeOf x]
primOpRanges (Iota n x s Int32) =
[(Just $ ScalarBound x',
Just $ ScalarBound $ x' + (n' - 1) * s')]
where n' = case n of
Var v -> SE.Id v $ IntType Int32
Constant val -> SE.Val val
x' = case x of
Var v -> SE.Id v $ IntType Int32
Constant val -> SE.Val val
s' = case s of
Var v -> SE.Id v $ IntType Int32
Constant val -> SE.Val val
primOpRanges (Replicate _ v) =
[rangeOf v]
primOpRanges (Rearrange _ v) =
[rangeOf $ Var v]
primOpRanges (Copy se) =
[rangeOf $ Var se]
primOpRanges (Index v _) =
[rangeOf $ Var v]
primOpRanges (ArrayLit (e:es) _) =
[(Just lower, Just upper)]
where (e_lower, e_upper) = subExpKnownRange e
(es_lower, es_upper) = unzip $ map subExpKnownRange es
lower = foldl MinimumBound e_lower es_lower
upper = foldl MaximumBound e_upper es_upper
primOpRanges _ =
[unknownRange]
expRanges :: Ranged lore =>
Exp lore -> [Range]
expRanges (BasicOp op) =
primOpRanges op
expRanges (If _ tbranch fbranch _) =
zip
(zipWith minimumBound t_lower f_lower)
(zipWith maximumBound t_upper f_upper)
where (t_lower, t_upper) = unzip $ rangesOf tbranch
(f_lower, f_upper) = unzip $ rangesOf fbranch
expRanges (DoLoop ctxmerge valmerge (ForLoop i Int32 iterations _) body) =
zipWith returnedRange valmerge $ rangesOf body
where bound_in_loop =
S.fromList $ i : map (paramName . fst) (ctxmerge++valmerge) ++
concatMap (patternNames . stmPattern) (bodyStms body)
returnedRange mergeparam (lower, upper) =
(returnedBound mergeparam lower,
returnedBound mergeparam upper)
returnedBound (param, mergeinit) (Just bound)
| paramType param == Prim (IntType Int32),
Just bound' <- boundToScalExp bound,
let se_diff =
AS.simplify (SE.SMinus (SE.Id (paramName param) $ IntType Int32) bound') M.empty,
S.null $ S.intersection bound_in_loop $ freeIn se_diff =
Just $ ScalarBound $ SE.SPlus (SE.subExpToScalExp mergeinit $ IntType Int32) $
SE.STimes se_diff $ SE.MaxMin False
[SE.subExpToScalExp iterations $ IntType Int32, 0]
returnedBound _ _ = Nothing
expRanges (Op ranges) = opRanges ranges
expRanges e =
replicate (expExtTypeSize e) unknownRange
class IsOp op => RangedOp op where
opRanges :: op -> [Range]
instance RangedOp () where
opRanges () = []
class RangedOp (OpWithRanges op) =>
CanBeRanged op where
type OpWithRanges op :: *
removeOpRanges :: OpWithRanges op -> op
addOpRanges :: op -> OpWithRanges op
instance CanBeRanged () where
type OpWithRanges () = ()
removeOpRanges = id
addOpRanges = id