{-# LANGUAGE BangPatterns        #-}
{-# LANGUAGE RankNTypes          #-}
{-# LANGUAGE ScopedTypeVariables #-}

module HaskellWorks.Data.Vector.Storable
  ( padded
  , foldMap
  , mapAccumL
  , mmap
  , constructSI
  , construct2N
  , construct64UnzipN
  , unzipFromListN2
  ) where

import Control.Monad.ST                   (ST, runST)
import Data.Vector.Storable               (Storable)
import Data.Word
import Foreign.ForeignPtr
import HaskellWorks.Data.Vector.AsVector8
import Prelude                            hiding (abs, foldMap)

import qualified Data.ByteString              as BS
import qualified Data.Vector.Generic          as DVG
import qualified Data.Vector.Storable         as DVS
import qualified Data.Vector.Storable.Mutable as DVSM
import qualified System.IO.MMap               as IO

{- HLINT ignore "Redundant do" -}

padded :: Int -> DVS.Vector Word8 -> DVS.Vector Word8
padded :: Int -> Vector Word8 -> Vector Word8
padded Int
n Vector Word8
v = Vector Word8
v forall a. Semigroup a => a -> a -> a
<> forall a. Storable a => Int -> a -> Vector a
DVS.replicate ((Int
n forall a. Num a => a -> a -> a
- forall a. Storable a => Vector a -> Int
DVS.length Vector Word8
v) forall a. Ord a => a -> a -> a
`max` Int
0) Word8
0
{-# INLINE padded #-}

foldMap :: (DVS.Storable a, Monoid m) => (a -> m) -> DVS.Vector a -> m
foldMap :: forall a m. (Storable a, Monoid m) => (a -> m) -> Vector a -> m
foldMap a -> m
f = forall b a. Storable b => (a -> b -> a) -> a -> Vector b -> a
DVS.foldl' (\m
a a
b -> m
a forall a. Semigroup a => a -> a -> a
<> a -> m
f a
b) forall a. Monoid a => a
mempty
{-# INLINE foldMap #-}

mapAccumL :: forall a b c. (Storable b, Storable c)
  => (a -> b -> (a, c))
  -> a
  -> DVS.Vector b
  -> (a, DVS.Vector c)
mapAccumL :: forall a b c.
(Storable b, Storable c) =>
(a -> b -> (a, c)) -> a -> Vector b -> (a, Vector c)
mapAccumL a -> b -> (a, c)
f a
a Vector b
vb = forall (f :: * -> *) a.
(Traversable f, Storable a) =>
(forall s. ST s (f (MVector s a))) -> f (Vector a)
DVS.createT forall a b. (a -> b) -> a -> b
$ do
  MVector s c
vc <- forall (m :: * -> *) a.
(PrimMonad m, Storable a) =>
Int -> m (MVector (PrimState m) a)
DVSM.unsafeNew (forall a. Storable a => Vector a -> Int
DVS.length Vector b
vb)
  a
a' <- forall s. Int -> a -> MVector s c -> ST s a
go Int
0 a
a MVector s c
vc
  forall (m :: * -> *) a. Monad m => a -> m a
return (a
a', MVector s c
vc)
  where go :: Int -> a -> DVS.MVector s c -> ST s a
        go :: forall s. Int -> a -> MVector s c -> ST s a
go Int
i a
a0 MVector s c
vc = if Int
i forall a. Ord a => a -> a -> Bool
< forall a. Storable a => Vector a -> Int
DVS.length Vector b
vb
          then do
            let (a
a1, c
c1) = a -> b -> (a, c)
f a
a0 (forall a. Storable a => Vector a -> Int -> a
DVS.unsafeIndex Vector b
vb Int
i)
            forall (m :: * -> *) a.
(PrimMonad m, Storable a) =>
MVector (PrimState m) a -> Int -> a -> m ()
DVSM.unsafeWrite MVector s c
vc Int
i c
c1
            forall s. Int -> a -> MVector s c -> ST s a
go (Int
i forall a. Num a => a -> a -> a
+ Int
1) a
a1 MVector s c
vc
          else forall (m :: * -> *) a. Monad m => a -> m a
return a
a0
{-# INLINE mapAccumL #-}

-- | MMap the file as a storable vector.  If the size of the file is not a multiple of the element size
-- in bytes, then the last few bytes of the file will not be included in the vector.
mmap :: Storable a => FilePath -> IO (DVS.Vector a)
mmap :: forall a. Storable a => FilePath -> IO (Vector a)
mmap FilePath
filepath = do
  (ForeignPtr Word8
fptr :: ForeignPtr Word8, Int
offset, Int
size) <- forall a.
FilePath
-> Mode -> Maybe (Int64, Int) -> IO (ForeignPtr a, Int, Int)
IO.mmapFileForeignPtr FilePath
filepath Mode
IO.ReadOnly forall a. Maybe a
Nothing
  let !v :: Vector Word8
v = forall a. Storable a => ForeignPtr a -> Int -> Int -> Vector a
DVS.unsafeFromForeignPtr ForeignPtr Word8
fptr Int
offset Int
size
  forall (m :: * -> *) a. Monad m => a -> m a
return (forall a b. (Storable a, Storable b) => Vector a -> Vector b
DVS.unsafeCast Vector Word8
v)

-- | Construct a vector statefully with index
constructSI :: forall a s. Storable a => Int -> (Int -> s -> (s, a)) -> s -> (s, DVS.Vector a)
constructSI :: forall a s.
Storable a =>
Int -> (Int -> s -> (s, a)) -> s -> (s, Vector a)
constructSI Int
n Int -> s -> (s, a)
f s
state = forall (f :: * -> *) a.
(Traversable f, Storable a) =>
(forall s. ST s (f (MVector s a))) -> f (Vector a)
DVS.createT forall a b. (a -> b) -> a -> b
$ do
  MVector s a
mv <- forall (m :: * -> *) a.
(PrimMonad m, Storable a) =>
Int -> m (MVector (PrimState m) a)
DVSM.unsafeNew Int
n
  s
state' <- forall t. Int -> s -> MVector t a -> ST t s
go Int
0 s
state MVector s a
mv
  forall (m :: * -> *) a. Monad m => a -> m a
return (s
state', MVector s a
mv)
  where go :: Int -> s -> DVSM.MVector t a -> ST t s
        go :: forall t. Int -> s -> MVector t a -> ST t s
go Int
i s
s MVector t a
mv = if Int
i forall a. Ord a => a -> a -> Bool
< forall a s. Storable a => MVector s a -> Int
DVSM.length MVector t a
mv
          then do
            let (s
s', a
a) = Int -> s -> (s, a)
f Int
i s
s
            forall (m :: * -> *) a.
(PrimMonad m, Storable a) =>
MVector (PrimState m) a -> Int -> a -> m ()
DVSM.unsafeWrite MVector t a
mv Int
i a
a
            forall t. Int -> s -> MVector t a -> ST t s
go (Int
i forall a. Num a => a -> a -> a
+ Int
1) s
s' MVector t a
mv
          else forall (m :: * -> *) a. Monad m => a -> m a
return s
s

construct2N :: (Storable b, Storable c)
  => Int
  -> (forall s. a -> DVSM.MVector s b -> ST s Int)
  -> Int
  -> (forall s. a -> DVSM.MVector s c -> ST s Int)
  -> [a]
  -> (DVS.Vector b, DVS.Vector c)
construct2N :: forall b c a.
(Storable b, Storable c) =>
Int
-> (forall s. a -> MVector s b -> ST s Int)
-> Int
-> (forall s. a -> MVector s c -> ST s Int)
-> [a]
-> (Vector b, Vector c)
construct2N Int
nb forall s. a -> MVector s b -> ST s Int
fb Int
nc forall s. a -> MVector s c -> ST s Int
fc [a]
as = forall a. (forall s. ST s a) -> a
runST forall a b. (a -> b) -> a -> b
$ do
  MVector s b
mbs <- forall (m :: * -> *) a.
(PrimMonad m, Storable a) =>
Int -> m (MVector (PrimState m) a)
DVSM.unsafeNew Int
nb
  MVector s c
mcs <- forall (m :: * -> *) a.
(PrimMonad m, Storable a) =>
Int -> m (MVector (PrimState m) a)
DVSM.unsafeNew Int
nc
  (MVector s b
mbs2, MVector s c
mcs2) <- forall b c a s.
(Storable b, Storable c) =>
(forall t. a -> MVector t b -> ST t Int)
-> Int
-> MVector s b
-> (forall t. a -> MVector t c -> ST t Int)
-> Int
-> MVector s c
-> [a]
-> ST s (MVector s b, MVector s c)
go forall s. a -> MVector s b -> ST s Int
fb Int
0 MVector s b
mbs forall s. a -> MVector s c -> ST s Int
fc Int
0 MVector s c
mcs [a]
as
  Vector b
bs <- forall (m :: * -> *) (v :: * -> *) a.
(PrimMonad m, Vector v a) =>
Mutable v (PrimState m) a -> m (v a)
DVG.unsafeFreeze MVector s b
mbs2
  Vector c
cs <- forall (m :: * -> *) (v :: * -> *) a.
(PrimMonad m, Vector v a) =>
Mutable v (PrimState m) a -> m (v a)
DVG.unsafeFreeze MVector s c
mcs2
  forall (m :: * -> *) a. Monad m => a -> m a
return (Vector b
bs, Vector c
cs)
  where go :: (Storable b, Storable c)
          => (forall t. a -> DVSM.MVector t b -> ST t Int)
          -> Int
          -> DVSM.MVector s b
          -> (forall t. a -> DVSM.MVector t c -> ST t Int)
          -> Int
          -> DVSM.MVector s c
          -> [a]
          -> ST s (DVSM.MVector s b, DVSM.MVector s c)
        go :: forall b c a s.
(Storable b, Storable c) =>
(forall t. a -> MVector t b -> ST t Int)
-> Int
-> MVector s b
-> (forall t. a -> MVector t c -> ST t Int)
-> Int
-> MVector s c
-> [a]
-> ST s (MVector s b, MVector s c)
go   forall t. a -> MVector t b -> ST t Int
_ Int
bn MVector s b
mbs   forall t. a -> MVector t c -> ST t Int
_ Int
cn MVector s c
mcs []     = forall (m :: * -> *) a. Monad m => a -> m a
return (forall a s. Storable a => Int -> MVector s a -> MVector s a
DVSM.take Int
bn MVector s b
mbs, forall a s. Storable a => Int -> MVector s a -> MVector s a
DVSM.take Int
cn MVector s c
mcs)
        go forall t. a -> MVector t b -> ST t Int
fb' Int
bn MVector s b
mbs forall t. a -> MVector t c -> ST t Int
fc' Int
cn MVector s c
mcs (a
d:[a]
ds) = do
          Int
bi <- forall t. a -> MVector t b -> ST t Int
fb' a
d (forall a s. Storable a => Int -> MVector s a -> MVector s a
DVSM.drop Int
bn MVector s b
mbs)
          Int
ci <- forall t. a -> MVector t c -> ST t Int
fc' a
d (forall a s. Storable a => Int -> MVector s a -> MVector s a
DVSM.drop Int
cn MVector s c
mcs)
          forall b c a s.
(Storable b, Storable c) =>
(forall t. a -> MVector t b -> ST t Int)
-> Int
-> MVector s b
-> (forall t. a -> MVector t c -> ST t Int)
-> Int
-> MVector s c
-> [a]
-> ST s (MVector s b, MVector s c)
go forall t. a -> MVector t b -> ST t Int
fb' (Int
bn forall a. Num a => a -> a -> a
+ Int
bi) MVector s b
mbs forall t. a -> MVector t c -> ST t Int
fc' (Int
cn forall a. Num a => a -> a -> a
+ Int
ci) MVector s c
mcs [a]
ds

construct64UnzipN :: Int -> [(BS.ByteString, BS.ByteString)] -> (DVS.Vector Word64, DVS.Vector Word64)
construct64UnzipN :: Int -> [(ByteString, ByteString)] -> (Vector Word64, Vector Word64)
construct64UnzipN Int
nBytes [(ByteString, ByteString)]
xs = (forall a b. (Storable a, Storable b) => Vector a -> Vector b
DVS.unsafeCast Vector Word8
ibv, forall a b. (Storable a, Storable b) => Vector a -> Vector b
DVS.unsafeCast Vector Word8
bpv)
  where [Vector Word8
ibv, Vector Word8
bpv] = forall (f :: * -> *) a.
(Traversable f, Storable a) =>
(forall s. ST s (f (MVector s a))) -> f (Vector a)
DVS.createT forall a b. (a -> b) -> a -> b
$ do
          let nW8s :: Int
nW8s     = (Int
nBytes forall a. Num a => a -> a -> a
+ Int
7) forall a. Integral a => a -> a -> a
`div` Int
8 forall a. Num a => a -> a -> a
* Int
8
          MVector s Word8
ibmv <- forall (m :: * -> *) a.
(PrimMonad m, Storable a) =>
Int -> m (MVector (PrimState m) a)
DVSM.new Int
nW8s
          MVector s Word8
bpmv <- forall (m :: * -> *) a.
(PrimMonad m, Storable a) =>
Int -> m (MVector (PrimState m) a)
DVSM.new Int
nW8s
          (Int
ibmvRemaining, Int
bpmvRemaining) <- forall s.
MVector s Word8
-> MVector s Word8 -> [(ByteString, ByteString)] -> ST s (Int, Int)
go MVector s Word8
ibmv MVector s Word8
bpmv [(ByteString, ByteString)]
xs
          let ibl :: Int
ibl = ((forall a s. Storable a => MVector s a -> Int
DVSM.length MVector s Word8
ibmv forall a. Num a => a -> a -> a
- Int
ibmvRemaining forall a. Num a => a -> a -> a
+ Int
7) forall a. Integral a => a -> a -> a
`div` Int
8) forall a. Num a => a -> a -> a
* Int
8
          let bpl :: Int
bpl = ((forall a s. Storable a => MVector s a -> Int
DVSM.length MVector s Word8
bpmv forall a. Num a => a -> a -> a
- Int
bpmvRemaining forall a. Num a => a -> a -> a
+ Int
7) forall a. Integral a => a -> a -> a
`div` Int
8) forall a. Num a => a -> a -> a
* Int
8
          forall (m :: * -> *) a. Monad m => a -> m a
return [forall a s. Storable a => Int -> MVector s a -> MVector s a
DVSM.take Int
ibl MVector s Word8
ibmv, forall a s. Storable a => Int -> MVector s a -> MVector s a
DVSM.take Int
bpl MVector s Word8
bpmv]
        go :: DVSM.MVector s Word8 -> DVSM.MVector s Word8 -> [(BS.ByteString, BS.ByteString)] -> ST s (Int, Int)
        go :: forall s.
MVector s Word8
-> MVector s Word8 -> [(ByteString, ByteString)] -> ST s (Int, Int)
go MVector s Word8
ibmv MVector s Word8
bpmv ((ByteString
ib, ByteString
bp):[(ByteString, ByteString)]
ys) = do
          forall a (m :: * -> *).
(Storable a, PrimMonad m) =>
MVector (PrimState m) a -> Vector a -> m ()
DVS.copy (forall a s. Storable a => Int -> MVector s a -> MVector s a
DVSM.take (ByteString -> Int
BS.length ByteString
ib) MVector s Word8
ibmv) (forall a. AsVector8 a => a -> Vector Word8
asVector8 ByteString
ib)
          forall a (m :: * -> *).
(Storable a, PrimMonad m) =>
MVector (PrimState m) a -> Vector a -> m ()
DVS.copy (forall a s. Storable a => Int -> MVector s a -> MVector s a
DVSM.take (ByteString -> Int
BS.length ByteString
bp) MVector s Word8
bpmv) (forall a. AsVector8 a => a -> Vector Word8
asVector8 ByteString
bp)
          forall s.
MVector s Word8
-> MVector s Word8 -> [(ByteString, ByteString)] -> ST s (Int, Int)
go (forall a s. Storable a => Int -> MVector s a -> MVector s a
DVSM.drop (ByteString -> Int
BS.length ByteString
ib) MVector s Word8
ibmv) (forall a s. Storable a => Int -> MVector s a -> MVector s a
DVSM.drop (ByteString -> Int
BS.length ByteString
bp) MVector s Word8
bpmv) [(ByteString, ByteString)]
ys
        go MVector s Word8
ibmv MVector s Word8
bpmv [] = do
          let ibl :: Int
ibl = forall a s. Storable a => MVector s a -> Int
DVSM.length MVector s Word8
ibmv forall a. Integral a => a -> a -> a
`mod` Int
8
          let bpl :: Int
bpl = forall a s. Storable a => MVector s a -> Int
DVSM.length MVector s Word8
bpmv forall a. Integral a => a -> a -> a
`mod` Int
8
          forall (m :: * -> *) a.
(PrimMonad m, Storable a) =>
MVector (PrimState m) a -> a -> m ()
DVSM.set (forall a s. Storable a => Int -> MVector s a -> MVector s a
DVSM.take Int
ibl MVector s Word8
ibmv) Word8
0
          forall (m :: * -> *) a.
(PrimMonad m, Storable a) =>
MVector (PrimState m) a -> a -> m ()
DVSM.set (forall a s. Storable a => Int -> MVector s a -> MVector s a
DVSM.take Int
bpl MVector s Word8
bpmv) Word8
0
          forall (m :: * -> *) a. Monad m => a -> m a
return (forall a s. Storable a => MVector s a -> Int
DVSM.length (forall a s. Storable a => Int -> MVector s a -> MVector s a
DVSM.drop Int
ibl MVector s Word8
ibmv), forall a s. Storable a => MVector s a -> Int
DVSM.length (forall a s. Storable a => Int -> MVector s a -> MVector s a
DVSM.drop Int
bpl MVector s Word8
bpmv))

unzipFromListN2 :: (Storable a, Storable b) => Int -> [(a, b)] -> (DVS.Vector a, DVS.Vector b)
unzipFromListN2 :: forall a b.
(Storable a, Storable b) =>
Int -> [(a, b)] -> (Vector a, Vector b)
unzipFromListN2 Int
n [(a, b)]
abs = forall a. (forall s. ST s a) -> a
runST forall a b. (a -> b) -> a -> b
$ do
  MVector s a
mas <- forall (m :: * -> *) a.
(PrimMonad m, Storable a) =>
Int -> m (MVector (PrimState m) a)
DVSM.unsafeNew Int
n
  MVector s b
mbs <- forall (m :: * -> *) a.
(PrimMonad m, Storable a) =>
Int -> m (MVector (PrimState m) a)
DVSM.unsafeNew Int
n
  Int
len <- forall c d s.
(Storable c, Storable d) =>
Int -> MVector s c -> MVector s d -> [(c, d)] -> ST s Int
go Int
0 MVector s a
mas MVector s b
mbs [(a, b)]
abs
  Vector a
as <- forall (m :: * -> *) (v :: * -> *) a.
(PrimMonad m, Vector v a) =>
Mutable v (PrimState m) a -> m (v a)
DVG.unsafeFreeze (forall a s. Storable a => Int -> MVector s a -> MVector s a
DVSM.take Int
len MVector s a
mas)
  Vector b
bs <- forall (m :: * -> *) (v :: * -> *) a.
(PrimMonad m, Vector v a) =>
Mutable v (PrimState m) a -> m (v a)
DVG.unsafeFreeze (forall a s. Storable a => Int -> MVector s a -> MVector s a
DVSM.take Int
len MVector s b
mbs)
  forall (m :: * -> *) a. Monad m => a -> m a
return (Vector a
as, Vector b
bs)
  where go :: (Storable c, Storable d)
          => Int
          -> DVSM.MVector s c
          -> DVSM.MVector s d
          -> [(c, d)]
          -> ST s Int
        go :: forall c d s.
(Storable c, Storable d) =>
Int -> MVector s c -> MVector s d -> [(c, d)] -> ST s Int
go Int
i MVector s c
_   MVector s d
_   []           = forall (m :: * -> *) a. Monad m => a -> m a
return Int
i
        go Int
i MVector s c
mvc MVector s d
mvd ((c
c, d
d):[(c, d)]
cds) = if Int
i forall a. Ord a => a -> a -> Bool
< Int
n
          then do
            forall (m :: * -> *) a.
(PrimMonad m, Storable a) =>
MVector (PrimState m) a -> Int -> a -> m ()
DVSM.write MVector s c
mvc Int
i c
c
            forall (m :: * -> *) a.
(PrimMonad m, Storable a) =>
MVector (PrimState m) a -> Int -> a -> m ()
DVSM.write MVector s d
mvd Int
i d
d
            forall c d s.
(Storable c, Storable d) =>
Int -> MVector s c -> MVector s d -> [(c, d)] -> ST s Int
go (Int
i forall a. Num a => a -> a -> a
+ Int
1) MVector s c
mvc MVector s d
mvd [(c, d)]
cds
          else forall (m :: * -> *) a. Monad m => a -> m a
return Int
i