{-# LANGUAGE AllowAmbiguousTypes     #-}
{-# LANGUAGE TypeApplications        #-}
{-# LANGUAGE UndecidableInstances    #-}
{-# LANGUAGE UndecidableSuperClasses #-}

module ZkFold.Symbolic.Data.Combinators where

import           Control.Monad                                             (mapM)
import           Data.List                                                 (find, splitAt)
import           Data.List.Split                                           (chunksOf)
import           Data.Maybe                                                (fromMaybe)
import           Data.Proxy                                                (Proxy (..))
import           Data.Ratio                                                ((%))
import           GHC.TypeNats                                              (KnownNat, Natural, natVal)
import           Prelude                                                   (error, pure, ($), (.), (<$>), (<>))
import qualified Prelude                                                   as Haskell

import           ZkFold.Base.Algebra.Basic.Class
import           ZkFold.Symbolic.Compiler
import           ZkFold.Symbolic.Compiler.ArithmeticCircuit.Combinators    (expansion, horner)
import           ZkFold.Symbolic.Compiler.ArithmeticCircuit.MonadBlueprint

-- | A class for isomorphic types.
-- The @Iso b a@ context ensures that transformations in both directions are defined
--
class Iso b a => Iso a b where
    from :: a -> b

-- | Describes types that can increase their capacity by adding zero bits to the beginning (i.e. before the higher register).
--
class Extend a b where
    extend :: a -> b

-- | Describes types that can shrink their capacity by removing higher bits.
--
class Shrink a b where
    shrink :: a -> b



-- | Convert an @ArithmeticCircuit@ to bits and return their corresponding variables.
--
toBits
    :: forall a
    .  ArithmeticCircuit a
    -> [ArithmeticCircuit a]
    -> Natural
    -> Natural
    -> (forall i m. MonadBlueprint i a m => m [i])
toBits :: forall a.
ArithmeticCircuit a
-> [ArithmeticCircuit a]
-> Natural
-> Natural
-> forall i (m :: Type -> Type). MonadBlueprint i a m => m [i]
toBits ArithmeticCircuit a
hi [ArithmeticCircuit a]
lo Natural
hiBits Natural
loBits = do
    [i]
lows <- (ArithmeticCircuit a -> m i) -> [ArithmeticCircuit a] -> m [i]
forall (t :: Type -> Type) (m :: Type -> Type) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: Type -> Type) a b.
Monad m =>
(a -> m b) -> [a] -> m [b]
mapM ArithmeticCircuit a -> m i
forall i a (m :: Type -> Type).
MonadBlueprint i a m =>
ArithmeticCircuit a -> m i
runCircuit [ArithmeticCircuit a]
lo
    i
high <- ArithmeticCircuit a -> m i
forall i a (m :: Type -> Type).
MonadBlueprint i a m =>
ArithmeticCircuit a -> m i
runCircuit ArithmeticCircuit a
hi

    [i]
bitsLow  <- ([i] -> [i]) -> [[i]] -> [i]
forall (t :: Type -> Type) a b.
Foldable t =>
(a -> [b]) -> t a -> [b]
Haskell.concatMap [i] -> [i]
forall a. [a] -> [a]
Haskell.reverse ([[i]] -> [i]) -> m [[i]] -> m [i]
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> (i -> m [i]) -> [i] -> m [[i]]
forall (t :: Type -> Type) (m :: Type -> Type) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: Type -> Type) a b.
Monad m =>
(a -> m b) -> [a] -> m [b]
mapM (Natural -> i -> m [i]
forall i a (m :: Type -> Type).
MonadBlueprint i a m =>
Natural -> i -> m [i]
expansion Natural
loBits) [i]
lows
    [i]
bitsHigh <- [i] -> [i]
forall a. [a] -> [a]
Haskell.reverse ([i] -> [i]) -> m [i] -> m [i]
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> Natural -> i -> m [i]
forall i a (m :: Type -> Type).
MonadBlueprint i a m =>
Natural -> i -> m [i]
expansion Natural
hiBits i
high

    [i] -> m [i]
forall a. a -> m a
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure ([i] -> m [i]) -> [i] -> m [i]
forall a b. (a -> b) -> a -> b
$ [i]
bitsHigh [i] -> [i] -> [i]
forall a. Semigroup a => a -> a -> a
<> [i]
bitsLow


-- | The inverse of @toBits@.
--
fromBits
    :: forall a
    .  Natural
    -> Natural
    -> (forall i m. MonadBlueprint i a m => [i] -> m [i])
fromBits :: forall a.
Natural
-> Natural
-> forall i (m :: Type -> Type).
   MonadBlueprint i a m =>
   [i] -> m [i]
fromBits Natural
hiBits Natural
loBits [i]
bits = do
    let ([i]
bitsHighNew, [i]
bitsLowNew) = Int -> [i] -> ([i], [i])
forall a. Int -> [a] -> ([a], [a])
splitAt (Natural -> Int
forall a b. (Integral a, Num b) => a -> b
Haskell.fromIntegral Natural
hiBits) [i]
bits
    let lowVarsNew :: [[i]]
lowVarsNew = Int -> [i] -> [[i]]
forall e. Int -> [e] -> [[e]]
chunksOf (Natural -> Int
forall a b. (Integral a, Num b) => a -> b
Haskell.fromIntegral Natural
loBits) [i]
bitsLowNew

    [i]
lowsNew <- ([i] -> m i) -> [[i]] -> m [i]
forall (t :: Type -> Type) (m :: Type -> Type) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: Type -> Type) a b.
Monad m =>
(a -> m b) -> [a] -> m [b]
mapM ([i] -> m i
forall i a (m :: Type -> Type). MonadBlueprint i a m => [i] -> m i
horner ([i] -> m i) -> ([i] -> [i]) -> [i] -> m i
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [i] -> [i]
forall a. [a] -> [a]
Haskell.reverse) [[i]]
lowVarsNew
    i
highNew <- [i] -> m i
forall i a (m :: Type -> Type). MonadBlueprint i a m => [i] -> m i
horner ([i] -> m i) -> ([i] -> [i]) -> [i] -> m i
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [i] -> [i]
forall a. [a] -> [a]
Haskell.reverse ([i] -> m i) -> [i] -> m i
forall a b. (a -> b) -> a -> b
$  [i]
bitsHighNew

    [i] -> m [i]
forall a. a -> m a
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure ([i] -> m [i]) -> [i] -> m [i]
forall a b. (a -> b) -> a -> b
$ i
highNew i -> [i] -> [i]
forall a. a -> [a] -> [a]
: [i]
lowsNew


maxOverflow :: forall a n . (Finite a, KnownNat n) => Natural
maxOverflow :: forall a (n :: Natural). (Finite a, KnownNat n) => Natural
maxOverflow = forall a (n :: Natural). (Finite a, KnownNat n) => Natural
registerSize @a @n Natural -> Natural -> Natural
forall a. AdditiveSemigroup a => a -> a -> a
+ Double -> Natural
forall b. Integral b => Double -> b
forall a b. (RealFrac a, Integral b) => a -> b
Haskell.ceiling (Natural -> Double
log2 (Natural -> Double) -> Natural -> Double
forall a b. (a -> b) -> a -> b
$ forall a (n :: Natural). (Finite a, KnownNat n) => Natural
numberOfRegisters @a @n)

highRegisterSize :: forall a n . (Finite a, KnownNat n) => Natural
highRegisterSize :: forall a (n :: Natural). (Finite a, KnownNat n) => Natural
highRegisterSize = forall (n :: Natural). KnownNat n => Natural
getNatural @n Natural -> Natural -> Natural
-! forall a (n :: Natural). (Finite a, KnownNat n) => Natural
registerSize @a @n Natural -> Natural -> Natural
forall a. MultiplicativeSemigroup a => a -> a -> a
* (forall a (n :: Natural). (Finite a, KnownNat n) => Natural
numberOfRegisters @a @n Natural -> Natural -> Natural
-! Natural
1)

registerSize :: forall a n . (Finite a, KnownNat n) => Natural
registerSize :: forall a (n :: Natural). (Finite a, KnownNat n) => Natural
registerSize = Ratio Natural -> Natural
forall b. Integral b => Ratio Natural -> b
forall a b. (RealFrac a, Integral b) => a -> b
Haskell.ceiling (forall (n :: Natural). KnownNat n => Natural
getNatural @n Natural -> Natural -> Ratio Natural
forall a. Integral a => a -> a -> Ratio a
% forall a (n :: Natural). (Finite a, KnownNat n) => Natural
numberOfRegisters @a @n)

numberOfRegisters :: forall a n . (Finite a, KnownNat n) => Natural
numberOfRegisters :: forall a (n :: Natural). (Finite a, KnownNat n) => Natural
numberOfRegisters = Natural -> Maybe Natural -> Natural
forall a. a -> Maybe a -> a
fromMaybe ([Char] -> Natural
forall a. HasCallStack => [Char] -> a
error [Char]
"too many bits, field is not big enough")
    (Maybe Natural -> Natural) -> Maybe Natural -> Natural
forall a b. (a -> b) -> a -> b
$ (Natural -> Bool) -> [Natural] -> Maybe Natural
forall (t :: Type -> Type) a.
Foldable t =>
(a -> Bool) -> t a -> Maybe a
find (\Natural
c -> Natural
c Natural -> Natural -> Natural
forall a. MultiplicativeSemigroup a => a -> a -> a
* Natural -> Natural
maxRegisterSize Natural
c Natural -> Natural -> Bool
forall a. Ord a => a -> a -> Bool
Haskell.>= forall (n :: Natural). KnownNat n => Natural
getNatural @n) [Natural
1 .. Natural
maxRegisterCount]
    where
        maxRegisterCount :: Natural
maxRegisterCount = Natural
2 Natural -> Natural -> Natural
forall a b. Exponent a b => a -> b -> a
^ Natural
bitLimit
        bitLimit :: Natural
bitLimit = Double -> Natural
forall b. Integral b => Double -> b
forall a b. (RealFrac a, Integral b) => a -> b
Haskell.floor (Double -> Natural) -> Double -> Natural
forall a b. (a -> b) -> a -> b
$ Natural -> Double
log2 (forall a. Finite a => Natural
order @a)
        maxRegisterSize :: Natural -> Natural
maxRegisterSize Natural
regCount =
            let maxAdded :: Natural
maxAdded = Double -> Natural
forall b. Integral b => Double -> b
forall a b. (RealFrac a, Integral b) => a -> b
Haskell.ceiling (Double -> Natural) -> Double -> Natural
forall a b. (a -> b) -> a -> b
$ Natural -> Double
log2 Natural
regCount
             in Ratio Natural -> Natural
forall b. Integral b => Ratio Natural -> b
forall a b. (RealFrac a, Integral b) => a -> b
Haskell.floor (Ratio Natural -> Natural) -> Ratio Natural -> Natural
forall a b. (a -> b) -> a -> b
$ (Natural
bitLimit Natural -> Natural -> Natural
-! Natural
maxAdded) Natural -> Natural -> Ratio Natural
forall a. Integral a => a -> a -> Ratio a
% Natural
2

log2 :: Natural -> Haskell.Double
log2 :: Natural -> Double
log2 = Double -> Double -> Double
forall a. Floating a => a -> a -> a
Haskell.logBase Double
2 (Double -> Double) -> (Natural -> Double) -> Natural -> Double
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Natural -> Double
forall a b. (Integral a, Num b) => a -> b
Haskell.fromIntegral

getNatural :: forall n . KnownNat n => Natural
getNatural :: forall (n :: Natural). KnownNat n => Natural
getNatural = Proxy n -> Natural
forall (n :: Natural) (proxy :: Natural -> Type).
KnownNat n =>
proxy n -> Natural
natVal (Proxy n
forall {k} (t :: k). Proxy t
Proxy :: Proxy n)

-- | The maximum number of bits a Field element can encode.
--
maxBitsPerFieldElement :: forall p. Finite p => Natural
maxBitsPerFieldElement :: forall a. Finite a => Natural
maxBitsPerFieldElement = Double -> Natural
forall b. Integral b => Double -> b
forall a b. (RealFrac a, Integral b) => a -> b
Haskell.floor (Double -> Natural) -> Double -> Natural
forall a b. (a -> b) -> a -> b
$ Natural -> Double
log2 (forall a. Finite a => Natural
order @p)

-- | The maximum number of bits it makes sense to encode in a register.
-- That is, if the field elements can encode more bits than required, choose the smaller number.
--
maxBitsPerRegister :: forall p n. (Finite p, KnownNat n) => Natural
maxBitsPerRegister :: forall a (n :: Natural). (Finite a, KnownNat n) => Natural
maxBitsPerRegister = Natural -> Natural -> Natural
forall a. Ord a => a -> a -> a
Haskell.min (forall a. Finite a => Natural
maxBitsPerFieldElement @p) (forall (n :: Natural). KnownNat n => Natural
getNatural @n)

-- | The number of bits remaining for the higher register
-- assuming that all smaller registers have the same optimal number of bits.
--
highRegisterBits :: forall p n. (Finite p, KnownNat n) => Natural
highRegisterBits :: forall a (n :: Natural). (Finite a, KnownNat n) => Natural
highRegisterBits = case forall (n :: Natural). KnownNat n => Natural
getNatural @n Natural -> Natural -> Natural
forall a. EuclideanDomain a => a -> a -> a
`mod` forall a. Finite a => Natural
maxBitsPerFieldElement @p of
                     Natural
0 -> forall a. Finite a => Natural
maxBitsPerFieldElement @p
                     Natural
m -> Natural
m
-- | The lowest possible number of registers to encode @n@ bits using Field elements from @p@
-- assuming that each register storest the largest possible number of bits.
--
minNumberOfRegisters :: forall p n. (Finite p, KnownNat n) => Natural
minNumberOfRegisters :: forall a (n :: Natural). (Finite a, KnownNat n) => Natural
minNumberOfRegisters = (forall (n :: Natural). KnownNat n => Natural
getNatural @n Natural -> Natural -> Natural
forall a. AdditiveSemigroup a => a -> a -> a
+ forall a (n :: Natural). (Finite a, KnownNat n) => Natural
maxBitsPerRegister @p @n Natural -> Natural -> Natural
-! Natural
1) Natural -> Natural -> Natural
forall a. EuclideanDomain a => a -> a -> a
`div` forall a (n :: Natural). (Finite a, KnownNat n) => Natural
maxBitsPerRegister @p @n