-----------------------------------------------------------------------------
-- |
-- Module    : Documentation.SBV.Examples.Crypto.Prince
-- Copyright : (c) Levent Erkok
-- License   : BSD3
-- Maintainer: erkokl@gmail.com
-- Stability : experimental
--
-- Implementation of Prince encryption and decrytion, following the spec
-- <https://eprint.iacr.org/2012/529.pdf>
-----------------------------------------------------------------------------

{-# LANGUAGE DataKinds        #-}
{-# LANGUAGE ParallelListComp #-}

{-# OPTIONS_GHC -Wno-incomplete-uni-patterns #-}

module Documentation.SBV.Examples.Crypto.Prince where

import Prelude hiding(round)
import Numeric

import Data.SBV
import Data.SBV.Tools.CodeGen

-- $setup
-- >>> -- For doctest purposes only:
-- >>> import Data.SBV

-- * Types
-- | Section 2: Prince is essentially a 64-bit cipher, with 128-bit key, coming in two parts.
type Block = SWord 64

-- | Plantext is simply a block.
type PT = Block

-- | Key is again a 64-bit block.
type Key = Block

-- | Cypher text is another 64-bit block.
type CT = Block

-- | A nibble is 4-bits. Ideally, we would like to represent a nibble by @SWord 4@; and indeed SBV can do that for
-- verification purposes just fine. Unfortunately, the SBV's C compiler doesn't support 4-bit bit-vectors, as
-- there's nothing meaningful in the C-land that we can map it to. Thus, we represent a nibble with 8-bits. The
-- top 4 bits will always be 0.
type Nibble = SWord 8

-- * Key expansion

-- | Expanding a key, from Section 3.4 of the spec.
expandKey :: Key -> Key
expandKey :: SWord 64 -> SWord 64
expandKey SWord 64
k = (SWord 64
k SWord 64 -> Int -> SWord 64
forall a. Bits a => a -> Int -> a
`rotateR` Int
1) SWord 64 -> SWord 64 -> SWord 64
forall a. Bits a => a -> a -> a
`xor` (SWord 64
k SWord 64 -> Int -> SWord 64
forall a. Bits a => a -> Int -> a
`shiftR` Int
63)

-- | expandKey(x) = x has a unique solution. We have:
--
-- >>> prop_ExpandKey
-- Q.E.D.
prop_ExpandKey :: IO ()
prop_ExpandKey :: IO ()
prop_ExpandKey = do let lim :: Int
lim = Int
10
                    [WordN 64]
ms <- AllSatResult -> [WordN 64]
forall a. SatModel a => AllSatResult -> [a]
extractModels (AllSatResult -> [WordN 64]) -> IO AllSatResult -> IO [WordN 64]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SMTConfig -> (SWord 64 -> SBool) -> IO AllSatResult
forall a. Satisfiable a => SMTConfig -> a -> IO AllSatResult
allSatWith SMTConfig
z3{allSatMaxModelCount = Just lim}
                                                       (\SWord 64
x -> SWord 64
x SWord 64 -> SWord 64 -> SBool
forall a. EqSymbolic a => a -> a -> SBool
.== SWord 64 -> SWord 64
expandKey SWord 64
x)
                    case [WordN 64] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [WordN 64]
ms of
                      Int
0 -> String -> IO ()
putStrLn String
"No solutions to equation `x == expandKey x`!"
                      Int
1 -> String -> IO ()
putStrLn String
"Q.E.D."
                      Int
n -> do let qual :: String
qual = if Int
n Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
lim then String
"at least " else String
""
                              String -> IO ()
putStrLn (String -> IO ()) -> String -> IO ()
forall a b. (a -> b) -> a -> b
$ String
"Failed. There are " String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
qual String -> String -> String
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show Int
n String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
" solutions to `x == expandKey x`!"
                              (WordN 64 -> IO ()) -> [WordN 64] -> IO ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (\WordN 64
i -> String -> IO ()
putStrLn (String
"    " String -> String -> String
forall a. [a] -> [a] -> [a]
++ WordN 64 -> String
forall a. Show a => a -> String
show WordN 64
i)) ([WordN 64]
ms :: [WordN 64])


-- | Section 2: Encryption
encrypt :: PT -> Key -> Key -> CT
encrypt :: SWord 64 -> SWord 64 -> SWord 64 -> SWord 64
encrypt SWord 64
pt SWord 64
k0 SWord 64
k1 = SWord 64 -> SWord 64 -> SWord 64 -> SWord 64 -> SWord 64
prince SWord 64
k0 SWord 64
k0' SWord 64
k1 SWord 64
pt
   where k0' :: SWord 64
k0' = SWord 64 -> SWord 64
expandKey SWord 64
k0

-- | Decryption
decrypt :: CT -> Key -> Key -> PT
decrypt :: SWord 64 -> SWord 64 -> SWord 64 -> SWord 64
decrypt SWord 64
ct SWord 64
k0 SWord 64
k1 = SWord 64 -> SWord 64 -> SWord 64 -> SWord 64 -> SWord 64
prince SWord 64
k0' SWord 64
k0 (SWord 64
k1 SWord 64 -> SWord 64 -> SWord 64
forall a. Bits a => a -> a -> a
`xor` SWord 64
alpha) SWord 64
ct
  where k0' :: SWord 64
k0'   = SWord 64 -> SWord 64
expandKey SWord 64
k0
        alpha :: SWord 64
alpha = SWord 64
0xc0ac29b7c97c50dd

-- * Main algorithm

-- | Basic prince algorithm
prince :: Block -> Key -> Key -> Key -> Block
prince :: SWord 64 -> SWord 64 -> SWord 64 -> SWord 64 -> SWord 64
prince SWord 64
k0 SWord 64
k0' SWord 64
k1 SWord 64
inp = SWord 64
out
   where start :: SWord 64
start = SWord 64
inp SWord 64 -> SWord 64 -> SWord 64
forall a. Bits a => a -> a -> a
`xor` SWord 64
k0
         end :: SWord 64
end   = SWord 64 -> SWord 64 -> SWord 64
princeCore SWord 64
k1 SWord 64
start
         out :: SWord 64
out   = SWord 64
end SWord 64 -> SWord 64 -> SWord 64
forall a. Bits a => a -> a -> a
`xor` SWord 64
k0'

-- | Core prince. It's essentially folding of 12 rounds stitched together:
princeCore :: Key -> Block -> Block
princeCore :: SWord 64 -> SWord 64 -> SWord 64
princeCore SWord 64
k1 SWord 64
inp = SWord 64
end
   where start :: SWord 64
start    = SWord 64
inp SWord 64 -> SWord 64 -> SWord 64
forall a. Bits a => a -> a -> a
`xor` SWord 64
k1 SWord 64 -> SWord 64 -> SWord 64
forall a. Bits a => a -> a -> a
`xor` Int -> SWord 64
rConstants Int
0
         front5 :: SWord 64
front5   = (SWord 64 -> Int -> SWord 64) -> SWord 64 -> [Int] -> SWord 64
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl (SWord 64 -> SWord 64 -> Int -> SWord 64
round SWord 64
k1) SWord 64
start    [Int
1 .. Int
5]
         midPoint :: SWord 64
midPoint = SWord 64 -> SWord 64
sBoxInv (SWord 64 -> SWord 64)
-> (SWord 64 -> SWord 64) -> SWord 64 -> SWord 64
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SWord 64 -> SWord 64
m' (SWord 64 -> SWord 64)
-> (SWord 64 -> SWord 64) -> SWord 64 -> SWord 64
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SWord 64 -> SWord 64
sBox (SWord 64 -> SWord 64) -> SWord 64 -> SWord 64
forall a b. (a -> b) -> a -> b
$ SWord 64
front5
         back5 :: SWord 64
back5    = (SWord 64 -> Int -> SWord 64) -> SWord 64 -> [Int] -> SWord 64
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl (SWord 64 -> SWord 64 -> Int -> SWord 64
invRound SWord 64
k1) SWord 64
midPoint [Int
6..Int
10]
         end :: SWord 64
end      = SWord 64
back5 SWord 64 -> SWord 64 -> SWord 64
forall a. Bits a => a -> a -> a
`xor` Int -> SWord 64
rConstants Int
11 SWord 64 -> SWord 64 -> SWord 64
forall a. Bits a => a -> a -> a
`xor` SWord 64
k1

-- | Forward round.
round :: Key -> Block -> Int -> Block
round :: SWord 64 -> SWord 64 -> Int -> SWord 64
round SWord 64
k1 SWord 64
b Int
i = SWord 64
k1 SWord 64 -> SWord 64 -> SWord 64
forall a. Bits a => a -> a -> a
`xor` Int -> SWord 64
rConstants Int
i SWord 64 -> SWord 64 -> SWord 64
forall a. Bits a => a -> a -> a
`xor` SWord 64 -> SWord 64
m (SWord 64 -> SWord 64
sBox SWord 64
b)

-- | Backend round.
invRound :: Key -> Block -> Int -> Block
invRound :: SWord 64 -> SWord 64 -> Int -> SWord 64
invRound SWord 64
k1 SWord 64
b Int
i = SWord 64 -> SWord 64
sBoxInv (SWord 64 -> SWord 64
mInv (Int -> SWord 64
rConstants Int
i SWord 64 -> SWord 64 -> SWord 64
forall a. Bits a => a -> a -> a
`xor` (SWord 64
b SWord 64 -> SWord 64 -> SWord 64
forall a. Bits a => a -> a -> a
`xor` SWord 64
k1)))

-- | M transformation.
m :: Block -> Block
m :: SWord 64 -> SWord 64
m = SWord 64 -> SWord 64
sr (SWord 64 -> SWord 64)
-> (SWord 64 -> SWord 64) -> SWord 64 -> SWord 64
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SWord 64 -> SWord 64
m'

-- | Inverse of M.
mInv :: Block -> Block
mInv :: SWord 64 -> SWord 64
mInv = SWord 64 -> SWord 64
m' (SWord 64 -> SWord 64)
-> (SWord 64 -> SWord 64) -> SWord 64 -> SWord 64
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SWord 64 -> SWord 64
srInv

-- | SR.
sr :: Block -> Block
sr :: SWord 64 -> SWord 64
sr SWord 64
b = [SWord 8] -> SWord 64
fromNibbles [SWord 8
n0, SWord 8
n5, SWord 8
n10, SWord 8
n15, SWord 8
n4, SWord 8
n9, SWord 8
n14, SWord 8
n3, SWord 8
n8, SWord 8
n13, SWord 8
n2, SWord 8
n7, SWord 8
n12, SWord 8
n1, SWord 8
n6, SWord 8
n11]
  where [SWord 8
n0, SWord 8
n1, SWord 8
n2, SWord 8
n3, SWord 8
n4, SWord 8
n5, SWord 8
n6, SWord 8
n7, SWord 8
n8, SWord 8
n9, SWord 8
n10, SWord 8
n11, SWord 8
n12, SWord 8
n13, SWord 8
n14, SWord 8
n15] = SWord 64 -> [SWord 8]
toNibbles SWord 64
b

-- | Inverse of SR:
srInv :: Block -> Block
srInv :: SWord 64 -> SWord 64
srInv SWord 64
b = [SWord 8] -> SWord 64
fromNibbles [SWord 8
n0, SWord 8
n1, SWord 8
n2, SWord 8
n3, SWord 8
n4, SWord 8
n5, SWord 8
n6, SWord 8
n7, SWord 8
n8, SWord 8
n9, SWord 8
n10, SWord 8
n11, SWord 8
n12, SWord 8
n13, SWord 8
n14, SWord 8
n15]
  where [SWord 8
n0, SWord 8
n5, SWord 8
n10, SWord 8
n15, SWord 8
n4, SWord 8
n9, SWord 8
n14, SWord 8
n3, SWord 8
n8, SWord 8
n13, SWord 8
n2, SWord 8
n7, SWord 8
n12, SWord 8
n1, SWord 8
n6, SWord 8
n11] = SWord 64 -> [SWord 8]
toNibbles SWord 64
b

-- | Prove sr and srInv are inverses: We have:
--
-- >>> prove prop_sr
-- Q.E.D.
prop_sr :: Predicate
prop_sr :: Predicate
prop_sr = do SWord 64
b <- String -> Symbolic (SWord 64)
forall a. SymVal a => String -> Symbolic (SBV a)
free String
"block"
             SBool -> Predicate
forall a. a -> SymbolicT IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (SBool -> Predicate) -> SBool -> Predicate
forall a b. (a -> b) -> a -> b
$   SWord 64
b SWord 64 -> SWord 64 -> SBool
forall a. EqSymbolic a => a -> a -> SBool
.== SWord 64 -> SWord 64
sr (SWord 64 -> SWord 64
srInv SWord 64
b)
                    SBool -> SBool -> SBool
.&& SWord 64
b SWord 64 -> SWord 64 -> SBool
forall a. EqSymbolic a => a -> a -> SBool
.== SWord 64 -> SWord 64
srInv (SWord 64 -> SWord 64
sr SWord 64
b)

-- | M' transformation
m' :: Block -> Block
m' :: SWord 64 -> SWord 64
m' = SWord 64 -> SWord 64
mMult

-- | The matrix as described in Section 3.3
mat :: [[Int]]
mat :: [[Int]]
mat = [[Int]]
res
  where m0 :: [[Int]]
m0 = [[Int
0, Int
0, Int
0, Int
0], [Int
0, Int
1, Int
0, Int
0], [Int
0, Int
0, Int
1, Int
0], [Int
0, Int
0, Int
0, Int
1]]
        m1 :: [[Int]]
m1 = [[Int
1, Int
0, Int
0, Int
0], [Int
0, Int
0, Int
0, Int
0], [Int
0, Int
0, Int
1, Int
0], [Int
0, Int
0, Int
0, Int
1]]
        m2 :: [[Int]]
m2 = [[Int
1, Int
0, Int
0, Int
0], [Int
0, Int
1, Int
0, Int
0], [Int
0, Int
0, Int
0, Int
0], [Int
0, Int
0, Int
0, Int
1]]
        m3 :: [[Int]]
m3 = [[Int
1, Int
0, Int
0, Int
0], [Int
0, Int
1, Int
0, Int
0], [Int
0, Int
0, Int
1, Int
0], [Int
0, Int
0, Int
0, Int
0]]

        rows :: [[a]] -> [[a]] -> [[a]] -> [[a]] -> [[a]]
rows [[a]]
as [[a]]
bs [[a]]
cs [[a]]
ds = [[a]
a [a] -> [a] -> [a]
forall a. [a] -> [a] -> [a]
++ [a]
b [a] -> [a] -> [a]
forall a. [a] -> [a] -> [a]
++ [a]
c [a] -> [a] -> [a]
forall a. [a] -> [a] -> [a]
++ [a]
d | [a]
a <- [[a]]
as | [a]
b <- [[a]]
bs | [a]
c <- [[a]]
cs | [a]
d <- [[a]]
ds ]

        m0' :: [[Int]]
m0' = [[[Int]]] -> [[Int]]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[[Int]] -> [[Int]] -> [[Int]] -> [[Int]] -> [[Int]]
forall {a}. [[a]] -> [[a]] -> [[a]] -> [[a]] -> [[a]]
rows [[Int]]
m0 [[Int]]
m1 [[Int]]
m2 [[Int]]
m3, [[Int]] -> [[Int]] -> [[Int]] -> [[Int]] -> [[Int]]
forall {a}. [[a]] -> [[a]] -> [[a]] -> [[a]] -> [[a]]
rows [[Int]]
m1 [[Int]]
m2 [[Int]]
m3 [[Int]]
m0, [[Int]] -> [[Int]] -> [[Int]] -> [[Int]] -> [[Int]]
forall {a}. [[a]] -> [[a]] -> [[a]] -> [[a]] -> [[a]]
rows [[Int]]
m2 [[Int]]
m3 [[Int]]
m0 [[Int]]
m1, [[Int]] -> [[Int]] -> [[Int]] -> [[Int]] -> [[Int]]
forall {a}. [[a]] -> [[a]] -> [[a]] -> [[a]] -> [[a]]
rows [[Int]]
m3 [[Int]]
m0 [[Int]]
m1 [[Int]]
m2]
        m1' :: [[Int]]
m1' = [[[Int]]] -> [[Int]]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[[Int]] -> [[Int]] -> [[Int]] -> [[Int]] -> [[Int]]
forall {a}. [[a]] -> [[a]] -> [[a]] -> [[a]] -> [[a]]
rows [[Int]]
m1 [[Int]]
m2 [[Int]]
m3 [[Int]]
m0, [[Int]] -> [[Int]] -> [[Int]] -> [[Int]] -> [[Int]]
forall {a}. [[a]] -> [[a]] -> [[a]] -> [[a]] -> [[a]]
rows [[Int]]
m2 [[Int]]
m3 [[Int]]
m0 [[Int]]
m1, [[Int]] -> [[Int]] -> [[Int]] -> [[Int]] -> [[Int]]
forall {a}. [[a]] -> [[a]] -> [[a]] -> [[a]] -> [[a]]
rows [[Int]]
m3 [[Int]]
m0 [[Int]]
m1 [[Int]]
m2, [[Int]] -> [[Int]] -> [[Int]] -> [[Int]] -> [[Int]]
forall {a}. [[a]] -> [[a]] -> [[a]] -> [[a]] -> [[a]]
rows [[Int]]
m0 [[Int]]
m1 [[Int]]
m2 [[Int]]
m3]

        zs :: [[Int]]
zs  = Int -> [Int] -> [[Int]]
forall a. Int -> a -> [a]
replicate Int
16 (Int -> Int -> [Int]
forall a. Int -> a -> [a]
replicate Int
16 Int
0)
        res :: [[Int]]
res = [[[Int]]] -> [[Int]]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[[Int]] -> [[Int]] -> [[Int]] -> [[Int]] -> [[Int]]
forall {a}. [[a]] -> [[a]] -> [[a]] -> [[a]] -> [[a]]
rows [[Int]]
m0' [[Int]]
zs  [[Int]]
zs  [[Int]]
zs, [[Int]] -> [[Int]] -> [[Int]] -> [[Int]] -> [[Int]]
forall {a}. [[a]] -> [[a]] -> [[a]] -> [[a]] -> [[a]]
rows [[Int]]
zs  [[Int]]
m1' [[Int]]
zs  [[Int]]
zs, [[Int]] -> [[Int]] -> [[Int]] -> [[Int]] -> [[Int]]
forall {a}. [[a]] -> [[a]] -> [[a]] -> [[a]] -> [[a]]
rows [[Int]]
zs  [[Int]]
zs  [[Int]]
m1' [[Int]]
zs, [[Int]] -> [[Int]] -> [[Int]] -> [[Int]] -> [[Int]]
forall {a}. [[a]] -> [[a]] -> [[a]] -> [[a]] -> [[a]]
rows [[Int]]
zs  [[Int]]
zs  [[Int]]
zs  [[Int]]
m0']

-- | Multiplication.
mMult :: Block -> Block
mMult :: SWord 64 -> SWord 64
mMult SWord 64
b | [[Int]] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [[Int]]
mat Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
64           = String -> SWord 64
forall a. HasCallStack => String -> a
error (String -> SWord 64) -> String -> SWord 64
forall a b. (a -> b) -> a -> b
$ String
"mMult: Expected 64 rows, got       : " String -> String -> String
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show ([[Int]] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [[Int]]
mat)
        | ([Int] -> Bool) -> [[Int]] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any ((Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
64) (Int -> Bool) -> ([Int] -> Int) -> [Int] -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Int] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length) [[Int]]
mat = String -> SWord 64
forall a. HasCallStack => String -> a
error (String -> SWord 64) -> String -> SWord 64
forall a b. (a -> b) -> a -> b
$ String
"mMult: Expected 64 on each row, got: " String -> String -> String
forall a. [a] -> [a] -> [a]
++ [(Int, Int)] -> String
forall a. Show a => a -> String
show [(Int, Int)
p | p :: (Int, Int)
p@(Int
_, Int
l) <- [Int] -> [Int] -> [(Int, Int)]
forall a b. [a] -> [b] -> [(a, b)]
zip [(Int
1::Int)..] (([Int] -> Int) -> [[Int]] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map [Int] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [[Int]]
mat), Int
l Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
64]
        | Bool
True                       = [SBool] -> SWord 64
forall a. SFiniteBits a => [SBool] -> SBV a
fromBitsBE ([SBool] -> SWord 64) -> [SBool] -> SWord 64
forall a b. (a -> b) -> a -> b
$ ([Int] -> SBool) -> [[Int]] -> [SBool]
forall a b. (a -> b) -> [a] -> [b]
map [Int] -> SBool
mult [[Int]]
mat
  where bits :: [SBool]
bits = SWord 64 -> [SBool]
forall a. SFiniteBits a => SBV a -> [SBool]
blastBE SWord 64
b

        mult :: [Int] -> SBool
        mult :: [Int] -> SBool
mult [Int]
row = (SBool -> SBool -> SBool) -> SBool -> [SBool] -> SBool
forall a b. (a -> b -> b) -> b -> [a] -> b
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr SBool -> SBool -> SBool
(.<+>) SBool
sFalse ([SBool] -> SBool) -> [SBool] -> SBool
forall a b. (a -> b) -> a -> b
$ (Int -> SBool -> SBool) -> [Int] -> [SBool] -> [SBool]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Int -> SBool -> SBool
mul [Int]
row [SBool]
bits

        mul :: Int -> SBool -> SBool
        mul :: Int -> SBool -> SBool
mul Int
0 SBool
_ = SBool
sFalse
        mul Int
1 SBool
v = SBool
v
        mul Int
i SBool
_ = String -> SBool
forall a. HasCallStack => String -> a
error (String -> SBool) -> String -> SBool
forall a b. (a -> b) -> a -> b
$ String
"mMult: Unexpected constant: " String -> String -> String
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show Int
i

-- | Non-linear transformation of a block
nonLinear :: [Nibble] -> Nibble -> Block -> Block
nonLinear :: [SWord 8] -> SWord 8 -> SWord 64 -> SWord 64
nonLinear [SWord 8]
box SWord 8
def = [SWord 8] -> SWord 64
fromNibbles ([SWord 8] -> SWord 64)
-> (SWord 64 -> [SWord 8]) -> SWord 64 -> SWord 64
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (SWord 8 -> SWord 8) -> [SWord 8] -> [SWord 8]
forall a b. (a -> b) -> [a] -> [b]
map SWord 8 -> SWord 8
s ([SWord 8] -> [SWord 8])
-> (SWord 64 -> [SWord 8]) -> SWord 64 -> [SWord 8]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SWord 64 -> [SWord 8]
toNibbles
  where s :: Nibble -> Nibble
        s :: SWord 8 -> SWord 8
s = [SWord 8] -> SWord 8 -> SWord 8 -> SWord 8
forall b.
(Ord b, SymVal b, Num b) =>
[SWord 8] -> SWord 8 -> SBV b -> SWord 8
forall a b.
(Mergeable a, Ord b, SymVal b, Num b) =>
[a] -> a -> SBV b -> a
select [SWord 8]
box SWord 8
def

-- | SBox transformation.
sBox :: Block -> Block
sBox :: SWord 64 -> SWord 64
sBox = [SWord 8] -> SWord 8 -> SWord 64 -> SWord 64
nonLinear [SWord 8
0xB, SWord 8
0xF, SWord 8
0x3, SWord 8
0x2, SWord 8
0xA, SWord 8
0xC, SWord 8
0x9, SWord 8
0x1, SWord 8
0x6, SWord 8
0x7, SWord 8
0x8, SWord 8
0x0, SWord 8
0xE, SWord 8
0x5, SWord 8
0xD, SWord 8
0x4] SWord 8
0x0

-- | Inverse SBox transformation.
sBoxInv :: Block -> Block
sBoxInv :: SWord 64 -> SWord 64
sBoxInv = [SWord 8] -> SWord 8 -> SWord 64 -> SWord 64
nonLinear [SWord 8
0xB, SWord 8
0x7, SWord 8
0x3, SWord 8
0x2, SWord 8
0xF, SWord 8
0xD, SWord 8
0x8, SWord 8
0x9, SWord 8
0xA, SWord 8
0x6, SWord 8
0x4, SWord 8
0x0, SWord 8
0x5, SWord 8
0xE, SWord 8
0xC, SWord 8
0x1] SWord 8
0x0

-- | Prove that sbox and sBoxInv are inverses: We have:
--
-- >>> prove prop_SBox
-- Q.E.D.
prop_SBox :: Predicate
prop_SBox :: Predicate
prop_SBox = do SWord 64
b <- String -> Symbolic (SWord 64)
forall a. SymVal a => String -> Symbolic (SBV a)
free String
"block"
               SBool -> Predicate
forall a. a -> SymbolicT IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (SBool -> Predicate) -> SBool -> Predicate
forall a b. (a -> b) -> a -> b
$   SWord 64
b SWord 64 -> SWord 64 -> SBool
forall a. EqSymbolic a => a -> a -> SBool
.== SWord 64 -> SWord 64
sBoxInv (SWord 64 -> SWord 64
sBox SWord 64
b)
                      SBool -> SBool -> SBool
.&& SWord 64
b SWord 64 -> SWord 64 -> SBool
forall a. EqSymbolic a => a -> a -> SBool
.== SWord 64 -> SWord 64
sBox (SWord 64 -> SWord 64
sBoxInv SWord 64
b)

-- * Round constants

-- | Round constants
rConstants :: Int -> SWord 64
rConstants :: Int -> SWord 64
rConstants  Int
0 = SWord 64
0x0000000000000000
rConstants  Int
1 = SWord 64
0x13198a2e03707344
rConstants  Int
2 = SWord 64
0xa4093822299f31d0
rConstants  Int
3 = SWord 64
0x082efa98ec4e6c89
rConstants  Int
4 = SWord 64
0x452821e638d01377
rConstants  Int
5 = SWord 64
0xbe5466cf34e90c6c
rConstants  Int
6 = SWord 64
0x7ef84f78fd955cb1
rConstants  Int
7 = SWord 64
0x85840851f1ac43aa
rConstants  Int
8 = SWord 64
0xc882d32f25323c54
rConstants  Int
9 = SWord 64
0x64a51195e0e3610d
rConstants Int
10 = SWord 64
0xd3b5a399ca0c2399
rConstants Int
11 = SWord 64
0xc0ac29b7c97c50dd
rConstants Int
n  = String -> SWord 64
forall a. HasCallStack => String -> a
error (String -> SWord 64) -> String -> SWord 64
forall a b. (a -> b) -> a -> b
$ String
"rConstants called with invalid round number: " String -> String -> String
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show Int
n

-- | Round-constants property: rc_i `xor` rc_{11-i} is constant. We have:
--
-- >>> prop_RoundKeys
-- True
prop_RoundKeys :: SBool
prop_RoundKeys :: SBool
prop_RoundKeys = [SBool] -> SBool
sAnd [SWord 64
magic SWord 64 -> SWord 64 -> SBool
forall a. EqSymbolic a => a -> a -> SBool
.== Int -> SWord 64
rConstants Int
i SWord 64 -> SWord 64 -> SWord 64
forall a. Bits a => a -> a -> a
`xor` Int -> SWord 64
rConstants (Int
11Int -> Int -> Int
forall a. Num a => a -> a -> a
-Int
i) | Int
i <- [Int
0 .. Int
11]]
  where magic :: SWord 64
magic = Int -> SWord 64
rConstants Int
11

-- | Convert a 64 bit word to nibbles
toNibbles :: SWord 64 -> [Nibble]
toNibbles :: SWord 64 -> [SWord 8]
toNibbles = (SWord 8 -> [SWord 8]) -> [SWord 8] -> [SWord 8]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap SWord 8 -> [SWord 8]
nibbles ([SWord 8] -> [SWord 8])
-> (SWord 64 -> [SWord 8]) -> SWord 64 -> [SWord 8]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SWord 64 -> [SWord 8]
forall a. ByteConverter a => a -> [SWord 8]
toBytes
  where nibbles :: SWord 8 -> [Nibble]
        nibbles :: SWord 8 -> [SWord 8]
nibbles SWord 8
b = [SWord 8
b SWord 8 -> Int -> SWord 8
forall a. Bits a => a -> Int -> a
`shiftR` Int
4, SWord 8
b SWord 8 -> SWord 8 -> SWord 8
forall a. Bits a => a -> a -> a
.&. SWord 8
0xF]

-- | Convert from nibbles to a 64 bit word
fromNibbles :: [Nibble] -> SWord 64
fromNibbles :: [SWord 8] -> SWord 64
fromNibbles [SWord 8]
xs
  | [SWord 8] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SWord 8]
xs Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
16 = String -> SWord 64
forall a. HasCallStack => String -> a
error (String -> SWord 64) -> String -> SWord 64
forall a b. (a -> b) -> a -> b
$ String
"fromNibbles: Incorrect number of nibbles, expected 16, got: " String -> String -> String
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show ([SWord 8] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SWord 8]
xs)
  | Bool
True            = [SWord 8] -> SWord 64
forall a. ByteConverter a => [SWord 8] -> a
fromBytes ([SWord 8] -> SWord 64) -> [SWord 8] -> SWord 64
forall a b. (a -> b) -> a -> b
$ [SWord 8] -> [SWord 8]
forall {a}. Bits a => [a] -> [a]
cvt [SWord 8]
xs
  where cvt :: [a] -> [a]
cvt (a
n1 : a
n2 : [a]
ns) = (a
n1 a -> Int -> a
forall a. Bits a => a -> Int -> a
`shiftL` Int
4 a -> a -> a
forall a. Bits a => a -> a -> a
.|. a
n2) a -> [a] -> [a]
forall a. a -> [a] -> [a]
: [a] -> [a]
cvt [a]
ns
        cvt [a]
_              = []

-- * Test vectors

-- | From Appendix A of the spec. We have:
--
-- >>> testVectors
-- True
testVectors :: SBool
testVectors :: SBool
testVectors = [SBool] -> SBool
sAnd ([SBool] -> SBool) -> [SBool] -> SBool
forall a b. (a -> b) -> a -> b
$  [SWord 64 -> SWord 64 -> SWord 64 -> SWord 64
encrypt SWord 64
pt SWord 64
k0 SWord 64
k1 SWord 64 -> SWord 64 -> SBool
forall a. EqSymbolic a => a -> a -> SBool
.== SWord 64
ct | (SWord 64
pt, SWord 64
k0, SWord 64
k1, SWord 64
ct) <- [(SWord 64, SWord 64, SWord 64, SWord 64)]
tvs]
                   [SBool] -> [SBool] -> [SBool]
forall a. [a] -> [a] -> [a]
++ [SWord 64 -> SWord 64 -> SWord 64 -> SWord 64
decrypt SWord 64
ct SWord 64
k0 SWord 64
k1 SWord 64 -> SWord 64 -> SBool
forall a. EqSymbolic a => a -> a -> SBool
.== SWord 64
pt | (SWord 64
pt, SWord 64
k0, SWord 64
k1, SWord 64
ct) <- [(SWord 64, SWord 64, SWord 64, SWord 64)]
tvs]
   where tvs :: [(SWord 64, SWord 64, SWord 64, SWord 64)]
         tvs :: [(SWord 64, SWord 64, SWord 64, SWord 64)]
tvs = [ (SWord 64
0x0000000000000000, SWord 64
0x0000000000000000, SWord 64
0x0000000000000000, SWord 64
0x818665aa0d02dfda)
               , (SWord 64
0xffffffffffffffff, SWord 64
0x0000000000000000, SWord 64
0x0000000000000000, SWord 64
0x604ae6ca03c20ada)
               , (SWord 64
0x0000000000000000, SWord 64
0xffffffffffffffff, SWord 64
0x0000000000000000, SWord 64
0x9fb51935fc3df524)
               , (SWord 64
0x0000000000000000, SWord 64
0x0000000000000000, SWord 64
0xffffffffffffffff, SWord 64
0x78a54cbe737bb7ef)
               , (SWord 64
0x0123456789abcdef, SWord 64
0x0000000000000000, SWord 64
0xfedcba9876543210, SWord 64
0xae25ad3ca8fa9ccf)
               ]

-- | Nicely show a concrete block.
showBlock :: Block -> String
showBlock :: SWord 64 -> String
showBlock SWord 64
b =  case SWord 64 -> Maybe (WordN 64)
forall a. SymVal a => SBV a -> Maybe a
unliteral SWord 64
b of
                 Just WordN 64
v  -> String
"0x" String -> String -> String
forall a. [a] -> [a] -> [a]
++ String -> String
pad (WordN 64 -> String -> String
forall a. Integral a => a -> String -> String
showHex WordN 64
v String
"")
                 Maybe (WordN 64)
Nothing -> String -> String
forall a. HasCallStack => String -> a
error String
"showBlock: Symbolic input!"
  where pad :: String -> String
pad String
s = String -> String
forall a. [a] -> [a]
reverse (String -> String) -> String -> String
forall a b. (a -> b) -> a -> b
$ Int -> String -> String
forall a. Int -> [a] -> [a]
take Int
16 (String -> String) -> String -> String
forall a b. (a -> b) -> a -> b
$ String -> String
forall a. [a] -> [a]
reverse String
s String -> String -> String
forall a. [a] -> [a] -> [a]
++ Char -> String
forall a. a -> [a]
repeat Char
'0'

-- * Code generation

-- | Generating C code for the encryption block.
codeGen :: IO ()
codeGen :: IO ()
codeGen = Maybe String -> String -> SBVCodeGen () -> IO ()
forall a. Maybe String -> String -> SBVCodeGen a -> IO a
compileToC Maybe String
forall a. Maybe a
Nothing String
"enc" (SBVCodeGen () -> IO ()) -> SBVCodeGen () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
               SWord 64
input <- String -> SBVCodeGen (SWord 64)
forall a. SymVal a => String -> SBVCodeGen (SBV a)
cgInput String
"inp"
               SWord 64
k0    <- String -> SBVCodeGen (SWord 64)
forall a. SymVal a => String -> SBVCodeGen (SBV a)
cgInput String
"k0"
               SWord 64
k1    <- String -> SBVCodeGen (SWord 64)
forall a. SymVal a => String -> SBVCodeGen (SBV a)
cgInput String
"k1"
               Bool -> SBVCodeGen ()
cgOverwriteFiles Bool
True
               String -> SWord 64 -> SBVCodeGen ()
forall a. String -> SBV a -> SBVCodeGen ()
cgOutput String
"ct"  (SWord 64 -> SBVCodeGen ()) -> SWord 64 -> SBVCodeGen ()
forall a b. (a -> b) -> a -> b
$ SWord 64 -> SWord 64 -> SWord 64 -> SWord 64
encrypt SWord 64
input SWord 64
k0 SWord 64
k1