{- | Low-level patchscript processing and application.

Patchscripts are applied as a list of @(skip x, write in-place y)@ commands. An
offset-based format is much simpler to use, however. This module processes such
offset patchscripts into a "linear" patchscript, and provides a stream patching
algorithm that can be applied to any forward-seeking byte stream.

Some core types are parameterized over the stream type/patch content. This
enables writing patches in any form (e.g. UTF-8 text), which are then processed
into an applicable patch by transforming edits into a concrete binary
representation (e.g. null-terminated UTF-8 bytestring). See TODO module for
more.
-}

module BytePatch.Linear.Patch
  ( patch
  , MonadFwdByteStream(..)
  , Cfg(..)
  , Error(..)
  ) where

import           BytePatch.Core
import           BytePatch.Linear.Core

import qualified Data.ByteString         as BS
import qualified Data.ByteString.Builder as BB
import           Control.Monad.State
import           Control.Monad.Reader
import           System.IO               ( Handle, SeekMode(..), hSeek )

type Bytes = BS.ByteString

-- TODO also require reporting cursor position (for error reporting)
class Monad m => MonadFwdByteStream m where
    -- | Read a number of bytes without advancing the cursor.
    readahead :: Int -> m Bytes

    -- | Advance cursor without reading.
    advance :: Int -> m ()

    -- | Insert bytes into the stream at the cursor position, overwriting
    --   existing bytes.
    overwrite :: Bytes -> m ()

instance Monad m => MonadFwdByteStream (StateT (Bytes, BB.Builder) m) where
    readahead :: Int -> StateT (Bytes, Builder) m Bytes
readahead Int
n = do
        (Bytes
src, Builder
out) <- StateT (Bytes, Builder) m (Bytes, Builder)
forall s (m :: * -> *). MonadState s m => m s
get
        let (Bytes
bs, Bytes
src') = Int -> Bytes -> (Bytes, Bytes)
BS.splitAt Int
n Bytes
src
        (Bytes, Builder) -> StateT (Bytes, Builder) m ()
forall s (m :: * -> *). MonadState s m => s -> m ()
put (Bytes
src', Builder
out)
        Bytes -> StateT (Bytes, Builder) m Bytes
forall (m :: * -> *) a. Monad m => a -> m a
return Bytes
bs
    advance :: Int -> StateT (Bytes, Builder) m ()
advance Int
n = do
        (Bytes
src, Builder
out) <- StateT (Bytes, Builder) m (Bytes, Builder)
forall s (m :: * -> *). MonadState s m => m s
get
        let (Bytes
bs, Bytes
src') = Int -> Bytes -> (Bytes, Bytes)
BS.splitAt Int
n Bytes
src
        (Bytes, Builder) -> StateT (Bytes, Builder) m ()
forall s (m :: * -> *). MonadState s m => s -> m ()
put (Bytes
src', Builder
out Builder -> Builder -> Builder
forall a. Semigroup a => a -> a -> a
<> Bytes -> Builder
BB.byteString Bytes
bs)
    overwrite :: Bytes -> StateT (Bytes, Builder) m ()
overwrite Bytes
bs = do
        (Bytes
src, Builder
out) <- StateT (Bytes, Builder) m (Bytes, Builder)
forall s (m :: * -> *). MonadState s m => m s
get
        let (Bytes
_, Bytes
src') = Int -> Bytes -> (Bytes, Bytes)
BS.splitAt (Bytes -> Int
BS.length Bytes
bs) Bytes
src
        (Bytes, Builder) -> StateT (Bytes, Builder) m ()
forall s (m :: * -> *). MonadState s m => s -> m ()
put (Bytes
src', Builder
out Builder -> Builder -> Builder
forall a. Semigroup a => a -> a -> a
<> Bytes -> Builder
BB.byteString Bytes
bs)

instance MonadIO m => MonadFwdByteStream (ReaderT Handle m) where
    readahead :: Int -> ReaderT Handle m Bytes
readahead Int
n = do
        Handle
hdl <- ReaderT Handle m Handle
forall r (m :: * -> *). MonadReader r m => m r
ask
        Bytes
bs <- IO Bytes -> ReaderT Handle m Bytes
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO Bytes -> ReaderT Handle m Bytes)
-> IO Bytes -> ReaderT Handle m Bytes
forall a b. (a -> b) -> a -> b
$ Handle -> Int -> IO Bytes
BS.hGet Handle
hdl Int
n
        IO () -> ReaderT Handle m ()
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> ReaderT Handle m ()) -> IO () -> ReaderT Handle m ()
forall a b. (a -> b) -> a -> b
$ Handle -> SeekMode -> Integer -> IO ()
hSeek Handle
hdl SeekMode
RelativeSeek (- Int -> Integer
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
n)
        Bytes -> ReaderT Handle m Bytes
forall (m :: * -> *) a. Monad m => a -> m a
return Bytes
bs
    advance :: Int -> ReaderT Handle m ()
advance Int
n = do
        Handle
hdl <- ReaderT Handle m Handle
forall r (m :: * -> *). MonadReader r m => m r
ask
        IO () -> ReaderT Handle m ()
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> ReaderT Handle m ()) -> IO () -> ReaderT Handle m ()
forall a b. (a -> b) -> a -> b
$ Handle -> SeekMode -> Integer -> IO ()
hSeek Handle
hdl SeekMode
RelativeSeek (Int -> Integer
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
n)
    overwrite :: Bytes -> ReaderT Handle m ()
overwrite Bytes
bs = do
        Handle
hdl <- ReaderT Handle m Handle
forall r (m :: * -> *). MonadReader r m => m r
ask
        IO () -> ReaderT Handle m ()
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> ReaderT Handle m ()) -> IO () -> ReaderT Handle m ()
forall a b. (a -> b) -> a -> b
$ Handle -> Bytes -> IO ()
BS.hPut Handle
hdl Bytes
bs

-- | Patch time config.
data Cfg = Cfg
  { Cfg -> Bool
cfgWarnIfLikelyReprocessing :: Bool
  -- ^ If we determine that we're repatching an already-patched stream, continue
  --   with a warning instead of failing.

  , Cfg -> Bool
cfgAllowPartialExpected :: Bool
  -- ^ If enabled, allow partial expected bytes checking. If disabled, then even
  --   if the expected bytes are a prefix of the actual, fail.
  } deriving (Cfg -> Cfg -> Bool
(Cfg -> Cfg -> Bool) -> (Cfg -> Cfg -> Bool) -> Eq Cfg
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Cfg -> Cfg -> Bool
$c/= :: Cfg -> Cfg -> Bool
== :: Cfg -> Cfg -> Bool
$c== :: Cfg -> Cfg -> Bool
Eq, Int -> Cfg -> ShowS
[Cfg] -> ShowS
Cfg -> String
(Int -> Cfg -> ShowS)
-> (Cfg -> String) -> ([Cfg] -> ShowS) -> Show Cfg
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Cfg] -> ShowS
$cshowList :: [Cfg] -> ShowS
show :: Cfg -> String
$cshow :: Cfg -> String
showsPrec :: Int -> Cfg -> ShowS
$cshowsPrec :: Int -> Cfg -> ShowS
Show)

-- | Errors encountered during patch time.
data Error
  = ErrorPatchOverlong
  | ErrorPatchUnexpectedNonnull
  | ErrorPatchDidNotMatchExpected Bytes Bytes
    deriving (Error -> Error -> Bool
(Error -> Error -> Bool) -> (Error -> Error -> Bool) -> Eq Error
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Error -> Error -> Bool
$c/= :: Error -> Error -> Bool
== :: Error -> Error -> Bool
$c== :: Error -> Error -> Bool
Eq, Int -> Error -> ShowS
[Error] -> ShowS
Error -> String
(Int -> Error -> ShowS)
-> (Error -> String) -> ([Error] -> ShowS) -> Show Error
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Error] -> ShowS
$cshowList :: [Error] -> ShowS
show :: Error -> String
$cshow :: Error -> String
showsPrec :: Int -> Error -> ShowS
$cshowsPrec :: Int -> Error -> ShowS
Show)

patch :: MonadFwdByteStream m => Cfg -> Patchscript Bytes -> m (Maybe Error)
patch :: Cfg -> Patchscript Bytes -> m (Maybe Error)
patch Cfg
cfg = Patchscript Bytes -> m (Maybe Error)
forall (m :: * -> *).
MonadFwdByteStream m =>
Patchscript Bytes -> m (Maybe Error)
go
  where
    go :: Patchscript Bytes -> m (Maybe Error)
go [] = Maybe Error -> m (Maybe Error)
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe Error
forall a. Maybe a
Nothing
    go ((Int
n, Overwrite Bytes
bs OverwriteMeta Bytes
meta):Patchscript Bytes
es) = do
        Int -> m ()
forall (m :: * -> *). MonadFwdByteStream m => Int -> m ()
advance Int
n
        Bytes
bsStream <- Int -> m Bytes
forall (m :: * -> *). MonadFwdByteStream m => Int -> m Bytes
readahead (Int -> m Bytes) -> Int -> m Bytes
forall a b. (a -> b) -> a -> b
$ Bytes -> Int
BS.length Bytes
bs -- TODO catch overlong error

        -- if provided, strip trailing nulls from to-overwrite bytestring
        case Bytes -> Maybe Int -> Maybe Bytes
tryStripNulls Bytes
bsStream (OverwriteMeta Bytes -> Maybe Int
forall a. OverwriteMeta a -> Maybe Int
omNullTerminates OverwriteMeta Bytes
meta) of
          Maybe Bytes
Nothing -> Maybe Error -> m (Maybe Error)
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe Error -> m (Maybe Error)) -> Maybe Error -> m (Maybe Error)
forall a b. (a -> b) -> a -> b
$ Error -> Maybe Error
forall a. a -> Maybe a
Just Error
ErrorPatchUnexpectedNonnull
          Just Bytes
bsStream' -> do

            -- if provided, check the to-overwrite bytestring matches expected
            case Bytes -> Maybe Bytes -> Maybe (Bytes, Bytes)
checkExpected Bytes
bsStream' (OverwriteMeta Bytes -> Maybe Bytes
forall a. OverwriteMeta a -> Maybe a
omExpected OverwriteMeta Bytes
meta) of
              Just (Bytes
bsa, Bytes
bse) -> Maybe Error -> m (Maybe Error)
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe Error -> m (Maybe Error)) -> Maybe Error -> m (Maybe Error)
forall a b. (a -> b) -> a -> b
$ Error -> Maybe Error
forall a. a -> Maybe a
Just (Error -> Maybe Error) -> Error -> Maybe Error
forall a b. (a -> b) -> a -> b
$ Bytes -> Bytes -> Error
ErrorPatchDidNotMatchExpected Bytes
bsa Bytes
bse
              Maybe (Bytes, Bytes)
Nothing -> Bytes -> m ()
forall (m :: * -> *). MonadFwdByteStream m => Bytes -> m ()
overwrite Bytes
bs m () -> m (Maybe Error) -> m (Maybe Error)
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Patchscript Bytes -> m (Maybe Error)
go Patchscript Bytes
es

    tryStripNulls :: Bytes -> Maybe Int -> Maybe Bytes
tryStripNulls Bytes
bs = \case
      Maybe Int
Nothing        -> Bytes -> Maybe Bytes
forall a. a -> Maybe a
Just Bytes
bs
      Just Int
nullsFrom ->
        let (Bytes
bs', Bytes
bsNulls) = Int -> Bytes -> (Bytes, Bytes)
BS.splitAt Int
nullsFrom Bytes
bs
         in if   Bytes
bsNulls Bytes -> Bytes -> Bool
forall a. Eq a => a -> a -> Bool
== Int -> Word8 -> Bytes
BS.replicate (Bytes -> Int
BS.length Bytes
bsNulls) Word8
0x00
            then Bytes -> Maybe Bytes
forall a. a -> Maybe a
Just Bytes
bs'
            else Maybe Bytes
forall a. Maybe a
Nothing

    checkExpected :: Bytes -> Maybe Bytes -> Maybe (Bytes, Bytes)
checkExpected Bytes
bs = \case
      Maybe Bytes
Nothing -> Maybe (Bytes, Bytes)
forall a. Maybe a
Nothing
      Just Bytes
bsExpected ->
        case Cfg -> Bool
cfgAllowPartialExpected Cfg
cfg of
          Bool
True  -> if   Bytes -> Bytes -> Bool
BS.isPrefixOf Bytes
bs Bytes
bsExpected
                   then Maybe (Bytes, Bytes)
forall a. Maybe a
Nothing
                   else (Bytes, Bytes) -> Maybe (Bytes, Bytes)
forall a. a -> Maybe a
Just (Bytes
bs, Bytes
bsExpected)
          Bool
False -> if   Bytes
bs Bytes -> Bytes -> Bool
forall a. Eq a => a -> a -> Bool
== Bytes
bsExpected
                   then Maybe (Bytes, Bytes)
forall a. Maybe a
Nothing
                   else (Bytes, Bytes) -> Maybe (Bytes, Bytes)
forall a. a -> Maybe a
Just (Bytes
bs, Bytes
bsExpected)

{-
finishPurePatch :: (BS.ByteString, BB.Builder) -> BL.ByteString
finishPurePatch (src, out) = BB.toLazyByteString $ out <> BB.byteString src

tmpExPatchscript :: Patchscript
tmpExPatchscript =
  [ (1, "ABC")
  , (2, "DEFG") ]

tmpExInitialState :: (BS.ByteString, BB.Builder)
tmpExInitialState = ("abcdefghijklmnopqrstuvwxyz", mempty)
-}