{-# LANGUAGE DataKinds             #-}
{-# LANGUAGE FlexibleContexts      #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE PolyKinds             #-}
{-# LANGUAGE Rank2Types            #-}
{-# LANGUAGE ScopedTypeVariables   #-}
{-# LANGUAGE TypeFamilies          #-}
-- |
-- Type classes for vectors which are implemented on top of the arrays
-- and support in-place mutation. API is similar to one used in the
-- @vector@ package.
module Data.Vector.Fixed.Mutable (
    -- * Mutable vectors
    Arity
  , arity
  , Mutable
  , DimM
  , MVector(..)
  , lengthM
  , read
  , write
  , clone
    -- * Creation
  , replicate
  , replicateM
  , generate
  , generateM
    -- * Loops
  , forI
    -- * Immutable vectors
  , IVector(..)
  , index
  , freeze
  , thaw
    -- * Vector API
  , constructVec
  , inspectVec
  ) where

import Control.Applicative  (Const(..))
import Control.Monad.ST
import Control.Monad.Primitive
import Data.Typeable  (Proxy(..))
import GHC.TypeLits
import Data.Vector.Fixed.Cont (Dim,PeanoNum(..),Peano,Arity,Fun(..),Vector(..),ContVec,arity,apply,accum,length)
import Prelude hiding (read,length,replicate)


----------------------------------------------------------------
-- Type classes
----------------------------------------------------------------

-- | Mutable counterpart of fixed-length vector.
type family Mutable (v :: * -> *) :: * -> * -> *

-- | Dimension for mutable vector.
type family DimM (v :: * -> * -> *) :: Nat

-- | Type class for mutable vectors.
class (Arity (DimM v)) => MVector v a where
  -- | Copy vector. The two vectors may not overlap. Since vectors'
  --   length is encoded in the type there is no need in runtime checks.
  copy :: PrimMonad m
       => v (PrimState m) a    -- ^ Target
       -> v (PrimState m) a    -- ^ Source
       -> m ()
  -- | Copy vector. The two vectors may overlap. Since vectors' length
  --   is encoded in the type there is no need in runtime checks.
  move :: PrimMonad m
       => v (PrimState m) a    -- ^ Target
       -> v (PrimState m) a    -- ^ Source
       -> m ()
  -- | Allocate new vector
  new   :: PrimMonad m => m (v (PrimState m) a)
  -- | Read value at index without bound checks.
  unsafeRead  :: PrimMonad m => v (PrimState m) a -> Int -> m a
  -- | Write value at index without bound checks.
  unsafeWrite :: PrimMonad m => v (PrimState m) a -> Int -> a -> m ()


-- | Length of mutable vector. Function doesn't evaluate its argument.
lengthM :: forall v s a. (Arity (DimM v)) => v s a -> Int
lengthM :: v s a -> Int
lengthM v s a
_ = Proxy (DimM v) -> Int
forall (n :: Nat) (proxy :: Nat -> *). KnownNat n => proxy n -> Int
arity (Proxy (DimM v)
forall k (t :: k). Proxy t
Proxy :: Proxy (DimM v))

-- | Create copy of vector.
--
--   Examples:
--
--   >>> import Control.Monad.ST (runST)
--   >>> import Data.Vector.Fixed (mk3)
--   >>> import Data.Vector.Fixed.Boxed (Vec3)
--   >>> import qualified Data.Vector.Fixed.Mutable as M
--   >>> let x = runST (do { v <- M.replicate 100; v' <- clone v; M.write v' 0 2; M.unsafeFreeze v' }) :: Vec3 Int
--   >>> x
--   fromList [2,100,100]
clone :: (PrimMonad m, MVector v a) => v (PrimState m) a -> m (v (PrimState m) a)
{-# INLINE clone #-}
clone :: v (PrimState m) a -> m (v (PrimState m) a)
clone v (PrimState m) a
v = do
  v (PrimState m) a
u <- m (v (PrimState m) a)
forall (v :: * -> * -> *) a (m :: * -> *).
(MVector v a, PrimMonad m) =>
m (v (PrimState m) a)
new
  v (PrimState m) a -> v (PrimState m) a -> m ()
forall (v :: * -> * -> *) a (m :: * -> *).
(MVector v a, PrimMonad m) =>
v (PrimState m) a -> v (PrimState m) a -> m ()
move v (PrimState m) a
u v (PrimState m) a
v
  v (PrimState m) a -> m (v (PrimState m) a)
forall (m :: * -> *) a. Monad m => a -> m a
return v (PrimState m) a
u

-- | Read value at index with bound checks.
read  :: (PrimMonad m, MVector v a) => v (PrimState m) a -> Int -> m a
{-# INLINE read #-}
read :: v (PrimState m) a -> Int -> m a
read v (PrimState m) a
v Int
i
  | Int
i Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
0 Bool -> Bool -> Bool
|| Int
i Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= v (PrimState m) a -> Int
forall (v :: * -> * -> *) s a. Arity (DimM v) => v s a -> Int
lengthM v (PrimState m) a
v = [Char] -> m a
forall a. HasCallStack => [Char] -> a
error [Char]
"Data.Vector.Fixed.Mutable.read: index out of range"
  | Bool
otherwise               = v (PrimState m) a -> Int -> m a
forall (v :: * -> * -> *) a (m :: * -> *).
(MVector v a, PrimMonad m) =>
v (PrimState m) a -> Int -> m a
unsafeRead v (PrimState m) a
v Int
i

-- | Write value at index with bound checks.
write :: (PrimMonad m, MVector v a) => v (PrimState m) a -> Int -> a -> m ()
{-# INLINE write #-}
write :: v (PrimState m) a -> Int -> a -> m ()
write v (PrimState m) a
v Int
i a
x
  | Int
i Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
0 Bool -> Bool -> Bool
|| Int
i Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= v (PrimState m) a -> Int
forall (v :: * -> * -> *) s a. Arity (DimM v) => v s a -> Int
lengthM v (PrimState m) a
v = [Char] -> m ()
forall a. HasCallStack => [Char] -> a
error [Char]
"Data.Vector.Fixed.Mutable.write: index out of range"
  | Bool
otherwise               = v (PrimState m) a -> Int -> a -> m ()
forall (v :: * -> * -> *) a (m :: * -> *).
(MVector v a, PrimMonad m) =>
v (PrimState m) a -> Int -> a -> m ()
unsafeWrite v (PrimState m) a
v Int
i a
x


-- | Create new vector with all elements set to given value.
replicate :: (PrimMonad m, MVector v a) => a -> m (v (PrimState m) a)
{-# INLINE replicate #-}
replicate :: a -> m (v (PrimState m) a)
replicate a
a = do
  v (PrimState m) a
v <- m (v (PrimState m) a)
forall (v :: * -> * -> *) a (m :: * -> *).
(MVector v a, PrimMonad m) =>
m (v (PrimState m) a)
new
  v (PrimState m) a -> (Int -> m ()) -> m ()
forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> (Int -> m ()) -> m ()
forI v (PrimState m) a
v ((Int -> m ()) -> m ()) -> (Int -> m ()) -> m ()
forall a b. (a -> b) -> a -> b
$ \Int
i -> v (PrimState m) a -> Int -> a -> m ()
forall (v :: * -> * -> *) a (m :: * -> *).
(MVector v a, PrimMonad m) =>
v (PrimState m) a -> Int -> a -> m ()
unsafeWrite v (PrimState m) a
v Int
i a
a
  v (PrimState m) a -> m (v (PrimState m) a)
forall (f :: * -> *) a. Applicative f => a -> f a
pure v (PrimState m) a
v

-- | Create new vector with all elements are generated by provided
--   monadic action.
replicateM :: (PrimMonad m, MVector v a) => m a -> m (v (PrimState m) a)
{-# INLINE replicateM #-}
replicateM :: m a -> m (v (PrimState m) a)
replicateM m a
m = do
  v (PrimState m) a
v <- m (v (PrimState m) a)
forall (v :: * -> * -> *) a (m :: * -> *).
(MVector v a, PrimMonad m) =>
m (v (PrimState m) a)
new
  v (PrimState m) a -> (Int -> m ()) -> m ()
forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> (Int -> m ()) -> m ()
forI v (PrimState m) a
v ((Int -> m ()) -> m ()) -> (Int -> m ()) -> m ()
forall a b. (a -> b) -> a -> b
$ \Int
i -> v (PrimState m) a -> Int -> a -> m ()
forall (v :: * -> * -> *) a (m :: * -> *).
(MVector v a, PrimMonad m) =>
v (PrimState m) a -> Int -> a -> m ()
unsafeWrite v (PrimState m) a
v Int
i (a -> m ()) -> m a -> m ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< m a
m
  v (PrimState m) a -> m (v (PrimState m) a)
forall (f :: * -> *) a. Applicative f => a -> f a
pure v (PrimState m) a
v

-- | Create new vector with using function from index to value.
generate :: (PrimMonad m, MVector v a) => (Int -> a) -> m (v (PrimState m) a)
{-# INLINE generate #-}
generate :: (Int -> a) -> m (v (PrimState m) a)
generate Int -> a
f = do
  v (PrimState m) a
v <- m (v (PrimState m) a)
forall (v :: * -> * -> *) a (m :: * -> *).
(MVector v a, PrimMonad m) =>
m (v (PrimState m) a)
new
  v (PrimState m) a -> (Int -> m ()) -> m ()
forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> (Int -> m ()) -> m ()
forI v (PrimState m) a
v ((Int -> m ()) -> m ()) -> (Int -> m ()) -> m ()
forall a b. (a -> b) -> a -> b
$ \Int
i -> v (PrimState m) a -> Int -> a -> m ()
forall (v :: * -> * -> *) a (m :: * -> *).
(MVector v a, PrimMonad m) =>
v (PrimState m) a -> Int -> a -> m ()
unsafeWrite v (PrimState m) a
v Int
i (a -> m ()) -> a -> m ()
forall a b. (a -> b) -> a -> b
$ Int -> a
f Int
i
  v (PrimState m) a -> m (v (PrimState m) a)
forall (f :: * -> *) a. Applicative f => a -> f a
pure v (PrimState m) a
v

-- | Create new vector with using monadic function from index to value.
generateM :: (PrimMonad m, MVector v a) => (Int -> m a) -> m (v (PrimState m) a)
{-# INLINE generateM #-}
generateM :: (Int -> m a) -> m (v (PrimState m) a)
generateM Int -> m a
f = do
  v (PrimState m) a
v <- m (v (PrimState m) a)
forall (v :: * -> * -> *) a (m :: * -> *).
(MVector v a, PrimMonad m) =>
m (v (PrimState m) a)
new
  v (PrimState m) a -> (Int -> m ()) -> m ()
forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> (Int -> m ()) -> m ()
forI v (PrimState m) a
v ((Int -> m ()) -> m ()) -> (Int -> m ()) -> m ()
forall a b. (a -> b) -> a -> b
$ \Int
i -> v (PrimState m) a -> Int -> a -> m ()
forall (v :: * -> * -> *) a (m :: * -> *).
(MVector v a, PrimMonad m) =>
v (PrimState m) a -> Int -> a -> m ()
unsafeWrite v (PrimState m) a
v Int
i (a -> m ()) -> m a -> m ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Int -> m a
f Int
i
  v (PrimState m) a -> m (v (PrimState m) a)
forall (f :: * -> *) a. Applicative f => a -> f a
pure v (PrimState m) a
v

-- | Loop which calls function for each index
forI :: (PrimMonad m, MVector v a) => v (PrimState m) a -> (Int -> m ()) -> m ()
{-# INLINE forI #-}
forI :: v (PrimState m) a -> (Int -> m ()) -> m ()
forI v (PrimState m) a
v Int -> m ()
f = Int -> m ()
go Int
0
  where
    go :: Int -> m ()
go Int
i | Int
i Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
n    = () -> m ()
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
         | Bool
otherwise = Int -> m ()
f Int
i m () -> m () -> m ()
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Int -> m ()
go (Int
iInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1)
    n :: Int
n = v (PrimState m) a -> Int
forall (v :: * -> * -> *) s a. Arity (DimM v) => v s a -> Int
lengthM v (PrimState m) a
v


----------------------------------------------------------------
-- Immutable
----------------------------------------------------------------

-- | Type class for immutable vectors
class (Dim v ~ DimM (Mutable v), MVector (Mutable v) a) => IVector v a where
  -- | Convert vector to immutable state. Mutable vector must not be
  --   modified afterwards.
  unsafeFreeze :: PrimMonad m => Mutable v (PrimState m) a -> m (v a)
  -- | /O(1)/ Unsafely convert immutable vector to mutable without
  --   copying.  Note that this is a very dangerous function and
  --   generally it's only safe to read from the resulting vector. In
  --   this case, the immutable vector could be used safely as well.
  --
  -- Problems with mutation happen because GHC has a lot of freedom to
  -- introduce sharing. As a result mutable vectors produced by
  -- @unsafeThaw@ may or may not share the same underlying buffer. For
  -- example:
  --
  -- > foo = do
  -- >   let vec = F.generate 10 id
  -- >   mvec <- M.unsafeThaw vec
  -- >   do_something mvec
  --
  -- Here GHC could lift @vec@ outside of foo which means that all calls to
  -- @do_something@ will use same buffer with possibly disastrous
  -- results. Whether such aliasing happens or not depends on the program in
  -- question, optimization levels, and GHC flags.
  --
  -- All in all, attempts to modify a vector produced by @unsafeThaw@
  -- fall out of domain of software engineering and into realm of
  -- black magic, dark rituals, and unspeakable horrors. The only
  -- advice that could be given is: "Don't attempt to mutate a vector
  -- produced by @unsafeThaw@ unless you know how to prevent GHC from
  -- aliasing buffers accidentally. We don't."
  unsafeThaw   :: PrimMonad m => v a -> m (Mutable v (PrimState m) a)
  -- | Get element at specified index without bounds check.
  unsafeIndex :: v a -> Int -> a

index :: IVector v a => v a -> Int -> a
{-# INLINE index #-}
index :: v a -> Int -> a
index v a
v Int
i | Int
i Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
0 Bool -> Bool -> Bool
|| Int
i Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= v a -> Int
forall (v :: * -> *) a. KnownNat (Dim v) => v a -> Int
length v a
v = [Char] -> a
forall a. HasCallStack => [Char] -> a
error [Char]
"Data.Vector.Fixed.Mutable.!: index out of bounds"
          | Bool
otherwise              = v a -> Int -> a
forall (v :: * -> *) a. IVector v a => v a -> Int -> a
unsafeIndex v a
v Int
i


-- | Safely convert mutable vector to immutable.
freeze :: (PrimMonad m, IVector v a) => Mutable v (PrimState m) a -> m (v a)
{-# INLINE freeze #-}
freeze :: Mutable v (PrimState m) a -> m (v a)
freeze Mutable v (PrimState m) a
v = Mutable v (PrimState m) a -> m (v a)
forall (v :: * -> *) a (m :: * -> *).
(IVector v a, PrimMonad m) =>
Mutable v (PrimState m) a -> m (v a)
unsafeFreeze (Mutable v (PrimState m) a -> m (v a))
-> m (Mutable v (PrimState m) a) -> m (v a)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Mutable v (PrimState m) a -> m (Mutable v (PrimState m) a)
forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> m (v (PrimState m) a)
clone Mutable v (PrimState m) a
v

-- | Safely convert immutable vector to mutable.
thaw :: (PrimMonad m, IVector v a) => v a -> m (Mutable v (PrimState m) a)
{-# INLINE thaw #-}
thaw :: v a -> m (Mutable v (PrimState m) a)
thaw v a
v = Mutable v (PrimState m) a -> m (Mutable v (PrimState m) a)
forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> m (v (PrimState m) a)
clone (Mutable v (PrimState m) a -> m (Mutable v (PrimState m) a))
-> m (Mutable v (PrimState m) a) -> m (Mutable v (PrimState m) a)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< v a -> m (Mutable v (PrimState m) a)
forall (v :: * -> *) a (m :: * -> *).
(IVector v a, PrimMonad m) =>
v a -> m (Mutable v (PrimState m) a)
unsafeThaw v a
v



----------------------------------------------------------------
-- Vector API
----------------------------------------------------------------

-- | Generic inspect implementation for array-based vectors.
inspectVec :: forall v a b. (Arity (Dim v), IVector v a) => v a -> Fun (Peano (Dim v)) a b -> b
{-# INLINE inspectVec #-}
inspectVec :: v a -> Fun (Peano (Dim v)) a b -> b
inspectVec v a
v
  = ContVec (DimM (Mutable v)) a
-> Fun (Peano (Dim (ContVec (DimM (Mutable v))))) a b -> b
forall (v :: * -> *) a b.
Vector v a =>
v a -> Fun (Peano (Dim v)) a b -> b
inspect ContVec (Dim v) a
ContVec (DimM (Mutable v)) a
cv
  where
    cv :: ContVec (Dim v) a
    cv :: ContVec (Dim v) a
cv = (forall (k :: PeanoNum). Const Int ('S k) -> (a, Const Int k))
-> Const Int (Peano (DimM (Mutable v)))
-> ContVec (DimM (Mutable v)) a
forall (n :: Nat) (t :: PeanoNum -> *) a.
Arity n =>
(forall (k :: PeanoNum). t ('S k) -> (a, t k))
-> t (Peano n) -> ContVec n a
apply (\(Const i) -> (v a -> Int -> a
forall (v :: * -> *) a. IVector v a => v a -> Int -> a
unsafeIndex v a
v Int
i, Int -> Const Int k
forall k a (b :: k). a -> Const a b
Const (Int
iInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1)))
               (Int -> Const Int (Peano (DimM (Mutable v)))
forall k a (b :: k). a -> Const a b
Const Int
0 :: Const Int (Peano (Dim v)))

-- | Generic construct implementation for array-based vectors.
constructVec :: forall v a. (Arity (Dim v), IVector v a) => Fun (Peano (Dim v)) a (v a)
{-# INLINE constructVec #-}
constructVec :: Fun (Peano (Dim v)) a (v a)
constructVec =
  (forall (k :: PeanoNum). T_new v a ('S k) -> a -> T_new v a k)
-> (T_new v a 'Z -> v a)
-> T_new v a (Peano (DimM (Mutable v)))
-> Fun (Peano (DimM (Mutable v))) a (v a)
forall (n :: PeanoNum) (t :: PeanoNum -> *) a b.
ArityPeano n =>
(forall (k :: PeanoNum). t ('S k) -> a -> t k)
-> (t 'Z -> b) -> t n -> Fun n a b
accum forall (k :: PeanoNum). T_new v a ('S k) -> a -> T_new v a k
forall (v :: * -> *) a (n :: PeanoNum).
IVector v a =>
T_new v a ('S n) -> a -> T_new v a n
step
        (\(T_new Int
_ forall s. ST s (Mutable v s a)
st) -> (forall s. ST s (v a)) -> v a
forall a. (forall s. ST s a) -> a
runST ((forall s. ST s (v a)) -> v a) -> (forall s. ST s (v a)) -> v a
forall a b. (a -> b) -> a -> b
$ Mutable v s a -> ST s (v a)
forall (v :: * -> *) a (m :: * -> *).
(IVector v a, PrimMonad m) =>
Mutable v (PrimState m) a -> m (v a)
unsafeFreeze (Mutable v s a -> ST s (v a)) -> ST s (Mutable v s a) -> ST s (v a)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< ST s (Mutable v s a)
forall s. ST s (Mutable v s a)
st :: v a)
        (Int
-> (forall s. ST s (Mutable v s a))
-> T_new v a (Peano (DimM (Mutable v)))
forall k (v :: * -> *) a (n :: k).
Int -> (forall s. ST s (Mutable v s a)) -> T_new v a n
T_new Int
0 forall s. ST s (Mutable v s a)
forall (v :: * -> * -> *) a (m :: * -> *).
(MVector v a, PrimMonad m) =>
m (v (PrimState m) a)
new :: T_new v a (Peano (Dim v)))

data T_new v a n = T_new Int (forall s. ST s (Mutable v s a))

step :: (IVector v a) => T_new v a ('S n) -> a -> T_new v a n
step :: T_new v a ('S n) -> a -> T_new v a n
step (T_new Int
i forall s. ST s (Mutable v s a)
st) a
x = Int -> (forall s. ST s (Mutable v s a)) -> T_new v a n
forall k (v :: * -> *) a (n :: k).
Int -> (forall s. ST s (Mutable v s a)) -> T_new v a n
T_new (Int
iInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1) ((forall s. ST s (Mutable v s a)) -> T_new v a n)
-> (forall s. ST s (Mutable v s a)) -> T_new v a n
forall a b. (a -> b) -> a -> b
$ do
  Mutable v s a
mv <- ST s (Mutable v s a)
forall s. ST s (Mutable v s a)
st
  Mutable v (PrimState (ST s)) a -> Int -> a -> ST s ()
forall (v :: * -> * -> *) a (m :: * -> *).
(MVector v a, PrimMonad m) =>
v (PrimState m) a -> Int -> a -> m ()
unsafeWrite Mutable v s a
Mutable v (PrimState (ST s)) a
mv Int
i a
x
  Mutable v s a -> ST s (Mutable v s a)
forall (m :: * -> *) a. Monad m => a -> m a
return Mutable v s a
mv