{-# LANGUAGE TypeFamilies, FlexibleContexts, FlexibleInstances, ConstraintKinds #-}
module Futhark.Representation.AST.Attributes
( module Futhark.Representation.AST.Attributes.Reshape
, module Futhark.Representation.AST.Attributes.Rearrange
, module Futhark.Representation.AST.Attributes.Types
, module Futhark.Representation.AST.Attributes.Constants
, module Futhark.Representation.AST.Attributes.TypeOf
, module Futhark.Representation.AST.Attributes.Patterns
, module Futhark.Representation.AST.Attributes.Names
, module Futhark.Representation.AST.RetType
, isBuiltInFunction
, builtInFunctions
, funDefByName
, asBasicOp
, safeExp
, subExpVars
, subExpVar
, shapeVars
, commutativeLambda
, entryPointSize
, defAux
, stmCerts
, certify
, expExtTypesFromPattern
, patternFromParams
, IsOp (..)
, Attributes (..)
)
where
import Data.List
import Data.Maybe (mapMaybe, isJust)
import Data.Monoid ((<>))
import qualified Data.Map.Strict as M
import Futhark.Representation.AST.Attributes.Reshape
import Futhark.Representation.AST.Attributes.Rearrange
import Futhark.Representation.AST.Attributes.Types
import Futhark.Representation.AST.Attributes.Constants
import Futhark.Representation.AST.Attributes.Patterns
import Futhark.Representation.AST.Attributes.Names
import Futhark.Representation.AST.Attributes.TypeOf
import Futhark.Representation.AST.RetType
import Futhark.Representation.AST.Syntax
import Futhark.Representation.AST.Pretty
import Futhark.Transform.Rename (Rename, Renameable)
import Futhark.Transform.Substitute (Substitute, Substitutable)
import Futhark.Util.Pretty
isBuiltInFunction :: Name -> Bool
isBuiltInFunction fnm = fnm `M.member` builtInFunctions
builtInFunctions :: M.Map Name (PrimType,[PrimType])
builtInFunctions = M.fromList $ map namify $ M.toList primFuns
where namify (k,(paramts,ret,_)) = (nameFromString k, (ret, paramts))
funDefByName :: Name -> Prog lore -> Maybe (FunDef lore)
funDefByName fname = find ((fname ==) . funDefName) . progFunctions
asBasicOp :: Exp lore -> Maybe (BasicOp lore)
asBasicOp (BasicOp op) = Just op
asBasicOp _ = Nothing
safeExp :: IsOp (Op lore) => Exp lore -> Bool
safeExp (BasicOp op) = safeBasicOp op
where safeBasicOp (BinOp SDiv{} _ (Constant y)) = not $ zeroIsh y
safeBasicOp (BinOp SDiv{} _ _) = False
safeBasicOp (BinOp UDiv{} _ (Constant y)) = not $ zeroIsh y
safeBasicOp (BinOp UDiv{} _ _) = False
safeBasicOp (BinOp SMod{} _ (Constant y)) = not $ zeroIsh y
safeBasicOp (BinOp SMod{} _ _) = False
safeBasicOp (BinOp UMod{} _ (Constant y)) = not $ zeroIsh y
safeBasicOp (BinOp UMod{} _ _) = False
safeBasicOp (BinOp SQuot{} _ (Constant y)) = not $ zeroIsh y
safeBasicOp (BinOp SQuot{} _ _) = False
safeBasicOp (BinOp SRem{} _ (Constant y)) = not $ zeroIsh y
safeBasicOp (BinOp SRem{} _ _) = False
safeBasicOp (BinOp Pow{} _ (Constant y)) = not $ negativeIsh y
safeBasicOp (BinOp Pow{} _ _) = False
safeBasicOp ArrayLit{} = True
safeBasicOp BinOp{} = True
safeBasicOp SubExp{} = True
safeBasicOp UnOp{} = True
safeBasicOp CmpOp{} = True
safeBasicOp ConvOp{} = True
safeBasicOp Scratch{} = True
safeBasicOp Concat{} = True
safeBasicOp Reshape{} = True
safeBasicOp Manifest{} = True
safeBasicOp Iota{} = True
safeBasicOp Replicate{} = True
safeBasicOp Copy{} = True
safeBasicOp _ = False
safeExp (DoLoop _ _ _ body) = safeBody body
safeExp (Apply fname _ _ _) = isBuiltInFunction fname
safeExp (If _ tbranch fbranch _) =
all (safeExp . stmExp) (bodyStms tbranch) &&
all (safeExp . stmExp) (bodyStms fbranch)
safeExp (Op op) = safeOp op
safeBody :: IsOp (Op lore) => Body lore -> Bool
safeBody = all (safeExp . stmExp) . bodyStms
subExpVars :: [SubExp] -> [VName]
subExpVars = mapMaybe subExpVar
subExpVar :: SubExp -> Maybe VName
subExpVar (Var v) = Just v
subExpVar Constant{} = Nothing
shapeVars :: Shape -> [VName]
shapeVars = subExpVars . shapeDims
commutativeLambda :: Lambda lore -> Bool
commutativeLambda lam =
let body = lambdaBody lam
n2 = length (lambdaParams lam) `div` 2
(xps,yps) = splitAt n2 (lambdaParams lam)
okComponent c = isJust $ find (okBinOp c) $ bodyStms body
okBinOp (xp,yp,Var r) (Let (Pattern [] [pe]) _ (BasicOp (BinOp op (Var x) (Var y)))) =
patElemName pe == r &&
commutativeBinOp op &&
((x == paramName xp && y == paramName yp) ||
(y == paramName xp && x == paramName yp))
okBinOp _ _ = False
in n2 * 2 == length (lambdaParams lam) &&
n2 == length (bodyResult body) &&
all okComponent (zip3 xps yps $ bodyResult body)
entryPointSize :: EntryPointType -> Int
entryPointSize (TypeOpaque _ x) = x
entryPointSize TypeUnsigned = 1
entryPointSize TypeDirect = 1
defAux :: attr -> StmAux attr
defAux = StmAux mempty
stmCerts :: Stm lore -> Certificates
stmCerts = stmAuxCerts . stmAux
certify :: Certificates -> Stm lore -> Stm lore
certify cs1 (Let pat (StmAux cs2 attr) e) = Let pat (StmAux (cs2<>cs1) attr) e
class (Eq op, Ord op, Show op,
TypedOp op,
Rename op,
Substitute op,
FreeIn op,
Pretty op) => IsOp op where
safeOp :: op -> Bool
cheapOp :: op -> Bool
instance IsOp () where
safeOp () = True
cheapOp () = True
class (Annotations lore,
PrettyLore lore,
Renameable lore, Substitutable lore,
FreeAttr (ExpAttr lore),
FreeIn (LetAttr lore),
FreeAttr (BodyAttr lore),
FreeIn (FParamAttr lore),
FreeIn (LParamAttr lore),
FreeIn (RetType lore),
FreeIn (BranchType lore),
IsOp (Op lore)) => Attributes lore where
expTypesFromPattern :: (HasScope lore m, Monad m) =>
Pattern lore -> m [BranchType lore]
expExtTypesFromPattern :: Typed attr => PatternT attr -> [ExtType]
expExtTypesFromPattern pat =
existentialiseExtTypes (patternContextNames pat) $
staticShapes $ map patElemType $ patternValueElements pat
patternFromParams :: [Param attr] -> PatternT attr
patternFromParams = Pattern [] . map toPatElem
where toPatElem p = PatElem (paramName p) $ paramAttr p