{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE FlexibleContexts #-}
{- |
Accelerate interface to the native CUDA implementation
of the Fourier Transform provided by the CUFFT library.
-}
module Data.Array.Accelerate.CUFFT.Private where

import qualified Data.Array.Accelerate.CUFFT.RealClass as RC

import qualified Data.Array.Accelerate.Fourier.Preprocessed as Prep
import qualified Data.Array.Accelerate.Fourier.Planned as Fourier

import qualified Data.Array.Accelerate.Utility.Lift.Exp as Exp
import qualified Data.Array.Accelerate.Utility.Sliced as Sliced
import Data.Array.Accelerate.Utility.Lift.Exp (expr)

import qualified Data.Array.Accelerate.LLVM.PTX.Foreign as AF
import qualified Data.Array.Accelerate.LLVM.PTX as PTX
import qualified Data.Array.Accelerate.Array.Sugar as Sugar
import qualified Data.Array.Accelerate as A
import Data.Array.Accelerate.Data.Complex (Complex((:+)), real, imag, conjugate)
import Data.Array.Accelerate.Lifetime (withLifetime)
import Data.Array.Accelerate
         (Acc, Array, Elt, Shape, Slice, (:.)((:.)), Exp, (!), (?))

import qualified Foreign.CUDA.FFT as CUFFT
import qualified Foreign.CUDA.Driver as CUDA
import Foreign.CUDA.Ptr (DevicePtr)

import qualified System.Mem.Weak as Weak

import Control.Exception (bracket_)


type Transform sh a b = Acc (Array sh a) -> Acc (Array sh b)


type Sign a = (Int, Fourier.Sign a)

forwardSign, inverseSign :: Num a => Sign a
forwardSign = (-1, Fourier.forward)
inverseSign = ( 1, Fourier.inverse)


data
   Handle sh e a b =
      Handle (Transform sh a b) (Mode sh e a b) Int CUFFT.Handle

makeHandle ::
   (Shape sh, Slice sh, RC.Real e) =>
   Mode sh e a b -> Int ->
   (Fourier.Sign e -> Transform sh a b) ->
   (CUFFT.Type -> IO CUFFT.Handle) ->
   IO (Handle sh e a b)
makeHandle mode width fallback planner = do
   plan <- planner $ types mode
   Weak.addFinalizer plan (CUFFT.destroy plan)
   return $ Handle (fallback $ fsign mode) mode width plan



getBestTarget :: IO AF.PTX
getBestTarget = do
   CUDA.initialise []
   -- (dev,prop) <- PTX.selectBestDevice
   dev <- CUDA.device 0
   prop <- CUDA.props dev
   PTX.createTargetForDevice dev prop [CUDA.SchedAuto]

atTarget :: AF.PTX -> IO a -> IO a
atTarget target act =
   withLifetime (AF.deviceContext $ AF.ptxContext target) $ \ctx ->
      bracket_ (CUDA.push ctx) CUDA.pop act



type Batch0 sh = sh
type Batch1 sh = Batch0 sh :. Int
type Batch2 sh = Batch1 sh :. Int
type Batch3 sh = Batch2 sh :. Int

transform2D ::
   (Shape sh, Slice sh, RC.Real a) =>
   Batch2 sh -> Fourier.Sign a ->
   Fourier.Transform (Batch2 sh) (Complex a)
transform2D (_shape:.height:.width) sign =
   Prep.transform2d $
      Prep.SubTransformPair
         (Fourier.transform sign width)
         (Fourier.transform sign height)


transform3D ::
   (Shape sh, Slice sh, RC.Real a) =>
   Batch3 sh -> Fourier.Sign a ->
   Fourier.Transform (Batch3 sh) (Complex a)
transform3D (_shape:.depth:.height:.width) sign =
   Prep.transform3d $
   Prep.SubTransformTriple
      (Fourier.transform sign width)
      (Fourier.transform sign height)
      (Fourier.transform sign depth)


{- |
The implementation works on all arrays of rank less than or equal to 3.
The result is un-normalised.
-}
transform ::
   (Shape sh, Slice sh, RC.Real e) =>
   Handle (sh:.Int) e a b ->
   Transform (sh:.Int) a b
transform hndl@(Handle fallback mode width _) =
   {-
   Unfortunately the fallback version of the function
   needs to be wrapped in 'interleave' and 'deinterleave'
   to match the data layout as expected by the foreign version.
   Fusion might remove redundant transformations.
   The optimal solution is to make the backend explicit in the type,
   which allows us to declare back-end specific functions
   without a fall-back implementation.
   -}
   wrap mode (A.constant width) $
   A.foreignAcc
      (AF.ForeignAcc "transformForeign" $ transformForeign hndl)
      (unwrap mode (A.constant width) fallback)


forwardComplex, inverseComplex ::
   (Shape sh, Slice sh, RC.Real e) =>
   Mode sh e (Complex e) (Complex e)
forwardComplex =
   getModeC2C $ RC.switch (modeC2CFloat forwardSign) (modeC2CDouble forwardSign)
inverseComplex =
   getModeC2C $ RC.switch (modeC2CFloat inverseSign) (modeC2CDouble inverseSign)

{- |
In contrast to plain CUFFT functions the data is redundant.
That is, an array of shape @sh@ is transformed to an array of shape @sh@.
This way, all dimensions of an array are handled the same way.
Chances are good,
that the internal post processing is fused with following array operations
and thus the redundant data will not be stored in a manifest array.
-}
forwardReal ::
   (Shape sh, Slice sh, RC.Real e) =>
   Mode (sh:.Int) e e (Complex e)
forwardReal =
   getModeR2C $
   RC.switch
      (modeR2C CUFFT.R2C CUFFT.execR2C)
      (modeR2C CUFFT.D2Z CUFFT.execD2Z)

inverseReal ::
   (Shape sh, Slice sh, RC.Real e) =>
   Mode (sh:.Int) e (Complex e) e
inverseReal =
   getModeC2R $
   RC.switch
      (modeC2R CUFFT.C2R CUFFT.execC2R)
      (modeC2R CUFFT.Z2D CUFFT.execZ2D)


data Types = R2C | C2R | C2C
   deriving (Eq, Ord, Enum, Show)


data Mode sh e a b =
   Mode {
      types :: CUFFT.Type,
      plainTypes :: Types,
      execute ::
         CUFFT.Handle ->
         CUDA.DevicePtr (Sugar.EltRepr e) ->
         CUDA.DevicePtr (Sugar.EltRepr e) ->
         IO (),
      wrap :: Exp Int -> Fourier.Transform (sh:.Int) e -> Transform sh a b,
      unwrap :: Exp Int -> Transform sh a b -> Fourier.Transform (sh:.Int) e,
      wrapFallback :: Fourier.Transform sh (Complex e) -> Transform sh a b,
      fsign :: Fourier.Sign e
   }

newtype
   ModeC2C sh e =
      ModeC2C {getModeC2C :: Mode sh e (Complex e) (Complex e)}

newtype
   ModeR2C sh e =
      ModeR2C {getModeR2C :: Mode sh e e (Complex e)}

newtype
   ModeC2R sh e =
      ModeC2R {getModeC2R :: Mode sh e (Complex e) e}


type Execute e = CUFFT.Handle -> DevicePtr e -> DevicePtr e -> IO ()
type ExecuteSign e = CUFFT.Handle -> DevicePtr e -> DevicePtr e -> Int -> IO ()


modeC2C ::
   (Shape sh, Slice sh, RC.Real e) =>
   CUFFT.Type -> ExecuteSign (Sugar.EltRepr e) -> Sign e -> ModeC2C sh e
modeC2C typ exec (isign,fsign0) =
   ModeC2C $
   Mode {
      types = typ,
      execute = \hndl iptr optr -> exec hndl iptr optr isign,
      plainTypes = C2C,
      wrap = \ _width f -> deinterleave . f . interleave,
      unwrap = \ _width f -> interleave . f . deinterleave,
      wrapFallback = id,
      fsign = fsign0
   }

modeC2CFloat :: (Shape sh, Slice sh) => Sign Float -> ModeC2C sh Float
modeC2CFloat = modeC2C CUFFT.C2C CUFFT.execC2C

modeC2CDouble :: (Shape sh, Slice sh) => Sign Double -> ModeC2C sh Double
modeC2CDouble = modeC2C CUFFT.Z2Z CUFFT.execZ2Z


{-
The fallback implementation is inefficient
because it does not benefit from occurring symmetries.
However, it works generally for all dimensions
and also for odd data set sizes.
-}
modeR2C ::
   (Shape sh, Slice sh, RC.Real e) =>
   CUFFT.Type -> Execute (Sugar.EltRepr e) -> ModeR2C (sh:.Int) e
modeR2C typ exec =
   ModeR2C $
   Mode {
      types = typ,
      execute = exec,
      plainTypes = R2C,
      wrap = \width f -> mirror width . deinterleave . f . addDim,
      unwrap = \width f -> interleave . takeHalf width . f . removeDim,
      wrapFallback = (. A.map (Exp.modify expr (:+0))),
      fsign = Fourier.forward
   }

modeC2R ::
   (Shape sh, Slice sh, RC.Real e) =>
   CUFFT.Type -> Execute (Sugar.EltRepr e) -> ModeC2R (sh:.Int) e
modeC2R typ exec =
   ModeC2R $
   Mode {
      types = typ,
      execute = exec,
      plainTypes = C2R,
      wrap = \width f -> removeDim . f . interleave . takeHalf width,
      unwrap = \width f -> addDim . f . mirror width . deinterleave,
      wrapFallback = (A.map real .),
      fsign = Fourier.inverse
   }


transformForeign ::
   (Shape sh, RC.Real e) =>
   Handle (sh:.Int) e a b -> AF.Stream ->
   Array (sh:.Int:.Int) e -> AF.LLVM AF.PTX (Array (sh:.Int:.Int) e)
transformForeign (Handle _ mode width hndl) stream input = do
   let (shape :. _width :. _tupleSize) = A.arrayShape input
       outputSh =
          case plainTypes mode of
             R2C -> shape :. div width 2 + 1 :. (2::Int)
             C2R -> shape :. width :. 1
             C2C -> shape :. width :. 2
   output <- AF.allocateRemote outputSh
   withDevicePtr input $ \iptr ->
      withDevicePtr output $ \optr -> do
         AF.liftIO $ execute mode hndl iptr optr
         ev <- AF.checkpoint stream
         return (Just ev, (Just ev, output))


newtype
   WithDevicePtr target r sh e =
      WithDevicePtr {
         runWithDevicePtr ::
            Array sh e ->
            (CUDA.DevicePtr (Sugar.EltRepr e) ->
             AF.LLVM target (Maybe AF.Event, r)) ->
            AF.LLVM target r
      }

withDevicePtr ::
   (RC.Real e) =>
   Array sh e ->
   (CUDA.DevicePtr (Sugar.EltRepr e) -> AF.LLVM AF.PTX (Maybe AF.Event, r)) ->
   AF.LLVM AF.PTX r
withDevicePtr =
   runWithDevicePtr
      (RC.switch
         (WithDevicePtr $ \(Sugar.Array _ dat) -> AF.withDevicePtr dat)
         (WithDevicePtr $ \(Sugar.Array _ dat) -> AF.withDevicePtr dat))


{-
The rule "interleave/deinterleave" may turn a bottom into the identity,
if the input array has not extent 2 at the least-significant dimension.
The rule is only safe for the usage in this module.
-}
{-# RULES
  "interleave/deinterleave" forall x. deinterleave (interleave x) = x;
  "deinterleave/interleave" forall x. interleave (deinterleave x) = x;
  "addDim/removeDim" forall x. removeDim (addDim x) = x;
  "removeDim/addDim" forall x. addDim (removeDim x) = x;
 #-}

{- |
Imitate cuComplex types by interleaving real and imaginary components.
Adds a least-significant dimension of extent 2.
-}
{-# NOINLINE[1] interleave #-}
interleave ::
   (Shape sh, Slice sh, Elt a) =>
   Acc (Array sh (Complex a)) -> Acc (Array (sh:.Int) a)
interleave arr =
   A.generate
      (A.lift $ A.shape arr :. (2::Int))
      (\ix ->
         let x = arr ! A.indexTail ix
         in  A.indexHead ix A.== 0 ? (real x, imag x))

{-# NOINLINE[1] deinterleave #-}
deinterleave ::
   (Shape sh, Slice sh, Elt a) =>
   Acc (Array (sh:.Int) a) -> Acc (Array sh (Complex a))
deinterleave arr =
   A.generate (A.indexTail $ A.shape arr)
      (\ix ->
         let get n = arr ! A.lift (ix :. (n::Int))
         in  A.lift $ get 0 :+ get 1)

{-# NOINLINE[1] addDim #-}
addDim ::
   (Shape sh, Slice sh, Elt a) =>
   Acc (Array sh a) -> Acc (Array (sh:.Int) a)
addDim arr = A.reshape (A.lift $ A.shape arr :. (1::Int)) arr

{-# NOINLINE[1] removeDim #-}
removeDim ::
   (Shape sh, Slice sh, Elt a) =>
   Acc (Array (sh:.Int) a) -> Acc (Array sh a)
removeDim arr = A.reshape (A.indexTail $ A.shape arr) arr


takeHalf ::
   (Shape sh, Slice sh, Elt a) =>
   Exp Int -> Fourier.Transform (sh:.Int) a
takeHalf width = Sliced.take (div width 2 + 1)

mirror ::
   (Shape sh, Slice sh, A.Num a) =>
   Exp Int -> Fourier.Transform (sh:.Int) (Complex a)
mirror newWidth arr =
   let (sh:.width) = Exp.unlift (expr:.expr) $ A.shape arr
   in  A.generate (A.lift $ sh :. newWidth) $
       Exp.modify (expr:.expr) $ \(ix:.k) ->
          k A.< width ?
             (arr ! Exp.indexCons ix k,
              conjugate (arr ! Exp.indexCons ix (newWidth - k)))