{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}

-- |
-- Module      :   Grisette.IR.SymPrim.Data.Prim.PartialEval.Unfold
-- Copyright   :   (c) Sirui Lu 2021-2023
-- License     :   BSD-3-Clause (see the LICENSE file)
--
-- Maintainer  :   siruilu@cs.washington.edu
-- Stability   :   Experimental
-- Portability :   GHC only
module Grisette.IR.SymPrim.Data.Prim.PartialEval.Unfold
  ( unaryUnfoldOnce,
    binaryUnfoldOnce,
  )
where

import Control.Monad.Except
import Data.Typeable
import Grisette.IR.SymPrim.Data.Prim.InternedTerm.Term
import Grisette.IR.SymPrim.Data.Prim.PartialEval.Bool
import Grisette.IR.SymPrim.Data.Prim.PartialEval.PartialEval

unaryPartialUnfoldOnce ::
  forall a b.
  (Typeable a, SupportedPrim b) =>
  PartialRuleUnary a b ->
  TotalRuleUnary a b ->
  PartialRuleUnary a b
unaryPartialUnfoldOnce :: forall a b.
(Typeable a, SupportedPrim b) =>
PartialRuleUnary a b -> TotalRuleUnary a b -> PartialRuleUnary a b
unaryPartialUnfoldOnce PartialRuleUnary a b
partial TotalRuleUnary a b
fallback = PartialRuleUnary a b
ret
  where
    oneLevel :: TotalRuleUnary a b -> PartialRuleUnary a b
    oneLevel :: TotalRuleUnary a b -> PartialRuleUnary a b
oneLevel TotalRuleUnary a b
fallback' Term a
x = case (Term a
x, PartialRuleUnary a b
partial Term a
x) of
      (ITETerm Id
_ Term Bool
cond Term a
vt Term a
vf, Maybe (Term b)
pr) ->
        let pt :: Maybe (Term b)
pt = PartialRuleUnary a b
partial Term a
vt
            pf :: Maybe (Term b)
pf = PartialRuleUnary a b
partial Term a
vf
         in case (Maybe (Term b)
pt, Maybe (Term b)
pf) of
              (Maybe (Term b)
Nothing, Maybe (Term b)
Nothing) -> Maybe (Term b)
pr
              (Maybe (Term b)
mt, Maybe (Term b)
mf) ->
                forall a.
SupportedPrim a =>
Term Bool -> Term a -> Term a -> Term a
pevalITETerm Term Bool
cond
                  forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall e (m :: * -> *) a.
MonadError e m =>
m a -> (e -> m a) -> m a
catchError Maybe (Term b)
mt (\()
_ -> forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ forall a b. PartialFun a b -> (a -> b) -> a -> b
totalize (TotalRuleUnary a b -> PartialRuleUnary a b
oneLevel TotalRuleUnary a b
fallback') TotalRuleUnary a b
fallback' Term a
vt)
                  forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall e (m :: * -> *) a.
MonadError e m =>
m a -> (e -> m a) -> m a
catchError Maybe (Term b)
mf (\()
_ -> forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ forall a b. PartialFun a b -> (a -> b) -> a -> b
totalize (TotalRuleUnary a b -> PartialRuleUnary a b
oneLevel TotalRuleUnary a b
fallback') TotalRuleUnary a b
fallback Term a
vf)
      (Term a
_, Maybe (Term b)
pr) -> Maybe (Term b)
pr
    ret :: PartialRuleUnary a b
    ret :: PartialRuleUnary a b
ret = TotalRuleUnary a b -> PartialRuleUnary a b
oneLevel (forall a b. PartialFun a b -> (a -> b) -> a -> b
totalize @(Term a) @(Term b) PartialRuleUnary a b
partial TotalRuleUnary a b
fallback)

unaryUnfoldOnce ::
  forall a b.
  (Typeable a, SupportedPrim b) =>
  PartialRuleUnary a b ->
  TotalRuleUnary a b ->
  TotalRuleUnary a b
unaryUnfoldOnce :: forall a b.
(Typeable a, SupportedPrim b) =>
PartialRuleUnary a b -> TotalRuleUnary a b -> TotalRuleUnary a b
unaryUnfoldOnce PartialRuleUnary a b
partial TotalRuleUnary a b
fallback = forall a b. PartialFun a b -> (a -> b) -> a -> b
totalize (forall a b.
(Typeable a, SupportedPrim b) =>
PartialRuleUnary a b -> TotalRuleUnary a b -> PartialRuleUnary a b
unaryPartialUnfoldOnce PartialRuleUnary a b
partial TotalRuleUnary a b
fallback) TotalRuleUnary a b
fallback

binaryPartialUnfoldOnce ::
  forall a b c.
  (Typeable a, Typeable b, SupportedPrim c) =>
  PartialRuleBinary a b c ->
  TotalRuleBinary a b c ->
  PartialRuleBinary a b c
binaryPartialUnfoldOnce :: forall a b c.
(Typeable a, Typeable b, SupportedPrim c) =>
PartialRuleBinary a b c
-> TotalRuleBinary a b c -> PartialRuleBinary a b c
binaryPartialUnfoldOnce PartialRuleBinary a b c
partial TotalRuleBinary a b c
fallback = PartialRuleBinary a b c
ret
  where
    oneLevel :: (Typeable x, Typeable y) => PartialRuleBinary x y c -> TotalRuleBinary x y c -> PartialRuleBinary x y c
    oneLevel :: forall x y.
(Typeable x, Typeable y) =>
PartialRuleBinary x y c
-> TotalRuleBinary x y c -> PartialRuleBinary x y c
oneLevel PartialRuleBinary x y c
partial' TotalRuleBinary x y c
fallback' Term x
x Term y
y =
      forall e (m :: * -> *) a.
MonadError e m =>
m a -> (e -> m a) -> m a
catchError
        (PartialRuleBinary x y c
partial' Term x
x Term y
y)
        ( \()
_ ->
            forall e (m :: * -> *) a.
MonadError e m =>
m a -> (e -> m a) -> m a
catchError
              ( case Term x
x of
                  ITETerm Id
_ Term Bool
cond Term x
vt Term x
vf -> forall x y.
(Typeable x, Typeable y) =>
Term Bool
-> Term x
-> Term x
-> Term y
-> PartialRuleBinary x y c
-> TotalRuleBinary x y c
-> Maybe (Term c)
left Term Bool
cond Term x
vt Term x
vf Term y
y PartialRuleBinary x y c
partial' TotalRuleBinary x y c
fallback'
                  Term x
_ -> forall a. Maybe a
Nothing
              )
              ( \()
_ -> case Term y
y of
                  ITETerm Id
_ Term Bool
cond Term y
vt Term y
vf -> forall x y.
(Typeable x, Typeable y) =>
Term Bool
-> Term x
-> Term x
-> Term y
-> PartialRuleBinary x y c
-> TotalRuleBinary x y c
-> Maybe (Term c)
left Term Bool
cond Term y
vt Term y
vf Term x
x (forall a b c. (a -> b -> c) -> b -> a -> c
flip PartialRuleBinary x y c
partial') (forall a b c. (a -> b -> c) -> b -> a -> c
flip TotalRuleBinary x y c
fallback')
                  Term y
_ -> forall a. Maybe a
Nothing
              )
        )
    left ::
      (Typeable x, Typeable y) =>
      Term Bool ->
      Term x ->
      Term x ->
      Term y ->
      PartialRuleBinary x y c ->
      TotalRuleBinary x y c ->
      Maybe (Term c)
    left :: forall x y.
(Typeable x, Typeable y) =>
Term Bool
-> Term x
-> Term x
-> Term y
-> PartialRuleBinary x y c
-> TotalRuleBinary x y c
-> Maybe (Term c)
left Term Bool
cond Term x
vt Term x
vf Term y
y PartialRuleBinary x y c
partial' TotalRuleBinary x y c
fallback' =
      let pt :: Maybe (Term c)
pt = PartialRuleBinary x y c
partial' Term x
vt Term y
y
          pf :: Maybe (Term c)
pf = PartialRuleBinary x y c
partial' Term x
vf Term y
y
       in case (Maybe (Term c)
pt, Maybe (Term c)
pf) of
            (Maybe (Term c)
Nothing, Maybe (Term c)
Nothing) -> forall a. Maybe a
Nothing
            (Maybe (Term c)
mt, Maybe (Term c)
mf) ->
              forall a.
SupportedPrim a =>
Term Bool -> Term a -> Term a -> Term a
pevalITETerm Term Bool
cond
                forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall e (m :: * -> *) a.
MonadError e m =>
m a -> (e -> m a) -> m a
catchError Maybe (Term c)
mt (\()
_ -> forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ forall a b c. (a -> PartialFun b c) -> (a -> b -> c) -> a -> b -> c
totalize2 (forall x y.
(Typeable x, Typeable y) =>
PartialRuleBinary x y c
-> TotalRuleBinary x y c -> PartialRuleBinary x y c
oneLevel PartialRuleBinary x y c
partial' TotalRuleBinary x y c
fallback') TotalRuleBinary x y c
fallback' Term x
vt Term y
y)
                forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall e (m :: * -> *) a.
MonadError e m =>
m a -> (e -> m a) -> m a
catchError Maybe (Term c)
mf (\()
_ -> forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ forall a b c. (a -> PartialFun b c) -> (a -> b -> c) -> a -> b -> c
totalize2 (forall x y.
(Typeable x, Typeable y) =>
PartialRuleBinary x y c
-> TotalRuleBinary x y c -> PartialRuleBinary x y c
oneLevel PartialRuleBinary x y c
partial' TotalRuleBinary x y c
fallback') TotalRuleBinary x y c
fallback' Term x
vf Term y
y)
    ret :: PartialRuleBinary a b c
    ret :: PartialRuleBinary a b c
ret = forall x y.
(Typeable x, Typeable y) =>
PartialRuleBinary x y c
-> TotalRuleBinary x y c -> PartialRuleBinary x y c
oneLevel PartialRuleBinary a b c
partial (forall a b c. (a -> PartialFun b c) -> (a -> b -> c) -> a -> b -> c
totalize2 @(Term a) @(Term b) @(Term c) PartialRuleBinary a b c
partial TotalRuleBinary a b c
fallback)

binaryUnfoldOnce ::
  forall a b c.
  (Typeable a, Typeable b, SupportedPrim c) =>
  PartialRuleBinary a b c ->
  TotalRuleBinary a b c ->
  TotalRuleBinary a b c
binaryUnfoldOnce :: forall a b c.
(Typeable a, Typeable b, SupportedPrim c) =>
PartialRuleBinary a b c
-> TotalRuleBinary a b c -> TotalRuleBinary a b c
binaryUnfoldOnce PartialRuleBinary a b c
partial TotalRuleBinary a b c
fallback = forall a b c. (a -> PartialFun b c) -> (a -> b -> c) -> a -> b -> c
totalize2 (forall a b c.
(Typeable a, Typeable b, SupportedPrim c) =>
PartialRuleBinary a b c
-> TotalRuleBinary a b c -> PartialRuleBinary a b c
binaryPartialUnfoldOnce PartialRuleBinary a b c
partial TotalRuleBinary a b c
fallback) TotalRuleBinary a b c
fallback