{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE CPP #-}
{-# LANGUAGE MagicHash #-}
{-# LANGUAGE MultiWayIf #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE TypeApplications #-}
-- |
-- Module       : Data.ByteString.Base32.Internal
-- Copyright 	: (c) 2020 Emily Pillmore
-- License	: BSD-style
--
-- Maintainer	: Emily Pillmore <emilypi@cohomolo.gy>
-- Stability	: Experimental
-- Portability	: portable
--
-- Internal module defining the encoding and decoding
-- processes and tables.
--
module Data.ByteString.Base32.Internal
( validateBase32
, validateLastNPads
) where


import qualified Data.ByteString as BS
import Data.ByteString.Internal
import Data.Text (Text)

import Foreign.ForeignPtr
import Foreign.Ptr
import Foreign.Storable

import GHC.Word

import System.IO.Unsafe

-- -------------------------------------------------------------------------- --
-- Validating Base32

validateBase32 :: ByteString -> ByteString -> Bool
validateBase32 :: ByteString -> ByteString -> Bool
validateBase32 !ByteString
alphabet bs :: ByteString
bs@(PS ForeignPtr Word8
_ Int
_ Int
l)
    | Int
l Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0 = Bool
True
    | Int
r Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0 = ByteString -> Bool
f ByteString
bs
    | Int
r Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
2 = ByteString -> Bool
f (ByteString -> ByteString -> ByteString
BS.append ByteString
bs ByteString
"======")
    | Int
r Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
4 = ByteString -> Bool
f (ByteString -> ByteString -> ByteString
BS.append ByteString
bs ByteString
"====")
    | Int
r Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
5 = ByteString -> Bool
f (ByteString -> ByteString -> ByteString
BS.append ByteString
bs ByteString
"===")
    | Int
r Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
7 = ByteString -> Bool
f (ByteString -> ByteString -> ByteString
BS.append ByteString
bs ByteString
"=")
    | Bool
otherwise = Bool
False
  where
    r :: Int
r = Int
l Int -> Int -> Int
forall a. Integral a => a -> a -> a
`rem` Int
8

    f :: ByteString -> Bool
f (PS ForeignPtr Word8
fp Int
o Int
l') = IO Bool -> Bool
forall a. IO a -> a
accursedUnutterablePerformIO (IO Bool -> Bool) -> IO Bool -> Bool
forall a b. (a -> b) -> a -> b
$ ForeignPtr Word8 -> (Ptr Word8 -> IO Bool) -> IO Bool
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr Word8
fp ((Ptr Word8 -> IO Bool) -> IO Bool)
-> (Ptr Word8 -> IO Bool) -> IO Bool
forall a b. (a -> b) -> a -> b
$ \Ptr Word8
p ->
      Ptr Word8 -> Ptr Word8 -> IO Bool
go (Ptr Word8 -> Int -> Ptr Word8
forall a b. Ptr a -> Int -> Ptr b
plusPtr Ptr Word8
p Int
o) (Ptr Word8 -> Int -> Ptr Word8
forall a b. Ptr a -> Int -> Ptr b
plusPtr Ptr Word8
p (Int
l' Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
o))

    go :: Ptr Word8 -> Ptr Word8 -> IO Bool
go !Ptr Word8
p !Ptr Word8
end
      | Ptr Word8
p Ptr Word8 -> Ptr Word8 -> Bool
forall a. Eq a => a -> a -> Bool
== Ptr Word8
end = Bool -> IO Bool
forall (m :: * -> *) a. Monad m => a -> m a
return Bool
True
      | Bool
otherwise = do
        Word8
w <- Ptr Word8 -> IO Word8
forall a. Storable a => Ptr a -> IO a
peek Ptr Word8
p

        let check :: Word8 -> Bool
check Word8
a
              | Word8
a Word8 -> Word8 -> Bool
forall a. Eq a => a -> a -> Bool
== Word8
0x3d, Ptr Word8 -> Int -> Ptr Word8
forall a b. Ptr a -> Int -> Ptr b
plusPtr Ptr Word8
p Int
1 Ptr Word8 -> Ptr Word8 -> Bool
forall a. Eq a => a -> a -> Bool
== Ptr Word8
end = Bool
True
              | Word8
a Word8 -> Word8 -> Bool
forall a. Eq a => a -> a -> Bool
== Word8
0x3d, Ptr Word8 -> Int -> Ptr Word8
forall a b. Ptr a -> Int -> Ptr b
plusPtr Ptr Word8
p Int
2 Ptr Word8 -> Ptr Word8 -> Bool
forall a. Eq a => a -> a -> Bool
== Ptr Word8
end = Bool
True
              | Word8
a Word8 -> Word8 -> Bool
forall a. Eq a => a -> a -> Bool
== Word8
0x3d, Ptr Word8 -> Int -> Ptr Word8
forall a b. Ptr a -> Int -> Ptr b
plusPtr Ptr Word8
p Int
3 Ptr Word8 -> Ptr Word8 -> Bool
forall a. Eq a => a -> a -> Bool
== Ptr Word8
end = Bool
True
              | Word8
a Word8 -> Word8 -> Bool
forall a. Eq a => a -> a -> Bool
== Word8
0x3d, Ptr Word8 -> Int -> Ptr Word8
forall a b. Ptr a -> Int -> Ptr b
plusPtr Ptr Word8
p Int
4 Ptr Word8 -> Ptr Word8 -> Bool
forall a. Eq a => a -> a -> Bool
== Ptr Word8
end = Bool
True
              | Word8
a Word8 -> Word8 -> Bool
forall a. Eq a => a -> a -> Bool
== Word8
0x3d, Ptr Word8 -> Int -> Ptr Word8
forall a b. Ptr a -> Int -> Ptr b
plusPtr Ptr Word8
p Int
5 Ptr Word8 -> Ptr Word8 -> Bool
forall a. Eq a => a -> a -> Bool
== Ptr Word8
end = Bool
True
              | Word8
a Word8 -> Word8 -> Bool
forall a. Eq a => a -> a -> Bool
== Word8
0x3d, Ptr Word8 -> Int -> Ptr Word8
forall a b. Ptr a -> Int -> Ptr b
plusPtr Ptr Word8
p Int
6 Ptr Word8 -> Ptr Word8 -> Bool
forall a. Eq a => a -> a -> Bool
== Ptr Word8
end = Bool
True
              | Word8
a Word8 -> Word8 -> Bool
forall a. Eq a => a -> a -> Bool
== Word8
0x3d = Bool
False
              | Bool
otherwise = Word8 -> ByteString -> Bool
BS.elem Word8
a ByteString
alphabet

        if Word8 -> Bool
check Word8
w then Ptr Word8 -> Ptr Word8 -> IO Bool
go (Ptr Word8 -> Int -> Ptr Word8
forall a b. Ptr a -> Int -> Ptr b
plusPtr Ptr Word8
p Int
1) Ptr Word8
end else Bool -> IO Bool
forall (m :: * -> *) a. Monad m => a -> m a
return Bool
False
{-# INLINE validateBase32 #-}

-- | This function checks that the last N-chars of a bytestring are '='
-- and, if true, fails with a message or completes some io action.
--
-- This is necessary to check when decoding permissively (i.e. filling in padding chars).
-- Consider the following 8 cases of a string of length l:
--
-- l = 0 mod 8: No pad chars are added, since the input is assumed to be good.
-- l = 1 mod 8: Never an admissible length in base32
-- l = 2 mod 8: 6 padding chars are added. If padding chars are present in the string, they will fail as to decode as final quanta
-- l = 3 mod 8: Never an admissible length in base32
-- l = 4 mod 8: 4 padding chars are added. If 2 padding chars are present in the string this can be "completed" in the sense that
--              it now acts like a string `l == 2 mod 8` with 6 padding chars, and could potentially form corrupted data.
-- l = 5 mod 8: 3 padding chars are added. If 3 padding chars are present in the string, this could form corrupted data like in the
--              previous case.
-- l = 6 mod 8: Never an admissible length in base32
-- l = 7 mod 8: 1 padding char is added. If 5 padding chars are present in the string, this could form corrupted data like the
--              previous cases.
--
-- Hence, permissive decodes should only fill in padding chars when it makes sense to add them. That is,
-- if an input is degenerate, it should never succeed when we add padding chars. We need the following invariant to hold:
--
-- @
--   B32.decodeUnpadded <|> B32.decodePadded ~ B32.decode
-- @
--
validateLastNPads
    :: Int
    -> ByteString
    -> IO (Either Text ByteString)
    -> Either Text ByteString
validateLastNPads :: Int
-> ByteString
-> IO (Either Text ByteString)
-> Either Text ByteString
validateLastNPads !Int
n (PS !ForeignPtr Word8
fp !Int
o !Int
l) IO (Either Text ByteString)
io
    | Bool -> Bool
not Bool
valid = Text -> Either Text ByteString
forall a b. a -> Either a b
Left Text
"Base32-encoded bytestring has invalid padding"
    | Bool
otherwise = IO (Either Text ByteString) -> Either Text ByteString
forall a. IO a -> a
unsafeDupablePerformIO IO (Either Text ByteString)
io
  where
    !lo :: Int
lo = Int
l Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
o
    valid :: Bool
valid = IO Bool -> Bool
forall a. IO a -> a
accursedUnutterablePerformIO (IO Bool -> Bool) -> IO Bool -> Bool
forall a b. (a -> b) -> a -> b
$ ForeignPtr Word8 -> (Ptr Word8 -> IO Bool) -> IO Bool
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr Word8
fp ((Ptr Word8 -> IO Bool) -> IO Bool)
-> (Ptr Word8 -> IO Bool) -> IO Bool
forall a b. (a -> b) -> a -> b
$ \Ptr Word8
p -> do
      let end :: Ptr b
end = Ptr Word8 -> Int -> Ptr b
forall a b. Ptr a -> Int -> Ptr b
plusPtr Ptr Word8
p Int
lo

      let go :: Ptr Word8 -> IO Bool
          go :: Ptr Word8 -> IO Bool
go !Ptr Word8
q
            | Ptr Word8
q Ptr Word8 -> Ptr Word8 -> Bool
forall a. Eq a => a -> a -> Bool
== Ptr Word8
forall b. Ptr b
end = Bool -> IO Bool
forall (m :: * -> *) a. Monad m => a -> m a
return Bool
True
            | Bool
otherwise = do
              Word8
a <- Ptr Word8 -> IO Word8
forall a. Storable a => Ptr a -> IO a
peek Ptr Word8
q
              if Word8
a Word8 -> Word8 -> Bool
forall a. Eq a => a -> a -> Bool
== Word8
0x3d then Bool -> IO Bool
forall (m :: * -> *) a. Monad m => a -> m a
return Bool
False else Ptr Word8 -> IO Bool
go (Ptr Word8 -> Int -> Ptr Word8
forall a b. Ptr a -> Int -> Ptr b
plusPtr Ptr Word8
q Int
1)

      Ptr Word8 -> IO Bool
go (Ptr Word8 -> Int -> Ptr Word8
forall a b. Ptr a -> Int -> Ptr b
plusPtr Ptr Word8
p (Int
lo Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
n))
{-# INLINE validateLastNPads #-}