{-# LANGUAGE CPP #-}
{-# LANGUAGE FlexibleContexts    #-}
{-# LANGUAGE FlexibleInstances   #-}
{-# LANGUAGE Rank2Types          #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TupleSections       #-}

module HaskellWorks.Data.Vector.AsVector64s
  ( AsVector64s(..)
  ) where

import Control.Monad.ST
import Data.Word
import Foreign.ForeignPtr

import qualified Data.ByteString              as BS
import qualified Data.ByteString.Internal     as BS
import qualified Data.ByteString.Lazy         as LBS
import qualified Data.Vector.Storable         as DVS
import qualified Data.Vector.Storable.Mutable as DVSM

#if !MIN_VERSION_base(4,13,0)
import Control.Applicative ((<$>)) -- Fix warning in ghc >= 9.2
#endif

class AsVector64s a where
  -- | Represent the value as a list of Vector of 'n' Word64 chunks.  The last chunk will
  -- also be of the specified chunk size filled with trailing zeros.
  asVector64s :: Int -> a -> [DVS.Vector Word64]

instance AsVector64s LBS.ByteString where
  asVector64s :: Int -> ByteString -> [Vector Word64]
asVector64s Int
n = Int -> [ByteString] -> [Vector Word64]
forall a. AsVector64s a => Int -> a -> [Vector Word64]
asVector64s Int
n ([ByteString] -> [Vector Word64])
-> (ByteString -> [ByteString]) -> ByteString -> [Vector Word64]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> [ByteString]
LBS.toChunks
  {-# INLINE asVector64s #-}

instance AsVector64s [BS.ByteString] where
  asVector64s :: Int -> [ByteString] -> [Vector Word64]
asVector64s = Int -> [ByteString] -> [Vector Word64]
bytestringsToVectors
  {-# INLINE asVector64s #-}

bytestringsToVectors :: Int -> [BS.ByteString] -> [DVS.Vector Word64]
bytestringsToVectors :: Int -> [ByteString] -> [Vector Word64]
bytestringsToVectors Int
n = [ByteString] -> [Vector Word64]
go
  where go :: [BS.ByteString] -> [DVS.Vector Word64]
        go :: [ByteString] -> [Vector Word64]
go [ByteString]
bs = case (forall s. ST s ([ByteString], MVector s Word64))
-> ([ByteString], Vector Word64)
forall (f :: * -> *) a.
(Traversable f, Storable a) =>
(forall s. ST s (f (MVector s a))) -> f (Vector a)
DVS.createT (Int -> [ByteString] -> ST s ([ByteString], MVector s Word64)
forall s.
Int -> [ByteString] -> ST s ([ByteString], MVector s Word64)
buildOneVector Int
n [ByteString]
bs) of
          ([ByteString]
cs, Vector Word64
ws) -> if Vector Word64 -> Int
forall a. Storable a => Vector a -> Int
DVS.length Vector Word64
ws Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
0
            then Vector Word64
wsVector Word64 -> [Vector Word64] -> [Vector Word64]
forall a. a -> [a] -> [a]
:[ByteString] -> [Vector Word64]
go [ByteString]
cs
            else []
{-# INLINE bytestringsToVectors #-}

buildOneVector :: forall s. Int -> [BS.ByteString] -> ST s ([BS.ByteString], DVS.MVector s Word64)
buildOneVector :: Int -> [ByteString] -> ST s ([ByteString], MVector s Word64)
buildOneVector Int
n [ByteString]
ss = case (ByteString -> Bool) -> [ByteString] -> [ByteString]
forall a. (a -> Bool) -> [a] -> [a]
dropWhile ((Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0) (Int -> Bool) -> (ByteString -> Int) -> ByteString -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> Int
BS.length) [ByteString]
ss of
  [] -> ([],) (MVector s Word64 -> ([ByteString], MVector s Word64))
-> ST s (MVector s Word64) -> ST s ([ByteString], MVector s Word64)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Int -> ST s (MVector (PrimState (ST s)) Word64)
forall (m :: * -> *) a.
(PrimMonad m, Storable a) =>
Int -> m (MVector (PrimState m) a)
DVSM.new Int
0
  [ByteString]
cs -> do
    MVector s Word64
v64 <- Int -> ST s (MVector (PrimState (ST s)) Word64)
forall (m :: * -> *) a.
(PrimMonad m, Storable a) =>
Int -> m (MVector (PrimState m) a)
DVSM.unsafeNew Int
n
    let v8 :: MVector s Word8
v8 = MVector s Word64 -> MVector s Word8
forall a b s.
(Storable a, Storable b) =>
MVector s a -> MVector s b
DVSM.unsafeCast MVector s Word64
v64
    [ByteString]
rs  <- [ByteString] -> MVector s Word8 -> ST s [ByteString]
go [ByteString]
cs MVector s Word8
v8
    ([ByteString], MVector s Word64)
-> ST s ([ByteString], MVector s Word64)
forall (m :: * -> *) a. Monad m => a -> m a
return ([ByteString]
rs, MVector s Word64
v64)
  where go :: [BS.ByteString] -> DVSM.MVector s Word8 -> ST s [BS.ByteString]
        go :: [ByteString] -> MVector s Word8 -> ST s [ByteString]
go [ByteString]
ts MVector s Word8
v = if MVector s Word8 -> Int
forall a s. Storable a => MVector s a -> Int
DVSM.length MVector s Word8
v Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
0
          then case [ByteString]
ts of
            (ByteString
u:[ByteString]
us) -> if ByteString -> Int
BS.length ByteString
u Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= MVector s Word8 -> Int
forall a s. Storable a => MVector s a -> Int
DVSM.length MVector s Word8
v
              then case Int -> MVector s Word8 -> (MVector s Word8, MVector s Word8)
forall a s.
Storable a =>
Int -> MVector s a -> (MVector s a, MVector s a)
DVSM.splitAt (ByteString -> Int
BS.length ByteString
u) MVector s Word8
v of
                (MVector s Word8
va, MVector s Word8
vb) -> do
                  MVector (PrimState (ST s)) Word8
-> MVector (PrimState (ST s)) Word8 -> ST s ()
forall (m :: * -> *) a.
(PrimMonad m, Storable a) =>
MVector (PrimState m) a -> MVector (PrimState m) a -> m ()
DVSM.copy MVector s Word8
MVector (PrimState (ST s)) Word8
va (ByteString -> MVector s Word8
forall s. ByteString -> MVector s Word8
byteStringToVector8 ByteString
u)
                  [ByteString] -> MVector s Word8 -> ST s [ByteString]
go [ByteString]
us MVector s Word8
vb
              else case Int -> ByteString -> (ByteString, ByteString)
BS.splitAt (MVector s Word8 -> Int
forall a s. Storable a => MVector s a -> Int
DVSM.length MVector s Word8
v) ByteString
u of
                (ByteString
ua, ByteString
ub) -> do
                  MVector (PrimState (ST s)) Word8
-> MVector (PrimState (ST s)) Word8 -> ST s ()
forall (m :: * -> *) a.
(PrimMonad m, Storable a) =>
MVector (PrimState m) a -> MVector (PrimState m) a -> m ()
DVSM.copy MVector s Word8
MVector (PrimState (ST s)) Word8
v (ByteString -> MVector s Word8
forall s. ByteString -> MVector s Word8
byteStringToVector8 ByteString
ua)
                  [ByteString] -> ST s [ByteString]
forall (m :: * -> *) a. Monad m => a -> m a
return (ByteString
ubByteString -> [ByteString] -> [ByteString]
forall a. a -> [a] -> [a]
:[ByteString]
us)
            [] -> do
              MVector (PrimState (ST s)) Word8 -> Word8 -> ST s ()
forall (m :: * -> *) a.
(PrimMonad m, Storable a) =>
MVector (PrimState m) a -> a -> m ()
DVSM.set MVector s Word8
MVector (PrimState (ST s)) Word8
v Word8
0
              [ByteString] -> ST s [ByteString]
forall (m :: * -> *) a. Monad m => a -> m a
return []
          else [ByteString] -> ST s [ByteString]
forall (m :: * -> *) a. Monad m => a -> m a
return [ByteString]
ts
        {-# INLINE go #-}
{-# INLINE buildOneVector #-}

byteStringToVector8 :: BS.ByteString -> DVSM.MVector s Word8
byteStringToVector8 :: ByteString -> MVector s Word8
byteStringToVector8 ByteString
bs = case ByteString -> (ForeignPtr Word8, Int, Int)
BS.toForeignPtr ByteString
bs of
  (ForeignPtr Word8
fptr, Int
off, Int
len) -> ForeignPtr Word8 -> Int -> Int -> MVector s Word8
forall a s. Storable a => ForeignPtr a -> Int -> Int -> MVector s a
DVSM.unsafeFromForeignPtr (ForeignPtr Word8 -> ForeignPtr Word8
forall a b. ForeignPtr a -> ForeignPtr b
castForeignPtr ForeignPtr Word8
fptr) Int
off Int
len
{-# INLINE byteStringToVector8 #-}