{-# LANGUAGE DeriveDataTypeable #-}
{-# LANGUAGE DerivingStrategies #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeFamilies #-}

-- | Simplified implementation of the monad-ste package,
-- with a few extras.
module Internal.STE
  ( STE,
    throwSTE,
    runSTE,
    liftST,
    steToST,
    steToIO,
  )
where

import Control.Exception
import Control.Monad.Catch (MonadThrow (..))
import Control.Monad.Primitive (PrimMonad (..))
import Control.Monad.ST
import Control.Monad.ST.Unsafe
import Data.Typeable (Typeable)

newtype InternalErr e = InternalErr e
  deriving stock (Typeable)

instance Show (InternalErr e) where
  show :: InternalErr e -> String
show InternalErr e
_ = String
"(InternalErr _)"

instance Typeable e => Exception (InternalErr e)

newtype STE e s a = STE (IO a)
  deriving newtype (forall a b. a -> STE e s b -> STE e s a
forall a b. (a -> b) -> STE e s a -> STE e s b
forall e s a b. a -> STE e s b -> STE e s a
forall e s a b. (a -> b) -> STE e s a -> STE e s b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
<$ :: forall a b. a -> STE e s b -> STE e s a
$c<$ :: forall e s a b. a -> STE e s b -> STE e s a
fmap :: forall a b. (a -> b) -> STE e s a -> STE e s b
$cfmap :: forall e s a b. (a -> b) -> STE e s a -> STE e s b
Functor, forall a. a -> STE e s a
forall e s. Functor (STE e s)
forall a b. STE e s a -> STE e s b -> STE e s a
forall a b. STE e s a -> STE e s b -> STE e s b
forall a b. STE e s (a -> b) -> STE e s a -> STE e s b
forall e s a. a -> STE e s a
forall a b c. (a -> b -> c) -> STE e s a -> STE e s b -> STE e s c
forall e s a b. STE e s a -> STE e s b -> STE e s a
forall e s a b. STE e s a -> STE e s b -> STE e s b
forall e s a b. STE e s (a -> b) -> STE e s a -> STE e s b
forall e s a b c.
(a -> b -> c) -> STE e s a -> STE e s b -> STE e s c
forall (f :: * -> *).
Functor f
-> (forall a. a -> f a)
-> (forall a b. f (a -> b) -> f a -> f b)
-> (forall a b c. (a -> b -> c) -> f a -> f b -> f c)
-> (forall a b. f a -> f b -> f b)
-> (forall a b. f a -> f b -> f a)
-> Applicative f
<* :: forall a b. STE e s a -> STE e s b -> STE e s a
$c<* :: forall e s a b. STE e s a -> STE e s b -> STE e s a
*> :: forall a b. STE e s a -> STE e s b -> STE e s b
$c*> :: forall e s a b. STE e s a -> STE e s b -> STE e s b
liftA2 :: forall a b c. (a -> b -> c) -> STE e s a -> STE e s b -> STE e s c
$cliftA2 :: forall e s a b c.
(a -> b -> c) -> STE e s a -> STE e s b -> STE e s c
<*> :: forall a b. STE e s (a -> b) -> STE e s a -> STE e s b
$c<*> :: forall e s a b. STE e s (a -> b) -> STE e s a -> STE e s b
pure :: forall a. a -> STE e s a
$cpure :: forall e s a. a -> STE e s a
Applicative, forall a. a -> STE e s a
forall e s. Applicative (STE e s)
forall a b. STE e s a -> STE e s b -> STE e s b
forall a b. STE e s a -> (a -> STE e s b) -> STE e s b
forall e s a. a -> STE e s a
forall e s a b. STE e s a -> STE e s b -> STE e s b
forall e s a b. STE e s a -> (a -> STE e s b) -> STE e s b
forall (m :: * -> *).
Applicative m
-> (forall a b. m a -> (a -> m b) -> m b)
-> (forall a b. m a -> m b -> m b)
-> (forall a. a -> m a)
-> Monad m
return :: forall a. a -> STE e s a
$creturn :: forall e s a. a -> STE e s a
>> :: forall a b. STE e s a -> STE e s b -> STE e s b
$c>> :: forall e s a b. STE e s a -> STE e s b -> STE e s b
>>= :: forall a b. STE e s a -> (a -> STE e s b) -> STE e s b
$c>>= :: forall e s a b. STE e s a -> (a -> STE e s b) -> STE e s b
Monad)

instance PrimMonad (STE e s) where
  type PrimState (STE e s) = s
  primitive :: forall a.
(State# (PrimState (STE e s))
 -> (# State# (PrimState (STE e s)), a #))
-> STE e s a
primitive = forall s a e. ST s a -> STE e s a
liftST forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *) a.
PrimMonad m =>
(State# (PrimState m) -> (# State# (PrimState m), a #)) -> m a
primitive

liftST :: ST s a -> STE e s a
liftST :: forall s a e. ST s a -> STE e s a
liftST ST s a
st = forall e s a. IO a -> STE e s a
STE (forall s a. ST s a -> IO a
unsafeSTToIO ST s a
st)

throwSTE :: Exception e => e -> STE e s a
throwSTE :: forall e s a. Exception e => e -> STE e s a
throwSTE e
e = forall e s a. IO a -> STE e s a
STE (forall e a. Exception e => e -> IO a
throwIO (forall e. e -> InternalErr e
InternalErr e
e))

runSTE :: Exception e => (forall s. STE e s a) -> Either e a
runSTE :: forall e a. Exception e => (forall s. STE e s a) -> Either e a
runSTE forall s. STE e s a
ste = forall a. (forall s. ST s a) -> a
runST (forall e s a. Typeable e => STE e s a -> ST s (Either e a)
steToST forall s. STE e s a
ste)

steToST :: Typeable e => STE e s a -> ST s (Either e a)
steToST :: forall e s a. Typeable e => STE e s a -> ST s (Either e a)
steToST (STE IO a
io) = forall a s. IO a -> ST s a
unsafeIOToST forall a b. (a -> b) -> a -> b
$ do
  Either (InternalErr e) a
res <- forall e a. Exception e => IO a -> IO (Either e a)
try IO a
io
  case Either (InternalErr e) a
res of
    Left (InternalErr e
e) -> forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall a b. a -> Either a b
Left e
e
    Right a
v -> forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall a b. b -> Either a b
Right a
v

steToIO :: forall e a. Exception e => STE e RealWorld a -> IO a
steToIO :: forall e a. Exception e => STE e RealWorld a -> IO a
steToIO (STE IO a
io) = do
  Either (InternalErr e) a
res <- forall e a. Exception e => IO a -> IO (Either e a)
try IO a
io
  case Either (InternalErr e) a
res of
    Left (InternalErr (e
e :: e)) -> forall e a. Exception e => e -> IO a
throwIO e
e
    Right a
v -> forall (f :: * -> *) a. Applicative f => a -> f a
pure a
v

instance MonadThrow (STE SomeException s) where
  throwM :: forall e a. Exception e => e -> STE SomeException s a
throwM = forall e s a. Exception e => e -> STE e s a
throwSTE forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall e. Exception e => e -> SomeException
toException