{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE DeriveAnyClass #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE DerivingStrategies #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE Rank2Types #-} -- for ReifyCrypto
{-# LANGUAGE UndecidableInstances #-} -- for Reifies instances
{-# OPTIONS_GHC -fno-warn-orphans #-}
-- | Finite Field Cryptography (FFC)
-- is a method of implementing discrete logarithm cryptography
-- using finite field mathematics.
module Voting.Protocol.Arith where

import Control.Arrow (first)
import Control.DeepSeq (NFData)
import Control.Monad (Monad(..))
import Data.Aeson (ToJSON(..),FromJSON(..))
import Data.Bits
import Data.Bool
import Data.Eq (Eq(..))
import Data.Foldable (Foldable, foldl')
import Data.Function (($), (.), id)
import Data.Functor ((<$>))
import Data.Int (Int)
import Data.Maybe (Maybe(..), fromJust)
import Data.Ord (Ord(..))
import Data.Proxy (Proxy(..))
import Data.Reflection (Reifies(..))
import Data.Semigroup (Semigroup(..))
import Data.String (IsString(..))
import Data.Text (Text)
import GHC.Generics (Generic)
import GHC.Natural (minusNaturalMaybe)
import Numeric.Natural (Natural)
import Prelude (Integer, Integral(..), fromIntegral, Enum(..))
import Text.Read (readMaybe)
import Text.Show (Show(..))
import qualified Control.Monad.Trans.State.Strict as S
import qualified Crypto.Hash as Crypto
import qualified Data.Aeson as JSON
import qualified Data.Aeson.Types as JSON
import qualified Data.ByteArray as ByteArray
import qualified Data.ByteString as BS
import qualified Data.ByteString.Base64 as BS64
import qualified Data.Char as Char
import qualified Data.List as List
import qualified Data.Text as Text
import qualified Data.Text.Encoding as Text
import qualified Data.Text.Lazy as TL
import qualified Data.Text.Lazy.Builder as TLB
import qualified Data.Text.Lazy.Builder.Int as TLB
import qualified Prelude as Num
import qualified System.Random as Random

-- * Class 'Additive'
class Additive a where
        zero :: a
        (+) :: a -> a -> a; infixl 6 +
        sum :: Foldable f => f a -> a
        sum = foldl' (+) zero
instance Additive Natural where
        zero = 0
        (+)  = (Num.+)
instance Additive Integer where
        zero = 0
        (+)  = (Num.+)
instance Additive Int where
        zero = 0
        (+)  = (Num.+)

-- ** Class 'Negable'
class Additive a => Negable a where
        neg :: a -> a
        (-) :: a -> a -> a; infixl 6 -
        x-y = x + neg y
instance Negable Integer where
        neg  = Num.negate
instance Negable Int where
        neg  = Num.negate

-- * Class 'Multiplicative'
class Multiplicative a where
        one :: a
        (*) :: a -> a -> a; infixl 7 *
instance Multiplicative Natural where
        one = 1
        (*) = (Num.*)
instance Multiplicative Integer where
        one = 1
        (*) = (Num.*)
instance Multiplicative Int where
        one = 1
        (*) = (Num.*)

-- ** Class 'Invertible'
class Multiplicative a => Invertible a where
        inv :: a -> a
        (/) :: a -> a -> a; infixl 7 /
        x/y = x * inv y

-- | @(b '^' e)@ returns the modular exponentiation of base 'b' by exponent 'e'.
(^) ::
 Reifies c crypto =>
 Multiplicative (FieldElement crypto c) =>
 G crypto c -> E crypto c -> G crypto c
(^) b (E e)
 | e == 0 = one
 | otherwise = t * (b*b) ^ E (e`shiftR`1)
        where
        t | testBit e 0 = b
                | otherwise   = one
infixr 8 ^

-- | 'groupGenInverses' returns the infinite list
-- of 'inv'erse powers of 'groupGen':
-- @['groupGen' '^' 'neg' i | i <- [0..]]@,
-- but by computing each value from the previous one.
--
-- Used by 'intervalDisjunctions'.
groupGenInverses ::
 forall crypto c.
 Reifies c crypto =>
 Group crypto =>
 Multiplicative (FieldElement crypto c) =>
 [G crypto c]
groupGenInverses = go one
        where
        invGen = inv $ groupGen @crypto @c
        go g = g : go (g * invGen)

groupGenPowers ::
 forall crypto c.
 Reifies c crypto =>
 Group crypto =>
 Multiplicative (FieldElement crypto c) =>
 [G crypto c]
groupGenPowers = go one
        where go g = g : go (g * groupGen @crypto @c)

-- | @('randomR' i)@ returns a random integer in @[0..i-1]@.
randomR ::
 Monad m =>
 Random.RandomGen r =>
 Random.Random i =>
 Negable i =>
 Multiplicative i =>
 i -> S.StateT r m i
randomR i = S.StateT $ return . Random.randomR (zero, i-one)

-- | @('random')@ returns a random integer
-- in the range determined by its type.
random ::
 Monad m =>
 Random.RandomGen r =>
 Random.Random i =>
 Negable i =>
 Multiplicative i =>
 S.StateT r m i
random = S.StateT $ return . Random.random

instance Random.Random Natural where
        randomR (mini,maxi) =
                first (fromIntegral::Integer -> Natural) .
                Random.randomR (fromIntegral mini, fromIntegral maxi)
        random = first (fromIntegral::Integer -> Natural) . Random.random

-- * Type family 'FieldElement'
type family FieldElement crypto :: * -> *

-- * Class 'Group' where
class Group crypto where
        groupGen   :: Reifies c crypto => G crypto c
        groupOrder :: Reifies c crypto => Proxy c -> Natural

-- ** Type 'G'
-- | The type of the elements of a subgroup of a field.
newtype G crypto c = G { unG :: FieldElement crypto c }
deriving newtype instance Eq     (FieldElement crypto c) => Eq     (G crypto c)
deriving newtype instance Ord    (FieldElement crypto c) => Ord    (G crypto c)
deriving newtype instance Show   (FieldElement crypto c) => Show   (G crypto c)
deriving newtype instance NFData (FieldElement crypto c) => NFData (G crypto c)
instance ToJSON (FieldElement crypto c) => ToJSON (G crypto c) where
        toJSON = JSON.toJSON . unG
instance FromNatural (FieldElement crypto c) => FromNatural (G crypto c) where
        fromNatural = G . fromNatural
instance ToNatural (FieldElement crypto c) => ToNatural (G crypto c) where
        nat = nat . unG
instance Multiplicative (FieldElement crypto c) => Multiplicative (G crypto c) where
        one = G one
        G x * G y = G (x * y)
instance
 ( Reifies c crypto
 , Group crypto
 , Multiplicative (FieldElement crypto c)
 ) => Invertible (G crypto c) where
        -- | NOTE: add 'groupOrder' so the exponent given to (^) is positive.
        inv = (^ E (fromJust $ groupOrder @crypto (Proxy @c)`minusNaturalMaybe`1))

-- ** Type 'E'
-- | An exponent of a (cyclic) subgroup of a field.
-- The value is always in @[0..'groupOrder'-1]@.
newtype E crypto c = E { unE :: Natural }
 deriving (Eq,Ord,Show)
 deriving newtype NFData
instance ToJSON (E crypto c) where
        toJSON = JSON.toJSON . show . unE
instance (Reifies c crypto, Group crypto) => FromJSON (E crypto c) where
        parseJSON (JSON.String s)
         | Just (c0,_) <- Text.uncons s
         , c0 /= '0'
         , Text.all Char.isDigit s
         , Just x <- readMaybe (Text.unpack s)
         , x < groupOrder @crypto (Proxy @c)
         = return (E x)
        parseJSON json = JSON.typeMismatch "Exponent" json
instance (Reifies c crypto, Group crypto) => FromNatural (E crypto c) where
        fromNatural i =
                E $ abs $ i `mod` groupOrder @crypto (Proxy @c)
                where
                abs x | x < 0 = x + groupOrder @crypto (Proxy @c)
                      | otherwise = x
instance ToNatural (E crypto c) where
        nat = unE
instance (Reifies c crypto, Group crypto) => Additive (E crypto c) where
        zero = E zero
        E x + E y = E $ (x + y) `mod` groupOrder @crypto (Proxy @c)
instance (Reifies c crypto, Group crypto) => Negable (E crypto c) where
        neg (E x)
         | x == 0 = zero
         | otherwise = E $ fromJust $ nat (groupOrder @crypto (Proxy @c))`minusNaturalMaybe`x
instance (Reifies c crypto, Group crypto) => Multiplicative (E crypto c) where
        one = E one
        E x * E y = E $ (x * y) `mod` groupOrder @crypto (Proxy @c)
instance (Reifies c crypto, Group crypto) => Random.Random (E crypto c) where
        randomR (E lo, E hi) =
                first (E . fromIntegral) .
                Random.randomR
                 ( 0`max`toInteger lo
                 , toInteger hi`min`(toInteger (groupOrder @crypto (Proxy @c)) - 1) )
        random =
                first (E . fromIntegral) .
                Random.randomR (0, toInteger (groupOrder @crypto (Proxy @c)) - 1)
instance (Reifies c crypto, Group crypto) => Enum (E crypto c) where
        toEnum = fromNatural . fromIntegral
        fromEnum = fromIntegral . nat
        enumFromTo lo hi = List.unfoldr
         (\i -> if i<=hi then Just (i, i+one) else Nothing) lo

-- * Class 'FromNatural'
class FromNatural a where
        fromNatural :: Natural -> a

-- * Class 'ToNatural'
class ToNatural a where
        nat :: a -> Natural
instance ToNatural Natural where
        nat = id

-- | @('bytesNat' x)@ returns the serialization of 'x'.
bytesNat :: ToNatural n => n -> BS.ByteString
bytesNat = fromString . show . nat

-- * Type 'Hash'
newtype Hash crypto c = Hash (E crypto c)
 deriving newtype (Eq,Ord,Show,NFData)

-- | @('hash' bs gs)@ returns as a number in 'GroupExponent'
-- the 'Crypto.SHA256' hash of the given 'BS.ByteString' 'bs'
-- prefixing the decimal representation of given subgroup elements 'gs',
-- with a comma (",") intercalated between them.
--
-- NOTE: to avoid any collision when the 'hash' function is used in different contexts,
-- a message 'gs' is actually prefixed by a 'bs' indicating the context.
--
-- Used by 'proveEncryption' and 'verifyEncryption',
-- where the 'bs' usually contains the 'statement' to be proven,
-- and the 'gs' contains the 'commitments'.
hash ::
 Reifies c crypto =>
 Group crypto =>
 ToNatural (FieldElement crypto c) =>
 BS.ByteString ->
 [G crypto c] ->
 E crypto c
hash bs gs = do
        let s = bs <> BS.intercalate (fromString ",") (bytesNat <$> gs)
        let h = Crypto.hashWith Crypto.SHA256 s
        fromNatural $
                decodeBigEndian $ ByteArray.convert h

-- | @('decodeBigEndian' bs)@ interpret @bs@ as big-endian number.
decodeBigEndian :: BS.ByteString -> Natural
decodeBigEndian =
        BS.foldl'
         (\acc b -> acc`shiftL`8 + fromIntegral b)
         (0::Natural)

-- ** Type 'Base64SHA256'
newtype Base64SHA256 = Base64SHA256 Text
 deriving (Eq,Ord,Show,Generic)
 deriving anyclass (ToJSON,FromJSON)
 deriving newtype NFData

-- | @('base64SHA256' bs)@ returns the 'Crypto.SHA256' hash
-- of the given 'BS.ByteString' 'bs',
-- as a 'Text' escaped in @base64@ encoding
-- (<https://tools.ietf.org/html/rfc4648 RFC 4648>).
base64SHA256 :: BS.ByteString -> Base64SHA256
base64SHA256 bs =
        let h = Crypto.hashWith Crypto.SHA256 bs in
        Base64SHA256 $
                Text.takeWhile (/= '=') $ -- NOTE: no padding.
                Text.decodeUtf8 $ BS64.encode $ ByteArray.convert h

-- ** Type 'HexSHA256'
newtype HexSHA256 = HexSHA256 Text
 deriving (Eq,Ord,Show,Generic)
 deriving anyclass (ToJSON,FromJSON)
 deriving newtype NFData
-- | @('hexSHA256' bs)@ returns the 'Crypto.SHA256' hash
-- of the given 'BS.ByteString' 'bs', escaped in hexadecimal
-- into a 'Text' of 32 lowercase characters.
--
-- Used (in retro-dependencies of this library) to hash
-- the 'PublicKey' of a voter or a trustee.
hexSHA256 :: BS.ByteString -> Text
hexSHA256 bs =
        let h = Crypto.hashWith Crypto.SHA256 bs in
        let n = decodeBigEndian $ ByteArray.convert h in
        -- NOTE: always set the 256 bit then remove it
        -- to always have leading zeros,
        -- and thus always 64 characters wide hashes.
        TL.toStrict $
        TL.tail $ TLB.toLazyText $ TLB.hexadecimal $
        setBit n 256