{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE Trustworthy #-}
-- |
-- Module       : Data.ByteString.Lazy.Base64
-- Copyright    : (c) 2019-2020 Emily Pillmore
-- License      : BSD-style
--
-- Maintainer   : Emily Pillmore <emilypi@cohomolo.gy>
-- Stability    : stable
-- Portability  : non-portable
--
-- This module contains 'Data.ByteString.Lazy.ByteString'-valued combinators for
-- implementing the RFC 4648 specification of the Base64
-- encoding format. This includes lenient decoding variants, as well as
-- internal and external validation for canonicity.
--
module Data.ByteString.Lazy.Base64
( -- * Encoding
  encodeBase64
, encodeBase64'
  -- * Decoding
, decodeBase64
, decodeBase64Lenient
  -- * Validation
, isBase64
, isValidBase64
) where


import Prelude hiding (all, elem)

import qualified Data.ByteString as BS
import qualified Data.ByteString.Base64 as B64
import Data.ByteString.Base64.Internal.Utils (reChunkN)
import Data.ByteString.Lazy (elem, fromChunks, toChunks)
import Data.ByteString.Lazy.Internal (ByteString(..))
import Data.Either (isRight)
import qualified Data.Text as T
import qualified Data.Text.Lazy as TL
import qualified Data.Text.Lazy.Encoding as TL

-- | Encode a 'ByteString' value as Base64 'Text' with padding.
--
-- See: <https://tools.ietf.org/html/rfc4648#section-4 RFC-4648 section 4>
--
-- === __Examples__:
--
-- >>> encodeBase64 "Sun"
-- "U3Vu"
--
encodeBase64 :: ByteString -> TL.Text
encodeBase64 :: ByteString -> Text
encodeBase64 = ByteString -> Text
TL.decodeUtf8 (ByteString -> Text)
-> (ByteString -> ByteString) -> ByteString -> Text
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> ByteString
encodeBase64'
{-# INLINE encodeBase64 #-}

-- | Encode a 'ByteString' value as a Base64 'ByteString'  value with padding.
--
-- See: <https://tools.ietf.org/html/rfc4648#section-4 RFC-4648 section 4>
--
-- === __Examples__:
--
-- >>> encodeBase64' "Sun"
-- "U3Vu"
--
encodeBase64' :: ByteString -> ByteString
encodeBase64' :: ByteString -> ByteString
encodeBase64' = [ByteString] -> ByteString
fromChunks
  ([ByteString] -> ByteString)
-> (ByteString -> [ByteString]) -> ByteString -> ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (ByteString -> ByteString) -> [ByteString] -> [ByteString]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ByteString -> ByteString
B64.encodeBase64'
  ([ByteString] -> [ByteString])
-> (ByteString -> [ByteString]) -> ByteString -> [ByteString]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> [ByteString] -> [ByteString]
reChunkN Int
3
  ([ByteString] -> [ByteString])
-> (ByteString -> [ByteString]) -> ByteString -> [ByteString]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> [ByteString]
toChunks
{-# INLINE encodeBase64' #-}

-- | Decode a padded Base64-encoded 'ByteString' value.
--
-- See: <https://tools.ietf.org/html/rfc4648#section-4 RFC-4648 section 4>
--
-- === __Examples__:
--
-- >>> decodeBase64 "U3Vu"
-- Right "Sun"
--
-- >>> decodeBase64 "U3V"
-- Left "Base64-encoded bytestring requires padding"
--
-- >>> decodebase64 "U3V="
-- Left "non-canonical encoding detected at offset: 2"
--
decodeBase64 :: ByteString -> Either T.Text ByteString
decodeBase64 :: ByteString -> Either Text ByteString
decodeBase64 = (ByteString -> ByteString)
-> Either Text ByteString -> Either Text ByteString
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ([ByteString] -> ByteString
fromChunks ([ByteString] -> ByteString)
-> (ByteString -> [ByteString]) -> ByteString -> ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (ByteString -> [ByteString] -> [ByteString]
forall a. a -> [a] -> [a]
:[]))
  (Either Text ByteString -> Either Text ByteString)
-> (ByteString -> Either Text ByteString)
-> ByteString
-> Either Text ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> Either Text ByteString
B64.decodeBase64
  (ByteString -> Either Text ByteString)
-> (ByteString -> ByteString)
-> ByteString
-> Either Text ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [ByteString] -> ByteString
BS.concat
  ([ByteString] -> ByteString)
-> (ByteString -> [ByteString]) -> ByteString -> ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> [ByteString]
toChunks
{-# INLINE decodeBase64 #-}

-- | Leniently decode an unpadded Base64-encoded 'ByteString' value. This function
-- will not generate parse errors. If input data contains padding chars,
-- then the input will be parsed up until the first pad character.
--
-- __Note:__ This is not RFC 4648-compliant.
--
-- === __Examples__:
--
-- >>> decodeBase64Lenient "U3Vu"
-- "Sun"
--
-- >>> decodeBase64Lenient "U3V"
-- "Su"
--
-- >>> decodebase64Lenient "U3V="
-- "Su"
--
decodeBase64Lenient :: ByteString -> ByteString
decodeBase64Lenient :: ByteString -> ByteString
decodeBase64Lenient = [ByteString] -> ByteString
fromChunks
    ([ByteString] -> ByteString)
-> (ByteString -> [ByteString]) -> ByteString -> ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (ByteString -> ByteString) -> [ByteString] -> [ByteString]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ByteString -> ByteString
B64.decodeBase64Lenient
    ([ByteString] -> [ByteString])
-> (ByteString -> [ByteString]) -> ByteString -> [ByteString]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> [ByteString] -> [ByteString]
reChunkN Int
4
    ([ByteString] -> [ByteString])
-> (ByteString -> [ByteString]) -> ByteString -> [ByteString]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (ByteString -> ByteString) -> [ByteString] -> [ByteString]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ((Word8 -> Bool) -> ByteString -> ByteString
BS.filter ((Word8 -> ByteString -> Bool) -> ByteString -> Word8 -> Bool
forall a b c. (a -> b -> c) -> b -> a -> c
flip Word8 -> ByteString -> Bool
elem ByteString
"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/="))
    ([ByteString] -> [ByteString])
-> (ByteString -> [ByteString]) -> ByteString -> [ByteString]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> [ByteString]
toChunks
{-# INLINE decodeBase64Lenient #-}

-- | Tell whether a 'ByteString' value is base64 encoded.
--
-- This function will also detect non-canonical encodings such as @ZE==@, which are
-- externally valid Base64url-encoded values, but are internally inconsistent "impossible"
-- values.
--
-- === __Examples__:
--
-- >>> isBase64 "U3Vu"
-- True
--
-- >>> isBase64 "U3V"
-- False
--
-- >>> isBase64 "U3V="
-- False
--
isBase64 :: ByteString -> Bool
isBase64 :: ByteString -> Bool
isBase64 ByteString
bs = ByteString -> Bool
isValidBase64 ByteString
bs Bool -> Bool -> Bool
&& Either Text ByteString -> Bool
forall a b. Either a b -> Bool
isRight (ByteString -> Either Text ByteString
decodeBase64 ByteString
bs)
{-# INLINE isBase64 #-}

-- | Tell whether a 'ByteString' value is a valid Base64 format.
--
-- This will not tell you whether or not this is a correct Base64url representation,
-- only that it conforms to the correct shape. To check whether it is a true
-- Base64 encoded 'ByteString' value, use 'isBase64'.
--
-- === __Examples__:
--
-- >>> isValidBase64 "U3Vu"
-- True
--
-- >>> isValidBase64 "U3V"
-- True
--
-- >>> isValidBase64 "U3V="
-- True
--
-- >>> isValidBase64 "%"
-- False
--
isValidBase64 :: ByteString -> Bool
isValidBase64 :: ByteString -> Bool
isValidBase64 = [ByteString] -> Bool
go ([ByteString] -> Bool)
-> (ByteString -> [ByteString]) -> ByteString -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> [ByteString]
toChunks
  where
    go :: [ByteString] -> Bool
go [] = Bool
True
    go [ByteString
c] = ByteString -> Bool
B64.isValidBase64 ByteString
c
    go (ByteString
c:[ByteString]
cs) = -- note the lack of padding char
      (Word8 -> Bool) -> ByteString -> Bool
BS.all ((Word8 -> ByteString -> Bool) -> ByteString -> Word8 -> Bool
forall a b c. (a -> b -> c) -> b -> a -> c
flip Word8 -> ByteString -> Bool
elem ByteString
"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/") ByteString
c
      Bool -> Bool -> Bool
&& [ByteString] -> Bool
go [ByteString]
cs
{-# INLINE isValidBase64 #-}