{-# LANGUAGE TemplateHaskell, DefaultSignatures, FlexibleInstances #-}
module AST.Class.Infer
( InferOf
, Infer(..)
, InferChild(..), _InferChild
, InferredChild(..), inType, inRep
) where
import AST
import AST.Class.Unify
import AST.Recurse
import Control.Lens (makeLenses, makePrisms)
import qualified Control.Lens as Lens
import Control.Lens.Operators
import Data.Constraint (Dict(..), withDict)
import Data.Functor.Sum.PolyKinds (Sum(..))
import Data.Kind (Type)
import Data.Proxy (Proxy(..))
import Prelude.Compat
type family InferOf (t :: Knot -> Type) :: Knot -> Type
type instance InferOf (Sum a b) = InferOf a
data InferredChild v k t = InferredChild
{ _inRep :: !(k t)
, _inType :: !(Tree (InferOf (GetKnot t)) v)
}
makeLenses ''InferredChild
newtype InferChild m k t =
InferChild { inferChild :: m (InferredChild (UVarOf m) k t) }
makePrisms ''InferChild
class (Monad m, KFunctor t) => Infer m t where
inferBody ::
Tree t (InferChild m k) ->
m (Tree t k, Tree (InferOf t) (UVarOf m))
inferContext ::
Proxy m ->
Proxy t ->
Dict (KNodesConstraint t (Infer m), KNodesConstraint (InferOf t) (Unify m))
{-# INLINE inferContext #-}
default inferContext ::
(KNodesConstraint t (Infer m), KNodesConstraint (InferOf t) (Unify m)) =>
Proxy m ->
Proxy t ->
Dict (KNodesConstraint t (Infer m), KNodesConstraint (InferOf t) (Unify m))
inferContext _ _ = Dict
instance Recursive (Infer m) where
{-# INLINE recurse #-}
recurse p =
withDict (inferContext (p0 p) (p1 p)) Dict
where
p0 :: Proxy (Infer m t) -> Proxy m
p0 _ = Proxy
p1 :: Proxy (Infer m t) -> Proxy t
p1 _ = Proxy
instance (InferOf a ~ InferOf b, Infer m a, Infer m b) => Infer m (Sum a b) where
{-# INLINE inferBody #-}
inferBody (InL x) = inferBody x <&> Lens._1 %~ InL
inferBody (InR x) = inferBody x <&> Lens._1 %~ InR
{-# INLINE inferContext #-}
inferContext p _ =
withDict (inferContext p (Proxy @a)) $
withDict (inferContext p (Proxy @b)) Dict