{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE ScopedTypeVariables #-}
module Data.Poly.Internal.Dense.DFT
( dft
, inverseDft
) where
import Prelude hiding (recip, fromIntegral)
import Control.Monad.ST
import Data.Bits hiding (shift)
import Data.Foldable
import Data.Semiring (Semiring(..), Ring(..), minus, fromIntegral)
import Data.Field (Field, recip)
import qualified Data.Vector.Generic as G
import qualified Data.Vector.Generic.Mutable as MG
dft
:: (Ring a, G.Vector v a)
=> a
-> v a
-> v a
dft :: forall a (v :: * -> *). (Ring a, Vector v a) => a -> v a -> v a
dft a
primRoot (v a
xs :: v a)
| forall a. Bits a => a -> Int
popCount Int
nn forall a. Eq a => a -> a -> Bool
/= Int
1 = forall a. HasCallStack => [Char] -> a
error [Char]
"dft: only vectors of length 2^n are supported"
| Bool
otherwise = Int -> Int -> v a
go Int
0 Int
0
where
nn :: Int
nn = forall (v :: * -> *) a. Vector v a => v a -> Int
G.length v a
xs
n :: Int
n = forall b. FiniteBits b => b -> Int
countTrailingZeros Int
nn
roots :: v a
roots :: v a
roots = forall (v :: * -> *) a. Vector v a => Int -> (a -> a) -> a -> v a
G.iterateN
(Int
1 forall a. Bits a => a -> Int -> a
`unsafeShiftL` (Int
n forall a. Num a => a -> a -> a
- Int
1))
(\a
x -> a
x seq :: forall a b. a -> b -> b
`seq` (a
x forall a. Semiring a => a -> a -> a
`times` a
primRoot))
forall a. Semiring a => a
one
go :: Int -> Int -> v a
go !Int
offset !Int
shift
| Int
shift forall a. Ord a => a -> a -> Bool
>= Int
n = forall (v :: * -> *) a. Vector v a => Int -> Int -> v a -> v a
G.unsafeSlice Int
offset Int
1 v a
xs
| Bool
otherwise = forall a. (forall s. ST s a) -> a
runST forall a b. (a -> b) -> a -> b
$ do
let halfLen :: Int
halfLen = Int
1 forall a. Bits a => a -> Int -> a
`unsafeShiftL` (Int
n forall a. Num a => a -> a -> a
- Int
shift forall a. Num a => a -> a -> a
- Int
1)
ys0 :: v a
ys0 = Int -> Int -> v a
go Int
offset (Int
shift forall a. Num a => a -> a -> a
+ Int
1)
ys1 :: v a
ys1 = Int -> Int -> v a
go (Int
offset forall a. Num a => a -> a -> a
+ Int
1 forall a. Bits a => a -> Int -> a
`unsafeShiftL` Int
shift) (Int
shift forall a. Num a => a -> a -> a
+ Int
1)
Mutable v s a
ys <- forall (m :: * -> *) (v :: * -> * -> *) a.
(HasCallStack, PrimMonad m, MVector v a) =>
Int -> m (v (PrimState m) a)
MG.new (Int
halfLen forall a. Bits a => a -> Int -> a
`unsafeShiftL` Int
1)
let y00 :: a
y00 = forall (v :: * -> *) a. Vector v a => v a -> Int -> a
G.unsafeIndex v a
ys0 Int
0
y10 :: a
y10 = forall (v :: * -> *) a. Vector v a => v a -> Int -> a
G.unsafeIndex v a
ys1 Int
0
forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> a -> m ()
MG.unsafeWrite Mutable v s a
ys Int
0 forall a b. (a -> b) -> a -> b
$! a
y00 forall a. Semiring a => a -> a -> a
`plus` a
y10
forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> a -> m ()
MG.unsafeWrite Mutable v s a
ys Int
halfLen forall a b. (a -> b) -> a -> b
$! a
y00 forall a. Ring a => a -> a -> a
`minus` a
y10
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [Int
1..Int
halfLen forall a. Num a => a -> a -> a
- Int
1] forall a b. (a -> b) -> a -> b
$ \Int
k -> do
let y0 :: a
y0 = forall (v :: * -> *) a. Vector v a => v a -> Int -> a
G.unsafeIndex v a
ys0 Int
k
y1 :: a
y1 = forall (v :: * -> *) a. Vector v a => v a -> Int -> a
G.unsafeIndex v a
ys1 Int
k forall a. Semiring a => a -> a -> a
`times`
forall (v :: * -> *) a. Vector v a => v a -> Int -> a
G.unsafeIndex v a
roots (Int
k forall a. Bits a => a -> Int -> a
`unsafeShiftL` Int
shift)
forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> a -> m ()
MG.unsafeWrite Mutable v s a
ys Int
k forall a b. (a -> b) -> a -> b
$! a
y0 forall a. Semiring a => a -> a -> a
`plus` a
y1
forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> a -> m ()
MG.unsafeWrite Mutable v s a
ys (Int
k forall a. Num a => a -> a -> a
+ Int
halfLen) forall a b. (a -> b) -> a -> b
$! a
y0 forall a. Ring a => a -> a -> a
`minus` a
y1
forall (m :: * -> *) (v :: * -> *) a.
(PrimMonad m, Vector v a) =>
Mutable v (PrimState m) a -> m (v a)
G.unsafeFreeze Mutable v s a
ys
{-# INLINABLE dft #-}
inverseDft
:: (Field a, G.Vector v a)
=> a
-> v a
-> v a
inverseDft :: forall a (v :: * -> *). (Field a, Vector v a) => a -> v a -> v a
inverseDft a
primRoot v a
ys = forall (v :: * -> *) a b.
(Vector v a, Vector v b) =>
(a -> b) -> v a -> v b
G.map (forall a. Semiring a => a -> a -> a
`times` a
invN) forall a b. (a -> b) -> a -> b
$ forall a (v :: * -> *). (Ring a, Vector v a) => a -> v a -> v a
dft (forall a. Field a => a -> a
recip a
primRoot) v a
ys
where
invN :: a
invN = forall a. Field a => a -> a
recip forall a b. (a -> b) -> a -> b
$ forall a b. (Integral a, Ring b) => a -> b
fromIntegral forall a b. (a -> b) -> a -> b
$ forall (v :: * -> *) a. Vector v a => v a -> Int
G.length v a
ys
{-# INLINABLE inverseDft #-}