{-# LANGUAGE CPP, LambdaCase, ViewPatterns #-}
{-# OPTIONS -Wno-name-shadowing #-}
module TypeLevel.Rewrite.Internal.TypeRule where

-- GHC API
#if MIN_VERSION_ghc(9,0,0)
import GHC.Plugins (getOccString)
import GHC.Core.Predicate (mkPrimEqPred)
import GHC.Plugins (TyVar, Type, mkTyVarTy)
#else
import Name (getOccString)
import Predicate (mkPrimEqPred)
import Type (TyVar, Type, mkTyVarTy)
#endif

-- term-rewriting API
import Data.Rewriting.Rule (Rule(..))
import Data.Rewriting.Term (Term(..))

import TypeLevel.Rewrite.Internal.TypeNode
import TypeLevel.Rewrite.Internal.TypeTemplate


type TypeRule = Rule TypeNode TyVar

toTypeRule_maybe
  :: Type
  -> Maybe TypeRule
toTypeRule_maybe :: Type -> Maybe TypeRule
toTypeRule_maybe (Type -> Maybe TypeTemplate
toTypeTemplate_maybe -> Just (Fun (TyCon (TyCon -> String
forall a. NamedThing a => a -> String
getOccString -> String
"~")) [TypeTemplate
_type, TypeTemplate
lhs_, TypeTemplate
rhs_]))
  = TypeRule -> Maybe TypeRule
forall a. a -> Maybe a
Just (TypeTemplate -> TypeTemplate -> TypeRule
forall f v. Term f v -> Term f v -> Rule f v
Rule TypeTemplate
lhs_ TypeTemplate
rhs_)
toTypeRule_maybe Type
_
  = Maybe TypeRule
forall a. Maybe a
Nothing

fromTyVar
  :: TyVar
  -> Type
fromTyVar :: TyVar -> Type
fromTyVar
  = TyVar -> Type
mkTyVarTy

fromTerm
  :: (f -> [Type] -> Type)
  -> (v -> Type)
  -> Term f v
  -> Type
fromTerm :: forall f v.
(f -> [Type] -> Type) -> (v -> Type) -> Term f v -> Type
fromTerm f -> [Type] -> Type
fromF v -> Type
fromV = \case
  Var v
v
    -> v -> Type
fromV v
v
  Fun f
f [Term f v]
args
    -> f -> [Type] -> Type
fromF f
f ((Term f v -> Type) -> [Term f v] -> [Type]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ((f -> [Type] -> Type) -> (v -> Type) -> Term f v -> Type
forall f v.
(f -> [Type] -> Type) -> (v -> Type) -> Term f v -> Type
fromTerm f -> [Type] -> Type
fromF v -> Type
fromV) [Term f v]
args)

fromTypeRule
  :: TypeRule
  -> Type
fromTypeRule :: TypeRule -> Type
fromTypeRule (Rule TypeTemplate
lhs TypeTemplate
rhs)
  = Type -> Type -> Type
mkPrimEqPred ((TypeNode -> [Type] -> Type)
-> (TyVar -> Type) -> TypeTemplate -> Type
forall f v.
(f -> [Type] -> Type) -> (v -> Type) -> Term f v -> Type
fromTerm TypeNode -> [Type] -> Type
fromTypeNode TyVar -> Type
fromTyVar TypeTemplate
lhs)
                 ((TypeNode -> [Type] -> Type)
-> (TyVar -> Type) -> TypeTemplate -> Type
forall f v.
(f -> [Type] -> Type) -> (v -> Type) -> Term f v -> Type
fromTerm TypeNode -> [Type] -> Type
fromTypeNode TyVar -> Type
fromTyVar TypeTemplate
rhs)