{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeOperators    #-}

module ZkFold.Base.Data.Vector where

import           Data.Bifunctor                   (first)
import           Data.These                       (These (..))
import           Data.Zip                         (Semialign (..), Zip (..))
import           Numeric.Natural                  (Natural)
import           Prelude                          hiding (length, replicate, sum, zip, zipWith, (*))
import           System.Random                    (Random (..))
import           Test.QuickCheck                  (Arbitrary (..))

import           ZkFold.Base.Algebra.Basic.Class
import           ZkFold.Base.Algebra.Basic.Number
import           ZkFold.Base.Data.ByteString      (Binary (..))
import           ZkFold.Prelude                   (length, replicate)

newtype Vector (size :: Natural) a = Vector [a]
    deriving (Int -> Vector size a -> ShowS
[Vector size a] -> ShowS
Vector size a -> String
(Int -> Vector size a -> ShowS)
-> (Vector size a -> String)
-> ([Vector size a] -> ShowS)
-> Show (Vector size a)
forall (size :: Natural) a. Show a => Int -> Vector size a -> ShowS
forall (size :: Natural) a. Show a => [Vector size a] -> ShowS
forall (size :: Natural) a. Show a => Vector size a -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: forall (size :: Natural) a. Show a => Int -> Vector size a -> ShowS
showsPrec :: Int -> Vector size a -> ShowS
$cshow :: forall (size :: Natural) a. Show a => Vector size a -> String
show :: Vector size a -> String
$cshowList :: forall (size :: Natural) a. Show a => [Vector size a] -> ShowS
showList :: [Vector size a] -> ShowS
Show, Vector size a -> Vector size a -> Bool
(Vector size a -> Vector size a -> Bool)
-> (Vector size a -> Vector size a -> Bool) -> Eq (Vector size a)
forall (size :: Natural) a.
Eq a =>
Vector size a -> Vector size a -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: forall (size :: Natural) a.
Eq a =>
Vector size a -> Vector size a -> Bool
== :: Vector size a -> Vector size a -> Bool
$c/= :: forall (size :: Natural) a.
Eq a =>
Vector size a -> Vector size a -> Bool
/= :: Vector size a -> Vector size a -> Bool
Eq, (forall a b. (a -> b) -> Vector size a -> Vector size b)
-> (forall a b. a -> Vector size b -> Vector size a)
-> Functor (Vector size)
forall (size :: Natural) a b. a -> Vector size b -> Vector size a
forall (size :: Natural) a b.
(a -> b) -> Vector size a -> Vector size b
forall a b. a -> Vector size b -> Vector size a
forall a b. (a -> b) -> Vector size a -> Vector size b
forall (f :: Type -> Type).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
$cfmap :: forall (size :: Natural) a b.
(a -> b) -> Vector size a -> Vector size b
fmap :: forall a b. (a -> b) -> Vector size a -> Vector size b
$c<$ :: forall (size :: Natural) a b. a -> Vector size b -> Vector size a
<$ :: forall a b. a -> Vector size b -> Vector size a
Functor, (forall m. Monoid m => Vector size m -> m)
-> (forall m a. Monoid m => (a -> m) -> Vector size a -> m)
-> (forall m a. Monoid m => (a -> m) -> Vector size a -> m)
-> (forall a b. (a -> b -> b) -> b -> Vector size a -> b)
-> (forall a b. (a -> b -> b) -> b -> Vector size a -> b)
-> (forall b a. (b -> a -> b) -> b -> Vector size a -> b)
-> (forall b a. (b -> a -> b) -> b -> Vector size a -> b)
-> (forall a. (a -> a -> a) -> Vector size a -> a)
-> (forall a. (a -> a -> a) -> Vector size a -> a)
-> (forall a. Vector size a -> [a])
-> (forall a. Vector size a -> Bool)
-> (forall a. Vector size a -> Int)
-> (forall a. Eq a => a -> Vector size a -> Bool)
-> (forall a. Ord a => Vector size a -> a)
-> (forall a. Ord a => Vector size a -> a)
-> (forall a. Num a => Vector size a -> a)
-> (forall a. Num a => Vector size a -> a)
-> Foldable (Vector size)
forall (size :: Natural) a. Eq a => a -> Vector size a -> Bool
forall (size :: Natural) a. Num a => Vector size a -> a
forall (size :: Natural) a. Ord a => Vector size a -> a
forall (size :: Natural) m. Monoid m => Vector size m -> m
forall (size :: Natural) a. Vector size a -> Bool
forall (size :: Natural) a. Vector size a -> Int
forall (size :: Natural) a. Vector size a -> [a]
forall (size :: Natural) a. (a -> a -> a) -> Vector size a -> a
forall (size :: Natural) m a.
Monoid m =>
(a -> m) -> Vector size a -> m
forall (size :: Natural) b a.
(b -> a -> b) -> b -> Vector size a -> b
forall (size :: Natural) a b.
(a -> b -> b) -> b -> Vector size a -> b
forall a. Eq a => a -> Vector size a -> Bool
forall a. Num a => Vector size a -> a
forall a. Ord a => Vector size a -> a
forall m. Monoid m => Vector size m -> m
forall a. Vector size a -> Bool
forall a. Vector size a -> Int
forall a. Vector size a -> [a]
forall a. (a -> a -> a) -> Vector size a -> a
forall m a. Monoid m => (a -> m) -> Vector size a -> m
forall b a. (b -> a -> b) -> b -> Vector size a -> b
forall a b. (a -> b -> b) -> b -> Vector size a -> b
forall (t :: Type -> Type).
(forall m. Monoid m => t m -> m)
-> (forall m a. Monoid m => (a -> m) -> t a -> m)
-> (forall m a. Monoid m => (a -> m) -> t a -> m)
-> (forall a b. (a -> b -> b) -> b -> t a -> b)
-> (forall a b. (a -> b -> b) -> b -> t a -> b)
-> (forall b a. (b -> a -> b) -> b -> t a -> b)
-> (forall b a. (b -> a -> b) -> b -> t a -> b)
-> (forall a. (a -> a -> a) -> t a -> a)
-> (forall a. (a -> a -> a) -> t a -> a)
-> (forall a. t a -> [a])
-> (forall a. t a -> Bool)
-> (forall a. t a -> Int)
-> (forall a. Eq a => a -> t a -> Bool)
-> (forall a. Ord a => t a -> a)
-> (forall a. Ord a => t a -> a)
-> (forall a. Num a => t a -> a)
-> (forall a. Num a => t a -> a)
-> Foldable t
$cfold :: forall (size :: Natural) m. Monoid m => Vector size m -> m
fold :: forall m. Monoid m => Vector size m -> m
$cfoldMap :: forall (size :: Natural) m a.
Monoid m =>
(a -> m) -> Vector size a -> m
foldMap :: forall m a. Monoid m => (a -> m) -> Vector size a -> m
$cfoldMap' :: forall (size :: Natural) m a.
Monoid m =>
(a -> m) -> Vector size a -> m
foldMap' :: forall m a. Monoid m => (a -> m) -> Vector size a -> m
$cfoldr :: forall (size :: Natural) a b.
(a -> b -> b) -> b -> Vector size a -> b
foldr :: forall a b. (a -> b -> b) -> b -> Vector size a -> b
$cfoldr' :: forall (size :: Natural) a b.
(a -> b -> b) -> b -> Vector size a -> b
foldr' :: forall a b. (a -> b -> b) -> b -> Vector size a -> b
$cfoldl :: forall (size :: Natural) b a.
(b -> a -> b) -> b -> Vector size a -> b
foldl :: forall b a. (b -> a -> b) -> b -> Vector size a -> b
$cfoldl' :: forall (size :: Natural) b a.
(b -> a -> b) -> b -> Vector size a -> b
foldl' :: forall b a. (b -> a -> b) -> b -> Vector size a -> b
$cfoldr1 :: forall (size :: Natural) a. (a -> a -> a) -> Vector size a -> a
foldr1 :: forall a. (a -> a -> a) -> Vector size a -> a
$cfoldl1 :: forall (size :: Natural) a. (a -> a -> a) -> Vector size a -> a
foldl1 :: forall a. (a -> a -> a) -> Vector size a -> a
$ctoList :: forall (size :: Natural) a. Vector size a -> [a]
toList :: forall a. Vector size a -> [a]
$cnull :: forall (size :: Natural) a. Vector size a -> Bool
null :: forall a. Vector size a -> Bool
$clength :: forall (size :: Natural) a. Vector size a -> Int
length :: forall a. Vector size a -> Int
$celem :: forall (size :: Natural) a. Eq a => a -> Vector size a -> Bool
elem :: forall a. Eq a => a -> Vector size a -> Bool
$cmaximum :: forall (size :: Natural) a. Ord a => Vector size a -> a
maximum :: forall a. Ord a => Vector size a -> a
$cminimum :: forall (size :: Natural) a. Ord a => Vector size a -> a
minimum :: forall a. Ord a => Vector size a -> a
$csum :: forall (size :: Natural) a. Num a => Vector size a -> a
sum :: forall a. Num a => Vector size a -> a
$cproduct :: forall (size :: Natural) a. Num a => Vector size a -> a
product :: forall a. Num a => Vector size a -> a
Foldable, Functor (Vector size)
Foldable (Vector size)
(Functor (Vector size), Foldable (Vector size)) =>
(forall (f :: Type -> Type) a b.
 Applicative f =>
 (a -> f b) -> Vector size a -> f (Vector size b))
-> (forall (f :: Type -> Type) a.
    Applicative f =>
    Vector size (f a) -> f (Vector size a))
-> (forall (m :: Type -> Type) a b.
    Monad m =>
    (a -> m b) -> Vector size a -> m (Vector size b))
-> (forall (m :: Type -> Type) a.
    Monad m =>
    Vector size (m a) -> m (Vector size a))
-> Traversable (Vector size)
forall (size :: Natural). Functor (Vector size)
forall (size :: Natural). Foldable (Vector size)
forall (size :: Natural) (m :: Type -> Type) a.
Monad m =>
Vector size (m a) -> m (Vector size a)
forall (size :: Natural) (f :: Type -> Type) a.
Applicative f =>
Vector size (f a) -> f (Vector size a)
forall (size :: Natural) (m :: Type -> Type) a b.
Monad m =>
(a -> m b) -> Vector size a -> m (Vector size b)
forall (size :: Natural) (f :: Type -> Type) a b.
Applicative f =>
(a -> f b) -> Vector size a -> f (Vector size b)
forall (t :: Type -> Type).
(Functor t, Foldable t) =>
(forall (f :: Type -> Type) a b.
 Applicative f =>
 (a -> f b) -> t a -> f (t b))
-> (forall (f :: Type -> Type) a.
    Applicative f =>
    t (f a) -> f (t a))
-> (forall (m :: Type -> Type) a b.
    Monad m =>
    (a -> m b) -> t a -> m (t b))
-> (forall (m :: Type -> Type) a. Monad m => t (m a) -> m (t a))
-> Traversable t
forall (m :: Type -> Type) a.
Monad m =>
Vector size (m a) -> m (Vector size a)
forall (f :: Type -> Type) a.
Applicative f =>
Vector size (f a) -> f (Vector size a)
forall (m :: Type -> Type) a b.
Monad m =>
(a -> m b) -> Vector size a -> m (Vector size b)
forall (f :: Type -> Type) a b.
Applicative f =>
(a -> f b) -> Vector size a -> f (Vector size b)
$ctraverse :: forall (size :: Natural) (f :: Type -> Type) a b.
Applicative f =>
(a -> f b) -> Vector size a -> f (Vector size b)
traverse :: forall (f :: Type -> Type) a b.
Applicative f =>
(a -> f b) -> Vector size a -> f (Vector size b)
$csequenceA :: forall (size :: Natural) (f :: Type -> Type) a.
Applicative f =>
Vector size (f a) -> f (Vector size a)
sequenceA :: forall (f :: Type -> Type) a.
Applicative f =>
Vector size (f a) -> f (Vector size a)
$cmapM :: forall (size :: Natural) (m :: Type -> Type) a b.
Monad m =>
(a -> m b) -> Vector size a -> m (Vector size b)
mapM :: forall (m :: Type -> Type) a b.
Monad m =>
(a -> m b) -> Vector size a -> m (Vector size b)
$csequence :: forall (size :: Natural) (m :: Type -> Type) a.
Monad m =>
Vector size (m a) -> m (Vector size a)
sequence :: forall (m :: Type -> Type) a.
Monad m =>
Vector size (m a) -> m (Vector size a)
Traversable)

toVector :: forall size a . KnownNat size => [a] -> Maybe (Vector size a)
toVector :: forall (size :: Natural) a.
KnownNat size =>
[a] -> Maybe (Vector size a)
toVector [a]
as
    | [a] -> Natural
forall (t :: Type -> Type) a. Foldable t => t a -> Natural
length [a]
as Natural -> Natural -> Bool
forall a. Eq a => a -> a -> Bool
== forall (n :: Natural). KnownNat n => Natural
value @size = Vector size a -> Maybe (Vector size a)
forall a. a -> Maybe a
Just (Vector size a -> Maybe (Vector size a))
-> Vector size a -> Maybe (Vector size a)
forall a b. (a -> b) -> a -> b
$ [a] -> Vector size a
forall (size :: Natural) a. [a] -> Vector size a
Vector [a]
as
    | Bool
otherwise                = Maybe (Vector size a)
forall a. Maybe a
Nothing

fromVector :: Vector size a -> [a]
fromVector :: forall (size :: Natural) a. Vector size a -> [a]
fromVector (Vector [a]
as) = [a]
as

vectorDotProduct :: forall size a . Semiring a => Vector size a -> Vector size a -> a
vectorDotProduct :: forall (size :: Natural) a.
Semiring a =>
Vector size a -> Vector size a -> a
vectorDotProduct (Vector [a]
as) (Vector [a]
bs) = [a] -> a
forall (t :: Type -> Type) a.
(Foldable t, AdditiveMonoid a) =>
t a -> a
sum ([a] -> a) -> [a] -> a
forall a b. (a -> b) -> a -> b
$ (a -> a -> a) -> [a] -> [a] -> [a]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
forall (f :: Type -> Type) a b c.
Zip f =>
(a -> b -> c) -> f a -> f b -> f c
zipWith a -> a -> a
forall a. MultiplicativeSemigroup a => a -> a -> a
(*) [a]
as [a]
bs

concat :: Vector m (Vector n a) -> Vector (m * n) a
concat :: forall (m :: Natural) (n :: Natural) a.
Vector m (Vector n a) -> Vector (m * n) a
concat = [a] -> Vector (m * n) a
forall (size :: Natural) a. [a] -> Vector size a
Vector ([a] -> Vector (m * n) a)
-> (Vector m (Vector n a) -> [a])
-> Vector m (Vector n a)
-> Vector (m * n) a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Vector n a -> [a]) -> Vector m (Vector n a) -> [a]
forall (t :: Type -> Type) a b.
Foldable t =>
(a -> [b]) -> t a -> [b]
concatMap Vector n a -> [a]
forall (size :: Natural) a. Vector size a -> [a]
fromVector

instance Binary a => Binary (Vector n a) where
    put :: Vector n a -> Put
put = [a] -> Put
forall t. Binary t => t -> Put
put ([a] -> Put) -> (Vector n a -> [a]) -> Vector n a -> Put
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Vector n a -> [a]
forall (size :: Natural) a. Vector size a -> [a]
fromVector
    get :: Get (Vector n a)
get = [a] -> Vector n a
forall (size :: Natural) a. [a] -> Vector size a
Vector ([a] -> Vector n a) -> Get [a] -> Get (Vector n a)
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> Get [a]
forall t. Binary t => Get t
get

instance KnownNat size => Applicative (Vector size) where
    pure :: forall a. a -> Vector size a
pure a
a = [a] -> Vector size a
forall (size :: Natural) a. [a] -> Vector size a
Vector ([a] -> Vector size a) -> [a] -> Vector size a
forall a b. (a -> b) -> a -> b
$ Natural -> a -> [a]
forall a. Natural -> a -> [a]
replicate (forall (n :: Natural). KnownNat n => Natural
value @size) a
a

    (Vector [a -> b]
fs) <*> :: forall a b. Vector size (a -> b) -> Vector size a -> Vector size b
<*> (Vector [a]
as) = [b] -> Vector size b
forall (size :: Natural) a. [a] -> Vector size a
Vector ([b] -> Vector size b) -> [b] -> Vector size b
forall a b. (a -> b) -> a -> b
$ ((a -> b) -> a -> b) -> [a -> b] -> [a] -> [b]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
forall (f :: Type -> Type) a b c.
Zip f =>
(a -> b -> c) -> f a -> f b -> f c
zipWith (a -> b) -> a -> b
forall a b. (a -> b) -> a -> b
($) [a -> b]
fs [a]
as

instance Semialign (Vector size) where
    align :: forall a b.
Vector size a -> Vector size b -> Vector size (These a b)
align (Vector [a]
as) (Vector [b]
bs) = [These a b] -> Vector size (These a b)
forall (size :: Natural) a. [a] -> Vector size a
Vector ([These a b] -> Vector size (These a b))
-> [These a b] -> Vector size (These a b)
forall a b. (a -> b) -> a -> b
$ (a -> b -> These a b) -> [a] -> [b] -> [These a b]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
forall (f :: Type -> Type) a b c.
Zip f =>
(a -> b -> c) -> f a -> f b -> f c
zipWith a -> b -> These a b
forall a b. a -> b -> These a b
These [a]
as [b]
bs

instance Zip (Vector size) where
    zip :: forall a b. Vector size a -> Vector size b -> Vector size (a, b)
zip (Vector [a]
as) (Vector [b]
bs) = [(a, b)] -> Vector size (a, b)
forall (size :: Natural) a. [a] -> Vector size a
Vector ([(a, b)] -> Vector size (a, b)) -> [(a, b)] -> Vector size (a, b)
forall a b. (a -> b) -> a -> b
$ [a] -> [b] -> [(a, b)]
forall a b. [a] -> [b] -> [(a, b)]
forall (f :: Type -> Type) a b. Zip f => f a -> f b -> f (a, b)
zip [a]
as [b]
bs

    zipWith :: forall a b c.
(a -> b -> c) -> Vector size a -> Vector size b -> Vector size c
zipWith a -> b -> c
f (Vector [a]
as) (Vector [b]
bs) = [c] -> Vector size c
forall (size :: Natural) a. [a] -> Vector size a
Vector ([c] -> Vector size c) -> [c] -> Vector size c
forall a b. (a -> b) -> a -> b
$ (a -> b -> c) -> [a] -> [b] -> [c]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
forall (f :: Type -> Type) a b c.
Zip f =>
(a -> b -> c) -> f a -> f b -> f c
zipWith a -> b -> c
f [a]
as [b]
bs

instance (Arbitrary a, KnownNat size) => Arbitrary (Vector size a) where
    arbitrary :: Gen (Vector size a)
arbitrary = [a] -> Vector size a
forall (size :: Natural) a. [a] -> Vector size a
Vector ([a] -> Vector size a) -> Gen [a] -> Gen (Vector size a)
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> (Natural -> Gen a) -> [Natural] -> Gen [a]
forall (t :: Type -> Type) (m :: Type -> Type) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: Type -> Type) a b.
Monad m =>
(a -> m b) -> [a] -> m [b]
mapM (Gen a -> Natural -> Gen a
forall a b. a -> b -> a
const Gen a
forall a. Arbitrary a => Gen a
arbitrary) [Natural
1..forall (n :: Natural). KnownNat n => Natural
value @size]

instance (Random a, KnownNat size) => Random (Vector size a) where
    random :: forall g. RandomGen g => g -> (Vector size a, g)
random g
g =
        let as :: ([a], g)
as = (([a], g) -> Natural -> ([a], g))
-> ([a], g) -> [Natural] -> ([a], g)
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: Type -> Type) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl (\([a]
as', g
g') Natural
_ ->
                let (a
a, g
g'') = g -> (a, g)
forall g. RandomGen g => g -> (a, g)
forall a g. (Random a, RandomGen g) => g -> (a, g)
random g
g'
                in ([a]
as' [a] -> [a] -> [a]
forall a. [a] -> [a] -> [a]
++ [a
a], g
g''))
                ([], g
g) [Natural
1..forall (n :: Natural). KnownNat n => Natural
value @size]
        in ([a] -> Vector size a) -> ([a], g) -> (Vector size a, g)
forall a b c. (a -> b) -> (a, c) -> (b, c)
forall (p :: Type -> Type -> Type) a b c.
Bifunctor p =>
(a -> b) -> p a c -> p b c
first [a] -> Vector size a
forall (size :: Natural) a. [a] -> Vector size a
Vector ([a], g)
as

    randomR :: forall g.
RandomGen g =>
(Vector size a, Vector size a) -> g -> (Vector size a, g)
randomR (Vector [a]
xs, Vector [a]
ys) g
g =
        let as :: ([a], g)
as = (([a], g), ([a], [a])) -> ([a], g)
forall a b. (a, b) -> a
fst ((([a], g), ([a], [a])) -> ([a], g))
-> (([a], g), ([a], [a])) -> ([a], g)
forall a b. (a -> b) -> a -> b
$ ((([a], g), ([a], [a])) -> Natural -> (([a], g), ([a], [a])))
-> (([a], g), ([a], [a])) -> [Natural] -> (([a], g), ([a], [a]))
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: Type -> Type) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl (\(([a]
as', g
g'), ([a]
xs', [a]
ys')) Natural
_ ->
                let (a
a, g
g'') = (a, a) -> g -> (a, g)
forall g. RandomGen g => (a, a) -> g -> (a, g)
forall a g. (Random a, RandomGen g) => (a, a) -> g -> (a, g)
randomR ([a] -> a
forall a. HasCallStack => [a] -> a
head [a]
xs', [a] -> a
forall a. HasCallStack => [a] -> a
head [a]
ys') g
g'
                in (([a]
as' [a] -> [a] -> [a]
forall a. [a] -> [a] -> [a]
++ [a
a], g
g''), ([a] -> [a]
forall a. HasCallStack => [a] -> [a]
tail [a]
xs', [a] -> [a]
forall a. HasCallStack => [a] -> [a]
tail [a]
ys'))) (([], g
g), ([a]
xs, [a]
ys)) [Natural
1..forall (n :: Natural). KnownNat n => Natural
value @size]
        in ([a] -> Vector size a) -> ([a], g) -> (Vector size a, g)
forall a b c. (a -> b) -> (a, c) -> (b, c)
forall (p :: Type -> Type -> Type) a b c.
Bifunctor p =>
(a -> b) -> p a c -> p b c
first [a] -> Vector size a
forall (size :: Natural) a. [a] -> Vector size a
Vector ([a], g)
as