-- Alfred-Margaret: Fast Aho-Corasick string searching
-- Copyright 2022 Channable
--
-- Licensed under the 3-clause BSD license, see the LICENSE file in the
-- repository root.
{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE ScopedTypeVariables #-}

module Data.TypedByteArray
    ( Data.TypedByteArray.replicate
    , MutableTypedByteArray
    , Prim
    , TypedByteArray
    , fromList
    , toList
    , generate
    , newTypedByteArray
    , unsafeFreezeTypedByteArray
    , unsafeIndex
    , writeTypedByteArray
    , null
    , length
    , foldr
    ) where

import Prelude hiding (foldr, length, null)

import Control.DeepSeq (NFData (rnf))
import Control.Monad.Primitive (PrimMonad (PrimState))
import Control.Monad.ST (runST)
import Data.Primitive (ByteArray (ByteArray), MutableByteArray, Prim, byteArrayFromList,
                       indexByteArray, newByteArray, sizeOf, unsafeFreezeByteArray, writeByteArray)

import qualified Data.Primitive as Primitive


-- | Thin wrapper around 'ByteArray' that makes signatures and indexing nicer to read.
newtype TypedByteArray a = TypedByteArray ByteArray
    deriving (Int -> TypedByteArray a -> ShowS
forall a. Int -> TypedByteArray a -> ShowS
forall a. [TypedByteArray a] -> ShowS
forall a. TypedByteArray a -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [TypedByteArray a] -> ShowS
$cshowList :: forall a. [TypedByteArray a] -> ShowS
show :: TypedByteArray a -> String
$cshow :: forall a. TypedByteArray a -> String
showsPrec :: Int -> TypedByteArray a -> ShowS
$cshowsPrec :: forall a. Int -> TypedByteArray a -> ShowS
Show, TypedByteArray a -> TypedByteArray a -> Bool
forall a. TypedByteArray a -> TypedByteArray a -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: TypedByteArray a -> TypedByteArray a -> Bool
$c/= :: forall a. TypedByteArray a -> TypedByteArray a -> Bool
== :: TypedByteArray a -> TypedByteArray a -> Bool
$c== :: forall a. TypedByteArray a -> TypedByteArray a -> Bool
Eq)

-- | Thin wrapper around 'MutableByteArray s' that makes signatures and indexing nicer to read.
newtype MutableTypedByteArray a s = MutableTypedByteArray (MutableByteArray s)

instance NFData (TypedByteArray a) where
    rnf :: TypedByteArray a -> ()
rnf (TypedByteArray (ByteArray !ByteArray#
_)) = ()

{-# INLINE newTypedByteArray #-}
newTypedByteArray :: forall a m. (Prim a, PrimMonad m) => Int -> m (MutableTypedByteArray a (PrimState m))
newTypedByteArray :: forall a (m :: * -> *).
(Prim a, PrimMonad m) =>
Int -> m (MutableTypedByteArray a (PrimState m))
newTypedByteArray = forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall a s. MutableByteArray s -> MutableTypedByteArray a s
MutableTypedByteArray forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *).
PrimMonad m =>
Int -> m (MutableByteArray (PrimState m))
newByteArray forall b c a. (b -> c) -> (a -> b) -> a -> c
. (forall a. Num a => a -> a -> a
* forall a. Prim a => a -> Int
sizeOf (forall a. HasCallStack => a
undefined :: a))

{-# INLINE fromList #-}
fromList :: Prim a => [a] -> TypedByteArray a
fromList :: forall a. Prim a => [a] -> TypedByteArray a
fromList = forall a. ByteArray -> TypedByteArray a
TypedByteArray forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Prim a => [a] -> ByteArray
byteArrayFromList

{-# INLINE toList #-}
toList :: Prim a => TypedByteArray a -> [a]
toList :: forall a. Prim a => TypedByteArray a -> [a]
toList = forall a b. Prim a => (a -> b -> b) -> b -> TypedByteArray a -> b
foldr (:) []

-- | Element index without bounds checking.
{-# INLINE unsafeIndex #-}
unsafeIndex :: Prim a => TypedByteArray a -> Int -> a
unsafeIndex :: forall a. Prim a => TypedByteArray a -> Int -> a
unsafeIndex (TypedByteArray ByteArray
arr) = forall a. Prim a => ByteArray -> Int -> a
indexByteArray ByteArray
arr

{-# INLINE generate #-}
-- | Construct a 'TypedByteArray' of the given length by applying the function to each index in @[0..n-1]@.
generate :: Prim a => Int -> (Int -> a) -> TypedByteArray a
generate :: forall a. Prim a => Int -> (Int -> a) -> TypedByteArray a
generate !Int
n Int -> a
f = forall a. (forall s. ST s a) -> a
runST forall a b. (a -> b) -> a -> b
$ do
    -- Allocate enough space for n elements of type a
    MutableTypedByteArray a s
arr <- forall a (m :: * -> *).
(Prim a, PrimMonad m) =>
Int -> m (MutableTypedByteArray a (PrimState m))
newTypedByteArray Int
n
    forall (m :: * -> *).
Monad m =>
Int -> Int -> (Int -> m ()) -> m ()
intLoop Int
0 Int
n forall a b. (a -> b) -> a -> b
$ \Int
i -> Int
i seq :: forall a b. a -> b -> b
`seq` forall a (m :: * -> *).
(Prim a, PrimMonad m) =>
MutableTypedByteArray a (PrimState m) -> Int -> a -> m ()
writeTypedByteArray MutableTypedByteArray a s
arr Int
i forall a b. (a -> b) -> a -> b
$ Int -> a
f Int
i

    forall (m :: * -> *) a.
PrimMonad m =>
MutableTypedByteArray a (PrimState m) -> m (TypedByteArray a)
unsafeFreezeTypedByteArray MutableTypedByteArray a s
arr

replicate :: (Prim a, PrimMonad m) => Int -> a -> m (MutableTypedByteArray a (PrimState m))
replicate :: forall a (m :: * -> *).
(Prim a, PrimMonad m) =>
Int -> a -> m (MutableTypedByteArray a (PrimState m))
replicate Int
n a
value = do
    MutableTypedByteArray a (PrimState m)
arr <- forall a (m :: * -> *).
(Prim a, PrimMonad m) =>
Int -> m (MutableTypedByteArray a (PrimState m))
newTypedByteArray Int
n
    forall (m :: * -> *).
Monad m =>
Int -> Int -> (Int -> m ()) -> m ()
intLoop Int
0 Int
n forall a b. (a -> b) -> a -> b
$ \Int
i -> Int
i seq :: forall a b. a -> b -> b
`seq` forall a (m :: * -> *).
(Prim a, PrimMonad m) =>
MutableTypedByteArray a (PrimState m) -> Int -> a -> m ()
writeTypedByteArray MutableTypedByteArray a (PrimState m)
arr Int
i a
value
    forall (f :: * -> *) a. Applicative f => a -> f a
pure MutableTypedByteArray a (PrimState m)
arr

{-# INLINE writeTypedByteArray #-}
writeTypedByteArray :: (Prim a, PrimMonad m) => MutableTypedByteArray a (PrimState m) -> Int -> a -> m ()
writeTypedByteArray :: forall a (m :: * -> *).
(Prim a, PrimMonad m) =>
MutableTypedByteArray a (PrimState m) -> Int -> a -> m ()
writeTypedByteArray (MutableTypedByteArray MutableByteArray (PrimState m)
array) = forall a (m :: * -> *).
(Prim a, PrimMonad m) =>
MutableByteArray (PrimState m) -> Int -> a -> m ()
writeByteArray MutableByteArray (PrimState m)
array

{-# INLINE unsafeFreezeTypedByteArray #-}
unsafeFreezeTypedByteArray :: PrimMonad m => MutableTypedByteArray a (PrimState m) -> m (TypedByteArray a)
unsafeFreezeTypedByteArray :: forall (m :: * -> *) a.
PrimMonad m =>
MutableTypedByteArray a (PrimState m) -> m (TypedByteArray a)
unsafeFreezeTypedByteArray (MutableTypedByteArray MutableByteArray (PrimState m)
array) = forall a. ByteArray -> TypedByteArray a
TypedByteArray forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (m :: * -> *).
PrimMonad m =>
MutableByteArray (PrimState m) -> m ByteArray
unsafeFreezeByteArray MutableByteArray (PrimState m)
array

{-# INLINE intLoop #-}
intLoop :: Monad m => Int -> Int -> (Int -> m ()) -> m ()
intLoop :: forall (m :: * -> *).
Monad m =>
Int -> Int -> (Int -> m ()) -> m ()
intLoop !Int
iStart !Int
n Int -> m ()
p = Int -> m ()
go Int
iStart
    where
        go :: Int -> m ()
go !Int
i
            | Int
i forall a. Ord a => a -> a -> Bool
>= Int
n = forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
            | Bool
otherwise = do
                Int -> m ()
p Int
i
                Int -> m ()
go (Int
i forall a. Num a => a -> a -> a
+ Int
1)

{-# INLINE null #-}
null :: TypedByteArray a -> Bool
null :: forall a. TypedByteArray a -> Bool
null (TypedByteArray ByteArray
arr) =
  ByteArray -> Int
Primitive.sizeofByteArray ByteArray
arr forall a. Eq a => a -> a -> Bool
== Int
0  -- under the assumption that elements are not size 0

{-# INLINE length #-}
length :: forall a. Prim a => TypedByteArray a -> Int
length :: forall a. Prim a => TypedByteArray a -> Int
length (TypedByteArray ByteArray
arr) =
  -- This is how foldrByteArray calculates it, so must be good
  ByteArray -> Int
Primitive.sizeofByteArray ByteArray
arr forall a. Integral a => a -> a -> a
`quot` forall a. Prim a => a -> Int
sizeOf (forall a. HasCallStack => a
undefined :: a)

{-# INLINE foldr #-}
foldr :: Prim a => (a -> b -> b) -> b -> TypedByteArray a -> b
foldr :: forall a b. Prim a => (a -> b -> b) -> b -> TypedByteArray a -> b
foldr a -> b -> b
f b
a (TypedByteArray ByteArray
arr) = forall a b. Prim a => (a -> b -> b) -> b -> ByteArray -> b
Primitive.foldrByteArray a -> b -> b
f b
a ByteArray
arr