{-# LANGUAGE StandaloneDeriving #-}
module Circuit.Expr
( UnOp (..),
BinOp (..),
Expr (..),
ExprM,
compile,
emit,
imm,
addVar,
addWire,
freshInput,
freshOutput,
rotateList,
runCircuitBuilder,
evalCircuitBuilder,
execCircuitBuilder,
exprToArithCircuit,
evalExpr,
truncRotate,
)
where
import Circuit.Affine
import Circuit.Arithmetic
import Protolude
import Text.PrettyPrint.Leijen.Text hiding ((<$>))
data UnOp f a where
UNeg :: UnOp f f
UNot :: UnOp f Bool
URot :: Int -> Int -> UnOp f f
data BinOp f a where
BAdd :: BinOp f f
BSub :: BinOp f f
BMul :: BinOp f f
BAnd :: BinOp f Bool
BOr :: BinOp f Bool
BXor :: BinOp f Bool
opPrecedence :: BinOp f a -> Int
opPrecedence BOr = 5
opPrecedence BXor = 5
opPrecedence BAnd = 5
opPrecedence BSub = 6
opPrecedence BAdd = 6
opPrecedence BMul = 7
data Expr i f ty where
EConst :: f -> Expr i f f
EConstBool :: Bool -> Expr i f Bool
EVar :: i -> Expr i f f
EVarBool :: i -> Expr i f Bool
EUnOp :: UnOp f ty -> Expr i f ty -> Expr i f ty
EBinOp :: BinOp f ty -> Expr i f ty -> Expr i f ty -> Expr i f ty
EIf :: Expr i f Bool -> Expr i f ty -> Expr i f ty -> Expr i f ty
EEq :: Expr i f f -> Expr i f f -> Expr i f Bool
deriving instance (Show i, Show f) => Show (Expr i f ty)
deriving instance (Show f) => Show (BinOp f a)
deriving instance (Show f) => Show (UnOp f a)
instance Pretty (BinOp f a) where
pretty op = case op of
BAdd -> text "+"
BSub -> text "-"
BMul -> text "*"
BAnd -> text "&&"
BOr -> text "||"
BXor -> text "xor"
instance Pretty (UnOp f a) where
pretty op = case op of
UNeg -> text "neg"
UNot -> text "!"
URot truncBits rotBits -> text "rot(" <> pretty truncBits <> text "," <> pretty rotBits <> ")"
instance (Pretty f, Pretty i, Pretty ty) => Pretty (Expr i f ty) where
pretty = prettyPrec 0
where
prettyPrec :: (Pretty f, Pretty i, Pretty ty) => Int -> Expr i f ty -> Doc
prettyPrec p e =
case e of
EVar v ->
pretty v
EVarBool v ->
pretty v
EConst l ->
pretty l
EConstBool b ->
pretty b
EUnOp op e1 -> parens (pretty op <+> pretty e1)
EBinOp op e1 e2 ->
parensPrec (opPrecedence op) p $
prettyPrec (opPrecedence op) e1
<+> pretty op
<+> prettyPrec (opPrecedence op) e2
EIf b true false ->
parensPrec 4 p (text "if" <+> pretty b <+> text "then" <+> pretty true <+> text "else" <+> pretty false)
EEq l r ->
parensPrec 1 p (pretty l) <+> text "=" <+> parensPrec 1 p (pretty r)
parensPrec :: Int -> Int -> Doc -> Doc
parensPrec opPrec p = if p > opPrec then parens else identity
truncRotate ::
(Bits f) =>
Int ->
Int ->
f ->
f
truncRotate nbits nrots x =
foldr
( \ix rest ->
if testBit x ix
then setBit rest ((ix + nrots) `mod` nbits)
else rest
)
zeroBits
[0 .. nbits - 1]
evalExpr ::
(Bits f, Num f) =>
(i -> vars -> Maybe f) ->
Expr i f ty ->
vars ->
ty
evalExpr lookupVar expr vars = case expr of
EConst f -> f
EConstBool b -> b
EVar i -> case lookupVar i vars of
Just v -> v
Nothing -> panic "TODO: incorrect var lookup"
EVarBool i -> case lookupVar i vars of
Just v -> v == 1
Nothing -> panic "TODO: incorrect var lookup"
EUnOp UNeg e1 ->
negate $ evalExpr lookupVar e1 vars
EUnOp UNot e1 ->
not $ evalExpr lookupVar e1 vars
EUnOp (URot truncBits rotBits) e1 ->
truncRotate truncBits rotBits $ evalExpr lookupVar e1 vars
EBinOp op e1 e2 ->
(evalExpr lookupVar e1 vars) `apply` (evalExpr lookupVar e2 vars)
where
apply = case op of
BAdd -> (+)
BSub -> (-)
BMul -> (*)
BAnd -> (&&)
BOr -> (||)
BXor -> \x y -> (x || y) && not (x && y)
EIf b true false ->
if evalExpr lookupVar b vars
then evalExpr lookupVar true vars
else evalExpr lookupVar false vars
EEq lhs rhs -> evalExpr lookupVar lhs vars == evalExpr lookupVar rhs vars
type ExprM f a = State (ArithCircuit f, Int) a
execCircuitBuilder :: ExprM f a -> ArithCircuit f
execCircuitBuilder m = reverseCircuit $ fst $ execState m (ArithCircuit [], 0)
where
reverseCircuit = \(ArithCircuit cs) -> ArithCircuit $ reverse cs
evalCircuitBuilder :: ExprM f a -> a
evalCircuitBuilder = fst . runCircuitBuilder
runCircuitBuilder :: ExprM f a -> (a, ArithCircuit f)
runCircuitBuilder m = second (reverseCircuit . fst) $ runState m (ArithCircuit [], 0)
where
reverseCircuit = \(ArithCircuit cs) -> ArithCircuit $ reverse cs
fresh :: ExprM f Int
fresh = do
v <- gets snd
modify (second (+ 1))
pure v
imm :: ExprM f Wire
imm = IntermediateWire <$> fresh
freshInput :: ExprM f Wire
freshInput = InputWire <$> fresh
freshOutput :: ExprM f Wire
freshOutput = OutputWire <$> fresh
mulToImm :: Either Wire (AffineCircuit Wire f) -> Either Wire (AffineCircuit Wire f) -> ExprM f Wire
mulToImm l r = do
o <- imm
emit $ Mul (addVar l) (addVar r) o
pure o
emit :: Gate Wire f -> ExprM f ()
emit c = modify $ first (\(ArithCircuit cs) -> ArithCircuit (c : cs))
rotateList :: Int -> [a] -> [a]
rotateList steps x = take (length x) $ drop steps $ cycle x
addVar :: Either Wire (AffineCircuit Wire f) -> AffineCircuit Wire f
addVar (Left w) = Var w
addVar (Right c) = c
addWire :: Num f => Either Wire (AffineCircuit Wire f) -> ExprM f Wire
addWire (Left w) = pure w
addWire (Right c) = do
mulOut <- imm
emit $ Mul (ConstGate 1) c mulOut
pure mulOut
compile :: Num f => Expr Wire f ty -> ExprM f (Either Wire (AffineCircuit Wire f))
compile expr = case expr of
EConst n -> pure . Right $ ConstGate n
EConstBool b -> pure . Right $ ConstGate (if b then 1 else 0)
EVar v -> pure . Left $ v
EVarBool v -> pure . Left $ v
EUnOp op e1 -> do
e1Out <- compile e1
case op of
UNeg -> pure . Right $ ScalarMul (-1) (addVar e1Out)
UNot -> pure . Right $ Add (ConstGate 1) (ScalarMul (-1) (addVar e1Out))
URot truncBits rotBits -> do
inp <- addWire e1Out
outputs <- replicateM truncBits imm
emit $ Split inp outputs
pure . Right $ unsplit (rotateList rotBits outputs)
EBinOp op e1 e2 -> do
e1Out <- addVar <$> compile e1
e2Out <- addVar <$> compile e2
case op of
BAdd -> pure . Right $ Add e1Out e2Out
BMul -> do
tmp1 <- mulToImm (Right e1Out) (Right e2Out)
pure . Left $ tmp1
BSub -> pure . Right $ Add e1Out (ScalarMul (-1) e2Out)
BAnd -> do
tmp1 <- mulToImm (Right e1Out) (Right e2Out)
pure . Left $ tmp1
BOr -> do
tmp1 <- imm
emit $ Mul e1Out e2Out tmp1
pure . Right $ Add (Add e1Out e2Out) (ScalarMul (-1) (Var tmp1))
BXor -> do
tmp1 <- imm
emit $ Mul e1Out e2Out tmp1
pure . Right $ Add (Add e1Out e2Out) (ScalarMul (-2) (Var tmp1))
EIf cond true false -> do
condOut <- addVar <$> compile cond
trueOut <- addVar <$> compile true
falseOut <- addVar <$> compile false
tmp1 <- imm
tmp2 <- imm
emit $ Mul condOut trueOut tmp1
emit $ Mul (Add (ConstGate 1) (ScalarMul (-1) condOut)) falseOut tmp2
pure . Right $ Add (Var tmp1) (Var tmp2)
EEq lhs rhs -> do
lhsSubRhs <- compile (EBinOp BSub lhs rhs)
eqInWire <- addWire lhsSubRhs
eqFreeWire <- imm
eqOutWire <- imm
emit $ Equal eqInWire eqFreeWire eqOutWire
pure . Right $ Add (ConstGate 1) (ScalarMul (-1) (Var eqOutWire))
exprToArithCircuit ::
Num f =>
Expr Int f ty ->
Wire ->
ExprM f ()
exprToArithCircuit expr output =
exprToArithCircuit' (mapVarsExpr InputWire expr) output
exprToArithCircuit' :: Num f => Expr Wire f ty -> Wire -> ExprM f ()
exprToArithCircuit' expr output = do
exprOut <- compile expr
emit $ Mul (ConstGate 1) (addVar exprOut) output
mapVarsExpr :: (i -> j) -> Expr i f ty -> Expr j f ty
mapVarsExpr f expr = case expr of
EVar i -> EVar $ f i
EVarBool i -> EVarBool $ f i
EConst v -> EConst v
EConstBool b -> EConstBool b
EBinOp op e1 e2 -> EBinOp op (mapVarsExpr f e1) (mapVarsExpr f e2)
EUnOp op e1 -> EUnOp op (mapVarsExpr f e1)
EIf b tr fl -> EIf (mapVarsExpr f b) (mapVarsExpr f tr) (mapVarsExpr f fl)
EEq lhs rhs -> EEq (mapVarsExpr f lhs) (mapVarsExpr f rhs)