-----------------------------------------------------------------------------
-- |
-- Module    : Documentation.SBV.Examples.Crypto.RC4
-- Copyright : (c) Austin Seipp
-- License   : BSD3
-- Maintainer: erkokl@gmail.com
-- Stability : experimental
--
-- An implementation of RC4 (AKA Rivest Cipher 4 or Alleged RC4/ARC4),
-- using SBV. For information on RC4, see: <http://en.wikipedia.org/wiki/RC4>.
--
-- We make no effort to optimize the code, and instead focus on a clear
-- implementation. In fact, the RC4 algorithm relies on in-place update of
-- its state heavily for efficiency, and is therefore unsuitable for a purely
-- functional implementation.
-----------------------------------------------------------------------------

{-# LANGUAGE ScopedTypeVariables #-}

{-# OPTIONS_GHC -Wall -Werror #-}

module Documentation.SBV.Examples.Crypto.RC4 where

import Data.Char  (ord, chr)
import Data.List  (genericIndex)
import Data.Maybe (fromJust)
import Data.SBV

import Data.SBV.Tools.STree

import Numeric (showHex)

-----------------------------------------------------------------------------
-- * Types
-----------------------------------------------------------------------------

-- | RC4 State contains 256 8-bit values. We use the symbolically accessible
-- full-binary type 'STree' to represent the state, since RC4 needs
-- access to the array via a symbolic index and it's important to minimize access time.
type S = STree Word8 Word8

-- | Construct the fully balanced initial tree, where the leaves are simply the numbers @0@ through @255@.
initS :: S
initS :: S
initS = forall i e. HasKind i => [SBV e] -> STree i e
mkSTree (forall a b. (a -> b) -> [a] -> [b]
map forall a. SymVal a => a -> SBV a
literal [Word8
0 .. Word8
255])

-- | The key is a stream of 'Word8' values.
type Key = [SWord8]

-- | Represents the current state of the RC4 stream: it is the @S@ array
-- along with the @i@ and @j@ index values used by the PRGA.
type RC4 = (S, SWord8, SWord8)

-----------------------------------------------------------------------------
-- * The PRGA
-----------------------------------------------------------------------------

-- | Swaps two elements in the RC4 array.
swap :: SWord8 -> SWord8 -> S -> S
swap :: SWord8 -> SWord8 -> S -> S
swap SWord8
i SWord8
j S
st = forall i e.
(SFiniteBits i, SymVal e) =>
STree i e -> SBV i -> SBV e -> STree i e
writeSTree (forall i e.
(SFiniteBits i, SymVal e) =>
STree i e -> SBV i -> SBV e -> STree i e
writeSTree S
st SWord8
i SWord8
stj) SWord8
j SWord8
sti
  where sti :: SWord8
sti = forall i e.
(SFiniteBits i, SymVal e) =>
STree i e -> SBV i -> SBV e
readSTree S
st SWord8
i
        stj :: SWord8
stj = forall i e.
(SFiniteBits i, SymVal e) =>
STree i e -> SBV i -> SBV e
readSTree S
st SWord8
j

-- | Implements the PRGA used in RC4. We return the new state and the next key value generated.
prga :: RC4 -> (SWord8, RC4)
prga :: RC4 -> (SWord8, RC4)
prga (S
st', SWord8
i', SWord8
j') = (forall i e.
(SFiniteBits i, SymVal e) =>
STree i e -> SBV i -> SBV e
readSTree S
st SWord8
kInd, (S
st, SWord8
i, SWord8
j))
  where i :: SWord8
i    = SWord8
i' forall a. Num a => a -> a -> a
+ SWord8
1
        j :: SWord8
j    = SWord8
j' forall a. Num a => a -> a -> a
+ forall i e.
(SFiniteBits i, SymVal e) =>
STree i e -> SBV i -> SBV e
readSTree S
st' SWord8
i
        st :: S
st   = SWord8 -> SWord8 -> S -> S
swap SWord8
i SWord8
j S
st'
        kInd :: SWord8
kInd = forall i e.
(SFiniteBits i, SymVal e) =>
STree i e -> SBV i -> SBV e
readSTree S
st SWord8
i forall a. Num a => a -> a -> a
+ forall i e.
(SFiniteBits i, SymVal e) =>
STree i e -> SBV i -> SBV e
readSTree S
st SWord8
j

-----------------------------------------------------------------------------
-- * Key schedule
-----------------------------------------------------------------------------

-- | Constructs the state to be used by the PRGA using the given key.
initRC4 :: Key -> S
initRC4 :: [SWord8] -> S
initRC4 [SWord8]
key
 | Int
keyLength forall a. Ord a => a -> a -> Bool
< Int
1 Bool -> Bool -> Bool
|| Int
keyLength forall a. Ord a => a -> a -> Bool
> Int
256
 = forall a. HasCallStack => [Char] -> a
error forall a b. (a -> b) -> a -> b
$ [Char]
"RC4 requires a key of length between 1 and 256, received: " forall a. [a] -> [a] -> [a]
++ forall a. Show a => a -> [Char]
show Int
keyLength
 | Bool
True
 = forall a b. (a, b) -> b
snd forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl (SWord8, S) -> SWord8 -> (SWord8, S)
mix (SWord8
0, S
initS) [SWord8
0..SWord8
255]
 where keyLength :: Int
keyLength = forall (t :: * -> *) a. Foldable t => t a -> Int
length [SWord8]
key
       mix :: (SWord8, S) -> SWord8 -> (SWord8, S)
       mix :: (SWord8, S) -> SWord8 -> (SWord8, S)
mix (SWord8
j', S
s) SWord8
i = let j :: SWord8
j = SWord8
j' forall a. Num a => a -> a -> a
+ forall i e.
(SFiniteBits i, SymVal e) =>
STree i e -> SBV i -> SBV e
readSTree S
s SWord8
i forall a. Num a => a -> a -> a
+ forall i a. Integral i => [a] -> i -> a
genericIndex [SWord8]
key (forall a. HasCallStack => Maybe a -> a
fromJust (forall a. SymVal a => SBV a -> Maybe a
unliteral SWord8
i) forall a. Integral a => a -> a -> a
`mod` forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
keyLength)
                       in (SWord8
j, SWord8 -> SWord8 -> S -> S
swap SWord8
i SWord8
j S
s)

-- | The key-schedule. Note that this function returns an infinite list.
keySchedule :: Key -> [SWord8]
keySchedule :: [SWord8] -> [SWord8]
keySchedule [SWord8]
key = RC4 -> [SWord8]
genKeys ([SWord8] -> S
initRC4 [SWord8]
key, SWord8
0, SWord8
0)
  where genKeys :: RC4 -> [SWord8]
        genKeys :: RC4 -> [SWord8]
genKeys RC4
st = let (SWord8
k, RC4
st') = RC4 -> (SWord8, RC4)
prga RC4
st in SWord8
k forall a. a -> [a] -> [a]
: RC4 -> [SWord8]
genKeys RC4
st'

-- | Generate a key-schedule from a given key-string.
keyScheduleString :: String -> [SWord8]
keyScheduleString :: [Char] -> [SWord8]
keyScheduleString = [SWord8] -> [SWord8]
keySchedule forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a -> b) -> [a] -> [b]
map (forall a. SymVal a => a -> SBV a
literal forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (Integral a, Num b) => a -> b
fromIntegral forall b c a. (b -> c) -> (a -> b) -> a -> c
. Char -> Int
ord)

-----------------------------------------------------------------------------
-- * Encryption and Decryption
-----------------------------------------------------------------------------

-- | RC4 encryption. We generate key-words and xor it with the input. The
-- following test-vectors are from Wikipedia <http://en.wikipedia.org/wiki/RC4>:
--
-- >>> concatMap hex2 $ encrypt "Key" "Plaintext"
-- "bbf316e8d940af0ad3"
--
-- >>> concatMap hex2 $ encrypt "Wiki" "pedia"
-- "1021bf0420"
--
-- >>> concatMap hex2 $ encrypt "Secret" "Attack at dawn"
-- "45a01f645fc35b383552544b9bf5"
encrypt :: String -> String -> [SWord8]
encrypt :: [Char] -> [Char] -> [SWord8]
encrypt [Char]
key [Char]
pt = forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith forall a. Bits a => a -> a -> a
xor ([Char] -> [SWord8]
keyScheduleString [Char]
key) (forall a b. (a -> b) -> [a] -> [b]
map Char -> SWord8
cvt [Char]
pt)
  where cvt :: Char -> SWord8
cvt = forall a. SymVal a => a -> SBV a
literal forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (Integral a, Num b) => a -> b
fromIntegral forall b c a. (b -> c) -> (a -> b) -> a -> c
. Char -> Int
ord

-- | RC4 decryption. Essentially the same as decryption. For the above test vectors we have:
--
-- >>> decrypt "Key" [0xbb, 0xf3, 0x16, 0xe8, 0xd9, 0x40, 0xaf, 0x0a, 0xd3]
-- "Plaintext"
--
-- >>> decrypt "Wiki" [0x10, 0x21, 0xbf, 0x04, 0x20]
-- "pedia"
--
-- >>> decrypt "Secret" [0x45, 0xa0, 0x1f, 0x64, 0x5f, 0xc3, 0x5b, 0x38, 0x35, 0x52, 0x54, 0x4b, 0x9b, 0xf5]
-- "Attack at dawn"
decrypt :: String -> [SWord8] -> String
decrypt :: [Char] -> [SWord8] -> [Char]
decrypt [Char]
key [SWord8]
ct = forall a b. (a -> b) -> [a] -> [b]
map SWord8 -> Char
cvt forall a b. (a -> b) -> a -> b
$ forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith forall a. Bits a => a -> a -> a
xor ([Char] -> [SWord8]
keyScheduleString [Char]
key) [SWord8]
ct
  where cvt :: SWord8 -> Char
cvt = Int -> Char
chr forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (Integral a, Num b) => a -> b
fromIntegral forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. HasCallStack => Maybe a -> a
fromJust forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. SymVal a => SBV a -> Maybe a
unliteral

-----------------------------------------------------------------------------
-- * Verification
-----------------------------------------------------------------------------

-- | Prove that round-trip encryption/decryption leaves the plain-text unchanged.
-- The theorem is stated parametrically over key and plain-text sizes. The expression
-- performs the proof for a 40-bit key (5 bytes) and 40-bit plaintext (again 5 bytes).
--
-- Note that this theorem is trivial to prove, since it is essentially establishing
-- xor'in the same value twice leaves a word unchanged (i.e., @x `xor` y `xor` y = x@).
-- However, the proof takes quite a while to complete, as it gives rise to a fairly
-- large symbolic trace.
rc4IsCorrect :: IO ThmResult
rc4IsCorrect :: IO ThmResult
rc4IsCorrect = forall a. Provable a => a -> IO ThmResult
prove forall a b. (a -> b) -> a -> b
$ do
        [SWord8]
key <- forall a. SymVal a => Int -> Symbolic [SBV a]
mkForallVars Int
5
        [SWord8]
pt  <- forall a. SymVal a => Int -> Symbolic [SBV a]
mkForallVars Int
5
        let ks :: [SWord8]
ks  = [SWord8] -> [SWord8]
keySchedule [SWord8]
key
            ct :: [SWord8]
ct  = forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith forall a. Bits a => a -> a -> a
xor [SWord8]
ks [SWord8]
pt
            pt' :: [SWord8]
pt' = forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith forall a. Bits a => a -> a -> a
xor [SWord8]
ks [SWord8]
ct
        forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ [SWord8]
pt forall a. EqSymbolic a => a -> a -> SBool
.== [SWord8]
pt'

--------------------------------------------------------------------------------------------
-- | For doctest purposes only
hex2 :: (SymVal a, Show a, Integral a) => SBV a -> String
hex2 :: forall a. (SymVal a, Show a, Integral a) => SBV a -> [Char]
hex2 SBV a
v = forall a. Int -> a -> [a]
replicate (Int
2 forall a. Num a => a -> a -> a
- forall (t :: * -> *) a. Foldable t => t a -> Int
length [Char]
s) Char
'0' forall a. [a] -> [a] -> [a]
++ [Char]
s
  where s :: [Char]
s = forall a b c. (a -> b -> c) -> b -> a -> c
flip forall a. (Integral a, Show a) => a -> ShowS
showHex [Char]
"" forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. HasCallStack => Maybe a -> a
fromJust forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. SymVal a => SBV a -> Maybe a
unliteral forall a b. (a -> b) -> a -> b
$ SBV a
v