{-# LANGUAGE DeriveAnyClass, DeriveGeneric, LambdaCase, ScopedTypeVariables, StrictData #-} -- | Definition of arithmetic circuits: one with a single -- multiplication gate with affine inputs and another variant with an -- arbitrary number of such gates. module Circuit.Arithmetic ( Gate (..), mapVarsGate, collectInputsGate, outputWires, ArithCircuit (..), fetchVars, generateRoots, validArithCircuit, Wire (..), evalGate, evalArithCircuit, unsplit, ) where import Circuit.Affine (AffineCircuit(..), collectInputsAffine, evalAffineCircuit, mapVarsAffine) import Data.Aeson (FromJSON, ToJSON) import Protolude import Text.PrettyPrint.Leijen.Text as PP (Pretty(..), hsep, list, parens, text, vcat) -- | Wires are can be labeled in the ways given in this data type data Wire = InputWire Int | IntermediateWire Int | OutputWire Int deriving (Show, Eq, Ord, Generic, NFData, ToJSON, FromJSON) instance Pretty Wire where pretty (InputWire v) = text "input_" <> pretty v pretty (IntermediateWire v) = text "imm_" <> pretty v pretty (OutputWire v) = text "output_" <> pretty v -- | An arithmetic circuit with a single multiplication gate. data Gate i f = Mul { mulLeft :: AffineCircuit i f, mulRight :: AffineCircuit i f, mulOutput :: i } | Equal { eqInput :: i, eqMagic :: i, eqOutput :: i } | Split { splitInput :: i, splitOutputs :: [i] } deriving (Show, Eq, Generic, NFData, FromJSON, ToJSON) collectInputsGate :: Ord i => Gate i f -> [i] collectInputsGate = \case Mul l r _ -> collectInputsAffine l ++ collectInputsAffine r _ -> panic "collectInputsGate: only supports mul gates" -- | List output wires of a gate outputWires :: Gate i f -> [i] outputWires = \case Mul _ _ out -> [out] Equal _ _ out -> [out] Split _ outs -> outs instance (Pretty i, Show f) => Pretty (Gate i f) where pretty (Mul l r o) = hsep [ pretty o, text ":=", parens (pretty l), text "*", parens (pretty r) ] pretty (Equal i _ o) = hsep [ pretty o, text ":=", pretty i, text "== 0 ? 0 : 1" ] pretty (Split inp outputs) = hsep [ PP.list (map pretty outputs), text ":=", text "split", pretty inp ] -- | Apply mapping to variable names, i.e. rename variables. (Ideally -- the mapping is injective.) mapVarsGate :: (i -> j) -> Gate i f -> Gate j f mapVarsGate f = \case Mul l r o -> Mul (mapVarsAffine f l) (mapVarsAffine f r) (f o) Equal i j o -> Equal (f i) (f j) (f o) Split i os -> Split (f i) (fmap f os) -- | Evaluate a single gate evalGate :: (Bits f, Fractional f) => -- | lookup a value at a wire (i -> vars -> Maybe f) -> -- | update a value at a wire (i -> f -> vars -> vars) -> -- | context before evaluation vars -> -- | gate Gate i f -> -- | context after evaluation vars evalGate lookupVar updateVar vars gate = case gate of Mul l r outputWire -> let lval = evalAffineCircuit lookupVar vars l rval = evalAffineCircuit lookupVar vars r res = lval * rval in updateVar outputWire res vars Equal i m outputWire -> case lookupVar i vars of Nothing -> panic "evalGate: the impossible happened" Just inp -> let res = if inp == 0 then 0 else 1 mid = if inp == 0 then 0 else recip inp in updateVar outputWire res $ updateVar m mid vars Split i os -> case lookupVar i vars of Nothing -> panic "evalGate: the impossible happened" Just inp -> let bool2val True = 1 bool2val False = 0 setWire (ix, oldEnv) currentOut = ( ix + 1, updateVar currentOut (bool2val $ testBit inp ix) oldEnv ) in snd . foldl setWire (0, vars) $ os -- | A circuit is a list of multiplication gates along with their -- output wire labels (which can be intermediate or actual outputs). newtype ArithCircuit f = ArithCircuit [Gate Wire f] deriving (Eq, Show, Generic, NFData, FromJSON, ToJSON) instance Show f => Pretty (ArithCircuit f) where pretty (ArithCircuit gs) = vcat . map pretty $ gs -- | Check whether an arithmetic circuit does not refer to -- intermediate wires before they are defined and whether output wires -- are not used as input wires. validArithCircuit :: ArithCircuit f -> Bool validArithCircuit (ArithCircuit gates) = noRefsToUndefinedWires where noRefsToUndefinedWires = fst $ foldl ( \(res, definedWires) gate -> ( res && all isNotInput (outputWires gate) && all (validWire definedWires) (fetchVarsGate gate), outputWires gate ++ definedWires ) ) (True, []) gates isNotInput (InputWire _) = False isNotInput (OutputWire _) = True isNotInput (IntermediateWire _) = True validWire _ (InputWire _) = True validWire _ (OutputWire _) = False validWire definedWires i@(IntermediateWire _) = i `elem` definedWires fetchVarsGate (Mul l r _) = fetchVars l ++ fetchVars r fetchVarsGate (Equal i _ _) = [i] -- we can ignore the magic -- variable "m", as it is filled -- in when evaluating the circuit fetchVarsGate (Split i _) = [i] fetchVars :: AffineCircuit Wire f -> [Wire] fetchVars (Var i) = [i] fetchVars (ConstGate _) = [] fetchVars (ScalarMul _ c) = fetchVars c fetchVars (Add l r) = fetchVars l ++ fetchVars r -- | Generate enough roots for a circuit generateRoots :: Applicative m => m f -> ArithCircuit f -> m [[f]] generateRoots _ (ArithCircuit []) = pure [] generateRoots takeRoot (ArithCircuit (gate : gates)) = case gate of Mul {} -> (\r rs -> [r] : rs) <$> takeRoot <*> generateRoots takeRoot (ArithCircuit gates) Equal {} -> (\r0 r1 rs -> [r0, r1] : rs) <$> takeRoot <*> takeRoot <*> generateRoots takeRoot (ArithCircuit gates) Split _ outputs -> (\r0 rOutputs rRest -> (r0 : rOutputs) : rRest) <$> takeRoot <*> traverse (const takeRoot) outputs <*> generateRoots takeRoot (ArithCircuit gates) -- | Evaluate an arithmetic circuit on a given environment containing -- the inputs. Outputs the entire environment (outputs, intermediate -- values and inputs). evalArithCircuit :: forall f vars. (Bits f, Fractional f) => -- | lookup a value at a wire (Wire -> vars -> Maybe f) -> -- | update a value at a wire (Wire -> f -> vars -> vars) -> -- | circuit to evaluate ArithCircuit f -> -- | input variables vars -> -- | input and output variables vars evalArithCircuit lookupVar updateVar (ArithCircuit gates) vars = foldl' (evalGate lookupVar updateVar) vars gates -- | Turn a binary expansion back into a single value. unsplit :: Num f => -- | (binary) wires containing a binary expansion, -- small-endian [Wire] -> AffineCircuit Wire f unsplit = snd . foldl (\(ix, rest) wire -> (ix + (1 :: Integer), Add rest (ScalarMul (2 ^ ix) (Var wire)))) (0, ConstGate 0)