-----------------------------------------------------------------------------
-- |
-- Module      :  Codec.Encryption.Padding
-- Copyright   :  (c) Dominic Steinitz 2003
-- License     :  BSD-style (see the file ReadMe.tex)
--
-- Stability   :  experimental
-- Portability :  portable
--
-- Padding algorithms for use with block ciphers.
--
-- This module currently supports:
--
-- * PKCS5 padding and unpadding.
--
-- * Null padding and unpadding.
--
-----------------------------------------------------------------------------

module Codec.Encryption.Padding (
   -- * Function types
   pkcs5, unPkcs5,
   padNulls, unPadNulls
   ) where

import Data.Word
import Data.Bits
import Data.List
import Codec.Utils

-- | When the last block of plaintext is shorter than the block size then it
-- must be padded. PKCS5 specifies that the padding octets should each
-- contain the number of octets which must be stripped off. So, for example,
-- with a block size of 8, \"0a0b0c\" will be padded with \"05\" resulting in
-- \"0a0b0c0505050505\". If the final block is a full block of 8 octets
-- then a whole block of \"0808080808080808\" is appended.

pkcs5 :: (Integral a, Bits a) => [Octet] -> [a]
pkcs5 :: forall a. (Integral a, Bits a) => [Octet] -> [a]
pkcs5 [Octet]
s = forall {b}.
(Integral b, Bits b) =>
(Int -> [Octet]) -> [Octet] -> [b]
pad forall {a}. Num a => Int -> [a]
p [Octet]
s where p :: Int -> [a]
p Int
n = forall a. Int -> a -> [a]
replicate Int
n (forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
n)

-- | When the last block of plaintext is shorter than the block size then it
-- must be padded. Nulls padding specifies that the padding octets should each
-- contain a null. So, for example,
-- with a block size of 8, \"0a0b0c\" will be padded to
-- \"0a0b0c0000000000\". If the final block is a full block of 8 octets
-- then a whole block of \"0000000000000000\" is appended.
-- NB this is only suitable for data which does not contain nulls,
-- for example, ASCII.

padNulls :: (Integral a, Bits a) => [Octet] -> [a]
padNulls :: forall a. (Integral a, Bits a) => [Octet] -> [a]
padNulls [Octet]
s = forall {b}.
(Integral b, Bits b) =>
(Int -> [Octet]) -> [Octet] -> [b]
pad forall {a}. Num a => Int -> [a]
p [Octet]
s where p :: Int -> [a]
p Int
n = forall a. Int -> a -> [a]
replicate Int
n a
0

testPad :: [Octet] -> [b]
testPad [Octet]
s = forall {b}.
(Integral b, Bits b) =>
(Int -> [Octet]) -> [Octet] -> [b]
pad forall {a}. Num a => Int -> [a]
p [Octet]
s where p :: Int -> [a]
p Int
n = forall a. Int -> a -> [a]
replicate (Int
nforall a. Num a => a -> a -> a
-Int
1) a
0xff forall a. [a] -> [a] -> [a]
++ [forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
n]

pad :: (Int -> [Octet]) -> [Octet] -> [b]
pad Int -> [Octet]
p [Octet]
s =
   [b]
blocks where
      octetSize :: Int
octetSize = (forall a. Bits a => a -> Int
bitSize forall a b. (a -> b) -> a -> b
$ forall a. [a] -> a
head [b]
blocks) forall a. Integral a => a -> a -> a
`div` Int
8
      blocks :: [b]
blocks = forall a b. (a -> b) -> [a] -> [b]
map (forall a b. (Integral a, Integral b) => a -> [Octet] -> b
fromOctets Integer
256) (forall b a. (b -> Maybe (a, b)) -> b -> [a]
unfoldr [Octet] -> Maybe ([Octet], [Octet])
h forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat forall a b. (a -> b) -> a -> b
$ forall b a. (b -> Maybe (a, b)) -> b -> [a]
unfoldr [Octet] -> Maybe ([Octet], [Octet])
g [Octet]
s)
      g :: [Octet] -> Maybe ([Octet],[Octet])
      g :: [Octet] -> Maybe ([Octet], [Octet])
g [Octet]
x
         | Int
l forall a. Eq a => a -> a -> Bool
== Int
0         = forall a. Maybe a
Nothing
         | Int
l forall a. Ord a => a -> a -> Bool
<  Int
octetSize = forall a. a -> Maybe a
Just ([Octet]
t forall a. [a] -> [a] -> [a]
++ (Int -> [Octet]
p (Int
octetSizeforall a. Num a => a -> a -> a
-Int
l)), [])
         | [Octet]
d forall a. Eq a => a -> a -> Bool
== []        = forall a. a -> Maybe a
Just ([Octet]
t forall a. [a] -> [a] -> [a]
++ (Int -> [Octet]
p Int
octetSize), [])
         | Bool
otherwise      = forall a. a -> Maybe a
Just ([Octet]
t, [Octet]
d)
         where l :: Int
l   = forall (t :: * -> *) a. Foldable t => t a -> Int
length [Octet]
t
               t :: [Octet]
t   = forall a. Int -> [a] -> [a]
take Int
octetSize [Octet]
x
               d :: [Octet]
d   = forall a. Int -> [a] -> [a]
drop Int
octetSize [Octet]
x
      h :: [Octet] -> Maybe ([Octet],[Octet])
      h :: [Octet] -> Maybe ([Octet], [Octet])
h [Octet]
x
         | [Octet]
x forall a. Eq a => a -> a -> Bool
== []   = forall a. Maybe a
Nothing
         | Bool
otherwise = forall a. a -> Maybe a
Just (forall a. Int -> [a] -> [a]
take Int
octetSize [Octet]
x, forall a. Int -> [a] -> [a]
drop Int
octetSize [Octet]
x)

-- | Take a list of blocks padded using the method described in PKCS5
-- (see <http://www.rsasecurity.com/rsalabs/pkcs/pkcs-5>)
-- and return the list of unpadded octets. NB this function does not
-- currently check that the padded block is correctly formed and should
-- only be used for blocks that have been padded correctly.

unPkcs5 :: (Bits a, Integral a) => [a] -> [Octet]
unPkcs5 :: forall a. (Bits a, Integral a) => [a] -> [Octet]
unPkcs5 [a]
s =
   forall {a}.
(Integral a, Bits a) =>
(Int -> [Octet] -> [Octet]) -> [a] -> [Octet]
unPad forall {a}. Integral a => Int -> [a] -> [a]
h [a]
s
      where
         h :: Int -> [a] -> [a]
h Int
octetSize [a]
x = forall a. Int -> [a] -> [a]
take (Int
octetSize forall a. Num a => a -> a -> a
- (forall a b. (Integral a, Num b) => a -> b
fromIntegral (forall a. [a] -> a
last [a]
x))) [a]
x

-- | Take a list of blocks padded with nulls
-- and return the list of unpadded octets. NB if the blocks contain
-- a null then the result is unpredictable.

unPadNulls :: (Bits a, Integral a) => [a] -> [Octet]
unPadNulls :: forall a. (Bits a, Integral a) => [a] -> [Octet]
unPadNulls [a]
s =
   forall {a}.
(Integral a, Bits a) =>
(Int -> [Octet] -> [Octet]) -> [a] -> [Octet]
unPad forall {a} {p}. (Eq a, Num a) => p -> [a] -> [a]
h [a]
s
      where
         h :: p -> [a] -> [a]
h p
_ [a]
x = forall a. (a -> Bool) -> [a] -> [a]
takeWhile (forall a. Eq a => a -> a -> Bool
/=a
0) [a]
x

unPad :: (Int -> [Octet] -> [Octet]) -> [a] -> [Octet]
unPad Int -> [Octet] -> [Octet]
p [a]
s =
   forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat forall a b. (a -> b) -> a -> b
$ forall b a. (b -> Maybe (a, b)) -> b -> [a]
unfoldr forall a. (Integral a, Bits a) => [a] -> Maybe ([Octet], [a])
g [a]
s where
      g :: (Integral a, Bits a) => [a] -> Maybe ([Octet],[a])
      g :: forall a. (Integral a, Bits a) => [a] -> Maybe ([Octet], [a])
g [a]
x
         | [a]
t forall a. Eq a => a -> a -> Bool
== []   = forall a. Maybe a
Nothing
         | [a]
d forall a. Eq a => a -> a -> Bool
== []   = forall a. a -> Maybe a
Just ([Octet]
s, [])
         | Bool
otherwise = forall a. a -> Maybe a
Just ([Octet]
v, [a]
d)
         where t :: [a]
t     = forall a. Int -> [a] -> [a]
take Int
1 [a]
x
               d :: [a]
d     = forall a. Int -> [a] -> [a]
drop Int
1 [a]
x
               u :: a
u     = forall a. [a] -> a
head [a]
t
               octetSize :: Int
octetSize = (forall a. Bits a => a -> Int
bitSize a
u) forall a. Integral a => a -> a -> a
`div` Int
8
               v :: [Octet]
v     = forall a. Integral a => Int -> a -> [Octet]
i2osp Int
octetSize a
u
               s :: [Octet]
s     = Int -> [Octet] -> [Octet]
p Int
octetSize [Octet]
v