{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE TypeFamilies #-}
module Futhark.IR.MCMem
( MCMem,
simplifyProg,
module Futhark.IR.Mem,
module Futhark.IR.SegOp,
module Futhark.IR.MC.Op,
)
where
import Futhark.Analysis.PrimExp.Convert
import Futhark.IR.MC.Op
import Futhark.IR.Mem
import Futhark.IR.Mem.Simplify
import Futhark.IR.SegOp
import qualified Futhark.Optimise.Simplify.Engine as Engine
import Futhark.Pass
import Futhark.Pass.ExplicitAllocations (BinderOps (..), mkLetNamesB', mkLetNamesB'')
import qualified Futhark.TypeCheck as TC
data MCMem
instance Decorations MCMem where
type LetDec MCMem = LetDecMem
type FParamInfo MCMem = FParamMem
type LParamInfo MCMem = LParamMem
type RetType MCMem = RetTypeMem
type BranchType MCMem = BranchTypeMem
type Op MCMem = MemOp (MCOp MCMem ())
instance ASTLore MCMem where
expTypesFromPattern :: forall (m :: * -> *).
(HasScope MCMem m, Monad m) =>
Pattern MCMem -> m [BranchType MCMem]
expTypesFromPattern = [BranchTypeMem] -> m [BranchTypeMem]
forall (m :: * -> *) a. Monad m => a -> m a
return ([BranchTypeMem] -> m [BranchTypeMem])
-> (PatternT LetDecMem -> [BranchTypeMem])
-> PatternT LetDecMem
-> m [BranchTypeMem]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((VName, BranchTypeMem) -> BranchTypeMem)
-> [(VName, BranchTypeMem)] -> [BranchTypeMem]
forall a b. (a -> b) -> [a] -> [b]
map (VName, BranchTypeMem) -> BranchTypeMem
forall a b. (a, b) -> b
snd ([(VName, BranchTypeMem)] -> [BranchTypeMem])
-> (PatternT LetDecMem -> [(VName, BranchTypeMem)])
-> PatternT LetDecMem
-> [BranchTypeMem]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ([(VName, BranchTypeMem)], [(VName, BranchTypeMem)])
-> [(VName, BranchTypeMem)]
forall a b. (a, b) -> b
snd (([(VName, BranchTypeMem)], [(VName, BranchTypeMem)])
-> [(VName, BranchTypeMem)])
-> (PatternT LetDecMem
-> ([(VName, BranchTypeMem)], [(VName, BranchTypeMem)]))
-> PatternT LetDecMem
-> [(VName, BranchTypeMem)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PatternT LetDecMem
-> ([(VName, BranchTypeMem)], [(VName, BranchTypeMem)])
bodyReturnsFromPattern
instance OpReturns MCMem where
opReturns :: forall (m :: * -> *).
(Monad m, HasScope MCMem m) =>
Op MCMem -> m [ExpReturns]
opReturns (Alloc SubExp
_ Space
space) = [ExpReturns] -> m [ExpReturns]
forall (m :: * -> *) a. Monad m => a -> m a
return [Space -> ExpReturns
forall d u ret. Space -> MemInfo d u ret
MemMem Space
space]
opReturns (Inner (ParOp Maybe (SegOp () MCMem)
_ SegOp () MCMem
op)) = SegOp () MCMem -> m [ExpReturns]
forall lore (m :: * -> *) lvl.
(Mem lore, Monad m, HasScope lore m) =>
SegOp lvl lore -> m [ExpReturns]
segOpReturns SegOp () MCMem
op
opReturns (Inner (OtherOp ())) = [ExpReturns] -> m [ExpReturns]
forall (f :: * -> *) a. Applicative f => a -> f a
pure []
instance PrettyLore MCMem
instance TC.CheckableOp MCMem where
checkOp :: OpWithAliases (Op MCMem) -> TypeM MCMem ()
checkOp = OpWithAliases (Op MCMem) -> TypeM MCMem ()
forall {lore}.
Checkable lore =>
MemOp (MCOp (Aliases lore) ()) -> TypeM lore ()
typeCheckMemoryOp
where
typeCheckMemoryOp :: MemOp (MCOp (Aliases lore) ()) -> TypeM lore ()
typeCheckMemoryOp (Alloc SubExp
size Space
_) =
[TypeBase (ShapeBase SubExp) NoUniqueness]
-> SubExp -> TypeM lore ()
forall lore.
Checkable lore =>
[TypeBase (ShapeBase SubExp) NoUniqueness]
-> SubExp -> TypeM lore ()
TC.require [PrimType -> TypeBase (ShapeBase SubExp) NoUniqueness
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64] SubExp
size
typeCheckMemoryOp (Inner MCOp (Aliases lore) ()
op) =
(() -> TypeM lore ()) -> MCOp (Aliases lore) () -> TypeM lore ()
forall lore op.
Checkable lore =>
(op -> TypeM lore ()) -> MCOp (Aliases lore) op -> TypeM lore ()
typeCheckMCOp () -> TypeM lore ()
forall (f :: * -> *) a. Applicative f => a -> f a
pure MCOp (Aliases lore) ()
op
instance TC.Checkable MCMem where
checkFParamLore :: VName -> FParamInfo MCMem -> TypeM MCMem ()
checkFParamLore = VName -> FParamInfo MCMem -> TypeM MCMem ()
forall lore u.
Checkable lore =>
VName -> MemInfo SubExp u MemBind -> TypeM lore ()
checkMemInfo
checkLParamLore :: VName -> LParamInfo MCMem -> TypeM MCMem ()
checkLParamLore = VName -> LParamInfo MCMem -> TypeM MCMem ()
forall lore u.
Checkable lore =>
VName -> MemInfo SubExp u MemBind -> TypeM lore ()
checkMemInfo
checkLetBoundLore :: VName -> LetDec MCMem -> TypeM MCMem ()
checkLetBoundLore = VName -> LetDec MCMem -> TypeM MCMem ()
forall lore u.
Checkable lore =>
VName -> MemInfo SubExp u MemBind -> TypeM lore ()
checkMemInfo
checkRetType :: [RetType MCMem] -> TypeM MCMem ()
checkRetType = (RetTypeMem -> TypeM MCMem ()) -> [RetTypeMem] -> TypeM MCMem ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (TypeBase (ShapeBase (Ext SubExp)) Uniqueness -> TypeM MCMem ()
forall lore u.
Checkable lore =>
TypeBase (ShapeBase (Ext SubExp)) u -> TypeM lore ()
TC.checkExtType (TypeBase (ShapeBase (Ext SubExp)) Uniqueness -> TypeM MCMem ())
-> (RetTypeMem -> TypeBase (ShapeBase (Ext SubExp)) Uniqueness)
-> RetTypeMem
-> TypeM MCMem ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. RetTypeMem -> TypeBase (ShapeBase (Ext SubExp)) Uniqueness
forall t.
DeclExtTyped t =>
t -> TypeBase (ShapeBase (Ext SubExp)) Uniqueness
declExtTypeOf)
primFParam :: VName -> PrimType -> TypeM MCMem (FParam (Aliases MCMem))
primFParam VName
name PrimType
t = Param FParamMem -> TypeM MCMem (Param FParamMem)
forall (m :: * -> *) a. Monad m => a -> m a
return (Param FParamMem -> TypeM MCMem (Param FParamMem))
-> Param FParamMem -> TypeM MCMem (Param FParamMem)
forall a b. (a -> b) -> a -> b
$ VName -> FParamMem -> Param FParamMem
forall dec. VName -> dec -> Param dec
Param VName
name (PrimType -> FParamMem
forall d u ret. PrimType -> MemInfo d u ret
MemPrim PrimType
t)
matchPattern :: Pattern (Aliases MCMem) -> Exp (Aliases MCMem) -> TypeM MCMem ()
matchPattern = Pattern (Aliases MCMem) -> Exp (Aliases MCMem) -> TypeM MCMem ()
forall lore.
(Mem lore, Checkable lore) =>
Pattern (Aliases lore) -> Exp (Aliases lore) -> TypeM lore ()
matchPatternToExp
matchReturnType :: [RetType MCMem] -> [SubExp] -> TypeM MCMem ()
matchReturnType = [RetType MCMem] -> [SubExp] -> TypeM MCMem ()
forall lore.
(Mem lore, Checkable lore) =>
[RetTypeMem] -> [SubExp] -> TypeM lore ()
matchFunctionReturnType
matchBranchType :: [BranchType MCMem] -> Body (Aliases MCMem) -> TypeM MCMem ()
matchBranchType = [BranchType MCMem] -> Body (Aliases MCMem) -> TypeM MCMem ()
forall lore.
(Mem lore, Checkable lore) =>
[BranchTypeMem] -> Body (Aliases lore) -> TypeM lore ()
matchBranchReturnType
matchLoopResult :: [FParam (Aliases MCMem)]
-> [FParam (Aliases MCMem)] -> [SubExp] -> TypeM MCMem ()
matchLoopResult = [FParam (Aliases MCMem)]
-> [FParam (Aliases MCMem)] -> [SubExp] -> TypeM MCMem ()
forall lore.
(Mem lore, Checkable lore) =>
[FParam (Aliases lore)]
-> [FParam (Aliases lore)] -> [SubExp] -> TypeM lore ()
matchLoopResultMem
instance BinderOps MCMem where
mkExpDecB :: forall (m :: * -> *).
(MonadBinder m, Lore m ~ MCMem) =>
Pattern MCMem -> Exp MCMem -> m (ExpDec MCMem)
mkExpDecB Pattern MCMem
_ Exp MCMem
_ = () -> m ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
mkBodyB :: forall (m :: * -> *).
(MonadBinder m, Lore m ~ MCMem) =>
Stms MCMem -> [SubExp] -> m (Body MCMem)
mkBodyB Stms MCMem
stms [SubExp]
res = Body MCMem -> m (Body MCMem)
forall (m :: * -> *) a. Monad m => a -> m a
return (Body MCMem -> m (Body MCMem)) -> Body MCMem -> m (Body MCMem)
forall a b. (a -> b) -> a -> b
$ BodyDec MCMem -> Stms MCMem -> [SubExp] -> Body MCMem
forall lore. BodyDec lore -> Stms lore -> [SubExp] -> BodyT lore
Body () Stms MCMem
stms [SubExp]
res
mkLetNamesB :: forall (m :: * -> *).
(MonadBinder m, Lore m ~ MCMem) =>
[VName] -> Exp MCMem -> m (Stm MCMem)
mkLetNamesB = ExpDec (Lore m) -> [VName] -> Exp (Lore m) -> m (Stm (Lore m))
forall (m :: * -> *) inner.
(Op (Lore m) ~ MemOp inner, MonadBinder m, ExpDec (Lore m) ~ (),
Allocator (Lore m) (PatAllocM (Lore m))) =>
ExpDec (Lore m) -> [VName] -> Exp (Lore m) -> m (Stm (Lore m))
mkLetNamesB' ()
instance BinderOps (Engine.Wise MCMem) where
mkExpDecB :: forall (m :: * -> *).
(MonadBinder m, Lore m ~ Wise MCMem) =>
Pattern (Wise MCMem) -> Exp (Wise MCMem) -> m (ExpDec (Wise MCMem))
mkExpDecB Pattern (Wise MCMem)
pat Exp (Wise MCMem)
e = (ExpWisdom, ()) -> m (ExpWisdom, ())
forall (m :: * -> *) a. Monad m => a -> m a
return ((ExpWisdom, ()) -> m (ExpWisdom, ()))
-> (ExpWisdom, ()) -> m (ExpWisdom, ())
forall a b. (a -> b) -> a -> b
$ Pattern (Wise MCMem)
-> ExpDec MCMem -> Exp (Wise MCMem) -> ExpDec (Wise MCMem)
forall lore.
(ASTLore lore, CanBeWise (Op lore)) =>
Pattern (Wise lore)
-> ExpDec lore -> Exp (Wise lore) -> ExpDec (Wise lore)
Engine.mkWiseExpDec Pattern (Wise MCMem)
pat () Exp (Wise MCMem)
e
mkBodyB :: forall (m :: * -> *).
(MonadBinder m, Lore m ~ Wise MCMem) =>
Stms (Wise MCMem) -> [SubExp] -> m (Body (Wise MCMem))
mkBodyB Stms (Wise MCMem)
stms [SubExp]
res = Body (Wise MCMem) -> m (Body (Wise MCMem))
forall (m :: * -> *) a. Monad m => a -> m a
return (Body (Wise MCMem) -> m (Body (Wise MCMem)))
-> Body (Wise MCMem) -> m (Body (Wise MCMem))
forall a b. (a -> b) -> a -> b
$ BodyDec MCMem -> Stms (Wise MCMem) -> [SubExp] -> Body (Wise MCMem)
forall lore.
(ASTLore lore, CanBeWise (Op lore)) =>
BodyDec lore -> Stms (Wise lore) -> [SubExp] -> Body (Wise lore)
Engine.mkWiseBody () Stms (Wise MCMem)
stms [SubExp]
res
mkLetNamesB :: forall (m :: * -> *).
(MonadBinder m, Lore m ~ Wise MCMem) =>
[VName] -> Exp (Wise MCMem) -> m (Stm (Wise MCMem))
mkLetNamesB = [VName] -> Exp (Wise MCMem) -> m (Stm (Wise MCMem))
forall (m :: * -> *) inner lore.
(Op (Lore m) ~ MemOp inner, ExpDec lore ~ (),
HasScope (Wise lore) m, Allocator lore (PatAllocM lore),
MonadBinder m, CanBeWise (Op lore)) =>
[VName] -> Exp (Wise lore) -> m (Stm (Wise lore))
mkLetNamesB''
simplifyProg :: Prog MCMem -> PassM (Prog MCMem)
simplifyProg :: Prog MCMem -> PassM (Prog MCMem)
simplifyProg = SimpleOps MCMem -> Prog MCMem -> PassM (Prog MCMem)
forall lore inner.
(SimplifyMemory lore, Op lore ~ MemOp inner) =>
SimpleOps lore -> Prog lore -> PassM (Prog lore)
simplifyProgGeneric SimpleOps MCMem
simpleMCMem
simpleMCMem :: Engine.SimpleOps MCMem
simpleMCMem :: SimpleOps MCMem
simpleMCMem =
(OpWithWisdom (MCOp MCMem ()) -> UsageTable)
-> SimplifyOp MCMem (MCOp MCMem ()) -> SimpleOps MCMem
forall lore inner.
(SimplifyMemory lore, Op lore ~ MemOp inner) =>
(OpWithWisdom inner -> UsageTable)
-> SimplifyOp lore inner -> SimpleOps lore
simpleGeneric (UsageTable -> MCOp (Wise MCMem) () -> UsageTable
forall a b. a -> b -> a
const UsageTable
forall a. Monoid a => a
mempty) (SimplifyOp MCMem (MCOp MCMem ()) -> SimpleOps MCMem)
-> SimplifyOp MCMem (MCOp MCMem ()) -> SimpleOps MCMem
forall a b. (a -> b) -> a -> b
$ SimplifyOp MCMem ()
-> MCOp MCMem ()
-> SimpleM
MCMem (MCOp (Wise MCMem) (OpWithWisdom ()), Stms (Wise MCMem))
forall lore op.
(SimplifiableLore lore, BodyDec lore ~ ()) =>
SimplifyOp lore op
-> MCOp lore op
-> SimpleM
lore (MCOp (Wise lore) (OpWithWisdom op), Stms (Wise lore))
simplifyMCOp (SimplifyOp MCMem ()
-> MCOp MCMem ()
-> SimpleM
MCMem (MCOp (Wise MCMem) (OpWithWisdom ()), Stms (Wise MCMem)))
-> SimplifyOp MCMem ()
-> MCOp MCMem ()
-> SimpleM
MCMem (MCOp (Wise MCMem) (OpWithWisdom ()), Stms (Wise MCMem))
forall a b. (a -> b) -> a -> b
$ SimpleM MCMem ((), Stms (Wise MCMem))
-> () -> SimpleM MCMem ((), Stms (Wise MCMem))
forall a b. a -> b -> a
const (SimpleM MCMem ((), Stms (Wise MCMem))
-> () -> SimpleM MCMem ((), Stms (Wise MCMem)))
-> SimpleM MCMem ((), Stms (Wise MCMem))
-> ()
-> SimpleM MCMem ((), Stms (Wise MCMem))
forall a b. (a -> b) -> a -> b
$ ((), Stms (Wise MCMem)) -> SimpleM MCMem ((), Stms (Wise MCMem))
forall (m :: * -> *) a. Monad m => a -> m a
return ((), Stms (Wise MCMem)
forall a. Monoid a => a
mempty)