module Numeric.FFT.Utils ( omega, slicevecs, slicemvecs, primes, isPrime , allFactors, factors , primitiveRoot, invModN, log2, isPow2, dupperm, (%.%) , compositions, makeComp, multisetPerms , backpermuteM ) where import Prelude hiding (all, concatMap, dropWhile, enumFromTo, filter, head, length, map, maximum, null, reverse) import qualified Prelude as P import qualified Control.Monad as CM import Control.Monad.ST import Data.Bits import Data.Complex import Data.Vector.Unboxed import qualified Data.Vector as V import qualified Data.Vector.Unboxed.Mutable as MV import Data.List (nub) import qualified Data.List as L import Numeric.FFT.Types -- | Roots of unity. omega :: Int -> Complex Double omega n = cis (2 * pi / fromIntegral n) -- | Slice a vector @v@ into equally sized parts, each of length @m@. slicevecs :: Int -> VCD -> VVCD slicevecs m v = V.map (\i -> slice (i * m) m v) $ V.enumFromN 0 (length v `div` m) -- | Slice a mutable vector @v@ into equally sized parts, each of -- length @m@. slicemvecs :: Int -> MVCD a -> VMVCD a slicemvecs m v = V.map (\i -> MV.slice (i * m) m v) $ V.enumFromN 0 (MV.length v `div` m) -- | Determine primitive roots modulo n. -- -- From Wikipedia (https://en.wikipedia.org/wiki/Primitive_root_modulo_n): -- -- No simple general formula to compute primitive roots modulo n is -- known. There are however methods to locate a primitive root that -- are faster than simply trying out all candidates. If the -- multiplicative order of a number m modulo n is equal to phi(n) (the -- order of Z_n^x), then it is a primitive root. In fact the converse -- is true: if m is a primitive root modulo n, then the multiplicative -- order of m is phi(n). We can use this to test for primitive roots. -- [Here, phi(n) is Euler's totient function, and Z_n^x is the -- multiplicative group of integers modulo n.] -- -- First, compute phi(n). Then determine the different prime factors -- of phi(n), say p1, ..., pk. Now, for every element m of Z_n^x, -- compute -- -- m^(phi(n) / pi) mod n for i = 1, ..., k -- -- using a fast algorithm for modular exponentiation such as -- exponentiation by squaring. A number m for which these k results -- are all different from 1 is a primitive root. -- -- [In our case, n is restricted to being prime, and phi(p) = p - 1 -- for prime p.] -- primitiveRoot :: Int -> Int primitiveRoot p | isPrime p = let tot = p - 1 -- ^ Euler's totient function for prime values. totpows = map (tot `div`) $ fromList $ nub $ toList $ allFactors tot -- ^ Powers to check. check n = all (/=1) $ map (expt p n) totpows -- ^ All powers are different from 1 => primitive root. in fromIntegral $ head $ dropWhile (not . check) $ fromList [1..p-1] | otherwise = error "Attempt to take primitive root of non-prime value" -- | Fast exponentation modulo n by squaring. expt :: Int -> Int -> Int -> Int expt n b pow = fromIntegral $ go pow where bb = fromIntegral b nb = fromIntegral n go :: Int -> Integer go p | p == 0 = 1 | p `mod` 2 == 1 = (bb * go (p - 1)) `mod` nb | otherwise = let h = go (p `div` 2) in (h * h) `mod` nb -- | Find inverse element in multiplicative integer group modulo n. invModN :: Int -> Int -> Int invModN n g = head $ filter (\iv -> (g * iv) `mod` n == 1) $ enumFromTo 1 (n-1) -- | Prime sieve from Haskell wiki. primes :: Integral a => [a] primes = 2 : primes' where primes' = sieve [3, 5 ..] 9 primes' sieve (x:xs) q ps@ ~(p:t) | x < q = x : sieve xs q ps | True = sieve [n | n <- xs, rem n p /= 0] (P.head t^2) t -- | Naive primality testing. isPrime :: Integral a => a -> Bool isPrime n = n `P.elem` P.takeWhile (<= n) primes -- | Simple prime factorisation. allFactors :: (Integral a, Unbox a) => a -> Vector a allFactors n = fromList $ go n primes where go cur pss@(p:ps) | cur == p = [p] | cur `mod` p == 0 = p : go (cur `div` p) pss | otherwise = go cur ps -- | Simple prime factorisation: small factors only; largest/last -- factor picked out as "special". factors :: (Integral a, Unbox a) => a -> (a, Vector a) factors n = let (lst, rest) = go n primes in (lst, fromList rest) where go cur pss@(p:ps) | cur == p = (p, []) | cur `mod` p == 0 = let (lst, rest) = go (cur `div` p) pss in (lst, p : rest) | otherwise = go cur ps -- | Base-2 logarithm. log2 :: Int -> Int log2 1 = 0 log2 n = 1 + log2 (n `div` 2) -- | Check for powers of two. isPow2 :: Int -> Bool isPow2 1 = True isPow2 n | n `mod` 2 == 0 = isPow2 $ n `div` 2 | otherwise = False -- | Duplicate a sub-permutation to fill a given vector length. dupperm :: Int -> VI -> VI dupperm n p = let sublen = length p shift di = map (+(sublen * di)) p in concatMap shift $ enumFromN 0 (n `div` sublen) -- | Composition of permutations. (%.%) :: VI -> VI -> VI p1 %.% p2 = backpermute p2 p1 -- | Generate all compositions of a given integer. compositions :: Int -> V.Vector (Vector Int) compositions 0 = V.empty compositions n = let fs = allFactors n in V.reverse $ V.map (makeComp fs) $ V.enumFromN 0 (2^(n-1)) -- | Generate a single composition of a given integer. makeComp :: Vector Int -> Int -> Vector Int makeComp fs i = fromList $ foldOps (toList fs) $ makeOps (length fs) i where foldOps :: [Int] -> [Bool] -> [Int] foldOps (f:fs) ops = go f fs ops where go acc [] [] = [acc] go acc (f:fs) (op:ops) = if op then go (acc * f) fs ops else acc : go f fs ops makeOps :: Int -> Int -> [Bool] makeOps n i = P.replicate (n - 1 - P.length bs) False P.++ bs where bs = P.dropWhile not $ P.reverse $ P.map (testBit i) [0..bitSize i-1] -- | Generate all distinct permutations of a multiset in lexicographic -- order. multisetPerms :: Vector Int -> [Vector Int] multisetPerms idp = sidp : L.unfoldr step sidp where sidp = fromList $ L.sort $ toList idp step v = case permStep v of Nothing -> Nothing Just p -> Just (p, p) permStep :: Vector Int -> Maybe (Vector Int) permStep v = if null ks then Nothing else let k = maximum ks ls = filter (\i -> v ! k < v ! i) $ enumFromN 0 n l = maximum ls in Just $ revEnd k (swap k l) where n = length v ks = filter (\i -> v ! i < v ! (i+1)) $ enumFromN 0 (n-1) swap a b = generate n $ \i -> if i == a then v ! b else if i == b then v ! a else v ! i revEnd f vv = generate n $ \i -> if i <= f then vv ! i else vv ! (n - i + f) -- | ST monad version of vector permutation. backpermuteM :: Int -> VI -> MVCD s -> MVCD s -> ST s () backpermuteM n perm vin vout = do CM.forM_ [0..n-1] $ \i -> do idx <- indexM perm i x <- MV.unsafeRead vin idx MV.unsafeWrite vout i x