{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE CPP #-}
{-# LANGUAGE MagicHash #-}
{-# LANGUAGE MultiWayIf #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE TypeApplications #-}
module Data.ByteString.Base64.Internal
( validateBase64
, validateBase64Url
, validateLastPad
) where
import qualified Data.ByteString as BS
import Data.ByteString.Internal
import Data.Text (Text)
import Foreign.ForeignPtr
import Foreign.Ptr
import Foreign.Storable
import System.IO.Unsafe
validateBase64 :: ByteString -> ByteString -> Bool
validateBase64 !alphabet (PS !fp !off !l) =
accursedUnutterablePerformIO $ withForeignPtr fp $ \p ->
go (plusPtr p off) (plusPtr p (l + off))
where
go !p !end
| p == end = return True
| otherwise = do
w <- peek p
let f a
| a == 0x3d, plusPtr p 1 == end = True
| a == 0x3d, plusPtr p 2 == end = True
| a == 0x3d = False
| otherwise = BS.elem a alphabet
if f w then go (plusPtr p 1) end else return False
{-# INLINE validateBase64 #-}
validateBase64Url :: ByteString -> ByteString -> Bool
validateBase64Url !alphabet bs@(PS _ _ l)
| l == 0 = True
| r == 0 = f bs
| r == 2 = f (BS.append bs "==")
| r == 3 = f (BS.append bs "=")
| otherwise = False
where
r = l `rem` 4
f (PS fp o n) = accursedUnutterablePerformIO $
withForeignPtr fp $ \p -> go (plusPtr p o) (plusPtr p (n + o))
go !p !end
| p == end = return True
| otherwise = do
w <- peek p
let check a
| a == 0x3d, plusPtr p 1 == end = True
| a == 0x3d, plusPtr p 2 == end = True
| a == 0x3d = False
| otherwise = BS.elem a alphabet
if check w then go (plusPtr p 1) end else return False
{-# INLINE validateBase64Url #-}
validateLastPad
:: ByteString
-> IO (Either Text ByteString)
-> Either Text ByteString
validateLastPad !bs io
| BS.last bs == 0x3d = Left "Base64-encoded bytestring has invalid padding"
| otherwise = unsafeDupablePerformIO io
{-# INLINE validateLastPad #-}