{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE Rank2Types #-}
{-# LANGUAGE FlexibleContexts #-}
module Data.Array.Accelerate.Fourier.Private where

import qualified Data.Array.Accelerate.Fourier.Sign as Sign
import qualified Data.Array.Accelerate.Convolution.Small as Cyclic
import Data.Array.Accelerate.Fourier.Sign (Sign, )

import qualified Data.Array.Accelerate.Utility.Sliced as Sliced
import qualified Data.Array.Accelerate.Utility.Sliced1 as Sliced1

import qualified Data.Array.Accelerate.LinearAlgebra as LinAlg
import Data.Array.Accelerate.LinearAlgebra (zipExtrudedVectorWith, )

import qualified Data.Array.Accelerate.Utility.Lift.Exp as Exp
import Data.Array.Accelerate.Utility.Lift.Exp (expr)

import qualified Data.Array.Accelerate as A
import Data.Array.Accelerate
          (Exp, Acc, Array, DIM1, DIM2, Elt,
           Z(Z), (:.)((:.)), Slice, Shape, (!), (?), )
import Data.Complex (Complex((:+)), )


type Transform sh a = Acc (Array sh a) -> Acc (Array sh a)

data SubTransform a =
        SubTransform
           (forall sh. (Shape sh, Slice sh) => Transform (sh:.Int) a)

data SubTransformPair a =
        SubTransformPair
           (forall sh. (Shape sh, Slice sh) => Transform (sh:.Int) a)
           (forall sh. (Shape sh, Slice sh) => Transform (sh:.Int) a)

data SubTransformTriple a =
        SubTransformTriple
           (forall sh. (Shape sh, Slice sh) => Transform (sh:.Int) a)
           (forall sh. (Shape sh, Slice sh) => Transform (sh:.Int) a)
           (forall sh. (Shape sh, Slice sh) => Transform (sh:.Int) a)


type PairTransform sh a =
        (Acc (Array sh a), Acc (Array sh a)) ->
        (Acc (Array sh a), Acc (Array sh a))

data SubPairTransform a =
        SubPairTransform
           (forall sh. (Shape sh, Slice sh) => PairTransform (sh:.Int:.Int) a)


cache2 :: (sign ~ Exp (Sign b), a ~ Exp (Complex b), A.RealFloat b) =>
   sign -> a
cache3 :: (sign ~ Exp (Sign b), a ~ Exp (Complex b), A.RealFloat b) =>
   sign -> (a,a)
cache4 :: (sign ~ Exp (Sign b), a ~ Exp (Complex b), A.RealFloat b) =>
   sign -> (a,a,a)
cache5 ::
   (sign ~ Exp (Sign b), a ~ Exp (Complex b),
    A.RealFloat b, A.FromIntegral Int b) =>
   sign -> (a,a,a,a)

cache2 _sign = -1

cache3 sign =
   let sqrt3d2 = sqrt 3 / 2
       mhalf = -1/2
       s = Sign.toSign sign
   in  (A.lift $ mhalf :+ s*sqrt3d2,
        A.lift $ mhalf :+ (-s)*sqrt3d2)

cache4 sign =
   let s = Sign.toSign sign
   in  (A.lift $ 0 :+ s, -1, A.lift $ 0 :+ (-s))

cache5 sign =
   let z = Sign.cisRat sign 5
   in  (z 1, z 2, z 3, z 4)


flatten2 ::
   (Shape sh, Slice sh, Elt a) =>
   Acc (Array sh (a,a)) ->
   Acc (Array (sh:.Int) a)
flatten2 x =
   A.generate
      (Exp.indexCons (A.shape x) (A.constant 2))
      (Exp.modify (expr :. expr) $
       \(ix :. k)  ->  let xi = x ! ix in k A.== 0 ? (A.fst xi, A.snd xi))

transform2 ::
   (Shape sh, Slice sh, a ~ Complex b, A.RealFloat b) =>
   Exp a -> Transform (sh:.Int) a
transform2 z arr =
   flatten2 $
   A.zipWith (\x0 x1 -> A.lift (x0+x1, x0+z*x1))
      (A.slice arr (A.lift $ A.Any :. (0::Int)))
      (A.slice arr (A.lift $ A.Any :. (1::Int)))


flatten3 ::
   (Shape sh, Slice sh, Elt a) =>
   Acc (Array sh (a,a,a)) ->
   Acc (Array (sh:.Int) a)
flatten3 x =
   A.generate
      (Exp.indexCons (A.shape x) (A.constant (3::Int)))
      (Exp.modify (expr :. expr) $
       \(ix :. k)  ->
          let (x0,x1,x2) = A.unlift $ x ! ix
          in  flip (A.caseof k) x0 $
                 ((A.==1), x1) :
                 ((A.==2), x2) :
                 [])

transform3 ::
   (Shape sh, Slice sh, a ~ Complex b, A.RealFloat b) =>
   (Exp a, Exp a) -> Transform (sh:.Int) a
transform3 (z,z2) arr =
   flatten3 $
   A.zipWith3
      (\x0 x1 x2 ->
         let ((s,_), (zx1,zx2)) = Cyclic.sumAndConvolvePair (x1,x2) (z,z2)
         in  A.lift (x0+s, x0+zx1, x0+zx2))
      (A.slice arr (A.lift $ A.Any :. (0::Int)))
      (A.slice arr (A.lift $ A.Any :. (1::Int)))
      (A.slice arr (A.lift $ A.Any :. (2::Int)))


flatten4 ::
   (Shape sh, Slice sh, Elt a) =>
   Acc (Array sh (a,a,a,a)) ->
   Acc (Array (sh:.Int) a)
flatten4 x =
   A.generate
      (Exp.indexCons (A.shape x) (A.constant (4::Int)))
      (Exp.modify (expr :. expr) $
       \(ix :. k)  ->
          let (x0,x1,x2,x3) = A.unlift $ x ! ix
          in  flip (A.caseof k) x0 $
                 ((A.==1), x1) :
                 ((A.==2), x2) :
                 ((A.==3), x3) :
                 [])

transform4 ::
   (Shape sh, Slice sh, a ~ Complex b, A.RealFloat b) =>
   (Exp a, Exp a, Exp a) -> Transform (sh:.Int) a
transform4 (z,z2,z3) arr =
   flatten4 $
   A.zipWith4
      (\x0 x1 x2 x3 ->
         let x02a = x0+x2; x02b = x0+z2*x2
             x13a = x1+x3; x13b = x1+z2*x3
         in  A.lift (x02a+   x13a, x02b+z *x13b,
                     x02a+z2*x13a, x02b+z3*x13b))
      (A.slice arr (A.lift $ A.Any :. (0::Int)))
      (A.slice arr (A.lift $ A.Any :. (1::Int)))
      (A.slice arr (A.lift $ A.Any :. (2::Int)))
      (A.slice arr (A.lift $ A.Any :. (3::Int)))


flatten5 ::
   (Shape sh, Slice sh, Elt a) =>
   Acc (Array sh (a,a,a,a,a)) ->
   Acc (Array (sh:.Int) a)
flatten5 x =
   A.generate
      (Exp.indexCons (A.shape x) (A.constant (5::Int)))
      (Exp.modify (expr :. expr) $
       \(ix :. k)  ->
          let (x0,x1,x2,x3,x4) = A.unlift $ x ! ix
          in  flip (A.caseof k) x0 $
                 ((A.==1), x1) :
                 ((A.==2), x2) :
                 ((A.==3), x3) :
                 ((A.==4), x4) :
                 [])


{-
Use Rader's trick for mapping the transform to a convolution
and apply Karatsuba's trick at two levels (i.e. total three times)
to that convolution.

0 0 0 0 0
0 1 2 3 4
0 2 4 1 3
0 3 1 4 2
0 4 3 2 1

Permutation.T: 0 1 2 4 3

0 0 0 0 0
0 1 2 4 3
0 2 4 3 1
0 4 3 1 2
0 3 1 2 4
-}
transform5 ::
   (Shape sh, Slice sh, a ~ Complex b, A.RealFloat b) =>
   (Exp a, Exp a, Exp a, Exp a) -> Transform (sh:.Int) a
transform5 (z1,z2,z3,z4) arr =
   flatten5 $
   A.zipWith5
      (\x0 x1 x2 x3 x4 ->
         let ((s,_), (d1,d2,d4,d3)) =
                Cyclic.sumAndConvolveQuadruple (x1,x3,x4,x2) (z1,z2,z4,z3)
         in  A.lift (x0+s, x0+d1, x0+d2, x0+d3, x0+d4))
      (A.slice arr (A.lift $ A.Any :. (0::Int)))
      (A.slice arr (A.lift $ A.Any :. (1::Int)))
      (A.slice arr (A.lift $ A.Any :. (2::Int)))
      (A.slice arr (A.lift $ A.Any :. (3::Int)))
      (A.slice arr (A.lift $ A.Any :. (4::Int)))


twist ::
   (Shape sh, Slice sh, Elt a) =>
   Exp Int -> Transform (sh:.Int:.Int) a
twist fac x =
   let sh :. m :. n = Exp.unlift (expr :. expr :. expr) $ A.shape x
   in  A.backpermute
          (A.lift $ sh :. fac*m :. div n fac)
          (Exp.modify (expr :. expr :. expr) $
           \(globalIx :. k :. j) -> globalIx :. div k fac :. fac*j + mod k fac)
          x


merge ::
   (Shape sh, Slice sh, Elt a) =>
   Acc (Array (sh:.Int:.Int) a) ->
   Acc (Array (sh:.Int) a)
merge x =
   let sh :. m :. n = Exp.unlift (expr :. expr :. expr) $ A.shape x
   in  A.backpermute
          (A.lift $ sh :. m*n)
          (Exp.modify (expr :. expr) $
           \(ix :. k) -> ix :. mod k m :. div k m)
          x

stack ::
   (Shape sh, Slice sh, Elt a) =>
   Acc (Array (sh:.Int) a) ->
   Acc (Array (sh:.Int) a) ->
   Acc (Array (sh:.Int:.Int) a)
stack x y =
   A.generate
      (Exp.modify (expr :. expr)
         (\(sh :. n) -> sh :. (2::Int) :. n)
         (A.shape x))
      (Exp.modify (expr :. expr :. expr) $
       \(globalIx :. evenOdd :. k) ->
          let ix = A.lift $ globalIx :. k
          in  evenOdd A.== 0 ? (x ! ix, y ! ix))


{- |
twiddle factors for radix-2 Cooley-Tukey transforms
-}
twiddleFactors2 ::
   (A.RealFloat a, A.FromIntegral Int a) =>
   Exp (Sign a) -> Exp Int -> Acc (A.Vector (Complex a))
twiddleFactors2 sign len2 =
   A.generate (A.lift $ Z:.len2) $ twiddle2 sign len2 . A.indexHead

twiddle2 ::
   (A.RealFloat a, A.FromIntegral Int a) =>
   Exp (Sign a) -> Exp Int -> Exp Int -> Exp (Complex a)
twiddle2 sign n2i ki =
   let n2 = A.fromIntegral n2i
       k  = A.fromIntegral ki
   in  Sign.cis sign $ pi*k/n2


twiddleFactors ::
   (A.RealFloat a, A.FromIntegral Int a) =>
   Exp (Sign a) -> Exp Int -> Exp Int -> Acc (Array DIM2 (Complex a))
twiddleFactors sign lenk lenj =
   A.generate (A.lift $ Z:.lenk:.lenj) $
   Exp.modify (expr :. expr :. expr) $
   \(_z :. k :. j) -> twiddle sign (lenk*lenj) k j

twiddle ::
   (A.RealFloat a, A.FromIntegral Int a) =>
   Exp (Sign a) -> Exp Int -> Exp Int -> Exp Int -> Exp (Complex a)
twiddle sign n k j =
   Sign.cisRat sign n $ mod (k*j) n


transformRadix2InterleavedTime ::
   (Shape sh, Slice sh, a ~ Complex b, A.RealFloat b) =>
   Acc (Array DIM1 a) ->
   Transform (sh:.Int:.Int) a ->
   Transform (sh:.Int) a
transformRadix2InterleavedTime twiddles subTransform arr =
   let (sh:.len) = Exp.unlift (expr:.expr) $ A.shape arr
       len2 = div len 2
       subs =
          subTransform $
          if True
            then Sliced.sliceHorizontal (A.lift $ Z:.(2::Int):.len2) arr
            else
               LinAlg.transpose $
               A.reshape (A.lift $ sh:.len2:.(2::Int)) arr
       evens = A.slice subs (A.lift $ A.Any :. (0::Int) :. A.All)
       odds =
          zipExtrudedVectorWith (*) twiddles $
          A.slice subs (A.lift $ A.Any :. (1::Int) :. A.All)
   in  A.zipWith (+) evens odds  A.++  A.zipWith (-) evens odds


initSplitRadix ::
   (Slice sh, Shape sh, a ~ Complex b, A.RealFloat b) =>
   Acc (Array (sh:.Int) a) ->
   (Acc (Array (sh:.Int:.Int) a), Acc (Array (sh:.Int:.Int) a))
initSplitRadix arr =
   let (sh:.len) = Exp.unlift (expr:.expr) $ A.shape arr
   in  (A.replicate (A.lift $ A.Any :. (1::Int) :. A.All) arr,
        A.fill (A.lift $ sh:.(0::Int):.div len 2) 0)

finishSplitRadix ::
   (Slice sh, Shape sh, a ~ Complex b, A.RealFloat b) =>
   Acc (Array (sh:.Int:.Int) a) -> Acc (Array (sh:.Int) a)
finishSplitRadix =
   flip A.slice (A.lift $ A.Any :. (0::Int) :. A.All)


initSplitRadixFlat ::
   (Slice sh, Shape sh, a ~ Complex b, A.RealFloat b) =>
   Acc (Array (sh:.Int) a) ->
   (Acc (Array DIM2 a), Acc (Array DIM2 a))
initSplitRadixFlat arr =
   let (sh:.len) = Exp.unlift (expr:.expr) $ A.shape arr
   in  (A.reshape (A.lift $ Z :. A.shapeSize sh :. len) arr,
        A.fill (A.lift $ Z:.(0::Int):.div len 2) 0)

finishSplitRadixFlat ::
   (Slice sh, Shape sh, a ~ Complex b, A.RealFloat b) =>
   Exp (sh:.Int) -> Acc (Array DIM2 a) -> Acc (Array (sh:.Int) a)
finishSplitRadixFlat = A.reshape


imagSplitRadixPlain ::
   (Num a) =>
   Sign a -> Complex a
imagSplitRadixPlain sign = 0 :+ Sign.getSign sign

imagSplitRadix ::
   (A.Num a) =>
   Exp (Sign a) -> Exp (Complex a)
imagSplitRadix sign =
   A.lift (0 :+ Sign.toSign sign)

ditSplitRadixReorder ::
   (Slice sh, Shape sh, Elt a) =>
   PairTransform (sh:.Int:.Int) a
ditSplitRadixReorder (arr2, arr1) =
   let evens = Sliced.sieve 2 0 arr2
       odds  = Sliced.sieve 2 1 arr2
   in  (Sliced1.append evens arr1, twist 2 odds)

ditSplitRadixBase ::
   (Slice sh, Shape sh, A.RealFloat a) =>
   PairTransform (sh:.Int:.Int) (Complex a)
ditSplitRadixBase (arr2, arr1) = (transform2 (-1) arr2, arr1)

ditSplitRadixStep ::
   (Slice sh, Shape sh, a ~ Complex b, A.RealFloat b) =>
   Exp a ->
   (Acc (Array DIM1 a), Acc (Array DIM1 a)) ->
   PairTransform (sh:.Int:.Int) a
ditSplitRadixStep imag (twiddles1, twiddles3) (u, zIntl) =
   let twiddledZEven =
          zipExtrudedVectorWith (*) twiddles1 $ Sliced1.sieve 2 0 zIntl
       twiddledZOdd =
          zipExtrudedVectorWith (*) twiddles3 $ Sliced1.sieve 2 1 zIntl
       zSum = A.zipWith (+) twiddledZEven twiddledZOdd
       zDiff = A.map (imag *) $ A.zipWith (-) twiddledZEven twiddledZOdd
       zComplete = zSum A.++ zDiff
   in  (A.zipWith (+) u zComplete
        A.++
        A.zipWith (-) u zComplete,
           Sliced1.drop (Sliced1.length zComplete) u)


twiddleSR ::
   (A.RealFloat a, A.FromIntegral Int a) =>
   Exp (Sign a) -> Exp Int -> Exp Int -> Exp Int -> Exp (Complex a)
twiddleSR sign n4i ki ji =
   let n4 = A.fromIntegral n4i
       k  = A.fromIntegral ki
       j  = A.fromIntegral ji
   in  Sign.cis sign $ pi*(k*j)/(2*n4)

twiddleFactorsSR ::
   (A.RealFloat a, A.FromIntegral Int a) =>
   Exp (Sign a) -> Exp Int -> Exp Int -> Acc (Array DIM1 (Complex a))
twiddleFactorsSR sign len4 k =
   A.generate (A.lift $ Z:.len4) $ twiddleSR sign len4 k . A.indexHead

twiddleFactorsSRPair ::
   (A.RealFloat a, A.FromIntegral Int a) =>
   Exp (Sign a) -> Exp Int ->
   (Acc (Array DIM1 (Complex a)), Acc (Array DIM1 (Complex a)))
twiddleFactorsSRPair sign len4 =
   (twiddleFactorsSR sign len4 1,
    twiddleFactorsSR sign len4 3)


cycleDim3 ::
   (Slice sh, Shape sh, Elt a) =>
   Acc (Array (sh:.Int:.Int:.Int) a) ->
   Acc (Array (sh:.Int:.Int:.Int) a)
cycleDim3 arr =
   A.backpermute
      (Exp.modify (expr:.expr:.expr:.expr)
         (\(sh:.k:.m:.n) -> (sh:.n:.k:.m)) $
       A.shape arr)
      (Exp.modify (expr:.expr:.expr:.expr)
         (\(ix:.n:.k:.m) -> (ix:.k:.m:.n)))
      arr


chirp ::
   (A.RealFloat a, A.FromIntegral Int a) =>
   Exp (Sign a) -> Exp Int -> Exp a -> A.Acc (A.Array DIM1 (Complex a))
chirp sign padLen lenFloat =
   A.generate (A.index1 padLen) $
   \ix ->
      let k = A.unindex1 ix
          sk = A.fromIntegral (padLen A.> 2*k ? (k, k-padLen))
      in  Sign.cis sign (pi*sk*sk/lenFloat)