module ZkFold.Base.Data.Sparse.Vector where

import           Data.Map                         (Map, empty, filter, fromList, map, toList)
import           Data.These                       (These (..))
import           Data.Zip                         (Semialign (..), Zip (..))
import           Prelude                          hiding (Num (..), filter, length, map, sum, zip, zipWith, (/))
import           Test.QuickCheck                  (Arbitrary (..))

import           ZkFold.Base.Algebra.Basic.Class
import           ZkFold.Base.Algebra.Basic.Field  (Zp)
import           ZkFold.Base.Algebra.Basic.Number (KnownNat)
import           ZkFold.Base.Data.ByteString      (Binary (..))

newtype SVector size a = SVector { forall (size :: Natural) a. SVector size a -> Map (Zp size) a
fromSVector :: Map (Zp size) a }
    deriving (Int -> SVector size a -> ShowS
[SVector size a] -> ShowS
SVector size a -> String
(Int -> SVector size a -> ShowS)
-> (SVector size a -> String)
-> ([SVector size a] -> ShowS)
-> Show (SVector size a)
forall (size :: Natural) a.
Show a =>
Int -> SVector size a -> ShowS
forall (size :: Natural) a. Show a => [SVector size a] -> ShowS
forall (size :: Natural) a. Show a => SVector size a -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: forall (size :: Natural) a.
Show a =>
Int -> SVector size a -> ShowS
showsPrec :: Int -> SVector size a -> ShowS
$cshow :: forall (size :: Natural) a. Show a => SVector size a -> String
show :: SVector size a -> String
$cshowList :: forall (size :: Natural) a. Show a => [SVector size a] -> ShowS
showList :: [SVector size a] -> ShowS
Show, SVector size a -> SVector size a -> Bool
(SVector size a -> SVector size a -> Bool)
-> (SVector size a -> SVector size a -> Bool)
-> Eq (SVector size a)
forall (size :: Natural) a.
(KnownNat size, Eq a) =>
SVector size a -> SVector size a -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: forall (size :: Natural) a.
(KnownNat size, Eq a) =>
SVector size a -> SVector size a -> Bool
== :: SVector size a -> SVector size a -> Bool
$c/= :: forall (size :: Natural) a.
(KnownNat size, Eq a) =>
SVector size a -> SVector size a -> Bool
/= :: SVector size a -> SVector size a -> Bool
Eq)

instance (Binary a, KnownNat n) => Binary (SVector n a) where
    put :: SVector n a -> Put
put = [(Zp n, a)] -> Put
forall t. Binary t => t -> Put
put ([(Zp n, a)] -> Put)
-> (SVector n a -> [(Zp n, a)]) -> SVector n a -> Put
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Map (Zp n) a -> [(Zp n, a)]
forall k a. Map k a -> [(k, a)]
toList (Map (Zp n) a -> [(Zp n, a)])
-> (SVector n a -> Map (Zp n) a) -> SVector n a -> [(Zp n, a)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SVector n a -> Map (Zp n) a
forall (size :: Natural) a. SVector size a -> Map (Zp size) a
fromSVector
    get :: Get (SVector n a)
get = Map (Zp n) a -> SVector n a
forall (size :: Natural) a. Map (Zp size) a -> SVector size a
SVector (Map (Zp n) a -> SVector n a)
-> ([(Zp n, a)] -> Map (Zp n) a) -> [(Zp n, a)] -> SVector n a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [(Zp n, a)] -> Map (Zp n) a
forall k a. Ord k => [(k, a)] -> Map k a
fromList ([(Zp n, a)] -> SVector n a)
-> Get [(Zp n, a)] -> Get (SVector n a)
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> Get [(Zp n, a)]
forall t. Binary t => Get t
get

instance Foldable (SVector size) where
    foldr :: forall a b. (a -> b -> b) -> b -> SVector size a -> b
foldr a -> b -> b
f b
z (SVector Map (Zp size) a
as) = (a -> b -> b) -> b -> Map (Zp size) a -> b
forall a b. (a -> b -> b) -> b -> Map (Zp size) a -> b
forall (t :: Type -> Type) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr a -> b -> b
f b
z Map (Zp size) a
as

instance Functor (SVector size) where
    fmap :: forall a b. (a -> b) -> SVector size a -> SVector size b
fmap a -> b
f (SVector Map (Zp size) a
as) = Map (Zp size) b -> SVector size b
forall (size :: Natural) a. Map (Zp size) a -> SVector size a
SVector (Map (Zp size) b -> SVector size b)
-> Map (Zp size) b -> SVector size b
forall a b. (a -> b) -> a -> b
$ (a -> b) -> Map (Zp size) a -> Map (Zp size) b
forall a b. (a -> b) -> Map (Zp size) a -> Map (Zp size) b
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
fmap a -> b
f Map (Zp size) a
as

instance KnownNat size => Semialign (SVector size) where
    align :: forall a b.
SVector size a -> SVector size b -> SVector size (These a b)
align (SVector Map (Zp size) a
as) (SVector Map (Zp size) b
bs) = Map (Zp size) (These a b) -> SVector size (These a b)
forall (size :: Natural) a. Map (Zp size) a -> SVector size a
SVector (Map (Zp size) (These a b) -> SVector size (These a b))
-> Map (Zp size) (These a b) -> SVector size (These a b)
forall a b. (a -> b) -> a -> b
$ Map (Zp size) a -> Map (Zp size) b -> Map (Zp size) (These a b)
forall a b.
Map (Zp size) a -> Map (Zp size) b -> Map (Zp size) (These a b)
forall (f :: Type -> Type) a b.
Semialign f =>
f a -> f b -> f (These a b)
align Map (Zp size) a
as Map (Zp size) b
bs

    alignWith :: forall a b c.
(These a b -> c)
-> SVector size a -> SVector size b -> SVector size c
alignWith These a b -> c
f (SVector Map (Zp size) a
as) (SVector Map (Zp size) b
bs) = Map (Zp size) c -> SVector size c
forall (size :: Natural) a. Map (Zp size) a -> SVector size a
SVector (Map (Zp size) c -> SVector size c)
-> Map (Zp size) c -> SVector size c
forall a b. (a -> b) -> a -> b
$ (These a b -> c)
-> Map (Zp size) a -> Map (Zp size) b -> Map (Zp size) c
forall a b c.
(These a b -> c)
-> Map (Zp size) a -> Map (Zp size) b -> Map (Zp size) c
forall (f :: Type -> Type) a b c.
Semialign f =>
(These a b -> c) -> f a -> f b -> f c
alignWith These a b -> c
f Map (Zp size) a
as Map (Zp size) b
bs

instance KnownNat size => Zip (SVector size) where
    zip :: forall a b. SVector size a -> SVector size b -> SVector size (a, b)
zip (SVector Map (Zp size) a
as) (SVector Map (Zp size) b
bs) = Map (Zp size) (a, b) -> SVector size (a, b)
forall (size :: Natural) a. Map (Zp size) a -> SVector size a
SVector (Map (Zp size) (a, b) -> SVector size (a, b))
-> Map (Zp size) (a, b) -> SVector size (a, b)
forall a b. (a -> b) -> a -> b
$ Map (Zp size) a -> Map (Zp size) b -> Map (Zp size) (a, b)
forall a b.
Map (Zp size) a -> Map (Zp size) b -> Map (Zp size) (a, b)
forall (f :: Type -> Type) a b. Zip f => f a -> f b -> f (a, b)
zip Map (Zp size) a
as Map (Zp size) b
bs

    zipWith :: forall a b c.
(a -> b -> c) -> SVector size a -> SVector size b -> SVector size c
zipWith a -> b -> c
f (SVector Map (Zp size) a
as) (SVector Map (Zp size) b
bs) = Map (Zp size) c -> SVector size c
forall (size :: Natural) a. Map (Zp size) a -> SVector size a
SVector (Map (Zp size) c -> SVector size c)
-> Map (Zp size) c -> SVector size c
forall a b. (a -> b) -> a -> b
$ (a -> b -> c)
-> Map (Zp size) a -> Map (Zp size) b -> Map (Zp size) c
forall a b c.
(a -> b -> c)
-> Map (Zp size) a -> Map (Zp size) b -> Map (Zp size) c
forall (f :: Type -> Type) a b c.
Zip f =>
(a -> b -> c) -> f a -> f b -> f c
zipWith a -> b -> c
f Map (Zp size) a
as Map (Zp size) b
bs

instance (KnownNat size, Arbitrary a) => Arbitrary (SVector size a) where
    arbitrary :: Gen (SVector size a)
arbitrary = Map (Zp size) a -> SVector size a
forall (size :: Natural) a. Map (Zp size) a -> SVector size a
SVector (Map (Zp size) a -> SVector size a)
-> Gen (Map (Zp size) a) -> Gen (SVector size a)
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> Gen (Map (Zp size) a)
forall a. Arbitrary a => Gen a
arbitrary

instance (KnownNat size, AdditiveMonoid a, Eq a) => AdditiveSemigroup (SVector size a) where
    SVector size a
va + :: SVector size a -> SVector size a -> SVector size a
+ SVector size a
vb = Map (Zp size) a -> SVector size a
forall (size :: Natural) a. Map (Zp size) a -> SVector size a
SVector (Map (Zp size) a -> SVector size a)
-> Map (Zp size) a -> SVector size a
forall a b. (a -> b) -> a -> b
$ (a -> Bool) -> Map (Zp size) a -> Map (Zp size) a
forall a k. (a -> Bool) -> Map k a -> Map k a
filter (a -> a -> Bool
forall a. Eq a => a -> a -> Bool
/= a
forall a. AdditiveMonoid a => a
zero) (Map (Zp size) a -> Map (Zp size) a)
-> Map (Zp size) a -> Map (Zp size) a
forall a b. (a -> b) -> a -> b
$ SVector size a -> Map (Zp size) a
forall (size :: Natural) a. SVector size a -> Map (Zp size) a
fromSVector (SVector size a -> Map (Zp size) a)
-> SVector size a -> Map (Zp size) a
forall a b. (a -> b) -> a -> b
$ (These a a -> a)
-> SVector size a -> SVector size a -> SVector size a
forall a b c.
(These a b -> c)
-> SVector size a -> SVector size b -> SVector size c
forall (f :: Type -> Type) a b c.
Semialign f =>
(These a b -> c) -> f a -> f b -> f c
alignWith (\case
        This a
a -> a
a
        That a
b -> a
b
        These a
a a
b -> a
a a -> a -> a
forall a. AdditiveSemigroup a => a -> a -> a
+ a
b) SVector size a
va SVector size a
vb

(.+) :: (KnownNat size, AdditiveMonoid a, Eq a) => SVector size a -> SVector size a -> SVector size a
.+ :: forall (size :: Natural) a.
(KnownNat size, AdditiveMonoid a, Eq a) =>
SVector size a -> SVector size a -> SVector size a
(.+) = SVector size a -> SVector size a -> SVector size a
forall a. AdditiveSemigroup a => a -> a -> a
(+)

instance Scale c a => Scale c (SVector size a) where
    scale :: c -> SVector size a -> SVector size a
scale c
c (SVector Map (Zp size) a
as) = Map (Zp size) a -> SVector size a
forall (size :: Natural) a. Map (Zp size) a -> SVector size a
SVector ((a -> a) -> Map (Zp size) a -> Map (Zp size) a
forall a b k. (a -> b) -> Map k a -> Map k b
map (c -> a -> a
forall b a. Scale b a => b -> a -> a
scale c
c) Map (Zp size) a
as)

instance (KnownNat size, AdditiveMonoid a, Eq a) => AdditiveMonoid (SVector size a) where
    zero :: SVector size a
zero = Map (Zp size) a -> SVector size a
forall (size :: Natural) a. Map (Zp size) a -> SVector size a
SVector Map (Zp size) a
forall k a. Map k a
empty

instance (KnownNat size, AdditiveGroup a, Eq a) => AdditiveGroup (SVector size a) where
    negate :: SVector size a -> SVector size a
negate = (a -> a) -> SVector size a -> SVector size a
forall a b. (a -> b) -> SVector size a -> SVector size b
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
fmap a -> a
forall a. AdditiveGroup a => a -> a
negate

(.-) :: (KnownNat size, AdditiveGroup a, Eq a) => SVector size a -> SVector size a -> SVector size a
.- :: forall (size :: Natural) a.
(KnownNat size, AdditiveGroup a, Eq a) =>
SVector size a -> SVector size a -> SVector size a
(.-) = (-)

(.*) :: (KnownNat size, MultiplicativeSemigroup a) => SVector size a -> SVector size a -> SVector size a
.* :: forall (size :: Natural) a.
(KnownNat size, MultiplicativeSemigroup a) =>
SVector size a -> SVector size a -> SVector size a
(.*) = (a -> a -> a) -> SVector size a -> SVector size a -> SVector size a
forall a b c.
(a -> b -> c) -> SVector size a -> SVector size b -> SVector size c
forall (f :: Type -> Type) a b c.
Zip f =>
(a -> b -> c) -> f a -> f b -> f c
zipWith a -> a -> a
forall a. MultiplicativeSemigroup a => a -> a -> a
(*)

(./) :: (KnownNat size, MultiplicativeGroup a) => SVector size a -> SVector size a -> SVector size a
./ :: forall (size :: Natural) a.
(KnownNat size, MultiplicativeGroup a) =>
SVector size a -> SVector size a -> SVector size a
(./) = (a -> a -> a) -> SVector size a -> SVector size a -> SVector size a
forall a b c.
(a -> b -> c) -> SVector size a -> SVector size b -> SVector size c
forall (f :: Type -> Type) a b c.
Zip f =>
(a -> b -> c) -> f a -> f b -> f c
zipWith a -> a -> a
forall a. MultiplicativeGroup a => a -> a -> a
(/)