module Crypto.Lol.Cyclotomic.UCyc
( UCyc, CElt
, mulG, divG
, tGaussian, errorRounded, errorCoset
, embed, twace, coeffsCyc, powBasis, crtSet
, adviseCRT
, liftCyc, scalarCyc
, fmapC, fmapCM
, forceBasis, forcePow, forceDec, forceAny
, U.Basis(..), U.RescaleCyc
) where
import Crypto.Lol.CRTrans
import Crypto.Lol.Cyclotomic.Tensor as T
import qualified Crypto.Lol.Cyclotomic.Utility as U
import Crypto.Lol.Gadget
import Crypto.Lol.LatticePrelude as LP
import Crypto.Lol.Types.FiniteField
import Crypto.Lol.Types.ZPP
import Algebra.Additive as Additive (C)
import Algebra.Ring as Ring (C)
import Algebra.ZeroTestable as ZeroTestable (C)
import Control.Applicative hiding ((*>))
import Control.DeepSeq
import Control.Monad.Identity
import Control.Monad.Random
import Data.Coerce
import Data.Foldable as F
import Data.Maybe
import Data.Traversable
import Data.Typeable
import Test.QuickCheck
import qualified Debug.Trace as DT
data UCyc t (m :: Factored) r where
Pow :: !(t m r) -> UCyc t m r
Dec :: !(t m r) -> UCyc t m r
CRTr :: !(t m r) -> UCyc t m r
CRTe :: !(t m (CRTExt r)) -> UCyc t m r
Scalar :: !r -> UCyc t m r
Sub :: (l `Divides` m) => !(UCyc t l r) -> UCyc t m r
deriving (Typeable)
type UCCtx t r = (Tensor t, CRTrans r, CRTrans (CRTExt r), CRTEmbed r,
ZeroTestable r, TElt t r, TElt t (CRTExt r))
type CElt t r = (Tensor t, CRTrans r, CRTrans (CRTExt r), CRTEmbed r,
ZeroTestable r, TElt t r, TElt t (CRTExt r), Eq r, NFData r)
scalarCyc :: (Fact m, CElt t a) => a -> UCyc t m a
scalarCyc = Scalar
instance (UCCtx t r, Fact m, Eq r) => Eq (UCyc t m r) where
(Scalar v1) == (Scalar v2) = v1 == v2
(Pow v1) == (Pow v2) = v1 == v2 \\ witness entailFullT v1
(Dec v1) == (Dec v2) = v1 == v2 \\ witness entailFullT v1
(CRTr v1) == (CRTr v2) = v1 == v2 \\ witness entailFullT v1
(Sub (c1 :: UCyc t l1 r)) == (Sub (c2 :: UCyc t l2 r)) =
(embed' c1 :: UCyc t (FLCM l1 l2) r) == embed' c2
\\ lcmDivides (Proxy::Proxy l1) (Proxy::Proxy l2)
p1 == p2 = toPow' p1 == toPow' p2
instance (UCCtx t r, Fact m) => ZeroTestable.C (UCyc t m r) where
isZero (Scalar v) = isZero v
isZero (Pow v) = isZero v \\ witness entailFullT v
isZero (Dec v) = isZero v \\ witness entailFullT v
isZero (CRTr v) = isZero v \\ witness entailFullT v
isZero x@(CRTe _) = isZero $ toPow' x
isZero (Sub c) = isZero c
instance (UCCtx t r, Fact m) => Additive.C (UCyc t m r) where
zero = Scalar zero
(Scalar c1) + v2 | isZero c1 = v2
v1 + (Scalar c2) | isZero c2 = v1
(Scalar c1) + (Scalar c2) = Scalar (c1+c2)
(Pow v1) + (Pow v2) = Pow $ v1 + v2 \\ witness entailFullT v1
(Dec v1) + (Dec v2) = Dec $ v1 + v2 \\ witness entailFullT v1
(CRTr v1) + (CRTr v2) = CRTr $ v1 + v2 \\ witness entailFullT v1
(CRTe v1) + (CRTe v2) = CRTe $ v1 + v2 \\ witness entailFullT v1
(Sub (c1 :: UCyc t m1 r)) + (Sub (c2 :: UCyc t m2 r)) =
(Sub $ (embed' c1 :: UCyc t (FLCM m1 m2) r) + embed' c2)
\\ lcm2Divides (Proxy::Proxy m1) (Proxy::Proxy m2) (Proxy::Proxy m)
p1@(Scalar _) + p2@(Pow _) = toPow' p1 + p2
p1@(Scalar _) + p2@(Dec _) = toDec' p1 + p2
p1@(Scalar _) + p2@(CRTr _) = toCRT' p1 + p2
p1@(Scalar _) + p2@(CRTe _) = toCRT' p1 + p2
(Scalar v1) + (Sub c2) = Sub $ Scalar v1 + c2
p1@(Pow _) + p2@(Scalar _) = p1 + toPow' p2
p1@(Dec _) + p2@(Scalar _) = p1 + toDec' p2
p1@(CRTr _) + p2@(Scalar _) = p1 + toCRT' p2
p1@(CRTe _) + p2@(Scalar _) = p1 + toCRT' p2
(Sub c1) + (Scalar v2) = Sub $ c1 + Scalar v2
(Sub c1) + c2 = embed' c1 + c2
c1 + (Sub c2) = c1 + embed' c2
p1@(Dec _) + p2@(Pow _) = toPow' p1 + p2
p1@(Pow _) + p2@(Dec _) = p1 + toPow' p2
p1@(CRTr _) + p2 = p1 + toCRT' p2
p1 + p2@(CRTr _) = toCRT' p1 + p2
p1 + p2 = toPow' p1 + toPow' p2
negate (Scalar c) = Scalar (negate c)
negate (Pow v) = Pow $ fmapT negate v
negate (Dec v) = Dec $ fmapT negate v
negate (CRTr v) = CRTr $ fmapT negate v
negate (CRTe v) = CRTe $ fmapT negate v
negate (Sub c) = Sub $ negate c
instance (UCCtx t r, Fact m) => Ring.C (UCyc t m r) where
one = Scalar one
v1@(Scalar c1) * _ | isZero c1 = v1
_ * v2@(Scalar c2) | isZero c2 = v2
(CRTr v1) * (CRTr v2) = CRTr $ v1 * v2 \\ witness entailFullT v1
(CRTe v1) * (CRTe v2) = toPow' $ CRTe $ v1 * v2 \\ witness entailFullT v1
(Scalar c1) * (Scalar c2) = Scalar $ c1 * c2
(Scalar c) * (Pow v) = Pow $ fmapT (*c) v
(Scalar c) * (Dec v) = Dec $ fmapT (*c) v
(Scalar c) * (CRTr v) = CRTr $ fmapT (*c) v
s@(Scalar _) * c'@(CRTe _) = s * toPow' c'
(Scalar c) * (Sub c2) = Sub $ Scalar c * c2
(Pow v) * (Scalar c) = Pow $ fmapT (*c) v
(Dec v) * (Scalar c) = Dec $ fmapT (*c) v
(CRTr v) * (Scalar c) = CRTr $ fmapT (*c) v
c'@(CRTe _) * s@(Scalar _) = toPow' c' * s
(Sub c1) * (Scalar c) = Sub $ c1 * Scalar c
(Sub (c1 :: UCyc t m1 r)) * (Sub (c2 :: UCyc t m2 r)) =
(Sub $ (embed' c1 :: UCyc t (FLCM m1 m2) r) * embed' c2)
\\ lcm2Divides (Proxy::Proxy m1) (Proxy::Proxy m2) (Proxy::Proxy m)
(Sub c1) * p2 = embed' c1 * p2
p1 * (Sub c2) = p1 * embed' c2
p1 * p2 = toCRT' p1 * toCRT' p2
fromInteger = Scalar . fromInteger
instance (Reduce a b, Fact m, CElt t a, CElt t b)
=> Reduce (UCyc t m a) (UCyc t m b) where
reduce = fmapC reduce . forceAny
instance (Gadget gad zq, Fact m, CElt t zq) => Gadget gad (UCyc t m zq) where
gadget = (scalarCyc <$>) <$> gadget
encode s = ((* adviseCRT s) <$>) <$> gadget
instance (Decompose gad zq, Fact m, CElt t zq,
Reduce (UCyc t m (DecompOf zq)) (UCyc t m zq))
=> Decompose gad (UCyc t m zq) where
type DecompOf (UCyc t m zq) = UCyc t m (DecompOf zq)
decompose = fromZL . traverse (toZL . decompose) . forcePow
where toZL :: Tagged s [a] -> TaggedT s ZipList a
toZL = coerce
fromZL :: TaggedT s ZipList a -> Tagged s [a]
fromZL = coerce
instance (Correct gad zq, Fact m, CElt t zq)
=> Correct gad (UCyc t m zq) where
correct bs = (correct . pasteT) <$> (sequenceA $ forceDec <$> peelT bs)
instance (Rescale a b, CElt t a, CElt t b)
=> U.RescaleCyc (UCyc t) a b where
rescaleCyc b = fmapC rescale . forceBasis (Just b)
instance (Mod a, Field b, Lift a z, Reduce z b,
CElt t a, CElt t b, CElt t (a,b), CElt t z)
=> U.RescaleCyc (UCyc t) (a,b) b where
rescaleCyc bas =
let aval = proxy modulus (Proxy::Proxy a)
in \x -> let y = forceAny x
a = fmapC fst y
b = fmapC snd y
z = liftCyc bas a
in (pure (recip (fromIntegral aval))) * (b reduce z)
liftCyc :: (Lift b a, Fact m, CElt t a, CElt t b)
=> U.Basis -> UCyc t m b -> UCyc t m a
liftCyc U.Pow = fmapC lift . forceBasis (Just U.Pow)
liftCyc U.Dec = fmapC lift . forceBasis (Just U.Dec)
adviseCRT :: (Fact m, CElt t r) => UCyc t m r -> UCyc t m r
adviseCRT x@(Scalar _) = x
adviseCRT (Sub c) = Sub $ adviseCRT c
adviseCRT x = toCRT' x
mulG :: (Fact m, CElt t r) => UCyc t m r -> UCyc t m r
mulG (Scalar c) = Pow $ mulGPow $ scalarPow c
mulG (Sub c) = mulG $ embed' c
mulG (Pow v) = Pow $ mulGPow v
mulG (Dec v) = Dec $ mulGDec v
mulG (CRTr v) = CRTr $ fromMaybe (error "FC.mulG CRTr") mulGCRT v
mulG (CRTe v) = CRTe $ fromMaybe (error "FC.mulG CRTe") mulGCRT v
divG :: (Fact m, CElt t r) => UCyc t m r -> Maybe (UCyc t m r)
divG (Scalar c) = liftM Pow (divGPow $ scalarPow c)
divG (Sub c) = divG $ embed' c
divG (Pow v) = Pow <$> divGPow v
divG (Dec v) = Dec <$> divGDec v
divG (CRTr v) = Just $ CRTr $ fromMaybe (error "FC.divG CRTr") divGCRT v
divG (CRTe v) = Just $ CRTe $ fromMaybe (error "FC.divG CRTe") divGCRT v
tGaussian :: (Fact m, OrdFloat q, Random q, CElt t q,
ToRational v, MonadRandom rnd)
=> v -> rnd (UCyc t m q)
tGaussian = liftM Dec . tGaussianDec
errorRounded :: forall v rnd t m z .
(ToInteger z, Fact m, CElt t z, ToRational v, MonadRandom rnd)
=> v -> rnd (UCyc t m z)
errorRounded svar =
fmapC (roundMult one) <$> (tGaussian svar :: rnd (UCyc t m Double))
errorCoset :: forall t m zp z v rnd .
(Mod zp, z ~ ModRep zp, Lift zp z, Fact m,
CElt t zp, CElt t z, ToRational v, MonadRandom rnd)
=> v -> UCyc t m zp -> rnd (UCyc t m z)
errorCoset =
let pval = fromIntegral $ proxy modulus (Proxy::Proxy zp)
in \ svar c ->
roundCosetDec c <$> (tGaussian (svar * pval * pval) :: rnd (UCyc t m Double))
roundCosetDec ::
(Mod zp, z ~ ModRep zp, Lift zp z, RealField q,
Fact m, CElt t q, CElt t zp, CElt t z)
=> UCyc t m zp -> UCyc t m q -> UCyc t m z
roundCosetDec c x = roundCoset <$> forceDec c <*> forceDec x
embed :: forall t r m m' . (m `Divides` m') => UCyc t m r -> UCyc t m' r
embed (Scalar c) = Scalar c
embed (Sub (c :: UCyc t l r)) = Sub c
\\ transDivides (Proxy::Proxy l) (Proxy::Proxy m) (Proxy::Proxy m')
embed c = Sub c
twace :: forall t r m m' . (UCCtx t r, m `Divides` m')
=> UCyc t m' r -> UCyc t m r
twace (Scalar c) = Scalar c
twace (Sub (c :: UCyc t l r)) =
Sub (twace c :: UCyc t (FGCD l m) r)
\\ gcdDivides (Proxy::Proxy l) (Proxy::Proxy m)
twace (Pow v) = Pow $ twacePowDec v
twace (Dec v) = Dec $ twacePowDec v
twace x@(CRTr v) =
fromMaybe (twace $ toPow' x) (CRTr <$> (twaceCRT <*> pure v))
twace (CRTe v) = CRTe $ fromMaybe (error "FC.twace CRTe") twaceCRT v
coeffsCyc :: (m `Divides` m', CElt t r)
=> U.Basis -> UCyc t m' r -> [UCyc t m r]
coeffsCyc U.Pow (Pow v) = LP.map Pow $ coeffs v
coeffsCyc U.Dec (Dec v) = LP.map Dec $ coeffs v
coeffsCyc U.Pow x = coeffsCyc U.Pow $ toPow' x
coeffsCyc U.Dec x = coeffsCyc U.Dec $ toDec' x
powBasis :: (m `Divides` m', CElt t r) => Tagged m [UCyc t m' r]
powBasis = map Pow <$> powBasisPow
crtSet :: forall t m m' r p mbar m'bar .
(m `Divides` m', ZPP r, p ~ CharOf (ZPOf r),
mbar ~ PFree p m, m'bar ~ PFree p m',
CElt t r, CElt t (ZPOf r))
=> Tagged m [UCyc t m' r]
crtSet =
let (p,e) = proxy modulusZPP (Proxy::Proxy r)
pp = Proxy::Proxy p
pm = Proxy::Proxy m
pm' = Proxy::Proxy m'
in retag (fmap (embed . (^(p^(e1))) . Dec . fmapT liftZp) <$>
(crtSetDec :: Tagged mbar [t m'bar (ZPOf r)]))
\\ pFreeDivides pp pm pm'
\\ pSplitTheorems pp pm \\ pSplitTheorems pp pm'
forceBasis :: (Fact m, CElt t r) => Maybe U.Basis -> UCyc t m r -> UCyc t m r
forceBasis (Just U.Pow) x = toPow' x
forceBasis (Just U.Dec) x = toDec' x
forceBasis Nothing x@(Scalar _) = toPow' x
forceBasis Nothing (Sub c) = forceBasis Nothing $ embed' c
forceBasis Nothing x@(CRTe _) = toPow' x
forceBasis Nothing x = x
forcePow, forceDec, forceAny :: (Fact m, CElt t r) => UCyc t m r -> UCyc t m r
forcePow = forceBasis (Just U.Pow)
forceDec = forceBasis (Just U.Dec)
forceAny = forceBasis Nothing
fmapC :: (Fact m, CElt t a, CElt t b) => (a -> b) -> UCyc t m a -> UCyc t m b
fmapC _ (Scalar _) = error "can't fmapC on Scalar. Must forceBasis first!"
fmapC _ (Sub _) = error "can't fmapC on Sub. Must forceBasis first!"
fmapC _ (CRTe _) = error "can't fmapC on CRTe. Must forceBasis first!"
fmapC f (Pow v) = Pow $ fmapT f v
fmapC f (Dec v) = Dec $ fmapT f v
fmapC f (CRTr v) = CRTr $ fmapT f v
fmapCM :: (Fact m, CElt t a, CElt t b, Monad mon)
=> (a -> mon b) -> UCyc t m a -> mon (UCyc t m b)
fmapCM _ (Scalar _) = error "can't fmapCM on Scalar. Must forceBasis first!"
fmapCM _ (Sub _) = error "can't fmapCM on Sub. Must forceBasis first!"
fmapCM _ (CRTe _) = error "can't fmapCM on CRTe. Must forceBasis first!"
fmapCM f (Pow v) = liftM Pow $ fmapTM f v
fmapCM f (Dec v) = liftM Dec $ fmapTM f v
fmapCM f (CRTr v) = liftM CRTr $ fmapTM f v
embed' :: forall t r l m .
(UCCtx t r, l `Divides` m) => UCyc t l r -> UCyc t m r
embed' (Scalar v) = Scalar v
embed' (Pow v) = Pow $ embedPow v
embed' (Dec v) = Dec $ embedDec v
embed' x@(CRTr v) =
fromMaybe (embed' $ toPow' x) (CRTr <$> (embedCRT <*> pure v))
embed' x@(CRTe _) = embed' $ toPow' x
embed' (Sub (c :: UCyc t k r)) = embed' c
\\ transDivides (Proxy::Proxy k) (Proxy::Proxy l) (Proxy::Proxy m)
toPow', toDec' :: (UCCtx t r, Fact m) => UCyc t m r -> UCyc t m r
toPow' (Scalar c) = Pow $ scalarPow c
toPow' (Sub c) = toPow' $ embed' c
toPow' x@(Pow _) = x
toPow' (Dec v) = Pow $ l v
toPow' (CRTr v) = Pow $ fromMaybe (error "FC.toPow'") crtInv v
toPow' (CRTe v) = Pow $ fmapT fromExt $ fromMaybe (error "FC.toPow'") crtInv v
toDec' x@(Scalar _) = toDec' $ toPow' x
toDec' (Sub c) = toDec' $ embed' c
toDec' (Pow v) = Dec $ lInv v
toDec' x@(Dec _) = x
toDec' (CRTr v) = Dec $ lInv $ fromMaybe (error "FC.toDec'") crtInv v
toDec' (CRTe v) = Dec $ lInv $ fmapT fromExt $ fromMaybe (error "FC.toDec'") crtInv v
toCRT' :: forall t m r . (UCCtx t r, Fact m) => UCyc t m r -> UCyc t m r
toCRT' (Sub c) = toCRT' $ embed' c
toCRT' x@(CRTr _) = x
toCRT' x@(CRTe _) = x
toCRT' x = fromMaybe (toCRTe x) (toCRTr <*> pure x)
where
toCRTr = do
crt' <- crt
scalarCRT' <- scalarCRT
return $ \x -> case x of
(Scalar c) -> CRTr $ scalarCRT' c
(Pow v) -> CRTr $ crt' v
(Dec v) -> CRTr $ crt' $ l v
toCRTe = let m = proxy valueFact (Proxy::Proxy m)
crt' = fromMaybe (error $ "FC.toCRT': no crt': " ++ (show m)) crt :: t m (CRTExt r) -> t m (CRTExt r)
scalarCRT' = fromMaybe (error "FC.toCRT': no scalar crt'") scalarCRT :: CRTExt r -> t m (CRTExt r)
in \x -> case x of
(Scalar c) -> CRTe $ scalarCRT' $ toExt c
(Pow v) -> CRTe $ crt' $ fmapT toExt v
(Dec v) -> CRTe $ crt' $ fmapT toExt $ l v
instance (Tensor t, Fact m) => Functor (UCyc t m) where
fmap f x = pure f <*> x
errApp name = error $ "UCyc.Applicative: can't/won't handle " ++ name ++
"; call forcePow|Dec first"
instance (Tensor t, Fact m) => Applicative (UCyc t m) where
pure = Scalar
(Scalar f) <*> (Scalar a) = Scalar $ f a
(Pow v1) <*> (Pow v2) = Pow $ v1 <*> v2 \\ witness entailIndexT v1
(Dec v1) <*> (Dec v2) = Dec $ v1 <*> v2 \\ witness entailIndexT v1
(CRTr v1) <*> (CRTr v2) = CRTr $ v1 <*> v2 \\ witness entailIndexT v1
(Scalar f) <*> (Pow v) = Pow $ pure f <*> v \\ witness entailIndexT v
(Scalar f) <*> (Dec v) = Dec $ pure f <*> v \\ witness entailIndexT v
(Scalar f) <*> (CRTr v) = CRTr $ pure f <*> v \\ witness entailIndexT v
(Pow v) <*> (Scalar a) = Pow $ v <*> pure a \\ witness entailIndexT v
(Dec v) <*> (Scalar a) = Dec $ v <*> pure a \\ witness entailIndexT v
(CRTr v) <*> (Scalar a) = CRTr $ v <*> pure a \\ witness entailIndexT v
(Pow _) <*> (Dec _) = error "UCyc.Applicative: Pow/Dec combo"
(Dec _) <*> (Pow _) = error "UCyc.Applicative: Pow/Dec combo"
(Sub _) <*> _ = errApp "Sub"
_ <*> (Sub _) = errApp "Sub"
(CRTe _) <*> _ = errApp "CRTe"
_ <*> (CRTe _) = errApp "CRTe"
instance (Tensor t, Fact m) => Foldable (UCyc t m) where
foldr f b (Scalar r) = f r b
foldr f b (Sub c) = F.foldr f b c
foldr f b (Pow v) = F.foldr f b v \\ witness entailIndexT v
foldr f b (Dec v) = F.foldr f b v \\ witness entailIndexT v
foldr f b (CRTr v) = F.foldr f b v \\ witness entailIndexT v
foldr _ _ (CRTe _) = error "UCyc.Foldable: can't handle CRTe"
instance (Tensor t, Fact m) => Traversable (UCyc t m) where
traverse f (Scalar r) = Scalar <$> f r
traverse f (Sub c) = Sub <$> traverse f c
traverse f (Pow v) = Pow <$> traverse f v \\ witness entailIndexT v
traverse f (Dec v) = Dec <$> traverse f v \\ witness entailIndexT v
traverse f (CRTr v) = CRTr <$> traverse f v \\ witness entailIndexT v
traverse _ (CRTe _) = error "UCyc.Traversable: can't handle CRTe"
instance (Tensor t, Fact m, TElt t r, CRTrans r) => Random (UCyc t m r) where
random = let cons = fromMaybe Pow
(proxyT hasCRTFuncs (Proxy::Proxy (t m r)) >> Just CRTr)
in \g -> let (v,g') = random g
\\ proxy entailFullT (Proxy::Proxy (t m r))
in (cons v, g')
randomR _ = error "randomR non-sensical for cyclotomic rings"
instance (Show r, Show (t m r), Show (t m (CRTExt r)))
=> Show (UCyc t m r) where
show (Scalar c) = "scalar " ++ show c
show (Sub _) = "subring (not showing due to missing constraints)"
show (Pow v) = "powerful basis coeffs " ++ show v
show (Dec v) = "decoding basis coeffs " ++ show v
show (CRTr v) = "CRTr basis coeffs " ++ show v
show (CRTe v) = "CRTe basis coeffs " ++ show v
instance (Arbitrary (t m r)) => Arbitrary (UCyc t m r) where
arbitrary = liftM Pow arbitrary
shrink = shrinkNothing
instance (Tensor t, Fact m, NFData r, TElt t r, TElt t (CRTExt r))
=> NFData (UCyc t m r) where
rnf (Pow x) = rnf x \\ witness entailFullT x
rnf (Dec x) = rnf x \\ witness entailFullT x
rnf (CRTr x) = rnf x \\ witness entailFullT x
rnf (CRTe x) = rnf x \\ witness entailFullT x
rnf (Scalar x) = rnf x
rnf (Sub x) = rnf x