{-# LANGUAGE ForeignFunctionInterface #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE NamedFieldPuns #-} {-# LANGUAGE ScopedTypeVariables #-} module Data.Layout.Vector ( Codec , compile , StorableVector (..) , encodeVectors , decodeVector ) where import Control.Monad (when) import Data.ByteString (ByteString) import qualified Data.ByteString as B import Data.ByteString.Internal (ByteString(..)) import qualified Data.Vector.Storable as V import Data.Word (Word8, Word32) import Foreign.C (CInt (..)) import Foreign.ForeignPtr () import Foreign.ForeignPtr (ForeignPtr, withForeignPtr, castForeignPtr) import Foreign.Marshal.Alloc (alloca) import Foreign.Ptr (Ptr, plusPtr, castPtr) import Foreign.Storable (Storable, peek, poke, sizeOf) import System.IO.Unsafe (unsafePerformIO) import Data.Layout.ForeignPtr (mallocPlainForeignPtrBytes) import qualified Data.Layout.Language as L import Data.Layout.Internal (Layout(..), ByteOrder(..)) ------------------------------------------------------------------------ -- Vector Encoding Operations -- | Abstracts over vectors of storable types to allow -- calling `encodeVectors`. The `SV` constructor provides -- proof that the `V.Vector` contains `Storable` elements. data StorableVector where SV :: Storable a => V.Vector a -> StorableVector -- | Creates a strict `ByteString` by interleaving multiple `V.Vector`s. encodeVectors :: [(Codec, StorableVector)] -> ByteString encodeVectors [] = B.empty encodeVectors xs@((codec, SV vec):_) = unsafePerformIO $ do -- allocate memory for the bytestring bp <- mallocPlainForeignPtrBytes bstrBytes -- encode each vector mapM_ (go (DstPtr bp 0)) xs -- return the bytestring return (PS bp 0 bstrBytes) where bstrBytes = encodeReps codec vec * encodedSize codec go dp (c, SV v) = encodeVector c dp v encodeVector :: forall a. Storable a => Codec -> DstPtr -> V.Vector a -> IO () encodeVector codec dstPtr vec = do -- check that the vector type matches the codec checkValueSize "encodeVector" codec vec -- get a pointer to the vector elements let vp = castForeignPtr (fst (V.unsafeToForeignPtr0 vec)) -- encode the vector encode codec n dstPtr (SrcPtr vp 0) where n = encodeReps codec vec encodeReps :: Storable a => Codec -> V.Vector a -> Int encodeReps c v = repetitions "Vector" (vectorSize v) (decodedSize c) ------------------------------------------------------------------------ -- Vector Decoding Operations -- | Creates a `V.Vector` by decoding a strict `ByteString` decodeVector :: forall a. Storable a => Codec -> ByteString -> V.Vector a decodeVector codec bstr@(PS bp bpOff _) = unsafePerformIO $ do -- check that the vector type matches the codec checkValueSize "decodeVector" codec (V.empty :: V.Vector a) -- allocate memory for the vector vp <- mallocPlainForeignPtrBytes vectorBytes -- decode the bytestring in to the vector decode codec n (DstPtr vp 0) (SrcPtr bp bpOff) -- return the vector return (V.unsafeFromForeignPtr0 (castForeignPtr vp) vectorElems) where n = repetitions "ByteString" (B.length bstr) (encodedSize codec) vectorBytes = n * decodedSize codec vectorElems = n * valueCount codec ------------------------------------------------------------------------ -- Vector Utils vectorSize :: forall a. Storable a => V.Vector a -> Int vectorSize v = V.length v * sizeOf (undefined :: a) -- | Ensures the value size of the codec matches the vector type. checkValueSize :: forall a m. (Storable a, Monad m) => String -> Codec -> V.Vector a -> m () checkValueSize fn codec _ = when (codecValueSize /= vectorElemSize) (error errorMsg) where codecValueSize = valueSize codec vectorElemSize = sizeOf (undefined :: a) errorMsg = concat [ "Data.Layout.Vector.", fn, ": " , "Value size mismatch. The value size of a codec (" , show codecValueSize, " bytes) did not match the size of " , "individual elements (", show vectorElemSize, " bytes) in " , "the corresponding vector. This means that the wrong type " , "of vector is being used for a given codec." ] repetitions :: String -> Int -> Int -> Int repetitions sourceName sourceBytes codecBytes = if leftover /= 0 then error msg else n where (n, leftover) = sourceBytes `quotRem` codecBytes msg = concat [ "Data.Layout.Vector.encodeReps: " , "The source ", sourceName, " is not a multiple of " , show codecBytes, " bytes, as required by the Codec. " , show sourceBytes, " bytes were provided, which leaves " , show leftover, " bytes unused." ] ------------------------------------------------------------------------ -- Ptr Operations data DstPtr = DstPtr {-# UNPACK #-} !(ForeignPtr Word8) {-# UNPACK #-} !Int data SrcPtr = SrcPtr {-# UNPACK #-} !(ForeignPtr Word8) {-# UNPACK #-} !Int -- | Contains the information required to encode or decode a -- `V.Vector` from its arbitrary layout in a strict `ByteString`. data Codec = Codec { encode :: Int -> DstPtr -> SrcPtr -> IO () , decode :: Int -> DstPtr -> SrcPtr -> IO () , encodedSize :: Int , decodedSize :: Int , valueCount :: Int , valueSize :: Int } -- | Compiles a data layout in to a codec capable of encoding -- and decoding data stored in the layout. compile :: Layout -> Codec compile layout = Codec { encode = runCodec copyInfo c_encode , decode = runCodec copyInfo c_decode , encodedSize = L.size layout , decodedSize = L.valueSizeN layout , valueCount = L.valueCount layout , valueSize = L.valueSize1 layout } where copyInfo = buildCopyInfo layout runCodec :: CopyInfo -> CodecFn -> Int -> DstPtr -> SrcPtr -> IO () runCodec info c_codec_fn reps (DstPtr dstFP dstOff) (SrcPtr srcFP srcOff) = -- unbox foreign pointers for use withForeignPtr dstFP $ \dst0 -> do withForeignPtr srcFP $ \src0 -> do -- add the offset let dst = dst0 `plusPtr` dstOff src = src0 `plusPtr` srcOff -- get a pointer to the offsets V.unsafeWith (ciOffsets info) $ \offsets -> do -- decode the data err <- c_codec_fn dst src (fromIntegral reps) (ciNumOffsets info) offsets (ciNumValues info) (ciValueSize info) (ciSwapBytes info) -- check error code case err of 0 -> return () 1 -> error ("runCodec: invalid value size: " ++ show (ciValueSize info)) _ -> error "runCodec: unknown error" ------------------------------------------------------------------------ -- CopyInfo data CopyInfo = CopyInfo { ciOffsets :: V.Vector CInt , ciNumValues :: CInt , ciValueSize :: CInt , ciSwapBytes :: CInt } deriving (Show) ciNumOffsets :: CopyInfo -> CInt ciNumOffsets = fromIntegral . V.length . ciOffsets type SkipCopyOp = CInt -- | Build copy instructions for the specified layout. buildCopyInfo :: Layout -> CopyInfo buildCopyInfo layout = CopyInfo { ciOffsets, ciNumValues, ciValueSize, ciSwapBytes } where ciNumValues = copySize `quot` ciValueSize ciValueSize = fromIntegral (L.valueSize1 layout) ciSwapBytes = if needsByteSwap layout then 1 else 0 (copySize, ciOffsets) = (splitOps . optimize . toSkipCopyOps) layout -- convert from layout to skip/copy operations toSkipCopyOps :: Layout -> V.Vector SkipCopyOp toSkipCopyOps = go where go v@(Value _) = V.singleton (copyOp (L.valueSize1 v)) go (Offset n xs) = skipOp n `V.cons` go xs go (Repeat n xs) = V.concat (replicate n (go xs)) go (Group n xs) = go xs `V.snoc` skipOp (n - L.size xs) -- positive number means skip 'n' bytes skipOp = fromIntegral -- negative number means copy 'n' bytes copyOp n = fromIntegral (-n) -- squash copies and skips together, remove no-ops optimize :: V.Vector SkipCopyOp -> V.Vector SkipCopyOp optimize = skips where skips = sumWhile (> 0) copies copies = sumWhile (<= 0) skips sumWhile p k xs | V.null xs = V.empty | otherwise = let (ys, zs) = V.span p xs in case V.sum ys of 0 -> k zs s -> s `V.cons` k zs -- ensures first and last operations are skips or no-ops, -- then strips all copies out returns a tuple of the form -- (copy size, skip operations) splitOps :: V.Vector SkipCopyOp -> (CInt, V.Vector CInt) splitOps = split . head0 . last0 where head0 xs | V.head xs < 0 = 0 `V.cons` xs | otherwise = xs last0 xs | V.last xs < 0 = xs `V.snoc` 0 | otherwise = xs split :: V.Vector SkipCopyOp -> (CInt, V.Vector CInt) split xs = (-copyOp, V.filter isSkip xs) where copyOp = V.head (V.dropWhile (>= 0) xs) isSkip x | x == copyOp = False -- remove | x >= 0 = True -- keep | otherwise = error $ "buildCopyInfo: invalid copy operation " ++ "(expected <" ++ show (-copyOp) ++ " bytes>," ++ " actual <" ++ show (-x) ++ " bytes>)" ------------------------------------------------------------------------ -- Endian check needsByteSwap :: Layout -> Bool needsByteSwap x = case L.byteOrder x of NoByteOrder -> False LittleEndian -> hostIsBigEndian BigEndian -> hostIsLittleEndian endianCheck :: Word8 endianCheck = unsafePerformIO $ alloca $ \p -> do poke p (0x01020304 :: Word32) peek (castPtr p :: Ptr Word8) hostIsLittleEndian :: Bool hostIsLittleEndian = endianCheck == 4 hostIsBigEndian :: Bool hostIsBigEndian = endianCheck == 1 ------------------------------------------------------------------------ -- FFI type CodecFn = Ptr Word8 -- ^ destination memory area -> Ptr Word8 -- ^ source memory area -> CInt -- ^ number of times to repeat the encode/decode -> CInt -- ^ number of skip operations in 'offsets' -> Ptr CInt -- ^ offset / skip list -> CInt -- ^ number of values to copy in between each skip -> CInt -- ^ size of a single value in bytes -> CInt -- ^ non-zero to swap the byte order of values -> IO CInt -- ^ zero on success, non-zero otherwise foreign import ccall unsafe "data_layout_encode" c_encode :: CodecFn foreign import ccall unsafe "data_layout_decode" c_decode :: CodecFn