{-# LANGUAGE DeriveAnyClass, DeriveGeneric, LambdaCase, StrictData #-}
module Circuit.Affine
( AffineCircuit (..),
collectInputsAffine,
mapVarsAffine,
evalAffineCircuit,
affineCircuitToAffineMap,
evalAffineMap,
dotProduct,
)
where
import Data.Aeson (FromJSON, ToJSON)
import Data.Map (Map)
import qualified Data.Map as Map
import Protolude
import Text.PrettyPrint.Leijen.Text (Doc, Pretty(..), parens, text,
(<+>))
data AffineCircuit i f
= Add (AffineCircuit i f) (AffineCircuit i f)
| ScalarMul f (AffineCircuit i f)
| ConstGate f
| Var i
deriving (Read, Eq, Show, Generic, NFData, FromJSON, ToJSON)
collectInputsAffine :: Ord i => AffineCircuit i f -> [i]
collectInputsAffine = \case
Add l r -> collectInputsAffine l ++ collectInputsAffine r
ScalarMul _ x -> collectInputsAffine x
ConstGate _ -> []
Var i -> [i]
instance (Pretty i, Show f) => Pretty (AffineCircuit i f) where
pretty = prettyPrec 0
where
prettyPrec :: (Pretty i, Show f) => Int -> AffineCircuit i f -> Doc
prettyPrec p e =
case e of
Var v ->
pretty v
ConstGate f ->
text $ show f
ScalarMul f e1 ->
text (show f) <+> text "*" <+> parensPrec 7 p (prettyPrec p e1)
Add e1 e2 ->
parensPrec 6 p $
prettyPrec 6 e1
<+> text "+"
<+> prettyPrec 6 e2
parensPrec :: Int -> Int -> Doc -> Doc
parensPrec opPrec p = if p > opPrec then parens else identity
mapVarsAffine :: (i -> j) -> AffineCircuit i f -> AffineCircuit j f
mapVarsAffine f = \case
Add l r -> Add (mapVarsAffine f l) (mapVarsAffine f r)
ScalarMul s expr -> ScalarMul s $ mapVarsAffine f expr
ConstGate c -> ConstGate c
Var i -> Var $ f i
evalAffineCircuit ::
Num f =>
(i -> vars -> Maybe f) ->
vars ->
AffineCircuit i f ->
f
evalAffineCircuit lookupVar vars = \case
ConstGate f -> f
Var i -> fromMaybe 0 $ lookupVar i vars
Add l r -> evalAffineCircuit lookupVar vars l + evalAffineCircuit lookupVar vars r
ScalarMul scalar expr -> evalAffineCircuit lookupVar vars expr * scalar
affineCircuitToAffineMap ::
(Num f, Ord i) =>
AffineCircuit i f ->
(f, Map i f)
affineCircuitToAffineMap = \case
Var i -> (0, Map.singleton i 1)
Add l r -> (constLeft + constRight, Map.unionWith (+) vecLeft vecRight)
where
(constLeft, vecLeft) = affineCircuitToAffineMap l
(constRight, vecRight) = affineCircuitToAffineMap r
ScalarMul scalar expr -> (scalar * constExpr, fmap (scalar *) vecExpr)
where
(constExpr, vecExpr) = affineCircuitToAffineMap expr
ConstGate f -> (f, Map.empty)
evalAffineMap ::
(Num f, Ord i) =>
(f, Map i f) ->
Map i f ->
f
evalAffineMap (constPart, linearPart) input =
constPart + dotProduct linearPart input
dotProduct :: (Num f, Ord i) => Map i f -> Map i f -> f
dotProduct inp comp =
sum
. Map.elems
$ Map.mapWithKey (\ix c -> c * Map.findWithDefault 0 ix inp) comp