-- | Nominal (named) types declaration, instantiation, construction, and access. {-# LANGUAGE FlexibleInstances, GeneralizedNewtypeDeriving, UndecidableInstances #-} {-# LANGUAGE FlexibleContexts, TemplateHaskell, EmptyCase #-} module AST.Term.Nominal ( NominalDecl(..), nParams, nScheme, KWitness(..) , NominalInst(..), nId, nArgs , ToNom(..), tnId, tnVal , FromNom(..), _FromNom , HasNominalInst(..) , NomVarTypes , MonadNominals(..) , LoadedNominalDecl, loadNominalDecl ) where import AST import AST.Class.Has (HasChild(..)) import AST.Class.Traversable (ContainedK(..)) import AST.Class.ZipMatch (ZipMatch(..)) import AST.Combinator.Flip (_Flip) import AST.Infer import AST.Recurse import AST.Term.FuncType (FuncType(..)) import AST.Term.Map (TermMap(..), _TermMap) import AST.Term.Scheme import AST.TH.Internal.Instances (makeCommonInstances) import AST.Unify import AST.Unify.Generalize (GTerm(..), _GMono, instantiateWith, instantiateForAll) import AST.Unify.New (newTerm) import AST.Unify.QuantifiedVar (HasQuantifiedVar(..), OrdQVar) import AST.Unify.Term (UTerm(..)) import Control.Applicative (Alternative(..)) import Control.DeepSeq (NFData) import Control.Lens (Prism', makeLenses, makePrisms) import qualified Control.Lens as Lens import Control.Lens.Operators import Control.Monad.Trans.Writer (execWriterT) import Data.Binary (Binary) import Data.Foldable (traverse_) import Data.Kind (Type) import Data.Proxy (Proxy(..)) import qualified Data.Map as Map import Generics.Constraints (Constraints) import GHC.Generics (Generic) import Text.PrettyPrint ((<+>)) import qualified Text.PrettyPrint as Pretty import Text.PrettyPrint.HughesPJClass (Pretty(..), maybeParens) import Prelude.Compat type family NomVarTypes (t :: Knot -> Type) :: Knot -> Type -- | A declaration of a nominal type. data NominalDecl typ k = NominalDecl { _nParams :: Tree (NomVarTypes typ) QVars , _nScheme :: Scheme (NomVarTypes typ) typ k } deriving Generic -- | An instantiation of a nominal type data NominalInst nomId varTypes k = NominalInst { _nId :: nomId , _nArgs :: Tree varTypes (QVarInstances (GetKnot k)) } deriving Generic -- | Nominal data constructor. -- -- Wrap content with a data constructor -- (analogues to a data constructor of a Haskell `newtype`'s). -- -- Introduces the nominal's foralled type variables into the value's scope. data ToNom nomId term k = ToNom { _tnId :: nomId , _tnVal :: k # term } deriving Generic -- | Access the data in a nominally typed value. -- -- Analogues to a getter of a Haskell `newtype`. newtype FromNom nomId (term :: Knot -> *) (k :: Knot) = FromNom nomId deriving newtype (Eq, Ord, Binary, NFData) deriving stock (Show, Generic) -- | A nominal declaration loaded into scope in an inference monad. data LoadedNominalDecl typ v = LoadedNominalDecl { _lnParams :: Tree (NomVarTypes typ) (QVarInstances (GetKnot v)) , _lnForalls :: Tree (NomVarTypes typ) (QVarInstances (GetKnot v)) , _lnType :: Tree (GTerm (GetKnot v)) typ } deriving Generic makeLenses ''NominalDecl makeLenses ''NominalInst makeLenses ''ToNom makePrisms ''FromNom makeCommonInstances [''NominalDecl, ''NominalInst, ''ToNom, ''LoadedNominalDecl] makeKTraversableAndBases ''NominalDecl makeKTraversableApplyAndBases ''ToNom makeKTraversableApplyAndBases ''FromNom instance KNodes v => KNodes (NominalInst n v) where type KNodesConstraint (NominalInst n v) c = KNodesConstraint v c data KWitness (NominalInst n v) c = E_NominalInst_k (KWitness v c) {-# INLINE kLiftConstraint #-} kLiftConstraint (E_NominalInst_k w) = kLiftConstraint w instance KFunctor v => KFunctor (NominalInst n v) where {-# INLINE mapK #-} mapK f = nArgs %~ mapK (\w -> _QVarInstances . Lens.mapped %~ f (E_NominalInst_k w)) instance KFoldable v => KFoldable (NominalInst n v) where {-# INLINE foldMapK #-} foldMapK f = foldMapK (\w -> foldMap (f (E_NominalInst_k w)) . (^. _QVarInstances)) . (^. nArgs) instance KTraversable v => KTraversable (NominalInst n v) where {-# INLINE sequenceK #-} sequenceK (NominalInst n v) = traverseK (const (_QVarInstances (traverse runContainedK))) v <&> NominalInst n instance ( Eq nomId , ZipMatch varTypes , KTraversable varTypes , KNodesConstraint varTypes ZipMatch , KNodesConstraint varTypes OrdQVar ) => ZipMatch (NominalInst nomId varTypes) where {-# INLINE zipMatch #-} zipMatch (NominalInst xId x) (NominalInst yId y) | xId /= yId = Nothing | otherwise = zipMatch x y >>= traverseK ( Proxy @ZipMatch #*# Proxy @OrdQVar #> \(Pair (QVarInstances c0) (QVarInstances c1)) -> zipMatch (TermMap c0) (TermMap c1) <&> (^. _TermMap) <&> QVarInstances ) <&> NominalInst xId instance Constraints (ToNom nomId term k) Pretty => Pretty (ToNom nomId term k) where pPrintPrec lvl p (ToNom nomId term) = (pPrint nomId <> Pretty.text "#") <+> pPrintPrec lvl 11 term & maybeParens (p > 10) class (Pretty (QVar k), Pretty (outer # k)) => PrettyConstraints outer k instance (Pretty (QVar k), Pretty (outer # k)) => PrettyConstraints outer k instance ( Pretty nomId , KApply varTypes, KFoldable varTypes , KNodesConstraint varTypes (PrettyConstraints k) ) => Pretty (NominalInst nomId varTypes k) where pPrint (NominalInst n vars) = pPrint n <> joinArgs (foldMapK (Proxy @(PrettyConstraints k) #> mkArgs) vars) where joinArgs [] = mempty joinArgs xs = Pretty.text "[" <> Pretty.sep (Pretty.punctuate (Pretty.text ",") xs) <> Pretty.text "]" mkArgs (QVarInstances m) = Map.toList m <&> \(k, v) -> (pPrint k <> Pretty.text ":") <+> pPrint v instance (RNodes t, KNodes (NomVarTypes t)) => KNodes (LoadedNominalDecl t) where type KNodesConstraint (LoadedNominalDecl t) c = ( KNodesConstraint (NomVarTypes t) c , c t , Recursive c ) data KWitness (LoadedNominalDecl t) n where E_LoadedNominalDecl_Body :: KRecWitness t n -> KWitness (LoadedNominalDecl t) n E_LoadedNominalDecl_NomVarTypes :: KWitness (NomVarTypes t) n -> KWitness (LoadedNominalDecl t) n {-# INLINE kLiftConstraint #-} kLiftConstraint (E_LoadedNominalDecl_Body w) = kLiftConstraint (E_Flip_GTerm w) kLiftConstraint (E_LoadedNominalDecl_NomVarTypes w) = kLiftConstraint w instance (Recursively KFunctor typ, KFunctor (NomVarTypes typ)) => KFunctor (LoadedNominalDecl typ) where {-# INLINE mapK #-} mapK f (LoadedNominalDecl mp mf t) = LoadedNominalDecl (onMap mp) (onMap mf) (t & Lens.from _Flip %~ mapK (\(E_Flip_GTerm w) -> f (E_LoadedNominalDecl_Body w))) where onMap = mapK (\w -> _QVarInstances . Lens.mapped %~ f (E_LoadedNominalDecl_NomVarTypes w)) instance (Recursively KFoldable typ, KFoldable (NomVarTypes typ)) => KFoldable (LoadedNominalDecl typ) where {-# INLINE foldMapK #-} foldMapK f (LoadedNominalDecl mp mf t) = onMap mp <> onMap mf <> foldMapK (\(E_Flip_GTerm w) -> f (E_LoadedNominalDecl_Body w)) (_Flip # t) where onMap = foldMapK (\w -> foldMap (f (E_LoadedNominalDecl_NomVarTypes w)) . (^. _QVarInstances)) instance (RTraversable typ, KTraversable (NomVarTypes typ)) => KTraversable (LoadedNominalDecl typ) where {-# INLINE sequenceK #-} sequenceK (LoadedNominalDecl p f t) = LoadedNominalDecl <$> onMap p <*> onMap f <*> Lens.from _Flip sequenceK t where onMap = traverseK (const ((_QVarInstances . traverse) runContainedK)) {-# INLINE loadBody #-} loadBody :: ( Unify m typ , HasChild varTypes typ , Ord (QVar typ) ) => Tree varTypes (QVarInstances (UVarOf m)) -> Tree varTypes (QVarInstances (UVarOf m)) -> Tree typ (GTerm (UVarOf m)) -> m (Tree (GTerm (UVarOf m)) typ) loadBody params foralls x = case x ^? quantifiedVar >>= get of Just r -> GPoly r & pure Nothing -> case traverseK (const (^? _GMono)) x of Just xm -> newTerm xm <&> GMono Nothing -> GBody x & pure where get v = params ^? getChild . _QVarInstances . Lens.ix v <|> foralls ^? getChild . _QVarInstances . Lens.ix v {-# INLINE loadNominalDecl #-} loadNominalDecl :: forall m typ. ( Monad m , KTraversable (NomVarTypes typ) , KNodesConstraint (NomVarTypes typ) (Unify m) , HasScheme (NomVarTypes typ) m typ ) => Tree Pure (NominalDecl typ) -> m (Tree (LoadedNominalDecl typ) (UVarOf m)) loadNominalDecl (Pure (NominalDecl params (Scheme foralls typ))) = do paramsL <- traverseK (Proxy @(Unify m) #> makeQVarInstances) params forallsL <- traverseK (Proxy @(Unify m) #> makeQVarInstances) foralls wrapM (Proxy @(HasScheme (NomVarTypes typ) m) #>> loadBody paramsL forallsL ) typ <&> LoadedNominalDecl paramsL forallsL class MonadNominals nomId typ m where getNominalDecl :: nomId -> m (Tree (LoadedNominalDecl typ) (UVarOf m)) class HasNominalInst nomId typ where nominalInst :: Prism' (Tree typ k) (Tree (NominalInst nomId (NomVarTypes typ)) k) {-# INLINE lookupParams #-} lookupParams :: forall m varTypes. ( Applicative m , KTraversable varTypes , KNodesConstraint varTypes (Unify m) ) => Tree varTypes (QVarInstances (UVarOf m)) -> m (Tree varTypes (QVarInstances (UVarOf m))) lookupParams = traverseK (Proxy @(Unify m) #> (_QVarInstances . traverse) lookupParam) where lookupParam v = lookupVar binding v >>= \case UInstantiated r -> pure r USkolem l -> -- This is a phantom-type, wasn't instantiated by `instantiate`. scopeConstraints <&> (<> l) >>= newVar binding . UUnbound _ -> error "unexpected state at nominal's parameter" type instance InferOf (ToNom n e) = NominalInst n (NomVarTypes (TypeOf e)) instance ( MonadScopeLevel m , MonadNominals nomId (TypeOf expr) m , KTraversable (NomVarTypes (TypeOf expr)) , KNodesConstraint (NomVarTypes (TypeOf expr)) (Unify m) , Unify m (TypeOf expr) , HasInferredType expr , Infer m expr ) => Infer m (ToNom nomId expr) where {-# INLINE inferBody #-} inferBody (ToNom nomId val) = do (InferredChild valI valR, typ, paramsT) <- do v <- inferChild val LoadedNominalDecl params foralls gen <- getNominalDecl nomId recover <- traverseK_ ( Proxy @(Unify m) #> traverse_ (instantiateForAll USkolem) . (^. _QVarInstances) ) foralls & execWriterT (typ, paramsT) <- instantiateWith (lookupParams params) UUnbound gen (v, typ, paramsT) <$ sequence_ recover & localLevel (ToNom nomId valI, NominalInst nomId paramsT) <$ unify typ (valR ^# inferredType (Proxy @expr)) type instance InferOf (FromNom n e) = FuncType (TypeOf e) instance ( Infer m expr , HasNominalInst nomId (TypeOf expr) , MonadNominals nomId (TypeOf expr) m , KTraversable (NomVarTypes (TypeOf expr)) , KNodesConstraint (NomVarTypes (TypeOf expr)) (Unify m) , Unify m (TypeOf expr) ) => Infer m (FromNom nomId expr) where {-# INLINE inferBody #-} inferBody (FromNom nomId) = do LoadedNominalDecl params _ gen <- getNominalDecl nomId (typ, paramsT) <- instantiateWith (lookupParams params) UUnbound gen nominalInst # NominalInst nomId paramsT & newTerm <&> (`FuncType` typ) <&> (FromNom nomId, )