module Agda.Utils.Pointer
  ( Ptr, newPtr, derefPtr, setPtr
  , updatePtr, updatePtrM
  ) where

import Control.DeepSeq
import Control.Concurrent.MVar

import Data.Function
import Data.Hashable
import Data.IORef

import System.IO.Unsafe

import Agda.Utils.Impossible

data Ptr a = Ptr { forall a. Ptr a -> Integer
ptrTag :: !Integer
                 , forall a. Ptr a -> IORef a
ptrRef :: !(IORef a) }

{-# NOINLINE freshVar #-}
freshVar :: MVar Integer
freshVar :: MVar Integer
freshVar = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a. a -> IO (MVar a)
newMVar Integer
0

fresh :: IO Integer
fresh :: IO Integer
fresh = do
    Integer
x <- forall a. MVar a -> IO a
takeMVar MVar Integer
freshVar
    forall a. MVar a -> a -> IO ()
putMVar MVar Integer
freshVar forall a b. (a -> b) -> a -> b
$! Integer
x forall a. Num a => a -> a -> a
+ Integer
1
    forall (m :: * -> *) a. Monad m => a -> m a
return Integer
x

{-# NOINLINE newPtr #-}
newPtr :: a -> Ptr a
newPtr :: forall a. a -> Ptr a
newPtr a
x = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ do
  Integer
i <- IO Integer
fresh
  forall a. Integer -> IORef a -> Ptr a
Ptr Integer
i forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall a. a -> IO (IORef a)
newIORef a
x

derefPtr :: Ptr a -> a
derefPtr :: forall a. Ptr a -> a
derefPtr Ptr a
p = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a. IORef a -> IO a
readIORef forall a b. (a -> b) -> a -> b
$ forall a. Ptr a -> IORef a
ptrRef Ptr a
p

{-# NOINLINE updatePtr #-}
updatePtr :: (a -> a) -> Ptr a -> Ptr a
updatePtr :: forall a. (a -> a) -> Ptr a -> Ptr a
updatePtr a -> a
f Ptr a
p = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ Ptr a
p forall (f :: * -> *) a b. Functor f => a -> f b -> f a
<$ forall a. IORef a -> (a -> a) -> IO ()
modifyIORef (forall a. Ptr a -> IORef a
ptrRef Ptr a
p) a -> a
f

setPtr :: a -> Ptr a -> Ptr a
setPtr :: forall a. a -> Ptr a -> Ptr a
setPtr !a
x = forall a. (a -> a) -> Ptr a -> Ptr a
updatePtr (forall a b. a -> b -> a
const a
x)

-- | If @f a@ contains many copies of @a@ they will all be the same pointer in
--   the result. If the function is well-behaved (i.e. preserves the implicit
--   equivalence, this shouldn't matter).
updatePtrM :: Functor f => (a -> f a) -> Ptr a -> f (Ptr a)
updatePtrM :: forall (f :: * -> *) a.
Functor f =>
(a -> f a) -> Ptr a -> f (Ptr a)
updatePtrM a -> f a
f Ptr a
p = forall a b c. (a -> b -> c) -> b -> a -> c
flip forall a. a -> Ptr a -> Ptr a
setPtr Ptr a
p forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> a -> f a
f (forall a. Ptr a -> a
derefPtr Ptr a
p)

instance Show a => Show (Ptr a) where
  show :: Ptr a -> String
show Ptr a
p = String
"#" forall a. [a] -> [a] -> [a]
++ forall a. Show a => a -> String
show (forall a. Ptr a -> Integer
ptrTag Ptr a
p) forall a. [a] -> [a] -> [a]
++ String
"{" forall a. [a] -> [a] -> [a]
++ forall a. Show a => a -> String
show (forall a. Ptr a -> a
derefPtr Ptr a
p) forall a. [a] -> [a] -> [a]
++ String
"}"

instance Functor Ptr where
  fmap :: forall a b. (a -> b) -> Ptr a -> Ptr b
fmap a -> b
f = forall a. a -> Ptr a
newPtr forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> b
f forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Ptr a -> a
derefPtr

instance Foldable Ptr where
  foldMap :: forall m a. Monoid m => (a -> m) -> Ptr a -> m
foldMap a -> m
f = a -> m
f forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Ptr a -> a
derefPtr

instance Traversable Ptr where
  traverse :: forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> Ptr a -> f (Ptr b)
traverse a -> f b
f Ptr a
p = forall a. a -> Ptr a
newPtr forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> a -> f b
f (forall a. Ptr a -> a
derefPtr Ptr a
p)

instance Eq (Ptr a) where
  == :: Ptr a -> Ptr a -> Bool
(==) = forall a. Eq a => a -> a -> Bool
(==) forall b c a. (b -> b -> c) -> (a -> b) -> a -> a -> c
`on` forall a. Ptr a -> Integer
ptrTag

instance Ord (Ptr a) where
  compare :: Ptr a -> Ptr a -> Ordering
compare = forall a. Ord a => a -> a -> Ordering
compare forall b c a. (b -> b -> c) -> (a -> b) -> a -> a -> c
`on` forall a. Ptr a -> Integer
ptrTag

instance Hashable (Ptr a) where
  hashWithSalt :: Int -> Ptr a -> Int
hashWithSalt Int
salt = (forall a. Hashable a => Int -> a -> Int
hashWithSalt Int
salt) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Ptr a -> Integer
ptrTag

instance NFData (Ptr a) where rnf :: Ptr a -> ()
rnf Ptr a
x = seq :: forall a b. a -> b -> b
seq Ptr a
x ()