module ZkFold.Base.Algebra.Basic.DFT (genericDft) where
import Control.Monad (forM_)
import qualified Data.STRef as ST
import qualified Data.Vector as V
import qualified Data.Vector.Mutable as VM
import Prelude hiding (mod, sum, (*), (+), (-), (/), (^))
import qualified Prelude as P
import ZkFold.Base.Algebra.Basic.Class
genericDft
:: forall a
. Ring a
=> Integer
-> a
-> V.Vector a
-> V.Vector a
genericDft :: forall a. Ring a => Integer -> a -> Vector a -> Vector a
genericDft Integer
0 a
_ Vector a
v = Vector a
v
genericDft Integer
n a
wn Vector a
v = (forall s. ST s (MVector s a)) -> Vector a
forall a. (forall s. ST s (MVector s a)) -> Vector a
V.create ((forall s. ST s (MVector s a)) -> Vector a)
-> (forall s. ST s (MVector s a)) -> Vector a
forall a b. (a -> b) -> a -> b
$ do
MVector s a
result <- Int -> ST s (MVector (PrimState (ST s)) a)
forall (m :: Type -> Type) a.
PrimMonad m =>
Int -> m (MVector (PrimState m) a)
VM.new (Int
2 Int -> Integer -> Int
forall a b. (Num a, Integral b) => a -> b -> a
P.^ Integer
n)
STRef s a
wRef <- a -> ST s (STRef s a)
forall a s. a -> ST s (STRef s a)
ST.newSTRef a
forall a. MultiplicativeMonoid a => a
one
[Int] -> (Int -> ST s ()) -> ST s ()
forall (t :: Type -> Type) (m :: Type -> Type) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [Int
0 .. Int
halfLen Int -> Int -> Int
forall a. Num a => a -> a -> a
P.- Int
1] ((Int -> ST s ()) -> ST s ()) -> (Int -> ST s ()) -> ST s ()
forall a b. (a -> b) -> a -> b
$ \Int
k -> do
a
w <- STRef s a -> ST s a
forall s a. STRef s a -> ST s a
ST.readSTRef STRef s a
wRef
MVector (PrimState (ST s)) a -> Int -> a -> ST s ()
forall (m :: Type -> Type) a.
PrimMonad m =>
MVector (PrimState m) a -> Int -> a -> m ()
VM.write MVector s a
MVector (PrimState (ST s)) a
result Int
k (a -> ST s ()) -> a -> ST s ()
forall a b. (a -> b) -> a -> b
$ Vector a
a0Hat Vector a -> Int -> a
forall a. Vector a -> Int -> a
`V.unsafeIndex` Int
k a -> a -> a
forall a. AdditiveSemigroup a => a -> a -> a
+ a
w a -> a -> a
forall a. MultiplicativeSemigroup a => a -> a -> a
* Vector a
a1Hat Vector a -> Int -> a
forall a. Vector a -> Int -> a
`V.unsafeIndex` Int
k
MVector (PrimState (ST s)) a -> Int -> a -> ST s ()
forall (m :: Type -> Type) a.
PrimMonad m =>
MVector (PrimState m) a -> Int -> a -> m ()
VM.write MVector s a
MVector (PrimState (ST s)) a
result (Int
k Int -> Int -> Int
forall a. Num a => a -> a -> a
P.+ Int
halfLen) (a -> ST s ()) -> a -> ST s ()
forall a b. (a -> b) -> a -> b
$ Vector a
a0Hat Vector a -> Int -> a
forall a. Vector a -> Int -> a
`V.unsafeIndex` Int
k a -> a -> a
forall a. AdditiveGroup a => a -> a -> a
- a
w a -> a -> a
forall a. MultiplicativeSemigroup a => a -> a -> a
* Vector a
a1Hat Vector a -> Int -> a
forall a. Vector a -> Int -> a
`V.unsafeIndex` Int
k
STRef s a -> (a -> a) -> ST s ()
forall s a. STRef s a -> (a -> a) -> ST s ()
ST.modifySTRef STRef s a
wRef (a -> a -> a
forall a. MultiplicativeSemigroup a => a -> a -> a
*a
wn)
MVector s a -> ST s (MVector s a)
forall a. a -> ST s a
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure MVector s a
result
where
a0 :: Vector a
a0 = (Int -> a -> Bool) -> Vector a -> Vector a
forall a. (Int -> a -> Bool) -> Vector a -> Vector a
V.ifilter (\Int
i a
_ -> Int
i Int -> Int -> Int
forall a. Integral a => a -> a -> a
`P.mod` Int
2 Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0) Vector a
v
a1 :: Vector a
a1 = (Int -> a -> Bool) -> Vector a -> Vector a
forall a. (Int -> a -> Bool) -> Vector a -> Vector a
V.ifilter (\Int
i a
_ -> Int
i Int -> Int -> Int
forall a. Integral a => a -> a -> a
`P.mod` Int
2 Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
1) Vector a
v
wn2 :: a
wn2 = a
wn a -> a -> a
forall a. MultiplicativeSemigroup a => a -> a -> a
* a
wn
a0Hat :: Vector a
a0Hat = Integer -> a -> Vector a -> Vector a
forall a. Ring a => Integer -> a -> Vector a -> Vector a
genericDft (Integer
n Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
P.- Integer
1) a
wn2 Vector a
a0
a1Hat :: Vector a
a1Hat = Integer -> a -> Vector a -> Vector a
forall a. Ring a => Integer -> a -> Vector a -> Vector a
genericDft (Integer
n Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
P.- Integer
1) a
wn2 Vector a
a1
halfLen :: Int
halfLen = Int
2 Int -> Integer -> Int
forall a b. (Num a, Integral b) => a -> b -> a
P.^ (Integer
n Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
P.- Integer
1)