----------------------------------------------------------------------------
-- |
-- Module      :  Control.Concurrent.Counter.Unlifted
-- Copyright   :  (c) Sergey Vinokurov 2022
-- License     :  Apache-2.0 (see LICENSE)
-- Maintainer  :  serg.foo@gmail.com
--
-- Counters that support some atomic operations. Safe to use from
-- multiple threads and likely faster than using 'Data.IORef.IORef' or
-- 'Control.Concurrent.STM.TVar.TVar' for the same operation (terms and
-- conditions apply).
--
-- This module defines unlifted newtype wrapper and corresponding operations,
-- they're not suitable for use with e.g. monads or being stored in other
-- data structures that expect lifted types. For general use start with
-- 'Control.Concurrent.Counter.Counter' module.
----------------------------------------------------------------------------

{-# LANGUAGE CPP                  #-}
{-# LANGUAGE GHCForeignImportPrim #-}
{-# LANGUAGE KindSignatures       #-}
{-# LANGUAGE MagicHash            #-}
{-# LANGUAGE UnboxedTuples        #-}
{-# LANGUAGE UnliftedFFITypes     #-}
{-# LANGUAGE UnliftedNewtypes     #-}

module Control.Concurrent.Counter.Unlifted
  ( Counter

  -- * Create
  , new

  -- * Read/write
  , get
  , set
  , cas

  -- * Arithmetic operations
  , add
  , sub

  -- * Bitwise operations
  , and
  , or
  , xor
  , nand

  -- * Compare
  , sameCounter

  ) where

import Prelude hiding (and, or)

import GHC.Exts

#include "MachDeps.h"
#ifndef SIZEOF_HSINT
#error "MachDeps.h didn't define SIZEOF_HSINT"
#endif

#define ADD_HASH(x) x#

#if defined(USE_CMM) && SIZEOF_HSINT == 8

-- | Memory location that supports select few atomic operations.
newtype Counter s = Counter (Any :: UnliftedType)

-- | Create new counter with initial value.
foreign import prim "stg_newCounterzh"
  new :: Int# -> State# s -> (# State# s, Counter s #)

-- | Atomically read the counter's value.
foreign import prim "stg_atomicGetCounterzh"
  get :: Counter s -> State# s -> (# State# s, Int# #)

-- | Atomically assign new value to the counter.
foreign import prim "stg_atomicSetCounterzh"
  set :: Counter s -> Int# -> State# s -> (# State# s #)

-- | Atomically add an amount to the counter and return its old value.
foreign import prim "stg_atomicAddCounterzh"
  add :: Counter s -> Int# -> State# s -> (# State# s, Int# #)

-- | Atomically subtract an amount from the counter and return its old value.
foreign import prim "stg_atomicSubCounterzh"
  sub :: Counter s -> Int# -> State# s -> (# State# s, Int# #)

-- | Atomically combine old value with a new one via bitwise and. Returns old counter value.
foreign import prim "stg_atomicAndCounterzh"
  and :: Counter s -> Int# -> State# s -> (# State# s, Int# #)

-- | Atomically combine old value with a new one via bitwise or. Returns old counter value.
foreign import prim "stg_atomicOrCounterzh"
  or :: Counter s -> Int# -> State# s -> (# State# s, Int# #)

-- | Atomically combine old value with a new one via bitwise xor. Returns old counter value.
foreign import prim "stg_atomicXorCounterzh"
  xor :: Counter s -> Int# -> State# s -> (# State# s, Int# #)

-- | Atomically combine old value with a new one via bitwise nand. Returns old counter value.
foreign import prim "stg_atomicNandCounterzh"
  nand :: Counter s -> Int# -> State# s -> (# State# s, Int# #)

-- | Atomic compare and swap, i.e. write the new value if the current
-- value matches the provided old value. Returns the value of the
-- element before the operation
foreign import prim "stg_casCounterzh"
  cas :: Counter s -> Int# -> Int# -> State# s -> (# State# s, Int# #)

-- | Compare the underlying pointers of two counters.
sameCounter :: Counter s -> Counter s -> Bool
sameCounter (Counter x) (Counter y) =
  isTrue# (reallyUnsafePtrEquality# x y)

#else

-- | Memory location that supports select few atomic operations.
newtype Counter s = Counter (MutableByteArray# s)

{-# INLINE new #-}
-- | Create new counter with initial value.
new :: Int# -> State# s -> (# State# s, Counter s #)
new :: forall s. Int# -> State# s -> (# State# s, Counter s #)
new Int#
initVal = \State# s
s1 -> case forall d. Int# -> State# d -> (# State# d, MutableByteArray# d #)
newByteArray# ADD_HASH(SIZEOF_HSINT) s1 of
  (# State# s
s2, MutableByteArray# s
arr #) ->
    case forall d.
MutableByteArray# d -> Int# -> Int# -> State# d -> State# d
writeIntArray# MutableByteArray# s
arr Int#
0# Int#
initVal State# s
s2 of
      State# s
s3 -> (# State# s
s3, forall s. MutableByteArray# s -> Counter s
Counter MutableByteArray# s
arr #)


{-# INLINE get #-}
-- | Atomically read the counter's value.
get :: Counter s -> State# s -> (# State# s, Int# #)
get :: forall s. Counter s -> State# s -> (# State# s, Int# #)
get (Counter MutableByteArray# s
arr) = forall d.
MutableByteArray# d -> Int# -> State# d -> (# State# d, Int# #)
atomicReadIntArray# MutableByteArray# s
arr Int#
0#

{-# INLINE set #-}
-- | Atomically assign new value to the counter.
set :: Counter s -> Int# -> State# s -> (# State# s #)
set :: forall s. Counter s -> Int# -> State# s -> (# State# s #)
set (Counter MutableByteArray# s
arr) Int#
n = \State# s
s1 -> case forall d.
MutableByteArray# d -> Int# -> Int# -> State# d -> State# d
atomicWriteIntArray# MutableByteArray# s
arr Int#
0# Int#
n State# s
s1 of
  State# s
s2 -> (# State# s
s2 #)

{-# INLINE cas #-}
-- | Atomic compare and swap, i.e. write the new value if the current
-- value matches the provided old value. Returns the value of the
-- element before the operation
cas
  :: Counter s
  -> Int# -- ^ Expected old value
  -> Int# -- ^ New value
  -> State# s
  -> (# State# s, Int# #)
cas :: forall s.
Counter s -> Int# -> Int# -> State# s -> (# State# s, Int# #)
cas (Counter MutableByteArray# s
arr) = forall d.
MutableByteArray# d
-> Int# -> Int# -> Int# -> State# d -> (# State# d, Int# #)
casIntArray# MutableByteArray# s
arr Int#
0#

{-# INLINE add #-}
-- | Atomically add an amount to the counter and return its old value.
add :: Counter s -> Int# -> State# s -> (# State# s, Int# #)
add :: forall s. Counter s -> Int# -> State# s -> (# State# s, Int# #)
add (Counter MutableByteArray# s
arr) = forall d.
MutableByteArray# d
-> Int# -> Int# -> State# d -> (# State# d, Int# #)
fetchAddIntArray# MutableByteArray# s
arr Int#
0#

{-# INLINE sub #-}
-- | Atomically subtract an amount from the counter and return its old value.
sub :: Counter s -> Int# -> State# s -> (# State# s, Int# #)
sub :: forall s. Counter s -> Int# -> State# s -> (# State# s, Int# #)
sub (Counter MutableByteArray# s
arr) = forall d.
MutableByteArray# d
-> Int# -> Int# -> State# d -> (# State# d, Int# #)
fetchSubIntArray# MutableByteArray# s
arr Int#
0#


{-# INLINE and #-}
-- | Atomically combine old value with a new one via bitwise and. Returns old counter value.
and :: Counter s -> Int# -> State# s -> (# State# s, Int# #)
and :: forall s. Counter s -> Int# -> State# s -> (# State# s, Int# #)
and (Counter MutableByteArray# s
arr) = forall d.
MutableByteArray# d
-> Int# -> Int# -> State# d -> (# State# d, Int# #)
fetchAndIntArray# MutableByteArray# s
arr Int#
0#

{-# INLINE or #-}
-- | Atomically combine old value with a new one via bitwise or. Returns old counter value.
or :: Counter s -> Int# -> State# s -> (# State# s, Int# #)
or :: forall s. Counter s -> Int# -> State# s -> (# State# s, Int# #)
or (Counter MutableByteArray# s
arr) = forall d.
MutableByteArray# d
-> Int# -> Int# -> State# d -> (# State# d, Int# #)
fetchOrIntArray# MutableByteArray# s
arr Int#
0#

{-# INLINE xor #-}
-- | Atomically combine old value with a new one via bitwise xor. Returns old counter value.
xor :: Counter s -> Int# -> State# s -> (# State# s, Int# #)
xor :: forall s. Counter s -> Int# -> State# s -> (# State# s, Int# #)
xor (Counter MutableByteArray# s
arr) = forall d.
MutableByteArray# d
-> Int# -> Int# -> State# d -> (# State# d, Int# #)
fetchXorIntArray# MutableByteArray# s
arr Int#
0#

{-# INLINE nand #-}
-- | Atomically combine old value with a new one via bitwise nand. Returns old counter value.
nand :: Counter s -> Int# -> State# s -> (# State# s, Int# #)
nand :: forall s. Counter s -> Int# -> State# s -> (# State# s, Int# #)
nand (Counter MutableByteArray# s
arr) = forall d.
MutableByteArray# d
-> Int# -> Int# -> State# d -> (# State# d, Int# #)
fetchNandIntArray# MutableByteArray# s
arr Int#
0#

-- | Compare the underlying pointers of two counters.
sameCounter :: Counter s -> Counter s -> Bool
sameCounter :: forall s. Counter s -> Counter s -> Bool
sameCounter (Counter MutableByteArray# s
x) (Counter MutableByteArray# s
y) =
  Int# -> Bool
isTrue# (forall d. MutableByteArray# d -> MutableByteArray# d -> Int#
sameMutableByteArray# MutableByteArray# s
x MutableByteArray# s
y)

#endif