{-# LANGUAGE BangPatterns #-}
module SDR.VectorUtils (
mapAccumMV,
stride,
fill,
copyInto,
vUnfoldr,
vUnfoldrM
) where
import Control.Monad
import Control.Monad.Primitive
import Control.Monad.ST
import qualified Data.Vector.Generic as VG
import qualified Data.Vector.Generic.Mutable as VGM
import qualified Data.Vector.Fusion.Bundle as VFB
import Data.Vector.Fusion.Stream.Monadic as VFSM
mapAccumMV :: (Monad m)
=> (acc -> x -> m (acc, y))
-> acc
-> VFSM.Stream m x
-> Stream m y
mapAccumMV func z (VFSM.Stream step s) = VFSM.Stream step' (s, z)
where
step' (s, acc) = do
r <- step s
case r of
VFB.Yield y s' -> do
(!acc', !res) <- func acc y
return $ VFB.Yield res (s', acc')
VFB.Skip s' -> return $ VFB.Skip (s', acc)
VFB.Done -> return VFB.Done
{-# INLINE stride #-}
stride :: VG.Vector v a
=> Int
-> v a
-> v a
stride str inv = VG.unstream $ VFB.unfoldr func 0
where
len = VG.length inv
func i | i >= len = Nothing
| otherwise = Just (VG.unsafeIndex inv i, i + str)
{-# INLINE fill #-}
fill :: (PrimMonad m, Functor m, VGM.MVector vm a)
=> VFB.Bundle v a
-> vm (PrimState m) a
-> m ()
fill str outBuf = void $ VFB.foldM' put 0 str
where
put i x = do
VGM.unsafeWrite outBuf i x
return $ i + 1
{-# INLINE copyInto #-}
copyInto :: (PrimMonad m, VGM.MVector vm a, VG.Vector v a)
=> vm (PrimState m) a
-> v a
-> m ()
copyInto dst src = fill (VG.stream src) dst
{-# INLINE vUnfoldr #-}
vUnfoldr :: VG.Vector v x
=> Int
-> (acc -> (x, acc))
-> acc
-> (v x, acc)
vUnfoldr size func acc = runST $ do
vect <- VGM.new size
acc' <- go vect 0 acc
vect' <- VG.unsafeFreeze vect
return (vect', acc')
where
go vect offset acc = go' offset acc
where
go' offset acc
| offset == size = return acc
| otherwise = do
let (res, acc') = func acc
VGM.write vect offset res
go' (offset + 1) acc'
{-# INLINE vUnfoldrM #-}
vUnfoldrM :: (PrimMonad m, VG.Vector v x)
=> Int
-> (acc -> m (x, acc))
-> acc
-> m (v x, acc)
vUnfoldrM size func acc = do
vect <- VGM.new size
acc' <- go vect 0 acc
vect' <- VG.unsafeFreeze vect
return (vect', acc')
where
go vect offset acc = go' offset acc
where
go' offset acc
| offset == size = return acc
| otherwise = do
(res, acc') <- func acc
VGM.write vect offset res
go' (offset + 1) acc'