module Data.Vector.Generic.Mutable (
MVector(..),
length, null,
slice, init, tail, take, drop, splitAt,
unsafeSlice, unsafeInit, unsafeTail, unsafeTake, unsafeDrop,
overlaps,
new, unsafeNew, replicate, replicateM, clone,
grow, unsafeGrow,
clear,
read, write, swap,
unsafeRead, unsafeWrite, unsafeSwap,
set, copy, move, unsafeCopy, unsafeMove,
mstream, mstreamR,
unstream, unstreamR,
munstream, munstreamR,
transform, transformR,
fill, fillR,
unsafeAccum, accum, unsafeUpdate, update, reverse,
unstablePartition, unstablePartitionStream, partitionStream
) where
import qualified Data.Vector.Fusion.Stream as Stream
import Data.Vector.Fusion.Stream ( Stream, MStream )
import qualified Data.Vector.Fusion.Stream.Monadic as MStream
import Data.Vector.Fusion.Stream.Size
import Data.Vector.Fusion.Util ( delay_inline )
import Control.Monad.Primitive ( PrimMonad, PrimState )
import Prelude hiding ( length, null, replicate, reverse, map, read,
take, drop, splitAt, init, tail )
#include "vector.h"
class MVector v a where
basicLength :: v s a -> Int
basicUnsafeSlice :: Int
-> Int
-> v s a
-> v s a
basicOverlaps :: v s a -> v s a -> Bool
basicUnsafeNew :: PrimMonad m => Int -> m (v (PrimState m) a)
basicUnsafeReplicate :: PrimMonad m => Int -> a -> m (v (PrimState m) a)
basicUnsafeRead :: PrimMonad m => v (PrimState m) a -> Int -> m a
basicUnsafeWrite :: PrimMonad m => v (PrimState m) a -> Int -> a -> m ()
basicClear :: PrimMonad m => v (PrimState m) a -> m ()
basicSet :: PrimMonad m => v (PrimState m) a -> a -> m ()
basicUnsafeCopy :: PrimMonad m => v (PrimState m) a
-> v (PrimState m) a
-> m ()
basicUnsafeMove :: PrimMonad m => v (PrimState m) a
-> v (PrimState m) a
-> m ()
basicUnsafeGrow :: PrimMonad m => v (PrimState m) a -> Int
-> m (v (PrimState m) a)
basicUnsafeReplicate n x
= do
v <- basicUnsafeNew n
basicSet v x
return v
basicClear _ = return ()
basicSet !v x
| n == 0 = return ()
| otherwise = do
basicUnsafeWrite v 0 x
do_set 1
where
!n = basicLength v
do_set i | 2*i < n = do basicUnsafeCopy (basicUnsafeSlice i i v)
(basicUnsafeSlice 0 i v)
do_set (2*i)
| otherwise = basicUnsafeCopy (basicUnsafeSlice i (ni) v)
(basicUnsafeSlice 0 (ni) v)
basicUnsafeCopy !dst !src = do_copy 0
where
!n = basicLength src
do_copy i | i < n = do
x <- basicUnsafeRead src i
basicUnsafeWrite dst i x
do_copy (i+1)
| otherwise = return ()
basicUnsafeMove !dst !src
| basicOverlaps dst src = do
srcCopy <- clone src
basicUnsafeCopy dst srcCopy
| otherwise = basicUnsafeCopy dst src
basicUnsafeGrow v by
= do
v' <- basicUnsafeNew (n+by)
basicUnsafeCopy (basicUnsafeSlice 0 n v') v
return v'
where
n = basicLength v
unsafeAppend1 :: (PrimMonad m, MVector v a)
=> v (PrimState m) a -> Int -> a -> m (v (PrimState m) a)
unsafeAppend1 v i x
| i < length v = do
unsafeWrite v i x
return v
| otherwise = do
v' <- enlarge v
INTERNAL_CHECK(checkIndex) "unsafeAppend1" i (length v')
$ unsafeWrite v' i x
return v'
unsafePrepend1 :: (PrimMonad m, MVector v a)
=> v (PrimState m) a -> Int -> a -> m (v (PrimState m) a, Int)
unsafePrepend1 v i x
| i /= 0 = do
let i' = i1
unsafeWrite v i' x
return (v, i')
| otherwise = do
(v', i) <- enlargeFront v
let i' = i1
INTERNAL_CHECK(checkIndex) "unsafePrepend1" i' (length v')
$ unsafeWrite v' i' x
return (v', i')
mstream :: (PrimMonad m, MVector v a) => v (PrimState m) a -> MStream m a
mstream v = v `seq` n `seq` (MStream.unfoldrM get 0 `MStream.sized` Exact n)
where
n = length v
get i | i < n = do x <- unsafeRead v i
return $ Just (x, i+1)
| otherwise = return $ Nothing
fill :: (PrimMonad m, MVector v a)
=> v (PrimState m) a -> MStream m a -> m (v (PrimState m) a)
fill v s = v `seq` do
n' <- MStream.foldM put 0 s
return $ unsafeSlice 0 n' v
where
put i x = do
INTERNAL_CHECK(checkIndex) "fill" i (length v)
$ unsafeWrite v i x
return (i+1)
transform :: (PrimMonad m, MVector v a)
=> (MStream m a -> MStream m a) -> v (PrimState m) a -> m (v (PrimState m) a)
transform f v = fill v (f (mstream v))
mstreamR :: (PrimMonad m, MVector v a) => v (PrimState m) a -> MStream m a
mstreamR v = v `seq` n `seq` (MStream.unfoldrM get n `MStream.sized` Exact n)
where
n = length v
get i | j >= 0 = do x <- unsafeRead v j
return $ Just (x,j)
| otherwise = return Nothing
where
j = i1
fillR :: (PrimMonad m, MVector v a)
=> v (PrimState m) a -> MStream m a -> m (v (PrimState m) a)
fillR v s = v `seq` do
i <- MStream.foldM put n s
return $ unsafeSlice i (ni) v
where
n = length v
put i x = do
unsafeWrite v j x
return j
where
j = i1
transformR :: (PrimMonad m, MVector v a)
=> (MStream m a -> MStream m a) -> v (PrimState m) a -> m (v (PrimState m) a)
transformR f v = fillR v (f (mstreamR v))
unstream :: (PrimMonad m, MVector v a) => Stream a -> m (v (PrimState m) a)
unstream s = munstream (Stream.liftStream s)
munstream :: (PrimMonad m, MVector v a) => MStream m a -> m (v (PrimState m) a)
munstream s = case upperBound (MStream.size s) of
Just n -> munstreamMax s n
Nothing -> munstreamUnknown s
munstreamMax
:: (PrimMonad m, MVector v a) => MStream m a -> Int -> m (v (PrimState m) a)
munstreamMax s n
= do
v <- INTERNAL_CHECK(checkLength) "munstreamMax" n
$ unsafeNew n
let put i x = do
INTERNAL_CHECK(checkIndex) "munstreamMax" i n
$ unsafeWrite v i x
return (i+1)
n' <- MStream.foldM' put 0 s
return $ INTERNAL_CHECK(checkSlice) "munstreamMax" 0 n' n
$ unsafeSlice 0 n' v
munstreamUnknown
:: (PrimMonad m, MVector v a) => MStream m a -> m (v (PrimState m) a)
munstreamUnknown s
= do
v <- unsafeNew 0
(v', n) <- MStream.foldM put (v, 0) s
return $ INTERNAL_CHECK(checkSlice) "munstreamUnknown" 0 n (length v')
$ unsafeSlice 0 n v'
where
put (v,i) x = do
v' <- unsafeAppend1 v i x
return (v',i+1)
unstreamR :: (PrimMonad m, MVector v a) => Stream a -> m (v (PrimState m) a)
unstreamR s = munstreamR (Stream.liftStream s)
munstreamR :: (PrimMonad m, MVector v a) => MStream m a -> m (v (PrimState m) a)
munstreamR s = case upperBound (MStream.size s) of
Just n -> munstreamRMax s n
Nothing -> munstreamRUnknown s
munstreamRMax
:: (PrimMonad m, MVector v a) => MStream m a -> Int -> m (v (PrimState m) a)
munstreamRMax s n
= do
v <- INTERNAL_CHECK(checkLength) "munstreamRMax" n
$ unsafeNew n
let put i x = do
let i' = i1
INTERNAL_CHECK(checkIndex) "munstreamRMax" i' n
$ unsafeWrite v i' x
return i'
i <- MStream.foldM' put n s
return $ INTERNAL_CHECK(checkSlice) "munstreamRMax" i (ni) n
$ unsafeSlice i (ni) v
munstreamRUnknown
:: (PrimMonad m, MVector v a) => MStream m a -> m (v (PrimState m) a)
munstreamRUnknown s
= do
v <- unsafeNew 0
(v', i) <- MStream.foldM put (v, 0) s
let n = length v'
return $ INTERNAL_CHECK(checkSlice) "unstreamRUnknown" i (ni) n
$ unsafeSlice i (ni) v'
where
put (v,i) x = unsafePrepend1 v i x
length :: MVector v a => v s a -> Int
length = basicLength
null :: MVector v a => v s a -> Bool
null v = length v == 0
slice :: MVector v a => Int -> Int -> v s a -> v s a
slice i n v = BOUNDS_CHECK(checkSlice) "slice" i n (length v)
$ unsafeSlice i n v
take :: MVector v a => Int -> v s a -> v s a
take n v = unsafeSlice 0 (min (max n 0) (length v)) v
drop :: MVector v a => Int -> v s a -> v s a
drop n v = unsafeSlice (min m n') (max 0 (m n')) v
where
n' = max n 0
m = length v
splitAt :: MVector v a => Int -> v s a -> (v s a, v s a)
splitAt n v = ( unsafeSlice 0 m v
, unsafeSlice m (max 0 (len n')) v
)
where
m = min n' len
n' = max n 0
len = length v
init :: MVector v a => v s a -> v s a
init v = slice 0 (length v 1) v
tail :: MVector v a => v s a -> v s a
tail v = slice 1 (length v 1) v
unsafeSlice :: MVector v a => Int
-> Int
-> v s a
-> v s a
unsafeSlice i n v = UNSAFE_CHECK(checkSlice) "unsafeSlice" i n (length v)
$ basicUnsafeSlice i n v
unsafeInit :: MVector v a => v s a -> v s a
unsafeInit v = unsafeSlice 0 (length v 1) v
unsafeTail :: MVector v a => v s a -> v s a
unsafeTail v = unsafeSlice 1 (length v 1) v
unsafeTake :: MVector v a => Int -> v s a -> v s a
unsafeTake n v = unsafeSlice 0 n v
unsafeDrop :: MVector v a => Int -> v s a -> v s a
unsafeDrop n v = unsafeSlice n (length v n) v
overlaps :: MVector v a => v s a -> v s a -> Bool
overlaps = basicOverlaps
new :: (PrimMonad m, MVector v a) => Int -> m (v (PrimState m) a)
new n = BOUNDS_CHECK(checkLength) "new" n
$ unsafeNew n
unsafeNew :: (PrimMonad m, MVector v a) => Int -> m (v (PrimState m) a)
unsafeNew n = UNSAFE_CHECK(checkLength) "unsafeNew" n
$ basicUnsafeNew n
replicate :: (PrimMonad m, MVector v a) => Int -> a -> m (v (PrimState m) a)
replicate n x = basicUnsafeReplicate (delay_inline max 0 n) x
replicateM :: (PrimMonad m, MVector v a) => Int -> m a -> m (v (PrimState m) a)
replicateM n m = munstream (MStream.replicateM n m)
clone :: (PrimMonad m, MVector v a) => v (PrimState m) a -> m (v (PrimState m) a)
clone v = do
v' <- unsafeNew (length v)
unsafeCopy v' v
return v'
grow :: (PrimMonad m, MVector v a)
=> v (PrimState m) a -> Int -> m (v (PrimState m) a)
grow v by = BOUNDS_CHECK(checkLength) "grow" by
$ unsafeGrow v by
growFront :: (PrimMonad m, MVector v a)
=> v (PrimState m) a -> Int -> m (v (PrimState m) a)
growFront v by = BOUNDS_CHECK(checkLength) "growFront" by
$ unsafeGrowFront v by
enlarge_delta v = max (length v) 1
enlarge :: (PrimMonad m, MVector v a)
=> v (PrimState m) a -> m (v (PrimState m) a)
enlarge v = unsafeGrow v (enlarge_delta v)
enlargeFront :: (PrimMonad m, MVector v a)
=> v (PrimState m) a -> m (v (PrimState m) a, Int)
enlargeFront v = do
v' <- unsafeGrowFront v by
return (v', by)
where
by = enlarge_delta v
unsafeGrow :: (PrimMonad m, MVector v a)
=> v (PrimState m) a -> Int -> m (v (PrimState m) a)
unsafeGrow v n = UNSAFE_CHECK(checkLength) "unsafeGrow" n
$ basicUnsafeGrow v n
unsafeGrowFront :: (PrimMonad m, MVector v a)
=> v (PrimState m) a -> Int -> m (v (PrimState m) a)
unsafeGrowFront v by = UNSAFE_CHECK(checkLength) "unsafeGrowFront" by
$ do
let n = length v
v' <- basicUnsafeNew (by+n)
basicUnsafeCopy (basicUnsafeSlice by n v') v
return v'
clear :: (PrimMonad m, MVector v a) => v (PrimState m) a -> m ()
clear = basicClear
read :: (PrimMonad m, MVector v a) => v (PrimState m) a -> Int -> m a
read v i = BOUNDS_CHECK(checkIndex) "read" i (length v)
$ unsafeRead v i
write :: (PrimMonad m, MVector v a) => v (PrimState m) a -> Int -> a -> m ()
write v i x = BOUNDS_CHECK(checkIndex) "write" i (length v)
$ unsafeWrite v i x
swap :: (PrimMonad m, MVector v a) => v (PrimState m) a -> Int -> Int -> m ()
swap v i j = BOUNDS_CHECK(checkIndex) "swap" i (length v)
$ BOUNDS_CHECK(checkIndex) "swap" j (length v)
$ unsafeSwap v i j
exchange :: (PrimMonad m, MVector v a) => v (PrimState m) a -> Int -> a -> m a
exchange v i x = BOUNDS_CHECK(checkIndex) "exchange" i (length v)
$ unsafeExchange v i x
unsafeRead :: (PrimMonad m, MVector v a) => v (PrimState m) a -> Int -> m a
unsafeRead v i = UNSAFE_CHECK(checkIndex) "unsafeRead" i (length v)
$ basicUnsafeRead v i
unsafeWrite :: (PrimMonad m, MVector v a)
=> v (PrimState m) a -> Int -> a -> m ()
unsafeWrite v i x = UNSAFE_CHECK(checkIndex) "unsafeWrite" i (length v)
$ basicUnsafeWrite v i x
unsafeSwap :: (PrimMonad m, MVector v a)
=> v (PrimState m) a -> Int -> Int -> m ()
unsafeSwap v i j = UNSAFE_CHECK(checkIndex) "unsafeSwap" i (length v)
$ UNSAFE_CHECK(checkIndex) "unsafeSwap" j (length v)
$ do
x <- unsafeRead v i
y <- unsafeRead v j
unsafeWrite v i y
unsafeWrite v j x
unsafeExchange :: (PrimMonad m, MVector v a)
=> v (PrimState m) a -> Int -> a -> m a
unsafeExchange v i x = UNSAFE_CHECK(checkIndex) "unsafeExchange" i (length v)
$ do
y <- unsafeRead v i
unsafeWrite v i x
return y
set :: (PrimMonad m, MVector v a) => v (PrimState m) a -> a -> m ()
set = basicSet
copy :: (PrimMonad m, MVector v a)
=> v (PrimState m) a -> v (PrimState m) a -> m ()
copy dst src = BOUNDS_CHECK(check) "copy" "overlapping vectors"
(not (dst `overlaps` src))
$ BOUNDS_CHECK(check) "copy" "length mismatch"
(length dst == length src)
$ unsafeCopy dst src
move :: (PrimMonad m, MVector v a)
=> v (PrimState m) a -> v (PrimState m) a -> m ()
move dst src = BOUNDS_CHECK(check) "move" "length mismatch"
(length dst == length src)
$ unsafeMove dst src
unsafeCopy :: (PrimMonad m, MVector v a) => v (PrimState m) a
-> v (PrimState m) a
-> m ()
unsafeCopy dst src = UNSAFE_CHECK(check) "unsafeCopy" "length mismatch"
(length dst == length src)
$ UNSAFE_CHECK(check) "unsafeCopy" "overlapping vectors"
(not (dst `overlaps` src))
$ (dst `seq` src `seq` basicUnsafeCopy dst src)
unsafeMove :: (PrimMonad m, MVector v a) => v (PrimState m) a
-> v (PrimState m) a
-> m ()
unsafeMove dst src = UNSAFE_CHECK(check) "unsafeMove" "length mismatch"
(length dst == length src)
$ (dst `seq` src `seq` basicUnsafeMove dst src)
accum :: (PrimMonad m, MVector v a)
=> (a -> b -> a) -> v (PrimState m) a -> Stream (Int, b) -> m ()
accum f !v s = Stream.mapM_ upd s
where
upd (i,b) = do
a <- BOUNDS_CHECK(checkIndex) "accum" i n
$ unsafeRead v i
unsafeWrite v i (f a b)
!n = length v
update :: (PrimMonad m, MVector v a)
=> v (PrimState m) a -> Stream (Int, a) -> m ()
update !v s = Stream.mapM_ upd s
where
upd (i,b) = BOUNDS_CHECK(checkIndex) "update" i n
$ unsafeWrite v i b
!n = length v
unsafeAccum :: (PrimMonad m, MVector v a)
=> (a -> b -> a) -> v (PrimState m) a -> Stream (Int, b) -> m ()
unsafeAccum f !v s = Stream.mapM_ upd s
where
upd (i,b) = do
a <- UNSAFE_CHECK(checkIndex) "accum" i n
$ unsafeRead v i
unsafeWrite v i (f a b)
!n = length v
unsafeUpdate :: (PrimMonad m, MVector v a)
=> v (PrimState m) a -> Stream (Int, a) -> m ()
unsafeUpdate !v s = Stream.mapM_ upd s
where
upd (i,b) = UNSAFE_CHECK(checkIndex) "accum" i n
$ unsafeWrite v i b
!n = length v
reverse :: (PrimMonad m, MVector v a) => v (PrimState m) a -> m ()
reverse !v = reverse_loop 0 (length v 1)
where
reverse_loop i j | i < j = do
unsafeSwap v i j
reverse_loop (i + 1) (j 1)
reverse_loop _ _ = return ()
unstablePartition :: forall m v a. (PrimMonad m, MVector v a)
=> (a -> Bool) -> v (PrimState m) a -> m Int
unstablePartition f !v = from_left 0 (length v)
where
from_left :: Int -> Int -> m Int
from_left i j
| i == j = return i
| otherwise = do
x <- unsafeRead v i
if f x
then from_left (i+1) j
else from_right i (j1)
from_right :: Int -> Int -> m Int
from_right i j
| i == j = return i
| otherwise = do
x <- unsafeRead v j
if f x
then do
y <- unsafeRead v i
unsafeWrite v i x
unsafeWrite v j y
from_left (i+1) j
else from_right i (j1)
unstablePartitionStream :: (PrimMonad m, MVector v a)
=> (a -> Bool) -> Stream a -> m (v (PrimState m) a, v (PrimState m) a)
unstablePartitionStream f s
= case upperBound (Stream.size s) of
Just n -> unstablePartitionMax f s n
Nothing -> partitionUnknown f s
unstablePartitionMax :: (PrimMonad m, MVector v a)
=> (a -> Bool) -> Stream a -> Int
-> m (v (PrimState m) a, v (PrimState m) a)
unstablePartitionMax f s n
= do
v <- INTERNAL_CHECK(checkLength) "unstablePartitionMax" n
$ unsafeNew n
let
put (i, j) x
| f x = do
unsafeWrite v i x
return (i+1, j)
| otherwise = do
unsafeWrite v (j1) x
return (i, j1)
(i,j) <- Stream.foldM' put (0, n) s
return (unsafeSlice 0 i v, unsafeSlice j (nj) v)
partitionStream :: (PrimMonad m, MVector v a)
=> (a -> Bool) -> Stream a -> m (v (PrimState m) a, v (PrimState m) a)
partitionStream f s
= case upperBound (Stream.size s) of
Just n -> partitionMax f s n
Nothing -> partitionUnknown f s
partitionMax :: (PrimMonad m, MVector v a)
=> (a -> Bool) -> Stream a -> Int -> m (v (PrimState m) a, v (PrimState m) a)
partitionMax f s n
= do
v <- INTERNAL_CHECK(checkLength) "unstablePartitionMax" n
$ unsafeNew n
let
put (i,j) x
| f x = do
unsafeWrite v i x
return (i+1,j)
| otherwise = let j' = j1 in
do
unsafeWrite v j' x
return (i,j')
(i,j) <- Stream.foldM' put (0,n) s
INTERNAL_CHECK(check) "partitionMax" "invalid indices" (i <= j)
$ return ()
let l = unsafeSlice 0 i v
r = unsafeSlice j (nj) v
reverse r
return (l,r)
partitionUnknown :: (PrimMonad m, MVector v a)
=> (a -> Bool) -> Stream a -> m (v (PrimState m) a, v (PrimState m) a)
partitionUnknown f s
= do
v1 <- unsafeNew 0
v2 <- unsafeNew 0
(v1', n1, v2', n2) <- Stream.foldM' put (v1, 0, v2, 0) s
INTERNAL_CHECK(checkSlice) "partitionUnknown" 0 n1 (length v1')
$ INTERNAL_CHECK(checkSlice) "partitionUnknown" 0 n2 (length v2')
$ return (unsafeSlice 0 n1 v1', unsafeSlice 0 n2 v2')
where
put (v1, i1, v2, i2) x
| f x = do
v1' <- unsafeAppend1 v1 i1 x
return (v1', i1+1, v2, i2)
| otherwise = do
v2' <- unsafeAppend1 v2 i2 x
return (v1, i1, v2', i2+1)