{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE NoImplicitPrelude #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}
module Crypto.Lol.Applications.KeyHomomorphicPRF
( FBT(..), SFBT, SizeFBT, FBTC, singFBT
, PRFKey, PRFParams, PRFState
, genKey, genParams, prf, prfState, prfAmortized, run, runT
, Vector, BitString
, replicate, replicateS, fromList, fromListS, split, splitS
) where
import Control.Applicative ((<$>), (<*>))
import Control.Monad.Identity
import Control.Monad.Random hiding (fromList, split)
import Control.Monad.State
import Crypto.Lol hiding (head, replicate)
import Crypto.Lol.Reflects
import Data.Maybe
import Data.Singletons.TH
import qualified MathObj.Matrix as M
singletons [d|
data FBT = Leaf | Intern FBT FBT
sizeFBT :: FBT -> Pos
sizeFBT Leaf = O
sizeFBT (Intern l r) = (sizeFBT l) `addPos` (sizeFBT r)
|]
type FBTC (t :: FBT) = SingI t
singFBT :: FBTC t => SFBT t
singFBT = sing
newtype PRFKey n a = Key { key :: Matrix a }
genKey :: forall rq rnd n . (MonadRandom rnd, Random rq, Reflects n Int)
=> rnd (PRFKey n rq)
genKey = fmap Key $ randomMtx 1 $ value @n
data PRFParams n gad a = Params { a0 :: (Matrix a), a1 :: (Matrix a) }
genParams :: forall gad rq rnd n .
(MonadRandom rnd, Random rq, Reflects n Int, Gadget gad rq)
=> rnd (PRFParams n gad rq)
genParams = let len = length $ gadget @gad @rq
n = value @n
in Params <$> (randomMtx n (n*len)) <*> (randomMtx n (n*len))
randomMtx :: (MonadRandom rnd, Random a) => Int -> Int -> rnd (Matrix a)
randomMtx r c = M.fromList r c <$> replicateM (r*c) getRandom
data PRFState t n gad rq = PRFState { params :: PRFParams n gad rq
, state' :: PRFState' t n gad rq }
data PRFState' t n gad rq where
L :: BitStringMatrix 'Leaf rq
-> PRFState' 'Leaf n gad rq
I :: BitStringMatrix ('Intern l r) rq
-> PRFState' l n gad rq
-> PRFState' r n gad rq
-> PRFState' ('Intern l r) n gad rq
data BitStringMatrix t a
= BSM { bitString :: BitString (SizeFBT t), matrix :: Matrix a }
root' :: PRFState' t n gad a -> BitStringMatrix t a
left' :: PRFState' ('Intern l r) n gad a -> PRFState' l n gad a
right' :: PRFState' ('Intern l r) n gad a -> PRFState' r n gad a
root' (L a) = a
root' (I a _ _) = a
left' (I _ l _) = l
right' (I _ _ r) = r
root :: PRFState t n gad a -> BitStringMatrix t a
root = root' . state'
updateState' :: forall gad rq t n . Decompose gad rq
=> SFBT t
-> PRFParams n gad rq
-> Maybe (PRFState' t n gad rq)
-> BitString (SizeFBT t)
-> PRFState' t n gad rq
updateState' t p st x = case t of
SLeaf -> L $ BSM x $ if head x then a1 p else a0 p
SIntern _ _ | fromMaybe False (((x ==) . bitString . root') <$> st)
-> fromJust st
SIntern l r -> let (xl, xr) = splitS (sSizeFBT l) x
stl = updateState' l p (left' <$> st) xl
str = updateState' r p (right' <$> st) xr
al = matrix $ root' stl
ar = matrix $ root' str
ar' = reduce <$> decomposeMatrix @gad ar
in I (BSM x (al*ar')) stl str
updateState :: Decompose gad rq
=> SFBT t
-> Either (PRFParams n gad rq) (PRFState t n gad rq)
-> BitString (SizeFBT t)
-> PRFState t n gad rq
updateState t e x =
let p = either id params e
st' = case e of
(Left _) -> Nothing
(Right st) -> Just $ state' st
in PRFState p $ updateState' t p st' x
prfCore :: (Ring rq, Rescale rq rp)
=> PRFKey n rq -> PRFState t n gad rq -> Matrix rp
prfCore s st = rescale <$> (key s) * matrix (root st)
prf :: (Rescale rq rp, Decompose gad rq)
=> SFBT t
-> PRFParams n gad rq
-> PRFKey n rq
-> BitString (SizeFBT t)
-> Matrix rp
prf = (fmap . fmap . fmap) fst . prfState
prfState :: (Rescale rq rp, Decompose gad rq)
=> SFBT t
-> PRFParams n gad rq
-> PRFKey n rq
-> BitString (SizeFBT t)
-> (Matrix rp, PRFState t n gad rq)
prfState t p s x = let st = updateState t (Left p) x in (prfCore s st, st)
prfAmortized ::
(Rescale rq rp, Decompose gad rq,
MonadState (Maybe (PRFState t n gad rq)) m)
=> SFBT t
-> PRFParams n gad rq
-> PRFKey n rq
-> BitString (SizeFBT t)
-> m (Matrix rp)
prfAmortized t p s x = do
fbt <- get
let fbt' = updateState t (maybe (Left p) Right fbt) x
put $ Just fbt'
return $ prfCore s fbt'
run :: State (Maybe (PRFState t n gad rq)) a -> a
run = runIdentity . runT
runT :: (Monad m) => StateT (Maybe (PRFState t n gad rq)) m a -> m a
runT = flip evalStateT Nothing
data Vector n a where
Lone :: a -> Vector 'O a
(:-) :: a -> Vector n a -> Vector ('S n) a
infixr 5 :-
deriving instance Show a => Show (Vector n a)
instance Eq a => Eq (Vector n a) where
Lone a1 == Lone a2 = a1 == a2
h1 :- t1 == h2 :- t2 = h1 == h2 && t1 == t2
instance PosC n => Enum (Vector n Bool) where
toEnum = case (sing :: Sing n) of
SO -> Lone . odd
SS m -> withSingI m $
let thresh = 2^(sPosToInt m)
num = 2 * thresh
in \x -> let x' = x `mod` num
in if x' < thresh
then False :- toEnum x'
else True :- toEnum (num - 1 - x')
fromEnum = case (sing :: Sing n) of
SO -> \(Lone x) -> if x then 1 else 0
SS m -> withSingI m $
let num :: Int = 2^(1 + sPosToInt m)
in \(x:-xs) -> if x
then num - 1 - fromEnum xs
else fromEnum xs
instance PosC n => Enumerable (Vector n Bool) where
values = case (sing :: Sing n) of
SO -> [Lone False, Lone True]
SS m -> withSingI m $
let num = 2^(1 + sPosToInt m)
in take num [replicate False ..]
type BitString n = Vector n Bool
head :: Vector n a -> a
head (Lone a) = a
head (a :- _) = a
split :: forall m n a . PosC m
=> Vector (m `AddPos` n) a -> (Vector m a, Vector n a)
split = splitS (sing :: Sing m)
splitS :: SPos m -> Vector (m `AddPos` n) a -> (Vector m a, Vector n a)
splitS m (h :- t) = case m of
SO -> (Lone h, t)
SS m' -> let (b, e) = splitS m' t in (h :- b, e)
splitS _ (Lone _) = error "splitS: internal error; can't split a Lone"
replicate :: forall n a . PosC n => a -> Vector n a
replicate = replicateS (sing :: Sing n)
replicateS :: SPos n -> a -> Vector n a
replicateS n a = case n of
SO -> Lone a
SS n' -> a :- replicateS n' a
fromList :: forall n a . PosC n => [a] -> Maybe (Vector n a)
fromList = fromListS (sing :: Sing n)
fromListS :: SPos n -> [a] -> Maybe (Vector n a)
fromListS n xs = case n of
SO -> case xs of
(x:[]) -> Just (Lone x)
_ -> Nothing
SS n' -> case xs of
(x:rest) -> (:-) x <$> fromListS n' rest
_ -> Nothing
sPosToInt :: SPos n -> Int
sPosToInt SO = 1
sPosToInt (SS a) = 1 + sPosToInt a