{-# LANGUAGE AllowAmbiguousTypes #-}

module StreamPatch.Apply where

import GHC.Generics ( Generic )

import StreamPatch.Patch
import StreamPatch.Stream
import StreamPatch.HFunctorList
import StreamPatch.Patch.Binary qualified as Bin
import StreamPatch.Patch.Compare qualified as Compare
import StreamPatch.Patch.Compare ( Compare(..), compareTo )
import StreamPatch.Patch.Linearize.InPlace ( HasLength, getLength )

import Data.Vinyl
import Data.ByteString qualified as BS
import Data.ByteString.Builder qualified as BB
import Data.ByteString.Lazy qualified as BL
import Control.Monad.State
import StreamPatch.Util ( traverseM_ )

import Control.Monad.Except

data Error
  = ErrorCompare String
  | ErrorBinUnexpectedNonNull BS.ByteString
    deriving ((forall x. Error -> Rep Error x)
-> (forall x. Rep Error x -> Error) -> Generic Error
forall x. Rep Error x -> Error
forall x. Error -> Rep Error x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
$cto :: forall x. Rep Error x -> Error
$cfrom :: forall x. Error -> Rep Error x
Generic, Error -> Error -> Bool
(Error -> Error -> Bool) -> (Error -> Error -> Bool) -> Eq Error
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Error -> Error -> Bool
$c/= :: Error -> Error -> Bool
== :: Error -> Error -> Bool
$c== :: Error -> Error -> Bool
Eq, Int -> Error -> ShowS
[Error] -> ShowS
Error -> String
(Int -> Error -> ShowS)
-> (Error -> String) -> ([Error] -> ShowS) -> Show Error
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Error] -> ShowS
$cshowList :: [Error] -> ShowS
show :: Error -> String
$cshow :: Error -> String
showsPrec :: Int -> Error -> ShowS
$cshowsPrec :: Int -> Error -> ShowS
Show)

applyBinCompareFwd
    :: forall v m
    .  ( FwdInplaceStream m, Chunk m ~ BS.ByteString
       , Compare v BS.ByteString, Num (Index m) )
    => [Patch (Index m) '[Compare.Meta v, Bin.Meta] BS.ByteString]
    -> m (Either Error ())
applyBinCompareFwd :: forall (v :: Via) (m :: * -> *).
(FwdInplaceStream m, Chunk m ~ ByteString, Compare v ByteString,
 Num (Index m)) =>
[Patch (Index m) '[Meta v, Meta] ByteString] -> m (Either Error ())
applyBinCompareFwd = (Patch (Index m) '[Meta v, Meta] ByteString -> m (Either Error ()))
-> [Patch (Index m) '[Meta v, Meta] ByteString]
-> m (Either Error ())
forall (t :: * -> *) (f :: * -> *) (m :: * -> *) v.
(Traversable t, Applicative f, Monad m) =>
(v -> m (f ())) -> t v -> m (f ())
traverseM_ ((Patch (Index m) '[Meta v, Meta] ByteString
  -> m (Either Error ()))
 -> [Patch (Index m) '[Meta v, Meta] ByteString]
 -> m (Either Error ()))
-> (Patch (Index m) '[Meta v, Meta] ByteString
    -> m (Either Error ()))
-> [Patch (Index m) '[Meta v, Meta] ByteString]
-> m (Either Error ())
forall a b. (a -> b) -> a -> b
$ \(Patch ByteString
bs Index m
s (HFunctorList (Flap r ByteString
cm :& Flap r ByteString
bm :& Rec (Flap ByteString) rs
RNil))) -> ExceptT Error m () -> m (Either Error ())
forall e (m :: * -> *) a. ExceptT e m a -> m (Either e a)
runExceptT (ExceptT Error m () -> m (Either Error ()))
-> ExceptT Error m () -> m (Either Error ())
forall a b. (a -> b) -> a -> b
$ do
    -- advance to patch location
    m () -> ExceptT Error m ()
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m () -> ExceptT Error m ()) -> m () -> ExceptT Error m ()
forall a b. (a -> b) -> a -> b
$ Index m -> m ()
forall (m :: * -> *). FwdInplaceStream m => Index m -> m ()
advance Index m
s

    -- read same number of bytes as patch data
    ByteString
bsStream  <- m ByteString -> ExceptT Error m ByteString
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m ByteString -> ExceptT Error m ByteString)
-> m ByteString -> ExceptT Error m ByteString
forall a b. (a -> b) -> a -> b
$ Index m -> m (Chunk m)
forall (m :: * -> *). FwdInplaceStream m => Index m -> m (Chunk m)
readahead (Index m -> m (Chunk m)) -> Index m -> m (Chunk m)
forall a b. (a -> b) -> a -> b
$ Int -> Index m
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int -> Index m) -> Int -> Index m
forall a b. (a -> b) -> a -> b
$ ByteString -> Int
forall a. HasLength a => a -> Int
getLength ByteString
bs

    -- check for & strip expected terminating nulls
    ByteString
bsStream' <- ByteString -> Maybe Natural -> ExceptT Error m ByteString
doNullTermCheck ByteString
bsStream (Meta ByteString -> Maybe Natural
forall {k} (a :: k). Meta a -> Maybe Natural
Bin.mNullTerminates r ByteString
Meta ByteString
bm)

    -- compare with expected data
    ByteString -> Maybe (CompareRep v ByteString) -> ExceptT Error m ()
doCompare ByteString
bsStream' (Maybe (CompareRep v ByteString) -> ExceptT Error m ())
-> Maybe (CompareRep v ByteString) -> ExceptT Error m ()
forall a b. (a -> b) -> a -> b
$ Meta v ByteString -> Maybe (CompareRep v ByteString)
forall (v :: Via) a. Meta v a -> Maybe (CompareRep v a)
Compare.mCompare r ByteString
Meta v ByteString
cm

    -- if that was all successful, write patch in-place
    m () -> ExceptT Error m ()
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m () -> ExceptT Error m ()) -> m () -> ExceptT Error m ()
forall a b. (a -> b) -> a -> b
$ Chunk m -> m ()
forall (m :: * -> *). FwdInplaceStream m => Chunk m -> m ()
overwrite ByteString
Chunk m
bs
  where
    err :: Error -> ExceptT Error m a
err = Error -> ExceptT Error m a
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError
    doCompare :: ByteString -> Maybe (CompareRep v ByteString) -> ExceptT Error m ()
doCompare ByteString
bs' = \case
      Maybe (CompareRep v ByteString)
Nothing   -> () -> ExceptT Error m ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
      Just CompareRep v ByteString
cmp -> do
        case forall (v :: Via) a.
Compare v a =>
CompareRep v a -> a -> Maybe String
compareTo @v CompareRep v ByteString
cmp ByteString
bs' of
          Maybe String
Nothing -> () -> ExceptT Error m ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
          Just String
e -> Error -> ExceptT Error m ()
forall {a}. Error -> ExceptT Error m a
err (Error -> ExceptT Error m ()) -> Error -> ExceptT Error m ()
forall a b. (a -> b) -> a -> b
$ String -> Error
ErrorCompare String
e
    doNullTermCheck :: ByteString -> Maybe Natural -> ExceptT Error m ByteString
doNullTermCheck ByteString
bs' = \case
      Maybe Natural
Nothing -> ByteString -> ExceptT Error m ByteString
forall (m :: * -> *) a. Monad m => a -> m a
return ByteString
bs'
      Just Natural
nt ->
        let (ByteString
bs'', ByteString
bsNulls) = Int -> ByteString -> (ByteString, ByteString)
BS.splitAt (Natural -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral Natural
nt) ByteString
bs'
         in if   ByteString
bsNulls ByteString -> ByteString -> Bool
forall a. Eq a => a -> a -> Bool
== Int -> Word8 -> ByteString
BS.replicate (ByteString -> Int
BS.length ByteString
bsNulls) Word8
0x00
            then ByteString -> ExceptT Error m ByteString
forall (m :: * -> *) a. Monad m => a -> m a
return ByteString
bs''
            else Error -> ExceptT Error m ByteString
forall {a}. Error -> ExceptT Error m a
err (Error -> ExceptT Error m ByteString)
-> Error -> ExceptT Error m ByteString
forall a b. (a -> b) -> a -> b
$ ByteString -> Error
ErrorBinUnexpectedNonNull ByteString
bs'

runPureBinCompareFwd
    :: (Compare v BS.ByteString)
    => [Patch Int '[Compare.Meta v, Bin.Meta] BS.ByteString]
    -> BS.ByteString
    -> Either Error BL.ByteString
runPureBinCompareFwd :: forall (v :: Via).
Compare v ByteString =>
[Patch Int '[Meta v, Meta] ByteString]
-> ByteString -> Either Error ByteString
runPureBinCompareFwd [Patch Int '[Meta v, Meta] ByteString]
ps ByteString
bs =
    let initState :: (ByteString, Builder, Int)
initState = (ByteString
bs, Builder
forall a. Monoid a => a
mempty :: BB.Builder, Int
0 :: Int)
        (Either Error ()
mErr, (ByteString
bsRemaining, Builder
bbPatched, Int
_)) = State (ByteString, Builder, Int) (Either Error ())
-> (ByteString, Builder, Int)
-> (Either Error (), (ByteString, Builder, Int))
forall s a. State s a -> s -> (a, s)
runState ([Patch
   (Index (StateT (ByteString, Builder, Int) Identity))
   '[Meta v, Meta]
   ByteString]
-> State (ByteString, Builder, Int) (Either Error ())
forall (v :: Via) (m :: * -> *).
(FwdInplaceStream m, Chunk m ~ ByteString, Compare v ByteString,
 Num (Index m)) =>
[Patch (Index m) '[Meta v, Meta] ByteString] -> m (Either Error ())
applyBinCompareFwd [Patch Int '[Meta v, Meta] ByteString]
[Patch
   (Index (StateT (ByteString, Builder, Int) Identity))
   '[Meta v, Meta]
   ByteString]
ps) (ByteString, Builder, Int)
initState
        bbPatched' :: Builder
bbPatched' = Builder
bbPatched Builder -> Builder -> Builder
forall a. Semigroup a => a -> a -> a
<> ByteString -> Builder
BB.byteString ByteString
bsRemaining
     in case Either Error ()
mErr of
          Left  Error
e  -> Error -> Either Error ByteString
forall a b. a -> Either a b
Left Error
e
          Right () -> ByteString -> Either Error ByteString
forall a b. b -> Either a b
Right (ByteString -> Either Error ByteString)
-> ByteString -> Either Error ByteString
forall a b. (a -> b) -> a -> b
$ Builder -> ByteString
BB.toLazyByteString Builder
bbPatched'

applyFwd
    :: (FwdInplaceStream m, Chunk m ~ a)
    => [Patch (Index m) '[] a]
    -> m ()
applyFwd :: forall (m :: * -> *) a.
(FwdInplaceStream m, Chunk m ~ a) =>
[Patch (Index m) '[] a] -> m ()
applyFwd =
    (Patch (Index m) '[] a -> m ()) -> [Patch (Index m) '[] a] -> m ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ ((Patch (Index m) '[] a -> m ())
 -> [Patch (Index m) '[] a] -> m ())
-> (Patch (Index m) '[] a -> m ())
-> [Patch (Index m) '[] a]
-> m ()
forall a b. (a -> b) -> a -> b
$ \(Patch a
a Index m
s (HFunctorList Rec (Flap a) '[]
RNil)) ->
        Index m -> m ()
forall (m :: * -> *). FwdInplaceStream m => Index m -> m ()
advance Index m
s m () -> m () -> m ()
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Chunk m -> m ()
forall (m :: * -> *). FwdInplaceStream m => Chunk m -> m ()
overwrite a
Chunk m
a

runPureFwdList
    :: [Patch Int '[] [a]]
    -> [a]
    -> [a]
runPureFwdList :: forall a. [Patch Int '[] [a]] -> [a] -> [a]
runPureFwdList [Patch Int '[] [a]]
ps [a]
start =
    let ((), ([a]
remaining, [a]
patched, Int
_)) = State ([a], [a], Int) ()
-> ([a], [a], Int) -> ((), ([a], [a], Int))
forall s a. State s a -> s -> (a, s)
runState ([Patch (Index (StateT ([a], [a], Int) Identity)) '[] [a]]
-> State ([a], [a], Int) ()
forall (m :: * -> *) a.
(FwdInplaceStream m, Chunk m ~ a) =>
[Patch (Index m) '[] a] -> m ()
applyFwd [Patch Int '[] [a]]
[Patch (Index (StateT ([a], [a], Int) Identity)) '[] [a]]
ps) ([a]
start, [a]
forall a. Monoid a => a
mempty, Int
0 :: Int)
     in [a]
patched [a] -> [a] -> [a]
forall a. Semigroup a => a -> a -> a
<> [a]
remaining

applyFwdCompare
    :: forall a v m
    .  ( FwdInplaceStream m, Chunk m ~ a
       , Compare v a, HasLength a, Num (Index m) )
    => [Patch (Index m) '[Compare.Meta v] a]
    -> m (Either Error ())
applyFwdCompare :: forall a (v :: Via) (m :: * -> *).
(FwdInplaceStream m, Chunk m ~ a, Compare v a, HasLength a,
 Num (Index m)) =>
[Patch (Index m) '[Meta v] a] -> m (Either Error ())
applyFwdCompare = (Patch (Index m) '[Meta v] a -> m (Either Error ()))
-> [Patch (Index m) '[Meta v] a] -> m (Either Error ())
forall (t :: * -> *) (f :: * -> *) (m :: * -> *) v.
(Traversable t, Applicative f, Monad m) =>
(v -> m (f ())) -> t v -> m (f ())
traverseM_ ((Patch (Index m) '[Meta v] a -> m (Either Error ()))
 -> [Patch (Index m) '[Meta v] a] -> m (Either Error ()))
-> (Patch (Index m) '[Meta v] a -> m (Either Error ()))
-> [Patch (Index m) '[Meta v] a]
-> m (Either Error ())
forall a b. (a -> b) -> a -> b
$ \(Patch a
a Index m
s (HFunctorList (Flap r a
cm :& Rec (Flap a) rs
RNil))) -> do
    Index m -> m ()
forall (m :: * -> *). FwdInplaceStream m => Index m -> m ()
advance Index m
s
    a
aStream <- Index m -> m (Chunk m)
forall (m :: * -> *). FwdInplaceStream m => Index m -> m (Chunk m)
readahead (Index m -> m (Chunk m)) -> Index m -> m (Chunk m)
forall a b. (a -> b) -> a -> b
$ Int -> Index m
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int -> Index m) -> Int -> Index m
forall a b. (a -> b) -> a -> b
$ a -> Int
forall a. HasLength a => a -> Int
getLength a
a
    case Meta v a -> Maybe (CompareRep v a)
forall (v :: Via) a. Meta v a -> Maybe (CompareRep v a)
Compare.mCompare r a
Meta v a
cm of
      Maybe (CompareRep v a)
Nothing   -> do
        ()
x <- Chunk m -> m ()
forall (m :: * -> *). FwdInplaceStream m => Chunk m -> m ()
overwrite a
Chunk m
a
        Either Error () -> m (Either Error ())
forall (m :: * -> *) a. Monad m => a -> m a
return (Either Error () -> m (Either Error ()))
-> Either Error () -> m (Either Error ())
forall a b. (a -> b) -> a -> b
$ () -> Either Error ()
forall a b. b -> Either a b
Right ()
x
      Just CompareRep v a
aCmp -> case forall (v :: Via) a.
Compare v a =>
CompareRep v a -> a -> Maybe String
compareTo @v CompareRep v a
aCmp a
aStream of
                     Maybe String
Nothing -> Either Error () -> m (Either Error ())
forall (m :: * -> *) a. Monad m => a -> m a
return (Either Error () -> m (Either Error ()))
-> Either Error () -> m (Either Error ())
forall a b. (a -> b) -> a -> b
$ () -> Either Error ()
forall a b. b -> Either a b
Right ()
                     Just String
e  -> Either Error () -> m (Either Error ())
forall (m :: * -> *) a. Monad m => a -> m a
return (Either Error () -> m (Either Error ()))
-> Either Error () -> m (Either Error ())
forall a b. (a -> b) -> a -> b
$ Error -> Either Error ()
forall a b. a -> Either a b
Left (Error -> Either Error ()) -> Error -> Either Error ()
forall a b. (a -> b) -> a -> b
$ String -> Error
ErrorCompare String
e

runPureFwdCompareString
    :: Compare v String
    => [Patch Int '[Compare.Meta v] String]
    -> String
    -> Either Error String
runPureFwdCompareString :: forall (v :: Via).
Compare v String =>
[Patch Int '[Meta v] String] -> String -> Either Error String
runPureFwdCompareString [Patch Int '[Meta v] String]
ps String
start =
    let (Either Error ()
r, (String
remaining, String
patched, Int
_)) = State (String, String, Int) (Either Error ())
-> (String, String, Int)
-> (Either Error (), (String, String, Int))
forall s a. State s a -> s -> (a, s)
runState ([Patch
   (Index (StateT (String, String, Int) Identity)) '[Meta v] String]
-> State (String, String, Int) (Either Error ())
forall a (v :: Via) (m :: * -> *).
(FwdInplaceStream m, Chunk m ~ a, Compare v a, HasLength a,
 Num (Index m)) =>
[Patch (Index m) '[Meta v] a] -> m (Either Error ())
applyFwdCompare [Patch Int '[Meta v] String]
[Patch
   (Index (StateT (String, String, Int) Identity)) '[Meta v] String]
ps) (String
start, String
"", Int
0 :: Int)
    in  case Either Error ()
r of
          Left Error
err -> Error -> Either Error String
forall a b. a -> Either a b
Left Error
err
          Right () -> String -> Either Error String
forall a b. b -> Either a b
Right (String -> Either Error String) -> String -> Either Error String
forall a b. (a -> b) -> a -> b
$ String
patched String -> ShowS
forall a. Semigroup a => a -> a -> a
<> String
remaining