-- |
-- Module:      Data.Poly.Internal.Dense.FFT
-- Copyright:   (c) 2020 Andrew Lelechenko
-- Licence:     BSD3
-- Maintainer:  Andrew Lelechenko <andrew.lelechenko@gmail.com>
--
-- Discrete Fourier transform.
--

{-# LANGUAGE BangPatterns        #-}
{-# LANGUAGE ScopedTypeVariables #-}

module Data.Poly.Internal.Dense.DFT
  ( dft
  , inverseDft
  ) where

import Prelude hiding (recip, fromIntegral)
import Control.Monad.ST
import Data.Bits hiding (shift)
import Data.Foldable
import Data.Semiring (Semiring(..), Ring(..), minus, fromIntegral)
import Data.Field (Field, recip)
import qualified Data.Vector.Generic as G
import qualified Data.Vector.Generic.Mutable as MG

-- | <https://en.wikipedia.org/wiki/Fast_Fourier_transform Discrete Fourier transform>
-- \( y_k = \sum_{j=0}^{N-1} x_j \sqrt[N]{1}^{jk} \).
--
-- @since 0.5.0.0
dft
  :: (Ring a, G.Vector v a)
  => a   -- ^ primitive root \( \sqrt[N]{1} \), otherwise behaviour is undefined
  -> v a -- ^ \( \{ x_k \}_{k=0}^{N-1} \) (currently only  \( N = 2^n \) is supported)
  -> v a -- ^ \( \{ y_k \}_{k=0}^{N-1} \)
dft :: forall a (v :: * -> *). (Ring a, Vector v a) => a -> v a -> v a
dft a
primRoot (v a
xs :: v a)
  | forall a. Bits a => a -> Int
popCount Int
nn forall a. Eq a => a -> a -> Bool
/= Int
1 = forall a. HasCallStack => [Char] -> a
error [Char]
"dft: only vectors of length 2^n are supported"
  | Bool
otherwise = Int -> Int -> v a
go Int
0 Int
0
  where
    nn :: Int
nn = forall (v :: * -> *) a. Vector v a => v a -> Int
G.length v a
xs
    n :: Int
n = forall b. FiniteBits b => b -> Int
countTrailingZeros Int
nn

    roots :: v a
    roots :: v a
roots = forall (v :: * -> *) a. Vector v a => Int -> (a -> a) -> a -> v a
G.iterateN
      (Int
1 forall a. Bits a => a -> Int -> a
`unsafeShiftL` (Int
n forall a. Num a => a -> a -> a
- Int
1))
      (\a
x -> a
x seq :: forall a b. a -> b -> b
`seq` (a
x forall a. Semiring a => a -> a -> a
`times` a
primRoot))
      forall a. Semiring a => a
one

    go :: Int -> Int -> v a
go !Int
offset !Int
shift
      | Int
shift forall a. Ord a => a -> a -> Bool
>= Int
n = forall (v :: * -> *) a. Vector v a => Int -> Int -> v a -> v a
G.unsafeSlice Int
offset Int
1 v a
xs
      | Bool
otherwise = forall a. (forall s. ST s a) -> a
runST forall a b. (a -> b) -> a -> b
$ do
        let halfLen :: Int
halfLen = Int
1 forall a. Bits a => a -> Int -> a
`unsafeShiftL` (Int
n forall a. Num a => a -> a -> a
- Int
shift forall a. Num a => a -> a -> a
- Int
1)
            ys0 :: v a
ys0 = Int -> Int -> v a
go Int
offset (Int
shift forall a. Num a => a -> a -> a
+ Int
1)
            ys1 :: v a
ys1 = Int -> Int -> v a
go (Int
offset forall a. Num a => a -> a -> a
+ Int
1 forall a. Bits a => a -> Int -> a
`unsafeShiftL` Int
shift) (Int
shift forall a. Num a => a -> a -> a
+ Int
1)
        Mutable v s a
ys <- forall (m :: * -> *) (v :: * -> * -> *) a.
(HasCallStack, PrimMonad m, MVector v a) =>
Int -> m (v (PrimState m) a)
MG.new (Int
halfLen forall a. Bits a => a -> Int -> a
`unsafeShiftL` Int
1)

        -- This corresponds to k = 0 in the loop below.
        -- It improves performance by avoiding multiplication
        -- by roots V.! 0 = 1.
        let y00 :: a
y00 = forall (v :: * -> *) a. Vector v a => v a -> Int -> a
G.unsafeIndex v a
ys0 Int
0
            y10 :: a
y10 = forall (v :: * -> *) a. Vector v a => v a -> Int -> a
G.unsafeIndex v a
ys1 Int
0
        forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> a -> m ()
MG.unsafeWrite Mutable v s a
ys Int
0       forall a b. (a -> b) -> a -> b
$! a
y00 forall a. Semiring a => a -> a -> a
`plus`  a
y10
        forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> a -> m ()
MG.unsafeWrite Mutable v s a
ys Int
halfLen forall a b. (a -> b) -> a -> b
$! a
y00 forall a. Ring a => a -> a -> a
`minus` a
y10

        forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [Int
1..Int
halfLen forall a. Num a => a -> a -> a
- Int
1] forall a b. (a -> b) -> a -> b
$ \Int
k -> do
          let y0 :: a
y0 = forall (v :: * -> *) a. Vector v a => v a -> Int -> a
G.unsafeIndex v a
ys0 Int
k
              y1 :: a
y1 = forall (v :: * -> *) a. Vector v a => v a -> Int -> a
G.unsafeIndex v a
ys1 Int
k forall a. Semiring a => a -> a -> a
`times`
                   forall (v :: * -> *) a. Vector v a => v a -> Int -> a
G.unsafeIndex v a
roots (Int
k forall a. Bits a => a -> Int -> a
`unsafeShiftL` Int
shift)
          forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> a -> m ()
MG.unsafeWrite Mutable v s a
ys Int
k             forall a b. (a -> b) -> a -> b
$! a
y0 forall a. Semiring a => a -> a -> a
`plus`  a
y1
          forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> a -> m ()
MG.unsafeWrite Mutable v s a
ys (Int
k forall a. Num a => a -> a -> a
+ Int
halfLen) forall a b. (a -> b) -> a -> b
$! a
y0 forall a. Ring a => a -> a -> a
`minus` a
y1
        forall (m :: * -> *) (v :: * -> *) a.
(PrimMonad m, Vector v a) =>
Mutable v (PrimState m) a -> m (v a)
G.unsafeFreeze Mutable v s a
ys
{-# INLINABLE dft #-}

-- | Inverse <https://en.wikipedia.org/wiki/Fast_Fourier_transform discrete Fourier transform>
-- \( x_k = {1\over N} \sum_{j=0}^{N-1} y_j \sqrt[N]{1}^{-jk} \).
--
-- @since 0.5.0.0
inverseDft
  :: (Field a, G.Vector v a)
  => a   -- ^ primitive root \( \sqrt[N]{1} \), otherwise behaviour is undefined
  -> v a -- ^ \( \{ y_k \}_{k=0}^{N-1} \) (currently only  \( N = 2^n \) is supported)
  -> v a -- ^ \( \{ x_k \}_{k=0}^{N-1} \)
inverseDft :: forall a (v :: * -> *). (Field a, Vector v a) => a -> v a -> v a
inverseDft a
primRoot v a
ys = forall (v :: * -> *) a b.
(Vector v a, Vector v b) =>
(a -> b) -> v a -> v b
G.map (forall a. Semiring a => a -> a -> a
`times` a
invN) forall a b. (a -> b) -> a -> b
$ forall a (v :: * -> *). (Ring a, Vector v a) => a -> v a -> v a
dft (forall a. Field a => a -> a
recip a
primRoot) v a
ys
  where
    invN :: a
invN = forall a. Field a => a -> a
recip forall a b. (a -> b) -> a -> b
$ forall a b. (Integral a, Ring b) => a -> b
fromIntegral forall a b. (a -> b) -> a -> b
$ forall (v :: * -> *) a. Vector v a => v a -> Int
G.length v a
ys
{-# INLINABLE inverseDft #-}