{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE TypeFamilies #-}
module Futhark.IR.Prop
( module Futhark.IR.Prop.Reshape,
module Futhark.IR.Prop.Rearrange,
module Futhark.IR.Prop.Types,
module Futhark.IR.Prop.Constants,
module Futhark.IR.Prop.TypeOf,
module Futhark.IR.Prop.Patterns,
module Futhark.IR.Prop.Names,
module Futhark.IR.RetType,
isBuiltInFunction,
builtInFunctions,
asBasicOp,
safeExp,
subExpVars,
subExpVar,
commutativeLambda,
entryPointSize,
defAux,
stmCerts,
certify,
expExtTypesFromPattern,
attrsForAssert,
ASTConstraints,
IsOp (..),
ASTLore (..),
)
where
import Data.List (find)
import qualified Data.Map.Strict as M
import Data.Maybe (isJust, mapMaybe)
import qualified Data.Set as S
import Futhark.IR.Pretty
import Futhark.IR.Prop.Constants
import Futhark.IR.Prop.Names
import Futhark.IR.Prop.Patterns
import Futhark.IR.Prop.Rearrange
import Futhark.IR.Prop.Reshape
import Futhark.IR.Prop.TypeOf
import Futhark.IR.Prop.Types
import Futhark.IR.RetType
import Futhark.IR.Syntax
import Futhark.Transform.Rename (Rename, Renameable)
import Futhark.Transform.Substitute (Substitutable, Substitute)
import Futhark.Util.Pretty
isBuiltInFunction :: Name -> Bool
isBuiltInFunction :: Name -> Bool
isBuiltInFunction Name
fnm = Name
fnm Name -> Map Name (PrimType, [PrimType]) -> Bool
forall k a. Ord k => k -> Map k a -> Bool
`M.member` Map Name (PrimType, [PrimType])
builtInFunctions
builtInFunctions :: M.Map Name (PrimType, [PrimType])
builtInFunctions :: Map Name (PrimType, [PrimType])
builtInFunctions = [(Name, (PrimType, [PrimType]))] -> Map Name (PrimType, [PrimType])
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([(Name, (PrimType, [PrimType]))]
-> Map Name (PrimType, [PrimType]))
-> [(Name, (PrimType, [PrimType]))]
-> Map Name (PrimType, [PrimType])
forall a b. (a -> b) -> a -> b
$ ((String, ([PrimType], PrimType, [PrimValue] -> Maybe PrimValue))
-> (Name, (PrimType, [PrimType])))
-> [(String,
([PrimType], PrimType, [PrimValue] -> Maybe PrimValue))]
-> [(Name, (PrimType, [PrimType]))]
forall a b. (a -> b) -> [a] -> [b]
map (String, ([PrimType], PrimType, [PrimValue] -> Maybe PrimValue))
-> (Name, (PrimType, [PrimType]))
forall {b} {a} {c}. (String, (b, a, c)) -> (Name, (a, b))
namify ([(String, ([PrimType], PrimType, [PrimValue] -> Maybe PrimValue))]
-> [(Name, (PrimType, [PrimType]))])
-> [(String,
([PrimType], PrimType, [PrimValue] -> Maybe PrimValue))]
-> [(Name, (PrimType, [PrimType]))]
forall a b. (a -> b) -> a -> b
$ Map String ([PrimType], PrimType, [PrimValue] -> Maybe PrimValue)
-> [(String,
([PrimType], PrimType, [PrimValue] -> Maybe PrimValue))]
forall k a. Map k a -> [(k, a)]
M.toList Map String ([PrimType], PrimType, [PrimValue] -> Maybe PrimValue)
primFuns
where
namify :: (String, (b, a, c)) -> (Name, (a, b))
namify (String
k, (b
paramts, a
ret, c
_)) = (String -> Name
nameFromString String
k, (a
ret, b
paramts))
asBasicOp :: Exp lore -> Maybe BasicOp
asBasicOp :: forall lore. Exp lore -> Maybe BasicOp
asBasicOp (BasicOp BasicOp
op) = BasicOp -> Maybe BasicOp
forall a. a -> Maybe a
Just BasicOp
op
asBasicOp ExpT lore
_ = Maybe BasicOp
forall a. Maybe a
Nothing
safeExp :: IsOp (Op lore) => Exp lore -> Bool
safeExp :: forall lore. IsOp (Op lore) => Exp lore -> Bool
safeExp (BasicOp BasicOp
op) = BasicOp -> Bool
safeBasicOp BasicOp
op
where
safeBasicOp :: BasicOp -> Bool
safeBasicOp (BinOp (SDiv IntType
_ Safety
Safe) SubExp
_ SubExp
_) = Bool
True
safeBasicOp (BinOp (SDivUp IntType
_ Safety
Safe) SubExp
_ SubExp
_) = Bool
True
safeBasicOp (BinOp (SQuot IntType
_ Safety
Safe) SubExp
_ SubExp
_) = Bool
True
safeBasicOp (BinOp (UDiv IntType
_ Safety
Safe) SubExp
_ SubExp
_) = Bool
True
safeBasicOp (BinOp (UDivUp IntType
_ Safety
Safe) SubExp
_ SubExp
_) = Bool
True
safeBasicOp (BinOp (SMod IntType
_ Safety
Safe) SubExp
_ SubExp
_) = Bool
True
safeBasicOp (BinOp (SRem IntType
_ Safety
Safe) SubExp
_ SubExp
_) = Bool
True
safeBasicOp (BinOp (UMod IntType
_ Safety
Safe) SubExp
_ SubExp
_) = Bool
True
safeBasicOp (BinOp SDiv {} SubExp
_ (Constant PrimValue
y)) = Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ PrimValue -> Bool
zeroIsh PrimValue
y
safeBasicOp (BinOp SDiv {} SubExp
_ SubExp
_) = Bool
False
safeBasicOp (BinOp SDivUp {} SubExp
_ (Constant PrimValue
y)) = Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ PrimValue -> Bool
zeroIsh PrimValue
y
safeBasicOp (BinOp SDivUp {} SubExp
_ SubExp
_) = Bool
False
safeBasicOp (BinOp UDiv {} SubExp
_ (Constant PrimValue
y)) = Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ PrimValue -> Bool
zeroIsh PrimValue
y
safeBasicOp (BinOp UDiv {} SubExp
_ SubExp
_) = Bool
False
safeBasicOp (BinOp UDivUp {} SubExp
_ (Constant PrimValue
y)) = Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ PrimValue -> Bool
zeroIsh PrimValue
y
safeBasicOp (BinOp UDivUp {} SubExp
_ SubExp
_) = Bool
False
safeBasicOp (BinOp SMod {} SubExp
_ (Constant PrimValue
y)) = Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ PrimValue -> Bool
zeroIsh PrimValue
y
safeBasicOp (BinOp SMod {} SubExp
_ SubExp
_) = Bool
False
safeBasicOp (BinOp UMod {} SubExp
_ (Constant PrimValue
y)) = Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ PrimValue -> Bool
zeroIsh PrimValue
y
safeBasicOp (BinOp UMod {} SubExp
_ SubExp
_) = Bool
False
safeBasicOp (BinOp SQuot {} SubExp
_ (Constant PrimValue
y)) = Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ PrimValue -> Bool
zeroIsh PrimValue
y
safeBasicOp (BinOp SQuot {} SubExp
_ SubExp
_) = Bool
False
safeBasicOp (BinOp SRem {} SubExp
_ (Constant PrimValue
y)) = Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ PrimValue -> Bool
zeroIsh PrimValue
y
safeBasicOp (BinOp SRem {} SubExp
_ SubExp
_) = Bool
False
safeBasicOp (BinOp Pow {} SubExp
_ (Constant PrimValue
y)) = Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ PrimValue -> Bool
negativeIsh PrimValue
y
safeBasicOp (BinOp Pow {} SubExp
_ SubExp
_) = Bool
False
safeBasicOp ArrayLit {} = Bool
True
safeBasicOp BinOp {} = Bool
True
safeBasicOp SubExp {} = Bool
True
safeBasicOp UnOp {} = Bool
True
safeBasicOp CmpOp {} = Bool
True
safeBasicOp ConvOp {} = Bool
True
safeBasicOp Scratch {} = Bool
True
safeBasicOp Concat {} = Bool
True
safeBasicOp Reshape {} = Bool
True
safeBasicOp Rearrange {} = Bool
True
safeBasicOp Manifest {} = Bool
True
safeBasicOp Iota {} = Bool
True
safeBasicOp Replicate {} = Bool
True
safeBasicOp Copy {} = Bool
True
safeBasicOp BasicOp
_ = Bool
False
safeExp (DoLoop [(FParam lore, SubExp)]
_ [(FParam lore, SubExp)]
_ LoopForm lore
_ BodyT lore
body) = BodyT lore -> Bool
forall lore. IsOp (Op lore) => Body lore -> Bool
safeBody BodyT lore
body
safeExp (Apply Name
fname [(SubExp, Diet)]
_ [RetType lore]
_ (Safety, SrcLoc, [SrcLoc])
_) =
Name -> Bool
isBuiltInFunction Name
fname
safeExp (If SubExp
_ BodyT lore
tbranch BodyT lore
fbranch IfDec (BranchType lore)
_) =
(Stm lore -> Bool) -> Seq (Stm lore) -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (ExpT lore -> Bool
forall lore. IsOp (Op lore) => Exp lore -> Bool
safeExp (ExpT lore -> Bool) -> (Stm lore -> ExpT lore) -> Stm lore -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stm lore -> ExpT lore
forall lore. Stm lore -> Exp lore
stmExp) (BodyT lore -> Seq (Stm lore)
forall lore. BodyT lore -> Stms lore
bodyStms BodyT lore
tbranch)
Bool -> Bool -> Bool
&& (Stm lore -> Bool) -> Seq (Stm lore) -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (ExpT lore -> Bool
forall lore. IsOp (Op lore) => Exp lore -> Bool
safeExp (ExpT lore -> Bool) -> (Stm lore -> ExpT lore) -> Stm lore -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stm lore -> ExpT lore
forall lore. Stm lore -> Exp lore
stmExp) (BodyT lore -> Seq (Stm lore)
forall lore. BodyT lore -> Stms lore
bodyStms BodyT lore
fbranch)
safeExp (Op Op lore
op) = Op lore -> Bool
forall op. IsOp op => op -> Bool
safeOp Op lore
op
safeBody :: IsOp (Op lore) => Body lore -> Bool
safeBody :: forall lore. IsOp (Op lore) => Body lore -> Bool
safeBody = (Stm lore -> Bool) -> Seq (Stm lore) -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (Exp lore -> Bool
forall lore. IsOp (Op lore) => Exp lore -> Bool
safeExp (Exp lore -> Bool) -> (Stm lore -> Exp lore) -> Stm lore -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stm lore -> Exp lore
forall lore. Stm lore -> Exp lore
stmExp) (Seq (Stm lore) -> Bool)
-> (BodyT lore -> Seq (Stm lore)) -> BodyT lore -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BodyT lore -> Seq (Stm lore)
forall lore. BodyT lore -> Stms lore
bodyStms
subExpVars :: [SubExp] -> [VName]
subExpVars :: [SubExp] -> [VName]
subExpVars = (SubExp -> Maybe VName) -> [SubExp] -> [VName]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe SubExp -> Maybe VName
subExpVar
subExpVar :: SubExp -> Maybe VName
subExpVar :: SubExp -> Maybe VName
subExpVar (Var VName
v) = VName -> Maybe VName
forall a. a -> Maybe a
Just VName
v
subExpVar Constant {} = Maybe VName
forall a. Maybe a
Nothing
commutativeLambda :: Lambda lore -> Bool
commutativeLambda :: forall lore. Lambda lore -> Bool
commutativeLambda Lambda lore
lam =
let body :: BodyT lore
body = Lambda lore -> BodyT lore
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda lore
lam
n2 :: Int
n2 = [Param (LParamInfo lore)] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (Lambda lore -> [Param (LParamInfo lore)]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda lore
lam) Int -> Int -> Int
forall a. Integral a => a -> a -> a
`div` Int
2
([Param (LParamInfo lore)]
xps, [Param (LParamInfo lore)]
yps) = Int
-> [Param (LParamInfo lore)]
-> ([Param (LParamInfo lore)], [Param (LParamInfo lore)])
forall a. Int -> [a] -> ([a], [a])
splitAt Int
n2 (Lambda lore -> [Param (LParamInfo lore)]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda lore
lam)
okComponent :: (Param (LParamInfo lore), Param (LParamInfo lore), SubExp) -> Bool
okComponent (Param (LParamInfo lore), Param (LParamInfo lore), SubExp)
c = Maybe (Stm lore) -> Bool
forall a. Maybe a -> Bool
isJust (Maybe (Stm lore) -> Bool) -> Maybe (Stm lore) -> Bool
forall a b. (a -> b) -> a -> b
$ (Stm lore -> Bool) -> Seq (Stm lore) -> Maybe (Stm lore)
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
find ((Param (LParamInfo lore), Param (LParamInfo lore), SubExp)
-> Stm lore -> Bool
forall {dec} {dec} {lore}.
(Param dec, Param dec, SubExp) -> Stm lore -> Bool
okBinOp (Param (LParamInfo lore), Param (LParamInfo lore), SubExp)
c) (Seq (Stm lore) -> Maybe (Stm lore))
-> Seq (Stm lore) -> Maybe (Stm lore)
forall a b. (a -> b) -> a -> b
$ BodyT lore -> Seq (Stm lore)
forall lore. BodyT lore -> Stms lore
bodyStms BodyT lore
body
okBinOp :: (Param dec, Param dec, SubExp) -> Stm lore -> Bool
okBinOp (Param dec
xp, Param dec
yp, Var VName
r) (Let (Pattern [] [PatElemT (LetDec lore)
pe]) StmAux (ExpDec lore)
_ (BasicOp (BinOp BinOp
op (Var VName
x) (Var VName
y)))) =
PatElemT (LetDec lore) -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec lore)
pe VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== VName
r
Bool -> Bool -> Bool
&& BinOp -> Bool
commutativeBinOp BinOp
op
Bool -> Bool -> Bool
&& ( (VName
x VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== Param dec -> VName
forall dec. Param dec -> VName
paramName Param dec
xp Bool -> Bool -> Bool
&& VName
y VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== Param dec -> VName
forall dec. Param dec -> VName
paramName Param dec
yp)
Bool -> Bool -> Bool
|| (VName
y VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== Param dec -> VName
forall dec. Param dec -> VName
paramName Param dec
xp Bool -> Bool -> Bool
&& VName
x VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== Param dec -> VName
forall dec. Param dec -> VName
paramName Param dec
yp)
)
okBinOp (Param dec, Param dec, SubExp)
_ Stm lore
_ = Bool
False
in Int
n2 Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
2 Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== [Param (LParamInfo lore)] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (Lambda lore -> [Param (LParamInfo lore)]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda lore
lam)
Bool -> Bool -> Bool
&& Int
n2 Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== [SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (BodyT lore -> [SubExp]
forall lore. BodyT lore -> [SubExp]
bodyResult BodyT lore
body)
Bool -> Bool -> Bool
&& ((Param (LParamInfo lore), Param (LParamInfo lore), SubExp)
-> Bool)
-> [(Param (LParamInfo lore), Param (LParamInfo lore), SubExp)]
-> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (Param (LParamInfo lore), Param (LParamInfo lore), SubExp) -> Bool
okComponent ([Param (LParamInfo lore)]
-> [Param (LParamInfo lore)]
-> [SubExp]
-> [(Param (LParamInfo lore), Param (LParamInfo lore), SubExp)]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [Param (LParamInfo lore)]
xps [Param (LParamInfo lore)]
yps ([SubExp]
-> [(Param (LParamInfo lore), Param (LParamInfo lore), SubExp)])
-> [SubExp]
-> [(Param (LParamInfo lore), Param (LParamInfo lore), SubExp)]
forall a b. (a -> b) -> a -> b
$ BodyT lore -> [SubExp]
forall lore. BodyT lore -> [SubExp]
bodyResult BodyT lore
body)
entryPointSize :: EntryPointType -> Int
entryPointSize :: EntryPointType -> Int
entryPointSize (TypeOpaque String
_ Int
x) = Int
x
entryPointSize EntryPointType
TypeUnsigned = Int
1
entryPointSize EntryPointType
TypeDirect = Int
1
defAux :: dec -> StmAux dec
defAux :: forall dec. dec -> StmAux dec
defAux = Certificates -> Attrs -> dec -> StmAux dec
forall dec. Certificates -> Attrs -> dec -> StmAux dec
StmAux Certificates
forall a. Monoid a => a
mempty Attrs
forall a. Monoid a => a
mempty
stmCerts :: Stm lore -> Certificates
stmCerts :: forall lore. Stm lore -> Certificates
stmCerts = StmAux (ExpDec lore) -> Certificates
forall dec. StmAux dec -> Certificates
stmAuxCerts (StmAux (ExpDec lore) -> Certificates)
-> (Stm lore -> StmAux (ExpDec lore)) -> Stm lore -> Certificates
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stm lore -> StmAux (ExpDec lore)
forall lore. Stm lore -> StmAux (ExpDec lore)
stmAux
certify :: Certificates -> Stm lore -> Stm lore
certify :: forall lore. Certificates -> Stm lore -> Stm lore
certify Certificates
cs1 (Let Pattern lore
pat (StmAux Certificates
cs2 Attrs
attrs ExpDec lore
dec) Exp lore
e) =
Pattern lore -> StmAux (ExpDec lore) -> Exp lore -> Stm lore
forall lore.
Pattern lore -> StmAux (ExpDec lore) -> Exp lore -> Stm lore
Let Pattern lore
pat (Certificates -> Attrs -> ExpDec lore -> StmAux (ExpDec lore)
forall dec. Certificates -> Attrs -> dec -> StmAux dec
StmAux (Certificates
cs2 Certificates -> Certificates -> Certificates
forall a. Semigroup a => a -> a -> a
<> Certificates
cs1) Attrs
attrs ExpDec lore
dec) Exp lore
e
type ASTConstraints a =
(Eq a, Ord a, Show a, Rename a, Substitute a, FreeIn a, Pretty a)
class (ASTConstraints op, TypedOp op) => IsOp op where
safeOp :: op -> Bool
cheapOp :: op -> Bool
instance IsOp () where
safeOp :: () -> Bool
safeOp () = Bool
True
cheapOp :: () -> Bool
cheapOp () = Bool
True
class
( Decorations lore,
PrettyLore lore,
Renameable lore,
Substitutable lore,
FreeDec (ExpDec lore),
FreeIn (LetDec lore),
FreeDec (BodyDec lore),
FreeIn (FParamInfo lore),
FreeIn (LParamInfo lore),
FreeIn (RetType lore),
FreeIn (BranchType lore),
IsOp (Op lore)
) =>
ASTLore lore
where
expTypesFromPattern ::
(HasScope lore m, Monad m) =>
Pattern lore ->
m [BranchType lore]
expExtTypesFromPattern :: Typed dec => PatternT dec -> [ExtType]
expExtTypesFromPattern :: forall dec. Typed dec => PatternT dec -> [ExtType]
expExtTypesFromPattern PatternT dec
pat =
[VName] -> [ExtType] -> [ExtType]
existentialiseExtTypes (PatternT dec -> [VName]
forall dec. PatternT dec -> [VName]
patternContextNames PatternT dec
pat) ([ExtType] -> [ExtType]) -> [ExtType] -> [ExtType]
forall a b. (a -> b) -> a -> b
$
[TypeBase Shape NoUniqueness] -> [ExtType]
forall u. [TypeBase Shape u] -> [TypeBase ExtShape u]
staticShapes ([TypeBase Shape NoUniqueness] -> [ExtType])
-> [TypeBase Shape NoUniqueness] -> [ExtType]
forall a b. (a -> b) -> a -> b
$ (PatElemT dec -> TypeBase Shape NoUniqueness)
-> [PatElemT dec] -> [TypeBase Shape NoUniqueness]
forall a b. (a -> b) -> [a] -> [b]
map PatElemT dec -> TypeBase Shape NoUniqueness
forall dec.
Typed dec =>
PatElemT dec -> TypeBase Shape NoUniqueness
patElemType ([PatElemT dec] -> [TypeBase Shape NoUniqueness])
-> [PatElemT dec] -> [TypeBase Shape NoUniqueness]
forall a b. (a -> b) -> a -> b
$ PatternT dec -> [PatElemT dec]
forall dec. PatternT dec -> [PatElemT dec]
patternValueElements PatternT dec
pat
attrsForAssert :: Attrs -> Attrs
attrsForAssert :: Attrs -> Attrs
attrsForAssert (Attrs Set Attr
attrs) =
Set Attr -> Attrs
Attrs (Set Attr -> Attrs) -> Set Attr -> Attrs
forall a b. (a -> b) -> a -> b
$ (Attr -> Bool) -> Set Attr -> Set Attr
forall a. (a -> Bool) -> Set a -> Set a
S.filter Attr -> Bool
attrForAssert Set Attr
attrs
where
attrForAssert :: Attr -> Bool
attrForAssert = (Attr -> Attr -> Bool
forall a. Eq a => a -> a -> Bool
== Name -> [Attr] -> Attr
AttrComp Name
"warn" [Attr
"safety_checks"])