{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE Safe #-}
module Crypto.NewHope.Internal.CCA_KEM where
import Control.DeepSeq
import qualified Data.ByteString as BS
import StringUtils
import qualified Crypto.NewHope.CPA_PKE as CPA_PKE
import Crypto.NewHope.FIPS202
import Crypto.NewHope.Internals (N (N1024, N512), symBytes)
import qualified Crypto.NewHope.Internals as Internals
import Crypto.NewHope.Poly (Poly)
import qualified Crypto.NewHope.Poly as Poly
import Crypto.NewHope.RNG
import Crypto.NewHope.Verify
publicKeyBytes :: Internals.N -> Int
publicKeyBytes = CPA_PKE.publicKeyBytes
secretKeyBytes :: Internals.N -> Int
secretKeyBytes m = CPA_PKE.secretKeyBytes m + CPA_PKE.publicKeyBytes m + 2 * symBytes
cipherTextBytes :: Internals.N -> Int
cipherTextBytes m = CPA_PKE.cipherTextBytes m + symBytes
newtype SharedSecret = SharedSecret BS.ByteString deriving Eq
instance NFData SharedSecret
where
rnf (SharedSecret skData) = deepseq skData ()
makeSharedSecret :: BS.ByteString -> SharedSecret
makeSharedSecret bs
| not lengthOK = error "Invalid length for SharedSecret"
| otherwise = SharedSecret bs
where
lengthOK = BS.length bs == Internals.sharedSecretBytes
getSSData :: SharedSecret -> BS.ByteString
getSSData (SharedSecret ssData) = ssData
newtype PublicKey = PublicKey BS.ByteString deriving Eq
instance NFData PublicKey
where
rnf _pk = ()
makePublicKey :: BS.ByteString -> PublicKey
makePublicKey bs
| not lengthOK = error "Invalid length for PublicKey"
| otherwise = PublicKey bs
where
len = BS.length bs
len512 = publicKeyBytes N512
len1024 = publicKeyBytes N1024
lengthOK = len == len512 || len == len1024
getPKData :: PublicKey -> BS.ByteString
getPKData (PublicKey pkData) = pkData
getPKn :: PublicKey -> Internals.N
getPKn (PublicKey pkData)
| len == len512 = N512
| len == len1024 = N1024
| otherwise = error "Invalid N for PublicKey"
where
len = BS.length pkData
len512 = publicKeyBytes N512
len1024 = publicKeyBytes N1024
getPKPoly :: PublicKey -> Poly
getPKPoly pk@(PublicKey pkData) = Poly.fromByteString encoded
where
polyBytes = Poly.polyBytes $ getPKn pk
encoded = BS.take polyBytes pkData
getPKPolyData :: PublicKey -> BS.ByteString
getPKPolyData pk@(PublicKey pkData) = encoded
where
polyBytes = Poly.polyBytes $ getPKn pk
encoded = BS.take polyBytes pkData
getPKSeedData :: PublicKey -> BS.ByteString
getPKSeedData pk@(PublicKey pkData) = encoded
where
polyBytes = Poly.polyBytes $ getPKn pk
encoded = BS.drop polyBytes pkData
newtype SecretKey = SecretKey BS.ByteString deriving Eq
instance NFData SecretKey
where
rnf _sk = ()
makeSecretKey :: BS.ByteString -> SecretKey
makeSecretKey bs
| not lengthOK = error "Invalid length for SecretKey"
| otherwise = SecretKey bs
where
len = BS.length bs
len512 = secretKeyBytes N512
len1024 = secretKeyBytes N1024
lengthOK = len == len512 || len == len1024
getSKn :: SecretKey -> Internals.N
getSKn (SecretKey skData)
| len == len512 = N512
| len == len1024 = N1024
| otherwise = error "Invalid N for SecretKey"
where
len = BS.length skData
len512 = secretKeyBytes N512
len1024 = secretKeyBytes N1024
getSKData :: SecretKey -> BS.ByteString
getSKData (SecretKey skData) = skData
getSkPkeSecretKeyData :: SecretKey -> BS.ByteString
getSkPkeSecretKeyData sk@(SecretKey skData) = encoded
where
secretKeyBytes' = CPA_PKE.secretKeyBytes $ getSKn sk
encoded = BS.take secretKeyBytes' skData
getSkSecretKey :: SecretKey -> CPA_PKE.SecretKey
getSkSecretKey sk = CPA_PKE.makeSecretKeyFromBytes $ getSkPkeSecretKeyData sk
getSkPkePublicKeyData :: SecretKey -> BS.ByteString
getSkPkePublicKeyData sk@(SecretKey skData) = encoded
where
n = getSKn sk
secretKeySize = CPA_PKE.secretKeyBytes n
publicKeySize = CPA_PKE.publicKeyBytes n
encoded = bsRange skData secretKeySize publicKeySize
getSkPkePublicKey :: SecretKey -> CPA_PKE.PublicKey
getSkPkePublicKey = CPA_PKE.makePublicKeyFromBytes . getSkPkePublicKeyData
getSkPkHash :: SecretKey -> BS.ByteString
getSkPkHash sk@(SecretKey skData) = encoded
where
n = getSKn sk
secretKeySize = CPA_PKE.secretKeyBytes n
publicKeySize = CPA_PKE.publicKeyBytes n
offset = secretKeySize + publicKeySize
encoded = bsRange skData offset symBytes
getSkZ :: SecretKey -> BS.ByteString
getSkZ sk@(SecretKey skData) = encoded
where
n = getSKn sk
secretKeySize = CPA_PKE.secretKeyBytes n
publicKeySize = CPA_PKE.publicKeyBytes n
offset = secretKeySize + publicKeySize + symBytes
encoded = bsRange skData offset symBytes
newtype CipherText = CipherText BS.ByteString deriving Eq
instance NFData CipherText
where
rnf _sk = ()
makeCipherText :: BS.ByteString -> CipherText
makeCipherText bs
| not lengthOK = error "Invalid length for CipherText"
| otherwise = CipherText bs
where
len = BS.length bs
len512 = cipherTextBytes N512
len1024 = cipherTextBytes N1024
lengthOK = len == len512 || len == len1024
getCTn :: CipherText -> Internals.N
getCTn (CipherText ctData)
| len == len512 = N512
| len == len1024 = N1024
| otherwise = error "Invalid N for CipherText"
where
len = BS.length ctData
len512 = cipherTextBytes N512
len1024 = cipherTextBytes N1024
getCTData :: CipherText -> BS.ByteString
getCTData (CipherText ctData) = ctData
getCtCTData :: CipherText -> BS.ByteString
getCtCTData ct@(CipherText ctData) = encoded
where
n = getCTn ct
len = CPA_PKE.cipherTextBytes n
encoded = BS.take len ctData
getCtCT :: CipherText -> CPA_PKE.CipherText
getCtCT ct = CPA_PKE.makeCipherTextFromBytes $ getCtCTData ct
keypair :: Context -> Internals.N -> (PublicKey, SecretKey, Context)
keypair ctx n = (makePublicKey pkData, makeSecretKey skData, ctx1)
where
(cpaPkePk, cpaPkeSk, ctx0) = CPA_PKE.keypair ctx n
(extra, ctx1) = randomBytes ctx0 symBytes
pkData = CPA_PKE.getPKData cpaPkePk
skParts = [ CPA_PKE.getSKPolyData cpaPkeSk
, pkData
, shake256 pkData symBytes
, extra
]
skData = foldr BS.append BS.empty skParts
encrypt :: Context -> PublicKey -> (CipherText, SharedSecret, Context)
encrypt ctx pk = (makeCipherText ctData, makeSharedSecret ss, ctx')
where
pkData = getPKData pk
(buf, ctx') = randomBytes ctx symBytes
bufP1Shaken = shake256 buf symBytes
(coin0, coin12) = let bufPart2 = shake256 pkData symBytes
buf' = BS.append bufP1Shaken bufPart2
kCoinsD = shake256 buf' (3 * symBytes)
in BS.splitAt symBytes kCoinsD
(coin1, coin2) = BS.splitAt symBytes coin12
ct = let seed = Internals.makeSeed coin1
pt = CPA_PKE.makePlainText bufP1Shaken
pk' = CPA_PKE.makePublicKeyFromBytes pkData
in CPA_PKE.encrypt pt pk' seed
ctData = let ctData' = CPA_PKE.getCTData ct
in BS.append ctData' coin2
coin1' = let n = getPKn pk
cipherTextBytes' = cipherTextBytes n
in shake256 (BS.take cipherTextBytes' ctData) symBytes
kCoinsD' = BS.append coin0 coin1'
ss = shake256 kCoinsD' symBytes
decrypt :: CipherText -> SecretKey -> (Bool, SharedSecret)
decrypt ct sk = (success, ss)
where
ctData = getCTData ct
buf = let buf' = let publicText = CPA_PKE.decrypt (getCtCT ct) (getSkSecretKey sk)
in CPA_PKE.getPTData publicText
in BS.append buf' $ getSkPkHash sk
kCoinsD = shake256 buf (3 * symBytes)
ctCmp = let coin2 = bsRange kCoinsD (2 * symBytes) symBytes
ctCmp' = let seed = bsRange kCoinsD symBytes symBytes
bufp1 = BS.take symBytes buf
in CPA_PKE.encrypt (CPA_PKE.makePlainText bufp1) (getSkPkePublicKey sk) (Internals.makeSeed seed)
in BS.append (CPA_PKE.getCTData ctCmp') coin2
success = verify ctData ctCmp
ssData = let firstTwoCoins = let kCoinsD' = let coin1 = shake256 ctData symBytes
in bsReplace kCoinsD symBytes coin1
replacementCoin = let coin0 = BS.take symBytes kCoinsD'
z = getSkZ sk
in constantTimeChoose success coin0 z
kCoinsD'' = bsReplace kCoinsD' 0 replacementCoin
in BS.take (2 * symBytes) kCoinsD''
in shake256 firstTwoCoins symBytes
!ss = makeSharedSecret ssData