{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
module HaskellWorks.Data.Vector.Storable
( padded
, foldMap
, mapAccumL
, mmap
, constructSI
, construct2N
) where
import Control.Monad.ST (ST, runST)
import Data.Monoid (Monoid (..), (<>))
import Data.Vector.Storable (Storable)
import Data.Word
import Foreign.ForeignPtr
import Prelude hiding (foldMap)
import qualified Data.Vector.Generic as DVG
import qualified Data.Vector.Storable as DVS
import qualified Data.Vector.Storable.Mutable as DVSM
import qualified System.IO.MMap as IO
{-# ANN module ("HLint: ignore Redundant do" :: String) #-}
padded :: Int -> DVS.Vector Word8 -> DVS.Vector Word8
padded n v = v <> DVS.replicate ((n - DVS.length v) `max` 0) 0
{-# INLINE padded #-}
foldMap :: (DVS.Storable a, Monoid m) => (a -> m) -> DVS.Vector a -> m
foldMap f = DVS.foldl' (\a b -> a <> f b) mempty
{-# INLINE foldMap #-}
mapAccumL :: forall a b c. (Storable b, Storable c)
=> (a -> b -> (a, c))
-> a
-> DVS.Vector b
-> (a, DVS.Vector c)
mapAccumL f a vb = DVS.createT $ do
vc <- DVSM.unsafeNew (DVS.length vb)
a' <- go 0 a vc
return (a', vc)
where go :: Int -> a -> DVS.MVector s c -> ST s a
go i a0 vc = if i < DVS.length vb
then do
let (a1, c1) = f a0 (DVS.unsafeIndex vb i)
DVSM.unsafeWrite vc i c1
go (i + 1) a1 vc
else return a0
{-# INLINE mapAccumL #-}
mmap :: Storable a => FilePath -> IO (DVS.Vector a)
mmap filepath = do
(fptr :: ForeignPtr Word8, offset, size) <- IO.mmapFileForeignPtr filepath IO.ReadOnly Nothing
let !v = DVS.unsafeFromForeignPtr fptr offset size
return (DVS.unsafeCast v)
constructSI :: forall a s. Storable a => Int -> (Int -> s -> (s, a)) -> s -> (s, DVS.Vector a)
constructSI n f state = DVS.createT $ do
mv <- DVSM.unsafeNew n
state' <- go 0 state mv
return (state', mv)
where go :: Int -> s -> DVSM.MVector t a -> ST t s
go i s mv = if i < DVSM.length mv
then do
let (s', a) = f i s
DVSM.unsafeWrite mv i a
go (i + 1) s' mv
else return s
construct2N :: (Storable b, Storable c)
=> Int
-> (forall s. a -> DVSM.MVector s b -> ST s Int)
-> Int
-> (forall s. a -> DVSM.MVector s c -> ST s Int)
-> [a]
-> (DVS.Vector b, DVS.Vector c)
construct2N nb fb nc fc as = runST $ do
mbs <- DVSM.unsafeNew nb
mcs <- DVSM.unsafeNew nc
(mbs2, mcs2) <- go fb 0 mbs fc 0 mcs as
bs <- DVG.unsafeFreeze mbs2
cs <- DVG.unsafeFreeze mcs2
return (bs, cs)
where go :: (Storable b, Storable c)
=> (forall t. a -> DVSM.MVector t b -> ST t Int)
-> Int
-> DVSM.MVector s b
-> (forall t. a -> DVSM.MVector t c -> ST t Int)
-> Int
-> DVSM.MVector s c
-> [a]
-> ST s (DVSM.MVector s b, DVSM.MVector s c)
go _ bn mbs _ cn mcs [] = return (DVSM.take bn mbs, DVSM.take cn mcs)
go fb' bn mbs fc' cn mcs (d:ds) = do
bi <- fb' d (DVSM.drop bn mbs)
ci <- fc' d (DVSM.drop cn mcs)
go fb' (bn + bi) mbs fc' (cn + ci) mcs ds