{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
module HaskellWorks.Data.Vector.Storable
( padded
, foldMap
, mapAccumL
, mmap
, constructSI
, construct2N
, construct64UnzipN
) where
import Control.Monad.ST (ST, runST)
import Data.Monoid (Monoid (..), (<>))
import Data.Vector.Storable (Storable)
import Data.Word
import Foreign.ForeignPtr
import HaskellWorks.Data.Vector.AsVector8
import Prelude hiding (foldMap)
import qualified Data.ByteString as BS
import qualified Data.Vector.Generic as DVG
import qualified Data.Vector.Storable as DVS
import qualified Data.Vector.Storable.Mutable as DVSM
import qualified System.IO.MMap as IO
{-# ANN module ("HLint: ignore Redundant do" :: String) #-}
padded :: Int -> DVS.Vector Word8 -> DVS.Vector Word8
padded n v = v <> DVS.replicate ((n - DVS.length v) `max` 0) 0
{-# INLINE padded #-}
foldMap :: (DVS.Storable a, Monoid m) => (a -> m) -> DVS.Vector a -> m
foldMap f = DVS.foldl' (\a b -> a <> f b) mempty
{-# INLINE foldMap #-}
mapAccumL :: forall a b c. (Storable b, Storable c)
=> (a -> b -> (a, c))
-> a
-> DVS.Vector b
-> (a, DVS.Vector c)
mapAccumL f a vb = DVS.createT $ do
vc <- DVSM.unsafeNew (DVS.length vb)
a' <- go 0 a vc
return (a', vc)
where go :: Int -> a -> DVS.MVector s c -> ST s a
go i a0 vc = if i < DVS.length vb
then do
let (a1, c1) = f a0 (DVS.unsafeIndex vb i)
DVSM.unsafeWrite vc i c1
go (i + 1) a1 vc
else return a0
{-# INLINE mapAccumL #-}
mmap :: Storable a => FilePath -> IO (DVS.Vector a)
mmap filepath = do
(fptr :: ForeignPtr Word8, offset, size) <- IO.mmapFileForeignPtr filepath IO.ReadOnly Nothing
let !v = DVS.unsafeFromForeignPtr fptr offset size
return (DVS.unsafeCast v)
constructSI :: forall a s. Storable a => Int -> (Int -> s -> (s, a)) -> s -> (s, DVS.Vector a)
constructSI n f state = DVS.createT $ do
mv <- DVSM.unsafeNew n
state' <- go 0 state mv
return (state', mv)
where go :: Int -> s -> DVSM.MVector t a -> ST t s
go i s mv = if i < DVSM.length mv
then do
let (s', a) = f i s
DVSM.unsafeWrite mv i a
go (i + 1) s' mv
else return s
construct2N :: (Storable b, Storable c)
=> Int
-> (forall s. a -> DVSM.MVector s b -> ST s Int)
-> Int
-> (forall s. a -> DVSM.MVector s c -> ST s Int)
-> [a]
-> (DVS.Vector b, DVS.Vector c)
construct2N nb fb nc fc as = runST $ do
mbs <- DVSM.unsafeNew nb
mcs <- DVSM.unsafeNew nc
(mbs2, mcs2) <- go fb 0 mbs fc 0 mcs as
bs <- DVG.unsafeFreeze mbs2
cs <- DVG.unsafeFreeze mcs2
return (bs, cs)
where go :: (Storable b, Storable c)
=> (forall t. a -> DVSM.MVector t b -> ST t Int)
-> Int
-> DVSM.MVector s b
-> (forall t. a -> DVSM.MVector t c -> ST t Int)
-> Int
-> DVSM.MVector s c
-> [a]
-> ST s (DVSM.MVector s b, DVSM.MVector s c)
go _ bn mbs _ cn mcs [] = return (DVSM.take bn mbs, DVSM.take cn mcs)
go fb' bn mbs fc' cn mcs (d:ds) = do
bi <- fb' d (DVSM.drop bn mbs)
ci <- fc' d (DVSM.drop cn mcs)
go fb' (bn + bi) mbs fc' (cn + ci) mcs ds
construct64UnzipN :: Int -> [(BS.ByteString, BS.ByteString)] -> (DVS.Vector Word64, DVS.Vector Word64)
construct64UnzipN nBytes xs = (DVS.unsafeCast ibv, DVS.unsafeCast bpv)
where [ibv, bpv] = DVS.createT $ do
let nW64s = (nBytes + 7) `div` 8
let capacity = nW64s * 8
ibmv <- DVSM.new capacity
bpmv <- DVSM.new capacity
(ibmvRemaining, bpmvRemaining) <- go ibmv bpmv xs
return
[ DVSM.take (((DVSM.length ibmv - ibmvRemaining) `div` 8) * 8) ibmv
, DVSM.take (((DVSM.length bpmv - bpmvRemaining) `div` 8) * 8) bpmv
]
go :: DVSM.MVector s Word8 -> DVSM.MVector s Word8 -> [(BS.ByteString, BS.ByteString)] -> ST s (Int, Int)
go ibmv bpmv ((ib, bp):ys) = do
DVS.copy (DVSM.take (BS.length ib) ibmv) (asVector8 ib)
DVS.copy (DVSM.take (BS.length bp) bpmv) (asVector8 bp)
go (DVSM.drop (BS.length ib) ibmv) (DVSM.drop (BS.length bp) bpmv) ys
go ibmv bpmv [] = do
DVSM.set (DVSM.take 8 ibmv) 0
DVSM.set (DVSM.take 8 bpmv) 0
return (DVSM.length (DVSM.drop 8 ibmv), DVSM.length (DVSM.drop 8 bpmv))