{- | Stream monads.

By stream, I mean a container indexed by the naturals(/integers). Note that not
all parts of streampatch are limited to this, but it's an extremely useful
invariant, without which patch linearization & application gets a whole lot more
complex. So for now, streams it is.

These are designed to support easy pure and impure implementations. That's the
reasoning behind forward-only streams, and separating overwrites (easy impure,
mid pure) from inserts (hard impure, ease pure).
-}

module StreamPatch.Stream where

import Data.Kind
import Data.ByteString qualified as B
import Data.ByteString.Builder qualified as BB
import Control.Monad.State
import Control.Monad.Reader
import System.IO qualified as IO
import Data.List qualified as List

-- | Streams supporting forward seeking and in-place edits (length never
--   changes).
class Monad m => FwdInplaceStream m where
    type Chunk m :: Type

    -- | The unsigned integral type used for indexing. 'Int' has a sign, but is
    --   often used internally; 'Integer' also gets some use; 'Natural' is the
    --   most mathematically honest. I leave the decision up to the instance to
    --   allow them to be as efficient as possible.
    type Index m :: Type

    -- | Read a number of elements forward without moving the cursor.
    --
    -- Argument must be positive.
    readahead :: Index m -> m (Chunk m)

    -- | Overlay a chunk onto the stream at the current cursor position,
    --   overwriting existing elements.
    --
    -- Moves the cursor to the right by the length of the chunk.
    overwrite :: Chunk m -> m ()

    -- | Move cursor forwards without reading. Must be positive.
    advance :: Index m -> m ()

    -- | Get the current cursor position.
    --
    -- Intended for error messages.
    getCursor :: m (Index m)

instance Monad m => FwdInplaceStream (StateT (B.ByteString, BB.Builder, Int) m) where
    type Chunk (StateT (B.ByteString, BB.Builder, Int) m) = B.ByteString
    type Index (StateT (B.ByteString, BB.Builder, Int) m) = Int
    readahead :: Index (StateT (ByteString, Builder, Int) m)
-> StateT
     (ByteString, Builder, Int)
     m
     (Chunk (StateT (ByteString, Builder, Int) m))
readahead Index (StateT (ByteString, Builder, Int) m)
n = forall s (m :: * -> *). MonadState s m => m s
get forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \(ByteString
src, Builder
_, Int
_) -> forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ Int -> ByteString -> ByteString
B.take Index (StateT (ByteString, Builder, Int) m)
n ByteString
src
    overwrite :: Chunk (StateT (ByteString, Builder, Int) m)
-> StateT (ByteString, Builder, Int) m ()
overwrite Chunk (StateT (ByteString, Builder, Int) m)
bs = do
        (ByteString
src, Builder
out, Int
pos) <- forall s (m :: * -> *). MonadState s m => m s
get
        let (ByteString
_, ByteString
src') = Int -> ByteString -> (ByteString, ByteString)
B.splitAt (ByteString -> Int
B.length Chunk (StateT (ByteString, Builder, Int) m)
bs) ByteString
src
            out' :: Builder
out' = Builder
out forall a. Semigroup a => a -> a -> a
<> ByteString -> Builder
BB.byteString Chunk (StateT (ByteString, Builder, Int) m)
bs
            pos' :: Int
pos' = Int
pos forall a. Num a => a -> a -> a
+ ByteString -> Int
B.length Chunk (StateT (ByteString, Builder, Int) m)
bs
        forall s (m :: * -> *). MonadState s m => s -> m ()
put (ByteString
src', Builder
out', Int
pos')
    advance :: Index (StateT (ByteString, Builder, Int) m)
-> StateT (ByteString, Builder, Int) m ()
advance Index (StateT (ByteString, Builder, Int) m)
n = do
        (ByteString
src, Builder
out, Int
pos) <- forall s (m :: * -> *). MonadState s m => m s
get
        let (ByteString
bs, ByteString
src') = Int -> ByteString -> (ByteString, ByteString)
B.splitAt Index (StateT (ByteString, Builder, Int) m)
n ByteString
src
            out' :: Builder
out' = Builder
out forall a. Semigroup a => a -> a -> a
<> ByteString -> Builder
BB.byteString ByteString
bs
            pos' :: Int
pos' = Int
pos forall a. Num a => a -> a -> a
+ Index (StateT (ByteString, Builder, Int) m)
n
        forall s (m :: * -> *). MonadState s m => s -> m ()
put (ByteString
src', Builder
out', Int
pos')
    getCursor :: StateT
  (ByteString, Builder, Int)
  m
  (Index (StateT (ByteString, Builder, Int) m))
getCursor = forall s (m :: * -> *). MonadState s m => m s
get forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \(ByteString
_, Builder
_, Int
pos) -> forall (f :: * -> *) a. Applicative f => a -> f a
pure Int
pos

instance MonadIO m => FwdInplaceStream (ReaderT IO.Handle m) where
    type Chunk (ReaderT IO.Handle m) = B.ByteString
    type Index (ReaderT IO.Handle m) = Integer
    readahead :: Index (ReaderT Handle m)
-> ReaderT Handle m (Chunk (ReaderT Handle m))
readahead Index (ReaderT Handle m)
n = do
        Handle
hdl <- forall r (m :: * -> *). MonadReader r m => m r
ask
        ByteString
bs <- forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO forall a b. (a -> b) -> a -> b
$ Handle -> Int -> IO ByteString
B.hGet Handle
hdl forall a b. (a -> b) -> a -> b
$ forall a. Num a => Integer -> a
fromInteger Index (ReaderT Handle m)
n -- TODO Integer -> Int :(
        forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO forall a b. (a -> b) -> a -> b
$ Handle -> SeekMode -> Integer -> IO ()
IO.hSeek Handle
hdl SeekMode
IO.RelativeSeek (-Index (ReaderT Handle m)
n)
        forall (f :: * -> *) a. Applicative f => a -> f a
pure ByteString
bs
    overwrite :: Chunk (ReaderT Handle m) -> ReaderT Handle m ()
overwrite Chunk (ReaderT Handle m)
bs = do
        Handle
hdl <- forall r (m :: * -> *). MonadReader r m => m r
ask
        forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO forall a b. (a -> b) -> a -> b
$ Handle -> ByteString -> IO ()
B.hPut Handle
hdl Chunk (ReaderT Handle m)
bs
    advance :: Index (ReaderT Handle m) -> ReaderT Handle m ()
advance Index (ReaderT Handle m)
n = do
        Handle
hdl <- forall r (m :: * -> *). MonadReader r m => m r
ask
        forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO forall a b. (a -> b) -> a -> b
$ Handle -> SeekMode -> Integer -> IO ()
IO.hSeek Handle
hdl SeekMode
IO.RelativeSeek Index (ReaderT Handle m)
n
    getCursor :: ReaderT Handle m (Index (ReaderT Handle m))
getCursor = do
        Handle
hdl <- forall r (m :: * -> *). MonadReader r m => m r
ask
        Integer
pos <- forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO forall a b. (a -> b) -> a -> b
$ Handle -> IO Integer
IO.hTell Handle
hdl
        forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall a. Num a => Integer -> a
fromInteger Integer
pos

-- TODO Need MonoTraversable to define for Text, ByteString etc easily. Bleh. I
-- think Snoyman's advice is to reimplement. Also bleh.
instance Monad m => FwdInplaceStream (StateT ([a], [a], Int) m) where
    type Chunk (StateT ([a], [a], Int) m) = [a]
    type Index (StateT ([a], [a], Int) m) = Int
    readahead :: Index (StateT ([a], [a], Int) m)
-> StateT ([a], [a], Int) m (Chunk (StateT ([a], [a], Int) m))
readahead Index (StateT ([a], [a], Int) m)
n = forall s (m :: * -> *). MonadState s m => m s
get forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \([a]
src, [a]
_, Int
_) -> forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall a. Int -> [a] -> [a]
List.take Index (StateT ([a], [a], Int) m)
n [a]
src
    overwrite :: Chunk (StateT ([a], [a], Int) m) -> StateT ([a], [a], Int) m ()
overwrite Chunk (StateT ([a], [a], Int) m)
bs = do
        ([a]
src, [a]
out, Int
pos) <- forall s (m :: * -> *). MonadState s m => m s
get
        let ([a]
_, [a]
src') = forall a. Int -> [a] -> ([a], [a])
List.splitAt (forall (t :: * -> *) a. Foldable t => t a -> Int
List.length Chunk (StateT ([a], [a], Int) m)
bs) [a]
src
            out' :: [a]
out' = [a]
out forall a. Semigroup a => a -> a -> a
<> Chunk (StateT ([a], [a], Int) m)
bs
            pos' :: Int
pos' = Int
pos forall a. Num a => a -> a -> a
+ forall (t :: * -> *) a. Foldable t => t a -> Int
List.length Chunk (StateT ([a], [a], Int) m)
bs
        forall s (m :: * -> *). MonadState s m => s -> m ()
put ([a]
src', [a]
out', Int
pos')
    advance :: Index (StateT ([a], [a], Int) m) -> StateT ([a], [a], Int) m ()
advance Index (StateT ([a], [a], Int) m)
n = do
        ([a]
src, [a]
out, Int
pos) <- forall s (m :: * -> *). MonadState s m => m s
get
        let ([a]
bs, [a]
src') = forall a. Int -> [a] -> ([a], [a])
List.splitAt (forall a b. (Integral a, Num b) => a -> b
fromIntegral Index (StateT ([a], [a], Int) m)
n) [a]
src
            out' :: [a]
out' = [a]
out forall a. Semigroup a => a -> a -> a
<> [a]
bs
            pos' :: Int
pos' = Int
pos forall a. Num a => a -> a -> a
+ Index (StateT ([a], [a], Int) m)
n
        forall s (m :: * -> *). MonadState s m => s -> m ()
put ([a]
src', [a]
out', Int
pos')
    getCursor :: StateT ([a], [a], Int) m (Index (StateT ([a], [a], Int) m))
getCursor = forall s (m :: * -> *). MonadState s m => m s
get forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \([a]
_, [a]
_, Int
pos) -> forall (f :: * -> *) a. Applicative f => a -> f a
pure Int
pos

-- | Streams supporting forward seeking and arbitrary edits.
class FwdInplaceStream m => FwdStream m where
    -- | Write a chunk into the stream at the current cursor position.
    --
    -- Moves the cursor to the right by the length of the chunk.
    write :: Chunk m -> m ()

    -- | Delete a sized chunk at the current cursor position.
    --
    -- Argument must be positive.
    delete :: Index m -> m ()

instance Monad m => FwdStream (StateT (B.ByteString, BB.Builder, Int) m) where
    write :: Chunk (StateT (ByteString, Builder, Int) m)
-> StateT (ByteString, Builder, Int) m ()
write Chunk (StateT (ByteString, Builder, Int) m)
bs = do
        (ByteString
src, Builder
out, Int
pos) <- forall s (m :: * -> *). MonadState s m => m s
get
        let out' :: Builder
out' = Builder
out forall a. Semigroup a => a -> a -> a
<> ByteString -> Builder
BB.byteString Chunk (StateT (ByteString, Builder, Int) m)
bs
            pos' :: Int
pos' = Int
pos forall a. Num a => a -> a -> a
+ ByteString -> Int
B.length Chunk (StateT (ByteString, Builder, Int) m)
bs
        forall s (m :: * -> *). MonadState s m => s -> m ()
put (ByteString
src, Builder
out', Int
pos')
    delete :: Index (StateT (ByteString, Builder, Int) m)
-> StateT (ByteString, Builder, Int) m ()
delete Index (StateT (ByteString, Builder, Int) m)
n = do
        (ByteString
src, Builder
out, Int
pos) <- forall s (m :: * -> *). MonadState s m => m s
get
        let src' :: ByteString
src' = Int -> ByteString -> ByteString
B.drop Index (StateT (ByteString, Builder, Int) m)
n ByteString
src
        forall s (m :: * -> *). MonadState s m => s -> m ()
put (ByteString
src', Builder
out, Int
pos)