{-# LANGUAGE CPP, DeriveDataTypeable, FlexibleInstances, MultiParamTypeClasses #-}

#if __GLASGOW_HASKELL__ >= 701
{-# LANGUAGE Trustworthy #-}
#endif

#if defined(__GLASGOW_HASKELL__) && __GLASGOW_HASKELL__ >= 904
#define HAS_UNLIFTED_ARRAY 1
#endif

#if defined(HAS_UNLIFTED_ARRAY)
{-# LANGUAGE MagicHash, UnboxedTuples #-}
#endif

-----------------------------------------------------------------------------
-- |
-- Module      :  Control.Concurrent.STM.TArray
-- Copyright   :  (c) The University of Glasgow 2005
-- License     :  BSD-style (see the file libraries/base/LICENSE)
--
-- Maintainer  :  libraries@haskell.org
-- Stability   :  experimental
-- Portability :  non-portable (requires STM)
--
-- TArrays: transactional arrays, for use in the STM monad.
--
-----------------------------------------------------------------------------

module Control.Concurrent.STM.TArray (
    TArray
) where

import Control.Monad.STM (STM, atomically)
import Data.Typeable (Typeable)
#if defined(HAS_UNLIFTED_ARRAY)
import Control.Concurrent.STM.TVar (readTVar, readTVarIO, writeTVar)
import Data.Array.Base (safeRangeSize, MArray(..))
import Data.Ix (Ix)
import GHC.Conc (STM(..), TVar(..))
import GHC.Exts
import GHC.IO (IO(..))
#else
import Control.Concurrent.STM.TVar (TVar, newTVar, newTVarIO, readTVar, readTVarIO, writeTVar)
import Data.Array (Array, bounds, listArray)
import Data.Array.Base (safeRangeSize, unsafeAt, MArray(..), IArray(numElements))
#endif

-- | 'TArray' is a transactional array, supporting the usual 'MArray'
-- interface for mutable arrays.
--
-- It is conceptually implemented as @Array i (TVar e)@.
#if defined(HAS_UNLIFTED_ARRAY)
data TArray i e = TArray
    !i   -- lower bound
    !i   -- upper bound
    !Int -- size
    (Array# (TVar# RealWorld e))
    deriving (Typeable)

instance (Eq i, Eq e) => Eq (TArray i e) where
    (TArray l1 u1 n1 arr1#) == (TArray l2 u2 n2 arr2#) =
        -- each `TArray` has its own `TVar`s, so it's sufficient to compare the first one
        if n1 == 0 then n2 == 0 else l1 == l2 && u1 == u2 && isTrue# (sameTVar# (unsafeFirstT arr1#) (unsafeFirstT arr2#))
      where
        unsafeFirstT :: Array# (TVar# RealWorld e) -> TVar# RealWorld e
        unsafeFirstT arr# = case indexArray# arr# 0# of (# e #) -> e

newTArray# :: Ix i => (i, i) -> e -> State# RealWorld -> (# State# RealWorld, TArray i e #)
newTArray# b@(l, u) e = \s1# ->
    case safeRangeSize b of
        n@(I# n#) -> case newTVar# e s1# of
            (# s2#, initial_tvar# #) -> case newArray# n# initial_tvar# s2# of
                (# s3#, marr# #) ->
                    let go i# = \s4# -> case newTVar# e s4# of
                            (# s5#, tvar# #) -> case writeArray# marr# i# tvar# s5# of
                                s6# -> if isTrue# (i# ==# n# -# 1#) then s6# else go (i# +# 1#) s6#
                    in case unsafeFreezeArray# marr# (if n <= 1 then s3# else go 1# s3#) of
                        (# s7#, arr# #) -> (# s7#, TArray l u n arr# #)

instance MArray TArray e STM where
    getBounds (TArray l u _ _) = return (l, u)
    getNumElements (TArray _ _ n _) = return n
    newArray b e = STM $ newTArray# b e
    unsafeRead (TArray _ _ _ arr#) (I# i#) = case indexArray# arr# i# of
        (# tvar# #) -> readTVar (TVar tvar#)
    unsafeWrite (TArray _ _ _ arr#) (I# i#) e = case indexArray# arr# i# of
        (# tvar# #) -> writeTVar (TVar tvar#) e

-- | Writes are slow in `IO`.
instance MArray TArray e IO where
    getBounds (TArray l u _ _) = return (l, u)
    getNumElements (TArray _ _ n _) = return n
    newArray b e = IO $ newTArray# b e
    unsafeRead (TArray _ _ _ arr#) (I# i#) = case indexArray# arr# i# of
        (# tvar# #) -> readTVarIO (TVar tvar#)
    unsafeWrite (TArray _ _ _ arr#) (I# i#) e = case indexArray# arr# i# of
        (# tvar# #) -> atomically $ writeTVar (TVar tvar#) e
#else
newtype TArray i e = TArray (Array i (TVar e)) deriving (TArray i e -> TArray i e -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
forall i e. Ix i => TArray i e -> TArray i e -> Bool
/= :: TArray i e -> TArray i e -> Bool
$c/= :: forall i e. Ix i => TArray i e -> TArray i e -> Bool
== :: TArray i e -> TArray i e -> Bool
$c== :: forall i e. Ix i => TArray i e -> TArray i e -> Bool
Eq, Typeable)

instance MArray TArray e STM where
    getBounds :: forall i. Ix i => TArray i e -> STM (i, i)
getBounds (TArray Array i (TVar e)
a) = forall (m :: * -> *) a. Monad m => a -> m a
return (forall i e. Array i e -> (i, i)
bounds Array i (TVar e)
a)
    getNumElements :: forall i. Ix i => TArray i e -> STM Int
getNumElements (TArray Array i (TVar e)
a) = forall (m :: * -> *) a. Monad m => a -> m a
return (forall (a :: * -> * -> *) e i. (IArray a e, Ix i) => a i e -> Int
numElements Array i (TVar e)
a)
    newArray :: forall i. Ix i => (i, i) -> e -> STM (TArray i e)
newArray (i, i)
b e
e = do
        [TVar e]
a <- forall (m :: * -> *) a. Monad m => Int -> m a -> m [a]
rep (forall i. Ix i => (i, i) -> Int
safeRangeSize (i, i)
b) (forall a. a -> STM (TVar a)
newTVar e
e)
        forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ forall i e. Array i (TVar e) -> TArray i e
TArray (forall i e. Ix i => (i, i) -> [e] -> Array i e
listArray (i, i)
b [TVar e]
a)
    unsafeRead :: forall i. Ix i => TArray i e -> Int -> STM e
unsafeRead (TArray Array i (TVar e)
a) Int
i = forall a. TVar a -> STM a
readTVar forall a b. (a -> b) -> a -> b
$ forall (a :: * -> * -> *) e i.
(IArray a e, Ix i) =>
a i e -> Int -> e
unsafeAt Array i (TVar e)
a Int
i
    unsafeWrite :: forall i. Ix i => TArray i e -> Int -> e -> STM ()
unsafeWrite (TArray Array i (TVar e)
a) Int
i e
e = forall a. TVar a -> a -> STM ()
writeTVar (forall (a :: * -> * -> *) e i.
(IArray a e, Ix i) =>
a i e -> Int -> e
unsafeAt Array i (TVar e)
a Int
i) e
e

    {-# INLINE newArray #-}

-- | Writes are slow in `IO`.
instance MArray TArray e IO where
    getBounds :: forall i. Ix i => TArray i e -> IO (i, i)
getBounds (TArray Array i (TVar e)
a) = forall (m :: * -> *) a. Monad m => a -> m a
return (forall i e. Array i e -> (i, i)
bounds Array i (TVar e)
a)
    getNumElements :: forall i. Ix i => TArray i e -> IO Int
getNumElements (TArray Array i (TVar e)
a) = forall (m :: * -> *) a. Monad m => a -> m a
return (forall (a :: * -> * -> *) e i. (IArray a e, Ix i) => a i e -> Int
numElements Array i (TVar e)
a)
    newArray :: forall i. Ix i => (i, i) -> e -> IO (TArray i e)
newArray (i, i)
b e
e = do
        [TVar e]
a <- forall (m :: * -> *) a. Monad m => Int -> m a -> m [a]
rep (forall i. Ix i => (i, i) -> Int
safeRangeSize (i, i)
b) (forall a. a -> IO (TVar a)
newTVarIO e
e)
        forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ forall i e. Array i (TVar e) -> TArray i e
TArray (forall i e. Ix i => (i, i) -> [e] -> Array i e
listArray (i, i)
b [TVar e]
a)
    unsafeRead :: forall i. Ix i => TArray i e -> Int -> IO e
unsafeRead (TArray Array i (TVar e)
a) Int
i = forall a. TVar a -> IO a
readTVarIO forall a b. (a -> b) -> a -> b
$ forall (a :: * -> * -> *) e i.
(IArray a e, Ix i) =>
a i e -> Int -> e
unsafeAt Array i (TVar e)
a Int
i
    unsafeWrite :: forall i. Ix i => TArray i e -> Int -> e -> IO ()
unsafeWrite (TArray Array i (TVar e)
a) Int
i e
e = forall a. STM a -> IO a
atomically forall a b. (a -> b) -> a -> b
$ forall a. TVar a -> a -> STM ()
writeTVar (forall (a :: * -> * -> *) e i.
(IArray a e, Ix i) =>
a i e -> Int -> e
unsafeAt Array i (TVar e)
a Int
i) e
e

    {-# INLINE newArray #-}

-- | Like 'replicateM', but uses an accumulator to prevent stack overflows.
-- Unlike 'replicateM', the returned list is in reversed order.
-- This doesn't matter though since this function is only used to create
-- arrays with identical elements.
rep :: Monad m => Int -> m a -> m [a]
rep :: forall (m :: * -> *) a. Monad m => Int -> m a -> m [a]
rep Int
n m a
m = forall {t}. (Eq t, Num t) => t -> [a] -> m [a]
go Int
n []
    where
      go :: t -> [a] -> m [a]
go t
0 [a]
xs = forall (m :: * -> *) a. Monad m => a -> m a
return [a]
xs
      go t
i [a]
xs = do
          a
x <- m a
m
          t -> [a] -> m [a]
go (t
i forall a. Num a => a -> a -> a
- t
1) (a
x forall a. a -> [a] -> [a]
: [a]
xs)
#endif