{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE ConstraintKinds #-}
module Futhark.Representation.KernelsMem
  ( KernelsMem

  -- * Simplification
  , simplifyProg
  , simplifyStms
  , simpleKernelsMem

    -- * Module re-exports
  , module Futhark.Representation.Mem
  , module Futhark.Representation.Kernels.Kernel
  )
  where

import Futhark.Analysis.PrimExp.Convert
import Futhark.MonadFreshNames
import Futhark.Pass
import Futhark.Representation.AST.Syntax
import Futhark.Representation.AST.Attributes
import Futhark.Representation.AST.Traversals
import Futhark.Representation.AST.Pretty
import Futhark.Representation.Kernels.Kernel
import Futhark.Representation.Kernels.Simplify (simplifyKernelOp)
import qualified Futhark.TypeCheck as TC
import Futhark.Representation.Mem
import Futhark.Representation.Mem.Simplify
import Futhark.Pass.ExplicitAllocations (BinderOps(..), mkLetNamesB', mkLetNamesB'')
import qualified Futhark.Optimise.Simplify.Engine as Engine

data KernelsMem

instance Annotations KernelsMem where
  type LetAttr    KernelsMem = LetAttrMem
  type FParamAttr KernelsMem = FParamMem
  type LParamAttr KernelsMem = LParamMem
  type RetType    KernelsMem = RetTypeMem
  type BranchType KernelsMem = BranchTypeMem
  type Op         KernelsMem = MemOp (HostOp KernelsMem ())

instance Attributes KernelsMem where
  expTypesFromPattern :: Pattern KernelsMem -> m [BranchType KernelsMem]
expTypesFromPattern = [BodyReturns] -> m [BodyReturns]
forall (m :: * -> *) a. Monad m => a -> m a
return ([BodyReturns] -> m [BodyReturns])
-> (PatternT (MemBound NoUniqueness) -> [BodyReturns])
-> PatternT (MemBound NoUniqueness)
-> m [BodyReturns]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((VName, BodyReturns) -> BodyReturns)
-> [(VName, BodyReturns)] -> [BodyReturns]
forall a b. (a -> b) -> [a] -> [b]
map (VName, BodyReturns) -> BodyReturns
forall a b. (a, b) -> b
snd ([(VName, BodyReturns)] -> [BodyReturns])
-> (PatternT (MemBound NoUniqueness) -> [(VName, BodyReturns)])
-> PatternT (MemBound NoUniqueness)
-> [BodyReturns]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ([(VName, BodyReturns)], [(VName, BodyReturns)])
-> [(VName, BodyReturns)]
forall a b. (a, b) -> b
snd (([(VName, BodyReturns)], [(VName, BodyReturns)])
 -> [(VName, BodyReturns)])
-> (PatternT (MemBound NoUniqueness)
    -> ([(VName, BodyReturns)], [(VName, BodyReturns)]))
-> PatternT (MemBound NoUniqueness)
-> [(VName, BodyReturns)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PatternT (MemBound NoUniqueness)
-> ([(VName, BodyReturns)], [(VName, BodyReturns)])
bodyReturnsFromPattern

instance OpReturns KernelsMem where
  opReturns :: Op KernelsMem -> m [ExpReturns]
opReturns (Alloc _ space) =
    [ExpReturns] -> m [ExpReturns]
forall (m :: * -> *) a. Monad m => a -> m a
return [Space -> ExpReturns
forall d u ret. Space -> MemInfo d u ret
MemMem Space
space]
  opReturns (Inner (SegOp op)) = SegOp SegLevel KernelsMem -> m [ExpReturns]
forall lore (m :: * -> *) lvl.
(Mem lore, Monad m, HasScope lore m) =>
SegOp lvl lore -> m [ExpReturns]
segOpReturns SegOp SegLevel KernelsMem
op
  opReturns Op KernelsMem
k = [ExtType] -> [ExpReturns]
extReturns ([ExtType] -> [ExpReturns]) -> m [ExtType] -> m [ExpReturns]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> MemOp (HostOp KernelsMem ()) -> m [ExtType]
forall op t (m :: * -> *).
(TypedOp op, HasScope t m) =>
op -> m [ExtType]
opType Op KernelsMem
MemOp (HostOp KernelsMem ())
k

instance PrettyLore KernelsMem where

instance TC.CheckableOp KernelsMem where
  checkOp :: OpWithAliases (Op KernelsMem) -> TypeM KernelsMem ()
checkOp = Maybe SegLevel
-> MemOp (HostOp (Aliases KernelsMem) ()) -> TypeM KernelsMem ()
forall lore b.
(Checkable lore,
 OpWithAliases (Op lore) ~ MemOp (HostOp (Aliases lore) b)) =>
Maybe SegLevel -> MemOp (HostOp (Aliases lore) b) -> TypeM lore ()
typeCheckMemoryOp Maybe SegLevel
forall a. Maybe a
Nothing
    where typeCheckMemoryOp :: Maybe SegLevel -> MemOp (HostOp (Aliases lore) b) -> TypeM lore ()
typeCheckMemoryOp Maybe SegLevel
_ (Alloc SubExp
size Space
_) =
            [Type] -> SubExp -> TypeM lore ()
forall lore. Checkable lore => [Type] -> SubExp -> TypeM lore ()
TC.require [PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64] SubExp
size
          typeCheckMemoryOp Maybe SegLevel
lvl (Inner HostOp (Aliases lore) b
op) =
            (SegLevel -> OpWithAliases (Op lore) -> TypeM lore ())
-> Maybe SegLevel
-> (b -> TypeM lore ())
-> HostOp (Aliases lore) b
-> TypeM lore ()
forall lore op.
Checkable lore =>
(SegLevel -> OpWithAliases (Op lore) -> TypeM lore ())
-> Maybe SegLevel
-> (op -> TypeM lore ())
-> HostOp (Aliases lore) op
-> TypeM lore ()
typeCheckHostOp (Maybe SegLevel -> MemOp (HostOp (Aliases lore) b) -> TypeM lore ()
typeCheckMemoryOp (Maybe SegLevel
 -> MemOp (HostOp (Aliases lore) b) -> TypeM lore ())
-> (SegLevel -> Maybe SegLevel)
-> SegLevel
-> MemOp (HostOp (Aliases lore) b)
-> TypeM lore ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegLevel -> Maybe SegLevel
forall a. a -> Maybe a
Just) Maybe SegLevel
lvl (TypeM lore () -> b -> TypeM lore ()
forall a b. a -> b -> a
const (TypeM lore () -> b -> TypeM lore ())
-> TypeM lore () -> b -> TypeM lore ()
forall a b. (a -> b) -> a -> b
$ () -> TypeM lore ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()) HostOp (Aliases lore) b
op

instance TC.Checkable KernelsMem where
  checkFParamLore :: VName -> FParamAttr KernelsMem -> TypeM KernelsMem ()
checkFParamLore = VName -> FParamAttr KernelsMem -> TypeM KernelsMem ()
forall lore u.
Checkable lore =>
VName -> MemInfo SubExp u MemBind -> TypeM lore ()
checkMemInfo
  checkLParamLore :: VName -> LParamAttr KernelsMem -> TypeM KernelsMem ()
checkLParamLore = VName -> LParamAttr KernelsMem -> TypeM KernelsMem ()
forall lore u.
Checkable lore =>
VName -> MemInfo SubExp u MemBind -> TypeM lore ()
checkMemInfo
  checkLetBoundLore :: VName -> LetAttr KernelsMem -> TypeM KernelsMem ()
checkLetBoundLore = VName -> LetAttr KernelsMem -> TypeM KernelsMem ()
forall lore u.
Checkable lore =>
VName -> MemInfo SubExp u MemBind -> TypeM lore ()
checkMemInfo
  checkRetType :: [RetType KernelsMem] -> TypeM KernelsMem ()
checkRetType = (TypeBase ExtShape Uniqueness -> TypeM KernelsMem ())
-> [TypeBase ExtShape Uniqueness] -> TypeM KernelsMem ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ TypeBase ExtShape Uniqueness -> TypeM KernelsMem ()
forall lore u.
Checkable lore =>
TypeBase ExtShape u -> TypeM lore ()
TC.checkExtType ([TypeBase ExtShape Uniqueness] -> TypeM KernelsMem ())
-> ([RetTypeMem] -> [TypeBase ExtShape Uniqueness])
-> [RetTypeMem]
-> TypeM KernelsMem ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [RetTypeMem] -> [TypeBase ExtShape Uniqueness]
forall rt. IsRetType rt => [rt] -> [TypeBase ExtShape Uniqueness]
retTypeValues
  primFParam :: VName -> PrimType -> TypeM KernelsMem (FParam (Aliases KernelsMem))
primFParam VName
name PrimType
t = Param (MemInfo SubExp Uniqueness MemBind)
-> TypeM KernelsMem (Param (MemInfo SubExp Uniqueness MemBind))
forall (m :: * -> *) a. Monad m => a -> m a
return (Param (MemInfo SubExp Uniqueness MemBind)
 -> TypeM KernelsMem (Param (MemInfo SubExp Uniqueness MemBind)))
-> Param (MemInfo SubExp Uniqueness MemBind)
-> TypeM KernelsMem (Param (MemInfo SubExp Uniqueness MemBind))
forall a b. (a -> b) -> a -> b
$ VName
-> MemInfo SubExp Uniqueness MemBind
-> Param (MemInfo SubExp Uniqueness MemBind)
forall attr. VName -> attr -> Param attr
Param VName
name (PrimType -> MemInfo SubExp Uniqueness MemBind
forall d u ret. PrimType -> MemInfo d u ret
MemPrim PrimType
t)
  matchPattern :: Pattern (Aliases KernelsMem)
-> Exp (Aliases KernelsMem) -> TypeM KernelsMem ()
matchPattern = Pattern (Aliases KernelsMem)
-> Exp (Aliases KernelsMem) -> TypeM KernelsMem ()
forall lore.
Mem lore =>
Pattern (Aliases lore) -> Exp (Aliases lore) -> TypeM lore ()
matchPatternToExp
  matchReturnType :: [RetType KernelsMem] -> Result -> TypeM KernelsMem ()
matchReturnType = [RetType KernelsMem] -> Result -> TypeM KernelsMem ()
forall lore. Mem lore => [RetTypeMem] -> Result -> TypeM lore ()
matchFunctionReturnType
  matchBranchType :: [BranchType KernelsMem]
-> Body (Aliases KernelsMem) -> TypeM KernelsMem ()
matchBranchType = [BranchType KernelsMem]
-> Body (Aliases KernelsMem) -> TypeM KernelsMem ()
forall lore.
Mem lore =>
[BodyReturns] -> Body (Aliases lore) -> TypeM lore ()
matchBranchReturnType

instance BinderOps KernelsMem where
  mkExpAttrB :: Pattern KernelsMem -> Exp KernelsMem -> m (ExpAttr KernelsMem)
mkExpAttrB Pattern KernelsMem
_ Exp KernelsMem
_ = () -> m ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
  mkBodyB :: Stms KernelsMem -> Result -> m (Body KernelsMem)
mkBodyB Stms KernelsMem
stms Result
res = Body KernelsMem -> m (Body KernelsMem)
forall (m :: * -> *) a. Monad m => a -> m a
return (Body KernelsMem -> m (Body KernelsMem))
-> Body KernelsMem -> m (Body KernelsMem)
forall a b. (a -> b) -> a -> b
$ BodyAttr KernelsMem -> Stms KernelsMem -> Result -> Body KernelsMem
forall lore. BodyAttr lore -> Stms lore -> Result -> BodyT lore
Body () Stms KernelsMem
stms Result
res
  mkLetNamesB :: [VName] -> Exp KernelsMem -> m (Stm KernelsMem)
mkLetNamesB = ExpAttr (Lore m) -> [VName] -> Exp (Lore m) -> m (Stm (Lore m))
forall (m :: * -> *) inner.
(Op (Lore m) ~ MemOp inner, MonadBinder m, ExpAttr (Lore m) ~ (),
 Allocator (Lore m) (PatAllocM (Lore m))) =>
ExpAttr (Lore m) -> [VName] -> Exp (Lore m) -> m (Stm (Lore m))
mkLetNamesB' ()

instance BinderOps (Engine.Wise KernelsMem) where
  mkExpAttrB :: Pattern (Wise KernelsMem)
-> Exp (Wise KernelsMem) -> m (ExpAttr (Wise KernelsMem))
mkExpAttrB Pattern (Wise KernelsMem)
pat Exp (Wise KernelsMem)
e = (ExpWisdom, ()) -> m (ExpWisdom, ())
forall (m :: * -> *) a. Monad m => a -> m a
return ((ExpWisdom, ()) -> m (ExpWisdom, ()))
-> (ExpWisdom, ()) -> m (ExpWisdom, ())
forall a b. (a -> b) -> a -> b
$ Pattern (Wise KernelsMem)
-> ExpAttr KernelsMem
-> Exp (Wise KernelsMem)
-> ExpAttr (Wise KernelsMem)
forall lore.
(Attributes lore, CanBeWise (Op lore)) =>
Pattern (Wise lore)
-> ExpAttr lore -> Exp (Wise lore) -> ExpAttr (Wise lore)
Engine.mkWiseExpAttr Pattern (Wise KernelsMem)
pat () Exp (Wise KernelsMem)
e
  mkBodyB :: Stms (Wise KernelsMem) -> Result -> m (Body (Wise KernelsMem))
mkBodyB Stms (Wise KernelsMem)
stms Result
res = Body (Wise KernelsMem) -> m (Body (Wise KernelsMem))
forall (m :: * -> *) a. Monad m => a -> m a
return (Body (Wise KernelsMem) -> m (Body (Wise KernelsMem)))
-> Body (Wise KernelsMem) -> m (Body (Wise KernelsMem))
forall a b. (a -> b) -> a -> b
$ BodyAttr KernelsMem
-> Stms (Wise KernelsMem) -> Result -> Body (Wise KernelsMem)
forall lore.
(Attributes lore, CanBeWise (Op lore)) =>
BodyAttr lore -> Stms (Wise lore) -> Result -> Body (Wise lore)
Engine.mkWiseBody () Stms (Wise KernelsMem)
stms Result
res
  mkLetNamesB :: [VName] -> Exp (Wise KernelsMem) -> m (Stm (Wise KernelsMem))
mkLetNamesB = [VName] -> Exp (Wise KernelsMem) -> m (Stm (Wise KernelsMem))
forall (m :: * -> *) inner lore.
(Op (Lore m) ~ MemOp inner, ExpAttr lore ~ (),
 HasScope (Wise lore) m, Allocator lore (PatAllocM lore),
 MonadBinder m, CanBeWise (Op lore)) =>
[VName] -> Exp (Wise lore) -> m (Stm (Wise lore))
mkLetNamesB''

simplifyProg :: Prog KernelsMem -> PassM (Prog KernelsMem)
simplifyProg :: Prog KernelsMem -> PassM (Prog KernelsMem)
simplifyProg =
  SimplifyOp KernelsMem (HostOp KernelsMem ())
-> Prog KernelsMem -> PassM (Prog KernelsMem)
forall lore inner.
(SimplifyMemory lore, Op lore ~ MemOp inner) =>
SimplifyOp lore inner -> Prog lore -> PassM (Prog lore)
simplifyProgGeneric (SimplifyOp KernelsMem (HostOp KernelsMem ())
 -> Prog KernelsMem -> PassM (Prog KernelsMem))
-> SimplifyOp KernelsMem (HostOp KernelsMem ())
-> Prog KernelsMem
-> PassM (Prog KernelsMem)
forall a b. (a -> b) -> a -> b
$ SimplifyOp KernelsMem ()
-> HostOp KernelsMem ()
-> SimpleM
     KernelsMem
     (HostOp (Wise KernelsMem) (OpWithWisdom ()),
      Stms (Wise KernelsMem))
forall lore op.
(SimplifiableLore lore, BodyAttr lore ~ ()) =>
SimplifyOp lore op
-> HostOp lore op
-> SimpleM
     lore (HostOp (Wise lore) (OpWithWisdom op), Stms (Wise lore))
simplifyKernelOp (SimplifyOp KernelsMem ()
 -> HostOp KernelsMem ()
 -> SimpleM
      KernelsMem
      (HostOp (Wise KernelsMem) (OpWithWisdom ()),
       Stms (Wise KernelsMem)))
-> SimplifyOp KernelsMem ()
-> HostOp KernelsMem ()
-> SimpleM
     KernelsMem
     (HostOp (Wise KernelsMem) (OpWithWisdom ()),
      Stms (Wise KernelsMem))
forall a b. (a -> b) -> a -> b
$ SimpleM KernelsMem ((), Stms (Wise KernelsMem))
-> () -> SimpleM KernelsMem ((), Stms (Wise KernelsMem))
forall a b. a -> b -> a
const (SimpleM KernelsMem ((), Stms (Wise KernelsMem))
 -> () -> SimpleM KernelsMem ((), Stms (Wise KernelsMem)))
-> SimpleM KernelsMem ((), Stms (Wise KernelsMem))
-> ()
-> SimpleM KernelsMem ((), Stms (Wise KernelsMem))
forall a b. (a -> b) -> a -> b
$ ((), Stms (Wise KernelsMem))
-> SimpleM KernelsMem ((), Stms (Wise KernelsMem))
forall (m :: * -> *) a. Monad m => a -> m a
return ((), Stms (Wise KernelsMem)
forall a. Monoid a => a
mempty)

simplifyStms :: (HasScope KernelsMem m, MonadFreshNames m) =>
                 Stms KernelsMem
             -> m (Engine.SymbolTable (Engine.Wise KernelsMem),
                   Stms KernelsMem)
simplifyStms :: Stms KernelsMem
-> m (SymbolTable (Wise KernelsMem), Stms KernelsMem)
simplifyStms =
  SimplifyOp KernelsMem (HostOp KernelsMem ())
-> Stms KernelsMem
-> m (SymbolTable (Wise KernelsMem), Stms KernelsMem)
forall lore (m :: * -> *) inner.
(HasScope lore m, MonadFreshNames m, SimplifyMemory lore,
 Op lore ~ MemOp inner) =>
SimplifyOp lore inner
-> Stms lore -> m (SymbolTable (Wise lore), Stms lore)
simplifyStmsGeneric (SimplifyOp KernelsMem (HostOp KernelsMem ())
 -> Stms KernelsMem
 -> m (SymbolTable (Wise KernelsMem), Stms KernelsMem))
-> SimplifyOp KernelsMem (HostOp KernelsMem ())
-> Stms KernelsMem
-> m (SymbolTable (Wise KernelsMem), Stms KernelsMem)
forall a b. (a -> b) -> a -> b
$ SimplifyOp KernelsMem ()
-> HostOp KernelsMem ()
-> SimpleM
     KernelsMem
     (HostOp (Wise KernelsMem) (OpWithWisdom ()),
      Stms (Wise KernelsMem))
forall lore op.
(SimplifiableLore lore, BodyAttr lore ~ ()) =>
SimplifyOp lore op
-> HostOp lore op
-> SimpleM
     lore (HostOp (Wise lore) (OpWithWisdom op), Stms (Wise lore))
simplifyKernelOp (SimplifyOp KernelsMem ()
 -> HostOp KernelsMem ()
 -> SimpleM
      KernelsMem
      (HostOp (Wise KernelsMem) (OpWithWisdom ()),
       Stms (Wise KernelsMem)))
-> SimplifyOp KernelsMem ()
-> HostOp KernelsMem ()
-> SimpleM
     KernelsMem
     (HostOp (Wise KernelsMem) (OpWithWisdom ()),
      Stms (Wise KernelsMem))
forall a b. (a -> b) -> a -> b
$ SimpleM KernelsMem ((), Stms (Wise KernelsMem))
-> () -> SimpleM KernelsMem ((), Stms (Wise KernelsMem))
forall a b. a -> b -> a
const (SimpleM KernelsMem ((), Stms (Wise KernelsMem))
 -> () -> SimpleM KernelsMem ((), Stms (Wise KernelsMem)))
-> SimpleM KernelsMem ((), Stms (Wise KernelsMem))
-> ()
-> SimpleM KernelsMem ((), Stms (Wise KernelsMem))
forall a b. (a -> b) -> a -> b
$ ((), Stms (Wise KernelsMem))
-> SimpleM KernelsMem ((), Stms (Wise KernelsMem))
forall (m :: * -> *) a. Monad m => a -> m a
return ((), Stms (Wise KernelsMem)
forall a. Monoid a => a
mempty)

simpleKernelsMem :: Engine.SimpleOps KernelsMem
simpleKernelsMem :: SimpleOps KernelsMem
simpleKernelsMem =
  SimplifyOp KernelsMem (HostOp KernelsMem ())
-> SimpleOps KernelsMem
forall lore inner.
(SimplifyMemory lore, Op lore ~ MemOp inner) =>
SimplifyOp lore inner -> SimpleOps lore
simpleGeneric (SimplifyOp KernelsMem (HostOp KernelsMem ())
 -> SimpleOps KernelsMem)
-> SimplifyOp KernelsMem (HostOp KernelsMem ())
-> SimpleOps KernelsMem
forall a b. (a -> b) -> a -> b
$ SimplifyOp KernelsMem ()
-> HostOp KernelsMem ()
-> SimpleM
     KernelsMem
     (HostOp (Wise KernelsMem) (OpWithWisdom ()),
      Stms (Wise KernelsMem))
forall lore op.
(SimplifiableLore lore, BodyAttr lore ~ ()) =>
SimplifyOp lore op
-> HostOp lore op
-> SimpleM
     lore (HostOp (Wise lore) (OpWithWisdom op), Stms (Wise lore))
simplifyKernelOp (SimplifyOp KernelsMem ()
 -> HostOp KernelsMem ()
 -> SimpleM
      KernelsMem
      (HostOp (Wise KernelsMem) (OpWithWisdom ()),
       Stms (Wise KernelsMem)))
-> SimplifyOp KernelsMem ()
-> HostOp KernelsMem ()
-> SimpleM
     KernelsMem
     (HostOp (Wise KernelsMem) (OpWithWisdom ()),
      Stms (Wise KernelsMem))
forall a b. (a -> b) -> a -> b
$ SimpleM KernelsMem ((), Stms (Wise KernelsMem))
-> () -> SimpleM KernelsMem ((), Stms (Wise KernelsMem))
forall a b. a -> b -> a
const (SimpleM KernelsMem ((), Stms (Wise KernelsMem))
 -> () -> SimpleM KernelsMem ((), Stms (Wise KernelsMem)))
-> SimpleM KernelsMem ((), Stms (Wise KernelsMem))
-> ()
-> SimpleM KernelsMem ((), Stms (Wise KernelsMem))
forall a b. (a -> b) -> a -> b
$ ((), Stms (Wise KernelsMem))
-> SimpleM KernelsMem ((), Stms (Wise KernelsMem))
forall (m :: * -> *) a. Monad m => a -> m a
return ((), Stms (Wise KernelsMem)
forall a. Monoid a => a
mempty)