module Data.Array.Accelerate.CUFFT.Private where
import qualified Data.Array.Accelerate.CUFFT.RealClass as RC
import qualified Data.Array.Accelerate.Fourier.Preprocessed as Prep
import qualified Data.Array.Accelerate.Fourier.Planned as Fourier
import qualified Data.Array.Accelerate.Utility.Lift.Exp as Exp
import qualified Data.Array.Accelerate.Utility.Sliced as Sliced
import Data.Array.Accelerate.Utility.Lift.Exp (expr)
import qualified Data.Array.Accelerate.LLVM.PTX.Foreign as AF
import qualified Data.Array.Accelerate.LLVM.PTX as PTX
import qualified Data.Array.Accelerate.Array.Sugar as Sugar
import qualified Data.Array.Accelerate as A
import Data.Array.Accelerate.Data.Complex (Complex((:+)), real, imag, conjugate)
import Data.Array.Accelerate.Lifetime (withLifetime)
import Data.Array.Accelerate
(Acc, Array, Elt, Shape, Slice, (:.)((:.)), Exp, (!), (?))
import qualified Foreign.CUDA.FFT as CUFFT
import qualified Foreign.CUDA.Driver as CUDA
import Foreign.CUDA.Ptr (DevicePtr)
import qualified System.Mem.Weak as Weak
import Control.Exception (bracket_)
type Transform sh a b = Acc (Array sh a) -> Acc (Array sh b)
type Sign a = (Int, Fourier.Sign a)
forwardSign, inverseSign :: Num a => Sign a
forwardSign = (1, Fourier.forward)
inverseSign = ( 1, Fourier.inverse)
data
Handle sh e a b =
Handle (Transform sh a b) (Mode sh e a b) Int CUFFT.Handle
makeHandle ::
(Shape sh, Slice sh, RC.Real e) =>
Mode sh e a b -> Int ->
(Fourier.Sign e -> Transform sh a b) ->
(CUFFT.Type -> IO CUFFT.Handle) ->
IO (Handle sh e a b)
makeHandle mode width fallback planner = do
plan <- planner $ types mode
Weak.addFinalizer plan (CUFFT.destroy plan)
return $ Handle (fallback $ fsign mode) mode width plan
getBestTarget :: IO AF.PTX
getBestTarget = do
CUDA.initialise []
dev <- CUDA.device 0
prop <- CUDA.props dev
PTX.createTargetForDevice dev prop [CUDA.SchedAuto]
atTarget :: AF.PTX -> IO a -> IO a
atTarget target act =
withLifetime (AF.deviceContext $ AF.ptxContext target) $ \ctx ->
bracket_ (CUDA.push ctx) CUDA.pop act
type Batch0 sh = sh
type Batch1 sh = Batch0 sh :. Int
type Batch2 sh = Batch1 sh :. Int
type Batch3 sh = Batch2 sh :. Int
transform2D ::
(Shape sh, Slice sh, RC.Real a) =>
Batch2 sh -> Fourier.Sign a ->
Fourier.Transform (Batch2 sh) (Complex a)
transform2D (_shape:.height:.width) sign =
Prep.transform2d $
Prep.SubTransformPair
(Fourier.transform sign width)
(Fourier.transform sign height)
transform3D ::
(Shape sh, Slice sh, RC.Real a) =>
Batch3 sh -> Fourier.Sign a ->
Fourier.Transform (Batch3 sh) (Complex a)
transform3D (_shape:.depth:.height:.width) sign =
Prep.transform3d $
Prep.SubTransformTriple
(Fourier.transform sign width)
(Fourier.transform sign height)
(Fourier.transform sign depth)
transform ::
(Shape sh, Slice sh, RC.Real e) =>
Handle (sh:.Int) e a b ->
Transform (sh:.Int) a b
transform hndl@(Handle fallback mode width _) =
wrap mode (A.constant width) $
A.foreignAcc
(AF.ForeignAcc "transformForeign" $ transformForeign hndl)
(unwrap mode (A.constant width) fallback)
forwardComplex, inverseComplex ::
(Shape sh, Slice sh, RC.Real e) =>
Mode sh e (Complex e) (Complex e)
forwardComplex =
getModeC2C $ RC.switch (modeC2CFloat forwardSign) (modeC2CDouble forwardSign)
inverseComplex =
getModeC2C $ RC.switch (modeC2CFloat inverseSign) (modeC2CDouble inverseSign)
forwardReal ::
(Shape sh, Slice sh, RC.Real e) =>
Mode (sh:.Int) e e (Complex e)
forwardReal =
getModeR2C $
RC.switch
(modeR2C CUFFT.R2C CUFFT.execR2C)
(modeR2C CUFFT.D2Z CUFFT.execD2Z)
inverseReal ::
(Shape sh, Slice sh, RC.Real e) =>
Mode (sh:.Int) e (Complex e) e
inverseReal =
getModeC2R $
RC.switch
(modeC2R CUFFT.C2R CUFFT.execC2R)
(modeC2R CUFFT.Z2D CUFFT.execZ2D)
data Types = R2C | C2R | C2C
deriving (Eq, Ord, Enum, Show)
data Mode sh e a b =
Mode {
types :: CUFFT.Type,
plainTypes :: Types,
execute ::
CUFFT.Handle ->
CUDA.DevicePtr (Sugar.EltRepr e) ->
CUDA.DevicePtr (Sugar.EltRepr e) ->
IO (),
wrap :: Exp Int -> Fourier.Transform (sh:.Int) e -> Transform sh a b,
unwrap :: Exp Int -> Transform sh a b -> Fourier.Transform (sh:.Int) e,
wrapFallback :: Fourier.Transform sh (Complex e) -> Transform sh a b,
fsign :: Fourier.Sign e
}
newtype
ModeC2C sh e =
ModeC2C {getModeC2C :: Mode sh e (Complex e) (Complex e)}
newtype
ModeR2C sh e =
ModeR2C {getModeR2C :: Mode sh e e (Complex e)}
newtype
ModeC2R sh e =
ModeC2R {getModeC2R :: Mode sh e (Complex e) e}
type Execute e = CUFFT.Handle -> DevicePtr e -> DevicePtr e -> IO ()
type ExecuteSign e = CUFFT.Handle -> DevicePtr e -> DevicePtr e -> Int -> IO ()
modeC2C ::
(Shape sh, Slice sh, RC.Real e) =>
CUFFT.Type -> ExecuteSign (Sugar.EltRepr e) -> Sign e -> ModeC2C sh e
modeC2C typ exec (isign,fsign0) =
ModeC2C $
Mode {
types = typ,
execute = \hndl iptr optr -> exec hndl iptr optr isign,
plainTypes = C2C,
wrap = \ _width f -> deinterleave . f . interleave,
unwrap = \ _width f -> interleave . f . deinterleave,
wrapFallback = id,
fsign = fsign0
}
modeC2CFloat :: (Shape sh, Slice sh) => Sign Float -> ModeC2C sh Float
modeC2CFloat = modeC2C CUFFT.C2C CUFFT.execC2C
modeC2CDouble :: (Shape sh, Slice sh) => Sign Double -> ModeC2C sh Double
modeC2CDouble = modeC2C CUFFT.Z2Z CUFFT.execZ2Z
modeR2C ::
(Shape sh, Slice sh, RC.Real e) =>
CUFFT.Type -> Execute (Sugar.EltRepr e) -> ModeR2C (sh:.Int) e
modeR2C typ exec =
ModeR2C $
Mode {
types = typ,
execute = exec,
plainTypes = R2C,
wrap = \width f -> mirror width . deinterleave . f . addDim,
unwrap = \width f -> interleave . takeHalf width . f . removeDim,
wrapFallback = (. A.map (Exp.modify expr (:+0))),
fsign = Fourier.forward
}
modeC2R ::
(Shape sh, Slice sh, RC.Real e) =>
CUFFT.Type -> Execute (Sugar.EltRepr e) -> ModeC2R (sh:.Int) e
modeC2R typ exec =
ModeC2R $
Mode {
types = typ,
execute = exec,
plainTypes = C2R,
wrap = \width f -> removeDim . f . interleave . takeHalf width,
unwrap = \width f -> addDim . f . mirror width . deinterleave,
wrapFallback = (A.map real .),
fsign = Fourier.inverse
}
transformForeign ::
(Shape sh, RC.Real e) =>
Handle (sh:.Int) e a b -> AF.Stream ->
Array (sh:.Int:.Int) e -> AF.LLVM AF.PTX (Array (sh:.Int:.Int) e)
transformForeign (Handle _ mode width hndl) stream input = do
let (shape :. _width :. _tupleSize) = A.arrayShape input
outputSh =
case plainTypes mode of
R2C -> shape :. div width 2 + 1 :. (2::Int)
C2R -> shape :. width :. 1
C2C -> shape :. width :. 2
output <- AF.allocateRemote outputSh
withDevicePtr input $ \iptr ->
withDevicePtr output $ \optr -> do
AF.liftIO $ execute mode hndl iptr optr
ev <- AF.checkpoint stream
return (Just ev, (Just ev, output))
newtype
WithDevicePtr target r sh e =
WithDevicePtr {
runWithDevicePtr ::
Array sh e ->
(CUDA.DevicePtr (Sugar.EltRepr e) ->
AF.LLVM target (Maybe AF.Event, r)) ->
AF.LLVM target r
}
withDevicePtr ::
(RC.Real e) =>
Array sh e ->
(CUDA.DevicePtr (Sugar.EltRepr e) -> AF.LLVM AF.PTX (Maybe AF.Event, r)) ->
AF.LLVM AF.PTX r
withDevicePtr =
runWithDevicePtr
(RC.switch
(WithDevicePtr $ \(Sugar.Array _ dat) -> AF.withDevicePtr dat)
(WithDevicePtr $ \(Sugar.Array _ dat) -> AF.withDevicePtr dat))
interleave ::
(Shape sh, Slice sh, Elt a) =>
Acc (Array sh (Complex a)) -> Acc (Array (sh:.Int) a)
interleave arr =
A.generate
(A.lift $ A.shape arr :. (2::Int))
(\ix ->
let x = arr ! A.indexTail ix
in A.indexHead ix A.== 0 ? (real x, imag x))
deinterleave ::
(Shape sh, Slice sh, Elt a) =>
Acc (Array (sh:.Int) a) -> Acc (Array sh (Complex a))
deinterleave arr =
A.generate (A.indexTail $ A.shape arr)
(\ix ->
let get n = arr ! A.lift (ix :. (n::Int))
in A.lift $ get 0 :+ get 1)
addDim ::
(Shape sh, Slice sh, Elt a) =>
Acc (Array sh a) -> Acc (Array (sh:.Int) a)
addDim arr = A.reshape (A.lift $ A.shape arr :. (1::Int)) arr
removeDim ::
(Shape sh, Slice sh, Elt a) =>
Acc (Array (sh:.Int) a) -> Acc (Array sh a)
removeDim arr = A.reshape (A.indexTail $ A.shape arr) arr
takeHalf ::
(Shape sh, Slice sh, Elt a) =>
Exp Int -> Fourier.Transform (sh:.Int) a
takeHalf width = Sliced.take (div width 2 + 1)
mirror ::
(Shape sh, Slice sh, A.Num a) =>
Exp Int -> Fourier.Transform (sh:.Int) (Complex a)
mirror newWidth arr =
let (sh:.width) = Exp.unlift (expr:.expr) $ A.shape arr
in A.generate (A.lift $ sh :. newWidth) $
Exp.modify (expr:.expr) $ \(ix:.k) ->
k A.< width ?
(arr ! Exp.indexCons ix k,
conjugate (arr ! Exp.indexCons ix (newWidth k)))