{-# LANGUAGE GADTs #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE PatternGuards #-}
{-# LANGUAGE RebindableSyntax #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE ViewPatterns #-}
module Data.Array.Accelerate.LLVM.PTX.CodeGen.Scan (
mkScanl, mkScanl1, mkScanl',
mkScanr, mkScanr1, mkScanr',
) where
import Data.Array.Accelerate.Analysis.Type
import Data.Array.Accelerate.Array.Sugar
import Data.Array.Accelerate.LLVM.Analysis.Match
import Data.Array.Accelerate.LLVM.CodeGen.Arithmetic as A
import Data.Array.Accelerate.LLVM.CodeGen.Array
import Data.Array.Accelerate.LLVM.CodeGen.Base
import Data.Array.Accelerate.LLVM.CodeGen.Constant
import Data.Array.Accelerate.LLVM.CodeGen.Environment
import Data.Array.Accelerate.LLVM.CodeGen.Exp
import Data.Array.Accelerate.LLVM.CodeGen.IR
import Data.Array.Accelerate.LLVM.CodeGen.Loop
import Data.Array.Accelerate.LLVM.CodeGen.Monad
import Data.Array.Accelerate.LLVM.CodeGen.Sugar
import Data.Array.Accelerate.LLVM.PTX.Analysis.Launch
import Data.Array.Accelerate.LLVM.PTX.CodeGen.Base
import Data.Array.Accelerate.LLVM.PTX.CodeGen.Generate
import Data.Array.Accelerate.LLVM.PTX.Context
import Data.Array.Accelerate.LLVM.PTX.Target
import LLVM.AST.Type.Representation
import qualified Foreign.CUDA.Analysis as CUDA
import Control.Applicative
import Control.Monad ( (>=>), void )
import Data.String ( fromString )
import Data.Coerce as Safe
import Data.Bits as P
import Prelude as P hiding ( last )
data Direction = L | R
mkScanl
:: forall aenv sh e. (Shape sh, Elt e)
=> PTX
-> Gamma aenv
-> IRFun2 PTX aenv (e -> e -> e)
-> IRExp PTX aenv e
-> IRDelayed PTX aenv (Array (sh:.Int) e)
-> CodeGen (IROpenAcc PTX aenv (Array (sh:.Int) e))
mkScanl ptx@(deviceProperties . ptxContext -> dev) aenv combine seed arr
| Just Refl <- matchShapeType (undefined::sh) (undefined::Z)
= foldr1 (+++) <$> sequence [ mkScanAllP1 L dev aenv combine (Just seed) arr
, mkScanAllP2 L dev aenv combine
, mkScanAllP3 L dev aenv combine (Just seed)
, mkScanFill ptx aenv seed
]
| otherwise
= (+++) <$> mkScanDim L dev aenv combine (Just seed) arr
<*> mkScanFill ptx aenv seed
mkScanl1
:: forall aenv sh e. (Shape sh, Elt e)
=> PTX
-> Gamma aenv
-> IRFun2 PTX aenv (e -> e -> e)
-> IRDelayed PTX aenv (Array (sh:.Int) e)
-> CodeGen (IROpenAcc PTX aenv (Array (sh:.Int) e))
mkScanl1 (deviceProperties . ptxContext -> dev) aenv combine arr
| Just Refl <- matchShapeType (undefined::sh) (undefined::Z)
= foldr1 (+++) <$> sequence [ mkScanAllP1 L dev aenv combine Nothing arr
, mkScanAllP2 L dev aenv combine
, mkScanAllP3 L dev aenv combine Nothing
]
| otherwise
= mkScanDim L dev aenv combine Nothing arr
mkScanl'
:: forall aenv sh e. (Shape sh, Elt e)
=> PTX
-> Gamma aenv
-> IRFun2 PTX aenv (e -> e -> e)
-> IRExp PTX aenv e
-> IRDelayed PTX aenv (Array (sh:.Int) e)
-> CodeGen (IROpenAcc PTX aenv (Array (sh:.Int) e, Array sh e))
mkScanl' ptx@(deviceProperties . ptxContext -> dev) aenv combine seed arr
| Just Refl <- matchShapeType (undefined::sh) (undefined::Z)
= foldr1 (+++) <$> sequence [ mkScan'AllP1 L dev aenv combine seed arr
, mkScan'AllP2 L dev aenv combine
, mkScan'AllP3 L dev aenv combine
, mkScan'Fill ptx aenv seed
]
| otherwise
= (+++) <$> mkScan'Dim L dev aenv combine seed arr
<*> mkScan'Fill ptx aenv seed
mkScanr
:: forall aenv sh e. (Shape sh, Elt e)
=> PTX
-> Gamma aenv
-> IRFun2 PTX aenv (e -> e -> e)
-> IRExp PTX aenv e
-> IRDelayed PTX aenv (Array (sh:.Int) e)
-> CodeGen (IROpenAcc PTX aenv (Array (sh:.Int) e))
mkScanr ptx@(deviceProperties . ptxContext -> dev) aenv combine seed arr
| Just Refl <- matchShapeType (undefined::sh) (undefined::Z)
= foldr1 (+++) <$> sequence [ mkScanAllP1 R dev aenv combine (Just seed) arr
, mkScanAllP2 R dev aenv combine
, mkScanAllP3 R dev aenv combine (Just seed)
, mkScanFill ptx aenv seed
]
| otherwise
= (+++) <$> mkScanDim R dev aenv combine (Just seed) arr
<*> mkScanFill ptx aenv seed
mkScanr1
:: forall aenv sh e. (Shape sh, Elt e)
=> PTX
-> Gamma aenv
-> IRFun2 PTX aenv (e -> e -> e)
-> IRDelayed PTX aenv (Array (sh:.Int) e)
-> CodeGen (IROpenAcc PTX aenv (Array (sh:.Int) e))
mkScanr1 (deviceProperties . ptxContext -> dev) aenv combine arr
| Just Refl <- matchShapeType (undefined::sh) (undefined::Z)
= foldr1 (+++) <$> sequence [ mkScanAllP1 R dev aenv combine Nothing arr
, mkScanAllP2 R dev aenv combine
, mkScanAllP3 R dev aenv combine Nothing
]
| otherwise
= mkScanDim R dev aenv combine Nothing arr
mkScanr'
:: forall aenv sh e. (Shape sh, Elt e)
=> PTX
-> Gamma aenv
-> IRFun2 PTX aenv (e -> e -> e)
-> IRExp PTX aenv e
-> IRDelayed PTX aenv (Array (sh:.Int) e)
-> CodeGen (IROpenAcc PTX aenv (Array (sh:.Int) e, Array sh e))
mkScanr' ptx@(deviceProperties . ptxContext -> dev) aenv combine seed arr
| Just Refl <- matchShapeType (undefined::sh) (undefined::Z)
= foldr1 (+++) <$> sequence [ mkScan'AllP1 R dev aenv combine seed arr
, mkScan'AllP2 R dev aenv combine
, mkScan'AllP3 R dev aenv combine
, mkScan'Fill ptx aenv seed
]
| otherwise
= (+++) <$> mkScan'Dim R dev aenv combine seed arr
<*> mkScan'Fill ptx aenv seed
mkScanAllP1
:: forall aenv e. Elt e
=> Direction
-> DeviceProperties
-> Gamma aenv
-> IRFun2 PTX aenv (e -> e -> e)
-> Maybe (IRExp PTX aenv e)
-> IRDelayed PTX aenv (Vector e)
-> CodeGen (IROpenAcc PTX aenv (Vector e))
mkScanAllP1 dir dev aenv combine mseed IRDelayed{..} =
let
(start, end, paramGang) = gangParam
(arrOut, paramOut) = mutableArray ("out" :: Name (Vector e))
(arrTmp, paramTmp) = mutableArray ("tmp" :: Name (Vector e))
paramEnv = envParam aenv
config = launchConfig dev (CUDA.incWarp dev) smem const [|| const ||]
smem n = warps * (1 + per_warp) * bytes
where
ws = CUDA.warpSize dev
warps = n `P.quot` ws
per_warp = ws + ws `P.quot` 2
bytes = sizeOf (eltType (undefined :: e))
in
makeOpenAccWith config "scanP1" (paramGang ++ paramTmp ++ paramOut ++ paramEnv) $ do
sz <- A.fromIntegral integralType numType . indexHead =<< delayedExtent
bid <- blockIdx
gd <- gridDim
s0 <- A.add numType start bid
imapFromStepTo s0 gd end $ \chunk -> do
bd <- blockDim
inf <- A.mul numType chunk bd
tid <- threadIdx
i0 <- case dir of
L -> A.add numType inf tid
R -> do x <- A.sub numType sz inf
y <- A.sub numType x tid
z <- A.sub numType y (lift 1)
return z
j0 <- case mseed of
Nothing -> return i0
Just _ -> case dir of
L -> A.add numType i0 (lift 1)
R -> return i0
let valid i = case dir of
L -> A.lt scalarType i sz
R -> A.gte scalarType i (lift 0)
when (valid i0) $ do
x0 <- app1 delayedLinearIndex =<< A.fromIntegral integralType numType i0
x1 <- case mseed of
Nothing -> return x0
Just seed ->
if A.eq scalarType tid (lift 0) `A.land` A.eq scalarType chunk (lift 0)
then do
z <- seed
case dir of
L -> writeArray arrOut (lift 0 :: IR Int32) z >> app2 combine z x0
R -> writeArray arrOut sz z >> app2 combine x0 z
else
return x0
n <- A.sub numType sz inf
x2 <- if A.gte scalarType n bd
then scanBlockSMem dir dev combine Nothing x1
else scanBlockSMem dir dev combine (Just n) x1
writeArray arrOut j0 x2
last <- A.sub numType bd (lift 1)
when (A.gt scalarType gd (lift 1) `land` A.eq scalarType tid last) $
case dir of
L -> writeArray arrTmp chunk x2
R -> do u <- A.sub numType end chunk
v <- A.sub numType u (lift 1)
writeArray arrTmp v x2
return_
mkScanAllP2
:: forall aenv e. Elt e
=> Direction
-> DeviceProperties
-> Gamma aenv
-> IRFun2 PTX aenv (e -> e -> e)
-> CodeGen (IROpenAcc PTX aenv (Vector e))
mkScanAllP2 dir dev aenv combine =
let
(start, end, paramGang) = gangParam
(arrTmp, paramTmp) = mutableArray ("tmp" :: Name (Vector e))
paramEnv = envParam aenv
config = launchConfig dev (CUDA.incWarp dev) smem grid gridQ
grid _ _ = 1
gridQ = [|| \_ _ -> 1 ||]
smem n = warps * (1 + per_warp) * bytes
where
ws = CUDA.warpSize dev
warps = n `P.quot` ws
per_warp = ws + ws `P.quot` 2
bytes = sizeOf (eltType (undefined :: e))
in
makeOpenAccWith config "scanP2" (paramGang ++ paramTmp ++ paramEnv) $ do
carry <- staticSharedMem 1
bd <- blockDim
imapFromStepTo start bd end $ \offset -> do
tid <- threadIdx
i0 <- case dir of
L -> A.add numType offset tid
R -> do x <- A.sub numType end offset
y <- A.sub numType x tid
z <- A.sub numType y (lift 1)
return z
let valid i = case dir of
L -> A.lt scalarType i end
R -> A.gte scalarType i start
when (valid i0) $ do
__syncthreads
x0 <- readArray arrTmp i0
x1 <- if A.gt scalarType offset (lift 0) `land` A.eq scalarType tid (lift 0)
then do
c <- readArray carry (lift 0 :: IR Int32)
case dir of
L -> app2 combine c x0
R -> app2 combine x0 c
else do
return x0
n <- A.sub numType end offset
x2 <- if A.gte scalarType n bd
then scanBlockSMem dir dev combine Nothing x1
else scanBlockSMem dir dev combine (Just n) x1
writeArray arrTmp i0 x2
last <- A.sub numType bd (lift 1)
when (A.eq scalarType tid last) $
writeArray carry (lift 0 :: IR Int32) x2
return_
mkScanAllP3
:: forall aenv e. Elt e
=> Direction
-> DeviceProperties
-> Gamma aenv
-> IRFun2 PTX aenv (e -> e -> e)
-> Maybe (IRExp PTX aenv e)
-> CodeGen (IROpenAcc PTX aenv (Vector e))
mkScanAllP3 dir dev aenv combine mseed =
let
(start, end, paramGang) = gangParam
(arrOut, paramOut) = mutableArray ("out" :: Name (Vector e))
(arrTmp, paramTmp) = mutableArray ("tmp" :: Name (Vector e))
paramEnv = envParam aenv
stride = local scalarType ("ix.stride" :: Name Int32)
paramStride = scalarParameter scalarType ("ix.stride" :: Name Int32)
config = launchConfig dev (CUDA.incWarp dev) (const 0) const [|| const ||]
in
makeOpenAccWith config "scanP3" (paramGang ++ paramTmp ++ paramOut ++ paramStride : paramEnv) $ do
sz <- A.fromIntegral integralType numType (indexHead (irArrayShape arrOut))
tid <- threadIdx
when (A.lt scalarType tid stride) $ do
bid <- blockIdx
gd <- gridDim
c0 <- A.add numType start bid
imapFromStepTo c0 gd end $ \chunk -> do
(inf,sup) <- case dir of
L -> do
a <- A.add numType chunk (lift 1)
b <- A.mul numType stride a
case mseed of
Just{} -> do
c <- A.add numType b (lift 1)
d <- A.add numType c stride
e <- A.min scalarType d sz
return (c,e)
Nothing -> do
c <- A.add numType b stride
d <- A.min scalarType c sz
return (b,d)
R -> do
a <- A.sub numType end chunk
b <- A.mul numType stride a
c <- A.sub numType sz b
case mseed of
Just{} -> do
d <- A.sub numType c (lift 1)
e <- A.sub numType d stride
f <- A.max scalarType e (lift 0)
return (f,d)
Nothing -> do
d <- A.sub numType c stride
e <- A.max scalarType d (lift 0)
return (e,c)
carry <- case dir of
L -> readArray arrTmp chunk
R -> do
a <- A.add numType chunk (lift 1)
b <- readArray arrTmp a
return b
bd <- blockDim
i0 <- A.add numType inf tid
imapFromStepTo i0 bd sup $ \i -> do
v <- readArray arrOut i
u <- case dir of
L -> app2 combine carry v
R -> app2 combine v carry
writeArray arrOut i u
return_
mkScan'AllP1
:: forall aenv e. Elt e
=> Direction
-> DeviceProperties
-> Gamma aenv
-> IRFun2 PTX aenv (e -> e -> e)
-> IRExp PTX aenv e
-> IRDelayed PTX aenv (Vector e)
-> CodeGen (IROpenAcc PTX aenv (Vector e, Scalar e))
mkScan'AllP1 dir dev aenv combine seed IRDelayed{..} =
let
(start, end, paramGang) = gangParam
(arrOut, paramOut) = mutableArray ("out" :: Name (Vector e))
(arrTmp, paramTmp) = mutableArray ("tmp" :: Name (Vector e))
paramEnv = envParam aenv
config = launchConfig dev (CUDA.incWarp dev) smem const [|| const ||]
smem n = warps * (1 + per_warp) * bytes
where
ws = CUDA.warpSize dev
warps = n `P.quot` ws
per_warp = ws + ws `P.quot` 2
bytes = sizeOf (eltType (undefined :: e))
in
makeOpenAccWith config "scanP1" (paramGang ++ paramTmp ++ paramOut ++ paramEnv) $ do
sz <- A.fromIntegral integralType numType . indexHead =<< delayedExtent
bid <- blockIdx
gd <- gridDim
s0 <- A.add numType start bid
imapFromStepTo s0 gd end $ \seg -> do
bd <- blockDim
inf <- A.mul numType seg bd
tid <- threadIdx
i0 <- case dir of
L -> A.add numType inf tid
R -> do x <- A.sub numType sz inf
y <- A.sub numType x tid
z <- A.sub numType y (lift 1)
return z
j0 <- case dir of
L -> A.add numType i0 (lift 1)
R -> A.sub numType i0 (lift 1)
let valid i = case dir of
L -> A.lt scalarType i sz
R -> A.gte scalarType i (lift 0)
when (valid i0) $ do
x0 <- app1 delayedLinearIndex =<< A.fromIntegral integralType numType i0
x1 <- if A.eq scalarType tid (lift 0) `A.land` A.eq scalarType seg (lift 0)
then do
z <- seed
writeArray arrOut i0 z
case dir of
L -> app2 combine z x0
R -> app2 combine x0 z
else
return x0
n <- A.sub numType sz inf
x2 <- if A.gte scalarType n bd
then scanBlockSMem dir dev combine Nothing x1
else scanBlockSMem dir dev combine (Just n) x1
case dir of
L -> when (A.lt scalarType j0 sz) $ writeArray arrOut j0 x2
R -> when (A.gte scalarType j0 (lift 0)) $ writeArray arrOut j0 x2
m <- do x <- A.min scalarType n bd
y <- A.sub numType x (lift 1)
return y
when (A.eq scalarType tid m) $
case dir of
L -> writeArray arrTmp seg x2
R -> do x <- A.sub numType end seg
y <- A.sub numType x (lift 1)
writeArray arrTmp y x2
return_
mkScan'AllP2
:: forall aenv e. Elt e
=> Direction
-> DeviceProperties
-> Gamma aenv
-> IRFun2 PTX aenv (e -> e -> e)
-> CodeGen (IROpenAcc PTX aenv (Vector e, Scalar e))
mkScan'AllP2 dir dev aenv combine =
let
(start, end, paramGang) = gangParam
(arrTmp, paramTmp) = mutableArray ("tmp" :: Name (Vector e))
(arrSum, paramSum) = mutableArray ("sum" :: Name (Scalar e))
paramEnv = envParam aenv
config = launchConfig dev (CUDA.incWarp dev) smem grid gridQ
grid _ _ = 1
gridQ = [|| \_ _ -> 1 ||]
smem n = warps * (1 + per_warp) * bytes
where
ws = CUDA.warpSize dev
warps = n `P.quot` ws
per_warp = ws + ws `P.quot` 2
bytes = sizeOf (eltType (undefined :: e))
in
makeOpenAccWith config "scanP2" (paramGang ++ paramTmp ++ paramSum ++ paramEnv) $ do
carry <- staticSharedMem 1
tid <- threadIdx
bd <- blockDim
imapFromStepTo start bd end $ \offset -> do
i0 <- case dir of
L -> A.add numType offset tid
R -> do x <- A.sub numType end offset
y <- A.sub numType x tid
z <- A.sub numType y (lift 1)
return z
let valid i = case dir of
L -> A.lt scalarType i end
R -> A.gte scalarType i start
when (valid i0) $ do
__syncthreads
x0 <- readArray arrTmp i0
x1 <- if A.gt scalarType offset (lift 0) `A.land` A.eq scalarType tid (lift 0)
then do
c <- readArray carry (lift 0 :: IR Int32)
case dir of
L -> app2 combine c x0
R -> app2 combine x0 c
else
return x0
n <- A.sub numType end offset
x2 <- if A.gte scalarType n bd
then scanBlockSMem dir dev combine Nothing x1
else scanBlockSMem dir dev combine (Just n) x1
writeArray arrTmp i0 x2
m <- do x <- A.min scalarType bd n
y <- A.sub numType x (lift 1)
return y
when (A.eq scalarType tid m) $
writeArray carry (lift 0 :: IR Int32) x2
__syncthreads
when (A.eq scalarType tid (lift 0)) $
writeArray arrSum (lift 0 :: IR Int32) =<< readArray carry (lift 0 :: IR Int32)
return_
mkScan'AllP3
:: forall aenv e. Elt e
=> Direction
-> DeviceProperties
-> Gamma aenv
-> IRFun2 PTX aenv (e -> e -> e)
-> CodeGen (IROpenAcc PTX aenv (Vector e, Scalar e))
mkScan'AllP3 dir dev aenv combine =
let
(start, end, paramGang) = gangParam
(arrOut, paramOut) = mutableArray ("out" :: Name (Vector e))
(arrTmp, paramTmp) = mutableArray ("tmp" :: Name (Vector e))
paramEnv = envParam aenv
stride = local scalarType ("ix.stride" :: Name Int32)
paramStride = scalarParameter scalarType ("ix.stride" :: Name Int32)
config = launchConfig dev (CUDA.incWarp dev) (const 0) const [|| const ||]
in
makeOpenAccWith config "scanP3" (paramGang ++ paramTmp ++ paramOut ++ paramStride : paramEnv) $ do
sz <- A.fromIntegral integralType numType (indexHead (irArrayShape arrOut))
tid <- threadIdx
when (A.lt scalarType tid stride) $ do
bid <- blockIdx
gd <- gridDim
c0 <- A.add numType start bid
imapFromStepTo c0 gd end $ \chunk -> do
(inf,sup) <- case dir of
L -> do
a <- A.add numType chunk (lift 1)
b <- A.mul numType stride a
c <- A.add numType b (lift 1)
d <- A.add numType c stride
e <- A.min scalarType d sz
return (c,e)
R -> do
a <- A.sub numType end chunk
b <- A.mul numType stride a
c <- A.sub numType sz b
d <- A.sub numType c (lift 1)
e <- A.sub numType d stride
f <- A.max scalarType e (lift 0)
return (f,d)
carry <- case dir of
L -> readArray arrTmp chunk
R -> do
a <- A.add numType chunk (lift 1)
b <- readArray arrTmp a
return b
bd <- blockDim
i0 <- A.add numType inf tid
imapFromStepTo i0 bd sup $ \i -> do
v <- readArray arrOut i
u <- case dir of
L -> app2 combine carry v
R -> app2 combine v carry
writeArray arrOut i u
return_
mkScanDim
:: forall aenv sh e. (Shape sh, Elt e)
=> Direction
-> DeviceProperties
-> Gamma aenv
-> IRFun2 PTX aenv (e -> e -> e)
-> Maybe (IRExp PTX aenv e)
-> IRDelayed PTX aenv (Array (sh:.Int) e)
-> CodeGen (IROpenAcc PTX aenv (Array (sh:.Int) e))
mkScanDim dir dev aenv combine mseed IRDelayed{..} =
let
(start, end, paramGang) = gangParam
(arrOut, paramOut) = mutableArray ("out" :: Name (Array (sh:.Int) e))
paramEnv = envParam aenv
config = launchConfig dev (CUDA.incWarp dev) smem const [|| const ||]
smem n = warps * (1 + per_warp) * bytes
where
ws = CUDA.warpSize dev
warps = n `P.quot` ws
per_warp = ws + ws `P.quot` 2
bytes = sizeOf (eltType (undefined :: e))
in
makeOpenAccWith config "scan" (paramGang ++ paramOut ++ paramEnv) $ do
carry <- staticSharedMem 1
sz <- A.fromIntegral integralType numType . indexHead =<< delayedExtent
bid <- blockIdx
gd <- gridDim
s0 <- A.add numType start bid
imapFromStepTo s0 gd end $ \seg -> do
tid <- threadIdx
i0 <- case dir of
L -> do x <- A.mul numType seg sz
y <- A.add numType x tid
return y
R -> do x <- A.add numType seg (lift 1)
y <- A.mul numType x sz
z <- A.sub numType y tid
w <- A.sub numType z (lift 1)
return w
j0 <- case mseed of
Nothing -> return i0
Just{} -> do szp1 <- A.fromIntegral integralType numType (indexHead (irArrayShape arrOut))
case dir of
L -> do x <- A.mul numType seg szp1
y <- A.add numType x tid
return y
R -> do x <- A.add numType seg (lift 1)
y <- A.mul numType x szp1
z <- A.sub numType y tid
w <- A.sub numType z (lift 1)
return w
bd <- blockDim
let next ix = case dir of
L -> A.add numType ix bd
R -> A.sub numType ix bd
r <-
case mseed of
Just seed -> do
when (A.eq scalarType tid (lift 0)) $ do
z <- seed
writeArray arrOut j0 z
writeArray carry (lift 0 :: IR Int32) z
j1 <- case dir of
L -> A.add numType j0 (lift 1)
R -> A.sub numType j0 (lift 1)
return $ A.trip sz i0 j1
Nothing -> do
when (A.lt scalarType tid sz) $ do
x0 <- app1 delayedLinearIndex =<< A.fromIntegral integralType numType i0
r0 <- if A.gte scalarType sz bd
then scanBlockSMem dir dev combine Nothing x0
else scanBlockSMem dir dev combine (Just sz) x0
writeArray arrOut j0 r0
ll <- A.sub numType bd (lift 1)
when (A.eq scalarType tid ll) $
writeArray carry (lift 0 :: IR Int32) r0
n1 <- A.sub numType sz bd
i1 <- next i0
j1 <- next j0
return $ A.trip n1 i1 j1
void $ while
(\(A.fst3 -> n) -> A.gt scalarType n (lift 0))
(\(A.untrip -> (n,i,j)) -> do
__syncthreads
x <- if A.lt scalarType tid n
then app1 delayedLinearIndex =<< A.fromIntegral integralType numType i
else let
go :: TupleType a -> Operands a
go UnitTuple = OP_Unit
go (PairTuple a b) = OP_Pair (go a) (go b)
go (SingleTuple t) = ir' t (undef t)
in
return . IR $ go (eltType (undefined::e))
y <- if A.eq scalarType tid (lift 0)
then do
c <- readArray carry (lift 0 :: IR Int32)
case dir of
L -> app2 combine c x
R -> app2 combine x c
else
return x
z <- if A.gte scalarType n bd
then scanBlockSMem dir dev combine Nothing y
else scanBlockSMem dir dev combine (Just n) y
when (A.lt scalarType tid n) $ do
writeArray arrOut j z
w <- A.sub numType bd (lift 1)
when (A.eq scalarType tid w) $
writeArray carry (lift 0 :: IR Int32) z
n' <- A.sub numType n bd
i' <- next i
j' <- next j
return $ A.trip n' i' j')
r
return_
mkScan'Dim
:: forall aenv sh e. (Shape sh, Elt e)
=> Direction
-> DeviceProperties
-> Gamma aenv
-> IRFun2 PTX aenv (e -> e -> e)
-> IRExp PTX aenv e
-> IRDelayed PTX aenv (Array (sh:.Int) e)
-> CodeGen (IROpenAcc PTX aenv (Array (sh:.Int) e, Array sh e))
mkScan'Dim dir dev aenv combine seed IRDelayed{..} =
let
(start, end, paramGang) = gangParam
(arrOut, paramOut) = mutableArray ("out" :: Name (Array (sh:.Int) e))
(arrSum, paramSum) = mutableArray ("sum" :: Name (Array sh e))
paramEnv = envParam aenv
config = launchConfig dev (CUDA.incWarp dev) smem const [|| const ||]
smem n = warps * (1 + per_warp) * bytes
where
ws = CUDA.warpSize dev
warps = n `P.quot` ws
per_warp = ws + ws `P.quot` 2
bytes = sizeOf (eltType (undefined :: e))
in
makeOpenAccWith config "scan" (paramGang ++ paramOut ++ paramSum ++ paramEnv) $ do
carry <- staticSharedMem 1
sz <- A.fromIntegral integralType numType . indexHead =<< delayedExtent
tid <- threadIdx
when (A.lte scalarType tid sz) $ do
bid <- blockIdx
gd <- gridDim
s0 <- A.add numType start bid
imapFromStepTo s0 gd end $ \seg -> do
inf <- A.mul numType seg sz
sup <- A.add numType inf sz
i0 <- case dir of
L -> A.add numType inf tid
R -> do x <- A.sub numType sup tid
y <- A.sub numType x (lift 1)
return y
j0 <- case dir of
L -> A.add numType i0 (lift 1)
R -> A.sub numType i0 (lift 1)
when (A.eq scalarType tid (lift 0)) $ do
z <- seed
writeArray arrOut i0 z
writeArray carry (lift 0 :: IR Int32) z
bd <- blockDim
let next ix = case dir of
L -> A.add numType ix bd
R -> A.sub numType ix bd
n0 <- A.sub numType sup inf
void $ while
(\(A.fst3 -> n) -> A.gt scalarType n (lift 0))
(\(A.untrip -> (n,i,j)) -> do
__syncthreads
_ <- if A.gte scalarType n bd
then do
x <- app1 delayedLinearIndex =<< A.fromIntegral integralType numType i
y <- if A.eq scalarType tid (lift 0)
then do
c <- readArray carry (lift 0 :: IR Int32)
case dir of
L -> app2 combine c x
R -> app2 combine x c
else
return x
z <- scanBlockSMem dir dev combine Nothing y
case dir of
L -> when (A.lt scalarType j sup) $ writeArray arrOut j z
R -> when (A.gte scalarType j inf) $ writeArray arrOut j z
bd1 <- A.sub numType bd (lift 1)
when (A.eq scalarType tid bd1) $
writeArray carry (lift 0 :: IR Int32) z
return (IR OP_Unit :: IR ())
else do
when (A.lt scalarType tid n) $ do
x <- app1 delayedLinearIndex =<< A.fromIntegral integralType numType i
y <- if A.eq scalarType tid (lift 0)
then do
c <- readArray carry (lift 0 :: IR Int32)
case dir of
L -> app2 combine c x
R -> app2 combine x c
else
return x
z <- scanBlockSMem dir dev combine (Just n) y
m <- A.sub numType n (lift 1)
_ <- if A.lt scalarType tid m
then writeArray arrOut j z >> return (IR OP_Unit :: IR ())
else writeArray carry (lift 0 :: IR Int32) z >> return (IR OP_Unit :: IR ())
return ()
return (IR OP_Unit :: IR ())
A.trip <$> A.sub numType n bd <*> next i <*> next j)
(A.trip n0 i0 j0)
__syncthreads
when (A.eq scalarType tid (lift 0)) $
writeArray arrSum seg =<< readArray carry (lift 0 :: IR Int32)
return_
mkScanFill
:: (Shape sh, Elt e)
=> PTX
-> Gamma aenv
-> IRExp PTX aenv e
-> CodeGen (IROpenAcc PTX aenv (Array sh e))
mkScanFill ptx aenv seed =
mkGenerate ptx aenv (IRFun1 (const seed))
mkScan'Fill
:: forall aenv sh e. (Shape sh, Elt e)
=> PTX
-> Gamma aenv
-> IRExp PTX aenv e
-> CodeGen (IROpenAcc PTX aenv (Array (sh:.Int) e, Array sh e))
mkScan'Fill ptx aenv seed =
Safe.coerce <$> (mkGenerate ptx aenv (IRFun1 (const seed)) :: CodeGen (IROpenAcc PTX aenv (Array sh e)))
scanBlockSMem
:: forall aenv e. Elt e
=> Direction
-> DeviceProperties
-> IRFun2 PTX aenv (e -> e -> e)
-> Maybe (IR Int32)
-> IR e
-> CodeGen (IR e)
scanBlockSMem dir dev combine nelem = warpScan >=> warpPrefix
where
int32 :: Integral a => a -> IR Int32
int32 = lift . P.fromIntegral
warp_smem_elems = CUDA.warpSize dev + (CUDA.warpSize dev `P.quot` 2)
warp_smem_bytes = warp_smem_elems * sizeOf (eltType (undefined::e))
warpScan :: IR e -> CodeGen (IR e)
warpScan input = do
wid <- warpId
skip <- A.mul numType wid (int32 warp_smem_bytes)
smem <- dynamicSharedMem (int32 warp_smem_elems) skip
scanWarpSMem dir dev combine smem input
warpPrefix :: IR e -> CodeGen (IR e)
warpPrefix input = do
bd <- blockDim
warps <- A.quot integralType bd (int32 (CUDA.warpSize dev))
skip <- A.mul numType warps (int32 warp_smem_bytes)
smem <- dynamicSharedMem warps skip
wid <- warpId
lane <- laneId
when (A.eq scalarType lane (int32 (CUDA.warpSize dev - 1))) $ do
writeArray smem wid input
__syncthreads
if A.eq scalarType wid (lift 0)
then return input
else do
steps <- case nelem of
Nothing -> return wid
Just n -> A.min scalarType wid =<< A.quot integralType n (int32 (CUDA.warpSize dev))
p0 <- readArray smem (lift 0 :: IR Int32)
prefix <- iterFromStepTo (lift 1) (lift 1) steps p0 $ \step x -> do
y <- readArray smem step
case dir of
L -> app2 combine x y
R -> app2 combine y x
case dir of
L -> app2 combine prefix input
R -> app2 combine input prefix
scanWarpSMem
:: forall aenv e. Elt e
=> Direction
-> DeviceProperties
-> IRFun2 PTX aenv (e -> e -> e)
-> IRArray (Vector e)
-> IR e
-> CodeGen (IR e)
scanWarpSMem dir dev combine smem = scan 0
where
log2 :: Double -> Double
log2 = P.logBase 2
steps = P.floor (log2 (P.fromIntegral (CUDA.warpSize dev)))
halfWarp = P.fromIntegral (CUDA.warpSize dev `P.quot` 2)
scan :: Int -> IR e -> CodeGen (IR e)
scan step x
| step >= steps = return x
| offset <- 1 `P.shiftL` step = do
lane <- laneId
i <- A.add numType lane (lift halfWarp)
writeArray smem i x
x' <- if A.gte scalarType lane (lift offset)
then do
i' <- A.sub numType i (lift offset)
x' <- readArray smem i'
case dir of
L -> app2 combine x' x
R -> app2 combine x x'
else
return x
scan (step+1) x'