{-# LANGUAGE BangPatterns        #-}
{-# LANGUAGE NoImplicitPrelude   #-}
{-# LANGUAGE ScopedTypeVariables #-}

module Data.Primitive.Contiguous.FFT
  ( dft
  , idft
  , overlapDFT
  ) where

import qualified Prelude

import Data.Eq (Eq((==)))
import Data.Function (($))
import Control.Monad
import Data.Ord
import Control.Monad.ST
import Data.Complex hiding (cis)
import qualified Data.Complex as C
import Data.Primitive.Contiguous
import GHC.Num (Num(..))
import GHC.Float
import GHC.Real
import GHC.Exts (Int)

cis :: Floating a => a -> a -> Complex a
cis k n = C.cis (2 * pi * k / n)
{-# INLINE cis #-}

mkComplex :: x -> x -> Complex x
mkComplex !r !i = r :+ i
{-# INLINE mkComplex #-}

dftMutable :: forall arr x s. (RealFloat x, Contiguous arr, Element arr (Complex x))
     => Mutable arr s (Complex x)
     -> ST s (Mutable arr s (Complex x))
dftMutable !mut = do
  !sz <- sizeMutable mut

  let getII !ix = (ix + sz `Prelude.div` 2) `Prelude.mod` sz
      go :: Int -- ^ i value
         -> Int -- ^ j value
         -> Complex x -- ^ accumulator
         -> ST s ()
      go !i !j !acc = if i == sz then return () else if j < sz
        then do
          let !jj = getII j
          atJJ@(r :+ _) <- read mut jj
          let real, imag, same :: x
              !same = (-2) * pi * (fromIntegral (i * j)) / (fromIntegral sz)
              !real = r * cos same
              !imag = r * sin same
              !val  = acc + mkComplex real imag
          go i (j + 1) val
        else do
          let !ii = getII i
          !_ <- write mut ii acc :: ST s ()
          go (i + 1) 0 0

  !_ <- go 0 0 0

  return mut

dft :: forall arr x. (RealFloat x, Contiguous arr, Element arr x, Element arr (Complex x))
     => arr x
     -> arr (Complex x)
dft !a = runST $ dftInternal a

-- | not in-place, also very inefficient. currently /O(n^2)/
dftInternal :: forall arr x s. (RealFloat x, Contiguous arr, Element arr x, Element arr (Complex x))
  => arr x
  -> ST s (arr (Complex x))
dftInternal !a = do
  let !sz = size a
      getII !ix = (ix + sz `Prelude.div` 2) `Prelude.mod` sz

  !mut <- new sz :: ST s (Mutable arr s (Complex x))

  let go :: Int -- ^ i value
         -> Int -- ^ j value
         -> Complex x -- ^ accumulator
         -> ST s ()
      go !i !j !acc = if i == sz then return () else if j < sz
        then do
          let !jj = getII j
              !atJJ = index a jj
              real, imag, same :: x
              !same = (-2) * pi * (fromIntegral (i * j)) / (fromIntegral sz)
              !real = atJJ * cos same
              !imag = atJJ * sin same
              !val  = acc + mkComplex real imag
          go i (j + 1) val
        else do
          let !ii = getII i
          !_ <- write mut ii acc :: ST s ()
          go (i + 1) 0 0

  !_ <- go 0 0 0

  unsafeFreeze mut

idft :: forall arr x. (RealFloat x, Contiguous arr, Element arr x, Element arr (Complex x))
  => arr (Complex x)
  -> arr x
idft !a = runST $ idftInternal a

-- | not in-place, also very inefficient. currently /O(n^2)/
idftInternal :: forall arr x s. (RealFloat x, Contiguous arr, Element arr x, Element arr (Complex x))
  => arr (Complex x)
  -> ST s (arr x)
idftInternal !a = do
  let !sz = size a
      getII !ix = (ix + sz `Prelude.div` 2) `Prelude.mod` sz

  !mut <- new sz :: ST s (Mutable arr s x)

  let go :: Int
         -> Int
         -> x
         -> ST s ()
      go !i !j !acc = if i == sz then return () else if j < sz
        then do
          let !jj = getII j
              !atJJ@(real :+ imag) = index a jj
              !sCount = fromIntegral sz
              !same = (-2) * pi * (fromIntegral (i * j)) / sCount
              !val = (real * cos same + imag * sin same) / sCount
          go i (j + 1) val
        else do
          let !ii = getII i
          !_ <- write mut ii acc :: ST s ()
          go (i + 1) 0 0

  !_ <- go 0 0 0

  unsafeFreeze mut

-- | Given a signal size, previous window, transform of previous window, and the newest value,
--   compute the transform of the new window (which is just a shifted version of the previous window)
--   in /O(n)/ time, in-place
overlapDFT :: forall arr x s. (RealFloat x, Contiguous arr, Element arr x, Element arr (Complex x))
  => Int   -- ^ N, signal size
  -> Mutable arr s (Complex x) -- ^ x1, original window
  -> Complex x -- ^ newest complex value
  -> Mutable arr s (Complex x) -- ^ f1, previous transform
  -> ST s (Mutable arr s (Complex x)) -- ^ f2, new transform
overlapDFT n x1 x2_N_1 f1 = do
  let !sz = fromIntegral n :: x

  !l <- sizeMutable f1
  !x1_0 <- read x1 0 :: ST s (Complex x)

  let go :: Int -> ST s ()
      go !ix = if ix < l
        then do
          f1_k <- read f1 ix
          let foo' = cis (fromIntegral ix) sz
              res  = f1_k + x2_N_1 + x1_0
              fin  = foo' * res
          !_ <- write f1 ix fin
          go (ix + 1)
        else return ()
  go 0
  return f1