{-# LANGUAGE Safe #-}
module Crypto.NewHope.CPA_PKE ( keypair
, encrypt
, decrypt
, publicKeyBytes
, secretKeyBytes
, cipherTextBytes
, PublicKey
, makePublicKey
, makePublicKeyFromBytes
, getPKData
, getPKPolyData
, getPKPoly
, getPKSeed
, SecretKey
, makeSecretKey
, makeSecretKeyFromBytes
, getSKPoly
, getSKPolyData
, CipherText
, makeCipherTextFromBytes
, getCTData
, getCTb
, getCTbData
, getCTv
, makePlainText
, getPTData
) where
import Control.Applicative
import qualified Data.ByteString as BS
import Crypto.NewHope.FIPS202
import Crypto.NewHope.Internals (N (N1024, N512), symBytes)
import qualified Crypto.NewHope.Internals as Internals
import Crypto.NewHope.Poly (Poly)
import qualified Crypto.NewHope.Poly as Poly
import Crypto.NewHope.RNG
import StringUtils
publicKeyBytes :: Internals.N -> Int
publicKeyBytes n = Poly.polyBytes n + symBytes
secretKeyBytes :: Internals.N -> Int
secretKeyBytes = Poly.polyBytes
cipherTextBytes :: Internals.N -> Int
cipherTextBytes = liftA2 (+) Poly.polyBytes Poly.polyCompressedBytes
newtype PlainText = PlainText BS.ByteString deriving Eq
makePlainText :: BS.ByteString -> PlainText
makePlainText bs
| not lengthOK = error "Invalid length for PlainText"
| otherwise = PlainText bs
where
len = BS.length bs
lengthOK = len == Poly.polyMsgBytes
getPTData :: PlainText -> BS.ByteString
getPTData (PlainText ptData) = ptData
getPTv :: PlainText -> Internals.N -> Poly
getPTv (PlainText ptData) n = Poly.fromMsg n ptData
newtype PublicKey = PublicKey BS.ByteString deriving Eq
getPKn :: PublicKey -> Internals.N
getPKn (PublicKey pkData)
| len == publicKeyBytes N512 = N512
| len == publicKeyBytes N1024 = N1024
| otherwise = error "Invalid N for PublicKey"
where
len = BS.length pkData
makePublicKey :: Poly -> Internals.Seed -> PublicKey
makePublicKey poly seed
| not lengthOK = error "Invalid imputed length for PublicKey"
| otherwise = PublicKey bs
where
bs = BS.append poly' seed'
poly' = Poly.toByteString poly
seed' = Internals.getSeedData seed
n = Poly.getN poly
lengthOK = BS.length bs == publicKeyBytes n
makePublicKeyFromBytes :: BS.ByteString -> PublicKey
makePublicKeyFromBytes pkData
| not lengthOK = error "Invalid length for PublicKey"
| otherwise = PublicKey pkData
where
len = BS.length pkData
lengthOK = len == publicKeyBytes N512 || len == publicKeyBytes N1024
getPKData :: PublicKey -> BS.ByteString
getPKData (PublicKey pkData) = pkData
getPKPolyBytes :: PublicKey -> Int
getPKPolyBytes pk = Poly.polyBytes $ getPKn pk
getPKPolyData :: PublicKey -> BS.ByteString
getPKPolyData pk@(PublicKey pkData) = encoded
where
polyBytes = getPKPolyBytes pk
encoded = BS.take polyBytes pkData
getPKPoly :: PublicKey -> Poly
getPKPoly pk@(PublicKey pkData) = Poly.fromByteString encoded
where
polyBytes = getPKPolyBytes pk
encoded = BS.take polyBytes pkData
getPKSeed :: PublicKey -> Internals.Seed
getPKSeed (PublicKey pkData) = Internals.makeSeed seedData
where
offset = BS.length pkData - symBytes
seedData = BS.drop offset pkData
newtype SecretKey = SecretKey BS.ByteString deriving Eq
makeSecretKey :: Poly -> SecretKey
makeSecretKey poly
| not lengthOK = error "Invalid imputed N for SecretKey"
| otherwise = SecretKey bs
where
n = Poly.getN poly
bs = Poly.toByteString poly
lengthOK = BS.length bs == secretKeyBytes n
makeSecretKeyFromBytes :: BS.ByteString -> SecretKey
makeSecretKeyFromBytes bs
| not lengthOK = error "Invalid length for SecretKey"
| otherwise = SecretKey bs
where
len512 = secretKeyBytes N512
len1024 = secretKeyBytes N1024
len = BS.length bs
lengthOK = len == len512 || len == len1024
_getSKn :: SecretKey -> Internals.N
_getSKn (SecretKey skData)
| len == len512 = N512
| len == len1024 = N1024
| otherwise = error "Invalid N for SecretKey"
where
len = BS.length skData
len512 = secretKeyBytes N512
len1024 = secretKeyBytes N1024
getSKPolyData :: SecretKey -> BS.ByteString
getSKPolyData (SecretKey sk) = sk
getSKPoly :: SecretKey -> Poly
getSKPoly (SecretKey sk) = Poly.fromByteString sk
newtype CipherText = CipherText BS.ByteString deriving Eq
instance Show CipherText where
show (CipherText bs) = "CipherText: " ++ byteStringToHexString bs
makeCipherText :: Poly -> Poly -> CipherText
makeCipherText b v
| not lengthOK = error "Invalid imputed length for CipherText"
| otherwise = CipherText bs
where
b' = Poly.toByteString b
v' = Poly.compress v
bs = BS.append b' v'
lengthOK = let bn = Poly.getN b
vn = Poly.getN v
in (bn == vn) && (BS.length bs == cipherTextBytes bn)
makeCipherTextFromBytes :: BS.ByteString -> CipherText
makeCipherTextFromBytes bs
| not lengthOK = error "Invalid length for CipherText"
| otherwise = CipherText bs
where
len = BS.length bs
len512 = cipherTextBytes N512
len1024 = cipherTextBytes N1024
lengthOK = len == len512 || len == len1024
getCTn :: CipherText -> Internals.N
getCTn (CipherText ctData)
| len == len512 = N512
| len == len1024 = N1024
| otherwise = error "Invalid N for CipherText"
where
len = BS.length ctData
len512 = cipherTextBytes N512
len1024 = cipherTextBytes N1024
getCTbData :: CipherText -> BS.ByteString
getCTbData ct@(CipherText ctData) = polyData
where
polyData = BS.take polyBytes ctData
n = getCTn ct
polyBytes = Poly.polyBytes n
getCTb :: CipherText -> Poly
getCTb = Poly.fromByteString . getCTbData
getCTData :: CipherText -> BS.ByteString
getCTData (CipherText ctData) = ctData
getCTv :: CipherText -> Poly
getCTv ct@(CipherText ctData) = Poly.decompress polyData
where
polyData = BS.drop polyBytes ctData
n = getCTn ct
polyBytes = Poly.polyBytes n
genA :: Internals.N -> Internals.Seed -> Poly
genA = Poly.uniform
keypair :: Context -> Internals.N -> (PublicKey, SecretKey, Context)
keypair ctx n = (pk, sk, ctx')
where
(z, ctx') = randomBytes ctx symBytes
(publicSeed, noiseSeed) = BS.splitAt symBytes $ shake256 z (2 * symBytes)
shat = Poly.ntt $ Poly.sample n (Internals.makeSeed noiseSeed) 0
ehat = Poly.ntt $ Poly.sample n (Internals.makeSeed noiseSeed) 1
sk = makeSecretKey shat
pk = let ahatShat = let ahat = genA n $ Internals.makeSeed publicSeed
in Poly.mulPointwise shat ahat
bhat = Poly.add ehat ahatShat
in makePublicKey bhat $ Internals.makeSeed publicSeed
encrypt :: PlainText -> PublicKey -> Internals.Seed -> CipherText
encrypt pt pk coin = makeCipherText uhat vprime
where
n = getPKn pk
v = getPTv pt n
bhat = getPKPoly pk
publicSeed = getPKSeed pk
sprime = Poly.ntt $ Poly.sample n coin 0
eprime = Poly.ntt $ Poly.sample n coin 1
uhat = Poly.add eprime $ Poly.mulPointwise sprime (genA n publicSeed)
vprime = Poly.add v $ Poly.add (Poly.sample n coin 2)
(Poly.invntt $ Poly.mulPointwise bhat sprime)
decrypt :: CipherText -> SecretKey -> PlainText
decrypt c sk = PlainText msg
where
shat = getSKPoly sk
uhat = getCTb c
vprime = getCTv c
msg = Poly.toMsg $ Poly.sub (Poly.invntt $ Poly.mulPointwise shat uhat) vprime