{-# LANGUAGE BangPatterns  #-}
{-# LANGUAGE Haskell2010   #-}
{-# LANGUAGE MagicHash     #-}
{-# LANGUAGE UnboxedTuples #-}

{-# LANGUAGE Trustworthy   #-}

-- |
-- Copyright: © 2020  Herbert Valerio Riedel
-- SPDX-License-Identifier: GPL-2.0-or-later
--
-- Apply XOR-masks to 'BS.ByteString's and memory regions.
--
module Data.XOR
    ( -- * Apply 32-bit XOR mask
      xor32StrictByteString
    , xor32StrictByteString'
    , xor32LazyByteString
    , xor32ShortByteString
    , xor32CStringLen

      -- * Apply 8-bit XOR mask
    , xor8StrictByteString
    , xor8LazyByteString
    , xor8ShortByteString
    , xor8CStringLen

    ) where

-- base
import           Control.Exception              (assert)
import           Control.Monad                  (void)
import           Control.Monad.ST               (ST, runST)
import           Data.Bits
import           Data.Tuple                     (swap)
import           Endianness                     (ByteOrder (..), Word32, Word8, byteSwap32,
                                                 targetByteOrder)
import           Foreign.C                      (CStringLen)
import           Foreign.ForeignPtr             (withForeignPtr)
import           Foreign.Ptr                    (Ptr, alignPtr, castPtr, minusPtr, plusPtr)
import           Foreign.Storable               (peek, poke)
import           System.IO.Unsafe               (unsafeDupablePerformIO)

import qualified GHC.Exts                       as X
import qualified GHC.ST                         as X
import qualified GHC.Word                       as X

-- bytestring
import qualified Data.ByteString                as BS
import           Data.ByteString.Internal       (mallocByteString, memcpy)
import qualified Data.ByteString.Internal       as BS (ByteString (..))
import qualified Data.ByteString.Lazy.Internal  as BL (ByteString (..))
import qualified Data.ByteString.Short          as SBS
import           Data.ByteString.Short.Internal (ShortByteString (SBS))

----------------------------------------------------------------------------

{- high-level reference impl

-- about 6-7 times slower
xor32StrictByteString'ref :: Word32 -> BS.ByteString -> BS.ByteString
xor32StrictByteString'ref 0    = id
xor32StrictByteString'ref msk0 = snd . BS.mapAccumL go msk0
  where
    go :: Word32 -> Word8 -> (Word32,Word8)
    go msk b = let b'   = fromIntegral msk' `xor` b
                   msk' = rotateL msk 8
               in b' `seq` (msk',b')

-- about 3 times slower
xor8StrictByteString'ref :: Word8 -> BS.ByteString -> BS.ByteString
xor8StrictByteString'ref 0    = id
xor8StrictByteString'ref msk0 = BS.map (xor msk0)

-}

-- | Apply 32-bit XOR mask (considered as four octets in big-endian order) to 'BS.ByteString'.
--
-- >>> xor32StrictByteString 0x37fa213d "\x7f\x9f\x4d\x51\x58"
-- "Hello"
--
-- In other words, the 32-bit word @0x37fa213d@ is taken as the infinite series of octets @('cycle' [0x37,0xfa,0x21,0x3d])@ and 'xor'ed with the respective octets from the input 'BS.ByteString'.
--
-- The 'xor' laws give rise to the following laws:
--
-- prop> xor32StrictByteString m (xor32StrictByteString m x) == x
--
-- prop> xor32StrictByteString 0 x == x
--
-- prop> xor32StrictByteString m (xor32StrictByteString n x) == xor32StrictByteString (m `xor` n) x
--
-- This function is semantically equivalent to the (less efficient) implementation shown below
--
-- > xor32StrictByteString'ref :: Word32 -> BS.ByteString -> BS.ByteString
-- > xor32StrictByteString'ref 0    = id
-- > xor32StrictByteString'ref msk0 = snd . BS.mapAccumL go msk0
-- >   where
-- >     go :: Word32 -> Word8 -> (Word32,Word8)
-- >     go msk b = let b'   = fromIntegral (msk' .&. 0xff) `xor` b
-- >                    msk' = rotateL msk 8
-- >                in (msk',b')
--
-- The 'xor32StrictByteString' implementation is about 6-7 times faster than the naive implementation above.
xor32StrictByteString :: Word32 -> BS.ByteString -> BS.ByteString
xor32StrictByteString 0 bs   = bs
xor32StrictByteString _ bs   | BS.null bs = bs
xor32StrictByteString msk bs = fst (xor32StrictByteString'' msk bs)

-- | Convenience version of 'xor32StrictByteString' which also returns the rotated XOR-mask useful for chained masking.
--
-- >>> xor32StrictByteString' 0x37fa213d "\x7f\x9f\x4d\x51\x58"
-- (0xfa213d37,"Hello")
--
xor32StrictByteString' :: Word32 -> BS.ByteString -> (Word32,BS.ByteString)
xor32StrictByteString' 0 bs   = (0,bs)
xor32StrictByteString' msk bs | BS.null bs = (msk,bs)
xor32StrictByteString' msk bs = swap (xor32StrictByteString'' msk bs)

-- | Variant of 'xor32StrictByteString' for masking lazy 'BL.ByteString's.
--
-- >>> xor32LazyByteString 0x37fa213d "\x7f\x9f\x4d\x51\x58"
-- "Hello"
--
xor32LazyByteString :: Word32 -> BL.ByteString -> BL.ByteString
xor32LazyByteString 0 = id
xor32LazyByteString msk0 = go msk0
  where
    go _ BL.Empty = BL.Empty
    go msk (BL.Chunk x xs) = BL.Chunk x' (go msk' xs)
      where
        (x',msk') = xor32StrictByteString'' msk x

{-# INLINE xor32StrictByteString'' #-}
-- internal
xor32StrictByteString'' :: Word32 -> BS.ByteString -> (BS.ByteString,Word32)
xor32StrictByteString'' msk0 (BS.PS x s l)
    = unsafeCreate' l $ \p8 ->
        withForeignPtr x $ \f -> do
          memcpy p8 (f `plusPtr` s) (fromIntegral l)

          case remPtr p8 4 of
            0 -> do
              let trailer = l `rem` 4
                  lbytes = l - trailer
              xor32PtrAligned msk0 (castPtr p8) lbytes
              xor32PtrNonAligned msk0 (p8 `plusPtr` lbytes) trailer
            _ ->
              -- misaligned bytestring...
              --
              -- This should not happen, as newly allocated
              -- bytestrings ought to be word-aligned; but if the
              -- impossible does happen we have a semantically sound
              -- codepath to jump to...
              xor32Ptr msk0 p8 l



-- | Apply 32-bit XOR mask (considered as four octets in big-endian order) to 'SBS.ShortByteString'. See also 'xor32StrictByteString'.
--
-- >>> xor32ShortByteString 0x37fa213d "\x7f\x9f\x4d\x51\x58"
-- "Hello"
--
xor32ShortByteString :: Word32 -> SBS.ShortByteString -> SBS.ShortByteString
xor32ShortByteString 0 sbs = sbs
xor32ShortByteString _ sbs | SBS.null sbs = sbs
xor32ShortByteString mask0be sbs = runST $ do
    tmp <- newSBS len

    let loop4 i
          | i == len4  = return ()
          | otherwise  = writeWord32Array tmp i (indexWord32Array sbs i `xor` mask0) >> loop4 (i+1)

    loop4 0

    let writeXor8 ofs msk8 = writeWord8Array tmp ofs (indexWord8Array sbs ofs `xor` msk8)

    case len1 of
      0 -> return ()
      1 -> do
        writeXor8 (len-1) (fromIntegral (shiftR mask0be 24))
      2 -> do
        writeXor8 (len-2) (fromIntegral (shiftR mask0be 24))
        writeXor8 (len-1) (fromIntegral (shiftR mask0be 16))
      3 -> do
        writeXor8 (len-3) (fromIntegral (shiftR mask0be 24))
        writeXor8 (len-2) (fromIntegral (shiftR mask0be 16))
        writeXor8 (len-1) (fromIntegral (shiftR mask0be  8))
      _ -> undefined -- impossible

    unsafeFreezeSBS tmp
  where
    len = SBS.length sbs
    (len4,len1) = quotRem len 4

    mask0 = case targetByteOrder of
              LittleEndian -> byteSwap32 mask0be
              BigEndian    -> mask0be


{-# INLINEABLE xor32CStringLen #-}
-- | Apply 32-bit XOR mask (considered as four octets in big-endian order) to memory region expressed as base-pointer and size. The returned value is the input mask rotated by the word-size remained of the memory region size (useful for chained xor-masking of multiple memory-fragments).
xor32CStringLen :: Word32 -> CStringLen -> IO Word32
xor32CStringLen m (p,l) = xor32Ptr m (castPtr p) l

{-# INLINEABLE xor32Ptr #-}
xor32Ptr :: Word32 -> Ptr Word8 -> Int -> IO Word32
xor32Ptr 0      !_  !_ = return 0
xor32Ptr !mask0 !_   0 = return mask0
xor32Ptr !mask0 !p0 !n
  | n < 4 = xor32PtrNonAligned mask0 p0 n
  | n < 0 = fail "xor32Ptr: negative size argument not supported"
xor32Ptr !mask0 !p0 !n
  | assert (p0 <= p1 && p1 <= p2 && p2 <= p3 && n0 < 4 && n2 < 4) False = undefined -- assert invariants
  | n1 == 0 = xor32PtrNonAligned mask0 p0 n
  | n0 == 0 = do
      xor32PtrAligned    mask0 p1 n1
      xor32PtrNonAligned mask0 p2 n2
  | otherwise = do
      mask1 <- xor32PtrNonAligned mask0 p0 n0
      xor32PtrAligned    mask1 p1 n1
      xor32PtrNonAligned mask1 p2 n2
  where
    -- Invariants: p0 <= p1 <= p2 <= p3
    --             0  <= n0  < 4
    --             0  <= n1
    --             0  <= n2  < 4
    --             n  == n0+n1+n2 >= 4
    p1 = castPtr (alignPtr p0 d)
    p2 = alignPtrDown p3 d
    p3 = plusPtr p0 n
    d  = 4

    n0 = p1 `minusPtr` p0
    n1 = p2 `minusPtr` p1
    n2 = p3 `minusPtr` p2

-- internal
xor32PtrNonAligned :: Word32 -> Ptr Word8 -> Int -> IO Word32
xor32PtrNonAligned mask0 _ 0 = return mask0
xor32PtrNonAligned mask0 p 1 = do
  let mask1 = rotateL mask0 8
  xor8Ptr1 (fromIntegral mask1) p
  return mask1
xor32PtrNonAligned mask0 p 2 = do
  xor8Ptr1 (fromIntegral (mask0 `shiftR` 24)) p
  let mask1 = mask0 `rotateL` 16
  xor8Ptr1 (fromIntegral mask1) (p `plusPtr` 1)
  return mask1
xor32PtrNonAligned mask0 p 3 = do
  xor8Ptr1 (fromIntegral (mask0 `shiftR` 24)) p
  xor8Ptr1 (fromIntegral (mask0 `shiftR` 16)) (p `plusPtr` 1)
  let mask1 = mask0 `rotateL` 24
  xor8Ptr1 (fromIntegral mask1) (p `plusPtr` 2)
  return mask1
xor32PtrNonAligned mask0 p0 n = go mask0 p0
  where
    p' = p0 `plusPtr` n
    go m p
      | p == p'   = return m
      | otherwise = do
          let m' = rotateL m 8
          xor8Ptr1 (fromIntegral m') p
          go m' (p `plusPtr` 1)

-- internal
xor32PtrAligned :: Word32 -> Ptr Word32 -> Int -> IO ()
xor32PtrAligned _ _ 0 = return ()
xor32PtrAligned mask0be p0 n
  = assert (p0 `remPtr` 4 == 0 && n `rem` 4 == 0) $ go p0
  where
    p' = p0 `plusPtr` n
    go p
      | p == p'   = return ()
      | otherwise = do { xor32Ptr1 mask0 p; go (p `plusPtr` 4) }

    mask0 = case targetByteOrder of
              LittleEndian -> byteSwap32 mask0be
              BigEndian    -> mask0be

----------------------------------------------------------------------------

remPtr :: Ptr a -> Int -> Int
remPtr (X.Ptr x) (X.I# d) = X.I# (X.remAddr# x d)

alignPtrDown :: Ptr a -> Int -> Ptr a
alignPtrDown p i
  = case remPtr p i of
      0 -> p
      n -> plusPtr p (negate n)

xor8Ptr1 :: Word8 -> Ptr Word8 -> IO ()
xor8Ptr1 msk ptr  = do { x <- peek ptr; poke ptr (xor msk x) }

-- xor16Ptr1 :: Word16 -> Ptr Word16 -> IO ()
-- xor16Ptr1 msk ptr = do { x <- peek ptr; poke ptr (xor msk x) }

xor32Ptr1 :: Word32 -> Ptr Word32 -> IO ()
xor32Ptr1 msk ptr = do { x <- peek ptr; poke ptr (xor msk x) }

{-# INLINE unsafeCreate' #-}
unsafeCreate' :: Int -> (Ptr Word8 -> IO a) -> (BS.ByteString, a)
unsafeCreate' l0 f0 = unsafeDupablePerformIO (create' l0 f0)
  where
    {-# INLINE create' #-}
    create' :: Int -> (Ptr Word8 -> IO a) -> IO (BS.ByteString, a)
    create' l f = do
        fp <- mallocByteString l
        res <- withForeignPtr fp $ \p -> f p
        return (BS.PS fp 0 l, res)

----------------------------------------------------------------------------
-- single octet masks -- trivially mapped to 32-bit versions

expandW8ToW32 :: Word8 -> Word32
expandW8ToW32 x = x' .|. (x' `shiftL` 16)
  where
    x' = fromIntegral x .|. (fromIntegral x `shiftL` 8)


-- | Apply 8-bit XOR mask to each octet of a 'BS.ByteString'.
--
-- >>> xor8StrictByteString 0x20 "Hello"
-- "hELLO"
--
-- This function is a faster implementation of the semantically equivalent function shown below:
--
-- > xor8StrictByteString'ref :: Word8 -> BS.ByteString -> BS.ByteString
-- > xor8StrictByteString'ref 0    = id
-- > xor8StrictByteString'ref msk0 = BS.map (xor msk0)
--
xor8StrictByteString :: Word8 -> BS.ByteString -> BS.ByteString
xor8StrictByteString x = xor32StrictByteString (expandW8ToW32 x)

-- | Apply 8-bit XOR mask to each octet of a lazy 'BL.ByteString'.
--
-- See also 'xor8StrictByteString'
xor8LazyByteString :: Word8 -> BL.ByteString -> BL.ByteString
xor8LazyByteString x = xor32LazyByteString (expandW8ToW32 x)

-- | Apply 8-bit XOR mask to each octet of a 'SBS.ShortByteString'.
--
-- See also 'xor8StrictByteString'
xor8ShortByteString :: Word8 -> SBS.ShortByteString -> SBS.ShortByteString
xor8ShortByteString x = xor32ShortByteString (expandW8ToW32 x)

-- | Apply 8-bit XOR mask to each octet of a memory region expressed as start address and length in bytes.
--
-- See also 'xor8StrictByteString'
xor8CStringLen :: Word8 -> CStringLen -> IO ()
xor8CStringLen x (p,l) = void (xor32Ptr (expandW8ToW32 x) (castPtr p) l)

----------------------------------------------------------------------------
-- The missing mutable ShortByteString abstraction

data MShortByteString s = MSBS (X.MutableByteArray# s)

newSBS :: Int -> ST s (MShortByteString s)
newSBS (X.I# len#) = X.ST $ \s0 -> case X.newByteArray# len# s0 of (# s, mba# #) -> (# s, MSBS mba# #)

indexWord8Array :: ShortByteString -> Int -> Word8
indexWord8Array (SBS ba#) (X.I# i#) = X.W8# (X.indexWord8Array# ba# i#)

writeWord8Array :: MShortByteString s -> Int -> Word8 -> ST s ()
writeWord8Array (MSBS mba#) (X.I# i#) (X.W8# w#) = X.ST $ \s0 -> case X.writeWord8Array# mba# i# w# s0 of s -> (# s, () #)

indexWord32Array :: ShortByteString -> Int -> Word32
indexWord32Array (SBS ba#) (X.I# i#) = X.W32# (X.indexWord32Array# ba# i#)

writeWord32Array :: MShortByteString s -> Int -> Word32 -> ST s ()
writeWord32Array (MSBS mba#) (X.I# i#) (X.W32# w#) = X.ST $ \s0 -> case X.writeWord32Array# mba# i# w# s0 of s -> (# s, () #)

unsafeFreezeSBS :: MShortByteString s -> ST s ShortByteString
unsafeFreezeSBS (MSBS mba#) = X.ST $ \s0 -> case X.unsafeFreezeByteArray# mba# s0 of (# s, ba# #) -> (# s, SBS ba# #)