{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE ConstraintKinds #-}
module Futhark.Representation.SeqMem
  ( SeqMem

  -- * Simplification
  , simplifyProg
  , simpleSeqMem

    -- * Module re-exports
  , module Futhark.Representation.Mem
  , module Futhark.Representation.Kernels.Kernel
  )
  where

import Futhark.Analysis.PrimExp.Convert
import Futhark.Pass
import Futhark.Representation.AST.Syntax
import Futhark.Representation.AST.Attributes
import Futhark.Representation.AST.Traversals
import Futhark.Representation.AST.Pretty
import Futhark.Representation.Kernels.Kernel
import qualified Futhark.TypeCheck as TC
import Futhark.Representation.Mem
import Futhark.Representation.Mem.Simplify
import Futhark.Pass.ExplicitAllocations (BinderOps(..), mkLetNamesB', mkLetNamesB'')
import qualified Futhark.Optimise.Simplify.Engine as Engine

data SeqMem

instance Annotations SeqMem where
  type LetAttr    SeqMem = LetAttrMem
  type FParamAttr SeqMem = FParamMem
  type LParamAttr SeqMem = LParamMem
  type RetType    SeqMem = RetTypeMem
  type BranchType SeqMem = BranchTypeMem
  type Op         SeqMem = MemOp ()

instance Attributes SeqMem where
  expTypesFromPattern :: Pattern SeqMem -> m [BranchType SeqMem]
expTypesFromPattern = [BodyReturns] -> m [BodyReturns]
forall (m :: * -> *) a. Monad m => a -> m a
return ([BodyReturns] -> m [BodyReturns])
-> (PatternT (MemBound NoUniqueness) -> [BodyReturns])
-> PatternT (MemBound NoUniqueness)
-> m [BodyReturns]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((VName, BodyReturns) -> BodyReturns)
-> [(VName, BodyReturns)] -> [BodyReturns]
forall a b. (a -> b) -> [a] -> [b]
map (VName, BodyReturns) -> BodyReturns
forall a b. (a, b) -> b
snd ([(VName, BodyReturns)] -> [BodyReturns])
-> (PatternT (MemBound NoUniqueness) -> [(VName, BodyReturns)])
-> PatternT (MemBound NoUniqueness)
-> [BodyReturns]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ([(VName, BodyReturns)], [(VName, BodyReturns)])
-> [(VName, BodyReturns)]
forall a b. (a, b) -> b
snd (([(VName, BodyReturns)], [(VName, BodyReturns)])
 -> [(VName, BodyReturns)])
-> (PatternT (MemBound NoUniqueness)
    -> ([(VName, BodyReturns)], [(VName, BodyReturns)]))
-> PatternT (MemBound NoUniqueness)
-> [(VName, BodyReturns)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PatternT (MemBound NoUniqueness)
-> ([(VName, BodyReturns)], [(VName, BodyReturns)])
bodyReturnsFromPattern

instance OpReturns SeqMem where
  opReturns :: Op SeqMem -> m [ExpReturns]
opReturns (Alloc _ space) = [ExpReturns] -> m [ExpReturns]
forall (m :: * -> *) a. Monad m => a -> m a
return [Space -> ExpReturns
forall d u ret. Space -> MemInfo d u ret
MemMem Space
space]
  opReturns (Inner ()) = [ExpReturns] -> m [ExpReturns]
forall (f :: * -> *) a. Applicative f => a -> f a
pure []

instance PrettyLore SeqMem where

instance TC.CheckableOp SeqMem where
  checkOp :: OpWithAliases (Op SeqMem) -> TypeM SeqMem ()
checkOp (Alloc size _) =
    [Type] -> SubExp -> TypeM SeqMem ()
forall lore. Checkable lore => [Type] -> SubExp -> TypeM lore ()
TC.require [PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64] SubExp
size
  checkOp (Inner ()) =
    () -> TypeM SeqMem ()
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()

instance TC.Checkable SeqMem where
  checkFParamLore :: VName -> FParamAttr SeqMem -> TypeM SeqMem ()
checkFParamLore = VName -> FParamAttr SeqMem -> TypeM SeqMem ()
forall lore u.
Checkable lore =>
VName -> MemInfo SubExp u MemBind -> TypeM lore ()
checkMemInfo
  checkLParamLore :: VName -> LParamAttr SeqMem -> TypeM SeqMem ()
checkLParamLore = VName -> LParamAttr SeqMem -> TypeM SeqMem ()
forall lore u.
Checkable lore =>
VName -> MemInfo SubExp u MemBind -> TypeM lore ()
checkMemInfo
  checkLetBoundLore :: VName -> LetAttr SeqMem -> TypeM SeqMem ()
checkLetBoundLore = VName -> LetAttr SeqMem -> TypeM SeqMem ()
forall lore u.
Checkable lore =>
VName -> MemInfo SubExp u MemBind -> TypeM lore ()
checkMemInfo
  checkRetType :: [RetType SeqMem] -> TypeM SeqMem ()
checkRetType = (TypeBase ExtShape Uniqueness -> TypeM SeqMem ())
-> [TypeBase ExtShape Uniqueness] -> TypeM SeqMem ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ TypeBase ExtShape Uniqueness -> TypeM SeqMem ()
forall lore u.
Checkable lore =>
TypeBase ExtShape u -> TypeM lore ()
TC.checkExtType ([TypeBase ExtShape Uniqueness] -> TypeM SeqMem ())
-> ([RetTypeMem] -> [TypeBase ExtShape Uniqueness])
-> [RetTypeMem]
-> TypeM SeqMem ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [RetTypeMem] -> [TypeBase ExtShape Uniqueness]
forall rt. IsRetType rt => [rt] -> [TypeBase ExtShape Uniqueness]
retTypeValues
  primFParam :: VName -> PrimType -> TypeM SeqMem (FParam (Aliases SeqMem))
primFParam VName
name PrimType
t = Param (MemInfo SubExp Uniqueness MemBind)
-> TypeM SeqMem (Param (MemInfo SubExp Uniqueness MemBind))
forall (m :: * -> *) a. Monad m => a -> m a
return (Param (MemInfo SubExp Uniqueness MemBind)
 -> TypeM SeqMem (Param (MemInfo SubExp Uniqueness MemBind)))
-> Param (MemInfo SubExp Uniqueness MemBind)
-> TypeM SeqMem (Param (MemInfo SubExp Uniqueness MemBind))
forall a b. (a -> b) -> a -> b
$ VName
-> MemInfo SubExp Uniqueness MemBind
-> Param (MemInfo SubExp Uniqueness MemBind)
forall attr. VName -> attr -> Param attr
Param VName
name (PrimType -> MemInfo SubExp Uniqueness MemBind
forall d u ret. PrimType -> MemInfo d u ret
MemPrim PrimType
t)
  matchPattern :: Pattern (Aliases SeqMem) -> Exp (Aliases SeqMem) -> TypeM SeqMem ()
matchPattern = Pattern (Aliases SeqMem) -> Exp (Aliases SeqMem) -> TypeM SeqMem ()
forall lore.
Mem lore =>
Pattern (Aliases lore) -> Exp (Aliases lore) -> TypeM lore ()
matchPatternToExp
  matchReturnType :: [RetType SeqMem] -> Result -> TypeM SeqMem ()
matchReturnType = [RetType SeqMem] -> Result -> TypeM SeqMem ()
forall lore. Mem lore => [RetTypeMem] -> Result -> TypeM lore ()
matchFunctionReturnType
  matchBranchType :: [BranchType SeqMem] -> Body (Aliases SeqMem) -> TypeM SeqMem ()
matchBranchType = [BranchType SeqMem] -> Body (Aliases SeqMem) -> TypeM SeqMem ()
forall lore.
Mem lore =>
[BodyReturns] -> Body (Aliases lore) -> TypeM lore ()
matchBranchReturnType

instance BinderOps SeqMem where
  mkExpAttrB :: Pattern SeqMem -> Exp SeqMem -> m (ExpAttr SeqMem)
mkExpAttrB Pattern SeqMem
_ Exp SeqMem
_ = () -> m ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
  mkBodyB :: Stms SeqMem -> Result -> m (Body SeqMem)
mkBodyB Stms SeqMem
stms Result
res = Body SeqMem -> m (Body SeqMem)
forall (m :: * -> *) a. Monad m => a -> m a
return (Body SeqMem -> m (Body SeqMem)) -> Body SeqMem -> m (Body SeqMem)
forall a b. (a -> b) -> a -> b
$ BodyAttr SeqMem -> Stms SeqMem -> Result -> Body SeqMem
forall lore. BodyAttr lore -> Stms lore -> Result -> BodyT lore
Body () Stms SeqMem
stms Result
res
  mkLetNamesB :: [VName] -> Exp SeqMem -> m (Stm SeqMem)
mkLetNamesB = ExpAttr (Lore m) -> [VName] -> Exp (Lore m) -> m (Stm (Lore m))
forall (m :: * -> *) inner.
(Op (Lore m) ~ MemOp inner, MonadBinder m, ExpAttr (Lore m) ~ (),
 Allocator (Lore m) (PatAllocM (Lore m))) =>
ExpAttr (Lore m) -> [VName] -> Exp (Lore m) -> m (Stm (Lore m))
mkLetNamesB' ()

instance BinderOps (Engine.Wise SeqMem) where
  mkExpAttrB :: Pattern (Wise SeqMem)
-> Exp (Wise SeqMem) -> m (ExpAttr (Wise SeqMem))
mkExpAttrB Pattern (Wise SeqMem)
pat Exp (Wise SeqMem)
e = (ExpWisdom, ()) -> m (ExpWisdom, ())
forall (m :: * -> *) a. Monad m => a -> m a
return ((ExpWisdom, ()) -> m (ExpWisdom, ()))
-> (ExpWisdom, ()) -> m (ExpWisdom, ())
forall a b. (a -> b) -> a -> b
$ Pattern (Wise SeqMem)
-> ExpAttr SeqMem -> Exp (Wise SeqMem) -> ExpAttr (Wise SeqMem)
forall lore.
(Attributes lore, CanBeWise (Op lore)) =>
Pattern (Wise lore)
-> ExpAttr lore -> Exp (Wise lore) -> ExpAttr (Wise lore)
Engine.mkWiseExpAttr Pattern (Wise SeqMem)
pat () Exp (Wise SeqMem)
e
  mkBodyB :: Stms (Wise SeqMem) -> Result -> m (Body (Wise SeqMem))
mkBodyB Stms (Wise SeqMem)
stms Result
res = Body (Wise SeqMem) -> m (Body (Wise SeqMem))
forall (m :: * -> *) a. Monad m => a -> m a
return (Body (Wise SeqMem) -> m (Body (Wise SeqMem)))
-> Body (Wise SeqMem) -> m (Body (Wise SeqMem))
forall a b. (a -> b) -> a -> b
$ BodyAttr SeqMem
-> Stms (Wise SeqMem) -> Result -> Body (Wise SeqMem)
forall lore.
(Attributes lore, CanBeWise (Op lore)) =>
BodyAttr lore -> Stms (Wise lore) -> Result -> Body (Wise lore)
Engine.mkWiseBody () Stms (Wise SeqMem)
stms Result
res
  mkLetNamesB :: [VName] -> Exp (Wise SeqMem) -> m (Stm (Wise SeqMem))
mkLetNamesB = [VName] -> Exp (Wise SeqMem) -> m (Stm (Wise SeqMem))
forall (m :: * -> *) inner lore.
(Op (Lore m) ~ MemOp inner, ExpAttr lore ~ (),
 HasScope (Wise lore) m, Allocator lore (PatAllocM lore),
 MonadBinder m, CanBeWise (Op lore)) =>
[VName] -> Exp (Wise lore) -> m (Stm (Wise lore))
mkLetNamesB''

simplifyProg :: Prog SeqMem -> PassM (Prog SeqMem)
simplifyProg :: Prog SeqMem -> PassM (Prog SeqMem)
simplifyProg =
  SimplifyOp SeqMem () -> Prog SeqMem -> PassM (Prog SeqMem)
forall lore inner.
(SimplifyMemory lore, Op lore ~ MemOp inner) =>
SimplifyOp lore inner -> Prog lore -> PassM (Prog lore)
simplifyProgGeneric (SimplifyOp SeqMem () -> Prog SeqMem -> PassM (Prog SeqMem))
-> SimplifyOp SeqMem () -> Prog SeqMem -> PassM (Prog SeqMem)
forall a b. (a -> b) -> a -> b
$ SimpleM SeqMem ((), Stms (Wise SeqMem))
-> () -> SimpleM SeqMem ((), Stms (Wise SeqMem))
forall a b. a -> b -> a
const (SimpleM SeqMem ((), Stms (Wise SeqMem))
 -> () -> SimpleM SeqMem ((), Stms (Wise SeqMem)))
-> SimpleM SeqMem ((), Stms (Wise SeqMem))
-> ()
-> SimpleM SeqMem ((), Stms (Wise SeqMem))
forall a b. (a -> b) -> a -> b
$ ((), Stms (Wise SeqMem)) -> SimpleM SeqMem ((), Stms (Wise SeqMem))
forall (m :: * -> *) a. Monad m => a -> m a
return ((), Stms (Wise SeqMem)
forall a. Monoid a => a
mempty)

simpleSeqMem :: Engine.SimpleOps SeqMem
simpleSeqMem :: SimpleOps SeqMem
simpleSeqMem =
  SimplifyOp SeqMem () -> SimpleOps SeqMem
forall lore inner.
(SimplifyMemory lore, Op lore ~ MemOp inner) =>
SimplifyOp lore inner -> SimpleOps lore
simpleGeneric (SimplifyOp SeqMem () -> SimpleOps SeqMem)
-> SimplifyOp SeqMem () -> SimpleOps SeqMem
forall a b. (a -> b) -> a -> b
$ SimpleM SeqMem ((), Stms (Wise SeqMem))
-> () -> SimpleM SeqMem ((), Stms (Wise SeqMem))
forall a b. a -> b -> a
const (SimpleM SeqMem ((), Stms (Wise SeqMem))
 -> () -> SimpleM SeqMem ((), Stms (Wise SeqMem)))
-> SimpleM SeqMem ((), Stms (Wise SeqMem))
-> ()
-> SimpleM SeqMem ((), Stms (Wise SeqMem))
forall a b. (a -> b) -> a -> b
$ ((), Stms (Wise SeqMem)) -> SimpleM SeqMem ((), Stms (Wise SeqMem))
forall (m :: * -> *) a. Monad m => a -> m a
return ((), Stms (Wise SeqMem)
forall a. Monoid a => a
mempty)