{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE UndecidableSuperClasses #-}
module ZkFold.Symbolic.Data.Combinators where
import Control.Monad (mapM)
import Data.List (find, splitAt)
import Data.List.Split (chunksOf)
import Data.Maybe (fromMaybe)
import Data.Proxy (Proxy (..))
import Data.Ratio ((%))
import GHC.TypeNats (KnownNat, Natural, natVal)
import Prelude (error, pure, ($), (.), (<$>), (<>))
import qualified Prelude as Haskell
import ZkFold.Base.Algebra.Basic.Class
import ZkFold.Symbolic.Compiler
import ZkFold.Symbolic.Compiler.ArithmeticCircuit.Combinators (expansion, horner)
import ZkFold.Symbolic.Compiler.ArithmeticCircuit.MonadBlueprint
class Iso b a => Iso a b where
from :: a -> b
class Extend a b where
extend :: a -> b
class Shrink a b where
shrink :: a -> b
toBits
:: forall a
. ArithmeticCircuit a
-> [ArithmeticCircuit a]
-> Natural
-> Natural
-> (forall i m. MonadBlueprint i a m => m [i])
toBits :: forall a.
ArithmeticCircuit a
-> [ArithmeticCircuit a]
-> Natural
-> Natural
-> forall i (m :: Type -> Type). MonadBlueprint i a m => m [i]
toBits ArithmeticCircuit a
hi [ArithmeticCircuit a]
lo Natural
hiBits Natural
loBits = do
[i]
lows <- (ArithmeticCircuit a -> m i) -> [ArithmeticCircuit a] -> m [i]
forall (t :: Type -> Type) (m :: Type -> Type) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: Type -> Type) a b.
Monad m =>
(a -> m b) -> [a] -> m [b]
mapM ArithmeticCircuit a -> m i
forall i a (m :: Type -> Type).
MonadBlueprint i a m =>
ArithmeticCircuit a -> m i
runCircuit [ArithmeticCircuit a]
lo
i
high <- ArithmeticCircuit a -> m i
forall i a (m :: Type -> Type).
MonadBlueprint i a m =>
ArithmeticCircuit a -> m i
runCircuit ArithmeticCircuit a
hi
[i]
bitsLow <- ([i] -> [i]) -> [[i]] -> [i]
forall (t :: Type -> Type) a b.
Foldable t =>
(a -> [b]) -> t a -> [b]
Haskell.concatMap [i] -> [i]
forall a. [a] -> [a]
Haskell.reverse ([[i]] -> [i]) -> m [[i]] -> m [i]
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> (i -> m [i]) -> [i] -> m [[i]]
forall (t :: Type -> Type) (m :: Type -> Type) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: Type -> Type) a b.
Monad m =>
(a -> m b) -> [a] -> m [b]
mapM (Natural -> i -> m [i]
forall i a (m :: Type -> Type).
MonadBlueprint i a m =>
Natural -> i -> m [i]
expansion Natural
loBits) [i]
lows
[i]
bitsHigh <- [i] -> [i]
forall a. [a] -> [a]
Haskell.reverse ([i] -> [i]) -> m [i] -> m [i]
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> Natural -> i -> m [i]
forall i a (m :: Type -> Type).
MonadBlueprint i a m =>
Natural -> i -> m [i]
expansion Natural
hiBits i
high
[i] -> m [i]
forall a. a -> m a
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure ([i] -> m [i]) -> [i] -> m [i]
forall a b. (a -> b) -> a -> b
$ [i]
bitsHigh [i] -> [i] -> [i]
forall a. Semigroup a => a -> a -> a
<> [i]
bitsLow
fromBits
:: forall a
. Natural
-> Natural
-> (forall i m. MonadBlueprint i a m => [i] -> m [i])
fromBits :: forall a.
Natural
-> Natural
-> forall i (m :: Type -> Type).
MonadBlueprint i a m =>
[i] -> m [i]
fromBits Natural
hiBits Natural
loBits [i]
bits = do
let ([i]
bitsHighNew, [i]
bitsLowNew) = Int -> [i] -> ([i], [i])
forall a. Int -> [a] -> ([a], [a])
splitAt (Natural -> Int
forall a b. (Integral a, Num b) => a -> b
Haskell.fromIntegral Natural
hiBits) [i]
bits
let lowVarsNew :: [[i]]
lowVarsNew = Int -> [i] -> [[i]]
forall e. Int -> [e] -> [[e]]
chunksOf (Natural -> Int
forall a b. (Integral a, Num b) => a -> b
Haskell.fromIntegral Natural
loBits) [i]
bitsLowNew
[i]
lowsNew <- ([i] -> m i) -> [[i]] -> m [i]
forall (t :: Type -> Type) (m :: Type -> Type) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: Type -> Type) a b.
Monad m =>
(a -> m b) -> [a] -> m [b]
mapM ([i] -> m i
forall i a (m :: Type -> Type). MonadBlueprint i a m => [i] -> m i
horner ([i] -> m i) -> ([i] -> [i]) -> [i] -> m i
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [i] -> [i]
forall a. [a] -> [a]
Haskell.reverse) [[i]]
lowVarsNew
i
highNew <- [i] -> m i
forall i a (m :: Type -> Type). MonadBlueprint i a m => [i] -> m i
horner ([i] -> m i) -> ([i] -> [i]) -> [i] -> m i
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [i] -> [i]
forall a. [a] -> [a]
Haskell.reverse ([i] -> m i) -> [i] -> m i
forall a b. (a -> b) -> a -> b
$ [i]
bitsHighNew
[i] -> m [i]
forall a. a -> m a
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure ([i] -> m [i]) -> [i] -> m [i]
forall a b. (a -> b) -> a -> b
$ i
highNew i -> [i] -> [i]
forall a. a -> [a] -> [a]
: [i]
lowsNew
maxOverflow :: forall a n . (Finite a, KnownNat n) => Natural
maxOverflow :: forall a (n :: Natural). (Finite a, KnownNat n) => Natural
maxOverflow = forall a (n :: Natural). (Finite a, KnownNat n) => Natural
registerSize @a @n Natural -> Natural -> Natural
forall a. AdditiveSemigroup a => a -> a -> a
+ Double -> Natural
forall b. Integral b => Double -> b
forall a b. (RealFrac a, Integral b) => a -> b
Haskell.ceiling (Natural -> Double
log2 (Natural -> Double) -> Natural -> Double
forall a b. (a -> b) -> a -> b
$ forall a (n :: Natural). (Finite a, KnownNat n) => Natural
numberOfRegisters @a @n)
highRegisterSize :: forall a n . (Finite a, KnownNat n) => Natural
highRegisterSize :: forall a (n :: Natural). (Finite a, KnownNat n) => Natural
highRegisterSize = forall (n :: Natural). KnownNat n => Natural
getNatural @n Natural -> Natural -> Natural
-! forall a (n :: Natural). (Finite a, KnownNat n) => Natural
registerSize @a @n Natural -> Natural -> Natural
forall a. MultiplicativeSemigroup a => a -> a -> a
* (forall a (n :: Natural). (Finite a, KnownNat n) => Natural
numberOfRegisters @a @n Natural -> Natural -> Natural
-! Natural
1)
registerSize :: forall a n . (Finite a, KnownNat n) => Natural
registerSize :: forall a (n :: Natural). (Finite a, KnownNat n) => Natural
registerSize = Ratio Natural -> Natural
forall b. Integral b => Ratio Natural -> b
forall a b. (RealFrac a, Integral b) => a -> b
Haskell.ceiling (forall (n :: Natural). KnownNat n => Natural
getNatural @n Natural -> Natural -> Ratio Natural
forall a. Integral a => a -> a -> Ratio a
% forall a (n :: Natural). (Finite a, KnownNat n) => Natural
numberOfRegisters @a @n)
numberOfRegisters :: forall a n . (Finite a, KnownNat n) => Natural
numberOfRegisters :: forall a (n :: Natural). (Finite a, KnownNat n) => Natural
numberOfRegisters = Natural -> Maybe Natural -> Natural
forall a. a -> Maybe a -> a
fromMaybe ([Char] -> Natural
forall a. HasCallStack => [Char] -> a
error [Char]
"too many bits, field is not big enough")
(Maybe Natural -> Natural) -> Maybe Natural -> Natural
forall a b. (a -> b) -> a -> b
$ (Natural -> Bool) -> [Natural] -> Maybe Natural
forall (t :: Type -> Type) a.
Foldable t =>
(a -> Bool) -> t a -> Maybe a
find (\Natural
c -> Natural
c Natural -> Natural -> Natural
forall a. MultiplicativeSemigroup a => a -> a -> a
* Natural -> Natural
maxRegisterSize Natural
c Natural -> Natural -> Bool
forall a. Ord a => a -> a -> Bool
Haskell.>= forall (n :: Natural). KnownNat n => Natural
getNatural @n) [Natural
1 .. Natural
maxRegisterCount]
where
maxRegisterCount :: Natural
maxRegisterCount = Natural
2 Natural -> Natural -> Natural
forall a b. Exponent a b => a -> b -> a
^ Natural
bitLimit
bitLimit :: Natural
bitLimit = Double -> Natural
forall b. Integral b => Double -> b
forall a b. (RealFrac a, Integral b) => a -> b
Haskell.floor (Double -> Natural) -> Double -> Natural
forall a b. (a -> b) -> a -> b
$ Natural -> Double
log2 (forall a. Finite a => Natural
order @a)
maxRegisterSize :: Natural -> Natural
maxRegisterSize Natural
regCount =
let maxAdded :: Natural
maxAdded = Double -> Natural
forall b. Integral b => Double -> b
forall a b. (RealFrac a, Integral b) => a -> b
Haskell.ceiling (Double -> Natural) -> Double -> Natural
forall a b. (a -> b) -> a -> b
$ Natural -> Double
log2 Natural
regCount
in Ratio Natural -> Natural
forall b. Integral b => Ratio Natural -> b
forall a b. (RealFrac a, Integral b) => a -> b
Haskell.floor (Ratio Natural -> Natural) -> Ratio Natural -> Natural
forall a b. (a -> b) -> a -> b
$ (Natural
bitLimit Natural -> Natural -> Natural
-! Natural
maxAdded) Natural -> Natural -> Ratio Natural
forall a. Integral a => a -> a -> Ratio a
% Natural
2
log2 :: Natural -> Haskell.Double
log2 :: Natural -> Double
log2 = Double -> Double -> Double
forall a. Floating a => a -> a -> a
Haskell.logBase Double
2 (Double -> Double) -> (Natural -> Double) -> Natural -> Double
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Natural -> Double
forall a b. (Integral a, Num b) => a -> b
Haskell.fromIntegral
getNatural :: forall n . KnownNat n => Natural
getNatural :: forall (n :: Natural). KnownNat n => Natural
getNatural = Proxy n -> Natural
forall (n :: Natural) (proxy :: Natural -> Type).
KnownNat n =>
proxy n -> Natural
natVal (Proxy n
forall {k} (t :: k). Proxy t
Proxy :: Proxy n)
maxBitsPerFieldElement :: forall p. Finite p => Natural
maxBitsPerFieldElement :: forall a. Finite a => Natural
maxBitsPerFieldElement = Double -> Natural
forall b. Integral b => Double -> b
forall a b. (RealFrac a, Integral b) => a -> b
Haskell.floor (Double -> Natural) -> Double -> Natural
forall a b. (a -> b) -> a -> b
$ Natural -> Double
log2 (forall a. Finite a => Natural
order @p)
maxBitsPerRegister :: forall p n. (Finite p, KnownNat n) => Natural
maxBitsPerRegister :: forall a (n :: Natural). (Finite a, KnownNat n) => Natural
maxBitsPerRegister = Natural -> Natural -> Natural
forall a. Ord a => a -> a -> a
Haskell.min (forall a. Finite a => Natural
maxBitsPerFieldElement @p) (forall (n :: Natural). KnownNat n => Natural
getNatural @n)
highRegisterBits :: forall p n. (Finite p, KnownNat n) => Natural
highRegisterBits :: forall a (n :: Natural). (Finite a, KnownNat n) => Natural
highRegisterBits = case forall (n :: Natural). KnownNat n => Natural
getNatural @n Natural -> Natural -> Natural
forall a. EuclideanDomain a => a -> a -> a
`mod` forall a. Finite a => Natural
maxBitsPerFieldElement @p of
Natural
0 -> forall a. Finite a => Natural
maxBitsPerFieldElement @p
Natural
m -> Natural
m
minNumberOfRegisters :: forall p n. (Finite p, KnownNat n) => Natural
minNumberOfRegisters :: forall a (n :: Natural). (Finite a, KnownNat n) => Natural
minNumberOfRegisters = (forall (n :: Natural). KnownNat n => Natural
getNatural @n Natural -> Natural -> Natural
forall a. AdditiveSemigroup a => a -> a -> a
+ forall a (n :: Natural). (Finite a, KnownNat n) => Natural
maxBitsPerRegister @p @n Natural -> Natural -> Natural
-! Natural
1) Natural -> Natural -> Natural
forall a. EuclideanDomain a => a -> a -> a
`div` forall a (n :: Natural). (Finite a, KnownNat n) => Natural
maxBitsPerRegister @p @n