{-# LANGUAGE FlexibleContexts #-}
module Futhark.CodeGen.ImpCode.Kernels
( Program
, Function
, FunctionT (Function)
, Code
, KernelCode
, KernelConst (..)
, KernelConstExp
, HostOp (..)
, KernelOp (..)
, AtomicOp (..)
, CallKernel (..)
, MapKernel (..)
, Kernel (..)
, LocalMemoryUse
, KernelUse (..)
, module Futhark.CodeGen.ImpCode
, module Futhark.Representation.Kernels.Sizes
, getKernels
)
where
import Control.Monad.Writer
import Data.List
import qualified Data.Set as S
import Futhark.CodeGen.ImpCode hiding (Function, Code)
import qualified Futhark.CodeGen.ImpCode as Imp
import Futhark.Representation.Kernels.Sizes
import Futhark.Representation.AST.Attributes.Names
import Futhark.Representation.AST.Pretty ()
import Futhark.Util.Pretty
type Program = Functions HostOp
type Function = Imp.Function HostOp
type Code = Imp.Code CallKernel
type KernelCode = Imp.Code KernelOp
newtype KernelConst = SizeConst VName
deriving (Eq, Ord, Show)
type KernelConstExp = PrimExp KernelConst
data HostOp = CallKernel CallKernel
| GetSize VName VName SizeClass
| CmpSizeLe VName VName SizeClass Imp.Exp
| GetSizeMax VName SizeClass
deriving (Show)
data CallKernel = Map MapKernel
| AnyKernel Kernel
| MapTranspose PrimType VName Exp VName Exp Exp Exp Exp Exp Exp
deriving (Show)
data MapKernel = MapKernel { mapKernelThreadNum :: VName
, mapKernelDesc :: String
, mapKernelBody :: Imp.Code KernelOp
, mapKernelUses :: [KernelUse]
, mapKernelNumGroups :: DimSize
, mapKernelGroupSize :: DimSize
, mapKernelSize :: Imp.Exp
}
deriving (Show)
data Kernel = Kernel
{ kernelBody :: Imp.Code KernelOp
, kernelLocalMemory :: [LocalMemoryUse]
, kernelUses :: [KernelUse]
, kernelNumGroups :: DimSize
, kernelGroupSize :: DimSize
, kernelName :: VName
, kernelDesc :: String
}
deriving (Show)
type LocalMemoryUse = (VName, Either MemSize KernelConstExp)
data KernelUse = ScalarUse VName PrimType
| MemoryUse VName Imp.DimSize
| ConstUse VName KernelConstExp
deriving (Eq, Show)
getKernels :: Program -> [CallKernel]
getKernels = nubBy sameKernel . execWriter . traverse getFunKernels
where getFunKernels (CallKernel kernel) =
tell [kernel]
getFunKernels _ =
return ()
sameKernel (MapTranspose bt1 _ _ _ _ _ _ _ _ _) (MapTranspose bt2 _ _ _ _ _ _ _ _ _) =
bt1 == bt2
sameKernel _ _ = False
instance Pretty KernelConst where
ppr (SizeConst key) = text "get_size" <> parens (ppr key)
instance Pretty KernelUse where
ppr (ScalarUse name t) =
text "scalar_copy" <> parens (commasep [ppr name, ppr t])
ppr (MemoryUse name size) =
text "mem_copy" <> parens (commasep [ppr name, ppr size])
ppr (ConstUse name e) =
text "const" <> parens (commasep [ppr name, ppr e])
instance Pretty HostOp where
ppr (GetSize dest key size_class) =
ppr dest <+> text "<-" <+>
text "get_size" <> parens (commasep [ppr key, ppr size_class])
ppr (GetSizeMax dest size_class) =
ppr dest <+> text "<-" <+> text "get_size_max" <> parens (ppr size_class)
ppr (CmpSizeLe dest name size_class x) =
ppr dest <+> text "<-" <+>
text "get_size" <> parens (commasep [ppr name, ppr size_class]) <+>
text "<" <+> ppr x
ppr (CallKernel c) =
ppr c
instance FreeIn HostOp where
freeIn (CallKernel c) = freeIn c
freeIn (CmpSizeLe dest name _ x) =
freeIn dest <> freeIn name <> freeIn x
freeIn (GetSizeMax dest _) =
freeIn dest
freeIn (GetSize dest _ _) =
freeIn dest
instance Pretty CallKernel where
ppr (Map k) = ppr k
ppr (AnyKernel k) = ppr k
ppr (MapTranspose bt dest destoffset src srcoffset num_arrays size_x size_y in_size out_size) =
text "mapTranspose" <>
parens (ppr bt <> comma </>
ppMemLoc dest destoffset <> comma </>
ppMemLoc src srcoffset <> comma </>
ppr num_arrays <> comma <+>
ppr size_x <> comma <+>
ppr size_y <> comma <+>
ppr in_size <> comma <+>
ppr out_size)
where ppMemLoc base offset =
ppr base <+> text "+" <+> ppr offset
instance FreeIn CallKernel where
freeIn (Map k) = freeIn k
freeIn (AnyKernel k) = freeIn k
freeIn (MapTranspose _ dest destoffset src srcoffset num_arrays size_x size_y in_size out_size) =
freeIn [dest, src] <> freeIn [destoffset, srcoffset] <> freeIn num_arrays <>
freeIn [size_x, size_y] <> freeIn [in_size, out_size]
instance FreeIn Kernel where
freeIn kernel = freeIn (kernelBody kernel) <>
freeIn [kernelNumGroups kernel, kernelGroupSize kernel]
instance Pretty MapKernel where
ppr kernel =
text "mapKernel" <+> brace
(text "uses" <+> brace (commasep $ map ppr $ mapKernelUses kernel) </>
text "body" <+> brace (ppr (mapKernelThreadNum kernel) <+>
text "<- get_thread_number()" </>
ppr (mapKernelBody kernel)))
instance Pretty Kernel where
ppr kernel =
text "kernel" <+> brace
(text "groups" <+> brace (ppr $ kernelNumGroups kernel) </>
text "group_size" <+> brace (ppr $ kernelGroupSize kernel) </>
text "local_memory" <+> brace (commasep $
map ppLocalMemory $
kernelLocalMemory kernel) </>
text "uses" <+> brace (commasep $ map ppr $ kernelUses kernel) </>
text "body" <+> brace (ppr $ kernelBody kernel))
where ppLocalMemory (name, Left size) =
ppr name <+> parens (ppr size <+> text "bytes")
ppLocalMemory (name, Right size) =
ppr name <+> parens (ppr size <+> text "bytes (const)")
instance FreeIn MapKernel where
freeIn kernel =
mapKernelThreadNum kernel `S.delete` freeIn (mapKernelBody kernel)
data KernelOp = GetGroupId VName Int
| GetLocalId VName Int
| GetLocalSize VName Int
| GetGlobalSize VName Int
| GetGlobalId VName Int
| GetLockstepWidth VName
| Atomic AtomicOp
| Barrier
| MemFence
deriving (Show)
data AtomicOp = AtomicAdd VName VName (Count Bytes) Exp
| AtomicSMax VName VName (Count Bytes) Exp
| AtomicSMin VName VName (Count Bytes) Exp
| AtomicUMax VName VName (Count Bytes) Exp
| AtomicUMin VName VName (Count Bytes) Exp
| AtomicAnd VName VName (Count Bytes) Exp
| AtomicOr VName VName (Count Bytes) Exp
| AtomicXor VName VName (Count Bytes) Exp
| AtomicCmpXchg VName VName (Count Bytes) Exp Exp
| AtomicXchg VName VName (Count Bytes) Exp
deriving (Show)
instance FreeIn AtomicOp where
freeIn (AtomicAdd _ arr i x) = freeIn arr <> freeIn i <> freeIn x
freeIn (AtomicSMax _ arr i x) = freeIn arr <> freeIn i <> freeIn x
freeIn (AtomicSMin _ arr i x) = freeIn arr <> freeIn i <> freeIn x
freeIn (AtomicUMax _ arr i x) = freeIn arr <> freeIn i <> freeIn x
freeIn (AtomicUMin _ arr i x) = freeIn arr <> freeIn i <> freeIn x
freeIn (AtomicAnd _ arr i x) = freeIn arr <> freeIn i <> freeIn x
freeIn (AtomicOr _ arr i x) = freeIn arr <> freeIn i <> freeIn x
freeIn (AtomicXor _ arr i x) = freeIn arr <> freeIn i <> freeIn x
freeIn (AtomicCmpXchg _ arr i x y) = freeIn arr <> freeIn i <> freeIn x <> freeIn y
freeIn (AtomicXchg _ arr i x) = freeIn arr <> freeIn i <> freeIn x
instance Pretty KernelOp where
ppr (GetGroupId dest i) =
ppr dest <+> text "<-" <+>
text "get_group_id" <> parens (ppr i)
ppr (GetLocalId dest i) =
ppr dest <+> text "<-" <+>
text "get_local_id" <> parens (ppr i)
ppr (GetLocalSize dest i) =
ppr dest <+> text "<-" <+>
text "get_local_size" <> parens (ppr i)
ppr (GetGlobalSize dest i) =
ppr dest <+> text "<-" <+>
text "get_global_size" <> parens (ppr i)
ppr (GetGlobalId dest i) =
ppr dest <+> text "<-" <+>
text "get_global_id" <> parens (ppr i)
ppr (GetLockstepWidth dest) =
ppr dest <+> text "<-" <+>
text "get_lockstep_width()"
ppr Barrier =
text "barrier()"
ppr MemFence =
text "mem_fence()"
ppr (Atomic (AtomicAdd old arr ind x)) =
ppr old <+> text "<-" <+> text "atomic_add" <>
parens (commasep [ppr arr <> brackets (ppr ind), ppr x])
ppr (Atomic (AtomicSMax old arr ind x)) =
ppr old <+> text "<-" <+> text "atomic_smax" <>
parens (commasep [ppr arr <> brackets (ppr ind), ppr x])
ppr (Atomic (AtomicSMin old arr ind x)) =
ppr old <+> text "<-" <+> text "atomic_smin" <>
parens (commasep [ppr arr <> brackets (ppr ind), ppr x])
ppr (Atomic (AtomicUMax old arr ind x)) =
ppr old <+> text "<-" <+> text "atomic_umax" <>
parens (commasep [ppr arr <> brackets (ppr ind), ppr x])
ppr (Atomic (AtomicUMin old arr ind x)) =
ppr old <+> text "<-" <+> text "atomic_umin" <>
parens (commasep [ppr arr <> brackets (ppr ind), ppr x])
ppr (Atomic (AtomicAnd old arr ind x)) =
ppr old <+> text "<-" <+> text "atomic_and" <>
parens (commasep [ppr arr <> brackets (ppr ind), ppr x])
ppr (Atomic (AtomicOr old arr ind x)) =
ppr old <+> text "<-" <+> text "atomic_or" <>
parens (commasep [ppr arr <> brackets (ppr ind), ppr x])
ppr (Atomic (AtomicXor old arr ind x)) =
ppr old <+> text "<-" <+> text "atomic_xor" <>
parens (commasep [ppr arr <> brackets (ppr ind), ppr x])
ppr (Atomic (AtomicCmpXchg old arr ind x y)) =
ppr old <+> text "<-" <+> text "atomic_cmp_xchg" <>
parens (commasep [ppr arr <> brackets (ppr ind), ppr x, ppr y])
ppr (Atomic (AtomicXchg old arr ind x)) =
ppr old <+> text "<-" <+> text "atomic_xchg" <>
parens (commasep [ppr arr <> brackets (ppr ind), ppr x])
instance FreeIn KernelOp where
freeIn (Atomic op) = freeIn op
freeIn _ = mempty
brace :: Doc -> Doc
brace body = text " {" </> indent 2 body </> text "}"