{-# LANGUAGE TypeOperators, RankNTypes, PatternGuards #-}
module Data.Array.Repa.Algorithms.DFT
( dftP
, idftP
, dftWithRootsP
, dftWithRootsSingleS)
where
import Data.Array.Repa.Algorithms.DFT.Roots as R
import Data.Array.Repa.Algorithms.Complex as R
import Data.Array.Repa as R
import Prelude as P
dftP :: (Shape sh, Monad m)
=> Array U (sh :. Int) Complex
-> m (Array U (sh :. Int) Complex)
dftP v
= do rofu <- calcRootsOfUnityP (extent v)
dftWithRootsP rofu v
{-# INLINE dftP #-}
idftP :: (Shape sh, Monad m)
=> Array U (sh :. Int) Complex
-> m (Array U (sh :. Int) Complex)
idftP v
= do let _ :. len = extent v
let scale = (fromIntegral len, 0)
rofu <- calcInverseRootsOfUnityP (extent v)
roots <- dftWithRootsP rofu v
computeP $ R.map (/ scale) roots
{-# INLINE idftP #-}
dftWithRootsP
:: (Shape sh, Monad m)
=> Array U (sh :. Int) Complex
-> Array U (sh :. Int) Complex
-> m (Array U (sh :. Int) Complex)
dftWithRootsP rofu arr
| _ :. rLen <- extent rofu
, _ :. vLen <- extent arr
, rLen /= vLen
= error $ "dftWithRoots: length of vector (" P.++ show vLen P.++ ")"
P.++ " does not match the length of the roots (" P.++ show rLen P.++ ")"
| otherwise
= computeP $ R.traverse arr id (\_ k -> dftWithRootsSingleS rofu arr k)
{-# INLINE dftWithRootsP #-}
dftWithRootsSingleS
:: Shape sh
=> Array U (sh :. Int) Complex
-> Array U (sh :. Int) Complex
-> (sh :. Int)
-> Complex
dftWithRootsSingleS rofu arrX (_ :. k)
| _ :. rLen <- extent rofu
, _ :. vLen <- extent arrX
, rLen /= vLen
= error $ "dftWithRootsSingle: length of vector (" P.++ show vLen P.++ ")"
P.++ " does not match the length of the roots (" P.++ show rLen P.++ ")"
| otherwise
= let sh@(_ :. len) = extent arrX
wroots = fromFunction sh elemFn
elemFn (sh' :. n)
= rofu ! (sh' :. (k * n) `mod` len)
in R.sumAllS $ R.zipWith (*) arrX wroots
{-# INLINE dftWithRootsSingleS #-}