{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}

-- |
-- Module      :   Grisette.Internal.SymPrim.Prim.PartialEval.Unfold
-- Copyright   :   (c) Sirui Lu 2021-2023
-- License     :   BSD-3-Clause (see the LICENSE file)
--
-- Maintainer  :   siruilu@cs.washington.edu
-- Stability   :   Experimental
-- Portability :   GHC only
module Grisette.Internal.SymPrim.Prim.Internal.Unfold
  ( unaryUnfoldOnce,
    binaryUnfoldOnce,
  )
where

import Control.Monad.Except (MonadError (catchError))
import Data.Typeable (Typeable)
import Grisette.Internal.SymPrim.Prim.Internal.PartialEval
  ( PartialRuleBinary,
    PartialRuleUnary,
    TotalRuleBinary,
    TotalRuleUnary,
    totalize,
    totalize2,
  )
import Grisette.Internal.SymPrim.Prim.Internal.Term
  ( SupportedPrim (pevalITETerm),
    Term (ITETerm),
  )

unaryPartialUnfoldOnce ::
  forall a b.
  (SupportedPrim b) =>
  PartialRuleUnary a b ->
  TotalRuleUnary a b ->
  PartialRuleUnary a b
unaryPartialUnfoldOnce :: forall a b.
SupportedPrim b =>
PartialRuleUnary a b -> TotalRuleUnary a b -> PartialRuleUnary a b
unaryPartialUnfoldOnce PartialRuleUnary a b
partial TotalRuleUnary a b
fallback = PartialRuleUnary a b
ret
  where
    oneLevel :: TotalRuleUnary a b -> PartialRuleUnary a b
    oneLevel :: TotalRuleUnary a b -> PartialRuleUnary a b
oneLevel TotalRuleUnary a b
fallback' Term a
x = case (Term a
x, PartialRuleUnary a b
partial Term a
x) of
      (ITETerm Id
_ Term Bool
cond Term a
vt Term a
vf, Maybe (Term b)
pr) ->
        let pt :: Maybe (Term b)
pt = PartialRuleUnary a b
partial Term a
vt
            pf :: Maybe (Term b)
pf = PartialRuleUnary a b
partial Term a
vf
         in case (Maybe (Term b)
pt, Maybe (Term b)
pf) of
              (Maybe (Term b)
Nothing, Maybe (Term b)
Nothing) -> Maybe (Term b)
pr
              (Maybe (Term b)
mt, Maybe (Term b)
mf) ->
                Term Bool -> Term b -> Term b -> Term b
forall t.
SupportedPrim t =>
Term Bool -> Term t -> Term t -> Term t
pevalITETerm Term Bool
cond
                  (Term b -> Term b -> Term b)
-> Maybe (Term b) -> Maybe (Term b -> Term b)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Maybe (Term b) -> (() -> Maybe (Term b)) -> Maybe (Term b)
forall a. Maybe a -> (() -> Maybe a) -> Maybe a
forall e (m :: * -> *) a.
MonadError e m =>
m a -> (e -> m a) -> m a
catchError Maybe (Term b)
mt (\()
_ -> Term b -> Maybe (Term b)
forall a. a -> Maybe a
Just (Term b -> Maybe (Term b)) -> Term b -> Maybe (Term b)
forall a b. (a -> b) -> a -> b
$ PartialRuleUnary a b -> TotalRuleUnary a b -> TotalRuleUnary a b
forall a b. PartialFun a b -> (a -> b) -> a -> b
totalize (TotalRuleUnary a b -> PartialRuleUnary a b
oneLevel TotalRuleUnary a b
fallback') TotalRuleUnary a b
fallback' Term a
vt)
                  Maybe (Term b -> Term b) -> Maybe (Term b) -> Maybe (Term b)
forall a b. Maybe (a -> b) -> Maybe a -> Maybe b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Maybe (Term b) -> (() -> Maybe (Term b)) -> Maybe (Term b)
forall a. Maybe a -> (() -> Maybe a) -> Maybe a
forall e (m :: * -> *) a.
MonadError e m =>
m a -> (e -> m a) -> m a
catchError Maybe (Term b)
mf (\()
_ -> Term b -> Maybe (Term b)
forall a. a -> Maybe a
Just (Term b -> Maybe (Term b)) -> Term b -> Maybe (Term b)
forall a b. (a -> b) -> a -> b
$ PartialRuleUnary a b -> TotalRuleUnary a b -> TotalRuleUnary a b
forall a b. PartialFun a b -> (a -> b) -> a -> b
totalize (TotalRuleUnary a b -> PartialRuleUnary a b
oneLevel TotalRuleUnary a b
fallback') TotalRuleUnary a b
fallback Term a
vf)
      (Term a
_, Maybe (Term b)
pr) -> Maybe (Term b)
pr
    ret :: PartialRuleUnary a b
    ret :: PartialRuleUnary a b
ret = TotalRuleUnary a b -> PartialRuleUnary a b
oneLevel (forall a b. PartialFun a b -> (a -> b) -> a -> b
totalize @(Term a) @(Term b) PartialRuleUnary a b
partial TotalRuleUnary a b
fallback)

unaryUnfoldOnce ::
  forall a b.
  (SupportedPrim b) =>
  PartialRuleUnary a b ->
  TotalRuleUnary a b ->
  TotalRuleUnary a b
unaryUnfoldOnce :: forall a b.
SupportedPrim b =>
PartialRuleUnary a b -> TotalRuleUnary a b -> TotalRuleUnary a b
unaryUnfoldOnce PartialRuleUnary a b
partial TotalRuleUnary a b
fallback = PartialRuleUnary a b -> TotalRuleUnary a b -> TotalRuleUnary a b
forall a b. PartialFun a b -> (a -> b) -> a -> b
totalize (PartialRuleUnary a b -> TotalRuleUnary a b -> PartialRuleUnary a b
forall a b.
SupportedPrim b =>
PartialRuleUnary a b -> TotalRuleUnary a b -> PartialRuleUnary a b
unaryPartialUnfoldOnce PartialRuleUnary a b
partial TotalRuleUnary a b
fallback) TotalRuleUnary a b
fallback

binaryPartialUnfoldOnce ::
  forall a b c.
  (SupportedPrim c) =>
  PartialRuleBinary a b c ->
  TotalRuleBinary a b c ->
  PartialRuleBinary a b c
binaryPartialUnfoldOnce :: forall a b c.
SupportedPrim c =>
PartialRuleBinary a b c
-> TotalRuleBinary a b c -> PartialRuleBinary a b c
binaryPartialUnfoldOnce PartialRuleBinary a b c
partial TotalRuleBinary a b c
fallback = PartialRuleBinary a b c
ret
  where
    oneLevel :: PartialRuleBinary x y c -> TotalRuleBinary x y c -> PartialRuleBinary x y c
    oneLevel :: forall x y.
PartialRuleBinary x y c
-> TotalRuleBinary x y c -> PartialRuleBinary x y c
oneLevel PartialRuleBinary x y c
partial' TotalRuleBinary x y c
fallback' Term x
x Term y
y =
      Maybe (Term c) -> (() -> Maybe (Term c)) -> Maybe (Term c)
forall a. Maybe a -> (() -> Maybe a) -> Maybe a
forall e (m :: * -> *) a.
MonadError e m =>
m a -> (e -> m a) -> m a
catchError
        (PartialRuleBinary x y c
partial' Term x
x Term y
y)
        ( \()
_ ->
            Maybe (Term c) -> (() -> Maybe (Term c)) -> Maybe (Term c)
forall a. Maybe a -> (() -> Maybe a) -> Maybe a
forall e (m :: * -> *) a.
MonadError e m =>
m a -> (e -> m a) -> m a
catchError
              ( case Term x
x of
                  ITETerm Id
_ Term Bool
cond Term x
vt Term x
vf -> Term Bool
-> Term x
-> Term x
-> Term y
-> PartialRuleBinary x y c
-> TotalRuleBinary x y c
-> Maybe (Term c)
forall x y.
Term Bool
-> Term x
-> Term x
-> Term y
-> PartialRuleBinary x y c
-> TotalRuleBinary x y c
-> Maybe (Term c)
left Term Bool
cond Term x
vt Term x
vf Term y
y PartialRuleBinary x y c
partial' TotalRuleBinary x y c
fallback'
                  Term x
_ -> Maybe (Term c)
forall a. Maybe a
Nothing
              )
              ( \()
_ -> case Term y
y of
                  ITETerm Id
_ Term Bool
cond Term y
vt Term y
vf -> Term Bool
-> Term y
-> Term y
-> Term x
-> PartialRuleBinary y x c
-> TotalRuleBinary y x c
-> Maybe (Term c)
forall x y.
Term Bool
-> Term x
-> Term x
-> Term y
-> PartialRuleBinary x y c
-> TotalRuleBinary x y c
-> Maybe (Term c)
left Term Bool
cond Term y
vt Term y
vf Term x
x (PartialRuleBinary x y c -> PartialRuleBinary y x c
forall a b c. (a -> b -> c) -> b -> a -> c
flip PartialRuleBinary x y c
partial') (TotalRuleBinary x y c -> TotalRuleBinary y x c
forall a b c. (a -> b -> c) -> b -> a -> c
flip TotalRuleBinary x y c
fallback')
                  Term y
_ -> Maybe (Term c)
forall a. Maybe a
Nothing
              )
        )
    left ::
      Term Bool ->
      Term x ->
      Term x ->
      Term y ->
      PartialRuleBinary x y c ->
      TotalRuleBinary x y c ->
      Maybe (Term c)
    left :: forall x y.
Term Bool
-> Term x
-> Term x
-> Term y
-> PartialRuleBinary x y c
-> TotalRuleBinary x y c
-> Maybe (Term c)
left Term Bool
cond Term x
vt Term x
vf Term y
y PartialRuleBinary x y c
partial' TotalRuleBinary x y c
fallback' =
      let pt :: Maybe (Term c)
pt = PartialRuleBinary x y c
partial' Term x
vt Term y
y
          pf :: Maybe (Term c)
pf = PartialRuleBinary x y c
partial' Term x
vf Term y
y
       in case (Maybe (Term c)
pt, Maybe (Term c)
pf) of
            (Maybe (Term c)
Nothing, Maybe (Term c)
Nothing) -> Maybe (Term c)
forall a. Maybe a
Nothing
            (Maybe (Term c)
mt, Maybe (Term c)
mf) ->
              Term Bool -> Term c -> Term c -> Term c
forall t.
SupportedPrim t =>
Term Bool -> Term t -> Term t -> Term t
pevalITETerm Term Bool
cond
                (Term c -> Term c -> Term c)
-> Maybe (Term c) -> Maybe (Term c -> Term c)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Maybe (Term c) -> (() -> Maybe (Term c)) -> Maybe (Term c)
forall a. Maybe a -> (() -> Maybe a) -> Maybe a
forall e (m :: * -> *) a.
MonadError e m =>
m a -> (e -> m a) -> m a
catchError Maybe (Term c)
mt (\()
_ -> Term c -> Maybe (Term c)
forall a. a -> Maybe a
Just (Term c -> Maybe (Term c)) -> Term c -> Maybe (Term c)
forall a b. (a -> b) -> a -> b
$ PartialRuleBinary x y c
-> TotalRuleBinary x y c -> TotalRuleBinary x y c
forall a b c. (a -> PartialFun b c) -> (a -> b -> c) -> a -> b -> c
totalize2 (PartialRuleBinary x y c
-> TotalRuleBinary x y c -> PartialRuleBinary x y c
forall x y.
PartialRuleBinary x y c
-> TotalRuleBinary x y c -> PartialRuleBinary x y c
oneLevel PartialRuleBinary x y c
partial' TotalRuleBinary x y c
fallback') TotalRuleBinary x y c
fallback' Term x
vt Term y
y)
                Maybe (Term c -> Term c) -> Maybe (Term c) -> Maybe (Term c)
forall a b. Maybe (a -> b) -> Maybe a -> Maybe b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Maybe (Term c) -> (() -> Maybe (Term c)) -> Maybe (Term c)
forall a. Maybe a -> (() -> Maybe a) -> Maybe a
forall e (m :: * -> *) a.
MonadError e m =>
m a -> (e -> m a) -> m a
catchError Maybe (Term c)
mf (\()
_ -> Term c -> Maybe (Term c)
forall a. a -> Maybe a
Just (Term c -> Maybe (Term c)) -> Term c -> Maybe (Term c)
forall a b. (a -> b) -> a -> b
$ PartialRuleBinary x y c
-> TotalRuleBinary x y c -> TotalRuleBinary x y c
forall a b c. (a -> PartialFun b c) -> (a -> b -> c) -> a -> b -> c
totalize2 (PartialRuleBinary x y c
-> TotalRuleBinary x y c -> PartialRuleBinary x y c
forall x y.
PartialRuleBinary x y c
-> TotalRuleBinary x y c -> PartialRuleBinary x y c
oneLevel PartialRuleBinary x y c
partial' TotalRuleBinary x y c
fallback') TotalRuleBinary x y c
fallback' Term x
vf Term y
y)
    ret :: PartialRuleBinary a b c
    ret :: PartialRuleBinary a b c
ret = PartialRuleBinary a b c
-> TotalRuleBinary a b c -> PartialRuleBinary a b c
forall x y.
PartialRuleBinary x y c
-> TotalRuleBinary x y c -> PartialRuleBinary x y c
oneLevel PartialRuleBinary a b c
partial (forall a b c. (a -> PartialFun b c) -> (a -> b -> c) -> a -> b -> c
totalize2 @(Term a) @(Term b) @(Term c) PartialRuleBinary a b c
partial TotalRuleBinary a b c
fallback)

binaryUnfoldOnce ::
  forall a b c.
  (Typeable a, Typeable b, SupportedPrim c) =>
  PartialRuleBinary a b c ->
  TotalRuleBinary a b c ->
  TotalRuleBinary a b c
binaryUnfoldOnce :: forall a b c.
(Typeable a, Typeable b, SupportedPrim c) =>
PartialRuleBinary a b c
-> TotalRuleBinary a b c -> TotalRuleBinary a b c
binaryUnfoldOnce PartialRuleBinary a b c
partial TotalRuleBinary a b c
fallback = PartialRuleBinary a b c
-> TotalRuleBinary a b c -> TotalRuleBinary a b c
forall a b c. (a -> PartialFun b c) -> (a -> b -> c) -> a -> b -> c
totalize2 (PartialRuleBinary a b c
-> TotalRuleBinary a b c -> PartialRuleBinary a b c
forall a b c.
SupportedPrim c =>
PartialRuleBinary a b c
-> TotalRuleBinary a b c -> PartialRuleBinary a b c
binaryPartialUnfoldOnce PartialRuleBinary a b c
partial TotalRuleBinary a b c
fallback) TotalRuleBinary a b c
fallback