{-|
Module      : Crypto.Lol.RLWE.Discrete
Description : Functions and types for working with discretized ring-LWE samples.
Copyright   : (c) Eric Crockett, 2011-2017
                  Chris Peikert, 2011-2018
License     : GPL-3
Maintainer  : ecrockett0@gmail.com
Stability   : experimental
Portability : POSIX

Functions and types for working with discretized ring-LWE samples.
-}

{-# LANGUAGE AllowAmbiguousTypes   #-}
{-# LANGUAGE ConstraintKinds       #-}
{-# LANGUAGE FlexibleContexts      #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE RebindableSyntax      #-}
{-# LANGUAGE ScopedTypeVariables   #-}
{-# LANGUAGE TypeApplications      #-}
{-# LANGUAGE TypeFamilies          #-}

module Crypto.Lol.RLWE.Discrete where

import Crypto.Lol
import Crypto.Lol.RLWE.Continuous as C (errorBound, tailGaussian)

import Control.Monad.Random

-- | A discrete RLWE sample \( (a,b) \in R_q \times R_q \).
type Sample cm zq = (cm zq, cm zq)

-- | Common constraints for working with discrete RLWE.
type RLWECtx cm zq = (Cyclotomic (cm zq), Ring (cm zq), Reduce (cm (LiftOf zq)) (cm zq))

-- | A discrete RLWE sample with the given scaled variance and secret.
sample :: forall rnd v cm zq .
  (RLWECtx cm zq, Random (cm zq), RoundedGaussianCyc (cm (LiftOf zq)),
   MonadRandom rnd, ToRational v)
  => v -> cm zq -> rnd (Sample cm zq)
{-# INLINABLE sample #-}
sample svar s = let s' = adviseCRT s in do
  a <- getRandom
  e :: cm (LiftOf zq) <- roundedGaussian svar
  return (a, a * s' + reduce e)

-- | The error term of an RLWE sample, given the purported secret.
errorTerm :: (RLWECtx cm zq, LiftCyc (cm zq))
          => cm zq -> Sample cm zq -> LiftOf (cm zq)
{-# INLINABLE errorTerm #-}
errorTerm s = let s' = adviseCRT s
              in \(a,b) -> liftDec $ b - a * s'

-- | The 'gSqNorm' of the error term of an RLWE sample, given the
-- purported secret.
errorGSqNorm :: (RLWECtx cm zq, GSqNormCyc cm (LiftOf zq),
                 LiftCyc (cm zq), LiftOf (cm zq) ~ cm (LiftOf zq))
             => cm zq -> Sample cm zq -> LiftOf zq
{-# INLINABLE errorGSqNorm #-}
errorGSqNorm s = gSqNorm . errorTerm s

-- | A bound such that the 'gSqNorm' of a discretized error term
-- generated by 'roundedGaussian' with scaled variance \(v\)
-- (over the \(m\)th cyclotomic field) is less than the
-- bound except with probability approximately \(\epsilon\).
errorBound :: forall m v . (Fact m, RealRing v, Transcendental v)
              => v              -- ^ the scaled variance
              -> v              -- ^ \(\epsilon\)
              -> Int64
errorBound =
  let n = fromIntegral $ totientFact @m
      ps = filter (/= 2) . fmap fst $ ppsFact @m -- odd primes dividing m
  in \v eps -> let
    bsq = C.errorBound @m v eps -- continuous bound
    csq = C.tailGaussian @m eps
    fsq = (2 ^ length ps) * n * csq
  in ceiling $ fsq + bsq + 2 * sqrt bsq * sqrt fsq