{-# LANGUAGE FunctionalDependencies #-}
--------------------------------------------------------------------------------
-- |
-- Module      :  Control.CanAquire
-- Copyright   :  (C) Frank Staals
-- License     :  see the LICENSE file
-- Maintainer  :  Frank Staals
--------------------------------------------------------------------------------
module Control.CanAquire(
      runAcquire
    , CanAquire(..)
    , HasIndex(..)

    , replaceByIndex, labelWithIndex
    , I
    ) where

import           Control.Monad.ST.Strict
import           Control.Monad.State.Strict
import           Data.Reflection
import qualified Data.Vector as V
import qualified Data.Vector.Mutable as MV
import Unsafe.Coerce(unsafeCoerce)

--------------------------------------------------------------------------------

-- | Run a computation on something that can aquire i's.
runAcquire         :: forall t a b. Traversable t
                   => (forall s. CanAquire (I s a) a => t (I s a) -> b)
                   -> t a -> b
runAcquire :: (forall s. CanAquire (I s a) a => t (I s a) -> b) -> t a -> b
runAcquire forall s. CanAquire (I s a) a => t (I s a) -> b
alg t a
pts = Vector a -> (forall s. Reifies s (Vector a) => Proxy s -> b) -> b
forall a r. a -> (forall s. Reifies s a => Proxy s -> r) -> r
reify Vector a
v ((forall s. Reifies s (Vector a) => Proxy s -> b) -> b)
-> (forall s. Reifies s (Vector a) => Proxy s -> b) -> b
forall a b. (a -> b) -> a -> b
$ \Proxy s
px -> t (I s a) -> b
forall s. CanAquire (I s a) a => t (I s a) -> b
alg (Proxy s -> t Int -> t (I s a)
forall (proxy :: * -> *) s. proxy s -> t Int -> t (I s a)
coerceTS Proxy s
px t Int
ts)
  where
    (Vector a
v,t Int
ts) = t a -> (Vector a, t Int)
forall (t :: * -> *) a. Traversable t => t a -> (Vector a, t Int)
replaceByIndex t a
pts

    coerceTS   :: proxy s -> t Int -> t (I s a)
    coerceTS :: proxy s -> t Int -> t (I s a)
coerceTS proxy s
_ = t Int -> t (I s a)
forall a b. a -> b
unsafeCoerce -- fmap I
      -- Ideally this would just be a coerce. But GHC doesn't want to do that.

class HasIndex i Int => CanAquire i a where
  -- | A value of type i can obtain something of type @\'a\'@
  aquire  :: i -> a

class HasIndex t i | t -> i where
  -- | Types that have an instance of this class can act as indices.
  indexOf :: t -> i

--------------------------------------------------------------------------------

-- | Replaces every element by an index. Returns the new traversable
-- containing only these indices, as well as a vector with the
-- values. (such that indexing in this value gives the original
-- value).
replaceByIndex     :: forall t a. Traversable t => t a -> (V.Vector a, t Int)
replaceByIndex :: t a -> (Vector a, t Int)
replaceByIndex t a
ts' = (forall s. ST s (Vector a, t Int)) -> (Vector a, t Int)
forall a. (forall s. ST s a) -> a
runST ((forall s. ST s (Vector a, t Int)) -> (Vector a, t Int))
-> (forall s. ST s (Vector a, t Int)) -> (Vector a, t Int)
forall a b. (a -> b) -> a -> b
$ do
                               MVector s a
v <- Int -> ST s (MVector (PrimState (ST s)) a)
forall (m :: * -> *) a.
PrimMonad m =>
Int -> m (MVector (PrimState m) a)
MV.new Int
n
                               t Int
t <- ((Int, a) -> ST s Int) -> t (Int, a) -> ST s (t Int)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse (MVector s a -> (Int, a) -> ST s Int
forall s'. MVector s' a -> (Int, a) -> ST s' Int
lbl MVector s a
v) t (Int, a)
ts
                               (,t Int
t) (Vector a -> (Vector a, t Int))
-> ST s (Vector a) -> ST s (Vector a, t Int)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> MVector (PrimState (ST s)) a -> ST s (Vector a)
forall (m :: * -> *) a.
PrimMonad m =>
MVector (PrimState m) a -> m (Vector a)
V.unsafeFreeze MVector s a
MVector (PrimState (ST s)) a
v
  where
    (t (Int, a)
ts, Int
n) = t a -> (t (Int, a), Int)
forall (t :: * -> *) a. Traversable t => t a -> (t (Int, a), Int)
labelWithIndex t a
ts'

    lbl         :: MV.MVector s' a -> (Int,a) -> ST s' Int
    lbl :: MVector s' a -> (Int, a) -> ST s' Int
lbl MVector s' a
v (Int
i,a
x) = MVector (PrimState (ST s')) a -> Int -> a -> ST s' ()
forall (m :: * -> *) a.
PrimMonad m =>
MVector (PrimState m) a -> Int -> a -> m ()
MV.write MVector s' a
MVector (PrimState (ST s')) a
v Int
i a
x ST s' () -> ST s' Int -> ST s' Int
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Int -> ST s' Int
forall (f :: * -> *) a. Applicative f => a -> f a
pure Int
i

-- | Label each element with its index. Returns the new collection as
-- well as its size.
labelWithIndex :: Traversable t => t a -> (t (Int, a), Int)
labelWithIndex :: t a -> (t (Int, a), Int)
labelWithIndex = (State Int (t (Int, a)) -> Int -> (t (Int, a), Int))
-> Int -> State Int (t (Int, a)) -> (t (Int, a), Int)
forall a b c. (a -> b -> c) -> b -> a -> c
flip State Int (t (Int, a)) -> Int -> (t (Int, a), Int)
forall s a. State s a -> s -> (a, s)
runState Int
0 (State Int (t (Int, a)) -> (t (Int, a), Int))
-> (t a -> State Int (t (Int, a))) -> t a -> (t (Int, a), Int)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (a -> StateT Int Identity (Int, a))
-> t a -> State Int (t (Int, a))
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse a -> StateT Int Identity (Int, a)
forall a. a -> State Int (Int, a)
lbl
  where
    lbl   :: a -> State Int (Int,a)
    lbl :: a -> State Int (Int, a)
lbl a
x = do Int
i <- StateT Int Identity Int
forall s (m :: * -> *). MonadState s m => m s
get
               Int -> StateT Int Identity ()
forall s (m :: * -> *). MonadState s m => s -> m ()
put (Int -> StateT Int Identity ()) -> Int -> StateT Int Identity ()
forall a b. (a -> b) -> a -> b
$ Int
iInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1
               (Int, a) -> State Int (Int, a)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Int
i,a
x)


--------------------------------------------------------------------------------

-- | A type that can act as an Index.
newtype I (s :: *) a = I Int deriving (I s a -> I s a -> Bool
(I s a -> I s a -> Bool) -> (I s a -> I s a -> Bool) -> Eq (I s a)
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
forall s k (a :: k). I s a -> I s a -> Bool
/= :: I s a -> I s a -> Bool
$c/= :: forall s k (a :: k). I s a -> I s a -> Bool
== :: I s a -> I s a -> Bool
$c== :: forall s k (a :: k). I s a -> I s a -> Bool
Eq, Eq (I s a)
Eq (I s a)
-> (I s a -> I s a -> Ordering)
-> (I s a -> I s a -> Bool)
-> (I s a -> I s a -> Bool)
-> (I s a -> I s a -> Bool)
-> (I s a -> I s a -> Bool)
-> (I s a -> I s a -> I s a)
-> (I s a -> I s a -> I s a)
-> Ord (I s a)
I s a -> I s a -> Bool
I s a -> I s a -> Ordering
I s a -> I s a -> I s a
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
forall s k (a :: k). Eq (I s a)
forall s k (a :: k). I s a -> I s a -> Bool
forall s k (a :: k). I s a -> I s a -> Ordering
forall s k (a :: k). I s a -> I s a -> I s a
min :: I s a -> I s a -> I s a
$cmin :: forall s k (a :: k). I s a -> I s a -> I s a
max :: I s a -> I s a -> I s a
$cmax :: forall s k (a :: k). I s a -> I s a -> I s a
>= :: I s a -> I s a -> Bool
$c>= :: forall s k (a :: k). I s a -> I s a -> Bool
> :: I s a -> I s a -> Bool
$c> :: forall s k (a :: k). I s a -> I s a -> Bool
<= :: I s a -> I s a -> Bool
$c<= :: forall s k (a :: k). I s a -> I s a -> Bool
< :: I s a -> I s a -> Bool
$c< :: forall s k (a :: k). I s a -> I s a -> Bool
compare :: I s a -> I s a -> Ordering
$ccompare :: forall s k (a :: k). I s a -> I s a -> Ordering
$cp1Ord :: forall s k (a :: k). Eq (I s a)
Ord, Int -> I s a
I s a -> Int
I s a -> [I s a]
I s a -> I s a
I s a -> I s a -> [I s a]
I s a -> I s a -> I s a -> [I s a]
(I s a -> I s a)
-> (I s a -> I s a)
-> (Int -> I s a)
-> (I s a -> Int)
-> (I s a -> [I s a])
-> (I s a -> I s a -> [I s a])
-> (I s a -> I s a -> [I s a])
-> (I s a -> I s a -> I s a -> [I s a])
-> Enum (I s a)
forall a.
(a -> a)
-> (a -> a)
-> (Int -> a)
-> (a -> Int)
-> (a -> [a])
-> (a -> a -> [a])
-> (a -> a -> [a])
-> (a -> a -> a -> [a])
-> Enum a
forall s k (a :: k). Int -> I s a
forall s k (a :: k). I s a -> Int
forall s k (a :: k). I s a -> [I s a]
forall s k (a :: k). I s a -> I s a
forall s k (a :: k). I s a -> I s a -> [I s a]
forall s k (a :: k). I s a -> I s a -> I s a -> [I s a]
enumFromThenTo :: I s a -> I s a -> I s a -> [I s a]
$cenumFromThenTo :: forall s k (a :: k). I s a -> I s a -> I s a -> [I s a]
enumFromTo :: I s a -> I s a -> [I s a]
$cenumFromTo :: forall s k (a :: k). I s a -> I s a -> [I s a]
enumFromThen :: I s a -> I s a -> [I s a]
$cenumFromThen :: forall s k (a :: k). I s a -> I s a -> [I s a]
enumFrom :: I s a -> [I s a]
$cenumFrom :: forall s k (a :: k). I s a -> [I s a]
fromEnum :: I s a -> Int
$cfromEnum :: forall s k (a :: k). I s a -> Int
toEnum :: Int -> I s a
$ctoEnum :: forall s k (a :: k). Int -> I s a
pred :: I s a -> I s a
$cpred :: forall s k (a :: k). I s a -> I s a
succ :: I s a -> I s a
$csucc :: forall s k (a :: k). I s a -> I s a
Enum)

instance Show (I s a) where
  showsPrec :: Int -> I s a -> ShowS
showsPrec Int
i (I Int
j) = Int -> Int -> ShowS
forall a. Show a => Int -> a -> ShowS
showsPrec Int
i Int
j

instance HasIndex (I s a) Int where
  indexOf :: I s a -> Int
indexOf (I Int
i) = Int
i

instance Reifies s (V.Vector a) => I s a `CanAquire` a where
  aquire :: I s a -> a
aquire (I Int
i) = let v :: Vector a
v = Any s -> Vector a
forall k (s :: k) a (proxy :: k -> *). Reifies s a => proxy s -> a
reflect @s Any s
forall a. HasCallStack => a
undefined in Vector a
v Vector a -> Int -> a
forall a. Vector a -> Int -> a
V.! Int
i