{-# LANGUAGE StandaloneDeriving #-} module Circuit.Expr ( UnOp (..), BinOp (..), Expr (..), ExprM, compile, emit, imm, addVar, addWire, freshInput, freshOutput, rotateList, runCircuitBuilder, evalCircuitBuilder, execCircuitBuilder, exprToArithCircuit, evalExpr, truncRotate, ) where import Circuit.Affine import Circuit.Arithmetic import Protolude import Text.PrettyPrint.Leijen.Text hiding ((<$>)) data UnOp f a where UNeg :: UnOp f f UNot :: UnOp f Bool -- | # truncate bits, # rotate bits URot :: Int -> Int -> UnOp f f data BinOp f a where BAdd :: BinOp f f BSub :: BinOp f f BMul :: BinOp f f BAnd :: BinOp f Bool BOr :: BinOp f Bool BXor :: BinOp f Bool opPrecedence :: BinOp f a -> Int opPrecedence BOr = 5 opPrecedence BXor = 5 opPrecedence BAnd = 5 opPrecedence BSub = 6 opPrecedence BAdd = 6 opPrecedence BMul = 7 -- | Expression data type of (arithmetic) expressions over a field @f@ -- with variable names/indices coming from @i@. data Expr i f ty where EConst :: f -> Expr i f f EConstBool :: Bool -> Expr i f Bool EVar :: i -> Expr i f f EVarBool :: i -> Expr i f Bool EUnOp :: UnOp f ty -> Expr i f ty -> Expr i f ty EBinOp :: BinOp f ty -> Expr i f ty -> Expr i f ty -> Expr i f ty EIf :: Expr i f Bool -> Expr i f ty -> Expr i f ty -> Expr i f ty EEq :: Expr i f f -> Expr i f f -> Expr i f Bool deriving instance (Show i, Show f) => Show (Expr i f ty) deriving instance (Show f) => Show (BinOp f a) deriving instance (Show f) => Show (UnOp f a) instance Pretty (BinOp f a) where pretty op = case op of BAdd -> text "+" BSub -> text "-" BMul -> text "*" BAnd -> text "&&" BOr -> text "||" BXor -> text "xor" instance Pretty (UnOp f a) where pretty op = case op of UNeg -> text "neg" UNot -> text "!" URot truncBits rotBits -> text "rot(" <> pretty truncBits <> text "," <> pretty rotBits <> ")" instance (Pretty f, Pretty i, Pretty ty) => Pretty (Expr i f ty) where pretty = prettyPrec 0 where prettyPrec :: (Pretty f, Pretty i, Pretty ty) => Int -> Expr i f ty -> Doc prettyPrec p e = case e of EVar v -> pretty v EVarBool v -> pretty v EConst l -> pretty l EConstBool b -> pretty b -- TODO correct precedence EUnOp op e1 -> parens (pretty op <+> pretty e1) EBinOp op e1 e2 -> parensPrec (opPrecedence op) p $ prettyPrec (opPrecedence op) e1 <+> pretty op <+> prettyPrec (opPrecedence op) e2 EIf b true false -> parensPrec 4 p (text "if" <+> pretty b <+> text "then" <+> pretty true <+> text "else" <+> pretty false) -- TODO correct precedence EEq l r -> parensPrec 1 p (pretty l) <+> text "=" <+> parensPrec 1 p (pretty r) parensPrec :: Int -> Int -> Doc -> Doc parensPrec opPrec p = if p > opPrec then parens else identity ------------------------------------------------------------------------------- -- Evaluator ------------------------------------------------------------------------------- -- | Truncate a number to the given number of bits and perform a right -- rotation (assuming small-endianness) within the truncation. truncRotate :: (Bits f) => -- | number of bits to truncate to Int -> -- | number of bits to rotate by Int -> f -> f truncRotate nbits nrots x = foldr ( \ix rest -> if testBit x ix then setBit rest ((ix + nrots) `mod` nbits) else rest ) zeroBits [0 .. nbits - 1] -- | Evaluate arithmetic expressions directly, given an environment evalExpr :: (Bits f, Num f) => -- | variable lookup (i -> vars -> Maybe f) -> -- | expression to evaluate Expr i f ty -> -- | input values vars -> -- | resulting value ty evalExpr lookupVar expr vars = case expr of EConst f -> f EConstBool b -> b EVar i -> case lookupVar i vars of Just v -> v Nothing -> panic "TODO: incorrect var lookup" EVarBool i -> case lookupVar i vars of Just v -> v == 1 Nothing -> panic "TODO: incorrect var lookup" EUnOp UNeg e1 -> negate $ evalExpr lookupVar e1 vars EUnOp UNot e1 -> not $ evalExpr lookupVar e1 vars EUnOp (URot truncBits rotBits) e1 -> truncRotate truncBits rotBits $ evalExpr lookupVar e1 vars EBinOp op e1 e2 -> (evalExpr lookupVar e1 vars) `apply` (evalExpr lookupVar e2 vars) where apply = case op of BAdd -> (+) BSub -> (-) BMul -> (*) BAnd -> (&&) BOr -> (||) BXor -> \x y -> (x || y) && not (x && y) EIf b true false -> if evalExpr lookupVar b vars then evalExpr lookupVar true vars else evalExpr lookupVar false vars EEq lhs rhs -> evalExpr lookupVar lhs vars == evalExpr lookupVar rhs vars ------------------------------------------------------------------------------- -- Circuit Builder ------------------------------------------------------------------------------- type ExprM f a = State (ArithCircuit f, Int) a execCircuitBuilder :: ExprM f a -> ArithCircuit f execCircuitBuilder m = reverseCircuit $ fst $ execState m (ArithCircuit [], 0) where reverseCircuit = \(ArithCircuit cs) -> ArithCircuit $ reverse cs evalCircuitBuilder :: ExprM f a -> a evalCircuitBuilder = fst . runCircuitBuilder runCircuitBuilder :: ExprM f a -> (a, ArithCircuit f) runCircuitBuilder m = second (reverseCircuit . fst) $ runState m (ArithCircuit [], 0) where reverseCircuit = \(ArithCircuit cs) -> ArithCircuit $ reverse cs fresh :: ExprM f Int fresh = do v <- gets snd modify (second (+ 1)) pure v -- | Fresh intermediate variables imm :: ExprM f Wire imm = IntermediateWire <$> fresh -- | Fresh input variables freshInput :: ExprM f Wire freshInput = InputWire <$> fresh -- | Fresh output variables freshOutput :: ExprM f Wire freshOutput = OutputWire <$> fresh -- | Multiply two wires or affine circuits to an intermediate variable mulToImm :: Either Wire (AffineCircuit Wire f) -> Either Wire (AffineCircuit Wire f) -> ExprM f Wire mulToImm l r = do o <- imm emit $ Mul (addVar l) (addVar r) o pure o -- | Add a Mul and its output to the ArithCircuit emit :: Gate Wire f -> ExprM f () emit c = modify $ first (\(ArithCircuit cs) -> ArithCircuit (c : cs)) -- | Rotate a list to the right rotateList :: Int -> [a] -> [a] rotateList steps x = take (length x) $ drop steps $ cycle x -- | Turn a wire into an affine circuit, or leave it be addVar :: Either Wire (AffineCircuit Wire f) -> AffineCircuit Wire f addVar (Left w) = Var w addVar (Right c) = c -- | Turn an affine circuit into a wire, or leave it be addWire :: Num f => Either Wire (AffineCircuit Wire f) -> ExprM f Wire addWire (Left w) = pure w addWire (Right c) = do mulOut <- imm emit $ Mul (ConstGate 1) c mulOut pure mulOut compile :: Num f => Expr Wire f ty -> ExprM f (Either Wire (AffineCircuit Wire f)) compile expr = case expr of EConst n -> pure . Right $ ConstGate n EConstBool b -> pure . Right $ ConstGate (if b then 1 else 0) EVar v -> pure . Left $ v EVarBool v -> pure . Left $ v EUnOp op e1 -> do e1Out <- compile e1 case op of UNeg -> pure . Right $ ScalarMul (-1) (addVar e1Out) UNot -> pure . Right $ Add (ConstGate 1) (ScalarMul (-1) (addVar e1Out)) URot truncBits rotBits -> do inp <- addWire e1Out outputs <- replicateM truncBits imm emit $ Split inp outputs pure . Right $ unsplit (rotateList rotBits outputs) EBinOp op e1 e2 -> do e1Out <- addVar <$> compile e1 e2Out <- addVar <$> compile e2 case op of BAdd -> pure . Right $ Add e1Out e2Out BMul -> do tmp1 <- mulToImm (Right e1Out) (Right e2Out) pure . Left $ tmp1 -- SUB(x, y) = x + (-y) BSub -> pure . Right $ Add e1Out (ScalarMul (-1) e2Out) BAnd -> do tmp1 <- mulToImm (Right e1Out) (Right e2Out) pure . Left $ tmp1 BOr -> do -- OR(input1, input2) = (input1 + input2) - (input1 * input2) tmp1 <- imm emit $ Mul e1Out e2Out tmp1 pure . Right $ Add (Add e1Out e2Out) (ScalarMul (-1) (Var tmp1)) BXor -> do -- XOR(input1, input2) = (input1 + input2) - 2 * (input1 * input2) tmp1 <- imm emit $ Mul e1Out e2Out tmp1 pure . Right $ Add (Add e1Out e2Out) (ScalarMul (-2) (Var tmp1)) -- IF(cond, true, false) = (cond*true) + ((!cond) * false) EIf cond true false -> do condOut <- addVar <$> compile cond trueOut <- addVar <$> compile true falseOut <- addVar <$> compile false tmp1 <- imm tmp2 <- imm emit $ Mul condOut trueOut tmp1 emit $ Mul (Add (ConstGate 1) (ScalarMul (-1) condOut)) falseOut tmp2 pure . Right $ Add (Var tmp1) (Var tmp2) -- EQ(lhs, rhs) = (lhs - rhs == 1) EEq lhs rhs -> do lhsSubRhs <- compile (EBinOp BSub lhs rhs) eqInWire <- addWire lhsSubRhs eqFreeWire <- imm eqOutWire <- imm emit $ Equal eqInWire eqFreeWire eqOutWire -- eqOutWire == 0 if lhs == rhs, so we need to return 1 - -- neqOutWire instead. pure . Right $ Add (ConstGate 1) (ScalarMul (-1) (Var eqOutWire)) -- | Translate an arithmetic expression to an arithmetic circuit exprToArithCircuit :: Num f => -- | expression to compile Expr Int f ty -> -- | Wire to assign the output of the expression to Wire -> ExprM f () exprToArithCircuit expr output = exprToArithCircuit' (mapVarsExpr InputWire expr) output exprToArithCircuit' :: Num f => Expr Wire f ty -> Wire -> ExprM f () exprToArithCircuit' expr output = do exprOut <- compile expr emit $ Mul (ConstGate 1) (addVar exprOut) output -- | Apply function to variable names. mapVarsExpr :: (i -> j) -> Expr i f ty -> Expr j f ty mapVarsExpr f expr = case expr of EVar i -> EVar $ f i EVarBool i -> EVarBool $ f i EConst v -> EConst v EConstBool b -> EConstBool b EBinOp op e1 e2 -> EBinOp op (mapVarsExpr f e1) (mapVarsExpr f e2) EUnOp op e1 -> EUnOp op (mapVarsExpr f e1) EIf b tr fl -> EIf (mapVarsExpr f b) (mapVarsExpr f tr) (mapVarsExpr f fl) EEq lhs rhs -> EEq (mapVarsExpr f lhs) (mapVarsExpr f rhs)