module Crypto.Lol.Cyclotomic.Tensor.RepaTensor
( RT ) where
import Crypto.Lol.Cyclotomic.Tensor as T
import Crypto.Lol.Cyclotomic.Tensor.RepaTensor.CRT
import Crypto.Lol.Cyclotomic.Tensor.RepaTensor.Extension
import Crypto.Lol.Cyclotomic.Tensor.RepaTensor.Gauss
import Crypto.Lol.Cyclotomic.Tensor.RepaTensor.GL
import Crypto.Lol.Cyclotomic.Tensor.RepaTensor.RTCommon as RT
import Crypto.Lol.LatticePrelude as LP hiding
((!!))
import Crypto.Lol.Types.IZipVector
import Algebra.Additive as Additive (C)
import Algebra.Ring as Ring (C)
import Algebra.ZeroTestable as ZeroTestable (C)
import Control.Applicative
import Control.DeepSeq (NFData (rnf))
import Control.Monad (liftM)
import Control.Monad.Random
import Data.Coerce
import Data.Constraint
import Data.Foldable as F
import Data.Maybe
import Data.Traversable as T
import Data.Typeable
import Data.Vector.Unboxed as U hiding (force)
import Test.QuickCheck
data RT (m :: Factored) r where
RT :: Unbox r => !(Arr m r) -> RT m r
ZV :: IZipVector m r -> RT m r
deriving (Typeable)
deriving instance Show r => Show (RT m r)
instance Eq r => Eq (RT m r) where
(ZV a) == (ZV b) = a == b
(RT a) == (RT b) = a == b
a@(RT _) == b = a == toRT b
a == b@(RT _) = toRT a == b
zvToArr :: Unbox r => IZipVector m r -> Arr m r
zvToArr v = let vec = convert $ unIZipVector v
in Arr $ fromUnboxed (Z :. U.length vec) vec
toRT :: Unbox r => RT m r -> RT m r
toRT v@(RT _) = v
toRT (ZV v) = RT $ zvToArr v
toZV :: Fact m => RT m r -> RT m r
toZV (RT (Arr v)) = ZV $ fromMaybe (error "toZV: internal error") $
iZipVector $ convert $ toUnboxed v
toZV v@(ZV _) = v
wrap :: Unbox r => (Arr l r -> Arr m r) -> RT l r -> RT m r
wrap f (RT v) = RT $ f v
wrap f (ZV v) = RT $ f $ zvToArr v
wrapM :: (Unbox r, Monad mon) => (Arr l r -> mon (Arr m r))
-> RT l r -> mon (RT m r)
wrapM f (RT v) = liftM RT $ f v
wrapM f (ZV v) = liftM RT $ f $ zvToArr v
instance Tensor RT where
type TElt RT r = (IntegralDomain r, ZeroTestable r,
Eq r, Random r, NFData r,
Unbox r, Elt r)
entailIndexT = tag $ Sub Dict
entailFullT = tag $ Sub Dict
scalarPow = RT . scalarPow'
l = wrap fL
lInv = wrap fLInv
mulGPow = wrap fGPow
mulGDec = wrap fGDec
divGPow = wrapM fGInvPow
divGDec = wrapM fGInvDec
crtFuncs = (,,,,) <$>
(liftM (RT .) scalarCRT') <*>
(wrap <$> mulGCRT') <*>
(wrap <$> divGCRT') <*>
(wrap <$> fCRT) <*>
(wrap <$> fCRTInv)
tGaussianDec :: forall v rnd m q .
(Fact m, OrdFloat q, Random q, TElt RT q,
ToRational v, MonadRandom rnd) => v -> rnd (RT m q)
tGaussianDec = liftM RT . tGaussianDec'
twacePowDec = wrap twacePowDec'
embedPow = wrap embedPow'
embedDec = wrap embedDec'
crtExtFuncs = (,) <$> (liftM wrap twaceCRT')
<*> (liftM wrap embedCRT')
coeffs = wrapM coeffs'
powBasisPow = (RT <$>) <$> powBasisPow'
crtSetDec = (RT <$>) <$> crtSetDec'
fmapT f (RT v) = RT $ (coerce $ force . RT.map f) v
fmapT f v@(ZV _) = fmapT f $ toRT v
fmapTM f (RT (Arr arr)) = liftM (RT . Arr . fromUnboxed (extent arr)) $
U.mapM f $ toUnboxed arr
fmapTM f v@(ZV _) = fmapTM f $ toRT v
instance Fact m => Functor (RT m) where
fmap f x = pure f <*> x
instance Fact m => Applicative (RT m) where
pure = ZV . pure
(ZV f) <*> (ZV a) = ZV (f <*> a)
f@(ZV _) <*> v@(RT _) = f <*> toZV v
instance Fact m => Foldable (RT m) where
foldMap = foldMapDefault
instance Fact m => Traversable (RT m) where
traverse f r@(RT _) = T.traverse f $ toZV r
traverse f (ZV v) = ZV <$> T.traverse f v
instance (Fact m, Additive r, Unbox r, Elt r) => Additive.C (RT m r) where
(RT a) + (RT b) = RT $ coerce (\x -> force . RT.zipWith (+) x) a b
a + b = toRT a + toRT b
negate (RT a) = RT $ (coerce $ force . RT.map negate) a
negate a = negate $ toRT a
zero = RT $ repl zero
instance (Fact m, Ring r, Unbox r, Elt r) => Ring.C (RT m r) where
(RT a) * (RT b) = RT $ coerce (\x -> force . RT.zipWith (*) x) a b
a * b = (toRT a) * (toRT b)
fromInteger = RT . repl . fromInteger
instance (Fact m, ZeroTestable r, Unbox r, Elt r) => ZeroTestable.C (RT m r) where
isZero (RT (Arr a)) = isZero $ foldAllS (\ x y -> if isZero x then y else x) (a RT.! (Z:.0)) a
isZero (ZV v) = isZero v
instance (Unbox r, Random (Arr m r)) => Random (RT m r) where
random = runRand $ liftM RT (liftRand random)
randomR = error "randomR nonsensical for RT"
instance (Unbox r, Arbitrary (Arr m r)) => Arbitrary (RT m r) where
arbitrary = RT <$> arbitrary
instance (NFData r) => NFData (RT m r) where
rnf (RT v) = rnf v
rnf (ZV v) = rnf v