{- | Stream monads.

By stream, I mean a container indexed by the naturals(/integers). Note that not
all parts of streampatch are limited to this, but it's an extremely useful
invariant, without which patch linearization & application gets a whole lot more
complex. So for now, streams it is.

These are designed to support easy pure and impure implementations. That's the
reasoning behind forward-only streams, and separating overwrites (easy impure,
mid pure) from inserts (hard impure, ease pure).
-}

module StreamPatch.Stream where

import Data.Kind
import Data.ByteString qualified as B
import Data.ByteString.Builder qualified as BB
import Control.Monad.State
import Control.Monad.Reader
import System.IO qualified as IO
import Data.List qualified as List

-- | Streams supporting forward seeking and in-place edits (length never
--   changes).
class Monad m => FwdInplaceStream m where
    type Chunk m :: Type

    -- | The unsigned integral type used for indexing. 'Int' has a sign, but is
    --   often used internally; 'Integer' also gets some use; 'Natural' is the
    --   most mathematically honest. I leave the decision up to the instance to
    --   allow them to be as efficient as possible.
    type Index m :: Type

    -- | Read a number of elements forward without moving the cursor.
    --
    -- Argument must be positive.
    readahead :: Index m -> m (Chunk m)

    -- | Overlay a chunk onto the stream at the current cursor position,
    --   overwriting existing elements.
    --
    -- Moves the cursor to the right by the length of the chunk.
    overwrite :: Chunk m -> m ()

    -- | Move cursor forwards without reading. Must be positive.
    advance :: Index m -> m ()

    -- | Get the current cursor position.
    --
    -- Intended for error messages.
    getCursor :: m (Index m)

instance Monad m => FwdInplaceStream (StateT (B.ByteString, BB.Builder, Int) m) where
    type Chunk (StateT (B.ByteString, BB.Builder, Int) m) = B.ByteString
    type Index (StateT (B.ByteString, BB.Builder, Int) m) = Int
    readahead :: Index (StateT (ByteString, Builder, Int) m)
-> StateT
     (ByteString, Builder, Int)
     m
     (Chunk (StateT (ByteString, Builder, Int) m))
readahead Index (StateT (ByteString, Builder, Int) m)
n = StateT (ByteString, Builder, Int) m (ByteString, Builder, Int)
forall s (m :: * -> *). MonadState s m => m s
get StateT (ByteString, Builder, Int) m (ByteString, Builder, Int)
-> ((ByteString, Builder, Int)
    -> StateT (ByteString, Builder, Int) m ByteString)
-> StateT (ByteString, Builder, Int) m ByteString
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \(ByteString
src, Builder
_, Int
_) -> ByteString -> StateT (ByteString, Builder, Int) m ByteString
forall (m :: * -> *) a. Monad m => a -> m a
return (ByteString -> StateT (ByteString, Builder, Int) m ByteString)
-> ByteString -> StateT (ByteString, Builder, Int) m ByteString
forall a b. (a -> b) -> a -> b
$ Int -> ByteString -> ByteString
B.take Int
Index (StateT (ByteString, Builder, Int) m)
n ByteString
src
    overwrite :: Chunk (StateT (ByteString, Builder, Int) m)
-> StateT (ByteString, Builder, Int) m ()
overwrite Chunk (StateT (ByteString, Builder, Int) m)
bs = do
        (ByteString
src, Builder
out, Int
pos) <- StateT (ByteString, Builder, Int) m (ByteString, Builder, Int)
forall s (m :: * -> *). MonadState s m => m s
get
        let (ByteString
_, ByteString
src') = Int -> ByteString -> (ByteString, ByteString)
B.splitAt (ByteString -> Int
B.length ByteString
Chunk (StateT (ByteString, Builder, Int) m)
bs) ByteString
src
            out' :: Builder
out' = Builder
out Builder -> Builder -> Builder
forall a. Semigroup a => a -> a -> a
<> ByteString -> Builder
BB.byteString ByteString
Chunk (StateT (ByteString, Builder, Int) m)
bs
            pos' :: Int
pos' = Int
pos Int -> Int -> Int
forall a. Num a => a -> a -> a
+ ByteString -> Int
B.length ByteString
Chunk (StateT (ByteString, Builder, Int) m)
bs
        (ByteString, Builder, Int)
-> StateT (ByteString, Builder, Int) m ()
forall s (m :: * -> *). MonadState s m => s -> m ()
put (ByteString
src', Builder
out', Int
pos')
    advance :: Index (StateT (ByteString, Builder, Int) m)
-> StateT (ByteString, Builder, Int) m ()
advance Index (StateT (ByteString, Builder, Int) m)
n = do
        (ByteString
src, Builder
out, Int
pos) <- StateT (ByteString, Builder, Int) m (ByteString, Builder, Int)
forall s (m :: * -> *). MonadState s m => m s
get
        let (ByteString
bs, ByteString
src') = Int -> ByteString -> (ByteString, ByteString)
B.splitAt Int
Index (StateT (ByteString, Builder, Int) m)
n ByteString
src
            out' :: Builder
out' = Builder
out Builder -> Builder -> Builder
forall a. Semigroup a => a -> a -> a
<> ByteString -> Builder
BB.byteString ByteString
bs
            pos' :: Int
pos' = Int
pos Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
Index (StateT (ByteString, Builder, Int) m)
n
        (ByteString, Builder, Int)
-> StateT (ByteString, Builder, Int) m ()
forall s (m :: * -> *). MonadState s m => s -> m ()
put (ByteString
src', Builder
out', Int
pos')
    getCursor :: StateT
  (ByteString, Builder, Int)
  m
  (Index (StateT (ByteString, Builder, Int) m))
getCursor = StateT (ByteString, Builder, Int) m (ByteString, Builder, Int)
forall s (m :: * -> *). MonadState s m => m s
get StateT (ByteString, Builder, Int) m (ByteString, Builder, Int)
-> ((ByteString, Builder, Int)
    -> StateT (ByteString, Builder, Int) m Int)
-> StateT (ByteString, Builder, Int) m Int
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \(ByteString
_, Builder
_, Int
pos) -> Int -> StateT (ByteString, Builder, Int) m Int
forall (m :: * -> *) a. Monad m => a -> m a
return Int
pos

instance MonadIO m => FwdInplaceStream (ReaderT IO.Handle m) where
    type Chunk (ReaderT IO.Handle m) = B.ByteString
    type Index (ReaderT IO.Handle m) = Integer
    readahead :: Index (ReaderT Handle m)
-> ReaderT Handle m (Chunk (ReaderT Handle m))
readahead Index (ReaderT Handle m)
n = do
        Handle
hdl <- ReaderT Handle m Handle
forall r (m :: * -> *). MonadReader r m => m r
ask
        ByteString
bs <- IO ByteString -> ReaderT Handle m ByteString
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO ByteString -> ReaderT Handle m ByteString)
-> IO ByteString -> ReaderT Handle m ByteString
forall a b. (a -> b) -> a -> b
$ Handle -> Int -> IO ByteString
B.hGet Handle
hdl (Int -> IO ByteString) -> Int -> IO ByteString
forall a b. (a -> b) -> a -> b
$ Integer -> Int
forall a. Num a => Integer -> a
fromInteger Integer
Index (ReaderT Handle m)
n -- TODO Integer -> Int :(
        IO () -> ReaderT Handle m ()
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> ReaderT Handle m ()) -> IO () -> ReaderT Handle m ()
forall a b. (a -> b) -> a -> b
$ Handle -> SeekMode -> Integer -> IO ()
IO.hSeek Handle
hdl SeekMode
IO.RelativeSeek (-Integer
Index (ReaderT Handle m)
n)
        ByteString -> ReaderT Handle m ByteString
forall (f :: * -> *) a. Applicative f => a -> f a
pure ByteString
bs
    overwrite :: Chunk (ReaderT Handle m) -> ReaderT Handle m ()
overwrite Chunk (ReaderT Handle m)
bs = do
        Handle
hdl <- ReaderT Handle m Handle
forall r (m :: * -> *). MonadReader r m => m r
ask
        IO () -> ReaderT Handle m ()
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> ReaderT Handle m ()) -> IO () -> ReaderT Handle m ()
forall a b. (a -> b) -> a -> b
$ Handle -> ByteString -> IO ()
B.hPut Handle
hdl ByteString
Chunk (ReaderT Handle m)
bs
    advance :: Index (ReaderT Handle m) -> ReaderT Handle m ()
advance Index (ReaderT Handle m)
n = do
        Handle
hdl <- ReaderT Handle m Handle
forall r (m :: * -> *). MonadReader r m => m r
ask
        IO () -> ReaderT Handle m ()
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> ReaderT Handle m ()) -> IO () -> ReaderT Handle m ()
forall a b. (a -> b) -> a -> b
$ Handle -> SeekMode -> Integer -> IO ()
IO.hSeek Handle
hdl SeekMode
IO.RelativeSeek Integer
Index (ReaderT Handle m)
n
    getCursor :: ReaderT Handle m (Index (ReaderT Handle m))
getCursor = do
        Handle
hdl <- ReaderT Handle m Handle
forall r (m :: * -> *). MonadReader r m => m r
ask
        Integer
pos <- IO Integer -> ReaderT Handle m Integer
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO Integer -> ReaderT Handle m Integer)
-> IO Integer -> ReaderT Handle m Integer
forall a b. (a -> b) -> a -> b
$ Handle -> IO Integer
IO.hTell Handle
hdl
        Integer -> ReaderT Handle m Integer
forall (m :: * -> *) a. Monad m => a -> m a
return (Integer -> ReaderT Handle m Integer)
-> Integer -> ReaderT Handle m Integer
forall a b. (a -> b) -> a -> b
$ Integer -> Integer
forall a. Num a => Integer -> a
fromInteger Integer
pos

-- TODO Need MonoTraversable to define for Text, ByteString etc easily. Bleh. I
-- think Snoyman's advice is to reimplement. Also bleh.
instance Monad m => FwdInplaceStream (StateT ([a], [a], Int) m) where
    type Chunk (StateT ([a], [a], Int) m) = [a]
    type Index (StateT ([a], [a], Int) m) = Int
    readahead :: Index (StateT ([a], [a], Int) m)
-> StateT ([a], [a], Int) m (Chunk (StateT ([a], [a], Int) m))
readahead Index (StateT ([a], [a], Int) m)
n = StateT ([a], [a], Int) m ([a], [a], Int)
forall s (m :: * -> *). MonadState s m => m s
get StateT ([a], [a], Int) m ([a], [a], Int)
-> (([a], [a], Int) -> StateT ([a], [a], Int) m [a])
-> StateT ([a], [a], Int) m [a]
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \([a]
src, [a]
_, Int
_) -> [a] -> StateT ([a], [a], Int) m [a]
forall (m :: * -> *) a. Monad m => a -> m a
return ([a] -> StateT ([a], [a], Int) m [a])
-> [a] -> StateT ([a], [a], Int) m [a]
forall a b. (a -> b) -> a -> b
$ Int -> [a] -> [a]
forall a. Int -> [a] -> [a]
List.take Int
Index (StateT ([a], [a], Int) m)
n [a]
src
    overwrite :: Chunk (StateT ([a], [a], Int) m) -> StateT ([a], [a], Int) m ()
overwrite Chunk (StateT ([a], [a], Int) m)
bs = do
        ([a]
src, [a]
out, Int
pos) <- StateT ([a], [a], Int) m ([a], [a], Int)
forall s (m :: * -> *). MonadState s m => m s
get
        let ([a]
_, [a]
src') = Int -> [a] -> ([a], [a])
forall a. Int -> [a] -> ([a], [a])
List.splitAt ([a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
List.length [a]
Chunk (StateT ([a], [a], Int) m)
bs) [a]
src
            out' :: [a]
out' = [a]
out [a] -> [a] -> [a]
forall a. Semigroup a => a -> a -> a
<> [a]
Chunk (StateT ([a], [a], Int) m)
bs
            pos' :: Int
pos' = Int
pos Int -> Int -> Int
forall a. Num a => a -> a -> a
+ [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
List.length [a]
Chunk (StateT ([a], [a], Int) m)
bs
        ([a], [a], Int) -> StateT ([a], [a], Int) m ()
forall s (m :: * -> *). MonadState s m => s -> m ()
put ([a]
src', [a]
out', Int
pos')
    advance :: Index (StateT ([a], [a], Int) m) -> StateT ([a], [a], Int) m ()
advance Index (StateT ([a], [a], Int) m)
n = do
        ([a]
src, [a]
out, Int
pos) <- StateT ([a], [a], Int) m ([a], [a], Int)
forall s (m :: * -> *). MonadState s m => m s
get
        let ([a]
bs, [a]
src') = Int -> [a] -> ([a], [a])
forall a. Int -> [a] -> ([a], [a])
List.splitAt (Int -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
Index (StateT ([a], [a], Int) m)
n) [a]
src
            out' :: [a]
out' = [a]
out [a] -> [a] -> [a]
forall a. Semigroup a => a -> a -> a
<> [a]
bs
            pos' :: Int
pos' = Int
pos Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
Index (StateT ([a], [a], Int) m)
n
        ([a], [a], Int) -> StateT ([a], [a], Int) m ()
forall s (m :: * -> *). MonadState s m => s -> m ()
put ([a]
src', [a]
out', Int
pos')
    getCursor :: StateT ([a], [a], Int) m (Index (StateT ([a], [a], Int) m))
getCursor = StateT ([a], [a], Int) m ([a], [a], Int)
forall s (m :: * -> *). MonadState s m => m s
get StateT ([a], [a], Int) m ([a], [a], Int)
-> (([a], [a], Int) -> StateT ([a], [a], Int) m Int)
-> StateT ([a], [a], Int) m Int
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \([a]
_, [a]
_, Int
pos) -> Int -> StateT ([a], [a], Int) m Int
forall (m :: * -> *) a. Monad m => a -> m a
return Int
pos

-- | Streams supporting forward seeking and arbitrary edits.
class FwdInplaceStream m => FwdStream m where
    -- | Write a chunk into the stream at the current cursor position.
    --
    -- Moves the cursor to the right by the length of the chunk.
    write :: Chunk m -> m ()

    -- | Delete a sized chunk at the current cursor position.
    --
    -- Argument must be positive.
    delete :: Index m -> m ()

instance Monad m => FwdStream (StateT (B.ByteString, BB.Builder, Int) m) where
    write :: Chunk (StateT (ByteString, Builder, Int) m)
-> StateT (ByteString, Builder, Int) m ()
write Chunk (StateT (ByteString, Builder, Int) m)
bs = do
        (ByteString
src, Builder
out, Int
pos) <- StateT (ByteString, Builder, Int) m (ByteString, Builder, Int)
forall s (m :: * -> *). MonadState s m => m s
get
        let out' :: Builder
out' = Builder
out Builder -> Builder -> Builder
forall a. Semigroup a => a -> a -> a
<> ByteString -> Builder
BB.byteString ByteString
Chunk (StateT (ByteString, Builder, Int) m)
bs
            pos' :: Int
pos' = Int
pos Int -> Int -> Int
forall a. Num a => a -> a -> a
+ ByteString -> Int
B.length ByteString
Chunk (StateT (ByteString, Builder, Int) m)
bs
        (ByteString, Builder, Int)
-> StateT (ByteString, Builder, Int) m ()
forall s (m :: * -> *). MonadState s m => s -> m ()
put (ByteString
src, Builder
out', Int
pos')
    delete :: Index (StateT (ByteString, Builder, Int) m)
-> StateT (ByteString, Builder, Int) m ()
delete Index (StateT (ByteString, Builder, Int) m)
n = do
        (ByteString
src, Builder
out, Int
pos) <- StateT (ByteString, Builder, Int) m (ByteString, Builder, Int)
forall s (m :: * -> *). MonadState s m => m s
get
        let src' :: ByteString
src' = Int -> ByteString -> ByteString
B.drop Int
Index (StateT (ByteString, Builder, Int) m)
n ByteString
src
        (ByteString, Builder, Int)
-> StateT (ByteString, Builder, Int) m ()
forall s (m :: * -> *). MonadState s m => s -> m ()
put (ByteString
src', Builder
out, Int
pos)