{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE NoImplicitPrelude #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE QuantifiedConstraints #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TupleSections #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE UndecidableSuperClasses #-}
module Crypto.Lol.Cyclotomic.Tensor
( TensorPowDec(..)
, TensorG(..)
, TensorCRT(..)
, TensorGaussian(..)
, TensorGSqNorm(..)
, TensorCRTSet(..)
, hasCRTFuncs
, scalarCRT, mulGCRT, divGCRT, crt, crtInv, twaceCRT, embedCRT
, Kron, indexK, gCRTK, gInvCRTK, twCRTs
, zmsToIndexFact
, indexInfo
, extIndicesPowDec, extIndicesCRT, extIndicesCoeffs
, baseIndicesPow, baseIndicesDec, baseIndicesCRT
, digitRev
)
where
import Crypto.Lol.CRTrans
import Crypto.Lol.Prelude as LP hiding (lift, (*>))
import Crypto.Lol.Types.IFunctor
import Algebra.Module as Module (C)
import Control.Applicative
import Control.Monad.Random
import Data.Singletons.Prelude
import Data.Traversable
import Data.Tuple (swap)
import qualified Data.Vector as V
import qualified Data.Vector.Unboxed as U
class (forall m . Fact m => (Applicative (t m), Traversable (t m)),
IFunctor t, IFElt t r, Additive r)
=> TensorPowDec t r where
scalarPow :: Fact m => r -> t m r
powToDec, decToPow :: Fact m => t m r -> t m r
twacePowDec :: (m `Divides` m') => t m' r -> t m r
embedPow, embedDec :: (m `Divides` m') => t m r -> t m' r
coeffs :: (m `Divides` m') => t m' r -> [t m r]
powBasisPow :: (m `Divides` m') => Tagged m [t m' r]
class TensorPowDec t r => TensorG t r where
mulGPow, mulGDec :: Fact m => t m r -> t m r
divGPow, divGDec :: Fact m => t m r -> Maybe (t m r)
class (TensorPowDec t r, CRTrans mon r,
forall m . Fact m => (Module.C r) (t m r))
=> TensorCRT t mon r where
crtFuncs :: Fact m =>
mon ( r -> t m r,
t m r -> t m r,
t m r -> t m r,
t m r -> t m r,
t m r -> t m r)
crtExtFuncs :: (m `Divides` m') =>
mon (t m' r -> t m r,
t m r -> t m' r)
class TensorGaussian t q where
tweakedGaussianDec :: (ToRational v, Fact m, MonadRandom rnd)
=> v -> rnd (t m q)
class TensorGSqNorm t r where
gSqNormDec :: Fact m => t m r -> r
class (TensorPowDec t fp) => TensorCRTSet t fp where
crtSetDec :: (m `Divides` m', Coprime (PToF (CharOf fp)) m')
=> Tagged m [t m' fp]
hasCRTFuncs :: forall t m r mon . (TensorCRT t mon r, Fact m) => mon ()
{-# INLINABLE hasCRTFuncs #-}
hasCRTFuncs = do
(_,_,_,_,_) <- crtFuncs @t @mon @r @m
return ()
scalarCRT :: (TensorCRT t mon r, Fact m) => mon (r -> t m r)
{-# INLINABLE scalarCRT #-}
scalarCRT = (\(f,_,_,_,_) -> f) <$> crtFuncs
mulGCRT, divGCRT, crt, crtInv ::
(TensorCRT t mon r, Fact m) => mon (t m r -> t m r)
{-# INLINABLE mulGCRT #-}
{-# INLINABLE divGCRT #-}
{-# INLINABLE crt #-}
{-# INLINE crtInv #-}
mulGCRT = (\(_,f,_,_,_) -> f) <$> crtFuncs
divGCRT = (\(_,_,f,_,_) -> f) <$> crtFuncs
crt = (\(_,_,_,f,_) -> f) <$> crtFuncs
crtInv = (\(_,_,_,_,f) -> f) <$> crtFuncs
twaceCRT :: forall t m m' mon r . (TensorCRT t mon r, m `Divides` m')
=> mon (t m' r -> t m r)
{-# INLINABLE twaceCRT #-}
twaceCRT = hasCRTFuncs @t @m' @r *>
hasCRTFuncs @t @m @r *>
(fst <$> crtExtFuncs)
embedCRT :: forall t m m' mon r . (TensorCRT t mon r, m `Divides` m')
=> mon (t m r -> t m' r)
embedCRT = hasCRTFuncs @t @m' @r *>
hasCRTFuncs @t @m @r *>
(snd <$> crtExtFuncs)
fKron :: forall m r mon . (Fact m, Monad mon)
=> (forall pp . (PPow pp) => TaggedT pp mon (KronC r))
-> mon (Kron r)
fKron mat = go $ sUnF (sing :: SFactored m)
where go :: Sing (pplist :: [PrimePower]) -> mon (Kron r)
go spps = case spps of
SNil -> return MNil
(SCons spp rest) -> do
rest' <- go rest
mat' <- withWitnessT mat spp
return $ MKron rest' mat'
ppKron :: forall pp r mon . (PPow pp, Monad mon)
=> (forall p . (Prime p) => TaggedT p mon (KronC r))
-> TaggedT pp mon (KronC r)
ppKron mat = tagT $ case (sing :: SPrimePower pp) of
pp@(SPP (STuple2 (sp :: Sing p) _)) -> do
(MC h w f) <- withWitnessT mat sp
let d = withSingI pp (valuePPow @pp) `div` withSingI sp (valuePrime @p)
return $ MC (h*d) w (f . (`mod` h))
data KronC r =
MC Int Int
(Int -> Int -> r)
data Kron r = MNil | MKron (Kron r) (KronC r)
indexK :: Ring r => Kron r -> Int -> Int -> r
indexK MNil 0 0 = LP.one
indexK MNil i j = error $ "indexK MNil out of bounds: i = " ++ show i ++ ", j = " ++ show j
indexK (MKron m (MC r c mc)) i j =
let (iq,ir) = i `divMod` r
(jq,jr) = j `divMod` c
in indexK m iq jq * mc ir jr
gCRTK, gInvCRTK :: forall m mon r . (Fact m, CRTrans mon r) => mon (Kron r)
gCRTK = fKron @m gCRTPPow
gInvCRTK = fKron @m gInvCRTPPow
twCRTs :: forall m mon r . (Fact m, CRTrans mon r) => mon (Kron r)
twCRTs = fKron @m twCRTsPPow
twCRTsPPow :: forall pp mon r .
(PPow pp, CRTrans mon r) => TaggedT pp mon (KronC r)
twCRTsPPow = do
let phi = totientPPow @pp
iToZms = indexToZmsPPow @pp
jToPow = indexToPowPPow @pp
(wPow, _) <- crtInfo
(MC _ _ gCRT) <- gCRTPPow
return $ MC phi phi (\j i -> wPow (jToPow j * negate (iToZms i)) * gCRT i 0)
gCRTPPow, gInvCRTPPow :: (PPow pp, CRTrans mon r) => TaggedT pp mon (KronC r)
gCRTPPow = ppKron gCRTPrime
gInvCRTPPow = ppKron gInvCRTPrime
gCRTPrime, gInvCRTPrime :: forall p mon r .
(Prime p, CRTrans mon r) => TaggedT p mon (KronC r)
gCRTPrime = do
let p = valuePrime @p
(wPow, _) <- crtInfo
return $ MC (p-1) 1 $ if p == 2 then const $ const one
else (\i _ -> one - wPow (i+1))
gInvCRTPrime = do
let p = valuePrime @p
(wPow, phatinv) <- crtInfo
return $ MC (p-1) 1 $
if p == 2 then const $ const one
else (\i -> const $ phatinv *
sum [fromIntegral j * wPow ((i+1)*(p-1-j)) | j <- [1..p-1]])
digitRev :: PP -> Int -> Int
digitRev (_,0) 0 = 0
digitRev (p,e) j
| e >= 1 = let (q,r) = j `divMod` p
in r * (p^(e-1)) + digitRev (p,e-1) q
indexToPowPPow, indexToZmsPPow :: forall pp . PPow pp => Int -> Int
indexToPowPPow = indexToPow (ppPPow @pp)
indexToZmsPPow = indexToZms (ppPPow @pp)
zmsToIndexFact :: forall m . Fact m => (Int -> Int)
zmsToIndexFact = zmsToIndex (ppsFact @m)
indexToPow :: PP -> Int -> Int
indexToPow (p,e) j = let (jq,jr) = j `divMod` (p-1)
in p^(e-1)*jr + digitRev (p,e-1) jq
indexToZms :: PP -> Int -> Int
indexToZms (p,_) i = let (i1,i0) = i `divMod` (p-1)
in p*i1 + i0 + 1
zmsToIndex :: [PP] -> Int -> Int
zmsToIndex [] _ = 0
zmsToIndex (pp:rest) i = zmsToIndexPP pp (i `mod` valuePP pp)
+ totientPP pp * zmsToIndex rest i
zmsToIndexPP :: PP -> Int -> Int
zmsToIndexPP (p,_) i = let (i1,i0) = i `divMod` p
in (p-1)*i1 + i0 - 1
{-# INLINE toIndexPair #-}
{-# INLINE fromIndexPair #-}
toIndexPair :: [(Int,Int)] -> Int -> (Int,Int)
fromIndexPair :: [(Int,Int)] -> (Int,Int) -> Int
toIndexPair [] 0 = (0,0)
toIndexPair ((phi,phi'):rest) i' =
let (i'q,i'r) = i' `divMod` phi'
(i'rq,i'rr) = i'r `divMod` phi
(i'q1,i'q0) = toIndexPair rest i'q
in (i'rq + i'q1*(phi' `div` phi), i'rr + i'q0*phi)
fromIndexPair [] (0,0) = 0
fromIndexPair ((phi,phi'):rest) (i1,i0) =
let (i0q,i0r) = i0 `divMod` phi
(i1q,i1r) = i1 `divMod` (phi' `div` phi)
i = fromIndexPair rest (i1q,i0q)
in (i0r + i1r*phi) + i*phi'
indexInfo :: forall m m' . (m `Divides` m')
=> ([(Int,Int,Int)], Int, Int, [(Int,Int)])
indexInfo = let pps = ppsFact @m
pps' = ppsFact @m'
mpps = mergePPs pps pps'
phi = totientFact @m
phi' = totientFact @m'
tots = totients mpps
in (mpps, phi, phi', tots)
extIndicesPowDec :: forall m m' . (m `Divides` m') => U.Vector Int
{-# INLINABLE extIndicesPowDec #-}
extIndicesPowDec =
let (_, phi, _, tots) = indexInfo @m @m'
in U.generate phi (fromIndexPair tots . (0,))
extIndicesCRT :: forall m m' . (m `Divides` m') => U.Vector Int
extIndicesCRT =
let (_, phi, phi', tots) = indexInfo @m @m'
in U.generate phi'
(fromIndexPair tots . swap . (`divMod` (phi' `div` phi)))
baseWrapper :: forall m m' a . (m `Divides` m', U.Unbox a)
=> ([(Int,Int,Int)] -> Int -> a)
-> U.Vector a
baseWrapper f =
let (mpps, _, phi', _) = indexInfo @m @m'
in U.generate phi' (f mpps)
baseIndicesPow :: forall m m' . (m `Divides` m') => U.Vector (Int,Int)
baseIndicesPow = baseWrapper @m @m' (toIndexPair . totients)
{-# INLINABLE baseIndicesPow #-}
baseIndicesDec :: forall m m' . (m `Divides` m') => U.Vector (Maybe (Int,Bool))
baseIndicesDec = baseWrapper @m @m' baseIndexDec
{-# INLINABLE baseIndicesDec #-}
baseIndicesCRT :: forall m m' . (m `Divides` m') => U.Vector Int
baseIndicesCRT =
baseWrapper @m @m' (\pps -> snd . toIndexPair (totients pps))
extIndicesCoeffs :: forall m m' . (m `Divides` m')
=> V.Vector (U.Vector Int)
extIndicesCoeffs =
let (_, phi, phi', tots) = indexInfo @m @m'
in V.generate (phi' `div` phi)
(\i1 -> U.generate phi (\i0 -> fromIndexPair tots (i1,i0)))
baseIndexDec :: [(Int,Int,Int)] -> Int -> Maybe (Int, Bool)
baseIndexDec [] 0 = Just (0,False)
baseIndexDec ((p,e,e'):rest) i'
= let (i'q, i'r) = i' `divMod` totientPP (p,e')
phi = totientPP (p,e)
curr
| p>2 && e==0 && e' > 0 = case i'r of
0 -> Just (0,False)
1 -> Just (0,True)
_ -> Nothing
| otherwise = if i'r < phi then Just (i'r,False) else Nothing
in do
(i,b) <- curr
(j,b') <- baseIndexDec rest i'q
return (i + phi*j, b /= b')
mergePPs :: [PP] -> [PP] -> [(Int,Int,Int)]
mergePPs [] pps = LP.map (\(p,e) -> (p,0,e)) pps
mergePPs allpps@((p,e):pps) ((p',e'):pps')
| p == p' && e <= e' = (p, e, e') : mergePPs pps pps'
| p > p' = (p', 0, e') : mergePPs allpps pps'
totients :: [(Int, Int, Int)] -> [(Int,Int)]
totients = LP.map (\(p,e,e') -> (totientPP (p,e), totientPP (p,e')))