{-# LANGUAGE CPP                 #-}
{-# LANGUAGE FlexibleContexts    #-}
{-# LANGUAGE FlexibleInstances   #-}
{-# LANGUAGE Rank2Types          #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TupleSections       #-}

module HaskellWorks.Data.Vector.AsVector8ns
  ( AsVector8ns(..)
  ) where

import Control.Monad.ST
import Data.Word
import Foreign.ForeignPtr
import HaskellWorks.Data.Vector.AsVector8

import qualified Data.ByteString              as BS
import qualified Data.ByteString.Internal     as BS
import qualified Data.ByteString.Lazy         as LBS
import qualified Data.Vector.Storable         as DVS
import qualified Data.Vector.Storable.Mutable as DVSM

#if !MIN_VERSION_base(4,13,0)
import Control.Applicative ((<$>)) -- Fix warning in ghc >= 9.2
#endif

class AsVector8ns a where
  -- | Represent the value as a list of Vector of 'n' Word8 chunks.  The last chunk will
  -- also be of the specified chunk size filled with trailing zeros.
  asVector8ns :: Int -> a -> [DVS.Vector Word8]

instance AsVector8ns LBS.ByteString where
  asVector8ns :: Int -> ByteString -> [Vector Word8]
asVector8ns Int
n = Int -> [ByteString] -> [Vector Word8]
forall a. AsVector8ns a => Int -> a -> [Vector Word8]
asVector8ns Int
n ([ByteString] -> [Vector Word8])
-> (ByteString -> [ByteString]) -> ByteString -> [Vector Word8]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> [ByteString]
LBS.toChunks
  {-# INLINE asVector8ns #-}

instance AsVector8ns [BS.ByteString] where
  asVector8ns :: Int -> [ByteString] -> [Vector Word8]
asVector8ns = Int -> [ByteString] -> [Vector Word8]
bytestringsToVectors
  {-# INLINE asVector8ns #-}

bytestringsToVectors :: Int -> [BS.ByteString] -> [DVS.Vector Word8]
bytestringsToVectors :: Int -> [ByteString] -> [Vector Word8]
bytestringsToVectors Int
n = [ByteString] -> [Vector Word8]
go
  where go :: [BS.ByteString] -> [DVS.Vector Word8]
        go :: [ByteString] -> [Vector Word8]
go [ByteString]
bss = case [ByteString]
bss of
          (ByteString
cs:[ByteString]
css) -> let csz :: Int
csz = ByteString -> Int
BS.length ByteString
cs in
            if Int
csz Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
n
              then if Int
csz Int -> Int -> Int
forall a. Integral a => a -> a -> a
`mod` Int
n Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0
                then ByteString -> Vector Word8
forall a. AsVector8 a => a -> Vector Word8
asVector8 ByteString
csVector Word8 -> [Vector Word8] -> [Vector Word8]
forall a. a -> [a] -> [a]
:Int -> [ByteString] -> [Vector Word8]
bytestringsToVectors Int
n [ByteString]
css
                else let p :: Int
p = (Int
csz Int -> Int -> Int
forall a. Integral a => a -> a -> a
`div` Int
n) Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
n in
                  ByteString -> Vector Word8
forall a. AsVector8 a => a -> Vector Word8
asVector8 (Int -> ByteString -> ByteString
BS.take Int
p ByteString
cs)Vector Word8 -> [Vector Word8] -> [Vector Word8]
forall a. a -> [a] -> [a]
:Int -> [ByteString] -> [Vector Word8]
bytestringsToVectors Int
n (Int -> ByteString -> ByteString
BS.drop Int
p ByteString
csByteString -> [ByteString] -> [ByteString]
forall a. a -> [a] -> [a]
:[ByteString]
css)
              else if Int
csz Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
0
                then case (forall s. ST s ([ByteString], MVector s Word8))
-> ([ByteString], Vector Word8)
forall (f :: * -> *) a.
(Traversable f, Storable a) =>
(forall s. ST s (f (MVector s a))) -> f (Vector a)
DVS.createT (Int -> [ByteString] -> ST s ([ByteString], MVector s Word8)
forall s.
Int -> [ByteString] -> ST s ([ByteString], MVector s Word8)
buildOneVector Int
n [ByteString]
bss) of
                  ([ByteString]
dss, Vector Word8
ws) -> if Vector Word8 -> Int
forall a. Storable a => Vector a -> Int
DVS.length Vector Word8
ws Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
0
                    then Vector Word8
wsVector Word8 -> [Vector Word8] -> [Vector Word8]
forall a. a -> [a] -> [a]
:[ByteString] -> [Vector Word8]
go [ByteString]
dss
                    else []
                else Int -> [ByteString] -> [Vector Word8]
bytestringsToVectors Int
n [ByteString]
css
          [] -> []
{-# INLINE bytestringsToVectors #-}

buildOneVector :: forall s. Int -> [BS.ByteString] -> ST s ([BS.ByteString], DVS.MVector s Word8)
buildOneVector :: Int -> [ByteString] -> ST s ([ByteString], MVector s Word8)
buildOneVector Int
n [ByteString]
ss = case (ByteString -> Bool) -> [ByteString] -> [ByteString]
forall a. (a -> Bool) -> [a] -> [a]
dropWhile ((Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0) (Int -> Bool) -> (ByteString -> Int) -> ByteString -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> Int
BS.length) [ByteString]
ss of
  [] -> ([],) (MVector s Word8 -> ([ByteString], MVector s Word8))
-> ST s (MVector s Word8) -> ST s ([ByteString], MVector s Word8)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Int -> ST s (MVector (PrimState (ST s)) Word8)
forall (m :: * -> *) a.
(PrimMonad m, Storable a) =>
Int -> m (MVector (PrimState m) a)
DVSM.new Int
0
  [ByteString]
cs -> do
    MVector s Word8
v64 <- Int -> ST s (MVector (PrimState (ST s)) Word8)
forall (m :: * -> *) a.
(PrimMonad m, Storable a) =>
Int -> m (MVector (PrimState m) a)
DVSM.unsafeNew Int
n
    let v8 :: MVector s Word8
v8 = MVector s Word8 -> MVector s Word8
forall a b s.
(Storable a, Storable b) =>
MVector s a -> MVector s b
DVSM.unsafeCast MVector s Word8
v64
    [ByteString]
rs  <- [ByteString] -> MVector s Word8 -> ST s [ByteString]
go [ByteString]
cs MVector s Word8
v8
    ([ByteString], MVector s Word8)
-> ST s ([ByteString], MVector s Word8)
forall (m :: * -> *) a. Monad m => a -> m a
return ([ByteString]
rs, MVector s Word8
v64)
  where go :: [BS.ByteString] -> DVSM.MVector s Word8 -> ST s [BS.ByteString]
        go :: [ByteString] -> MVector s Word8 -> ST s [ByteString]
go [ByteString]
ts MVector s Word8
v = if MVector s Word8 -> Int
forall a s. Storable a => MVector s a -> Int
DVSM.length MVector s Word8
v Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
0
          then case [ByteString]
ts of
            (ByteString
u:[ByteString]
us) -> if ByteString -> Int
BS.length ByteString
u Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= MVector s Word8 -> Int
forall a s. Storable a => MVector s a -> Int
DVSM.length MVector s Word8
v
              then case Int -> MVector s Word8 -> (MVector s Word8, MVector s Word8)
forall a s.
Storable a =>
Int -> MVector s a -> (MVector s a, MVector s a)
DVSM.splitAt (ByteString -> Int
BS.length ByteString
u) MVector s Word8
v of
                (MVector s Word8
va, MVector s Word8
vb) -> do
                  MVector (PrimState (ST s)) Word8
-> MVector (PrimState (ST s)) Word8 -> ST s ()
forall (m :: * -> *) a.
(PrimMonad m, Storable a) =>
MVector (PrimState m) a -> MVector (PrimState m) a -> m ()
DVSM.copy MVector s Word8
MVector (PrimState (ST s)) Word8
va (ByteString -> MVector s Word8
forall s. ByteString -> MVector s Word8
byteStringToVector8 ByteString
u)
                  [ByteString] -> MVector s Word8 -> ST s [ByteString]
go [ByteString]
us MVector s Word8
vb
              else case Int -> ByteString -> (ByteString, ByteString)
BS.splitAt (MVector s Word8 -> Int
forall a s. Storable a => MVector s a -> Int
DVSM.length MVector s Word8
v) ByteString
u of
                (ByteString
ua, ByteString
ub) -> do
                  MVector (PrimState (ST s)) Word8
-> MVector (PrimState (ST s)) Word8 -> ST s ()
forall (m :: * -> *) a.
(PrimMonad m, Storable a) =>
MVector (PrimState m) a -> MVector (PrimState m) a -> m ()
DVSM.copy MVector s Word8
MVector (PrimState (ST s)) Word8
v (ByteString -> MVector s Word8
forall s. ByteString -> MVector s Word8
byteStringToVector8 ByteString
ua)
                  [ByteString] -> ST s [ByteString]
forall (m :: * -> *) a. Monad m => a -> m a
return (ByteString
ubByteString -> [ByteString] -> [ByteString]
forall a. a -> [a] -> [a]
:[ByteString]
us)
            [] -> do
              MVector (PrimState (ST s)) Word8 -> Word8 -> ST s ()
forall (m :: * -> *) a.
(PrimMonad m, Storable a) =>
MVector (PrimState m) a -> a -> m ()
DVSM.set MVector s Word8
MVector (PrimState (ST s)) Word8
v Word8
0
              [ByteString] -> ST s [ByteString]
forall (m :: * -> *) a. Monad m => a -> m a
return []
          else [ByteString] -> ST s [ByteString]
forall (m :: * -> *) a. Monad m => a -> m a
return [ByteString]
ts
        {-# INLINE go #-}
{-# INLINE buildOneVector #-}

byteStringToVector8 :: BS.ByteString -> DVSM.MVector s Word8
byteStringToVector8 :: ByteString -> MVector s Word8
byteStringToVector8 ByteString
bs = case ByteString -> (ForeignPtr Word8, Int, Int)
BS.toForeignPtr ByteString
bs of
  (ForeignPtr Word8
fptr, Int
off, Int
len) -> ForeignPtr Word8 -> Int -> Int -> MVector s Word8
forall a s. Storable a => ForeignPtr a -> Int -> Int -> MVector s a
DVSM.unsafeFromForeignPtr (ForeignPtr Word8 -> ForeignPtr Word8
forall a b. ForeignPtr a -> ForeignPtr b
castForeignPtr ForeignPtr Word8
fptr) Int
off Int
len
{-# INLINE byteStringToVector8 #-}