module Agda.Utils.Pointer
  ( Ptr, newPtr, derefPtr, setPtr
  , updatePtr, updatePtrM
  ) where

import Control.DeepSeq
import Control.Concurrent.MVar

import Data.Function (on)
import Data.Hashable
import Data.IORef

import System.IO.Unsafe

data Ptr a = Ptr { forall a. Ptr a -> Integer
ptrTag :: !Integer
                 , forall a. Ptr a -> IORef a
ptrRef :: !(IORef a) }

{-# NOINLINE freshVar #-}
freshVar :: MVar Integer
freshVar :: MVar Integer
freshVar = IO (MVar Integer) -> MVar Integer
forall a. IO a -> a
unsafePerformIO (IO (MVar Integer) -> MVar Integer)
-> IO (MVar Integer) -> MVar Integer
forall a b. (a -> b) -> a -> b
$ Integer -> IO (MVar Integer)
forall a. a -> IO (MVar a)
newMVar Integer
0

fresh :: IO Integer
fresh :: IO Integer
fresh = do
    x <- MVar Integer -> IO Integer
forall a. MVar a -> IO a
takeMVar MVar Integer
freshVar
    putMVar freshVar $! x + 1
    return x

{-# NOINLINE newPtr #-}
newPtr :: a -> Ptr a
newPtr :: forall a. a -> Ptr a
newPtr a
x = IO (Ptr a) -> Ptr a
forall a. IO a -> a
unsafePerformIO (IO (Ptr a) -> Ptr a) -> IO (Ptr a) -> Ptr a
forall a b. (a -> b) -> a -> b
$ do
  i <- IO Integer
fresh
  Ptr i <$> newIORef x

derefPtr :: Ptr a -> a
derefPtr :: forall a. Ptr a -> a
derefPtr Ptr a
p = IO a -> a
forall a. IO a -> a
unsafePerformIO (IO a -> a) -> IO a -> a
forall a b. (a -> b) -> a -> b
$ IORef a -> IO a
forall a. IORef a -> IO a
readIORef (IORef a -> IO a) -> IORef a -> IO a
forall a b. (a -> b) -> a -> b
$ Ptr a -> IORef a
forall a. Ptr a -> IORef a
ptrRef Ptr a
p

{-# NOINLINE updatePtr #-}
updatePtr :: (a -> a) -> Ptr a -> Ptr a
updatePtr :: forall a. (a -> a) -> Ptr a -> Ptr a
updatePtr a -> a
f Ptr a
p = IO (Ptr a) -> Ptr a
forall a. IO a -> a
unsafePerformIO (IO (Ptr a) -> Ptr a) -> IO (Ptr a) -> Ptr a
forall a b. (a -> b) -> a -> b
$ Ptr a
p Ptr a -> IO () -> IO (Ptr a)
forall a b. a -> IO b -> IO a
forall (f :: * -> *) a b. Functor f => a -> f b -> f a
<$ IORef a -> (a -> a) -> IO ()
forall a. IORef a -> (a -> a) -> IO ()
modifyIORef (Ptr a -> IORef a
forall a. Ptr a -> IORef a
ptrRef Ptr a
p) a -> a
f

setPtr :: a -> Ptr a -> Ptr a
setPtr :: forall a. a -> Ptr a -> Ptr a
setPtr !a
x = (a -> a) -> Ptr a -> Ptr a
forall a. (a -> a) -> Ptr a -> Ptr a
updatePtr (a -> a -> a
forall a b. a -> b -> a
const a
x)

-- | If @f a@ contains many copies of @a@ they will all be the same pointer in
--   the result. If the function is well-behaved (i.e. preserves the implicit
--   equivalence, this shouldn't matter).
updatePtrM :: Functor f => (a -> f a) -> Ptr a -> f (Ptr a)
updatePtrM :: forall (f :: * -> *) a.
Functor f =>
(a -> f a) -> Ptr a -> f (Ptr a)
updatePtrM a -> f a
f Ptr a
p = (a -> Ptr a -> Ptr a) -> Ptr a -> a -> Ptr a
forall a b c. (a -> b -> c) -> b -> a -> c
flip a -> Ptr a -> Ptr a
forall a. a -> Ptr a -> Ptr a
setPtr Ptr a
p (a -> Ptr a) -> f a -> f (Ptr a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> a -> f a
f (Ptr a -> a
forall a. Ptr a -> a
derefPtr Ptr a
p)

instance Show a => Show (Ptr a) where
  show :: Ptr a -> String
show Ptr a
p = String
"#" String -> ShowS
forall a. [a] -> [a] -> [a]
++ Integer -> String
forall a. Show a => a -> String
show (Ptr a -> Integer
forall a. Ptr a -> Integer
ptrTag Ptr a
p) String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
"{" String -> ShowS
forall a. [a] -> [a] -> [a]
++ a -> String
forall a. Show a => a -> String
show (Ptr a -> a
forall a. Ptr a -> a
derefPtr Ptr a
p) String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
"}"

instance Functor Ptr where
  fmap :: forall a b. (a -> b) -> Ptr a -> Ptr b
fmap a -> b
f = b -> Ptr b
forall a. a -> Ptr a
newPtr (b -> Ptr b) -> (Ptr a -> b) -> Ptr a -> Ptr b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> b
f (a -> b) -> (Ptr a -> a) -> Ptr a -> b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Ptr a -> a
forall a. Ptr a -> a
derefPtr

instance Foldable Ptr where
  foldMap :: forall m a. Monoid m => (a -> m) -> Ptr a -> m
foldMap a -> m
f = a -> m
f (a -> m) -> (Ptr a -> a) -> Ptr a -> m
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Ptr a -> a
forall a. Ptr a -> a
derefPtr

instance Traversable Ptr where
  traverse :: forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> Ptr a -> f (Ptr b)
traverse a -> f b
f Ptr a
p = b -> Ptr b
forall a. a -> Ptr a
newPtr (b -> Ptr b) -> f b -> f (Ptr b)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> a -> f b
f (Ptr a -> a
forall a. Ptr a -> a
derefPtr Ptr a
p)

instance Eq (Ptr a) where
  == :: Ptr a -> Ptr a -> Bool
(==) = Integer -> Integer -> Bool
forall a. Eq a => a -> a -> Bool
(==) (Integer -> Integer -> Bool)
-> (Ptr a -> Integer) -> Ptr a -> Ptr a -> Bool
forall b c a. (b -> b -> c) -> (a -> b) -> a -> a -> c
`on` Ptr a -> Integer
forall a. Ptr a -> Integer
ptrTag

instance Ord (Ptr a) where
  compare :: Ptr a -> Ptr a -> Ordering
compare = Integer -> Integer -> Ordering
forall a. Ord a => a -> a -> Ordering
compare (Integer -> Integer -> Ordering)
-> (Ptr a -> Integer) -> Ptr a -> Ptr a -> Ordering
forall b c a. (b -> b -> c) -> (a -> b) -> a -> a -> c
`on` Ptr a -> Integer
forall a. Ptr a -> Integer
ptrTag

instance Hashable (Ptr a) where
  hashWithSalt :: Int -> Ptr a -> Int
hashWithSalt Int
salt = (Int -> Integer -> Int
forall a. Hashable a => Int -> a -> Int
hashWithSalt Int
salt) (Integer -> Int) -> (Ptr a -> Integer) -> Ptr a -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Ptr a -> Integer
forall a. Ptr a -> Integer
ptrTag

instance NFData (Ptr a) where rnf :: Ptr a -> ()
rnf Ptr a
x = Ptr a -> () -> ()
forall a b. a -> b -> b
seq Ptr a
x ()