{-# Language GADTs, DataKinds, TypeOperators, BangPatterns #-}
{-# Language PatternGuards #-}
{-# Language TypeApplications, ScopedTypeVariables #-}
{-# Language Rank2Types, RoleAnnotations #-}
{-# Language CPP #-}
#if __GLASGOW_HASKELL__ >= 805
{-# Language NoStarIsType #-}
#endif
module Data.Parameterized.Vector
( Vector
, fromList
, toList
, length
, nonEmpty
, lengthInt
, elemAt
, elemAtMaybe
, elemAtUnsafe
, insertAt
, insertAtMaybe
, uncons
, slice
, Data.Parameterized.Vector.take
, zipWith
, zipWithM
, zipWithM_
, interleave
, shuffle
, reverse
, rotateL
, rotateR
, shiftL
, shiftR
, singleton
, cons
, snoc
, generate
, generateM
, joinWithM
, joinWith
, splitWith
, splitWithA
, split
, join
, append
) where
import qualified Data.Vector as Vector
import Data.Functor.Compose
import Data.Coerce
import Data.Vector.Mutable (MVector)
import qualified Data.Vector.Mutable as MVector
import Control.Monad.ST
import Data.Functor.Identity
import Data.Parameterized.NatRepr
import Data.Parameterized.NatRepr.Internal
import Data.Proxy
import Prelude hiding (length,reverse,zipWith)
import Numeric.Natural
import Data.Parameterized.Utils.Endian
data Vector n a where
Vector :: (1 <= n) => !(Vector.Vector a) -> Vector n a
type role Vector nominal representational
instance Eq a => Eq (Vector n a) where
(Vector x) == (Vector y) = x == y
instance Show a => Show (Vector n a) where
show (Vector x) = show x
toList :: Vector n a -> [a]
toList (Vector v) = Vector.toList v
{-# Inline toList #-}
length :: Vector n a -> NatRepr n
length (Vector xs) = NatRepr (fromIntegral (Vector.length xs) :: Natural)
{-# INLINE length #-}
lengthInt :: Vector n a -> Int
lengthInt (Vector xs) = Vector.length xs
{-# Inline lengthInt #-}
elemAt :: ((i+1) <= n) => NatRepr i -> Vector n a -> a
elemAt n (Vector xs) = xs Vector.! widthVal n
elemAtMaybe :: Int -> Vector n a -> Maybe a
elemAtMaybe n (Vector xs) = xs Vector.!? n
{-# INLINE elemAt #-}
elemAtUnsafe :: Int -> Vector n a -> a
elemAtUnsafe n (Vector xs) = xs Vector.! n
{-# INLINE elemAtUnsafe #-}
insertAt :: ((i + 1) <= n) => NatRepr i -> a -> Vector n a -> Vector n a
insertAt n a (Vector xs) = Vector (Vector.unsafeUpd xs [(widthVal n,a)])
insertAtMaybe :: Int -> a -> Vector n a -> Maybe (Vector n a)
insertAtMaybe n a (Vector xs)
| 0 <= n && n < Vector.length xs = Just (Vector (Vector.unsafeUpd xs [(n,a)]))
| otherwise = Nothing
nonEmpty :: Vector n a -> LeqProof 1 n
nonEmpty (Vector _) = LeqProof
{-# Inline nonEmpty #-}
uncons :: forall n a. Vector n a -> (a, Either (n :~: 1) (Vector (n-1) a))
uncons v@(Vector xs) = (Vector.head xs, mbTail)
where
mbTail :: Either (n :~: 1) (Vector (n - 1) a)
mbTail = case testStrictLeq (knownNat @1) (length v) of
Left n2_leq_n ->
do LeqProof <- return (leqSub2 n2_leq_n (leqRefl (knownNat @1)))
return (Vector (Vector.tail xs))
Right Refl -> Left Refl
{-# Inline uncons #-}
fromList :: (1 <= n) => NatRepr n -> [a] -> Maybe (Vector n a)
fromList n xs
| widthVal n == Vector.length v = Just (Vector v)
| otherwise = Nothing
where
v = Vector.fromList xs
{-# INLINE fromList #-}
slice :: (i + w <= n, 1 <= w) =>
NatRepr i ->
NatRepr w ->
Vector n a -> Vector w a
slice i w (Vector xs) = Vector (Vector.slice (widthVal i) (widthVal w) xs)
{-# INLINE slice #-}
take :: forall n x a. (1 <= n) => NatRepr n -> Vector (n + x) a -> Vector n a
take | LeqProof <- prf = slice (knownNat @0)
where
prf = leqAdd (leqRefl (Proxy @n)) (Proxy @x)
instance Functor (Vector n) where
fmap f (Vector xs) = Vector (Vector.map f xs)
{-# Inline fmap #-}
instance Foldable (Vector n) where
foldMap f (Vector xs) = foldMap f xs
instance Traversable (Vector n) where
traverse f (Vector xs) = Vector <$> traverse f xs
{-# Inline traverse #-}
zipWith :: (a -> b -> c) -> Vector n a -> Vector n b -> Vector n c
zipWith f (Vector xs) (Vector ys) = Vector (Vector.zipWith f xs ys)
{-# Inline zipWith #-}
zipWithM :: Monad m => (a -> b -> m c) ->
Vector n a -> Vector n b -> m (Vector n c)
zipWithM f (Vector xs) (Vector ys) = Vector <$> Vector.zipWithM f xs ys
{-# Inline zipWithM #-}
zipWithM_ :: Monad m => (a -> b -> m ()) -> Vector n a -> Vector n b -> m ()
zipWithM_ f (Vector xs) (Vector ys) = Vector.zipWithM_ f xs ys
{-# Inline zipWithM_ #-}
interleave ::
forall n a. (1 <= n) => Vector n a -> Vector n a -> Vector (2 * n) a
interleave (Vector xs) (Vector ys)
| LeqProof <- leqMulPos (Proxy @2) (Proxy @n) = Vector zs
where
len = Vector.length xs + Vector.length ys
zs = Vector.generate len (\i -> let v = if even i then xs else ys
in v Vector.! (i `div` 2))
shuffle :: (Int -> Int) -> Vector n a -> Vector n a
shuffle f (Vector xs) = Vector ys
where
ys = Vector.generate (Vector.length xs) (\i -> xs Vector.! f i)
{-# Inline shuffle #-}
reverse :: forall a n. (1 <= n) => Vector n a -> Vector n a
reverse x = shuffle (\i -> lengthInt x - i - 1) x
rotateL :: Int -> Vector n a -> Vector n a
rotateL !n xs = shuffle rotL xs
where
!len = lengthInt xs
rotL i = (i + n) `mod` len
{-# Inline rotateL #-}
rotateR :: Int -> Vector n a -> Vector n a
rotateR !n xs = shuffle rotR xs
where
!len = lengthInt xs
rotR i = (i - n) `mod` len
{-# Inline rotateR #-}
shiftL :: Int -> a -> Vector n a -> Vector n a
shiftL !x a (Vector xs) = Vector ys
where
!len = Vector.length xs
ys = Vector.generate len (\i -> let j = i + x
in if j >= len then a else xs Vector.! j)
{-# Inline shiftL #-}
shiftR :: Int -> a -> Vector n a -> Vector n a
shiftR !x a (Vector xs) = Vector ys
where
!len = Vector.length xs
ys = Vector.generate len (\i -> let j = i - x
in if j < 0 then a else xs Vector.! j)
{-# Inline shiftR #-}
append :: Vector m a -> Vector n a -> Vector (m + n) a
append v1@(Vector xs) v2@(Vector ys) =
case leqAddPos (length v1) (length v2) of { LeqProof ->
Vector (xs Vector.++ ys)
}
{-# Inline append #-}
singleton :: forall a. a -> Vector 1 a
singleton a = Vector (Vector.singleton a)
leqLen :: forall n a. Vector n a -> LeqProof 1 (n + 1)
leqLen v =
let leqSucc :: forall f z. f z -> LeqProof z (z + 1)
leqSucc fz = leqAdd (leqRefl fz :: LeqProof z z) (knownNat @1)
in leqTrans (nonEmpty v :: LeqProof 1 n) (leqSucc (length v))
cons :: forall n a. a -> Vector n a -> Vector (n+1) a
cons a v@(Vector x) = case leqLen v of LeqProof -> (Vector (Vector.cons a x))
snoc :: forall n a. Vector n a -> a -> Vector (n+1) a
snoc v@(Vector x) a = case leqLen v of LeqProof -> (Vector (Vector.snoc x a))
newtype Vector' a n = MkVector' (Vector (n+1) a)
unVector' :: Vector' a n -> Vector (n+1) a
unVector' (MkVector' v) = v
snoc' :: forall a m. Vector' a m -> a -> Vector' a (m+1)
snoc' v = MkVector' . snoc (unVector' v)
generate' :: forall h a
. NatRepr h
-> (forall n. (n <= h) => NatRepr n -> a)
-> Vector' a h
generate' h gen =
case isZeroOrGT1 h of
Left Refl -> base
Right LeqProof ->
case (minusPlusCancel h (knownNat @1) :: h - 1 + 1 :~: h) of { Refl ->
natRecBounded (decNat h) (decNat h) base step
}
where base :: Vector' a 0
base = MkVector' $ singleton (gen (knownNat @0))
step :: forall m. (1 <= h, m <= h - 1)
=> NatRepr m -> Vector' a m -> Vector' a (m + 1)
step m v =
case minusPlusCancel h (knownNat @1) :: h - 1 + 1 :~: h of { Refl ->
case (leqAdd2 (LeqProof :: LeqProof m (h-1))
(LeqProof :: LeqProof 1 1) :: LeqProof (m+1) h) of { LeqProof ->
snoc' v (gen (incNat m))
}}
generate :: forall h a
. NatRepr h
-> (forall n. (n <= h) => NatRepr n -> a)
-> Vector (h + 1) a
generate h gen = unVector' (generate' h gen)
generateM :: forall m h a. (Monad m)
=> NatRepr h
-> (forall n. (n <= h) => NatRepr n -> m a)
-> m (Vector (h + 1) a)
generateM h gen = sequence $ generate h gen
coerceVec :: Coercible a b => Vector n a -> Vector n b
coerceVec = coerce
joinWithM ::
forall m f n w.
(1 <= w, Monad m) =>
(forall l. (1 <= l) => NatRepr l -> f w -> f l -> m (f (w + l)))
-> NatRepr w
-> Vector n (f w)
-> m (f (n * w))
joinWithM jn w = fmap fst . go
where
go :: forall l. Vector l (f w) -> m (f (l * w), NatRepr (l * w))
go exprs =
case uncons exprs of
(a, Left Refl) -> return (a, w)
(a, Right rest) ->
case nonEmpty rest of { LeqProof ->
case leqMulPos (length rest) w of { LeqProof ->
case nonEmpty exprs of { LeqProof ->
case lemmaMul w (length exprs) of { Refl -> do
(res, sz) <- go rest
joined <- jn sz a res
return (joined, addNat w sz)
}}}}
joinWith ::
forall f n w.
(1 <= w) =>
(forall l. (1 <= l) => NatRepr l -> f w -> f l -> f (w + l))
-> NatRepr w
-> Vector n (f w)
-> f (n * w)
joinWith jn w v = runIdentity $ joinWithM (\n x -> pure . (jn n x)) w v
{-# Inline joinWith #-}
splitWith :: forall f w n.
(1 <= w, 1 <= n) =>
Endian ->
(forall i. (i + w <= n * w) =>
NatRepr (n * w) -> NatRepr i -> f (n * w) -> f w)
->
NatRepr n -> NatRepr w -> f (n * w) -> Vector n (f w)
splitWith endian select n w val = Vector (Vector.create initializer)
where
len = widthVal n
start :: Int
next :: Int -> Int
(start,next) = case endian of
LittleEndian -> (0, succ)
BigEndian -> (len - 1, pred)
initializer :: forall s. ST s (MVector s (f w))
initializer =
do LeqProof <- return (leqMulPos n w)
LeqProof <- return (leqMulMono n w)
v <- MVector.new len
let fill :: Int -> NatRepr i -> ST s ()
fill loc i =
let end = addNat i w in
case testLeq end inLen of
Just LeqProof ->
do MVector.write v loc (select inLen i val)
fill (next loc) end
Nothing -> return ()
fill start (knownNat @0)
return v
inLen :: NatRepr (n * w)
inLen = natMultiply n w
{-# Inline splitWith #-}
splitWithA :: forall f g w n. (Applicative f, 1 <= w, 1 <= n) =>
Endian ->
(forall i. (i + w <= n * w) =>
NatRepr (n * w) -> NatRepr i -> g (n * w) -> f (g w))
->
NatRepr n -> NatRepr w -> g (n * w) -> f (Vector n (g w))
splitWithA e select n w val = traverse getCompose $
splitWith @(Compose f g) e select' n w $ Compose (pure val)
where
select' :: (forall i. (i + w <= n * w)
=> NatRepr (n * w) -> NatRepr i -> Compose f g (n * w) -> Compose f g w)
select' nw i _ = Compose $ select nw i val
newtype Vec a n = Vec (Vector n a)
vSlice :: (i + w <= l, 1 <= w) =>
NatRepr w -> NatRepr l -> NatRepr i -> Vec a l -> Vec a w
vSlice w _ i (Vec xs) = Vec (slice i w xs)
{-# Inline vSlice #-}
vAppend :: NatRepr n -> Vec a m -> Vec a n -> Vec a (m + n)
vAppend _ (Vec xs) (Vec ys) = Vec (append xs ys)
{-# Inline vAppend #-}
split :: (1 <= w, 1 <= n) =>
NatRepr n
-> NatRepr w
-> Vector (n * w) a
-> Vector n (Vector w a)
split n w xs = coerceVec (splitWith LittleEndian (vSlice w) n w (Vec xs))
{-# Inline split #-}
join :: (1 <= w) => NatRepr w -> Vector n (Vector w a) -> Vector (n * w) a
join w xs = ys
where Vec ys = joinWith vAppend w (coerceVec xs)
{-# Inline join #-}