{-# LANGUAGE AllowAmbiguousTypes #-}

module ZkFold.Symbolic.Cardano.UPLC.Inference.Internal where

import           Data.List                         (find)
import           Prelude

import           ZkFold.Symbolic.Cardano.UPLC.Term
import           ZkFold.Symbolic.Cardano.UPLC.Type

type TypeList name fun a = [(Term name fun a, SomeType a)]

updateTermType :: (Eq name, Eq fun) =>
    (Term name fun a, SomeType a) -> (Term name fun a, SomeType a) -> (Term name fun a, SomeType a)
updateTermType :: forall name fun a.
(Eq name, Eq fun) =>
(Term name fun a, SomeType a)
-> (Term name fun a, SomeType a) -> (Term name fun a, SomeType a)
updateTermType (Term name fun a
term1, SomeType a
t1) (Term name fun a
term2, SomeType a
t2) = if Term name fun a
term1 Term name fun a -> Term name fun a -> Bool
forall a. Eq a => a -> a -> Bool
== Term name fun a
term2 then (Term name fun a
term1, SomeType a
t1 SomeType a -> SomeType a -> SomeType a
forall a. Semigroup a => a -> a -> a
<> SomeType a
t2) else (Term name fun a
term1, SomeType a
t1)

makeTermList :: (Eq name, Eq fun) => Term name fun a -> [Term name fun a]
makeTermList :: forall name fun a.
(Eq name, Eq fun) =>
Term name fun a -> [Term name fun a]
makeTermList (Var name
x)      = [name -> Term name fun a
forall name fun a. name -> Term name fun a
Var name
x]
makeTermList (LamAbs name
x Term name fun a
f) = name -> Term name fun a -> Term name fun a
forall name fun a. name -> Term name fun a -> Term name fun a
LamAbs name
x Term name fun a
f Term name fun a -> [Term name fun a] -> [Term name fun a]
forall a. a -> [a] -> [a]
: Term name fun a -> [Term name fun a]
forall name fun a.
(Eq name, Eq fun) =>
Term name fun a -> [Term name fun a]
makeTermList Term name fun a
f
makeTermList (Apply Term name fun a
f Term name fun a
x)  = Term name fun a -> Term name fun a -> Term name fun a
forall name fun a.
Term name fun a -> Term name fun a -> Term name fun a
Apply Term name fun a
f Term name fun a
x Term name fun a -> [Term name fun a] -> [Term name fun a]
forall a. a -> [a] -> [a]
: Term name fun a -> [Term name fun a]
forall name fun a.
(Eq name, Eq fun) =>
Term name fun a -> [Term name fun a]
makeTermList Term name fun a
f [Term name fun a] -> [Term name fun a] -> [Term name fun a]
forall a. [a] -> [a] -> [a]
++ Term name fun a -> [Term name fun a]
forall name fun a.
(Eq name, Eq fun) =>
Term name fun a -> [Term name fun a]
makeTermList Term name fun a
x
makeTermList (Force Term name fun a
x)    = Term name fun a -> Term name fun a
forall name fun a. Term name fun a -> Term name fun a
Force Term name fun a
x Term name fun a -> [Term name fun a] -> [Term name fun a]
forall a. a -> [a] -> [a]
: Term name fun a -> [Term name fun a]
forall name fun a.
(Eq name, Eq fun) =>
Term name fun a -> [Term name fun a]
makeTermList Term name fun a
x
makeTermList (Delay Term name fun a
x)    = Term name fun a -> Term name fun a
forall name fun a. Term name fun a -> Term name fun a
Delay Term name fun a
x Term name fun a -> [Term name fun a] -> [Term name fun a]
forall a. a -> [a] -> [a]
: Term name fun a -> [Term name fun a]
forall name fun a.
(Eq name, Eq fun) =>
Term name fun a -> [Term name fun a]
makeTermList Term name fun a
x
makeTermList (Constant c
x) = [c -> Term name fun a
forall c a name fun.
(Eq c, Typeable c, SymbolicData a c) =>
c -> Term name fun a
Constant c
x]
makeTermList (Builtin fun
x)  = [fun -> Term name fun a
forall fun name a. fun -> Term name fun a
Builtin fun
x]
makeTermList Term name fun a
Error        = [Term name fun a
forall name fun a. Term name fun a
Error]

makeTypeList :: (Eq name, Eq fun) => Term name fun a -> TypeList name fun a
makeTypeList :: forall name fun a.
(Eq name, Eq fun) =>
Term name fun a -> TypeList name fun a
makeTypeList = (Term name fun a -> (Term name fun a, SomeType a))
-> [Term name fun a] -> [(Term name fun a, SomeType a)]
forall a b. (a -> b) -> [a] -> [b]
map (, SomeType a
forall a. SomeType a
NoType) ([Term name fun a] -> [(Term name fun a, SomeType a)])
-> (Term name fun a -> [Term name fun a])
-> Term name fun a
-> [(Term name fun a, SomeType a)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Term name fun a -> [Term name fun a]
forall name fun a.
(Eq name, Eq fun) =>
Term name fun a -> [Term name fun a]
makeTermList

findTermType :: (Eq name, Eq fun) => Term name fun a -> TypeList name fun a -> Maybe (SomeType a)
findTermType :: forall name fun a.
(Eq name, Eq fun) =>
Term name fun a -> TypeList name fun a -> Maybe (SomeType a)
findTermType Term name fun a
term = ((Term name fun a, SomeType a) -> SomeType a)
-> Maybe (Term name fun a, SomeType a) -> Maybe (SomeType a)
forall a b. (a -> b) -> Maybe a -> Maybe b
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
fmap (Term name fun a, SomeType a) -> SomeType a
forall a b. (a, b) -> b
snd (Maybe (Term name fun a, SomeType a) -> Maybe (SomeType a))
-> (TypeList name fun a -> Maybe (Term name fun a, SomeType a))
-> TypeList name fun a
-> Maybe (SomeType a)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((Term name fun a, SomeType a) -> Bool)
-> TypeList name fun a -> Maybe (Term name fun a, SomeType a)
forall (t :: Type -> Type) a.
Foldable t =>
(a -> Bool) -> t a -> Maybe a
find ((Term name fun a -> Term name fun a -> Bool
forall a. Eq a => a -> a -> Bool
== Term name fun a
term) (Term name fun a -> Bool)
-> ((Term name fun a, SomeType a) -> Term name fun a)
-> (Term name fun a, SomeType a)
-> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Term name fun a, SomeType a) -> Term name fun a
forall a b. (a, b) -> a
fst)

findLambda :: (Eq name) => name -> TypeList name fun a -> Maybe (Term name fun a)
findLambda :: forall name fun a.
Eq name =>
name -> TypeList name fun a -> Maybe (Term name fun a)
findLambda name
x = ((Term name fun a, SomeType a) -> Term name fun a)
-> Maybe (Term name fun a, SomeType a) -> Maybe (Term name fun a)
forall a b. (a -> b) -> Maybe a -> Maybe b
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
fmap (Term name fun a, SomeType a) -> Term name fun a
forall a b. (a, b) -> a
fst (Maybe (Term name fun a, SomeType a) -> Maybe (Term name fun a))
-> (TypeList name fun a -> Maybe (Term name fun a, SomeType a))
-> TypeList name fun a
-> Maybe (Term name fun a)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((Term name fun a, SomeType a) -> Bool)
-> TypeList name fun a -> Maybe (Term name fun a, SomeType a)
forall (t :: Type -> Type) a.
Foldable t =>
(a -> Bool) -> t a -> Maybe a
find (Term name fun a -> Bool
isLambda (Term name fun a -> Bool)
-> ((Term name fun a, SomeType a) -> Term name fun a)
-> (Term name fun a, SomeType a)
-> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Term name fun a, SomeType a) -> Term name fun a
forall a b. (a, b) -> a
fst)
    where isLambda :: Term name fun a -> Bool
isLambda (LamAbs name
y Term name fun a
_) = name
x name -> name -> Bool
forall a. Eq a => a -> a -> Bool
== name
y
          isLambda Term name fun a
_            = Bool
False

updateTypeList :: (Eq name, Eq fun) => TypeList name fun a -> TypeList name fun a -> TypeList name fun a
updateTypeList :: forall name fun a.
(Eq name, Eq fun) =>
TypeList name fun a -> TypeList name fun a -> TypeList name fun a
updateTypeList TypeList name fun a
xs []             = TypeList name fun a
xs
updateTypeList TypeList name fun a
xs ((Term name fun a
term, SomeType a
t):TypeList name fun a
ys) = TypeList name fun a -> TypeList name fun a -> TypeList name fun a
forall name fun a.
(Eq name, Eq fun) =>
TypeList name fun a -> TypeList name fun a -> TypeList name fun a
updateTypeList (((Term name fun a, SomeType a) -> (Term name fun a, SomeType a))
-> TypeList name fun a -> TypeList name fun a
forall a b. (a -> b) -> [a] -> [b]
map ((Term name fun a, SomeType a)
-> (Term name fun a, SomeType a) -> (Term name fun a, SomeType a)
forall name fun a.
(Eq name, Eq fun) =>
(Term name fun a, SomeType a)
-> (Term name fun a, SomeType a) -> (Term name fun a, SomeType a)
updateTermType (Term name fun a
term, SomeType a
t)) TypeList name fun a
xs) TypeList name fun a
ys