{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE ExistentialQuantification #-}
module Streamly.Memory.Ring
( Ring(..)
, new
, unsafeInsert
, unsafeFoldRing
, unsafeFoldRingM
, unsafeFoldRingFullM
, unsafeEqArray
, unsafeEqArrayN
) where
import Control.Exception (assert)
import Foreign.ForeignPtr (ForeignPtr, withForeignPtr, touchForeignPtr)
import Foreign.ForeignPtr.Unsafe (unsafeForeignPtrToPtr)
import Foreign.Ptr (plusPtr, minusPtr, castPtr)
import Foreign.Storable (Storable(..))
import GHC.ForeignPtr (mallocPlainForeignPtrAlignedBytes)
import GHC.Ptr (Ptr(..))
import Prelude hiding (length, concat)
import Control.Monad.IO.Class (MonadIO(..))
import qualified Streamly.Internal.Memory.Array.Types as A
data Ring a = Ring
{ ringStart :: !(ForeignPtr a)
, ringBound :: !(Ptr a)
}
{-# INLINE new #-}
new :: forall a. Storable a => Int -> IO (Ring a, Ptr a)
new count = do
let size = count * sizeOf (undefined :: a)
fptr <- mallocPlainForeignPtrAlignedBytes size (alignment (undefined :: a))
let p = unsafeForeignPtrToPtr fptr
return (Ring
{ ringStart = fptr
, ringBound = p `plusPtr` size
}, p)
{-# INLINE advance #-}
advance :: forall a. Storable a => Ring a -> Ptr a -> Ptr a
advance Ring{..} ringHead =
let ptr = ringHead `plusPtr` sizeOf (undefined :: a)
in if ptr < ringBound
then ptr
else unsafeForeignPtrToPtr ringStart
{-# INLINE unsafeInsert #-}
unsafeInsert :: Storable a => Ring a -> Ptr a -> a -> IO (Ptr a)
unsafeInsert rb ringHead newVal = do
poke ringHead newVal
return $ advance rb ringHead
{-# INLINE unsafeEqArrayN #-}
unsafeEqArrayN :: Ring a -> Ptr a -> A.Array a -> Int -> Bool
unsafeEqArrayN Ring{..} rh A.Array{..} n =
let !res = A.unsafeInlineIO $ do
let rs = unsafeForeignPtrToPtr ringStart
as = unsafeForeignPtrToPtr aStart
assert (aBound `minusPtr` as >= ringBound `minusPtr` rs) (return ())
let len = ringBound `minusPtr` rh
r1 <- A.memcmp (castPtr rh) (castPtr as) (min len n)
r2 <- if n > len
then A.memcmp (castPtr rs) (castPtr (as `plusPtr` len))
(min (rh `minusPtr` rs) (n - len))
else return True
return (r1 && r2)
in res
{-# INLINE unsafeEqArray #-}
unsafeEqArray :: Ring a -> Ptr a -> A.Array a -> Bool
unsafeEqArray Ring{..} rh A.Array{..} =
let !res = A.unsafeInlineIO $ do
let rs = unsafeForeignPtrToPtr ringStart
let as = unsafeForeignPtrToPtr aStart
assert (aBound `minusPtr` as >= ringBound `minusPtr` rs)
(return ())
let len = ringBound `minusPtr` rh
r1 <- A.memcmp (castPtr rh) (castPtr as) len
r2 <- A.memcmp (castPtr rs) (castPtr (as `plusPtr` len))
(rh `minusPtr` rs)
return (r1 && r2)
in res
{-# INLINE unsafeFoldRing #-}
unsafeFoldRing :: forall a b. Storable a
=> Ptr a -> (b -> a -> b) -> b -> Ring a -> b
unsafeFoldRing ptr f z Ring{..} =
let !res = A.unsafeInlineIO $ withForeignPtr ringStart $ \p ->
go z p ptr
in res
where
go !acc !p !q
| p == q = return acc
| otherwise = do
x <- peek p
go (f acc x) (p `plusPtr` sizeOf (undefined :: a)) q
withForeignPtrM :: MonadIO m => ForeignPtr a -> (Ptr a -> m b) -> m b
withForeignPtrM fp fn = do
r <- fn $ unsafeForeignPtrToPtr fp
liftIO $ touchForeignPtr fp
return r
{-# INLINE unsafeFoldRingM #-}
unsafeFoldRingM :: forall m a b. (MonadIO m, Storable a)
=> Ptr a -> (b -> a -> m b) -> b -> Ring a -> m b
unsafeFoldRingM ptr f z Ring {..} =
withForeignPtrM ringStart $ \x -> go z x ptr
where
go !acc !start !end
| start == end = return acc
| otherwise = do
let !x = A.unsafeInlineIO $ peek start
acc' <- f acc x
go acc' (start `plusPtr` sizeOf (undefined :: a)) end
{-# INLINE unsafeFoldRingFullM #-}
unsafeFoldRingFullM :: forall m a b. (MonadIO m, Storable a)
=> Ptr a -> (b -> a -> m b) -> b -> Ring a -> m b
unsafeFoldRingFullM rh f z rb@Ring {..} =
withForeignPtrM ringStart $ \_ -> go z rh
where
go !acc !start = do
let !x = A.unsafeInlineIO $ peek start
acc' <- f acc x
let ptr = advance rb start
if ptr == rh
then return acc'
else go acc' ptr