{-# LANGUAGE AllowAmbiguousTypes          #-}
{-# LANGUAGE DeriveAnyClass               #-}
{-# LANGUAGE DerivingStrategies           #-}
{-# LANGUAGE NoGeneralisedNewtypeDeriving #-}
{-# LANGUAGE TypeApplications             #-}
{-# LANGUAGE TypeOperators                #-}
{-# LANGUAGE UndecidableInstances         #-}

module ZkFold.Symbolic.Data.ByteString
    ( ByteString(..)
    , ShiftBits (..)
    , ToWords (..)
    , Concat (..)
    , Truncate (..)
    ) where

import           Control.DeepSeq                                           (NFData)
import           Control.Monad                                             (mapM, replicateM, zipWithM)
import           Data.Bits                                                 as B
import qualified Data.ByteString                                           as Bytes
import           Data.List                                                 (foldl, reverse, unfoldr)
import           Data.List.Split                                           (chunksOf)
import           Data.Maybe                                                (Maybe (..))
import           Data.Proxy                                                (Proxy (..))
import           Data.String                                               (IsString (..))
import           GHC.Generics                                              (Generic)
import           GHC.Natural                                               (naturalFromInteger)
import           GHC.TypeNats                                              (Mod, Natural, natVal)
import           Prelude                                                   (Bool (..), Integer, drop, error, fmap,
                                                                            length, otherwise, pure, take, type (~),
                                                                            ($), (.), (<$>), (<), (<>), (==), (>=))
import qualified Prelude                                                   as Haskell
import           Test.QuickCheck                                           (Arbitrary (..), chooseInteger)

import           ZkFold.Base.Algebra.Basic.Class
import           ZkFold.Base.Algebra.Basic.Field                           (Zp)
import           ZkFold.Base.Algebra.Basic.Number
import           ZkFold.Prelude                                            (replicate, replicateA)
import           ZkFold.Symbolic.Compiler
import           ZkFold.Symbolic.Compiler.ArithmeticCircuit.MonadBlueprint
import           ZkFold.Symbolic.Data.Bool                                 (BoolType (..))
import           ZkFold.Symbolic.Data.Combinators
import           ZkFold.Symbolic.Data.UInt


-- | A ByteString which stores @n@ bits and uses elements of @a@ as registers, one element per register.
-- Bit layout is Big-endian.
--
newtype ByteString (n :: Natural) a = ByteString [a]
    deriving (Int -> ByteString n a -> ShowS
[ByteString n a] -> ShowS
ByteString n a -> String
(Int -> ByteString n a -> ShowS)
-> (ByteString n a -> String)
-> ([ByteString n a] -> ShowS)
-> Show (ByteString n a)
forall (n :: Natural) a. Show a => Int -> ByteString n a -> ShowS
forall (n :: Natural) a. Show a => [ByteString n a] -> ShowS
forall (n :: Natural) a. Show a => ByteString n a -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: forall (n :: Natural) a. Show a => Int -> ByteString n a -> ShowS
showsPrec :: Int -> ByteString n a -> ShowS
$cshow :: forall (n :: Natural) a. Show a => ByteString n a -> String
show :: ByteString n a -> String
$cshowList :: forall (n :: Natural) a. Show a => [ByteString n a] -> ShowS
showList :: [ByteString n a] -> ShowS
Haskell.Show, ByteString n a -> ByteString n a -> Bool
(ByteString n a -> ByteString n a -> Bool)
-> (ByteString n a -> ByteString n a -> Bool)
-> Eq (ByteString n a)
forall (n :: Natural) a.
Eq a =>
ByteString n a -> ByteString n a -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: forall (n :: Natural) a.
Eq a =>
ByteString n a -> ByteString n a -> Bool
== :: ByteString n a -> ByteString n a -> Bool
$c/= :: forall (n :: Natural) a.
Eq a =>
ByteString n a -> ByteString n a -> Bool
/= :: ByteString n a -> ByteString n a -> Bool
Haskell.Eq, (forall x. ByteString n a -> Rep (ByteString n a) x)
-> (forall x. Rep (ByteString n a) x -> ByteString n a)
-> Generic (ByteString n a)
forall (n :: Natural) a x. Rep (ByteString n a) x -> ByteString n a
forall (n :: Natural) a x. ByteString n a -> Rep (ByteString n a) x
forall x. Rep (ByteString n a) x -> ByteString n a
forall x. ByteString n a -> Rep (ByteString n a) x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
$cfrom :: forall (n :: Natural) a x. ByteString n a -> Rep (ByteString n a) x
from :: forall x. ByteString n a -> Rep (ByteString n a) x
$cto :: forall (n :: Natural) a x. Rep (ByteString n a) x -> ByteString n a
to :: forall x. Rep (ByteString n a) x -> ByteString n a
Generic, ByteString n a -> ()
(ByteString n a -> ()) -> NFData (ByteString n a)
forall (n :: Natural) a. NFData a => ByteString n a -> ()
forall a. (a -> ()) -> NFData a
$crnf :: forall (n :: Natural) a. NFData a => ByteString n a -> ()
rnf :: ByteString n a -> ()
NFData)

instance
    ( FromConstant Natural a
    , Concat (ByteString 8 a) (ByteString n a)
    ) => IsString (ByteString n a) where
    fromString :: String -> ByteString n a
fromString = ByteString -> ByteString n a
forall a b. FromConstant a b => a -> b
fromConstant (ByteString -> ByteString n a)
-> (String -> ByteString) -> String -> ByteString n a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. IsString a => String -> a
fromString @Bytes.ByteString

instance
    ( FromConstant Natural a
    , Concat (ByteString 8 a) (ByteString n a)
    ) => FromConstant Bytes.ByteString (ByteString n a) where
    fromConstant :: ByteString -> ByteString n a
fromConstant ByteString
bytes = [ByteString 8 a] -> ByteString n a
forall a b. Concat a b => [a] -> b
concat
        ([ByteString 8 a] -> ByteString n a)
-> [ByteString 8 a] -> ByteString n a
forall a b. (a -> b) -> a -> b
$ forall a b. FromConstant a b => a -> b
fromConstant @Natural @(ByteString 8 a)
        (Natural -> ByteString 8 a)
-> (Word8 -> Natural) -> Word8 -> ByteString 8 a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Integer -> Natural
forall a b. (Integral a, Num b) => a -> b
Haskell.fromIntegral
        (Integer -> Natural) -> (Word8 -> Integer) -> Word8 -> Natural
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Word8 -> Integer
forall a. Integral a => a -> Integer
Haskell.toInteger
        (Word8 -> ByteString 8 a) -> [Word8] -> [ByteString 8 a]
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> ByteString -> [Word8]
Bytes.unpack ByteString
bytes

-- | A class for data types that support bit shift and bit cyclic shift (rotation) operations.
--
class ShiftBits a where
    {-# MINIMAL (shiftBits | (shiftBitsL, shiftBitsR)), (rotateBits | (rotateBitsL, rotateBitsR)) #-}

    -- | shiftBits performs a left shift when its agrument is greater than zero and a right shift otherwise.
    --
    shiftBits :: a -> Integer -> a
    shiftBits a
a Integer
s
      | Integer
s Integer -> Integer -> Bool
forall a. Ord a => a -> a -> Bool
< Integer
0     = a -> Natural -> a
forall a. ShiftBits a => a -> Natural -> a
shiftBitsR a
a (Integer -> Natural
forall a b. (Integral a, Num b) => a -> b
Haskell.fromIntegral (Integer -> Natural) -> (Integer -> Integer) -> Integer -> Natural
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Integer -> Integer
forall a. AdditiveGroup a => a -> a
negate (Integer -> Natural) -> Integer -> Natural
forall a b. (a -> b) -> a -> b
$ Integer
s)
      | Bool
otherwise = a -> Natural -> a
forall a. ShiftBits a => a -> Natural -> a
shiftBitsL a
a (Integer -> Natural
forall a b. (Integral a, Num b) => a -> b
Haskell.fromIntegral Integer
s)

    shiftBitsL :: a -> Natural -> a
    shiftBitsL a
a Natural
s = a -> Integer -> a
forall a. ShiftBits a => a -> Integer -> a
shiftBits a
a (Natural -> Integer
forall a b. (Integral a, Num b) => a -> b
Haskell.fromIntegral Natural
s)

    shiftBitsR :: a -> Natural -> a
    shiftBitsR a
a Natural
s = a -> Integer -> a
forall a. ShiftBits a => a -> Integer -> a
shiftBits a
a (Integer -> Integer
forall a. AdditiveGroup a => a -> a
negate (Integer -> Integer) -> (Natural -> Integer) -> Natural -> Integer
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Natural -> Integer
forall a b. (Integral a, Num b) => a -> b
Haskell.fromIntegral (Natural -> Integer) -> Natural -> Integer
forall a b. (a -> b) -> a -> b
$ Natural
s)

    -- | rotateBits performs a left cyclic shift when its agrument is greater than zero and a right cyclic shift otherwise.
    --
    rotateBits :: a -> Integer -> a
    rotateBits a
a Integer
s
      | Integer
s Integer -> Integer -> Bool
forall a. Ord a => a -> a -> Bool
< Integer
0     = a -> Natural -> a
forall a. ShiftBits a => a -> Natural -> a
rotateBitsR a
a (Integer -> Natural
forall a b. (Integral a, Num b) => a -> b
Haskell.fromIntegral (Integer -> Natural) -> (Integer -> Integer) -> Integer -> Natural
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Integer -> Integer
forall a. AdditiveGroup a => a -> a
negate (Integer -> Natural) -> Integer -> Natural
forall a b. (a -> b) -> a -> b
$ Integer
s)
      | Bool
otherwise = a -> Natural -> a
forall a. ShiftBits a => a -> Natural -> a
rotateBitsL a
a (Integer -> Natural
forall a b. (Integral a, Num b) => a -> b
Haskell.fromIntegral Integer
s)

    rotateBitsL :: a -> Natural -> a
    rotateBitsL a
a Natural
s = a -> Integer -> a
forall a. ShiftBits a => a -> Integer -> a
rotateBits a
a (Natural -> Integer
forall a b. (Integral a, Num b) => a -> b
Haskell.fromIntegral Natural
s)

    rotateBitsR :: a -> Natural -> a
    rotateBitsR a
a Natural
s = a -> Integer -> a
forall a. ShiftBits a => a -> Integer -> a
rotateBits a
a (Integer -> Integer
forall a. AdditiveGroup a => a -> a
negate (Integer -> Integer) -> (Natural -> Integer) -> Natural -> Integer
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Natural -> Integer
forall a b. (Integral a, Num b) => a -> b
Haskell.fromIntegral (Natural -> Integer) -> Natural -> Integer
forall a b. (a -> b) -> a -> b
$ Natural
s)


-- | Describes types which can be split into words of equal size.
-- Parameters have to be of different types as ByteString store their lengths on type level and hence after splitting they chagne types.
--
class ToWords a b where
    toWords :: a -> [b]


-- | Describes types which can be made by concatenating several words of equal length.
--
class Concat a b where
    concat :: [a] -> b


-- | Describes types that can be truncated by dropping several bits from the end (i.e. stored in the lower registers)
--
class Truncate a b where
    truncate :: a -> b


instance (ToConstant a Natural) => ToConstant (ByteString n a) Natural where
    toConstant :: ByteString n a -> Natural
toConstant (ByteString [a]
bits) = (Natural -> a -> Natural) -> Natural -> [a] -> Natural
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: Type -> Type) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
Haskell.foldl (\Natural
y a
p -> a -> Natural
forall a b. ToConstant a b => a -> b
toConstant a
p Natural -> Natural -> Natural
forall a. AdditiveSemigroup a => a -> a -> a
+ Natural
base Natural -> Natural -> Natural
forall a. MultiplicativeSemigroup a => a -> a -> a
* Natural
y) Natural
0 [a]
bits
        where base :: Natural
base = Natural
2


instance (FromConstant Natural a, KnownNat n) => FromConstant Natural (ByteString n a) where

    -- | Pack a ByteString using one field element per bit.
    -- @fromConstant@ discards bits after @n@.
    -- If the constant is greater than @2^n@, only the part modulo @2^n@ will be converted into a ByteString.
    --
    fromConstant :: Natural -> ByteString n a
fromConstant Natural
n = [a] -> ByteString n a
forall (n :: Natural) a. [a] -> ByteString n a
ByteString ([a] -> ByteString n a) -> [a] -> ByteString n a
forall a b. (a -> b) -> a -> b
$ [a] -> [a]
forall a. [a] -> [a]
reverse [a]
bits
        where
            base :: Natural
base = Natural
2

            availableBits :: [a]
availableBits = (Natural -> Maybe (a, Natural)) -> Natural -> [a]
forall b a. (b -> Maybe (a, b)) -> b -> [a]
unfoldr (Natural -> Natural -> Maybe (a, Natural)
forall a.
FromConstant Natural a =>
Natural -> Natural -> Maybe (a, Natural)
toBase Natural
base) (Natural
n Natural -> Natural -> Natural
forall a. Integral a => a -> a -> a
`Haskell.mod` (Natural
2 Natural -> Natural -> Natural
forall a b. (Num a, Integral b) => a -> b -> a
Haskell.^ (forall (n :: Natural). KnownNat n => Natural
getNatural @n))) [a] -> [a] -> [a]
forall a. Semigroup a => a -> a -> a
<> a -> [a]
forall a. a -> [a]
Haskell.repeat (forall a b. FromConstant a b => a -> b
fromConstant @Natural Natural
0)

            bits :: [a]
bits = Int -> [a] -> [a]
forall a. Int -> [a] -> [a]
take (Natural -> Int
forall a b. (Integral a, Num b) => a -> b
Haskell.fromIntegral (Natural -> Int) -> Natural -> Int
forall a b. (a -> b) -> a -> b
$ forall (n :: Natural). KnownNat n => Natural
value @n) [a]
availableBits

-- | Convert a number into @base@-ary system.
--
toBase :: FromConstant Natural a => Natural -> Natural -> Maybe (a, Natural)
toBase :: forall a.
FromConstant Natural a =>
Natural -> Natural -> Maybe (a, Natural)
toBase Natural
_ Natural
0    = Maybe (a, Natural)
forall a. Maybe a
Nothing
toBase Natural
base Natural
b = let (Natural
d, Natural
m) = Natural
b Natural -> Natural -> (Natural, Natural)
forall a. EuclideanDomain a => a -> a -> (a, a)
`divMod` Natural
base in (a, Natural) -> Maybe (a, Natural)
forall a. a -> Maybe a
Just (Natural -> a
forall a b. FromConstant a b => a -> b
fromConstant Natural
m, Natural
d)


instance (FromConstant Natural a, KnownNat n) => FromConstant Integer (ByteString n a) where
    fromConstant :: Integer -> ByteString n a
fromConstant = Natural -> ByteString n a
forall a b. FromConstant a b => a -> b
fromConstant (Natural -> ByteString n a)
-> (Integer -> Natural) -> Integer -> ByteString n a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Integer -> Natural
naturalFromInteger (Integer -> Natural) -> (Integer -> Integer) -> Integer -> Natural
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Integer -> Integer -> Integer
forall a. Integral a => a -> a -> a
`Haskell.mod` (Integer
2 Integer -> Natural -> Integer
forall a b. Exponent a b => a -> b -> a
^ forall (n :: Natural). KnownNat n => Natural
getNatural @n))

instance (Finite (Zp p), KnownNat n) => Iso (ByteString n (Zp p)) (UInt n (Zp p)) where
    from :: ByteString n (Zp p) -> UInt n (Zp p)
from ByteString n (Zp p)
bs = forall a b. FromConstant a b => a -> b
fromConstant @Natural (Natural -> UInt n (Zp p))
-> (ByteString n (Zp p) -> Natural)
-> ByteString n (Zp p)
-> UInt n (Zp p)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString n (Zp p) -> Natural
forall a b. ToConstant a b => a -> b
toConstant (ByteString n (Zp p) -> UInt n (Zp p))
-> ByteString n (Zp p) -> UInt n (Zp p)
forall a b. (a -> b) -> a -> b
$ ByteString n (Zp p)
bs

instance (Finite (Zp p), KnownNat n) => Iso (UInt n (Zp p)) (ByteString n (Zp p)) where
    from :: UInt n (Zp p) -> ByteString n (Zp p)
from UInt n (Zp p)
ui = forall a b. FromConstant a b => a -> b
fromConstant @Natural (Natural -> ByteString n (Zp p))
-> (UInt n (Zp p) -> Natural)
-> UInt n (Zp p)
-> ByteString n (Zp p)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. UInt n (Zp p) -> Natural
forall a b. ToConstant a b => a -> b
toConstant (UInt n (Zp p) -> ByteString n (Zp p))
-> UInt n (Zp p) -> ByteString n (Zp p)
forall a b. (a -> b) -> a -> b
$ UInt n (Zp p)
ui

instance (Finite (Zp p), KnownNat n) => Arbitrary (ByteString n (Zp p)) where
    arbitrary :: Gen (ByteString n (Zp p))
arbitrary = [Zp p] -> ByteString n (Zp p)
forall (n :: Natural) a. [a] -> ByteString n a
ByteString ([Zp p] -> ByteString n (Zp p))
-> Gen [Zp p] -> Gen (ByteString n (Zp p))
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> Natural -> Gen (Zp p) -> Gen [Zp p]
forall (f :: Type -> Type) a.
Applicative f =>
Natural -> f a -> f [a]
replicateA (forall (n :: Natural). KnownNat n => Natural
value @n) (Natural -> Gen (Zp p)
forall {b} {b}.
(FromConstant Integer b, Exponent Integer b) =>
b -> Gen b
toss (Natural
1 :: Natural))
        where toss :: b -> Gen b
toss b
b = Integer -> b
forall a b. FromConstant a b => a -> b
fromConstant (Integer -> b) -> Gen Integer -> Gen b
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> (Integer, Integer) -> Gen Integer
chooseInteger (Integer
0, Integer
2 Integer -> b -> Integer
forall a b. Exponent a b => a -> b -> a
^ b
b Integer -> Integer -> Integer
forall a. AdditiveGroup a => a -> a -> a
- Integer
1)

instance (Finite (Zp p), KnownNat n) => ShiftBits (ByteString n (Zp p)) where

    shiftBits :: ByteString n (Zp p) -> Integer -> ByteString n (Zp p)
shiftBits ByteString n (Zp p)
b Integer
s = Natural -> ByteString n (Zp p)
forall a b. FromConstant a b => a -> b
fromConstant (Natural -> ByteString n (Zp p)) -> Natural -> ByteString n (Zp p)
forall a b. (a -> b) -> a -> b
$ Natural -> Int -> Natural
forall a. Bits a => a -> Int -> a
shift (forall a b. ToConstant a b => a -> b
toConstant @_ @Natural ByteString n (Zp p)
b) (Integer -> Int
forall a b. (Integral a, Num b) => a -> b
Haskell.fromIntegral Integer
s) Natural -> Natural -> Natural
forall a. Integral a => a -> a -> a
`Haskell.mod` (Natural
2 Natural -> Natural -> Natural
forall a b. (Num a, Integral b) => a -> b -> a
Haskell.^ (forall (n :: Natural). KnownNat n => Natural
getNatural @n))

    -- | @Data.Bits.rotate@ works exactly as @Data.Bits.shift@ for @Natural@, we have to rotate bits manually.
    rotateBitsL :: ByteString n (Zp p) -> Natural -> ByteString n (Zp p)
rotateBitsL ByteString n (Zp p)
b Natural
s
      | Natural
s Natural -> Natural -> Bool
forall a. Eq a => a -> a -> Bool
== Natural
0 = ByteString n (Zp p)
b
       -- Rotations by k * n + p bits where n is the length of the ByteString are equivalent to rotations by p bits.
      | Natural
s Natural -> Natural -> Bool
forall a. Ord a => a -> a -> Bool
>= (forall (n :: Natural). KnownNat n => Natural
getNatural @n) = ByteString n (Zp p) -> Natural -> ByteString n (Zp p)
forall a. ShiftBits a => a -> Natural -> a
rotateBitsL ByteString n (Zp p)
b (Natural
s Natural -> Natural -> Natural
forall a. Integral a => a -> a -> a
`Haskell.mod` (forall (n :: Natural). KnownNat n => Natural
getNatural @n))
      | Bool
otherwise = Natural -> ByteString n (Zp p)
forall a b. FromConstant a b => a -> b
fromConstant (Natural -> ByteString n (Zp p)) -> Natural -> ByteString n (Zp p)
forall a b. (a -> b) -> a -> b
$ Natural
d Natural -> Natural -> Natural
forall a. AdditiveSemigroup a => a -> a -> a
+ Natural
m
        where
            nat :: Natural
            nat :: Natural
nat = ByteString n (Zp p) -> Natural
forall a b. ToConstant a b => a -> b
toConstant ByteString n (Zp p)
b

            bound :: Natural
            bound :: Natural
bound = Natural
2 Natural -> Natural -> Natural
forall a b. (Num a, Integral b) => a -> b -> a
Haskell.^ forall (n :: Natural). KnownNat n => Natural
getNatural @n

            d, m :: Natural
            (Natural
d, Natural
m) = (Natural
nat Natural -> Int -> Natural
forall a. Bits a => a -> Int -> a
`shiftL` Natural -> Int
forall a b. (Integral a, Num b) => a -> b
Haskell.fromIntegral Natural
s) Natural -> Natural -> (Natural, Natural)
forall a. Integral a => a -> a -> (a, a)
`Haskell.divMod` Natural
bound

    rotateBitsR :: ByteString n (Zp p) -> Natural -> ByteString n (Zp p)
rotateBitsR ByteString n (Zp p)
b Natural
s
      | Natural
s Natural -> Natural -> Bool
forall a. Eq a => a -> a -> Bool
== Natural
0 = ByteString n (Zp p)
b
       -- Rotations by k * n + p bits where n is the length of the ByteString are equivalent to rotations by p bits.
      | Natural
s Natural -> Natural -> Bool
forall a. Ord a => a -> a -> Bool
>= (forall (n :: Natural). KnownNat n => Natural
getNatural @n) = ByteString n (Zp p) -> Natural -> ByteString n (Zp p)
forall a. ShiftBits a => a -> Natural -> a
rotateBitsR ByteString n (Zp p)
b (Natural
s Natural -> Natural -> Natural
forall a. Integral a => a -> a -> a
`Haskell.mod` (forall (n :: Natural). KnownNat n => Natural
getNatural @n))
      | Bool
otherwise = Natural -> ByteString n (Zp p)
forall a b. FromConstant a b => a -> b
fromConstant (Natural -> ByteString n (Zp p)) -> Natural -> ByteString n (Zp p)
forall a b. (a -> b) -> a -> b
$ Natural
d Natural -> Natural -> Natural
forall a. AdditiveSemigroup a => a -> a -> a
+ Natural
m
        where
            nat :: Natural
            nat :: Natural
nat = ByteString n (Zp p) -> Natural
forall a b. ToConstant a b => a -> b
toConstant ByteString n (Zp p)
b

            intS :: Haskell.Int
            intS :: Int
intS = Natural -> Int
forall a b. (Integral a, Num b) => a -> b
Haskell.fromIntegral Natural
s

            m :: Natural
            m :: Natural
m = (Natural
nat Natural -> Natural -> Natural
forall a. Integral a => a -> a -> a
`Haskell.mod` (Natural
2 Natural -> Int -> Natural
forall a b. (Num a, Integral b) => a -> b -> a
Haskell.^ Int
intS)) Natural -> Int -> Natural
forall a. Bits a => a -> Int -> a
`shiftL` (Natural -> Int
forall a b. (Integral a, Num b) => a -> b
Haskell.fromIntegral (forall (n :: Natural). KnownNat n => Natural
getNatural @n) Int -> Int -> Int
forall a. Num a => a -> a -> a
Haskell.- Int
intS)

            d :: Natural
            d :: Natural
d = Natural
nat Natural -> Int -> Natural
forall a. Bits a => a -> Int -> a
`shiftR` Int
intS


instance (Finite (Zp p), KnownNat n) => BoolType (ByteString n (Zp p)) where
    false :: ByteString n (Zp p)
false = Natural -> ByteString n (Zp p)
forall a b. FromConstant a b => a -> b
fromConstant (Natural
0 :: Natural)

    -- | A ByteString with all bits set to 1 is the unity for bitwise and.
    true :: ByteString n (Zp p)
true = ByteString n (Zp p) -> ByteString n (Zp p)
forall b. BoolType b => b -> b
not ByteString n (Zp p)
forall b. BoolType b => b
false

    -- | bitwise not.
    -- @Data.Bits.complement@ is undefined for @Natural@, we have to flip bits manually.
    not :: ByteString n (Zp p) -> ByteString n (Zp p)
not = forall a b. FromConstant a b => a -> b
fromConstant @Natural (Natural -> ByteString n (Zp p))
-> (ByteString n (Zp p) -> Natural)
-> ByteString n (Zp p)
-> ByteString n (Zp p)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Natural
nextPow2 Natural -> Natural -> Natural
-!) (Natural -> Natural)
-> (ByteString n (Zp p) -> Natural)
-> ByteString n (Zp p)
-> Natural
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString n (Zp p) -> Natural
forall a b. ToConstant a b => a -> b
toConstant
      where
        nextPow2 :: Natural
        nextPow2 :: Natural
nextPow2 = (Natural
2 Natural -> Natural -> Natural
forall a b. Exponent a b => a -> b -> a
^ Proxy n -> Natural
forall (n :: Natural) (proxy :: Natural -> Type).
KnownNat n =>
proxy n -> Natural
natVal (forall (t :: Natural). Proxy t
forall {k} (t :: k). Proxy t
Proxy @n)) Natural -> Natural -> Natural
-! Natural
1

    -- | Bitwise or
    ByteString n (Zp p)
x || :: ByteString n (Zp p) -> ByteString n (Zp p) -> ByteString n (Zp p)
|| ByteString n (Zp p)
y = forall a b. FromConstant a b => a -> b
fromConstant @Natural (Natural -> ByteString n (Zp p)) -> Natural -> ByteString n (Zp p)
forall a b. (a -> b) -> a -> b
$ ByteString n (Zp p) -> Natural
forall a b. ToConstant a b => a -> b
toConstant ByteString n (Zp p)
x Natural -> Natural -> Natural
forall a. Bits a => a -> a -> a
.|. ByteString n (Zp p) -> Natural
forall a b. ToConstant a b => a -> b
toConstant ByteString n (Zp p)
y

    -- | Bitwise and
    ByteString n (Zp p)
x && :: ByteString n (Zp p) -> ByteString n (Zp p) -> ByteString n (Zp p)
&& ByteString n (Zp p)
y = forall a b. FromConstant a b => a -> b
fromConstant @Natural (Natural -> ByteString n (Zp p)) -> Natural -> ByteString n (Zp p)
forall a b. (a -> b) -> a -> b
$ ByteString n (Zp p) -> Natural
forall a b. ToConstant a b => a -> b
toConstant ByteString n (Zp p)
x Natural -> Natural -> Natural
forall a. Bits a => a -> a -> a
.&. ByteString n (Zp p) -> Natural
forall a b. ToConstant a b => a -> b
toConstant ByteString n (Zp p)
y

    -- | Bitwise xor
    xor :: ByteString n (Zp p) -> ByteString n (Zp p) -> ByteString n (Zp p)
xor ByteString n (Zp p)
x ByteString n (Zp p)
y = forall a b. FromConstant a b => a -> b
fromConstant @Natural (Natural -> ByteString n (Zp p)) -> Natural -> ByteString n (Zp p)
forall a b. (a -> b) -> a -> b
$ ByteString n (Zp p) -> Natural
forall a b. ToConstant a b => a -> b
toConstant ByteString n (Zp p)
x Natural -> Natural -> Natural
forall a. Bits a => a -> a -> a
`B.xor` ByteString n (Zp p) -> Natural
forall a b. ToConstant a b => a -> b
toConstant ByteString n (Zp p)
y


-- | A ByteString of length @n@ can only be split into words of length @wordSize@ if all of the following conditions are met:
-- 1. @wordSize@ is not greater than @n@;
-- 2. @wordSize@ is not zero;
-- 3. The bytestring is not empty;
-- 4. @wordSize@ divides @n@.
--
instance
  ( KnownNat wordSize
  , KnownNat n
  , Finite (Zp p)
  , wordSize <= n
  , 1 <= wordSize
  , 1 <= n
  , Mod n wordSize ~ 0
  ) => ToWords (ByteString n (Zp p)) (ByteString wordSize (Zp p)) where

    toWords :: ByteString n (Zp p) -> [ByteString wordSize (Zp p)]
toWords ByteString n (Zp p)
bs = (Natural -> ByteString wordSize (Zp p))
-> [Natural] -> [ByteString wordSize (Zp p)]
forall a b. (a -> b) -> [a] -> [b]
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
fmap Natural -> ByteString wordSize (Zp p)
forall a b. FromConstant a b => a -> b
fromConstant ([Natural] -> [ByteString wordSize (Zp p)])
-> [Natural] -> [ByteString wordSize (Zp p)]
forall a b. (a -> b) -> a -> b
$ [Natural] -> [Natural]
forall a. [a] -> [a]
reverse ([Natural] -> [Natural]) -> [Natural] -> [Natural]
forall a b. (a -> b) -> a -> b
$ Int -> [Natural] -> [Natural]
forall a. Int -> [a] -> [a]
take (Natural -> Int
forall a b. (Integral a, Num b) => a -> b
Haskell.fromIntegral (Natural -> Int) -> Natural -> Int
forall a b. (a -> b) -> a -> b
$ Natural
n Natural -> Natural -> Natural
forall a. Integral a => a -> a -> a
`Haskell.div` Natural
wordSize) [Natural]
natWords
      where
        asNat :: Natural
        asNat :: Natural
asNat = ByteString n (Zp p) -> Natural
forall a b. ToConstant a b => a -> b
toConstant ByteString n (Zp p)
bs

        n :: Natural
        n :: Natural
n = forall (n :: Natural). KnownNat n => Natural
getNatural @n

        wordSize :: Natural
        wordSize :: Natural
wordSize = forall (n :: Natural). KnownNat n => Natural
getNatural @wordSize

        natWords :: [Natural]
        natWords :: [Natural]
natWords = (Natural -> Maybe (Natural, Natural)) -> Natural -> [Natural]
forall b a. (b -> Maybe (a, b)) -> b -> [a]
unfoldr (Natural -> Natural -> Maybe (Natural, Natural)
forall a.
FromConstant Natural a =>
Natural -> Natural -> Maybe (a, Natural)
toBase (Natural
2 Natural -> Natural -> Natural
forall a b. (Num a, Integral b) => a -> b -> a
Haskell.^ Natural
wordSize)) Natural
asNat [Natural] -> [Natural] -> [Natural]
forall a. Semigroup a => a -> a -> a
<> Natural -> [Natural]
forall a. a -> [a]
Haskell.repeat (forall a b. FromConstant a b => a -> b
fromConstant @Natural Natural
0)

-- | Unfortunately, Haskell does not support dependent types yet,
-- so we have no possibility to infer the exact type of the result
-- (the list can contain an arbitrary number of words).
-- We can only impose some restrictions on @n@ and @m@.
--
instance
  ( KnownNat n
  , KnownNat m
  , m <= n
  , Mod n m ~ 0
  , Finite (Zp p)
  ) => Concat (ByteString m (Zp p)) (ByteString n (Zp p)) where

    concat :: [ByteString m (Zp p)] -> ByteString n (Zp p)
concat = forall a b. FromConstant a b => a -> b
fromConstant @Natural (Natural -> ByteString n (Zp p))
-> ([ByteString m (Zp p)] -> Natural)
-> [ByteString m (Zp p)]
-> ByteString n (Zp p)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Natural -> ByteString m (Zp p) -> Natural)
-> Natural -> [ByteString m (Zp p)] -> Natural
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: Type -> Type) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl (\Natural
p ByteString m (Zp p)
y -> ByteString m (Zp p) -> Natural
forall a b. ToConstant a b => a -> b
toConstant ByteString m (Zp p)
y Natural -> Natural -> Natural
forall a. AdditiveSemigroup a => a -> a -> a
+ Natural
p Natural -> Int -> Natural
forall a. Bits a => a -> Int -> a
`shift` Int
m) Natural
0
        where
            m :: Int
m = Natural -> Int
forall a b. (Integral a, Num b) => a -> b
Haskell.fromIntegral (Natural -> Int) -> Natural -> Int
forall a b. (a -> b) -> a -> b
$ forall (n :: Natural). KnownNat n => Natural
getNatural @m


-- | Only a bigger ByteString can be truncated into a smaller one.
--
instance
  ( KnownNat m
  , KnownNat n
  , n <= m
  , Finite (Zp p)
  ) => Truncate (ByteString m (Zp p)) (ByteString n (Zp p)) where

    truncate :: ByteString m (Zp p) -> ByteString n (Zp p)
truncate = forall a b. FromConstant a b => a -> b
fromConstant @Natural (Natural -> ByteString n (Zp p))
-> (ByteString m (Zp p) -> Natural)
-> ByteString m (Zp p)
-> ByteString n (Zp p)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Natural -> Int -> Natural
forall a. Bits a => a -> Int -> a
`shiftR` Int
diff) (Natural -> Natural)
-> (ByteString m (Zp p) -> Natural)
-> ByteString m (Zp p)
-> Natural
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString m (Zp p) -> Natural
forall a b. ToConstant a b => a -> b
toConstant
        where
            diff :: Haskell.Int
            diff :: Int
diff = Natural -> Int
forall a b. (Integral a, Num b) => a -> b
Haskell.fromIntegral (Natural -> Int) -> Natural -> Int
forall a b. (a -> b) -> a -> b
$ forall (n :: Natural). KnownNat n => Natural
getNatural @m Natural -> Natural -> Natural
forall a. Num a => a -> a -> a
Haskell.- forall (n :: Natural). KnownNat n => Natural
getNatural @n

-- | Only a smaller ByteString can be extended into a bigger one.
--
instance
  ( KnownNat n
  , m <= n
  , Finite (Zp p)
  ) => Extend (ByteString m (Zp p)) (ByteString n (Zp p)) where

    extend :: ByteString m (Zp p) -> ByteString n (Zp p)
extend = forall a b. FromConstant a b => a -> b
fromConstant @Natural (Natural -> ByteString n (Zp p))
-> (ByteString m (Zp p) -> Natural)
-> ByteString m (Zp p)
-> ByteString n (Zp p)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString m (Zp p) -> Natural
forall a b. ToConstant a b => a -> b
toConstant

--------------------------------------------------------------------------------

instance (Arithmetic a, KnownNat n) => SymbolicData a (ByteString n (ArithmeticCircuit a)) where
    pieces :: ByteString n (ArithmeticCircuit a) -> [ArithmeticCircuit a]
pieces (ByteString [ArithmeticCircuit a]
bits) = [ArithmeticCircuit a]
bits

    restore :: [ArithmeticCircuit a] -> ByteString n (ArithmeticCircuit a)
restore [ArithmeticCircuit a]
bits
      | Int -> Natural
forall a b. (Integral a, Num b) => a -> b
Haskell.fromIntegral ([ArithmeticCircuit a] -> Int
forall a. [a] -> Int
forall (t :: Type -> Type) a. Foldable t => t a -> Int
length [ArithmeticCircuit a]
bits) Natural -> Natural -> Bool
forall a. Eq a => a -> a -> Bool
== forall (n :: Natural). KnownNat n => Natural
value @n = [ArithmeticCircuit a] -> ByteString n (ArithmeticCircuit a)
forall (n :: Natural) a. [a] -> ByteString n a
ByteString [ArithmeticCircuit a]
bits
      | Bool
otherwise = String -> ByteString n (ArithmeticCircuit a)
forall a. HasCallStack => String -> a
error String
"ByteString: invalid number of values"

    typeSize :: Natural
typeSize = forall (n :: Natural). KnownNat n => Natural
value @n


instance (Arithmetic a, KnownNat n) => Iso (ByteString n (ArithmeticCircuit a)) (UInt n (ArithmeticCircuit a)) where
    from :: ByteString n (ArithmeticCircuit a) -> UInt n (ArithmeticCircuit a)
from (ByteString [ArithmeticCircuit a]
bits) =
       case (forall i (m :: Type -> Type). MonadBlueprint i a m => m [i])
-> [ArithmeticCircuit a]
forall a (f :: Type -> Type).
(Arithmetic a, Functor f) =>
(forall i (m :: Type -> Type). MonadBlueprint i a m => m (f i))
-> f (ArithmeticCircuit a)
circuits m [i]
forall i (m :: Type -> Type). MonadBlueprint i a m => m [i]
solve of
           (ArithmeticCircuit a
x:[ArithmeticCircuit a]
xs) -> [ArithmeticCircuit a]
-> ArithmeticCircuit a -> UInt n (ArithmeticCircuit a)
forall (n :: Natural) a. [a] -> a -> UInt n a
UInt ([ArithmeticCircuit a] -> [ArithmeticCircuit a]
forall a. [a] -> [a]
Haskell.reverse [ArithmeticCircuit a]
xs) ArithmeticCircuit a
x
           [ArithmeticCircuit a]
_      -> String -> UInt n (ArithmeticCircuit a)
forall a. HasCallStack => String -> a
error String
"Iso ByteString UInt : unreachable"
        where
            solve :: forall i m. MonadBlueprint i a m => m [i]
            solve :: forall i (m :: Type -> Type). MonadBlueprint i a m => m [i]
solve = do
                [i]
bsBits <- (ArithmeticCircuit a -> m i) -> [ArithmeticCircuit a] -> m [i]
forall (t :: Type -> Type) (m :: Type -> Type) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: Type -> Type) a b.
Monad m =>
(a -> m b) -> [a] -> m [b]
mapM ArithmeticCircuit a -> m i
forall i a (m :: Type -> Type).
MonadBlueprint i a m =>
ArithmeticCircuit a -> m i
runCircuit [ArithmeticCircuit a]
bits
                Natural
-> Natural
-> forall i (m :: Type -> Type).
   MonadBlueprint i a m =>
   [i] -> m [i]
forall a.
Natural
-> Natural
-> forall i (m :: Type -> Type).
   MonadBlueprint i a m =>
   [i] -> m [i]
fromBits (forall a (n :: Natural). (Finite a, KnownNat n) => Natural
highRegisterSize @a @n) (forall a (n :: Natural). (Finite a, KnownNat n) => Natural
registerSize @a @n) [i]
bsBits

instance (Arithmetic a, KnownNat n) => Iso (UInt n (ArithmeticCircuit a)) (ByteString n (ArithmeticCircuit a)) where
    from :: UInt n (ArithmeticCircuit a) -> ByteString n (ArithmeticCircuit a)
from (UInt [ArithmeticCircuit a]
rs ArithmeticCircuit a
r) = [ArithmeticCircuit a] -> ByteString n (ArithmeticCircuit a)
forall (n :: Natural) a. [a] -> ByteString n a
ByteString ([ArithmeticCircuit a] -> ByteString n (ArithmeticCircuit a))
-> [ArithmeticCircuit a] -> ByteString n (ArithmeticCircuit a)
forall a b. (a -> b) -> a -> b
$ (forall i (m :: Type -> Type). MonadBlueprint i a m => m [i])
-> [ArithmeticCircuit a]
forall a (f :: Type -> Type).
(Arithmetic a, Functor f) =>
(forall i (m :: Type -> Type). MonadBlueprint i a m => m (f i))
-> f (ArithmeticCircuit a)
circuits m [i]
forall i (m :: Type -> Type). MonadBlueprint i a m => m [i]
solve
        where
            solve :: forall i m. MonadBlueprint i a m => m [i]
            solve :: forall i (m :: Type -> Type). MonadBlueprint i a m => m [i]
solve = ArithmeticCircuit a
-> [ArithmeticCircuit a]
-> Natural
-> Natural
-> forall i (m :: Type -> Type). MonadBlueprint i a m => m [i]
forall a.
ArithmeticCircuit a
-> [ArithmeticCircuit a]
-> Natural
-> Natural
-> forall i (m :: Type -> Type). MonadBlueprint i a m => m [i]
toBits ArithmeticCircuit a
r ([ArithmeticCircuit a] -> [ArithmeticCircuit a]
forall a. [a] -> [a]
Haskell.reverse [ArithmeticCircuit a]
rs) (forall a (n :: Natural). (Finite a, KnownNat n) => Natural
highRegisterSize @a @n) (forall a (n :: Natural). (Finite a, KnownNat n) => Natural
registerSize @a @n)

instance (Arithmetic a, KnownNat n) => ShiftBits (ByteString n (ArithmeticCircuit a)) where
    shiftBits :: ByteString n (ArithmeticCircuit a)
-> Integer -> ByteString n (ArithmeticCircuit a)
shiftBits bs :: ByteString n (ArithmeticCircuit a)
bs@(ByteString [ArithmeticCircuit a]
oldBits) Integer
s
      | Integer
s Integer -> Integer -> Bool
forall a. Eq a => a -> a -> Bool
== Integer
0 = ByteString n (ArithmeticCircuit a)
bs
      | Integer -> Integer
forall a. Num a => a -> a
Haskell.abs Integer
s Integer -> Integer -> Bool
forall a. Ord a => a -> a -> Bool
>= Natural -> Integer
forall a b. (Integral a, Num b) => a -> b
Haskell.fromIntegral (forall (n :: Natural). KnownNat n => Natural
getNatural @n) = ByteString n (ArithmeticCircuit a)
forall b. BoolType b => b
false
      | Bool
otherwise = [ArithmeticCircuit a] -> ByteString n (ArithmeticCircuit a)
forall (n :: Natural) a. [a] -> ByteString n a
ByteString ([ArithmeticCircuit a] -> ByteString n (ArithmeticCircuit a))
-> [ArithmeticCircuit a] -> ByteString n (ArithmeticCircuit a)
forall a b. (a -> b) -> a -> b
$ (forall i (m :: Type -> Type). MonadBlueprint i a m => m [i])
-> [ArithmeticCircuit a]
forall a (f :: Type -> Type).
(Arithmetic a, Functor f) =>
(forall i (m :: Type -> Type). MonadBlueprint i a m => m (f i))
-> f (ArithmeticCircuit a)
circuits m [i]
forall i (m :: Type -> Type). MonadBlueprint i a m => m [i]
solve
      where
        solve :: forall i m. MonadBlueprint i a m => m [i]
        solve :: forall i (m :: Type -> Type). MonadBlueprint i a m => m [i]
solve = do
            [i]
bits <- (ArithmeticCircuit a -> m i) -> [ArithmeticCircuit a] -> m [i]
forall (t :: Type -> Type) (m :: Type -> Type) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: Type -> Type) a b.
Monad m =>
(a -> m b) -> [a] -> m [b]
mapM ArithmeticCircuit a -> m i
forall i a (m :: Type -> Type).
MonadBlueprint i a m =>
ArithmeticCircuit a -> m i
runCircuit [ArithmeticCircuit a]
oldBits
            [i]
zeros <- Int -> m i -> m [i]
forall (m :: Type -> Type) a. Applicative m => Int -> m a -> m [a]
replicateM (Integer -> Int
forall a b. (Integral a, Num b) => a -> b
Haskell.fromIntegral (Integer -> Int) -> Integer -> Int
forall a b. (a -> b) -> a -> b
$ Integer -> Integer
forall a. Num a => a -> a
Haskell.abs Integer
s) (m i -> m [i]) -> m i -> m [i]
forall a b. (a -> b) -> a -> b
$ ClosedPoly i a -> m i
forall i a (m :: Type -> Type).
MonadBlueprint i a m =>
ClosedPoly i a -> m i
newAssigned (x -> (i -> x) -> x
forall a b. a -> b -> a
Haskell.const x
forall a. AdditiveMonoid a => a
zero)

            [i] -> m [i]
forall a. a -> m a
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure ([i] -> m [i]) -> [i] -> m [i]
forall a b. (a -> b) -> a -> b
$ case (Integer
s Integer -> Integer -> Bool
forall a. Ord a => a -> a -> Bool
< Integer
0) of
                        Bool
True  -> Int -> [i] -> [i]
forall a. Int -> [a] -> [a]
take (Natural -> Int
forall a b. (Integral a, Num b) => a -> b
Haskell.fromIntegral (Natural -> Int) -> Natural -> Int
forall a b. (a -> b) -> a -> b
$ forall (n :: Natural). KnownNat n => Natural
getNatural @n) ([i] -> [i]) -> [i] -> [i]
forall a b. (a -> b) -> a -> b
$ [i]
zeros [i] -> [i] -> [i]
forall a. Semigroup a => a -> a -> a
<> [i]
bits
                        Bool
False -> Int -> [i] -> [i]
forall a. Int -> [a] -> [a]
drop (Integer -> Int
forall a b. (Integral a, Num b) => a -> b
Haskell.fromIntegral Integer
s) ([i] -> [i]) -> [i] -> [i]
forall a b. (a -> b) -> a -> b
$ [i]
bits [i] -> [i] -> [i]
forall a. Semigroup a => a -> a -> a
<> [i]
zeros


    -- | rotateBits does not even require operations on the circuit.
    --
    rotateBits :: ByteString n (ArithmeticCircuit a)
-> Integer -> ByteString n (ArithmeticCircuit a)
rotateBits bs :: ByteString n (ArithmeticCircuit a)
bs@(ByteString [ArithmeticCircuit a]
bits) Integer
s
      | Integer
s Integer -> Integer -> Bool
forall a. Eq a => a -> a -> Bool
== Integer
0 = ByteString n (ArithmeticCircuit a)
bs
      | (Integer
s Integer -> Integer -> Bool
forall a. Ord a => a -> a -> Bool
< Integer
0) Bool -> Bool -> Bool
forall b. BoolType b => b -> b -> b
|| (Integer
s Integer -> Integer -> Bool
forall a. Ord a => a -> a -> Bool
>= Integer
intN) = ByteString n (ArithmeticCircuit a)
-> Integer -> ByteString n (ArithmeticCircuit a)
forall a. ShiftBits a => a -> Integer -> a
rotateBits ByteString n (ArithmeticCircuit a)
bs (Integer
s Integer -> Integer -> Integer
forall a. Integral a => a -> a -> a
`Haskell.mod` Integer
intN) -- Always perform a left rotation
      | Bool
otherwise = [ArithmeticCircuit a] -> ByteString n (ArithmeticCircuit a)
forall (n :: Natural) a. [a] -> ByteString n a
ByteString ([ArithmeticCircuit a] -> ByteString n (ArithmeticCircuit a))
-> [ArithmeticCircuit a] -> ByteString n (ArithmeticCircuit a)
forall a b. (a -> b) -> a -> b
$ [ArithmeticCircuit a] -> [ArithmeticCircuit a]
forall a. [a] -> [a]
rotateList [ArithmeticCircuit a]
bits
      where
        intN :: Integer
        intN :: Integer
intN = Natural -> Integer
forall a b. (Integral a, Num b) => a -> b
Haskell.fromIntegral (forall (n :: Natural). KnownNat n => Natural
getNatural @n)

        rotateList :: [e] -> [e]
        rotateList :: forall a. [a] -> [a]
rotateList [e]
lst = Int -> [e] -> [e]
forall a. Int -> [a] -> [a]
drop (Integer -> Int
forall a b. (Integral a, Num b) => a -> b
Haskell.fromIntegral Integer
s) [e]
lst [e] -> [e] -> [e]
forall a. Semigroup a => a -> a -> a
<> Int -> [e] -> [e]
forall a. Int -> [a] -> [a]
take (Integer -> Int
forall a b. (Integral a, Num b) => a -> b
Haskell.fromIntegral Integer
s) [e]
lst


-- | A generic bitwise operation on two ByteStrings.
-- TODO: Shall we expose it to users? Can they do something malicious having such function? AFAIK there are checks that constrain each bit to 0 or 1.
--
bitwiseOperation
    :: forall a n
    .  Arithmetic a
    => ByteString n (ArithmeticCircuit a)
    -> ByteString n (ArithmeticCircuit a)
    -> (forall i. i -> i -> ClosedPoly i a)
    -> ByteString n (ArithmeticCircuit a)
bitwiseOperation :: forall a (n :: Natural).
Arithmetic a =>
ByteString n (ArithmeticCircuit a)
-> ByteString n (ArithmeticCircuit a)
-> (forall i. i -> i -> ClosedPoly i a)
-> ByteString n (ArithmeticCircuit a)
bitwiseOperation (ByteString [ArithmeticCircuit a]
bits1) (ByteString [ArithmeticCircuit a]
bits2) forall i. i -> i -> ClosedPoly i a
cons = [ArithmeticCircuit a] -> ByteString n (ArithmeticCircuit a)
forall (n :: Natural) a. [a] -> ByteString n a
ByteString ([ArithmeticCircuit a] -> ByteString n (ArithmeticCircuit a))
-> [ArithmeticCircuit a] -> ByteString n (ArithmeticCircuit a)
forall a b. (a -> b) -> a -> b
$ (forall i (m :: Type -> Type). MonadBlueprint i a m => m [i])
-> [ArithmeticCircuit a]
forall a (f :: Type -> Type).
(Arithmetic a, Functor f) =>
(forall i (m :: Type -> Type). MonadBlueprint i a m => m (f i))
-> f (ArithmeticCircuit a)
circuits m [i]
forall i (m :: Type -> Type). MonadBlueprint i a m => m [i]
solve
  where
    solve :: forall i m. MonadBlueprint i a m => m [i]
    solve :: forall i (m :: Type -> Type). MonadBlueprint i a m => m [i]
solve = do
        [i]
varsLeft  <- (ArithmeticCircuit a -> m i) -> [ArithmeticCircuit a] -> m [i]
forall (t :: Type -> Type) (m :: Type -> Type) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: Type -> Type) a b.
Monad m =>
(a -> m b) -> [a] -> m [b]
mapM ArithmeticCircuit a -> m i
forall i a (m :: Type -> Type).
MonadBlueprint i a m =>
ArithmeticCircuit a -> m i
runCircuit [ArithmeticCircuit a]
bits1
        [i]
varsRight <- (ArithmeticCircuit a -> m i) -> [ArithmeticCircuit a] -> m [i]
forall (t :: Type -> Type) (m :: Type -> Type) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: Type -> Type) a b.
Monad m =>
(a -> m b) -> [a] -> m [b]
mapM ArithmeticCircuit a -> m i
forall i a (m :: Type -> Type).
MonadBlueprint i a m =>
ArithmeticCircuit a -> m i
runCircuit [ArithmeticCircuit a]
bits2
        (i -> i -> m i) -> [i] -> [i] -> m [i]
forall (m :: Type -> Type) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM i -> i -> m i
forall i (m :: Type -> Type). MonadBlueprint i a m => i -> i -> m i
applyBitwise [i]
varsLeft [i]
varsRight

    applyBitwise :: forall i m . MonadBlueprint i a m => i -> i -> m i
    applyBitwise :: forall i (m :: Type -> Type). MonadBlueprint i a m => i -> i -> m i
applyBitwise i
l i
r = ClosedPoly i a -> m i
forall i a (m :: Type -> Type).
MonadBlueprint i a m =>
ClosedPoly i a -> m i
newAssigned (ClosedPoly i a -> m i) -> ClosedPoly i a -> m i
forall a b. (a -> b) -> a -> b
$ i -> i -> ClosedPoly i a
forall i. i -> i -> ClosedPoly i a
cons i
l i
r


instance (Arithmetic a, KnownNat n) => BoolType (ByteString n (ArithmeticCircuit a)) where
    false :: ByteString n (ArithmeticCircuit a)
false = [ArithmeticCircuit a] -> ByteString n (ArithmeticCircuit a)
forall (n :: Natural) a. [a] -> ByteString n a
ByteString (Natural -> ArithmeticCircuit a -> [ArithmeticCircuit a]
forall a. Natural -> a -> [a]
replicate (forall (n :: Natural). KnownNat n => Natural
value @n) ArithmeticCircuit a
forall a. AdditiveMonoid a => a
zero)

    true :: ByteString n (ArithmeticCircuit a)
true = ByteString n (ArithmeticCircuit a)
-> ByteString n (ArithmeticCircuit a)
forall b. BoolType b => b -> b
not ByteString n (ArithmeticCircuit a)
forall b. BoolType b => b
false

    not :: ByteString n (ArithmeticCircuit a)
-> ByteString n (ArithmeticCircuit a)
not (ByteString [ArithmeticCircuit a]
bits) = [ArithmeticCircuit a] -> ByteString n (ArithmeticCircuit a)
forall (n :: Natural) a. [a] -> ByteString n a
ByteString (ArithmeticCircuit a -> ArithmeticCircuit a
forall {a}.
(Finite a, Field a, Eq a, BinaryExpansion a) =>
ArithmeticCircuit a -> ArithmeticCircuit a
flipBits (ArithmeticCircuit a -> ArithmeticCircuit a)
-> [ArithmeticCircuit a] -> [ArithmeticCircuit a]
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> [ArithmeticCircuit a]
bits)
        where
            flipBits :: ArithmeticCircuit a -> ArithmeticCircuit a
flipBits ArithmeticCircuit a
r = (forall i (m :: Type -> Type). MonadBlueprint i a m => m i)
-> ArithmeticCircuit a
forall a.
Arithmetic a =>
(forall i (m :: Type -> Type). MonadBlueprint i a m => m i)
-> ArithmeticCircuit a
circuit ((forall i (m :: Type -> Type). MonadBlueprint i a m => m i)
 -> ArithmeticCircuit a)
-> (forall i (m :: Type -> Type). MonadBlueprint i a m => m i)
-> ArithmeticCircuit a
forall a b. (a -> b) -> a -> b
$ do
                i
i <- ArithmeticCircuit a -> m i
forall i a (m :: Type -> Type).
MonadBlueprint i a m =>
ArithmeticCircuit a -> m i
runCircuit ArithmeticCircuit a
r
                ClosedPoly i a -> m i
forall i a (m :: Type -> Type).
MonadBlueprint i a m =>
ClosedPoly i a -> m i
newAssigned (\i -> x
p -> x
forall a. MultiplicativeMonoid a => a
one x -> x -> x
forall a. AdditiveGroup a => a -> a -> a
- i -> x
p i
i)

    ByteString n (ArithmeticCircuit a)
l || :: ByteString n (ArithmeticCircuit a)
-> ByteString n (ArithmeticCircuit a)
-> ByteString n (ArithmeticCircuit a)
|| ByteString n (ArithmeticCircuit a)
r = ByteString n (ArithmeticCircuit a)
-> ByteString n (ArithmeticCircuit a)
-> (forall i. i -> i -> ClosedPoly i a)
-> ByteString n (ArithmeticCircuit a)
forall a (n :: Natural).
Arithmetic a =>
ByteString n (ArithmeticCircuit a)
-> ByteString n (ArithmeticCircuit a)
-> (forall i. i -> i -> ClosedPoly i a)
-> ByteString n (ArithmeticCircuit a)
bitwiseOperation ByteString n (ArithmeticCircuit a)
l ByteString n (ArithmeticCircuit a)
r (\i
i i
j i -> x
x -> i -> x
x i
i x -> x -> x
forall a. AdditiveSemigroup a => a -> a -> a
+ i -> x
x i
j x -> x -> x
forall a. AdditiveGroup a => a -> a -> a
- i -> x
x i
i x -> x -> x
forall a. MultiplicativeSemigroup a => a -> a -> a
* i -> x
x i
j)

    ByteString n (ArithmeticCircuit a)
l && :: ByteString n (ArithmeticCircuit a)
-> ByteString n (ArithmeticCircuit a)
-> ByteString n (ArithmeticCircuit a)
&& ByteString n (ArithmeticCircuit a)
r = ByteString n (ArithmeticCircuit a)
-> ByteString n (ArithmeticCircuit a)
-> (forall i. i -> i -> ClosedPoly i a)
-> ByteString n (ArithmeticCircuit a)
forall a (n :: Natural).
Arithmetic a =>
ByteString n (ArithmeticCircuit a)
-> ByteString n (ArithmeticCircuit a)
-> (forall i. i -> i -> ClosedPoly i a)
-> ByteString n (ArithmeticCircuit a)
bitwiseOperation ByteString n (ArithmeticCircuit a)
l ByteString n (ArithmeticCircuit a)
r (\i
i i
j i -> x
x -> i -> x
x i
i x -> x -> x
forall a. MultiplicativeSemigroup a => a -> a -> a
* i -> x
x i
j)

    xor :: ByteString n (ArithmeticCircuit a)
-> ByteString n (ArithmeticCircuit a)
-> ByteString n (ArithmeticCircuit a)
xor ByteString n (ArithmeticCircuit a)
l ByteString n (ArithmeticCircuit a)
r = ByteString n (ArithmeticCircuit a)
-> ByteString n (ArithmeticCircuit a)
-> (forall i. i -> i -> ClosedPoly i a)
-> ByteString n (ArithmeticCircuit a)
forall a (n :: Natural).
Arithmetic a =>
ByteString n (ArithmeticCircuit a)
-> ByteString n (ArithmeticCircuit a)
-> (forall i. i -> i -> ClosedPoly i a)
-> ByteString n (ArithmeticCircuit a)
bitwiseOperation ByteString n (ArithmeticCircuit a)
l ByteString n (ArithmeticCircuit a)
r (\i
i i
j i -> x
x -> i -> x
x i
i x -> x -> x
forall a. AdditiveSemigroup a => a -> a -> a
+ i -> x
x i
j x -> x -> x
forall a. AdditiveGroup a => a -> a -> a
- (i -> x
x i
i x -> x -> x
forall a. MultiplicativeSemigroup a => a -> a -> a
* i -> x
x i
j x -> x -> x
forall a. AdditiveSemigroup a => a -> a -> a
+ i -> x
x i
i x -> x -> x
forall a. MultiplicativeSemigroup a => a -> a -> a
* i -> x
x i
j))


instance
  ( KnownNat wordSize
  , 1 <= wordSize
  , 1 <= n
  , Mod n wordSize ~ 0
  ) => ToWords (ByteString n (ArithmeticCircuit a)) (ByteString wordSize (ArithmeticCircuit a)) where

    toWords :: ByteString n (ArithmeticCircuit a)
-> [ByteString wordSize (ArithmeticCircuit a)]
toWords (ByteString [ArithmeticCircuit a]
bits) = [ArithmeticCircuit a] -> ByteString wordSize (ArithmeticCircuit a)
forall (n :: Natural) a. [a] -> ByteString n a
ByteString ([ArithmeticCircuit a]
 -> ByteString wordSize (ArithmeticCircuit a))
-> [[ArithmeticCircuit a]]
-> [ByteString wordSize (ArithmeticCircuit a)]
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> Int -> [ArithmeticCircuit a] -> [[ArithmeticCircuit a]]
forall e. Int -> [e] -> [[e]]
chunksOf (Natural -> Int
forall a b. (Integral a, Num b) => a -> b
Haskell.fromIntegral (Natural -> Int) -> Natural -> Int
forall a b. (a -> b) -> a -> b
$ forall (n :: Natural). KnownNat n => Natural
value @wordSize) [ArithmeticCircuit a]
bits

instance
  ( Mod n m ~ 0
  ) => Concat (ByteString m (ArithmeticCircuit a)) (ByteString n (ArithmeticCircuit a)) where

    concat :: [ByteString m (ArithmeticCircuit a)]
-> ByteString n (ArithmeticCircuit a)
concat [ByteString m (ArithmeticCircuit a)]
bs = [ArithmeticCircuit a] -> ByteString n (ArithmeticCircuit a)
forall (n :: Natural) a. [a] -> ByteString n a
ByteString ([ArithmeticCircuit a] -> ByteString n (ArithmeticCircuit a))
-> [ArithmeticCircuit a] -> ByteString n (ArithmeticCircuit a)
forall a b. (a -> b) -> a -> b
$ (ByteString m (ArithmeticCircuit a) -> [ArithmeticCircuit a])
-> [ByteString m (ArithmeticCircuit a)] -> [ArithmeticCircuit a]
forall (t :: Type -> Type) a b.
Foldable t =>
(a -> [b]) -> t a -> [b]
Haskell.concatMap (\(ByteString [ArithmeticCircuit a]
bits) -> [ArithmeticCircuit a]
bits) [ByteString m (ArithmeticCircuit a)]
bs

instance
  ( KnownNat n
  , n <= m
  ) => Truncate (ByteString m (ArithmeticCircuit a)) (ByteString n (ArithmeticCircuit a)) where

    truncate :: ByteString m (ArithmeticCircuit a)
-> ByteString n (ArithmeticCircuit a)
truncate (ByteString [ArithmeticCircuit a]
bits) = [ArithmeticCircuit a] -> ByteString n (ArithmeticCircuit a)
forall (n :: Natural) a. [a] -> ByteString n a
ByteString ([ArithmeticCircuit a] -> ByteString n (ArithmeticCircuit a))
-> [ArithmeticCircuit a] -> ByteString n (ArithmeticCircuit a)
forall a b. (a -> b) -> a -> b
$ Int -> [ArithmeticCircuit a] -> [ArithmeticCircuit a]
forall a. Int -> [a] -> [a]
take (Natural -> Int
forall a b. (Integral a, Num b) => a -> b
Haskell.fromIntegral (Natural -> Int) -> Natural -> Int
forall a b. (a -> b) -> a -> b
$ forall (n :: Natural). KnownNat n => Natural
getNatural @n) [ArithmeticCircuit a]
bits

instance
  ( KnownNat m
  , KnownNat n
  , m <= n
  , Arithmetic a
  ) => Extend (ByteString m (ArithmeticCircuit a)) (ByteString n (ArithmeticCircuit a)) where

    extend :: ByteString m (ArithmeticCircuit a)
-> ByteString n (ArithmeticCircuit a)
extend (ByteString [ArithmeticCircuit a]
oldBits) = [ArithmeticCircuit a] -> ByteString n (ArithmeticCircuit a)
forall (n :: Natural) a. [a] -> ByteString n a
ByteString ([ArithmeticCircuit a] -> ByteString n (ArithmeticCircuit a))
-> [ArithmeticCircuit a] -> ByteString n (ArithmeticCircuit a)
forall a b. (a -> b) -> a -> b
$ (forall i (m :: Type -> Type). MonadBlueprint i a m => m [i])
-> [ArithmeticCircuit a]
forall a (f :: Type -> Type).
(Arithmetic a, Functor f) =>
(forall i (m :: Type -> Type). MonadBlueprint i a m => m (f i))
-> f (ArithmeticCircuit a)
circuits m [i]
forall i (m :: Type -> Type). MonadBlueprint i a m => m [i]
solve
      where
        solve :: forall i m'. MonadBlueprint i a m' => m' [i]
        solve :: forall i (m :: Type -> Type). MonadBlueprint i a m => m [i]
solve = do
            [i]
bits <- (ArithmeticCircuit a -> m' i) -> [ArithmeticCircuit a] -> m' [i]
forall (t :: Type -> Type) (m :: Type -> Type) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: Type -> Type) a b.
Monad m =>
(a -> m b) -> [a] -> m [b]
mapM ArithmeticCircuit a -> m' i
forall i a (m :: Type -> Type).
MonadBlueprint i a m =>
ArithmeticCircuit a -> m i
runCircuit [ArithmeticCircuit a]
oldBits
            [i]
zeros <- Int -> m' i -> m' [i]
forall (m :: Type -> Type) a. Applicative m => Int -> m a -> m [a]
replicateM Int
diff (m' i -> m' [i]) -> m' i -> m' [i]
forall a b. (a -> b) -> a -> b
$ ClosedPoly i a -> m' i
forall i a (m :: Type -> Type).
MonadBlueprint i a m =>
ClosedPoly i a -> m i
newAssigned (x -> (i -> x) -> x
forall a b. a -> b -> a
Haskell.const x
forall a. AdditiveMonoid a => a
zero)
            [i] -> m' [i]
forall a. a -> m' a
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure ([i] -> m' [i]) -> [i] -> m' [i]
forall a b. (a -> b) -> a -> b
$ [i]
zeros [i] -> [i] -> [i]
forall a. Semigroup a => a -> a -> a
<> [i]
bits

        diff :: Haskell.Int
        diff :: Int
diff = Natural -> Int
forall a b. (Integral a, Num b) => a -> b
Haskell.fromIntegral (Natural -> Int) -> Natural -> Int
forall a b. (a -> b) -> a -> b
$ forall (n :: Natural). KnownNat n => Natural
getNatural @n Natural -> Natural -> Natural
forall a. Num a => a -> a -> a
Haskell.- forall (n :: Natural). KnownNat n => Natural
getNatural @m