----------------------------------------------------------------------------
-- |
-- Module      :  Control.Concurrent.Counter.Lifted.ST
-- Copyright   :  (c) Sergey Vinokurov 2022
-- License     :  Apache-2.0 (see LICENSE)
-- Maintainer  :  serg.foo@gmail.com
--
-- Counters that support some atomic operations. Safe to use from
-- multiple threads and likely faster than using IORef or TVar for the
-- same operation (terms and conditions apply).
----------------------------------------------------------------------------

{-# LANGUAGE MagicHash     #-}
{-# LANGUAGE UnboxedTuples #-}

module Control.Concurrent.Counter.Lifted.ST
  ( Counter

  -- * Create
  , new

  -- * Read/write
  , get
  , set
  , cas

  -- * Arithmetic operations
  , add
  , sub

  -- * Bitwise operations
  , and
  , or
  , xor
  , nand
  ) where

import Prelude hiding (and, or)

import GHC.Exts (Int(..), Int#, State#)
import GHC.ST

import qualified Control.Concurrent.Counter.Unlifted as Unlifted

-- | Memory location that supports select few atomic operations.
--
-- Isomorphic to @STRef s Int@.
data Counter s = Counter (Unlifted.Counter s)

-- | Pointer equality
instance Eq (Counter s) where
  Counter Counter s
x == :: Counter s -> Counter s -> Bool
== Counter Counter s
y = Counter s -> Counter s -> Bool
forall s. Counter s -> Counter s -> Bool
Unlifted.sameCounter Counter s
x Counter s
y


{-# INLINE new #-}
-- | Create new counter with initial value.
new :: Int -> ST s (Counter s)
new :: forall s. Int -> ST s (Counter s)
new (I# Int#
initVal) = STRep s (Counter s) -> ST s (Counter s)
forall s a. STRep s a -> ST s a
ST (STRep s (Counter s) -> ST s (Counter s))
-> STRep s (Counter s) -> ST s (Counter s)
forall a b. (a -> b) -> a -> b
$ \State# s
s1 -> case Int# -> State# s -> (# State# s, Counter s #)
forall s. Int# -> State# s -> (# State# s, Counter s #)
Unlifted.new Int#
initVal State# s
s1 of
  (# State# s
s2, Counter s
c #) -> (# State# s
s2, Counter s -> Counter s
forall s. Counter s -> Counter s
Counter Counter s
c #)


{-# INLINE get #-}
-- | Atomically read the counter's value.
get
  :: Counter s
  -> ST s Int
get :: forall s. Counter s -> ST s Int
get (Counter Counter s
c) = STRep s Int -> ST s Int
forall s a. STRep s a -> ST s a
ST (STRep s Int -> ST s Int) -> STRep s Int -> ST s Int
forall a b. (a -> b) -> a -> b
$ \State# s
s1 -> case Counter s -> State# s -> (# State# s, Int# #)
forall s. Counter s -> State# s -> (# State# s, Int# #)
Unlifted.get Counter s
c State# s
s1 of
  (# State# s
s2, Int#
x #) -> (# State# s
s2, Int# -> Int
I# Int#
x #)

{-# INLINE set #-}
-- | Atomically assign new value to the counter.
set
  :: Counter s
  -> Int
  -> ST s ()
set :: forall s. Counter s -> Int -> ST s ()
set (Counter Counter s
c) (I# Int#
x) = STRep s () -> ST s ()
forall s a. STRep s a -> ST s a
ST (STRep s () -> ST s ()) -> STRep s () -> ST s ()
forall a b. (a -> b) -> a -> b
$ \State# s
s1 -> case Counter s -> Int# -> State# s -> (# State# s #)
forall s. Counter s -> Int# -> State# s -> (# State# s #)
Unlifted.set Counter s
c Int#
x State# s
s1 of
  (# State# s
s2 #) -> (# State# s
s2, () #)

{-# INLINE cas #-}
-- | Atomic compare and swap, i.e. write the new value if the current
-- value matches the provided old value. Returns the value of the
-- element before the operation
cas
  :: Counter s
  -> Int -- ^ Expected old value
  -> Int -- ^ New value
  -> ST s Int
cas :: forall s. Counter s -> Int -> Int -> ST s Int
cas (Counter Counter s
c) (I# Int#
x) (I# Int#
y) = STRep s Int -> ST s Int
forall s a. STRep s a -> ST s a
ST (STRep s Int -> ST s Int) -> STRep s Int -> ST s Int
forall a b. (a -> b) -> a -> b
$ \State# s
s1 -> case Counter s -> Int# -> Int# -> State# s -> (# State# s, Int# #)
forall s.
Counter s -> Int# -> Int# -> State# s -> (# State# s, Int# #)
Unlifted.cas Counter s
c Int#
x Int#
y State# s
s1 of
  (# State# s
s2, Int#
z #) -> (# State# s
s2, Int# -> Int
I# Int#
z #)

{-# INLINE add #-}
-- | Atomically add an amount to the counter and return its old value.
add :: Counter s -> Int -> ST s Int
add :: forall s. Counter s -> Int -> ST s Int
add = (Counter s -> Int# -> State# s -> (# State# s, Int# #))
-> Counter s -> Int -> ST s Int
forall s.
(Counter s -> Int# -> State# s -> (# State# s, Int# #))
-> Counter s -> Int -> ST s Int
toST Counter s -> Int# -> State# s -> (# State# s, Int# #)
forall s. Counter s -> Int# -> State# s -> (# State# s, Int# #)
Unlifted.add

{-# INLINE sub #-}
-- | Atomically subtract an amount from the counter and return its old value.
sub :: Counter s -> Int -> ST s Int
sub :: forall s. Counter s -> Int -> ST s Int
sub = (Counter s -> Int# -> State# s -> (# State# s, Int# #))
-> Counter s -> Int -> ST s Int
forall s.
(Counter s -> Int# -> State# s -> (# State# s, Int# #))
-> Counter s -> Int -> ST s Int
toST Counter s -> Int# -> State# s -> (# State# s, Int# #)
forall s. Counter s -> Int# -> State# s -> (# State# s, Int# #)
Unlifted.sub


{-# INLINE and #-}
-- | Atomically combine old value with a new one via bitwise and. Returns old counter value.
and :: Counter s -> Int -> ST s Int
and :: forall s. Counter s -> Int -> ST s Int
and = (Counter s -> Int# -> State# s -> (# State# s, Int# #))
-> Counter s -> Int -> ST s Int
forall s.
(Counter s -> Int# -> State# s -> (# State# s, Int# #))
-> Counter s -> Int -> ST s Int
toST Counter s -> Int# -> State# s -> (# State# s, Int# #)
forall s. Counter s -> Int# -> State# s -> (# State# s, Int# #)
Unlifted.and

{-# INLINE or #-}
-- | Atomically combine old value with a new one via bitwise or. Returns old counter value.
or :: Counter s -> Int -> ST s Int
or :: forall s. Counter s -> Int -> ST s Int
or = (Counter s -> Int# -> State# s -> (# State# s, Int# #))
-> Counter s -> Int -> ST s Int
forall s.
(Counter s -> Int# -> State# s -> (# State# s, Int# #))
-> Counter s -> Int -> ST s Int
toST Counter s -> Int# -> State# s -> (# State# s, Int# #)
forall s. Counter s -> Int# -> State# s -> (# State# s, Int# #)
Unlifted.or

{-# INLINE xor #-}
-- | Atomically combine old value with a new one via bitwise xor. Returns old counter value.
xor :: Counter s -> Int -> ST s Int
xor :: forall s. Counter s -> Int -> ST s Int
xor = (Counter s -> Int# -> State# s -> (# State# s, Int# #))
-> Counter s -> Int -> ST s Int
forall s.
(Counter s -> Int# -> State# s -> (# State# s, Int# #))
-> Counter s -> Int -> ST s Int
toST Counter s -> Int# -> State# s -> (# State# s, Int# #)
forall s. Counter s -> Int# -> State# s -> (# State# s, Int# #)
Unlifted.xor

{-# INLINE nand #-}
-- | Atomically combine old value with a new one via bitwise nand. Returns old counter value.
nand :: Counter s -> Int -> ST s Int
nand :: forall s. Counter s -> Int -> ST s Int
nand = (Counter s -> Int# -> State# s -> (# State# s, Int# #))
-> Counter s -> Int -> ST s Int
forall s.
(Counter s -> Int# -> State# s -> (# State# s, Int# #))
-> Counter s -> Int -> ST s Int
toST Counter s -> Int# -> State# s -> (# State# s, Int# #)
forall s. Counter s -> Int# -> State# s -> (# State# s, Int# #)
Unlifted.nand

{-# INLINE toST #-}
toST
  :: (Unlifted.Counter s -> Int# -> State# s -> (# State# s, Int# #))
  -> Counter s -> Int -> ST s Int
toST :: forall s.
(Counter s -> Int# -> State# s -> (# State# s, Int# #))
-> Counter s -> Int -> ST s Int
toST Counter s -> Int# -> State# s -> (# State# s, Int# #)
f = \(Counter Counter s
c) (I# Int#
x) -> STRep s Int -> ST s Int
forall s a. STRep s a -> ST s a
ST (STRep s Int -> ST s Int) -> STRep s Int -> ST s Int
forall a b. (a -> b) -> a -> b
$ \State# s
s1 -> case Counter s -> Int# -> State# s -> (# State# s, Int# #)
f Counter s
c Int#
x State# s
s1 of
  (# State# s
s2, Int#
old #) -> (# State# s
s2, Int# -> Int
I# Int#
old #)