{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE TypeApplications    #-}

module ZkFold.Symbolic.Cardano.UPLC.Inference where

import           Data.Maybe                                      (fromMaybe, maybeToList)
import           Data.Typeable                                   (Proxy (..))
import           Prelude

import           ZkFold.Symbolic.Cardano.UPLC.Builtins
import           ZkFold.Symbolic.Cardano.UPLC.Inference.Internal
import           ZkFold.Symbolic.Cardano.UPLC.Term
import           ZkFold.Symbolic.Cardano.UPLC.Type

-- TODO: Variable names must be unique for this to work!
-- TODO: Properly infer polymorphic type instantiations

inferType :: forall name fun a . (Eq name, Eq fun, PlutusBuiltinFunction a fun) =>
    (Term name fun a, SomeType a) -> TypeList name fun a -> TypeList name fun a
inferType :: forall name fun a.
(Eq name, Eq fun, PlutusBuiltinFunction a fun) =>
(Term name fun a, SomeType a)
-> TypeList name fun a -> TypeList name fun a
inferType (Var name
x, SomeType a
t) TypeList name fun a
types =
    let mf :: Maybe (Term name fun a, SomeType a)
mf = (, SomeType a -> SomeType a -> SomeType a
forall a. SomeType a -> SomeType a -> SomeType a
SomeFunction SomeType a
t SomeType a
forall a. SomeType a
NoType) (Term name fun a -> (Term name fun a, SomeType a))
-> Maybe (Term name fun a) -> Maybe (Term name fun a, SomeType a)
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> name -> TypeList name fun a -> Maybe (Term name fun a)
forall name fun a.
Eq name =>
name -> TypeList name fun a -> Maybe (Term name fun a)
findLambda name
x TypeList name fun a
types
    in TypeList name fun a -> TypeList name fun a -> TypeList name fun a
forall name fun a.
(Eq name, Eq fun) =>
TypeList name fun a -> TypeList name fun a -> TypeList name fun a
updateTypeList TypeList name fun a
types (TypeList name fun a -> TypeList name fun a)
-> TypeList name fun a -> TypeList name fun a
forall a b. (a -> b) -> a -> b
$ Maybe (Term name fun a, SomeType a) -> TypeList name fun a
forall a. Maybe a -> [a]
maybeToList Maybe (Term name fun a, SomeType a)
mf
inferType (LamAbs name
x Term name fun a
f, SomeFunction SomeType a
t1 SomeType a
t2) TypeList name fun a
types =
    TypeList name fun a -> TypeList name fun a -> TypeList name fun a
forall name fun a.
(Eq name, Eq fun) =>
TypeList name fun a -> TypeList name fun a -> TypeList name fun a
updateTypeList TypeList name fun a
types [(name -> Term name fun a
forall name fun a. name -> Term name fun a
Var name
x, SomeType a
t1), (Term name fun a
f, SomeType a
t2)]
inferType (Apply Term name fun a
f Term name fun a
x, SomeType a
t2) TypeList name fun a
types =
    let t1 :: SomeType a
t1  = SomeType a -> Maybe (SomeType a) -> SomeType a
forall a. a -> Maybe a -> a
fromMaybe SomeType a
forall a. SomeType a
NoType (Maybe (SomeType a) -> SomeType a)
-> Maybe (SomeType a) -> SomeType a
forall a b. (a -> b) -> a -> b
$ Term name fun a -> TypeList name fun a -> Maybe (SomeType a)
forall name fun a.
(Eq name, Eq fun) =>
Term name fun a -> TypeList name fun a -> Maybe (SomeType a)
findTermType Term name fun a
x TypeList name fun a
types
        t :: SomeType a
t   = case SomeType a -> Maybe (SomeType a) -> SomeType a
forall a. a -> Maybe a -> a
fromMaybe SomeType a
forall a. SomeType a
NoType (Maybe (SomeType a) -> SomeType a)
-> Maybe (SomeType a) -> SomeType a
forall a b. (a -> b) -> a -> b
$ Term name fun a -> TypeList name fun a -> Maybe (SomeType a)
forall name fun a.
(Eq name, Eq fun) =>
Term name fun a -> TypeList name fun a -> Maybe (SomeType a)
findTermType Term name fun a
f TypeList name fun a
types of
            SomeFunction SomeType a
_ SomeType a
t2' -> SomeType a
t2'
            SomeType a
_                  -> SomeType a
forall a. SomeType a
NoType
    in TypeList name fun a -> TypeList name fun a -> TypeList name fun a
forall name fun a.
(Eq name, Eq fun) =>
TypeList name fun a -> TypeList name fun a -> TypeList name fun a
updateTypeList TypeList name fun a
types [(Term name fun a
f, SomeType a -> SomeType a -> SomeType a
forall a. SomeType a -> SomeType a -> SomeType a
SomeFunction SomeType a
t1 SomeType a
t2), (Term name fun a -> Term name fun a -> Term name fun a
forall name fun a.
Term name fun a -> Term name fun a -> Term name fun a
Apply Term name fun a
f Term name fun a
x, SomeType a
t)]
inferType (Force Term name fun a
x, SomeType a
t) TypeList name fun a
types =
    TypeList name fun a -> TypeList name fun a -> TypeList name fun a
forall name fun a.
(Eq name, Eq fun) =>
TypeList name fun a -> TypeList name fun a -> TypeList name fun a
updateTypeList TypeList name fun a
types [(Term name fun a
x, SomeType a
t)]
inferType (Delay Term name fun a
x, SomeType a
t) TypeList name fun a
types =
    TypeList name fun a -> TypeList name fun a -> TypeList name fun a
forall name fun a.
(Eq name, Eq fun) =>
TypeList name fun a -> TypeList name fun a -> TypeList name fun a
updateTypeList TypeList name fun a
types [(Term name fun a
x, SomeType a
t)]
inferType (Constant (c
c :: c), SomeType a
_) TypeList name fun a
types =
    TypeList name fun a -> TypeList name fun a -> TypeList name fun a
forall name fun a.
(Eq name, Eq fun) =>
TypeList name fun a -> TypeList name fun a -> TypeList name fun a
updateTypeList TypeList name fun a
types [(c -> Term name fun a
forall c a name fun.
(Eq c, Typeable c, SymbolicData a c) =>
c -> Term name fun a
Constant c
c, SomeSymbolic a -> SomeType a
forall a. SomeSymbolic a -> SomeType a
SomeSym (SomeSymbolic a -> SomeType a) -> SomeSymbolic a -> SomeType a
forall a b. (a -> b) -> a -> b
$ Proxy c -> SomeSymbolic a
forall a t.
(Typeable t, SymbolicData a t) =>
Proxy t -> SomeSymbolic a
SomeData (Proxy c
forall {k} (t :: k). Proxy t
Proxy :: Proxy c))]
inferType (Builtin fun
b, SomeType a
_) TypeList name fun a
types =
    TypeList name fun a -> TypeList name fun a -> TypeList name fun a
forall name fun a.
(Eq name, Eq fun) =>
TypeList name fun a -> TypeList name fun a -> TypeList name fun a
updateTypeList TypeList name fun a
types [(fun -> Term name fun a
forall fun name a. fun -> Term name fun a
Builtin fun
b, fun -> SomeType a
forall a fun. PlutusBuiltinFunction a fun => fun -> SomeType a
builtinFunctionType fun
b)]
inferType (Term name fun a
Error, SomeType a
_) TypeList name fun a
types =
    TypeList name fun a -> TypeList name fun a -> TypeList name fun a
forall name fun a.
(Eq name, Eq fun) =>
TypeList name fun a -> TypeList name fun a -> TypeList name fun a
updateTypeList TypeList name fun a
types [(Term name fun a
forall name fun a. Term name fun a
Error, SomeType a
forall a. SomeType a
AnyType)]
inferType (Term name fun a, SomeType a)
_ TypeList name fun a
types = TypeList name fun a
types

inferTypes :: forall name fun a . (Eq name, Eq fun, PlutusBuiltinFunction a fun) =>
    Term name fun a -> (Term name fun a, SomeType a)
inferTypes :: forall name fun a.
(Eq name, Eq fun, PlutusBuiltinFunction a fun) =>
Term name fun a -> (Term name fun a, SomeType a)
inferTypes Term name fun a
term = [(Term name fun a, SomeType a)] -> (Term name fun a, SomeType a)
forall a. HasCallStack => [a] -> a
head ([(Term name fun a, SomeType a)] -> (Term name fun a, SomeType a))
-> [(Term name fun a, SomeType a)] -> (Term name fun a, SomeType a)
forall a b. (a -> b) -> a -> b
$ [(Term name fun a, SomeType a)] -> [(Term name fun a, SomeType a)]
forall {name} {fun} {a}.
(Eq name, Eq fun, PlutusBuiltinFunction a fun) =>
TypeList name fun a -> TypeList name fun a
go (Term name fun a -> [(Term name fun a, SomeType a)]
forall name fun a.
(Eq name, Eq fun) =>
Term name fun a -> TypeList name fun a
makeTypeList Term name fun a
term)
    where
        go :: TypeList name fun a -> TypeList name fun a
go TypeList name fun a
types =
            let types' :: TypeList name fun a
types' = ((Term name fun a, SomeType a)
 -> TypeList name fun a -> TypeList name fun a)
-> TypeList name fun a
-> TypeList name fun a
-> TypeList name fun a
forall a b. (a -> b -> b) -> b -> [a] -> b
forall (t :: Type -> Type) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr (Term name fun a, SomeType a)
-> TypeList name fun a -> TypeList name fun a
forall name fun a.
(Eq name, Eq fun, PlutusBuiltinFunction a fun) =>
(Term name fun a, SomeType a)
-> TypeList name fun a -> TypeList name fun a
inferType TypeList name fun a
types TypeList name fun a
types
            in if TypeList name fun a
types TypeList name fun a -> TypeList name fun a -> Bool
forall a. Eq a => a -> a -> Bool
== TypeList name fun a
types' then TypeList name fun a
types else TypeList name fun a
types'

-- To obtain an arithmetizable term, we need all types to be concrete.
inferSuccess :: forall name fun a . (Eq name, Eq fun) => SomeType a -> Bool
inferSuccess :: forall name fun a. (Eq name, Eq fun) => SomeType a -> Bool
inferSuccess (SomeSym SomeSymbolic a
_)          = Bool
True
inferSuccess (SomeFunction SomeType a
t1 SomeType a
t2) = forall name fun a. (Eq name, Eq fun) => SomeType a -> Bool
inferSuccess @name @fun SomeType a
t1 Bool -> Bool -> Bool
&& forall name fun a. (Eq name, Eq fun) => SomeType a -> Bool
inferSuccess @name @fun SomeType a
t2
inferSuccess SomeType a
_                    = Bool
False