{-# LANGUAGE GADTs #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE ViewPatterns #-}
module Data.Array.Accelerate.LLVM.PTX.CodeGen.Permute (
mkPermute,
) where
import Data.Array.Accelerate.Analysis.Type
import Data.Array.Accelerate.Array.Sugar ( Array, Vector, Shape, Elt, eltType )
import Data.Array.Accelerate.Error
import qualified Data.Array.Accelerate.Array.Sugar as S
import Data.Array.Accelerate.LLVM.CodeGen.Arithmetic as A
import Data.Array.Accelerate.LLVM.CodeGen.Array
import Data.Array.Accelerate.LLVM.CodeGen.Base
import Data.Array.Accelerate.LLVM.CodeGen.Constant
import Data.Array.Accelerate.LLVM.CodeGen.Environment
import Data.Array.Accelerate.LLVM.CodeGen.Exp
import Data.Array.Accelerate.LLVM.CodeGen.IR
import Data.Array.Accelerate.LLVM.CodeGen.Monad
import Data.Array.Accelerate.LLVM.CodeGen.Permute
import Data.Array.Accelerate.LLVM.CodeGen.Ptr
import Data.Array.Accelerate.LLVM.CodeGen.Sugar
import Data.Array.Accelerate.LLVM.PTX.CodeGen.Base
import Data.Array.Accelerate.LLVM.PTX.CodeGen.Loop
import Data.Array.Accelerate.LLVM.PTX.Context
import Data.Array.Accelerate.LLVM.PTX.Target
import LLVM.AST.Type.AddrSpace
import LLVM.AST.Type.Instruction
import LLVM.AST.Type.Instruction.Atomic
import LLVM.AST.Type.Instruction.RMW as RMW
import LLVM.AST.Type.Instruction.Volatile
import LLVM.AST.Type.Operand
import LLVM.AST.Type.Representation
import Foreign.CUDA.Analysis
import Data.Typeable
import Control.Monad ( void )
import Prelude
mkPermute
:: forall aenv sh sh' e. (Shape sh, Shape sh', Elt e)
=> PTX
-> Gamma aenv
-> IRPermuteFun PTX aenv (e -> e -> e)
-> IRFun1 PTX aenv (sh -> sh')
-> IRDelayed PTX aenv (Array sh e)
-> CodeGen (IROpenAcc PTX aenv (Array sh' e))
mkPermute ptx aenv IRPermuteFun{..} project arr =
let
bytes = sizeOf (eltType (undefined :: e))
sizeOk = bytes == 4 || bytes == 8
in
case atomicRMW of
Just (rmw, f) | sizeOk -> mkPermute_rmw ptx aenv rmw f project arr
_ -> mkPermute_mutex ptx aenv combine project arr
mkPermute_rmw
:: forall aenv sh sh' e. (Shape sh, Shape sh', Elt e)
=> PTX
-> Gamma aenv
-> RMWOperation
-> IRFun1 PTX aenv (e -> e)
-> IRFun1 PTX aenv (sh -> sh')
-> IRDelayed PTX aenv (Array sh e)
-> CodeGen (IROpenAcc PTX aenv (Array sh' e))
mkPermute_rmw ptx@(deviceProperties . ptxContext -> dev) aenv rmw update project IRDelayed{..} =
let
(start, end, paramGang) = gangParam
(arrOut, paramOut) = mutableArray ("out" :: Name (Array sh' e))
paramEnv = envParam aenv
bytes = sizeOf (eltType (undefined :: e))
compute = computeCapability dev
compute32 = Compute 3 2
in
makeOpenAcc ptx "permute_rmw" (paramGang ++ paramOut ++ paramEnv) $ do
sh <- delayedExtent
imapFromTo start end $ \i -> do
i' <- A.fromIntegral integralType numType i
ix <- indexOfInt sh i'
ix' <- app1 project ix
unless (ignore ix') $ do
j <- intOfIndex (irArrayShape arrOut) ix'
x <- app1 delayedLinearIndex i'
r <- app1 update x
case rmw of
Exchange
-> writeArray arrOut j r
_ | SingleTuple s <- eltType (undefined::e)
, Just adata <- gcast (irArrayData arrOut)
, Just r' <- gcast r
-> do
addr <- instr' $ GetElementPtr (asPtr defaultAddrSpace (op s adata)) [op integralType j]
let
rmw_integral :: IntegralType t -> Operand (Ptr t) -> Operand t -> CodeGen ()
rmw_integral t ptr val
| primOk = void . instr' $ AtomicRMW t NonVolatile rmw ptr val (CrossThread, AcquireRelease)
| otherwise =
case rmw of
RMW.And -> atomicCAS_rmw s' (A.band t (ir t val)) ptr
RMW.Or -> atomicCAS_rmw s' (A.bor t (ir t val)) ptr
RMW.Xor -> atomicCAS_rmw s' (A.xor t (ir t val)) ptr
RMW.Min -> atomicCAS_cmp s' A.lt ptr val
RMW.Max -> atomicCAS_cmp s' A.gt ptr val
_ -> $internalError "mkPermute_rmw.integral" "unexpected transition"
where
s' = NumScalarType (IntegralNumType t)
primOk = compute >= compute32
|| bytes == 4
|| case rmw of
RMW.Add -> True
RMW.Sub -> True
_ -> False
rmw_floating :: FloatingType t -> Operand (Ptr t) -> Operand t -> CodeGen ()
rmw_floating t ptr val =
case rmw of
RMW.Min -> atomicCAS_cmp s' A.lt ptr val
RMW.Max -> atomicCAS_cmp s' A.gt ptr val
RMW.Sub -> atomicCAS_rmw s' (A.sub n (ir t val)) ptr
RMW.Add
| primAdd -> atomicAdd_f t ptr val
| otherwise -> atomicCAS_rmw s' (A.add n (ir t val)) ptr
_ -> $internalError "mkPermute_rmw.floating" "unexpected transition"
where
n = FloatingNumType t
s' = NumScalarType n
primAdd = bytes == 4
rmw_nonnum :: NonNumType t -> Operand (Ptr t) -> Operand t -> CodeGen ()
rmw_nonnum TypeChar{} ptr val = do
ptr32 <- instr' $ PtrCast (primType :: PrimType (Ptr Word32)) ptr
val32 <- instr' $ BitCast (scalarType :: ScalarType Word32) val
void $ instr' $ AtomicRMW (integralType :: IntegralType Word32) NonVolatile rmw ptr32 val32 (CrossThread, AcquireRelease)
rmw_nonnum _ _ _ =
$internalError "mkPermute_rmw.nonnum" "unexpected transition"
case s of
NumScalarType (IntegralNumType t) -> rmw_integral t addr (op t r')
NumScalarType (FloatingNumType t) -> rmw_floating t addr (op t r')
NonNumScalarType t -> rmw_nonnum t addr (op t r')
_ -> $internalError "mkPermute_rmw" "unexpected transition"
return_
mkPermute_mutex
:: forall aenv sh sh' e. (Shape sh, Shape sh', Elt e)
=> PTX
-> Gamma aenv
-> IRFun2 PTX aenv (e -> e -> e)
-> IRFun1 PTX aenv (sh -> sh')
-> IRDelayed PTX aenv (Array sh e)
-> CodeGen (IROpenAcc PTX aenv (Array sh' e))
mkPermute_mutex ptx aenv combine project IRDelayed{..} =
let
(start, end, paramGang) = gangParam
(arrOut, paramOut) = mutableArray ("out" :: Name (Array sh' e))
(arrLock, paramLock) = mutableArray ("lock" :: Name (Vector Word32))
paramEnv = envParam aenv
in
makeOpenAcc ptx "permute_mutex" (paramGang ++ paramOut ++ paramLock ++ paramEnv) $ do
sh <- delayedExtent
imapFromTo start end $ \i -> do
i' <- A.fromIntegral integralType numType i
ix <- indexOfInt sh i'
ix' <- app1 project ix
unless (ignore ix') $ do
j <- intOfIndex (irArrayShape arrOut) ix'
x <- app1 delayedLinearIndex i'
atomically arrLock j $ do
y <- readArray arrOut j
r <- app2 combine x y
writeArray arrOut j r
return_
atomically
:: IRArray (Vector Word32)
-> IR Int
-> CodeGen a
-> CodeGen a
atomically barriers i action = do
let
lock = integral integralType 1
unlock = integral integralType 0
unlock' = lift 0
spin <- newBlock "spinlock.entry"
crit <- newBlock "spinlock.critical-start"
skip <- newBlock "spinlock.critical-end"
exit <- newBlock "spinlock.exit"
addr <- instr' $ GetElementPtr (asPtr defaultAddrSpace (op integralType (irArrayData barriers))) [op integralType i]
_ <- br spin
setBlock spin
old <- instr $ AtomicRMW integralType NonVolatile Exchange addr lock (CrossThread, Acquire)
ok <- A.eq scalarType old unlock'
no <- cbr ok crit skip
setBlock crit
r <- action
_ <- instr $ AtomicRMW integralType NonVolatile Exchange addr unlock (CrossThread, Release)
yes <- br skip
setBlock skip
done <- phi [(lift True, yes), (lift False, no)]
__syncthreads
_ <- cbr done exit spin
setBlock exit
return r
ignore :: forall ix. Shape ix => IR ix -> CodeGen (IR Bool)
ignore (IR ix) = go (S.eltType (undefined::ix)) (S.fromElt (S.ignore::ix)) ix
where
go :: TupleType t -> t -> Operands t -> CodeGen (IR Bool)
go UnitTuple () OP_Unit = return (lift True)
go (PairTuple tsh tsz) (ish, isz) (OP_Pair sh sz) = do x <- go tsh ish sh
y <- go tsz isz sz
land' x y
go (SingleTuple t) ig sz = A.eq t (ir t (scalar t ig)) (ir t (op' t sz))