{-# LANGUAGE BangPatterns               #-}
{-# LANGUAGE DerivingStrategies         #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE ScopedTypeVariables        #-}

module Data.Record.Anon.Internal.Util.StrictArray (
    StrictArray -- opaque
    -- * Reads
  , (!)
    -- * Conversion
  , fromList
  , fromListN
  , fromLazy
  , toLazy
    -- * Non-monadic combinators
  , (//)
  , update
  , backpermute
  , zipWith
    -- * Monadic combinators
  , mapM
  , zipWithM
  ) where

import Prelude hiding (mapM, zipWith)

import Control.Monad (forM_)
import Data.Primitive.SmallArray

import qualified Control.Monad as Monad
import qualified Data.Foldable as Foldable

{-------------------------------------------------------------------------------
  Definition
-------------------------------------------------------------------------------}

-- | Strict vector
--
-- Implemented as a wrapper around a 'SmallArray'.
--
-- NOTE: None of the operations on 'Vector' do any bounds checking.
--
-- NOTE: 'Vector' is implemented as a newtype around 'SmallArray', which in turn
-- is defined as
--
-- > data SmallArray a = SmallArray (SmallArray# a)
--
-- Furthermore, 'Canonical' is a newtype around 'Vector', which is then used in
-- 'Record' as
--
-- > data Record (f :: k -> Type) (r :: Row k) = Record {
-- >       recordCanon :: {-# UNPACK #-} !(Canonical f)
-- >     , ..
-- >     }
--
-- This means that 'Record' will have /direct/ access (no pointers) to the
-- 'SmallArray#'.
newtype StrictArray a = WrapLazy { forall a. StrictArray a -> SmallArray a
unwrapLazy :: SmallArray a }
  deriving newtype (Int -> StrictArray a -> ShowS
[StrictArray a] -> ShowS
StrictArray a -> String
forall a. Show a => Int -> StrictArray a -> ShowS
forall a. Show a => [StrictArray a] -> ShowS
forall a. Show a => StrictArray a -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [StrictArray a] -> ShowS
$cshowList :: forall a. Show a => [StrictArray a] -> ShowS
show :: StrictArray a -> String
$cshow :: forall a. Show a => StrictArray a -> String
showsPrec :: Int -> StrictArray a -> ShowS
$cshowsPrec :: forall a. Show a => Int -> StrictArray a -> ShowS
Show, StrictArray a -> StrictArray a -> Bool
forall a. Eq a => StrictArray a -> StrictArray a -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: StrictArray a -> StrictArray a -> Bool
$c/= :: forall a. Eq a => StrictArray a -> StrictArray a -> Bool
== :: StrictArray a -> StrictArray a -> Bool
$c== :: forall a. Eq a => StrictArray a -> StrictArray a -> Bool
Eq, forall a. Eq a => a -> StrictArray a -> Bool
forall a. Num a => StrictArray a -> a
forall a. Ord a => StrictArray a -> a
forall m. Monoid m => StrictArray m -> m
forall a. StrictArray a -> Bool
forall a. StrictArray a -> Int
forall a. StrictArray a -> [a]
forall a. (a -> a -> a) -> StrictArray a -> a
forall m a. Monoid m => (a -> m) -> StrictArray a -> m
forall b a. (b -> a -> b) -> b -> StrictArray a -> b
forall a b. (a -> b -> b) -> b -> StrictArray a -> b
forall (t :: * -> *).
(forall m. Monoid m => t m -> m)
-> (forall m a. Monoid m => (a -> m) -> t a -> m)
-> (forall m a. Monoid m => (a -> m) -> t a -> m)
-> (forall a b. (a -> b -> b) -> b -> t a -> b)
-> (forall a b. (a -> b -> b) -> b -> t a -> b)
-> (forall b a. (b -> a -> b) -> b -> t a -> b)
-> (forall b a. (b -> a -> b) -> b -> t a -> b)
-> (forall a. (a -> a -> a) -> t a -> a)
-> (forall a. (a -> a -> a) -> t a -> a)
-> (forall a. t a -> [a])
-> (forall a. t a -> Bool)
-> (forall a. t a -> Int)
-> (forall a. Eq a => a -> t a -> Bool)
-> (forall a. Ord a => t a -> a)
-> (forall a. Ord a => t a -> a)
-> (forall a. Num a => t a -> a)
-> (forall a. Num a => t a -> a)
-> Foldable t
product :: forall a. Num a => StrictArray a -> a
$cproduct :: forall a. Num a => StrictArray a -> a
sum :: forall a. Num a => StrictArray a -> a
$csum :: forall a. Num a => StrictArray a -> a
minimum :: forall a. Ord a => StrictArray a -> a
$cminimum :: forall a. Ord a => StrictArray a -> a
maximum :: forall a. Ord a => StrictArray a -> a
$cmaximum :: forall a. Ord a => StrictArray a -> a
elem :: forall a. Eq a => a -> StrictArray a -> Bool
$celem :: forall a. Eq a => a -> StrictArray a -> Bool
length :: forall a. StrictArray a -> Int
$clength :: forall a. StrictArray a -> Int
null :: forall a. StrictArray a -> Bool
$cnull :: forall a. StrictArray a -> Bool
toList :: forall a. StrictArray a -> [a]
$ctoList :: forall a. StrictArray a -> [a]
foldl1 :: forall a. (a -> a -> a) -> StrictArray a -> a
$cfoldl1 :: forall a. (a -> a -> a) -> StrictArray a -> a
foldr1 :: forall a. (a -> a -> a) -> StrictArray a -> a
$cfoldr1 :: forall a. (a -> a -> a) -> StrictArray a -> a
foldl' :: forall b a. (b -> a -> b) -> b -> StrictArray a -> b
$cfoldl' :: forall b a. (b -> a -> b) -> b -> StrictArray a -> b
foldl :: forall b a. (b -> a -> b) -> b -> StrictArray a -> b
$cfoldl :: forall b a. (b -> a -> b) -> b -> StrictArray a -> b
foldr' :: forall a b. (a -> b -> b) -> b -> StrictArray a -> b
$cfoldr' :: forall a b. (a -> b -> b) -> b -> StrictArray a -> b
foldr :: forall a b. (a -> b -> b) -> b -> StrictArray a -> b
$cfoldr :: forall a b. (a -> b -> b) -> b -> StrictArray a -> b
foldMap' :: forall m a. Monoid m => (a -> m) -> StrictArray a -> m
$cfoldMap' :: forall m a. Monoid m => (a -> m) -> StrictArray a -> m
foldMap :: forall m a. Monoid m => (a -> m) -> StrictArray a -> m
$cfoldMap :: forall m a. Monoid m => (a -> m) -> StrictArray a -> m
fold :: forall m. Monoid m => StrictArray m -> m
$cfold :: forall m. Monoid m => StrictArray m -> m
Foldable, NonEmpty (StrictArray a) -> StrictArray a
StrictArray a -> StrictArray a -> StrictArray a
forall b. Integral b => b -> StrictArray a -> StrictArray a
forall a. NonEmpty (StrictArray a) -> StrictArray a
forall a. StrictArray a -> StrictArray a -> StrictArray a
forall a.
(a -> a -> a)
-> (NonEmpty a -> a)
-> (forall b. Integral b => b -> a -> a)
-> Semigroup a
forall a b. Integral b => b -> StrictArray a -> StrictArray a
stimes :: forall b. Integral b => b -> StrictArray a -> StrictArray a
$cstimes :: forall a b. Integral b => b -> StrictArray a -> StrictArray a
sconcat :: NonEmpty (StrictArray a) -> StrictArray a
$csconcat :: forall a. NonEmpty (StrictArray a) -> StrictArray a
<> :: StrictArray a -> StrictArray a -> StrictArray a
$c<> :: forall a. StrictArray a -> StrictArray a -> StrictArray a
Semigroup, StrictArray a
[StrictArray a] -> StrictArray a
StrictArray a -> StrictArray a -> StrictArray a
forall a. Semigroup (StrictArray a)
forall a. StrictArray a
forall a.
Semigroup a -> a -> (a -> a -> a) -> ([a] -> a) -> Monoid a
forall a. [StrictArray a] -> StrictArray a
forall a. StrictArray a -> StrictArray a -> StrictArray a
mconcat :: [StrictArray a] -> StrictArray a
$cmconcat :: forall a. [StrictArray a] -> StrictArray a
mappend :: StrictArray a -> StrictArray a -> StrictArray a
$cmappend :: forall a. StrictArray a -> StrictArray a -> StrictArray a
mempty :: StrictArray a
$cmempty :: forall a. StrictArray a
Monoid)

{-------------------------------------------------------------------------------
  Reads
-------------------------------------------------------------------------------}

(!) :: StrictArray a -> Int -> a
! :: forall a. StrictArray a -> Int -> a
(!) = forall a. SmallArray a -> Int -> a
indexSmallArray forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. StrictArray a -> SmallArray a
unwrapLazy

{-------------------------------------------------------------------------------
  Conversion
-------------------------------------------------------------------------------}

fromList :: [a] -> StrictArray a
fromList :: forall a. [a] -> StrictArray a
fromList [a]
as = forall a. Int -> [a] -> StrictArray a
fromListN (forall (t :: * -> *) a. Foldable t => t a -> Int
length [a]
as) [a]
as

fromListN :: Int -> [a] -> StrictArray a
fromListN :: forall a. Int -> [a] -> StrictArray a
fromListN Int
n [a]
as = forall a. SmallArray a -> StrictArray a
WrapLazy forall a b. (a -> b) -> a -> b
$ forall a. (forall s. ST s (SmallMutableArray s a)) -> SmallArray a
runSmallArray forall a b. (a -> b) -> a -> b
$ do
    SmallMutableArray s a
r <- forall (m :: * -> *) a.
PrimMonad m =>
Int -> a -> m (SmallMutableArray (PrimState m) a)
newSmallArray Int
n forall a. HasCallStack => a
undefined
    forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip [Int
0..] [a]
as) forall a b. (a -> b) -> a -> b
$ \(Int
i, !a
a) ->
      forall (m :: * -> *) a.
PrimMonad m =>
SmallMutableArray (PrimState m) a -> Int -> a -> m ()
writeSmallArray SmallMutableArray s a
r Int
i a
a
    forall (m :: * -> *) a. Monad m => a -> m a
return SmallMutableArray s a
r

fromLazy :: forall a. SmallArray a -> StrictArray a
fromLazy :: forall a. SmallArray a -> StrictArray a
fromLazy SmallArray a
v = Int -> StrictArray a
go Int
0
  where
    go :: Int -> StrictArray a
    go :: Int -> StrictArray a
go Int
i
      | Int
i forall a. Ord a => a -> a -> Bool
< forall a. SmallArray a -> Int
sizeofSmallArray SmallArray a
v
      = let !_a :: a
_a = forall a. SmallArray a -> Int -> a
indexSmallArray SmallArray a
v Int
i in Int -> StrictArray a
go (forall a. Enum a => a -> a
succ Int
i)

      | Bool
otherwise
      = forall a. SmallArray a -> StrictArray a
WrapLazy SmallArray a
v

toLazy :: StrictArray a -> SmallArray a
toLazy :: forall a. StrictArray a -> SmallArray a
toLazy = forall a. StrictArray a -> SmallArray a
unwrapLazy

{-------------------------------------------------------------------------------
  Non-monadic combinators
-------------------------------------------------------------------------------}

instance Functor StrictArray where
  fmap :: forall a b. (a -> b) -> StrictArray a -> StrictArray b
fmap a -> b
f (WrapLazy SmallArray a
as) = forall a. SmallArray a -> StrictArray a
WrapLazy forall a b. (a -> b) -> a -> b
$ forall a. (forall s. ST s (SmallMutableArray s a)) -> SmallArray a
runSmallArray forall a b. (a -> b) -> a -> b
$ do
      SmallMutableArray s b
r <- forall (m :: * -> *) a.
PrimMonad m =>
Int -> a -> m (SmallMutableArray (PrimState m) a)
newSmallArray Int
newSize forall a. HasCallStack => a
undefined
      forall (m :: * -> *) a.
Monad m =>
SmallArray a -> (Int -> a -> m ()) -> m ()
forArrayM_ SmallArray a
as forall a b. (a -> b) -> a -> b
$ \Int
i a
a -> forall (m :: * -> *) a.
PrimMonad m =>
SmallMutableArray (PrimState m) a -> Int -> a -> m ()
writeSmallArray SmallMutableArray s b
r Int
i forall a b. (a -> b) -> a -> b
$! a -> b
f a
a
      forall (m :: * -> *) a. Monad m => a -> m a
return SmallMutableArray s b
r
    where
      newSize :: Int
      newSize :: Int
newSize = forall a. SmallArray a -> Int
sizeofSmallArray SmallArray a
as

(//) :: StrictArray a -> [(Int, a)] -> StrictArray a
// :: forall a. StrictArray a -> [(Int, a)] -> StrictArray a
(//) (WrapLazy SmallArray a
as) [(Int, a)]
as' = forall a. SmallArray a -> StrictArray a
WrapLazy forall a b. (a -> b) -> a -> b
$ forall a. (forall s. ST s (SmallMutableArray s a)) -> SmallArray a
runSmallArray forall a b. (a -> b) -> a -> b
$ do
    SmallMutableArray s a
r <- forall (m :: * -> *) a.
PrimMonad m =>
SmallArray a -> Int -> Int -> m (SmallMutableArray (PrimState m) a)
thawSmallArray SmallArray a
as Int
0 Int
newSize
    forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [(Int, a)]
as' forall a b. (a -> b) -> a -> b
$ \(Int
i, !a
a) -> forall (m :: * -> *) a.
PrimMonad m =>
SmallMutableArray (PrimState m) a -> Int -> a -> m ()
writeSmallArray SmallMutableArray s a
r Int
i a
a
    forall (m :: * -> *) a. Monad m => a -> m a
return SmallMutableArray s a
r
  where
    newSize :: Int
    newSize :: Int
newSize = forall a. SmallArray a -> Int
sizeofSmallArray SmallArray a
as

update :: StrictArray a -> StrictArray (Int, a) -> StrictArray a
update :: forall a. StrictArray a -> StrictArray (Int, a) -> StrictArray a
update (WrapLazy SmallArray a
as) (WrapLazy SmallArray (Int, a)
as') = forall a. SmallArray a -> StrictArray a
WrapLazy forall a b. (a -> b) -> a -> b
$ forall a. (forall s. ST s (SmallMutableArray s a)) -> SmallArray a
runSmallArray forall a b. (a -> b) -> a -> b
$ do
    SmallMutableArray s a
r <- forall (m :: * -> *) a.
PrimMonad m =>
SmallArray a -> Int -> Int -> m (SmallMutableArray (PrimState m) a)
thawSmallArray SmallArray a
as Int
0 Int
newSize
    forall (m :: * -> *) a.
Monad m =>
SmallArray a -> (Int -> a -> m ()) -> m ()
forArrayM_ SmallArray (Int, a)
as' forall a b. (a -> b) -> a -> b
$ \Int
_i (Int
j, !a
a) -> forall (m :: * -> *) a.
PrimMonad m =>
SmallMutableArray (PrimState m) a -> Int -> a -> m ()
writeSmallArray SmallMutableArray s a
r Int
j a
a
    forall (m :: * -> *) a. Monad m => a -> m a
return SmallMutableArray s a
r
  where
    newSize :: Int
    newSize :: Int
newSize = forall a. SmallArray a -> Int
sizeofSmallArray SmallArray a
as

backpermute :: StrictArray a -> StrictArray Int -> StrictArray a
backpermute :: forall a. StrictArray a -> StrictArray Int -> StrictArray a
backpermute (WrapLazy SmallArray a
as) (WrapLazy SmallArray Int
is) = forall a. SmallArray a -> StrictArray a
WrapLazy forall a b. (a -> b) -> a -> b
$ forall a. (forall s. ST s (SmallMutableArray s a)) -> SmallArray a
runSmallArray forall a b. (a -> b) -> a -> b
$ do
    SmallMutableArray s a
r <- forall (m :: * -> *) a.
PrimMonad m =>
Int -> a -> m (SmallMutableArray (PrimState m) a)
newSmallArray Int
newSize forall a. HasCallStack => a
undefined
    forall (m :: * -> *) a.
Monad m =>
SmallArray a -> (Int -> a -> m ()) -> m ()
forArrayM_ SmallArray Int
is forall a b. (a -> b) -> a -> b
$ \Int
i Int
j -> forall (m :: * -> *) a.
PrimMonad m =>
SmallMutableArray (PrimState m) a -> Int -> a -> m ()
writeSmallArray SmallMutableArray s a
r Int
i forall a b. (a -> b) -> a -> b
$! forall a. SmallArray a -> Int -> a
indexSmallArray SmallArray a
as Int
j
    forall (m :: * -> *) a. Monad m => a -> m a
return SmallMutableArray s a
r
  where
    newSize :: Int
    newSize :: Int
newSize = forall (t :: * -> *) a. Foldable t => t a -> Int
length SmallArray Int
is

zipWith :: (a -> b -> c) -> StrictArray a -> StrictArray b -> StrictArray c
zipWith :: forall a b c.
(a -> b -> c) -> StrictArray a -> StrictArray b -> StrictArray c
zipWith a -> b -> c
f (WrapLazy SmallArray a
as) (WrapLazy SmallArray b
bs) = forall a. SmallArray a -> StrictArray a
WrapLazy forall a b. (a -> b) -> a -> b
$ forall a. (forall s. ST s (SmallMutableArray s a)) -> SmallArray a
runSmallArray forall a b. (a -> b) -> a -> b
$ do
    SmallMutableArray s c
r <- forall (m :: * -> *) a.
PrimMonad m =>
Int -> a -> m (SmallMutableArray (PrimState m) a)
newSmallArray Int
newSize forall a. HasCallStack => a
undefined
    forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [Int
0 .. Int
newSize forall a. Num a => a -> a -> a
- Int
1] forall a b. (a -> b) -> a -> b
$ \Int
i -> do
      let !c :: c
c = a -> b -> c
f (forall a. SmallArray a -> Int -> a
indexSmallArray SmallArray a
as Int
i) (forall a. SmallArray a -> Int -> a
indexSmallArray SmallArray b
bs Int
i)
      forall (m :: * -> *) a.
PrimMonad m =>
SmallMutableArray (PrimState m) a -> Int -> a -> m ()
writeSmallArray SmallMutableArray s c
r Int
i c
c
    forall (m :: * -> *) a. Monad m => a -> m a
return SmallMutableArray s c
r
  where
    newSize :: Int
    newSize :: Int
newSize = forall a. Ord a => a -> a -> a
min (forall a. SmallArray a -> Int
sizeofSmallArray SmallArray a
as) (forall a. SmallArray a -> Int
sizeofSmallArray SmallArray b
bs)

{-------------------------------------------------------------------------------
  Applicative combinators

  NOTE: The monadic combinators here do two traversals, first collecting all
  elements of the vector in memory, and then constructing the new vector. The
  alternative is to use 'traverseSmallArrayP', but it is only sound with
  certain monads. Since this restriction would leak out to users of the library
  (through the monadic combinators on 'Record'), we prefer to avoid it.
-------------------------------------------------------------------------------}

mapM :: forall m a b.
     Applicative m
  => (a -> m b) -> StrictArray a -> m (StrictArray b)
mapM :: forall (m :: * -> *) a b.
Applicative m =>
(a -> m b) -> StrictArray a -> m (StrictArray b)
mapM a -> m b
f (WrapLazy SmallArray a
as) =
    forall a. Int -> [a] -> StrictArray a
fromListN Int
newSize forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$>
      forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse a -> m b
f (forall (t :: * -> *) a. Foldable t => t a -> [a]
Foldable.toList SmallArray a
as)
  where
    newSize :: Int
    newSize :: Int
newSize = forall a. SmallArray a -> Int
sizeofSmallArray SmallArray a
as

zipWithM ::
     Applicative m
  => (a -> b -> m c) -> StrictArray a -> StrictArray b -> m (StrictArray c)
zipWithM :: forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c)
-> StrictArray a -> StrictArray b -> m (StrictArray c)
zipWithM a -> b -> m c
f (WrapLazy SmallArray a
as) (WrapLazy SmallArray b
bs) = do
    forall a. Int -> [a] -> StrictArray a
fromListN Int
newSize forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$>
      forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
Monad.zipWithM a -> b -> m c
f (forall (t :: * -> *) a. Foldable t => t a -> [a]
Foldable.toList SmallArray a
as) (forall (t :: * -> *) a. Foldable t => t a -> [a]
Foldable.toList SmallArray b
bs)
  where
    newSize :: Int
    newSize :: Int
newSize = forall a. Ord a => a -> a -> a
min (forall a. SmallArray a -> Int
sizeofSmallArray SmallArray a
as) (forall a. SmallArray a -> Int
sizeofSmallArray SmallArray b
bs)

{-------------------------------------------------------------------------------
  Internal auxiliary
-------------------------------------------------------------------------------}

forArrayM_ :: forall m a. Monad m => SmallArray a -> (Int -> a -> m ()) -> m ()
forArrayM_ :: forall (m :: * -> *) a.
Monad m =>
SmallArray a -> (Int -> a -> m ()) -> m ()
forArrayM_ SmallArray a
arr Int -> a -> m ()
f = Int -> m ()
go Int
0
  where
    go :: Int -> m ()
    go :: Int -> m ()
go Int
i
      | Int
i forall a. Ord a => a -> a -> Bool
< forall a. SmallArray a -> Int
sizeofSmallArray SmallArray a
arr
      = Int -> a -> m ()
f Int
i (forall a. SmallArray a -> Int -> a
indexSmallArray SmallArray a
arr Int
i) forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Int -> m ()
go (forall a. Enum a => a -> a
succ Int
i)

      | Bool
otherwise
      = forall (m :: * -> *) a. Monad m => a -> m a
return ()