module ZkFold.Base.Algebra.Basic.DFT (genericDft) where

import           Control.Monad                   (forM_)
import qualified Data.STRef                      as ST
import qualified Data.Vector                     as V
import qualified Data.Vector.Mutable             as VM
import           Prelude                         hiding (mod, sum, (*), (+), (-), (/), (^))
import qualified Prelude                         as P

import           ZkFold.Base.Algebra.Basic.Class

-- | Generif FFT algorithm. Can be both direct and inverse depending on @wn@ (root of unity or its inverse) supplied.
-- Does not apply scaling when it's inverse.
-- Requires the vector to be of length 2^@n@.
--
genericDft
    :: forall a
     . Ring a
    => Integer
    -> a
    -> V.Vector a
    -> V.Vector a
genericDft :: forall a. Ring a => Integer -> a -> Vector a -> Vector a
genericDft Integer
0 a
_ Vector a
v  = Vector a
v
genericDft Integer
n a
wn Vector a
v = (forall s. ST s (MVector s a)) -> Vector a
forall a. (forall s. ST s (MVector s a)) -> Vector a
V.create ((forall s. ST s (MVector s a)) -> Vector a)
-> (forall s. ST s (MVector s a)) -> Vector a
forall a b. (a -> b) -> a -> b
$ do
    MVector s a
result <- Int -> ST s (MVector (PrimState (ST s)) a)
forall (m :: Type -> Type) a.
PrimMonad m =>
Int -> m (MVector (PrimState m) a)
VM.new (Int
2 Int -> Integer -> Int
forall a b. (Num a, Integral b) => a -> b -> a
P.^ Integer
n)
    STRef s a
wRef <- a -> ST s (STRef s a)
forall a s. a -> ST s (STRef s a)
ST.newSTRef a
forall a. MultiplicativeMonoid a => a
one
    [Int] -> (Int -> ST s ()) -> ST s ()
forall (t :: Type -> Type) (m :: Type -> Type) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [Int
0 .. Int
halfLen Int -> Int -> Int
forall a. Num a => a -> a -> a
P.- Int
1] ((Int -> ST s ()) -> ST s ()) -> (Int -> ST s ()) -> ST s ()
forall a b. (a -> b) -> a -> b
$ \Int
k -> do
        a
w <- STRef s a -> ST s a
forall s a. STRef s a -> ST s a
ST.readSTRef STRef s a
wRef
        MVector (PrimState (ST s)) a -> Int -> a -> ST s ()
forall (m :: Type -> Type) a.
PrimMonad m =>
MVector (PrimState m) a -> Int -> a -> m ()
VM.write MVector s a
MVector (PrimState (ST s)) a
result Int
k               (a -> ST s ()) -> a -> ST s ()
forall a b. (a -> b) -> a -> b
$ Vector a
a0Hat Vector a -> Int -> a
forall a. Vector a -> Int -> a
`V.unsafeIndex` Int
k a -> a -> a
forall a. AdditiveSemigroup a => a -> a -> a
+ a
w a -> a -> a
forall a. MultiplicativeSemigroup a => a -> a -> a
* Vector a
a1Hat Vector a -> Int -> a
forall a. Vector a -> Int -> a
`V.unsafeIndex` Int
k
        MVector (PrimState (ST s)) a -> Int -> a -> ST s ()
forall (m :: Type -> Type) a.
PrimMonad m =>
MVector (PrimState m) a -> Int -> a -> m ()
VM.write MVector s a
MVector (PrimState (ST s)) a
result (Int
k Int -> Int -> Int
forall a. Num a => a -> a -> a
P.+ Int
halfLen) (a -> ST s ()) -> a -> ST s ()
forall a b. (a -> b) -> a -> b
$ Vector a
a0Hat Vector a -> Int -> a
forall a. Vector a -> Int -> a
`V.unsafeIndex` Int
k a -> a -> a
forall a. AdditiveGroup a => a -> a -> a
- a
w a -> a -> a
forall a. MultiplicativeSemigroup a => a -> a -> a
* Vector a
a1Hat Vector a -> Int -> a
forall a. Vector a -> Int -> a
`V.unsafeIndex` Int
k
        STRef s a -> (a -> a) -> ST s ()
forall s a. STRef s a -> (a -> a) -> ST s ()
ST.modifySTRef STRef s a
wRef (a -> a -> a
forall a. MultiplicativeSemigroup a => a -> a -> a
*a
wn)
    MVector s a -> ST s (MVector s a)
forall a. a -> ST s a
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure MVector s a
result
  where
    a0 :: Vector a
a0 = (Int -> a -> Bool) -> Vector a -> Vector a
forall a. (Int -> a -> Bool) -> Vector a -> Vector a
V.ifilter (\Int
i a
_ -> Int
i Int -> Int -> Int
forall a. Integral a => a -> a -> a
`P.mod` Int
2 Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0) Vector a
v
    a1 :: Vector a
a1 = (Int -> a -> Bool) -> Vector a -> Vector a
forall a. (Int -> a -> Bool) -> Vector a -> Vector a
V.ifilter (\Int
i a
_ -> Int
i Int -> Int -> Int
forall a. Integral a => a -> a -> a
`P.mod` Int
2 Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
1) Vector a
v

    wn2 :: a
wn2 = a
wn a -> a -> a
forall a. MultiplicativeSemigroup a => a -> a -> a
* a
wn

    a0Hat :: Vector a
a0Hat = Integer -> a -> Vector a -> Vector a
forall a. Ring a => Integer -> a -> Vector a -> Vector a
genericDft (Integer
n Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
P.- Integer
1) a
wn2 Vector a
a0
    a1Hat :: Vector a
a1Hat = Integer -> a -> Vector a -> Vector a
forall a. Ring a => Integer -> a -> Vector a -> Vector a
genericDft (Integer
n Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
P.- Integer
1) a
wn2 Vector a
a1

    halfLen :: Int
halfLen = Int
2 Int -> Integer -> Int
forall a b. (Num a, Integral b) => a -> b -> a
P.^ (Integer
n Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
P.- Integer
1)