-- | Type schemes {-# LANGUAGE TemplateHaskell, FlexibleContexts, DefaultSignatures #-} {-# LANGUAGE FlexibleInstances, UndecidableInstances #-} module AST.Term.Scheme ( Scheme(..), sForAlls, sTyp, KWitness(..) , QVars(..), _QVars , HasScheme(..), loadScheme, saveScheme , MonadInstantiate(..), inferType , QVarInstances(..), _QVarInstances , makeQVarInstances ) where import AST import AST.Class.Has (HasChild(..)) import AST.Class.Recursive import AST.Combinator.ANode (ANode) import AST.Combinator.Flip (Flip(..)) import AST.Infer import AST.Recurse import AST.TH.Internal.Instances (makeCommonInstances) import AST.Unify import AST.Unify.Lookup (semiPruneLookup) import AST.Unify.New (newTerm) import AST.Unify.Generalize import AST.Unify.QuantifiedVar (HasQuantifiedVar(..), MonadQuantify(..), OrdQVar) import AST.Unify.Term (UTerm(..), uBody) import qualified Control.Lens as Lens import Control.Lens.Operators import Control.Monad.Trans.Class (MonadTrans(..)) import Control.Monad.Trans.State (StateT(..)) import Data.Constraint import Data.Map (Map) import qualified Data.Map as Map import Data.Proxy (Proxy(..)) import GHC.Generics (Generic) import Text.PrettyPrint ((<+>)) import qualified Text.PrettyPrint as Pretty import Text.PrettyPrint.HughesPJClass (Pretty(..), maybeParens) import Prelude.Compat -- | A type scheme representing a polymorphic type. data Scheme varTypes typ k = Scheme { _sForAlls :: Tree varTypes QVars , _sTyp :: k # typ } deriving Generic newtype QVars typ = QVars (Map (QVar (GetKnot typ)) (TypeConstraintsOf (GetKnot typ))) deriving stock Generic newtype QVarInstances k typ = QVarInstances (Map (QVar (GetKnot typ)) (k typ)) deriving stock Generic Lens.makeLenses ''Scheme Lens.makePrisms ''QVars Lens.makePrisms ''QVarInstances makeCommonInstances [''Scheme, ''QVars, ''QVarInstances] makeKTraversableApplyAndBases ''Scheme instance RNodes t => RNodes (Scheme v t) instance (c (Scheme v t), Recursively c t) => Recursively c (Scheme v t) instance (KTraversable (Scheme v t), RTraversable t) => RTraversable (Scheme v t) instance (RTraversable t, RTraversableInferOf t) => RTraversableInferOf (Scheme v t) instance (RNodes t, c t, Recursive c, ITermVarsConstraint c t) => ITermVarsConstraint c (Scheme v t) instance ( Ord (QVar (GetKnot typ)) , Semigroup (TypeConstraintsOf (GetKnot typ)) ) => Semigroup (QVars typ) where QVars m <> QVars n = QVars (Map.unionWith (<>) m n) instance ( Ord (QVar (GetKnot typ)) , Semigroup (TypeConstraintsOf (GetKnot typ)) ) => Monoid (QVars typ) where mempty = QVars Map.empty instance (Pretty (Tree varTypes QVars), Pretty (k # typ)) => Pretty (Scheme varTypes typ k) where pPrintPrec lvl p (Scheme forAlls typ) = pPrintPrec lvl 0 forAlls <+> pPrintPrec lvl 0 typ & maybeParens (p > 0) instance (Pretty (TypeConstraintsOf typ), Pretty (QVar typ)) => Pretty (Tree QVars typ) where pPrint (QVars qvars) = Map.toList qvars <&> printVar <&> (Pretty.text "∀" <>) <&> (<> Pretty.text ".") & Pretty.hsep where printVar (q, c) | cP == mempty = pPrint q | otherwise = pPrint q <> Pretty.text "(" <> cP <> Pretty.text ")" where cP = pPrint c type instance Lens.Index (QVars typ) = QVar (GetKnot typ) type instance Lens.IxValue (QVars typ) = TypeConstraintsOf (GetKnot typ) instance Ord (QVar (GetKnot typ)) => Lens.Ixed (QVars typ) instance Ord (QVar (GetKnot typ)) => Lens.At (QVars typ) where at k = _QVars . Lens.at k type instance InferOf (Scheme v t) = Flip GTerm t class Unify m t => MonadInstantiate m t where localInstantiations :: Tree (QVarInstances (UVarOf m)) t -> m a -> m a lookupQVar :: QVar t -> m (Tree (UVarOf m) t) instance ( Monad m , HasInferredValue typ , Unify m typ , KTraversable varTypes , KNodesConstraint varTypes (MonadInstantiate m) , RTraversable typ , Infer m typ ) => Infer m (Scheme varTypes typ) where {-# INLINE inferBody #-} inferBody (Scheme vars typ) = do foralls <- traverseK (Proxy @(MonadInstantiate m) #> makeQVarInstances) vars let withForalls = foldMapK (Proxy @(MonadInstantiate m) #> (:[]) . localInstantiations) foralls & foldl (.) id InferredChild typI typR <- inferChild typ & withForalls generalize (typR ^. inferredValue) <&> (Scheme vars typI, ) . MkFlip inferType :: ( InferOf t ~ ANode t , KTraversable t , KNodesConstraint t HasInferredValue , Unify m t , MonadInstantiate m t ) => Tree t (InferChild m k) -> m (Tree t k, Tree (InferOf t) (UVarOf m)) inferType x = case x ^? quantifiedVar of Just q -> lookupQVar q <&> (quantifiedVar # q, ) . MkANode Nothing -> do xI <- traverseK (const inferChild) x mapK (Proxy @HasInferredValue #> (^. inType . inferredValue)) xI & newTerm <&> (mapK (const (^. inRep)) xI, ) . MkANode {-# INLINE makeQVarInstances #-} makeQVarInstances :: Unify m typ => Tree QVars typ -> m (Tree (QVarInstances (UVarOf m)) typ) makeQVarInstances (QVars foralls) = traverse (newVar binding . USkolem) foralls <&> QVarInstances {-# INLINE loadBody #-} loadBody :: ( Unify m typ , HasChild varTypes typ , Ord (QVar typ) ) => Tree varTypes (QVarInstances (UVarOf m)) -> Tree typ (GTerm (UVarOf m)) -> m (Tree (GTerm (UVarOf m)) typ) loadBody foralls x = case x ^? quantifiedVar >>= getForAll of Just r -> GPoly r & pure Nothing -> case traverseK (const (^? _GMono)) x of Just xm -> newTerm xm <&> GMono Nothing -> GBody x & pure where getForAll v = foralls ^? getChild . _QVarInstances . Lens.ix v class (Unify m t, HasChild varTypes t, Ord (QVar t)) => HasScheme varTypes m t where hasSchemeRecursive :: Proxy varTypes -> Proxy m -> Proxy t -> Dict (KNodesConstraint t (HasScheme varTypes m)) {-# INLINE hasSchemeRecursive #-} default hasSchemeRecursive :: KNodesConstraint t (HasScheme varTypes m) => Proxy varTypes -> Proxy m -> Proxy t -> Dict (KNodesConstraint t (HasScheme varTypes m)) hasSchemeRecursive _ _ _ = Dict instance Recursive (HasScheme varTypes m) where recurse = hasSchemeRecursive (Proxy @varTypes) (Proxy @m) . p where p :: Proxy (HasScheme varTypes m t) -> Proxy t p _ = Proxy -- | Load scheme into unification monad so that different instantiations share -- the scheme's monomorphic parts - -- their unification is O(1) as it is the same shared unification term. {-# INLINE loadScheme #-} loadScheme :: forall m varTypes typ. ( Monad m , KTraversable varTypes , KNodesConstraint varTypes (Unify m) , HasScheme varTypes m typ ) => Tree Pure (Scheme varTypes typ) -> m (Tree (GTerm (UVarOf m)) typ) loadScheme (Pure (Scheme vars typ)) = do foralls <- traverseK (Proxy @(Unify m) #> makeQVarInstances) vars wrapM (Proxy @(HasScheme varTypes m) #>> loadBody foralls) typ saveH :: forall typ varTypes m. (Monad m, HasScheme varTypes m typ) => Tree (GTerm (UVarOf m)) typ -> StateT (Tree varTypes QVars, [m ()]) m (Tree Pure typ) saveH (GBody x) = withDict (hasSchemeRecursive (Proxy @varTypes) (Proxy @m) (Proxy @typ)) $ traverseK (Proxy @(HasScheme varTypes m) #> saveH) x <&> (_Pure #) saveH (GMono x) = unwrapM (Proxy @(HasScheme varTypes m) #>> f) x & lift where f v = semiPruneLookup v <&> \case (_, UTerm t) -> t ^. uBody (_, UUnbound{}) -> error "saveScheme of non-toplevel scheme!" _ -> error "unexpected state at saveScheme of monomorphic part" saveH (GPoly x) = lookupVar binding x & lift >>= \case USkolem l -> do r <- scopeConstraints <&> (<> l) >>= newQuantifiedVariable & lift Lens._1 . getChild %= (\v -> v & _QVars . Lens.at r ?~ l :: Tree QVars typ) Lens._2 %= (bindVar binding x (USkolem l) :) let result = _Pure . quantifiedVar # r UResolved result & bindVar binding x & lift pure result UResolved v -> pure v _ -> error "unexpected state at saveScheme's forall" saveScheme :: ( KNodesConstraint varTypes OrdQVar , KPointed varTypes , HasScheme varTypes m typ ) => Tree (GTerm (UVarOf m)) typ -> m (Tree Pure (Scheme varTypes typ)) saveScheme x = do (t, (v, recover)) <- runStateT (saveH x) ( pureK (Proxy @OrdQVar #> QVars mempty) , [] ) _Pure # Scheme v t <$ sequence_ recover