{-# LANGUAGE CPP #-}
{-# LANGUAGE InstanceSigs #-}
-- |
-- Module: Language.KURE.Combinators.Transform
-- Copyright: (c) 2012--2021 The University of Kansas
-- License: BSD3
--
-- Maintainer: Neil Sculthorpe <neil.sculthorpe@ntu.ac.uk>
-- Stability: beta
-- Portability: ghc
--
-- This module provides a variety of combinators over 'Transform' and 'Rewrite'.

module Language.KURE.Combinators.Transform
        ( -- * Transformation Combinators
          idR
        , successT
        , contextT
        , exposeT
        , liftContext
        , readerT
        , resultT
        , catchesT
        , mapT
        , joinT
        , guardT
          -- * Rewrite Combinators
        , tryR
        , andR
        , orR
        , (>+>)
        , repeatR
        , acceptR
        , acceptWithFailMsgR
        , accepterR
        , changedR
        , changedByR
        , sideEffectR
          -- * Monad Transformers
          -- ** anyR Support
          -- $AnyR_doc
        , AnyR
        , wrapAnyR
        , unwrapAnyR
          -- ** oneR Support
          -- $OneR_doc
        , OneR
        , wrapOneR
        , unwrapOneR
) where

import Prelude hiding (id, map, foldr, mapM)

import Control.Category ((>>>),id)
import Control.Monad (liftM,ap)

#if !MIN_VERSION_base(4,13,0)
import Control.Monad.Fail (MonadFail)
import qualified Control.Monad.Fail
#endif

import Data.Foldable ()
import Data.Traversable

import Language.KURE.Combinators.Arrow
import Language.KURE.Combinators.Monad
import Language.KURE.MonadCatch
import Language.KURE.Transform

------------------------------------------------------------------------------------------

-- | The identity rewrite.
idR :: Monad m => Rewrite c m a
idR :: Rewrite c m a
idR = Rewrite c m a
forall k (cat :: k -> k -> *) (a :: k). Category cat => cat a a
id
{-# INLINE idR #-}

-- | An always successful transformation.
successT :: Monad m => Transform c m a ()
successT :: Transform c m a ()
successT = () -> Transform c m a ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
{-# INLINE successT #-}

-- | Extract the current context.
contextT :: Monad m => Transform c m a c
contextT :: Transform c m a c
contextT = (c -> a -> m c) -> Transform c m a c
forall k c a (m :: k -> *) (b :: k).
(c -> a -> m b) -> Transform c m a b
transform (\ c
c a
_ -> c -> m c
forall (m :: * -> *) a. Monad m => a -> m a
return c
c)
{-# INLINE contextT #-}

-- | Expose the current context and value.
exposeT :: Monad m => Transform c m a (c,a)
exposeT :: Transform c m a (c, a)
exposeT = (c -> a -> m (c, a)) -> Transform c m a (c, a)
forall k c a (m :: k -> *) (b :: k).
(c -> a -> m b) -> Transform c m a b
transform (((c, a) -> m (c, a)) -> c -> a -> m (c, a)
forall a b c. ((a, b) -> c) -> a -> b -> c
curry (c, a) -> m (c, a)
forall (m :: * -> *) a. Monad m => a -> m a
return)
{-# INLINE exposeT #-}

-- | Lift a transformation to operate on a derived context.
liftContext :: (c -> c') -> Transform c' m a b -> Transform c m a b
liftContext :: (c -> c') -> Transform c' m a b -> Transform c m a b
liftContext c -> c'
f Transform c' m a b
t = (c -> a -> m b) -> Transform c m a b
forall k c a (m :: k -> *) (b :: k).
(c -> a -> m b) -> Transform c m a b
transform (Transform c' m a b -> c' -> a -> m b
forall c k (m :: k -> *) a (b :: k).
Transform c m a b -> c -> a -> m b
applyT Transform c' m a b
t (c' -> a -> m b) -> (c -> c') -> c -> a -> m b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. c -> c'
f)
{-# INLINE liftContext #-}

-- | Map a transformation over a list.
mapT :: (Traversable t, Monad m) => Transform c m a b -> Transform c m (t a) (t b)
mapT :: Transform c m a b -> Transform c m (t a) (t b)
mapT Transform c m a b
t = (c -> t a -> m (t b)) -> Transform c m (t a) (t b)
forall k c a (m :: k -> *) (b :: k).
(c -> a -> m b) -> Transform c m a b
transform ((a -> m b) -> t a -> m (t b)
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM ((a -> m b) -> t a -> m (t b))
-> (c -> a -> m b) -> c -> t a -> m (t b)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Transform c m a b -> c -> a -> m b
forall c k (m :: k -> *) a (b :: k).
Transform c m a b -> c -> a -> m b
applyT Transform c m a b
t)
{-# INLINE mapT #-}

-- | An identity rewrite with side-effects.
sideEffectR :: Monad m => (c -> a -> m ()) -> Rewrite c m a
sideEffectR :: (c -> a -> m ()) -> Rewrite c m a
sideEffectR c -> a -> m ()
f = (c -> a -> m ()) -> Transform c m a ()
forall k c a (m :: k -> *) (b :: k).
(c -> a -> m b) -> Transform c m a b
transform c -> a -> m ()
f Transform c m a () -> Rewrite c m a -> Rewrite c m a
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Rewrite c m a
forall k (cat :: k -> k -> *) (a :: k). Category cat => cat a a
id
{-# INLINE sideEffectR #-}

-- | Look at the argument to the transformation before choosing which 'Transform' to use.
readerT :: (a -> Transform c m a b) -> Transform c m a b
readerT :: (a -> Transform c m a b) -> Transform c m a b
readerT a -> Transform c m a b
f = (c -> a -> m b) -> Transform c m a b
forall k c a (m :: k -> *) (b :: k).
(c -> a -> m b) -> Transform c m a b
transform (\ c
c a
a -> Transform c m a b -> c -> a -> m b
forall c k (m :: k -> *) a (b :: k).
Transform c m a b -> c -> a -> m b
applyT (a -> Transform c m a b
f a
a) c
c a
a)
{-# INLINE readerT #-}

-- | Convert the monadic result of a transformation into a result in another monad.
resultT :: (m b -> n d) -> Transform c m a b -> Transform c n a d
resultT :: (m b -> n d) -> Transform c m a b -> Transform c n a d
resultT m b -> n d
f Transform c m a b
t = (c -> a -> n d) -> Transform c n a d
forall k c a (m :: k -> *) (b :: k).
(c -> a -> m b) -> Transform c m a b
transform (\ c
c -> m b -> n d
f (m b -> n d) -> (a -> m b) -> a -> n d
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Transform c m a b -> c -> a -> m b
forall c k (m :: k -> *) a (b :: k).
Transform c m a b -> c -> a -> m b
applyT Transform c m a b
t c
c)
{-# INLINE resultT #-}

-- | Perform a collection of rewrites in sequence, requiring all to succeed.
andR :: (Foldable f, Monad m) => f (Rewrite c m a) -> Rewrite c m a
andR :: f (Rewrite c m a) -> Rewrite c m a
andR = f (Rewrite c m a) -> Rewrite c m a
forall (f :: * -> *) (bi :: * -> * -> *) a.
(Foldable f, Category bi) =>
f (bi a a) -> bi a a
serialise
{-# INLINE andR #-}

-- | Perform two rewrites in sequence, succeeding if one or both succeed.
(>+>) :: MonadCatch m => Rewrite c m a -> Rewrite c m a -> Rewrite c m a
Rewrite c m a
r1 >+> :: Rewrite c m a -> Rewrite c m a -> Rewrite c m a
>+> Rewrite c m a
r2 = Rewrite c (AnyR m) a -> Rewrite c m a
forall (m :: * -> *) c a.
MonadFail m =>
Rewrite c (AnyR m) a -> Rewrite c m a
unwrapAnyR (Rewrite c m a -> Rewrite c (AnyR m) a
forall (m :: * -> *) c a.
MonadCatch m =>
Rewrite c m a -> Rewrite c (AnyR m) a
wrapAnyR Rewrite c m a
r1 Rewrite c (AnyR m) a
-> Rewrite c (AnyR m) a -> Rewrite c (AnyR m) a
forall k (cat :: k -> k -> *) (a :: k) (b :: k) (c :: k).
Category cat =>
cat a b -> cat b c -> cat a c
>>> Rewrite c m a -> Rewrite c (AnyR m) a
forall (m :: * -> *) c a.
MonadCatch m =>
Rewrite c m a -> Rewrite c (AnyR m) a
wrapAnyR Rewrite c m a
r2)
{-# INLINE (>+>) #-}

-- | Perform a collection of rewrites in sequence, succeeding if any succeed.
orR :: (Functor f, Foldable f, MonadCatch m) => f (Rewrite c m a) -> Rewrite c m a
orR :: f (Rewrite c m a) -> Rewrite c m a
orR = Rewrite c (AnyR m) a -> Rewrite c m a
forall (m :: * -> *) c a.
MonadFail m =>
Rewrite c (AnyR m) a -> Rewrite c m a
unwrapAnyR (Rewrite c (AnyR m) a -> Rewrite c m a)
-> (f (Rewrite c m a) -> Rewrite c (AnyR m) a)
-> f (Rewrite c m a)
-> Rewrite c m a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. f (Rewrite c (AnyR m) a) -> Rewrite c (AnyR m) a
forall (f :: * -> *) (m :: * -> *) c a.
(Foldable f, Monad m) =>
f (Rewrite c m a) -> Rewrite c m a
andR (f (Rewrite c (AnyR m) a) -> Rewrite c (AnyR m) a)
-> (f (Rewrite c m a) -> f (Rewrite c (AnyR m) a))
-> f (Rewrite c m a)
-> Rewrite c (AnyR m) a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Rewrite c m a -> Rewrite c (AnyR m) a)
-> f (Rewrite c m a) -> f (Rewrite c (AnyR m) a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Rewrite c m a -> Rewrite c (AnyR m) a
forall (m :: * -> *) c a.
MonadCatch m =>
Rewrite c m a -> Rewrite c (AnyR m) a
wrapAnyR
{-# INLINE orR #-}

-- | As 'acceptR', but takes a custom failure message.
acceptWithFailMsgR :: MonadFail m => (a -> Bool) -> String -> Rewrite c m a
acceptWithFailMsgR :: (a -> Bool) -> String -> Rewrite c m a
acceptWithFailMsgR a -> Bool
p String
msg = (a -> Rewrite c m a) -> Rewrite c m a
forall a c (m :: * -> *) b.
(a -> Transform c m a b) -> Transform c m a b
readerT ((a -> Rewrite c m a) -> Rewrite c m a)
-> (a -> Rewrite c m a) -> Rewrite c m a
forall a b. (a -> b) -> a -> b
$ \ a
a -> if a -> Bool
p a
a then Rewrite c m a
forall k (cat :: k -> k -> *) (a :: k). Category cat => cat a a
id else String -> Rewrite c m a
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
msg
{-# INLINE acceptWithFailMsgR #-}

-- | Look at the argument to a rewrite, and choose to be either 'idR' or a failure.
acceptR :: MonadFail m => (a -> Bool) -> Rewrite c m a
acceptR :: (a -> Bool) -> Rewrite c m a
acceptR a -> Bool
p = (a -> Bool) -> String -> Rewrite c m a
forall (m :: * -> *) a c.
MonadFail m =>
(a -> Bool) -> String -> Rewrite c m a
acceptWithFailMsgR a -> Bool
p String
"acceptR: predicate failed"
{-# INLINE acceptR #-}

-- | A generalisation of 'acceptR' where the predicate is a 'Transform'.
accepterR :: MonadFail m => Transform c m a Bool -> Rewrite c m a
accepterR :: Transform c m a Bool -> Rewrite c m a
accepterR Transform c m a Bool
t = Transform c m a Bool
-> Rewrite c m a -> Rewrite c m a -> Rewrite c m a
forall (m :: * -> *) a. Monad m => m Bool -> m a -> m a -> m a
ifM Transform c m a Bool
t Rewrite c m a
forall (m :: * -> *) c a. Monad m => Rewrite c m a
idR (String -> Rewrite c m a
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"accepterR: predicate failed")
{-# INLINE accepterR #-}

-- | Catch a failing rewrite, making it into an identity.
tryR :: MonadCatch m => Rewrite c m a -> Rewrite c m a
tryR :: Rewrite c m a -> Rewrite c m a
tryR Rewrite c m a
r = Rewrite c m a
r Rewrite c m a -> Rewrite c m a -> Rewrite c m a
forall (m :: * -> *) a. MonadCatch m => m a -> m a -> m a
<+ Rewrite c m a
forall k (cat :: k -> k -> *) (a :: k). Category cat => cat a a
id
{-# INLINE tryR #-}

-- | Makes a rewrite fail if the result value and the argument value satisfy the equality predicate.
--   This is a generalisation of 'changedR'.
--   @changedR = changedByR ('==')@
changedByR :: MonadCatch m => (a -> a -> Bool) -> Rewrite c m a -> Rewrite c m a
changedByR :: (a -> a -> Bool) -> Rewrite c m a -> Rewrite c m a
changedByR a -> a -> Bool
p Rewrite c m a
r = (a -> Rewrite c m a) -> Rewrite c m a
forall a c (m :: * -> *) b.
(a -> Transform c m a b) -> Transform c m a b
readerT (\ a
a -> Rewrite c m a
r Rewrite c m a -> Rewrite c m a -> Rewrite c m a
forall k (cat :: k -> k -> *) (a :: k) (b :: k) (c :: k).
Category cat =>
cat a b -> cat b c -> cat a c
>>> (a -> Bool) -> String -> Rewrite c m a
forall (m :: * -> *) a c.
MonadFail m =>
(a -> Bool) -> String -> Rewrite c m a
acceptWithFailMsgR (Bool -> Bool
not (Bool -> Bool) -> (a -> Bool) -> a -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> a -> Bool
p a
a) String
"changedByR: value is unchanged")
{-# INLINE changedByR #-}

-- | Makes an rewrite fail if the result value equals the argument value.
changedR :: (MonadCatch m, Eq a) => Rewrite c m a -> Rewrite c m a
changedR :: Rewrite c m a -> Rewrite c m a
changedR = (a -> a -> Bool) -> Rewrite c m a -> Rewrite c m a
forall (m :: * -> *) a c.
MonadCatch m =>
(a -> a -> Bool) -> Rewrite c m a -> Rewrite c m a
changedByR a -> a -> Bool
forall a. Eq a => a -> a -> Bool
(==)
{-# INLINE changedR #-}

-- | Repeat a rewrite until it fails, then return the result before the failure.
--   Requires at least the first attempt to succeed.
repeatR :: MonadCatch m => Rewrite c m a -> Rewrite c m a
repeatR :: Rewrite c m a -> Rewrite c m a
repeatR Rewrite c m a
r = let go :: Rewrite c m a
go = Rewrite c m a
r Rewrite c m a -> Rewrite c m a -> Rewrite c m a
forall k (cat :: k -> k -> *) (a :: k) (b :: k) (c :: k).
Category cat =>
cat a b -> cat b c -> cat a c
>>> Rewrite c m a -> Rewrite c m a
forall (m :: * -> *) c a.
MonadCatch m =>
Rewrite c m a -> Rewrite c m a
tryR Rewrite c m a
go
             in Rewrite c m a
go
{-# INLINE repeatR #-}

-- | Attempt each transformation until one succeeds, then return that result and discard the rest of the transformations.
catchesT :: MonadCatch m => [Transform c m a b] -> Transform c m a b
catchesT :: [Transform c m a b] -> Transform c m a b
catchesT = [Transform c m a b] -> Transform c m a b
forall (f :: * -> *) (m :: * -> *) a.
(Foldable f, MonadCatch m) =>
f (m a) -> m a
catchesM
{-# INLINE catchesT #-}
{-# DEPRECATED catchesT "Please use 'catchesM' instead." #-}


-- | An identity transformation that resembles a monadic 'Control.Monad.join'.
joinT :: Transform c m (m a) a
joinT :: Transform c m (m a) a
joinT = (m a -> m a) -> Transform c m (m a) a
forall k a (m :: k -> *) (b :: k) c.
(a -> m b) -> Transform c m a b
contextfreeT m a -> m a
forall k (cat :: k -> k -> *) (a :: k). Category cat => cat a a
id
{-# INLINE joinT #-}

-- | Fail if the Boolean is False, succeed if the Boolean is True.
guardT :: MonadFail m => Transform c m Bool ()
guardT :: Transform c m Bool ()
guardT = (Bool -> m ()) -> Transform c m Bool ()
forall k a (m :: k -> *) (b :: k) c.
(a -> m b) -> Transform c m a b
contextfreeT Bool -> m ()
forall (m :: * -> *). MonadFail m => Bool -> m ()
guardM
{-# INLINE guardT #-}

-------------------------------------------------------------------------------

data PBool a = PBool !Bool a

instance Functor PBool where
  fmap :: (a -> b) -> PBool a -> PBool b
  fmap :: (a -> b) -> PBool a -> PBool b
fmap a -> b
f (PBool Bool
b a
a) = Bool -> b -> PBool b
forall a. Bool -> a -> PBool a
PBool Bool
b (a -> b
f a
a)

checkSuccessPBool :: MonadFail m => String -> m (PBool a) -> m a
checkSuccessPBool :: String -> m (PBool a) -> m a
checkSuccessPBool String
msg m (PBool a)
m = do PBool Bool
b a
a <- m (PBool a)
m
                             if Bool
b
                               then a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return a
a
                               else String -> m a
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
msg
{-# INLINE checkSuccessPBool #-}

-------------------------------------------------------------------------------

-- $AnyR_doc
-- These are useful when defining congruence combinators that succeed if /any/ child rewrite succeeds.
-- See the \"Expr\" example, or the HERMIT package.

-- | The 'AnyR' transformer, in combination with 'wrapAnyR' and 'unwrapAnyR',
--   causes a sequence of rewrites to succeed if at least one succeeds, converting failures to
--   identity rewrites.
newtype AnyR m a = AnyR (m (PBool a))

unAnyR :: AnyR m a -> m (PBool a)
unAnyR :: AnyR m a -> m (PBool a)
unAnyR (AnyR m (PBool a)
mba) = m (PBool a)
mba
{-# INLINE unAnyR #-}

instance Monad m => Functor (AnyR m) where
   fmap :: (a -> b) -> AnyR m a -> AnyR m b
   fmap :: (a -> b) -> AnyR m a -> AnyR m b
fmap = (a -> b) -> AnyR m a -> AnyR m b
forall (m :: * -> *) a1 r. Monad m => (a1 -> r) -> m a1 -> m r
liftM
   {-# INLINE fmap #-}

instance Monad m => Applicative (AnyR m) where
   pure :: a -> AnyR m a
   pure :: a -> AnyR m a
pure = a -> AnyR m a
forall (m :: * -> *) a. Monad m => a -> m a
return
   {-# INLINE pure #-}

   (<*>) :: AnyR m (a -> b) -> AnyR m a -> AnyR m b
   <*> :: AnyR m (a -> b) -> AnyR m a -> AnyR m b
(<*>) = AnyR m (a -> b) -> AnyR m a -> AnyR m b
forall (m :: * -> *) a b. Monad m => m (a -> b) -> m a -> m b
ap
   {-# INLINE (<*>) #-}

instance Monad m => Monad (AnyR m) where
   return :: a -> AnyR m a
   return :: a -> AnyR m a
return = m (PBool a) -> AnyR m a
forall (m :: * -> *) a. m (PBool a) -> AnyR m a
AnyR (m (PBool a) -> AnyR m a) -> (a -> m (PBool a)) -> a -> AnyR m a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PBool a -> m (PBool a)
forall (m :: * -> *) a. Monad m => a -> m a
return (PBool a -> m (PBool a)) -> (a -> PBool a) -> a -> m (PBool a)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Bool -> a -> PBool a
forall a. Bool -> a -> PBool a
PBool Bool
False
   {-# INLINE return #-}

   (>>=) :: AnyR m a -> (a -> AnyR m d) -> AnyR m d
   AnyR m a
ma >>= :: AnyR m a -> (a -> AnyR m d) -> AnyR m d
>>= a -> AnyR m d
f = m (PBool d) -> AnyR m d
forall (m :: * -> *) a. m (PBool a) -> AnyR m a
AnyR (m (PBool d) -> AnyR m d) -> m (PBool d) -> AnyR m d
forall a b. (a -> b) -> a -> b
$ do PBool Bool
b1 a
a <- AnyR m a -> m (PBool a)
forall (m :: * -> *) a. AnyR m a -> m (PBool a)
unAnyR AnyR m a
ma
                        PBool Bool
b2 d
d <- AnyR m d -> m (PBool d)
forall (m :: * -> *) a. AnyR m a -> m (PBool a)
unAnyR (a -> AnyR m d
f a
a)
                        PBool d -> m (PBool d)
forall (m :: * -> *) a. Monad m => a -> m a
return (Bool -> d -> PBool d
forall a. Bool -> a -> PBool a
PBool (Bool
b1 Bool -> Bool -> Bool
|| Bool
b2) d
d)
   {-# INLINE (>>=) #-}

#if !MIN_VERSION_base(4,13,0)
   fail :: String -> AnyR m a
   fail = AnyR . fail
   {-# INLINE fail #-}
#endif

instance MonadFail m => MonadFail (AnyR m) where
   fail :: String -> AnyR m a
   fail :: String -> AnyR m a
fail = m (PBool a) -> AnyR m a
forall (m :: * -> *) a. m (PBool a) -> AnyR m a
AnyR (m (PBool a) -> AnyR m a)
-> (String -> m (PBool a)) -> String -> AnyR m a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> m (PBool a)
forall (m :: * -> *) a. MonadFail m => String -> m a
fail
   {-# INLINE fail #-}

instance MonadCatch m => MonadCatch (AnyR m) where
   catchM :: AnyR m a -> (String -> AnyR m a) -> AnyR m a
   catchM :: AnyR m a -> (String -> AnyR m a) -> AnyR m a
catchM AnyR m a
ma String -> AnyR m a
f = m (PBool a) -> AnyR m a
forall (m :: * -> *) a. m (PBool a) -> AnyR m a
AnyR (AnyR m a -> m (PBool a)
forall (m :: * -> *) a. AnyR m a -> m (PBool a)
unAnyR AnyR m a
ma m (PBool a) -> (String -> m (PBool a)) -> m (PBool a)
forall (m :: * -> *) a.
MonadCatch m =>
m a -> (String -> m a) -> m a
`catchM` (AnyR m a -> m (PBool a)
forall (m :: * -> *) a. AnyR m a -> m (PBool a)
unAnyR (AnyR m a -> m (PBool a))
-> (String -> AnyR m a) -> String -> m (PBool a)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> AnyR m a
f))
   {-# INLINE catchM #-}

-- | Wrap a 'Rewrite' using the 'AnyR' monad transformer.
wrapAnyR :: MonadCatch m => Rewrite c m a -> Rewrite c (AnyR m) a
wrapAnyR :: Rewrite c m a -> Rewrite c (AnyR m) a
wrapAnyR Rewrite c m a
r = (c -> a -> AnyR m a) -> Rewrite c (AnyR m) a
forall c a (m :: * -> *). (c -> a -> m a) -> Rewrite c m a
rewrite ((c -> a -> AnyR m a) -> Rewrite c (AnyR m) a)
-> (c -> a -> AnyR m a) -> Rewrite c (AnyR m) a
forall a b. (a -> b) -> a -> b
$ \ c
c a
a -> m (PBool a) -> AnyR m a
forall (m :: * -> *) a. m (PBool a) -> AnyR m a
AnyR (m (PBool a) -> AnyR m a) -> m (PBool a) -> AnyR m a
forall a b. (a -> b) -> a -> b
$ (Bool -> a -> PBool a
forall a. Bool -> a -> PBool a
PBool Bool
True (a -> PBool a) -> m a -> m (PBool a)
forall (m :: * -> *) a1 r. Monad m => (a1 -> r) -> m a1 -> m r
`liftM` Rewrite c m a -> c -> a -> m a
forall c (m :: * -> *) a. Rewrite c m a -> c -> a -> m a
applyR Rewrite c m a
r c
c a
a) m (PBool a) -> m (PBool a) -> m (PBool a)
forall (m :: * -> *) a. MonadCatch m => m a -> m a -> m a
<+ PBool a -> m (PBool a)
forall (m :: * -> *) a. Monad m => a -> m a
return (Bool -> a -> PBool a
forall a. Bool -> a -> PBool a
PBool Bool
False a
a)
{-# INLINE wrapAnyR #-}

-- | Unwrap a 'Rewrite' from the 'AnyR' monad transformer.
unwrapAnyR :: MonadFail m => Rewrite c (AnyR m) a -> Rewrite c m a
unwrapAnyR :: Rewrite c (AnyR m) a -> Rewrite c m a
unwrapAnyR = (AnyR m a -> m a) -> Rewrite c (AnyR m) a -> Rewrite c m a
forall (m :: * -> *) b (n :: * -> *) d c a.
(m b -> n d) -> Transform c m a b -> Transform c n a d
resultT (String -> m (PBool a) -> m a
forall (m :: * -> *) a. MonadFail m => String -> m (PBool a) -> m a
checkSuccessPBool String
"anyR failed" (m (PBool a) -> m a)
-> (AnyR m a -> m (PBool a)) -> AnyR m a -> m a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. AnyR m a -> m (PBool a)
forall (m :: * -> *) a. AnyR m a -> m (PBool a)
unAnyR)
{-# INLINE unwrapAnyR #-}

-------------------------------------------------------------------------------

-- $OneR_doc
-- These are useful when defining congruence combinators that succeed if one child rewrite succeeds
-- (and the remainder are then discarded).
-- See the \"Expr\" example, or the HERMIT package.

-- | The 'OneR' transformer, in combination with 'wrapOneR' and 'unwrapOneR',
--   causes a sequence of rewrites to only apply the first success, converting the remainder (and failures) to identity rewrites.
newtype OneR m a = OneR (Bool -> m (PBool a))

unOneR :: OneR m a -> Bool -> m (PBool a)
unOneR :: OneR m a -> Bool -> m (PBool a)
unOneR (OneR Bool -> m (PBool a)
mba) = Bool -> m (PBool a)
mba
{-# INLINE unOneR #-}

instance Monad m => Functor (OneR m) where
   fmap :: (a -> b) -> OneR m a -> OneR m b
   fmap :: (a -> b) -> OneR m a -> OneR m b
fmap = (a -> b) -> OneR m a -> OneR m b
forall (m :: * -> *) a1 r. Monad m => (a1 -> r) -> m a1 -> m r
liftM
   {-# INLINE fmap #-}

instance Monad m => Applicative (OneR m) where
   pure :: a -> OneR m a
   pure :: a -> OneR m a
pure = a -> OneR m a
forall (m :: * -> *) a. Monad m => a -> m a
return
   {-# INLINE pure #-}

   (<*>) :: OneR m (a -> b) -> OneR m a -> OneR m b
   <*> :: OneR m (a -> b) -> OneR m a -> OneR m b
(<*>) = OneR m (a -> b) -> OneR m a -> OneR m b
forall (m :: * -> *) a b. Monad m => m (a -> b) -> m a -> m b
ap
   {-# INLINE (<*>) #-}

instance Monad m => Monad (OneR m) where
   return :: a -> OneR m a
   return :: a -> OneR m a
return a
a = (Bool -> m (PBool a)) -> OneR m a
forall (m :: * -> *) a. (Bool -> m (PBool a)) -> OneR m a
OneR (\ Bool
b -> PBool a -> m (PBool a)
forall (m :: * -> *) a. Monad m => a -> m a
return (Bool -> a -> PBool a
forall a. Bool -> a -> PBool a
PBool Bool
b a
a))
   {-# INLINE return #-}

   (>>=) :: OneR m a -> (a -> OneR m d) -> OneR m d
   OneR m a
ma >>= :: OneR m a -> (a -> OneR m d) -> OneR m d
>>= a -> OneR m d
f = (Bool -> m (PBool d)) -> OneR m d
forall (m :: * -> *) a. (Bool -> m (PBool a)) -> OneR m a
OneR ((Bool -> m (PBool d)) -> OneR m d)
-> (Bool -> m (PBool d)) -> OneR m d
forall a b. (a -> b) -> a -> b
$ \ Bool
b1 -> do PBool Bool
b2 a
a <- OneR m a -> Bool -> m (PBool a)
forall (m :: * -> *) a. OneR m a -> Bool -> m (PBool a)
unOneR OneR m a
ma Bool
b1
                                OneR m d -> Bool -> m (PBool d)
forall (m :: * -> *) a. OneR m a -> Bool -> m (PBool a)
unOneR (a -> OneR m d
f a
a) Bool
b2
   {-# INLINE (>>=) #-}

#if !MIN_VERSION_base(4,13,0)
   fail :: String -> OneR m a
   fail msg = OneR (\ _ -> fail msg)
   {-# INLINE fail #-}
#endif

instance MonadFail m => MonadFail (OneR m) where
   fail :: String -> OneR m a
   fail :: String -> OneR m a
fail String
msg = (Bool -> m (PBool a)) -> OneR m a
forall (m :: * -> *) a. (Bool -> m (PBool a)) -> OneR m a
OneR (\ Bool
_ -> String -> m (PBool a)
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
msg)
   {-# INLINE fail #-}

instance MonadCatch m => MonadCatch (OneR m) where
   catchM :: OneR m a -> (String -> OneR m a) -> OneR m a
   catchM :: OneR m a -> (String -> OneR m a) -> OneR m a
catchM (OneR Bool -> m (PBool a)
g) String -> OneR m a
f = (Bool -> m (PBool a)) -> OneR m a
forall (m :: * -> *) a. (Bool -> m (PBool a)) -> OneR m a
OneR (\ Bool
b -> Bool -> m (PBool a)
g Bool
b m (PBool a) -> (String -> m (PBool a)) -> m (PBool a)
forall (m :: * -> *) a.
MonadCatch m =>
m a -> (String -> m a) -> m a
`catchM` (((Bool -> m (PBool a)) -> Bool -> m (PBool a)
forall a b. (a -> b) -> a -> b
$ Bool
b) ((Bool -> m (PBool a)) -> m (PBool a))
-> (String -> Bool -> m (PBool a)) -> String -> m (PBool a)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. OneR m a -> Bool -> m (PBool a)
forall (m :: * -> *) a. OneR m a -> Bool -> m (PBool a)
unOneR (OneR m a -> Bool -> m (PBool a))
-> (String -> OneR m a) -> String -> Bool -> m (PBool a)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> OneR m a
f))
   {-# INLINE catchM #-}

-- | Wrap a 'Rewrite' using the 'OneR' monad transformer.
wrapOneR :: MonadCatch m => Rewrite c m g -> Rewrite c (OneR m) g
wrapOneR :: Rewrite c m g -> Rewrite c (OneR m) g
wrapOneR Rewrite c m g
r = (c -> g -> OneR m g) -> Rewrite c (OneR m) g
forall c a (m :: * -> *). (c -> a -> m a) -> Rewrite c m a
rewrite ((c -> g -> OneR m g) -> Rewrite c (OneR m) g)
-> (c -> g -> OneR m g) -> Rewrite c (OneR m) g
forall a b. (a -> b) -> a -> b
$ \ c
c g
a -> (Bool -> m (PBool g)) -> OneR m g
forall (m :: * -> *) a. (Bool -> m (PBool a)) -> OneR m a
OneR ((Bool -> m (PBool g)) -> OneR m g)
-> (Bool -> m (PBool g)) -> OneR m g
forall a b. (a -> b) -> a -> b
$ \ Bool
b -> if Bool
b
                                                then PBool g -> m (PBool g)
forall (m :: * -> *) a. Monad m => a -> m a
return (Bool -> g -> PBool g
forall a. Bool -> a -> PBool a
PBool Bool
True g
a)
                                                else (Bool -> g -> PBool g
forall a. Bool -> a -> PBool a
PBool Bool
True (g -> PBool g) -> m g -> m (PBool g)
forall (m :: * -> *) a1 r. Monad m => (a1 -> r) -> m a1 -> m r
`liftM` Rewrite c m g -> c -> g -> m g
forall c (m :: * -> *) a. Rewrite c m a -> c -> a -> m a
applyR Rewrite c m g
r c
c g
a) m (PBool g) -> m (PBool g) -> m (PBool g)
forall (m :: * -> *) a. MonadCatch m => m a -> m a -> m a
<+ PBool g -> m (PBool g)
forall (m :: * -> *) a. Monad m => a -> m a
return (Bool -> g -> PBool g
forall a. Bool -> a -> PBool a
PBool Bool
False g
a)
{-# INLINE wrapOneR #-}

-- | Unwrap a 'Rewrite' from the 'OneR' monad transformer.
unwrapOneR :: MonadFail m => Rewrite c (OneR m) a -> Rewrite c m a
unwrapOneR :: Rewrite c (OneR m) a -> Rewrite c m a
unwrapOneR = (OneR m a -> m a) -> Rewrite c (OneR m) a -> Rewrite c m a
forall (m :: * -> *) b (n :: * -> *) d c a.
(m b -> n d) -> Transform c m a b -> Transform c n a d
resultT (String -> m (PBool a) -> m a
forall (m :: * -> *) a. MonadFail m => String -> m (PBool a) -> m a
checkSuccessPBool String
"oneR failed" (m (PBool a) -> m a)
-> (OneR m a -> m (PBool a)) -> OneR m a -> m a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((Bool -> m (PBool a)) -> Bool -> m (PBool a)
forall a b. (a -> b) -> a -> b
$ Bool
False) ((Bool -> m (PBool a)) -> m (PBool a))
-> (OneR m a -> Bool -> m (PBool a)) -> OneR m a -> m (PBool a)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. OneR m a -> Bool -> m (PBool a)
forall (m :: * -> *) a. OneR m a -> Bool -> m (PBool a)
unOneR)
{-# INLINE unwrapOneR #-}

-------------------------------------------------------------------------------