{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{- |
Transformations of collections of datasets.
-}
module Data.Array.Accelerate.CUFFT.Batched (
   Priv.Transform,
   Priv.transform,

   Handle,
   plan1D,
   plan2D,
   plan3D,

   RC.Real,
   Mode,
   Priv.forwardComplex, Priv.inverseComplex,
   Priv.forwardReal, Priv.inverseReal,
   Batch0, Batch1, Batch2, Batch3,

   Priv.getBestTarget,
   ) where

import qualified Data.Array.Accelerate.CUFFT.Private as Priv
import Data.Array.Accelerate.CUFFT.Private
          (Batch0, Batch1, Batch2, Batch3,
           Mode, wrapFallback, Handle, makeHandle, atTarget)

import qualified Data.Array.Accelerate.CUFFT.RealClass as RC
import Data.Array.Accelerate.LLVM.PTX (PTX)

import qualified Data.Array.Accelerate.Fourier.Planned as Fourier

import qualified Data.Array.Accelerate as A
import Data.Array.Accelerate (Elt, Shape, Slice, (:.)((:.)), )

import qualified Foreign.CUDA.FFT as CUFFT


{- |
The plan must be created in the 'Data.Array.Accelerate.LLVM.PTX.PTX' target
where 'Priv.transform' is executed.
That is, if using cuFFT you always have to use a @run*With@ function.
-}
plan1D ::
   (Shape sh, Slice sh, Elt e, RC.Real e) =>
   PTX -> Mode (Batch1 sh) e a b -> Batch1 sh -> IO (Handle (Batch1 sh) e a b)
plan1D target mode (batch:.width) =
   atTarget target $
   makeHandle mode width
      (\sign -> wrapFallback mode $ Fourier.transform sign width)
      (\typ -> CUFFT.planMany [width] Nothing Nothing typ (A.arraySize batch))

plan2D ::
   (Shape sh, Slice sh, Elt e, RC.Real e) =>
   PTX -> Mode (Batch2 sh) e a b -> Batch2 sh -> IO (Handle (Batch2 sh) e a b)
plan2D target mode sh@(batch:.height:.width) =
   atTarget target $
   makeHandle mode width
      (wrapFallback mode . Priv.transform2D sh)
      (\typ ->
         CUFFT.planMany [height,width] Nothing Nothing typ (A.arraySize batch))

plan3D ::
   (Shape sh, Slice sh, Elt e, RC.Real e) =>
   PTX -> Mode (Batch3 sh) e a b -> Batch3 sh -> IO (Handle (Batch3 sh) e a b)
plan3D target mode sh@(batch:.depth:.height:.width) =
   atTarget target $
   makeHandle mode width
      (wrapFallback mode . Priv.transform3D sh)
      (\typ ->
         CUFFT.planMany [depth,height,width]
            Nothing Nothing typ (A.arraySize batch))