{-# LANGUAGE CPP , ScopedTypeVariables , GADTs , DataKinds , KindSignatures , GeneralizedNewtypeDeriving , TypeOperators , FlexibleContexts , FlexibleInstances , OverloadedStrings , PatternGuards , Rank2Types #-} {-# OPTIONS_GHC -Wall -fwarn-tabs #-} ---------------------------------------------------------------- -- 2016.05.28 -- | -- Module : Language.Hakaru.Syntax.TypeCheck -- Copyright : Copyright (c) 2016 the Hakaru team -- License : BSD3 -- Maintainer : wren@community.haskell.org -- Stability : experimental -- Portability : GHC-only -- -- Bidirectional type checking for our AST. ---------------------------------------------------------------- module Language.Hakaru.Syntax.TypeCheck ( -- * The type checking monad TypeCheckError , TypeCheckMonad(), runTCM, unTCM , TypeCheckMode(..) -- * Type checking itself , inferable , mustCheck , TypedAST(..) , onTypedAST, onTypedASTM, elimTypedAST , inferType , checkType ) where import Prelude hiding (id, (.)) import Control.Category import Data.Proxy (KProxy(..)) import Data.Text (pack, Text()) import Data.Either (partitionEithers) import qualified Data.IntMap as IM import qualified Data.Traversable as T import qualified Data.List.NonEmpty as L import qualified Data.Foldable as F import qualified Data.Sequence as S import qualified Data.Vector as V #if __GLASGOW_HASKELL__ < 710 import Control.Applicative (Applicative(..), (<$>)) import Data.Monoid (Monoid(..)) #endif import qualified Language.Hakaru.Parser.AST as U import Data.Number.Nat (fromNat) import Language.Hakaru.Syntax.TypeCheck.TypeCheckMonad import Language.Hakaru.Syntax.TypeCheck.Unification import Language.Hakaru.Syntax.IClasses import Language.Hakaru.Types.DataKind (Hakaru(..), HData', HBool) import Language.Hakaru.Types.Sing import Language.Hakaru.Types.Coercion import Language.Hakaru.Types.HClasses ( HEq, hEq_Sing, HOrd, hOrd_Sing, HSemiring, hSemiring_Sing , hRing_Sing, sing_HRing, hFractional_Sing, sing_HFractional , sing_NonNegative, hDiscrete_Sing , HIntegrable(..) , HRadical(..), HContinuous(..)) import Language.Hakaru.Syntax.ABT import Language.Hakaru.Syntax.Datum import Language.Hakaru.Syntax.Reducer import Language.Hakaru.Syntax.AST import Language.Hakaru.Syntax.AST.Sing (sing_Literal, sing_MeasureOp) import Language.Hakaru.Pretty.Concrete (prettyType, prettyTypeT) import Language.Hakaru.Syntax.TypeOf (typeOf) import Language.Hakaru.Syntax.Prelude (triv) ---------------------------------------------------------------- ---------------------------------------------------------------- -- | Those terms from which we can synthesize a unique type. We are -- also allowed to check them, via the change-of-direction rule. inferable :: U.AST -> Bool inferable = not . mustCheck -- | Those terms whose types must be checked analytically. We cannot -- synthesize (unambiguous) types for these terms. -- -- N.B., this function assumes we're in 'StrictMode'. If we're -- actually in 'LaxMode' then a handful of AST nodes behave -- differently: in particular, 'U.NaryOp_', 'U.Superpose', and -- 'U.Case_'. In strict mode those cases can just infer one of their -- arguments and then check the rest against the inferred type. -- (For case-expressions, we must also check the scrutinee since -- it's type cannot be unambiguously inferred from the patterns.) -- Whereas in lax mode we must infer all arguments and then take -- the lub of their types in order to know which coercions to -- introduce. mustCheck :: U.AST -> Bool mustCheck e = caseVarSyn e (const False) go where go :: U.MetaTerm -> Bool go (U.Lam_ _ e2) = mustCheck' e2 -- In general, applications don't require checking; we infer -- the first applicand to get the type of the second and of the -- result, then we check the second and return the result type. -- Thus, applications will only yield \"must check\" errors if -- the function does; but that's the responsability of the -- function term, not of the application term it's embedded -- within. -- -- However, do note that the above only applies to lambda-defined -- functions, not to all \"function-like\" things. In particular, -- data constructors require checking (see the note below). go (U.App_ _ _) = False -- We follow Dunfield & Pientka and \Pi\Sigma in inferring or -- checking depending on what the body requires. This is as -- opposed to the TLDI'05 paper, which always infers @e2@ but -- will check or infer the @e1@ depending on whether it has a -- type annotation or not. go (U.Let_ _ e2) = mustCheck' e2 go (U.Ann_ _ _) = False go (U.CoerceTo_ _ _) = False go (U.UnsafeTo_ _ _) = False -- In general (according to Dunfield & Pientka), we should be -- able to infer the result of a fully saturated primop by -- looking up its type and then checking all the arguments. go (U.PrimOp_ _ _) = False go (U.ArrayOp_ _ es) = F.all mustCheck es -- In strict mode: if we can infer any of the arguments, then -- we can check all the rest at the same type. -- -- BUG: in lax mode we must be able to infer all of them; -- otherwise we may not be able to take the lub of the types go (U.NaryOp_ _ es) = F.all mustCheck es go (U.Superpose_ pes) = F.all (mustCheck . snd) pes -- Our numeric literals aren't polymorphic, so we can infer -- them just fine. Or rather, according to our AST they aren't; -- in truth, they are in the surface language. Which is part -- of the reason for needing 'LaxMode' -- -- TODO: correctly capture our surface-language semantics by -- always treating literals as if we're in 'LaxMode'. go (U.Literal_ _) = False -- I return true because most folks (neelk, Pfenning, Dunfield -- & Pientka) say all data constructors mustCheck. The main -- issue here is dealing with (polymorphic) sum types and phantom -- types, since these mean the term doesn't contain enough -- information for all the type indices. Even for record types, -- there's the additional issue of the term (perhaps) not giving -- enough information about the nominal type even if it does -- give enough info for the structural type. -- -- Still, given those limitations, we should be able to infer -- a subset of data constructors which happen to avoid the -- problem areas. In particular, given that our surface syntax -- doesn't use the sum-of-products representation, we should -- be able to rely on symbol resolution to avoid the nominal -- typing issue. Thus, for non-empty arrays and non-phantom -- record types, we should be able to infer the whole type -- provided we can infer the various subterms. go (U.Pair_ e1 e2) = mustCheck e1 && mustCheck e2 go (U.Array_ _ e1) = mustCheck' e1 go (U.ArrayLiteral_ es) = F.all mustCheck es go (U.Datum_ _) = True -- TODO: everyone says this, but it seems to me that if we can -- infer any of the branches (and check the rest to agree) then -- we should be able to infer the whole thing... Or maybe the -- problem is that the change-of-direction rule might send us -- down the wrong path? go (U.Case_ _ _) = True go (U.Dirac_ e1) = mustCheck e1 go (U.MBind_ _ e2) = mustCheck' e2 go (U.Plate_ _ e2) = mustCheck' e2 go (U.Chain_ _ e2 e3) = mustCheck e2 && mustCheck' e3 go (U.MeasureOp_ _ _) = False go (U.Integrate_ _ _ _) = False go (U.Summate_ _ _ _) = False go (U.Product_ _ _ _) = False go (U.Bucket_ _ _ _) = False go U.Reject_ = True go (U.Transform_ tr es ) = case (tr, es) of (Expect , (Nil2, e1) U.:* _ U.:* U.End) -> mustCheck e1 (Observe , (Nil2, e1) U.:* _ U.:* U.End) -> mustCheck e1 (MCMC , (Nil2, e1) U.:* (Nil2, e2) U.:* U.End) -> mustCheck e1 && mustCheck e2 (Disint _ , (Nil2, e1) U.:* U.End) -> mustCheck e1 (Simplify , (Nil2, e1) U.:* U.End) -> mustCheck e1 (Summarize, (Nil2, e1) U.:* U.End) -> mustCheck e1 (Reparam , (Nil2, e1) U.:* U.End) -> mustCheck e1 go U.InjTyped{} = False mustCheck' :: MetaABT U.SourceSpan U.Term '[ 'U.U ] 'U.U -> Bool mustCheck' e = caseBind e $ \_ e' -> mustCheck e' inferBinder :: (ABT Term abt) => Sing a -> MetaABT U.SourceSpan U.Term '[ 'U.U ] 'U.U -> (forall b. Sing b -> abt '[ a ] b -> TypeCheckMonad r) -> TypeCheckMonad r inferBinder typ e k = caseBind e $ \x e1 -> do let x' = x {varType = typ} TypedAST typ1 e1' <- pushCtx x' (inferType e1) k typ1 (bind x' e1') inferBinders :: (ABT Term abt) => List1 Variable xs -> U.AST -> (forall a. Sing a -> abt xs a -> TypeCheckMonad r) -> TypeCheckMonad r inferBinders = \xs e k -> do TypedAST typ e' <- pushesCtx xs (inferType e) k typ (binds_ xs e') where -- TODO: make sure the 'TCM'\/'unTCM' stuff doesn't do stupid asymptotic things pushesCtx :: List1 Variable (xs :: [Hakaru]) -> TypeCheckMonad b -> TypeCheckMonad b pushesCtx Nil1 m = m pushesCtx (Cons1 x xs) m = pushesCtx xs (TCM (unTCM m . insertVarSet x)) checkBinder :: (ABT Term abt) => Sing a -> Sing b -> MetaABT U.SourceSpan U.Term '[ 'U.U ] 'U.U -> TypeCheckMonad (abt '[ a ] b) checkBinder typ eTyp e = caseBind e $ \x e1 -> do let x' = x {varType = typ} pushCtx x' (bind x' <$> checkType eTyp e1) checkBinders :: (ABT Term abt) => List1 Variable xs -> Sing a -> U.AST -> TypeCheckMonad (abt xs a) checkBinders xs eTyp e = case xs of Nil1 -> checkType eTyp e Cons1 x xs' -> pushCtx x (bind x <$> checkBinders xs' eTyp e) ---------------------------------------------------------------- -- | Given a typing environment and a term, synthesize the term's -- type (and produce an elaborated term): -- -- > Γ ⊢ e ⇒ e' ∈ τ inferType :: forall abt . (ABT Term abt) => U.AST -> TypeCheckMonad (TypedAST abt) inferType = inferType_ where -- HACK: we need to give these local definitions to avoid -- \"ambiguity\" in the choice of ABT instance... checkType_ :: forall b. Sing b -> U.AST -> TypeCheckMonad (abt '[] b) checkType_ = checkType inferOneCheckOthers_ :: [U.AST] -> TypeCheckMonad (TypedASTs abt) inferOneCheckOthers_ = inferOneCheckOthers inferVariable :: Maybe U.SourceSpan -> Variable 'U.U -> TypeCheckMonad (TypedAST abt) inferVariable sourceSpan (Variable hintID nameID _) = do ctx <- getCtx case IM.lookup (fromNat nameID) (unVarSet ctx) of Just (SomeVariable x') -> return $ TypedAST (varType x') (var x') Nothing -> ambiguousFreeVariable hintID sourceSpan -- HACK: We need this monomorphic binding so that GHC doesn't get -- confused about which @(ABT AST abt)@ instance to use in recursive -- calls. inferType_ :: U.AST -> TypeCheckMonad (TypedAST abt) inferType_ e0 = let s = getMetadata e0 in caseVarSyn e0 (inferVariable s) (go s) where go :: Maybe U.SourceSpan -> U.MetaTerm -> TypeCheckMonad (TypedAST abt) go sourceSpan t = case t of U.Lam_ (U.SSing typ) e -> do inferBinder typ e $ \typ2 e2 -> return . TypedAST (SFun typ typ2) $ syn (Lam_ :$ e2 :* End) U.App_ e1 e2 -> do TypedAST typ1 e1' <- inferType_ e1 unifyFun typ1 sourceSpan $ \typ2 typ3 -> do e2' <- checkType_ typ2 e2 return . TypedAST typ3 $ syn (App_ :$ e1' :* e2' :* End) -- case typ1 of -- SFun typ2 typ3 -> do -- e2' <- checkType_ typ2 e2 -- return . TypedAST typ3 $ syn (App_ :$ e1' :* e2' :* End) -- _ -> typeMismatch sourceSpan (Left "function type") (Right typ1) -- The above is the standard rule that everyone uses. -- However, if the @e1@ is a lambda (rather than a primop -- or a variable), then it will require a type annotation. -- Couldn't we just as well add an additional rule that -- says to infer @e2@ and then infer @e1@ under the assumption -- that the variable has the same type as the argument? (or -- generalize that idea to keep track of a bunch of arguments -- being passed in; sort of like a dual to our typing -- environments?) Is this at all related to what Dunfield -- & Neelk are doing in their ICFP'13 paper with that -- \"=>=>\" judgment? (prolly not, but...) U.Let_ e1 e2 -> do TypedAST typ1 e1' <- inferType_ e1 inferBinder typ1 e2 $ \typ2 e2' -> return . TypedAST typ2 $ syn (Let_ :$ e1' :* e2' :* End) U.Ann_ (U.SSing typ1) e1 -> do -- N.B., this requires that @typ1@ is a 'Sing' not a 'Proxy', -- since we can't generate a 'Sing' from a 'Proxy'. TypedAST typ1 <$> checkType_ typ1 e1 U.PrimOp_ op es -> inferPrimOp op es U.ArrayOp_ op es -> inferArrayOp op es U.NaryOp_ op es -> do mode <- getMode TypedASTs typ es' <- case mode of StrictMode -> inferOneCheckOthers_ es LaxMode -> inferLubType sourceSpan es UnsafeMode -> inferLubType sourceSpan es op' <- make_NaryOp typ op return . TypedAST typ $ syn (NaryOp_ op' $ S.fromList es') U.Literal_ (Some1 v) -> -- TODO: in truth, we can infer this to be any supertype -- (adjusting the concrete @v@ as necessary). That is, the -- surface language treats numeric literals as polymorphic, -- so we should capture that somehow--- even if we're not -- in 'LaxMode'. We'll prolly need to handle this -- subtype-polymorphism the same way as we do for for -- everything when in 'UnsafeMode'. return . TypedAST (sing_Literal v) $ syn (Literal_ v) -- TODO: we can try to do 'U.Case_' by using branch-based -- variants of 'inferOneCheckOthers' and 'inferLubType' depending -- on the mode; provided we can in fact infer the type of the -- scrutinee. N.B., if we add this case, then we need to update -- 'mustCheck' to return the right thing. U.CoerceTo_ (Some2 c) e1 -> case singCoerceDomCod c of Nothing | inferable e1 -> inferType_ e1 | otherwise -> ambiguousNullCoercion sourceSpan Just (dom,cod) -> do e1' <- checkType_ dom e1 return . TypedAST cod $ syn (CoerceTo_ c :$ e1' :* End) U.UnsafeTo_ (Some2 c) e1 -> case singCoerceDomCod c of Nothing | inferable e1 -> inferType_ e1 | otherwise -> ambiguousNullCoercion sourceSpan Just (dom,cod) -> do e1' <- checkType_ cod e1 return . TypedAST dom $ syn (UnsafeFrom_ c :$ e1' :* End) U.MeasureOp_ (U.SomeOp op) es -> do let (typs, typ1) = sing_MeasureOp op es' <- checkSArgs typs es return . TypedAST (SMeasure typ1) $ syn (MeasureOp_ op :$ es') U.Pair_ e1 e2 -> do TypedAST typ1 e1' <- inferType_ e1 TypedAST typ2 e2' <- inferType_ e2 return . TypedAST (sPair typ1 typ2) $ syn (Datum_ $ dPair_ typ1 typ2 e1' e2') U.Array_ e1 e2 -> do e1' <- checkType_ SNat e1 inferBinder SNat e2 $ \typ2 e2' -> return . TypedAST (SArray typ2) $ syn (Array_ e1' e2') U.ArrayLiteral_ es -> do mode <- getMode TypedASTs typ es' <- case mode of StrictMode -> inferOneCheckOthers_ es LaxMode -> inferLubType sourceSpan es UnsafeMode -> inferLubType sourceSpan es return . TypedAST (SArray typ) $ syn (ArrayLiteral_ es') U.Case_ e1 branches -> do TypedAST typ1 e1' <- inferType_ e1 mode <- getMode case mode of StrictMode -> inferCaseStrict typ1 e1' branches LaxMode -> inferCaseLax sourceSpan typ1 e1' branches UnsafeMode -> inferCaseLax sourceSpan typ1 e1' branches U.Dirac_ e1 -> do TypedAST typ1 e1' <- inferType_ e1 return . TypedAST (SMeasure typ1) $ syn (Dirac :$ e1' :* End) U.MBind_ e1 e2 -> caseBind e2 $ \x e2' -> do TypedAST typ1 e1' <- inferType_ e1 unifyMeasure typ1 sourceSpan $ \typ2 -> let x' = makeVar x typ2 in pushCtx x' $ do TypedAST typ3 e3' <- inferType_ e2' unifyMeasure typ3 sourceSpan $ \_ -> return . TypedAST typ3 $ syn (MBind :$ e1' :* bind x' e3' :* End) U.Plate_ e1 e2 -> caseBind e2 $ \x e2' -> do e1' <- checkType_ SNat e1 let x' = makeVar x SNat pushCtx x' $ do TypedAST typ2 e3' <- inferType_ e2' unifyMeasure typ2 sourceSpan $ \typ3 -> return . TypedAST (SMeasure . SArray $ typ3) $ syn (Plate :$ e1' :* bind x' e3' :* End) U.Chain_ e1 e2 e3 -> caseBind e3 $ \x e3' -> do e1' <- checkType_ SNat e1 TypedAST typ2 e2' <- inferType_ e2 let x' = makeVar x typ2 pushCtx x' $ do TypedAST typ3 e4' <- inferType_ e3' unifyMeasure typ3 sourceSpan $ \typ4 -> unifyPair typ4 sourceSpan $ \a b -> matchTypes typ2 b sourceSpan () () $ return . TypedAST (SMeasure $ sPair (SArray a) typ2) $ syn (Chain :$ e1' :* e2' :* bind x' e4' :* End) U.Integrate_ e1 e2 e3 -> do e1' <- checkType_ SReal e1 e2' <- checkType_ SReal e2 e3' <- checkBinder SReal SProb e3 return . TypedAST SProb $ syn (Integrate :$ e1' :* e2' :* e3' :* End) U.Summate_ e1 e2 e3 -> do TypedAST typ1 e1' <- inferType e1 e2' <- checkType_ typ1 e2 case hDiscrete_Sing typ1 of Nothing -> failwith_ "Summate given bounds which are not discrete" Just h1 -> inferBinder typ1 e3 $ \typ2 ee' -> case hSemiring_Sing typ2 of Nothing -> failwith_ "Summate given summands which are not in a semiring" Just h2 -> return . TypedAST typ2 $ syn (Summate h1 h2 :$ e1' :* e2' :* ee' :* End) U.Product_ e1 e2 e3 -> do TypedAST typ1 e1' <- inferType e1 e2' <- checkType_ typ1 e2 case hDiscrete_Sing typ1 of Nothing -> failwith_ "Product given bounds which are not discrete" Just h1 -> inferBinder typ1 e3 $ \typ2 e3' -> case hSemiring_Sing typ2 of Nothing -> failwith_ "Product given factors which are not in a semiring" Just h2 -> return . TypedAST typ2 $ syn (Product h1 h2 :$ e1' :* e2' :* e3' :* End) U.Bucket_ e1 e2 r1 -> do e1' <- checkType_ SNat e1 e2' <- checkType_ SNat e2 TypedReducer typ1 Nil1 r1' <- inferReducer r1 Nil1 return . TypedAST typ1 $ syn (Bucket e1' e2' r1') U.Transform_ tr es -> inferTransform sourceSpan tr es U.Superpose_ pes -> do -- TODO: clean up all this @map fst@, @map snd@, @zip@ stuff mode <- getMode TypedASTs typ es' <- case mode of StrictMode -> inferOneCheckOthers_ (L.toList $ fmap snd pes) LaxMode -> inferLubType sourceSpan (L.toList $ fmap snd pes) UnsafeMode -> inferLubType sourceSpan (L.toList $ fmap snd pes) unifyMeasure typ sourceSpan $ \_ -> do ps' <- T.traverse (checkType SProb) (fmap fst pes) return $ TypedAST typ (syn (Superpose_ (L.zip ps' (L.fromList es')))) U.InjTyped t -> let t' = t in return $ TypedAST (typeOf t') t' _ | mustCheck e0 -> ambiguousMustCheck sourceSpan | otherwise -> error "inferType: missing an inferable branch!" inferTransform :: Maybe U.SourceSpan -> Transform as x -> U.SArgs U.U_ABT as -> TypeCheckMonad (TypedAST abt) inferTransform sourceSpan Expect ((Nil2, e1) U.:* (Cons2 U.ToU Nil2, e2) U.:* U.End) = do let e1src = getMetadata e1 TypedAST typ1 e1' <- inferType_ e1 unifyMeasure typ1 e1src $ \typ2 -> do e2' <- checkBinder typ2 SProb e2 return . TypedAST SProb $ syn (Transform_ Expect :$ e1' :* e2' :* End) inferTransform sourceSpan Observe ((Nil2, e1) U.:* (Nil2, e2) U.:* U.End) = do let e1src = getMetadata e1 TypedAST typ1 e1' <- inferType_ e1 unifyMeasure typ1 e1src $ \typ2 -> do e2' <- checkType_ typ2 e2 return . TypedAST typ1 $ syn (Transform_ Observe :$ e1' :* e2' :* End) inferTransform sourceSpan MCMC ((Nil2, e1) U.:* (Nil2, e2) U.:* U.End) = do let e1src = getMetadata e1 e2src = getMetadata e2 TypedAST typ1 e1' <- inferType_ e1 TypedAST typ2 e2' <- inferType_ e2 unifyFun typ1 e1src $ \typa typmb -> unifyMeasure typmb e1src $ \typb -> unifyMeasure typ2 e2src $ \typc -> matchTypes typa typb e1src (SFun typa (SMeasure typa)) typ1 $ matchTypes typb typc e2src typmb typ2 $ return $ TypedAST (SFun typa (SMeasure typa)) $ syn $ Transform_ MCMC :$ e1' :* e2' :* End inferTransform sourceSpan (Disint k) ((Nil2, e1) U.:* U.End) = do let e1src = getMetadata e1 TypedAST typ1 e1' <- inferType_ e1 unifyMeasure typ1 e1src $ \typ2 -> unifyPair typ2 e1src $ \typa typb -> return $ TypedAST (SFun typa (SMeasure typb)) $ syn $ Transform_ (Disint k) :$ e1' :* End inferTransform sourceSpan Simplify ((Nil2, e1) U.:* U.End) = do TypedAST typ1 e1' <- inferType_ e1 return $ TypedAST typ1 $ syn (Transform_ Simplify :$ e1' :* End) inferTransform sourceSpan Reparam ((Nil2, e1) U.:* U.End) = do TypedAST typ1 e1' <- inferType_ e1 return $ TypedAST typ1 $ syn (Transform_ Reparam :$ e1' :* End) inferTransform sourceSpan Summarize ((Nil2, e1) U.:* U.End) = do TypedAST typ1 e1' <- inferType_ e1 return $ TypedAST typ1 $ syn (Transform_ Summarize :$ e1' :* End) inferTransform _ tr _ = error $ "inferTransform{" ++ show tr ++ "}: TODO" inferPrimOp :: U.PrimOp -> [U.AST] -> TypeCheckMonad (TypedAST abt) inferPrimOp U.Not es = case es of [e] -> do e' <- checkType_ sBool e return . TypedAST sBool $ syn (PrimOp_ Not :$ e' :* End) _ -> argumentNumberError inferPrimOp U.Pi es = case es of [] -> return . TypedAST SProb $ syn (PrimOp_ Pi :$ End) _ -> argumentNumberError inferPrimOp U.Cos es = case es of [e] -> do e' <- checkType_ SReal e return . TypedAST SReal $ syn (PrimOp_ Cos :$ e' :* End) _ -> argumentNumberError inferPrimOp U.RealPow es = case es of [e1, e2] -> do e1' <- checkType_ SProb e1 e2' <- checkType_ SReal e2 return . TypedAST SProb $ syn (PrimOp_ RealPow :$ e1' :* e2' :* End) _ -> argumentNumberError inferPrimOp U.Choose es = case es of [e1, e2] -> do e1' <- checkType_ SNat e1 e2' <- checkType_ SNat e2 return . TypedAST SNat $ syn (PrimOp_ Choose :$ e1' :* e2' :* End) _ -> argumentNumberError inferPrimOp U.Exp es = case es of [e] -> do e' <- checkType_ SReal e return . TypedAST SProb $ syn (PrimOp_ Exp :$ e' :* End) _ -> argumentNumberError inferPrimOp U.Log es = case es of [e] -> do e' <- checkType_ SProb e return . TypedAST SReal $ syn (PrimOp_ Log :$ e' :* End) _ -> argumentNumberError inferPrimOp U.Infinity es = case es of [] -> return . TypedAST SProb $ syn (PrimOp_ (Infinity HIntegrable_Prob) :$ End) _ -> argumentNumberError inferPrimOp U.GammaFunc es = case es of [e] -> do e' <- checkType_ SReal e return . TypedAST SProb $ syn (PrimOp_ GammaFunc :$ e' :* End) _ -> argumentNumberError inferPrimOp U.BetaFunc es = case es of [e1, e2] -> do e1' <- checkType_ SProb e1 e2' <- checkType_ SProb e2 return . TypedAST SProb $ syn (PrimOp_ BetaFunc :$ e1' :* e2' :* End) _ -> argumentNumberError inferPrimOp U.Equal es = case es of [_, _] -> do mode <- getMode TypedASTs typ [e1', e2'] <- case mode of StrictMode -> inferOneCheckOthers_ es _ -> inferLubType Nothing es primop <- Equal <$> getHEq typ return . TypedAST sBool $ syn (PrimOp_ primop :$ e1' :* e2' :* End) _ -> argumentNumberError inferPrimOp U.Less es = case es of [_, _] -> do mode <- getMode TypedASTs typ [e1', e2'] <- case mode of StrictMode -> inferOneCheckOthers_ es _ -> inferLubType Nothing es primop <- Less <$> getHOrd typ return . TypedAST sBool $ syn (PrimOp_ primop :$ e1' :* e2' :* End) _ -> argumentNumberError inferPrimOp U.NatPow es = case es of [e1, e2] -> do TypedAST typ e1' <- inferType_ e1 e2' <- checkType_ SNat e2 primop <- NatPow <$> getHSemiring typ return . TypedAST typ $ syn (PrimOp_ primop :$ e1' :* e2' :* End) _ -> argumentNumberError inferPrimOp U.Negate es = case es of [e] -> do TypedAST typ e' <- inferType_ e mode <- getMode SomeRing ring c <- getHRing typ mode primop <- Negate <$> return ring let e'' = case c of CNil -> e' c' -> unLC_ . coerceTo c' $ LC_ e' return . TypedAST (sing_HRing ring) $ syn (PrimOp_ primop :$ e'' :* End) _ -> argumentNumberError inferPrimOp U.Abs es = case es of [e] -> do TypedAST typ e' <- inferType_ e mode <- getMode SomeRing ring c <- getHRing typ mode primop <- Abs <$> return ring let e'' = case c of CNil -> e' c' -> unLC_ . coerceTo c' $ LC_ e' return . TypedAST (sing_NonNegative ring) $ syn (PrimOp_ primop :$ e'' :* End) _ -> argumentNumberError inferPrimOp U.Signum es = case es of [e] -> do TypedAST typ e' <- inferType_ e mode <- getMode SomeRing ring c <- getHRing typ mode primop <- Signum <$> return ring let e'' = case c of CNil -> e' c' -> unLC_ . coerceTo c' $ LC_ e' return . TypedAST (sing_HRing ring) $ syn (PrimOp_ primop :$ e'' :* End) _ -> argumentNumberError inferPrimOp U.Recip es = case es of [e] -> do TypedAST typ e' <- inferType_ e mode <- getMode SomeFractional frac c <- getHFractional typ mode primop <- Recip <$> return frac let e'' = case c of CNil -> e' c' -> unLC_ . coerceTo c' $ LC_ e' return . TypedAST (sing_HFractional frac) $ syn (PrimOp_ primop :$ e'' :* End) _ -> argumentNumberError -- BUG: Only defined for HRadical_Prob inferPrimOp U.NatRoot es = case es of [e1, e2] -> do e1' <- checkType_ SProb e1 e2' <- checkType_ SNat e2 return . TypedAST SProb $ syn (PrimOp_ (NatRoot HRadical_Prob) :$ e1' :* e2' :* End) _ -> argumentNumberError -- BUG: Only defined for HContinuous_Real inferPrimOp U.Erf es = case es of [e] -> do e' <- checkType_ SReal e return . TypedAST SReal $ syn (PrimOp_ (Erf HContinuous_Real) :$ e' :* End) _ -> argumentNumberError inferPrimOp x es | Just y <- lookup x [(U.Sin , Sin ), (U.Cos , Cos ), (U.Tan , Tan ), (U.Asin , Asin ), (U.Acos , Acos ), (U.Atan , Atan ), (U.Sinh , Sinh ), (U.Cosh , Cosh ), (U.Tanh , Tanh ), (U.Asinh, Asinh), (U.Acosh, Acosh), (U.Atanh, Atanh)] = case es of [e] -> do e' <- checkType_ SReal e return . TypedAST SReal $ syn (PrimOp_ y :$ e' :* End) _ -> argumentNumberError inferPrimOp U.Floor es = case es of [e] -> do e' <- checkType_ SProb e return . TypedAST SNat $ syn (PrimOp_ Floor :$ e' :* End) _ -> argumentNumberError inferPrimOp x _ = error ("TODO: inferPrimOp: " ++ show x) inferArrayOp :: U.ArrayOp -> [U.AST] -> TypeCheckMonad (TypedAST abt) inferArrayOp U.Index_ es = case es of [e1, e2] -> do TypedAST typ1 e1' <- inferType_ e1 unifyArray typ1 Nothing $ \typ2 -> do e2' <- checkType_ SNat e2 return . TypedAST typ2 $ syn (ArrayOp_ (Index typ2) :$ e1' :* e2' :* End) _ -> argumentNumberError inferArrayOp U.Size es = case es of [e] -> do TypedAST typ e' <- inferType_ e unifyArray typ Nothing $ \typ1 -> return . TypedAST SNat $ syn (ArrayOp_ (Size typ1) :$ e' :* End) _ -> argumentNumberError inferArrayOp U.Reduce es = case es of [e1, e2, e3] -> do TypedAST typ e1' <- inferType_ e1 unifyFun typ Nothing $ \typ1 typ2 -> do Refl <- jmEq1_ typ2 (SFun typ1 typ1) e2' <- checkType_ typ1 e2 e3' <- checkType_ (SArray typ1) e3 return . TypedAST typ1 $ syn (ArrayOp_ (Reduce typ1) :$ e1' :* e2' :* e3' :* End) _ -> argumentNumberError inferReducer :: U.Reducer xs U.U_ABT 'U.U -> List1 Variable xs1 -> TypeCheckMonad (TypedReducer abt xs1) inferReducer (U.R_Fanout_ r1 r2) xs = do TypedReducer t1 _ r1' <- inferReducer r1 xs TypedReducer t2 _ r2' <- inferReducer r2 xs return (TypedReducer (sPair t1 t2) xs (Red_Fanout r1' r2')) inferReducer (U.R_Index_ x n ix r1) xs = do let (_, n') = caseBinds n let b = makeVar x SNat TypedReducer t1 _ r1' <- inferReducer r1 (Cons1 b xs) n'' <- checkBinders xs SNat n' caseBind ix $ \i ix1 -> let i' = makeVar i SNat (_, ix2) = caseBinds ix1 in do ix3 <- pushCtx i' (checkBinders xs SNat ix2) return . TypedReducer (SArray t1) xs $ Red_Index n'' (bind i' ix3) r1' inferReducer (U.R_Split_ b r1 r2) xs = do TypedReducer t1 _ r1' <- inferReducer r1 xs TypedReducer t2 _ r2' <- inferReducer r2 xs caseBind b $ \x b1 -> let (_, b2) = caseBinds b1 x' = makeVar x SNat in do b3 <- pushCtx x' (checkBinders xs sBool b2) return . TypedReducer (sPair t1 t2) xs $ (Red_Split (bind x' b3) r1' r2') inferReducer U.R_Nop_ xs = return (TypedReducer sUnit xs Red_Nop) inferReducer (U.R_Add_ e) xs = caseBind e $ \x e1 -> let (_, e2) = caseBinds e1 x' = makeVar x SNat in pushCtx x' $ inferBinders xs e2 $ \typ e3 -> do h <- getHSemiring typ return $ TypedReducer typ xs (Red_Add h (bind x' e3)) -- TODO: can we make this lazier in the second component of 'TypedASTs' -- so that we can perform case analysis on the type component before -- actually evaluating 'checkOthers'? Problem is, even though we -- have the type to return we don't know whether the whole thing -- will succeed or not until after calling 'checkOthers'... We could -- handle this by changing the return type to @TypeCheckMonad (exists -- b. (Sing b, TypeCheckMonad [abt '[] b]))@ thereby making the -- staging explicit. -- -- | Given a list of terms which must all have the same type, try -- inferring each term in order until one of them succeeds and then -- check all the others against that type. This is appropriate for -- 'StrictMode' where we won't need to insert coercions; for -- 'LaxMode', see 'inferLubType' instead. inferOneCheckOthers :: forall abt . (ABT Term abt) => [U.AST] -> TypeCheckMonad (TypedASTs abt) inferOneCheckOthers = inferOne [] where inferOne :: [U.AST] -> [U.AST] -> TypeCheckMonad (TypedASTs abt) inferOne ls [] | null ls = ambiguousEmptyNary Nothing | otherwise = ambiguousMustCheckNary Nothing inferOne ls (e:rs) = do m <- try $ inferType e case m of Nothing -> inferOne (e:ls) rs Just (TypedAST typ e') -> do ls' <- checkOthers typ ls rs' <- checkOthers typ rs return (TypedASTs typ (reverse ls' ++ e' : rs')) checkOthers :: forall a. Sing a -> [U.AST] -> TypeCheckMonad [abt '[] a] checkOthers typ = T.traverse (checkType typ) -- | Given a list of terms which must all have the same type, infer -- all the terms in order and coerce them to the lub of all their -- types. This is appropriate for 'LaxMode' where we need to insert -- coercions; for 'StrictMode', see 'inferOneCheckOthers' instead. inferLubType :: forall abt . (ABT Term abt) => Maybe U.SourceSpan -> [U.AST] -> TypeCheckMonad (TypedASTs abt) inferLubType s = start where start :: [U.AST] -> TypeCheckMonad (TypedASTs abt) start [] = ambiguousEmptyNary Nothing start (u:us) = do TypedAST typ1 e1 <- inferType u TypedASTs typ2 es <- F.foldlM step (TypedASTs typ1 [e1]) us return (TypedASTs typ2 (reverse es)) -- TODO: inline 'F.foldlM' and then inline this, to unpack the first argument. step :: TypedASTs abt -> U.AST -> TypeCheckMonad (TypedASTs abt) step (TypedASTs typ1 es) u = do TypedAST typ2 e2 <- inferType u case findLub typ1 typ2 of Nothing -> missingLub typ1 typ2 s Just (Lub typ c1 c2) -> let es' = map (unLC_ . coerceTo c1 . LC_) es e2' = unLC_ . coerceTo c2 $ LC_ e2 in return (TypedASTs typ (e2' : es')) inferCaseStrict :: forall abt a . (ABT Term abt) => Sing a -> abt '[] a -> [U.Branch] -> TypeCheckMonad (TypedAST abt) inferCaseStrict typA e1 = inferOne [] where inferOne :: [U.Branch] -> [U.Branch] -> TypeCheckMonad (TypedAST abt) inferOne ls [] | null ls = ambiguousEmptyNary Nothing | otherwise = ambiguousMustCheckNary Nothing inferOne ls (b@(U.Branch_ pat e):rs) = do SP pat' vars <- checkPattern typA pat m <- try $ inferBinders vars e $ \typ e' -> do ls' <- checkOthers typ ls rs' <- checkOthers typ rs return (TypedAST typ $ syn (Case_ e1 (reverse ls' ++ (Branch pat' e') : rs'))) case m of Nothing -> inferOne (b:ls) rs Just m' -> return m' checkOthers :: forall b. Sing b -> [U.Branch] -> TypeCheckMonad [Branch a abt b] checkOthers typ = T.traverse (checkBranch typA typ) inferCaseLax :: forall abt a . (ABT Term abt) => Maybe U.SourceSpan -> Sing a -> abt '[] a -> [U.Branch] -> TypeCheckMonad (TypedAST abt) inferCaseLax s typA e1 = start where start :: [U.Branch] -> TypeCheckMonad (TypedAST abt) start [] = ambiguousEmptyNary Nothing start ((U.Branch_ pat e):us) = do SP pat' vars <- checkPattern typA pat inferBinders vars e $ \typ1 e' -> do SomeBranch typ2 bs <- F.foldlM step (SomeBranch typ1 [Branch pat' e']) us return . TypedAST typ2 . syn . Case_ e1 $ reverse bs -- TODO: inline 'F.foldlM' and then inline this, to unpack the first argument. step :: SomeBranch a abt -> U.Branch -> TypeCheckMonad (SomeBranch a abt) step (SomeBranch typB bs) (U.Branch_ pat e) = do SP pat' vars <- checkPattern typA pat inferBinders vars e $ \typE e' -> case findLub typB typE of Nothing -> missingLub typB typE s Just (Lub typLub coeB coeE) -> return $ SomeBranch typLub ( Branch pat' (coerceTo_nonLC coeE e') : map (coerceTo coeB) bs ) ---------------------------------------------------------------- ---------------------------------------------------------------- -- HACK: we must add the constraints that 'LCs' and 'UnLCs' are inverses. -- TODO: how can we do that in general rather than needing to repeat -- it here and in the various constructors of 'SCon'? checkSArgs :: (ABT Term abt, typs ~ UnLCs args, args ~ LCs typs) => List1 Sing typs -> [U.AST] -> TypeCheckMonad (SArgs abt args) checkSArgs Nil1 [] = return End checkSArgs (Cons1 typ typs) (e:es) = (:*) <$> checkType typ e <*> checkSArgs typs es checkSArgs _ _ = error "checkSArgs: the number of types and terms doesn't match up" -- | Given a typing environment, a type, and a term, verify that -- the term satisfies the type (and produce an elaborated term): -- -- > Γ ⊢ τ ∋ e ⇒ e' checkType :: forall abt a . (ABT Term abt) => Sing a -> U.AST -> TypeCheckMonad (abt '[] a) checkType = checkType_ where -- HACK: to convince GHC to stop being stupid about resolving -- the \"choice\" of @abt'@. I'm not sure why we don't need to -- use this same hack when 'inferType' calls 'checkType', but whatevs. inferType_ :: U.AST -> TypeCheckMonad (TypedAST abt) inferType_ = inferType checkVariable :: forall b . Sing b -> Maybe U.SourceSpan -> Variable 'U.U -> TypeCheckMonad (abt '[] b) checkVariable typ0 sourceSpan x = do TypedAST typ' e0' <- inferType_ (var x) mode <- getMode case mode of StrictMode -> case jmEq1 typ0 typ' of Just Refl -> return e0' Nothing -> typeMismatch sourceSpan (Right typ0) (Right typ') LaxMode -> checkOrCoerce sourceSpan e0' typ' typ0 UnsafeMode -> checkOrUnsafeCoerce sourceSpan e0' typ' typ0 checkType_ :: forall b. Sing b -> U.AST -> TypeCheckMonad (abt '[] b) checkType_ typ0 e0 = let s = getMetadata e0 in caseVarSyn e0 (checkVariable typ0 s) (go s) where go sourceSpan t = case t of -- Change of direction rule suggests this doesn't need to be here -- We keep it here in case, we later use a U.Lam which doesn't -- carry the type of its variable U.Lam_ (U.SSing typ) e1 -> unifyFun typ0 sourceSpan $ \typ1 typ2 -> matchTypes typ1 typ sourceSpan () () $ do e1' <- checkBinder typ1 typ2 e1 return $ syn (Lam_ :$ e1' :* End) U.Let_ e1 e2 -> do TypedAST typ1 e1' <- inferType_ e1 e2' <- checkBinder typ1 typ0 e2 return $ syn (Let_ :$ e1' :* e2' :* End) U.CoerceTo_ (Some2 c) e1 -> case singCoerceDomCod c of Nothing -> do e1' <- checkType_ typ0 e1 return $ syn (CoerceTo_ CNil :$ e1' :* End) Just (dom, cod) -> matchTypes typ0 cod sourceSpan () () $ do e1' <- checkType_ dom e1 return $ syn (CoerceTo_ c :$ e1' :* End) U.UnsafeTo_ (Some2 c) e1 -> case singCoerceDomCod c of Nothing -> do e1' <- checkType_ typ0 e1 return $ syn (UnsafeFrom_ CNil :$ e1' :* End) Just (dom, cod) -> matchTypes typ0 dom sourceSpan () () $ do e1' <- checkType_ cod e1 return $ syn (UnsafeFrom_ c :$ e1' :* End) -- TODO: Find better place to put this logic U.PrimOp_ U.Infinity [] -> do case typ0 of SNat -> return $ syn (PrimOp_ (Infinity HIntegrable_Nat) :$ End) SInt -> checkOrCoerce sourceSpan (syn (PrimOp_ (Infinity HIntegrable_Nat) :$ End)) SNat SInt SProb -> return $ syn (PrimOp_ (Infinity HIntegrable_Prob) :$ End) SReal -> checkOrCoerce sourceSpan (syn (PrimOp_ (Infinity HIntegrable_Prob) :$ End)) SProb SReal _ -> failwith =<< makeErrMsg "Type Mismatch:" sourceSpan "infinity can only be checked against nat or prob" U.Product_ e1 e2 e3 -> case hSemiring_Sing typ0 of Nothing -> failwith_ "Product given factors which are not in a semiring" Just h2 -> do TypedAST typ1 e1' <- inferType e1 e2' <- checkType_ typ1 e2 case hDiscrete_Sing typ1 of Nothing -> failwith_ "Product given bounds which are not discrete" Just h1 -> do e3' <- checkBinder typ1 typ0 e3 return $ syn (Product h1 h2 :$ e1' :* e2' :* e3' :* End) U.NaryOp_ op es -> do mode <- getMode case mode of StrictMode -> safeNaryOp typ0 LaxMode -> safeNaryOp typ0 UnsafeMode -> case op of U.Prod -> do op' <- make_NaryOp typ0 op (bads, goods) <- fmap partitionEithers . T.forM es $ \e -> do r <- tryWith LaxMode (checkType_ typ0 e) case r of Just er -> return (Right er) Nothing -> do r <- try (do TypedAST t p <- inferType e checkOrCoerce sourceSpan p t typ0) case r of Just er -> return (Right er) Nothing -> return (Left e) if null bads then return $ syn (NaryOp_ op' (S.fromList goods)) else do TypedAST typ bad <- inferType (case bads of [b] -> b _ -> syn $ U.NaryOp_ op bads) bad <- checkOrUnsafeCoerce sourceSpan bad typ typ0 return (case bad:goods of [e] -> e es' -> syn $ NaryOp_ op' (S.fromList es')) _ -> do es' <- tryWith LaxMode (safeNaryOp typ0) case es' of Just es'' -> return es'' Nothing -> do TypedAST typ e0' <- inferType (syn $ U.NaryOp_ op es) checkOrUnsafeCoerce sourceSpan e0' typ typ0 where safeNaryOp :: forall c. Sing c -> TypeCheckMonad (abt '[] c) safeNaryOp typ = do op' <- make_NaryOp typ op es' <- T.forM es $ checkType_ typ return $ syn (NaryOp_ op' (S.fromList es')) U.Pair_ e1 e2 -> unifyPair typ0 sourceSpan $ \a b -> do e1' <- checkType_ a e1 e2' <- checkType_ b e2 return $ syn (Datum_ $ dPair_ a b e1' e2') U.Array_ e1 e2 -> unifyArray typ0 sourceSpan $ \typ1 -> do e1' <- checkType_ SNat e1 e2' <- checkBinder SNat typ1 e2 return $ syn (Array_ e1' e2') U.ArrayLiteral_ es -> unifyArray typ0 sourceSpan $ \typ1 -> if null es then return $ syn (Empty_ typ0) else do es' <- T.forM es $ checkType_ typ1 return $ syn (ArrayLiteral_ es') U.Datum_ (U.Datum hint d) -> case typ0 of SData _ typ2 -> (syn . Datum_ . Datum hint typ0) <$> checkDatumCode typ0 typ2 d _ -> typeMismatch sourceSpan (Right typ0) (Left "HData") U.Case_ e1 branches -> do TypedAST typ1 e1' <- inferType_ e1 branches' <- T.forM branches $ checkBranch typ1 typ0 return $ syn (Case_ e1' branches') U.Dirac_ e1 -> unifyMeasure typ0 sourceSpan $ \typ1 -> do e1' <- checkType_ typ1 e1 return $ syn (Dirac :$ e1' :* End) U.MBind_ e1 e2 -> unifyMeasure typ0 sourceSpan $ \_ -> do TypedAST typ1 e1' <- inferType_ e1 unifyMeasure typ1 (getMetadata e1) $ \typ2 -> do e2' <- checkBinder typ2 typ0 e2 return $ syn (MBind :$ e1' :* e2' :* End) U.Plate_ e1 e2 -> unifyMeasure typ0 sourceSpan $ \typ1 -> do e1' <- checkType_ SNat e1 unifyArray typ1 sourceSpan $ \typ2 -> do e2' <- checkBinder SNat (SMeasure typ2) e2 return $ syn (Plate :$ e1' :* e2' :* End) U.Chain_ e1 e2 e3 -> unifyMeasure typ0 sourceSpan $ \typ1 -> unifyPair typ1 sourceSpan $ \aa s -> unifyArray aa sourceSpan $ \a -> do e1' <- checkType_ SNat e1 e2' <- checkType_ s e2 e3' <- checkBinder s (SMeasure $ sPair a s) e3 return $ syn (Chain :$ e1' :* e2' :* e3' :* End) U.Transform_ tr es -> checkTransform sourceSpan typ0 tr es U.Superpose_ pes -> unifyMeasure typ0 sourceSpan $ \_ -> fmap (syn . Superpose_) . T.forM pes $ \(p,e) -> (,) <$> checkType_ SProb p <*> checkType_ typ0 e U.Reject_ -> unifyMeasure typ0 sourceSpan $ \_ -> return $ syn (Reject_ typ0) U.InjTyped t -> let typ1 = typeOf $ triv t in case jmEq1 typ0 typ1 of Just Refl -> return t Nothing -> typeMismatch sourceSpan (Right typ0) (Right typ1) _ | inferable e0 -> do TypedAST typ' e0' <- inferType_ e0 mode <- getMode case mode of StrictMode -> case jmEq1 typ0 typ' of Just Refl -> return e0' Nothing -> typeMismatch sourceSpan (Right typ0) (Right typ') LaxMode -> checkOrCoerce sourceSpan e0' typ' typ0 UnsafeMode -> checkOrUnsafeCoerce sourceSpan e0' typ' typ0 | otherwise -> error "checkType: missing an mustCheck branch!" checkTransform :: Maybe U.SourceSpan -> Sing x' -> Transform as x -> U.SArgs U.U_ABT as -> TypeCheckMonad (abt '[] x') checkTransform sourceSpan typ0 Expect ((Nil2, e1) U.:* (Cons2 U.ToU Nil2, e2) U.:* U.End) = case typ0 of SProb -> do TypedAST typ1 e1' <- inferType_ e1 unifyMeasure typ1 sourceSpan $ \typ2 -> do e2' <- checkBinder typ2 typ0 e2 return $ syn (Transform_ Expect :$ e1' :* e2' :* End) _ -> typeMismatch sourceSpan (Right typ0) (Left "HProb") checkTransform sourceSpan typ0 Observe ((Nil2, e1) U.:* (Nil2, e2) U.:* U.End) = unifyMeasure typ0 sourceSpan $ \typ2 -> do e1' <- checkType_ typ0 e1 e2' <- checkType_ typ2 e2 return $ syn (Transform_ Observe :$ e1' :* e2' :* End) checkTransform sourceSpan typ0 MCMC ((Nil2, e1) U.:* (Nil2, e2) U.:* U.End) = unifyFun typ0 sourceSpan $ \typa typmb -> unifyMeasure typmb sourceSpan $ \typb -> matchTypes typa typb sourceSpan (SFun typa (SMeasure typa)) typ0 $ do e1' <- checkType (SFun typa (SMeasure typa)) e1 e2' <- checkType (SMeasure typa) e2 return $ syn $ Transform_ MCMC :$ e1' :* e2' :* End checkTransform sourceSpan typ0 (Disint k) ((Nil2, e1) U.:* U.End) = unifyFun typ0 sourceSpan $ \typa typmb -> unifyMeasure typmb sourceSpan $ \typb -> do e1' <- checkType (SMeasure (sPair typa typb)) e1 return $ syn $ Transform_ (Disint k) :$ e1' :* End checkTransform sourceSpan typ0 Simplify ((Nil2, e1) U.:* U.End) = do e1' <- checkType_ typ0 e1 return $ syn (Transform_ Simplify :$ e1' :* End) checkTransform sourceSpan typ0 Reparam ((Nil2, e1) U.:* U.End) = do e1' <- checkType_ typ0 e1 return $ syn (Transform_ Reparam :$ e1' :* End) checkTransform sourceSpan typ0 Summarize ((Nil2, e1) U.:* U.End) = do e1' <- checkType_ typ0 e1 return $ syn (Transform_ Summarize :$ e1' :* End) checkTransform _ _ tr _ = error $ "checkTransform{" ++ show tr ++ "}: TODO" -------------------------------------------------------- -- We make these local to 'checkType' for the same reason we have 'checkType_' -- TODO: can we combine these in with the 'checkBranch' functions somehow? checkDatumCode :: forall xss t . Sing (HData' t) -> Sing xss -> U.DCode_ -> TypeCheckMonad (DatumCode xss (abt '[]) (HData' t)) checkDatumCode typA typ d = case d of U.Inr d2 -> case typ of SPlus _ typ2 -> Inr <$> checkDatumCode typA typ2 d2 _ -> failwith_ "expected datum of `inr' type" U.Inl d1 -> case typ of SPlus typ1 _ -> Inl <$> checkDatumStruct typA typ1 d1 _ -> failwith_ "expected datum of `inl' type" checkDatumStruct :: forall xs t . Sing (HData' t) -> Sing xs -> U.DStruct_ -> TypeCheckMonad (DatumStruct xs (abt '[]) (HData' t)) checkDatumStruct typA typ d = case d of U.Et d1 d2 -> case typ of SEt typ1 typ2 -> Et <$> checkDatumFun typA typ1 d1 <*> checkDatumStruct typA typ2 d2 _ -> failwith_ "expected datum of `et' type" U.Done -> case typ of SDone -> return Done _ -> failwith_ "expected datum of `done' type" checkDatumFun :: forall x t . Sing (HData' t) -> Sing x -> U.DFun_ -> TypeCheckMonad (DatumFun x (abt '[]) (HData' t)) checkDatumFun typA typ d = case d of U.Ident e1 -> case typ of SIdent -> Ident <$> checkType_ typA e1 _ -> failwith_ "expected datum of `I' type" U.Konst e1 -> case typ of SKonst typ1 -> Konst <$> checkType_ typ1 e1 _ -> failwith_ "expected datum of `K' type" checkBranch :: (ABT Term abt) => Sing a -> Sing b -> U.Branch -> TypeCheckMonad (Branch a abt b) checkBranch patTyp bodyTyp (U.Branch_ pat body) = do SP pat' vars <- checkPattern patTyp pat Branch pat' <$> checkBinders vars bodyTyp body checkPattern :: Sing a -> U.Pattern -> TypeCheckMonad (SomePattern a) checkPattern = \typA pat -> case pat of U.PVar x -> return $ SP PVar (Cons1 (makeVar (U.nameToVar x) typA) Nil1) U.PWild -> return $ SP PWild Nil1 U.PDatum hint pat1 -> case typA of SData _ typ1 -> do SPC pat1' xs <- checkPatternCode typA typ1 pat1 return $ SP (PDatum hint pat1') xs _ -> typeMismatch Nothing (Right typA) (Left "HData") where checkPatternCode :: Sing (HData' t) -> Sing xss -> U.PCode -> TypeCheckMonad (SomePatternCode xss t) checkPatternCode typA typ pat = case pat of U.PInr pat2 -> case typ of SPlus _ typ2 -> do SPC pat2' xs <- checkPatternCode typA typ2 pat2 return $ SPC (PInr pat2') xs _ -> failwith_ "expected pattern of `sum' type" U.PInl pat1 -> case typ of SPlus typ1 _ -> do SPS pat1' xs <- checkPatternStruct typA typ1 pat1 return $ SPC (PInl pat1') xs _ -> failwith_ "expected pattern of `zero' type" checkPatternStruct :: Sing (HData' t) -> Sing xs -> U.PStruct -> TypeCheckMonad (SomePatternStruct xs t) checkPatternStruct typA typ pat = case pat of U.PEt pat1 pat2 -> case typ of SEt typ1 typ2 -> do SPF pat1' xs <- checkPatternFun typA typ1 pat1 SPS pat2' ys <- checkPatternStruct typA typ2 pat2 return $ SPS (PEt pat1' pat2') (append1 xs ys) _ -> failwith_ "expected pattern of `et' type" U.PDone -> case typ of SDone -> return $ SPS PDone Nil1 _ -> failwith_ "expected pattern of `done' type" checkPatternFun :: Sing (HData' t) -> Sing x -> U.PFun -> TypeCheckMonad (SomePatternFun x t) checkPatternFun typA typ pat = case pat of U.PIdent pat1 -> case typ of SIdent -> do SP pat1' xs <- checkPattern typA pat1 return $ SPF (PIdent pat1') xs _ -> failwith_ "expected pattern of `I' type" U.PKonst pat1 -> case typ of SKonst typ1 -> do SP pat1' xs <- checkPattern typ1 pat1 return $ SPF (PKonst pat1') xs _ -> failwith_ "expected pattern of `K' type" checkOrCoerce :: (ABT Term abt) => Maybe (U.SourceSpan) -> abt '[] a -> Sing a -> Sing b -> TypeCheckMonad (abt '[] b) checkOrCoerce s e typA typB = case findCoercion typA typB of Just c -> return . unLC_ . coerceTo c $ LC_ e Nothing -> typeMismatch s (Right typB) (Right typA) checkOrUnsafeCoerce :: (ABT Term abt) => Maybe (U.SourceSpan) -> abt '[] a -> Sing a -> Sing b -> TypeCheckMonad (abt '[] b) checkOrUnsafeCoerce s e typA typB = case findEitherCoercion typA typB of Just (Unsafe c) -> return . unLC_ . coerceFrom c $ LC_ e Just (Safe c) -> return . unLC_ . coerceTo c $ LC_ e Just (Mixed (_, c1, c2)) -> return . unLC_ . coerceTo c2 . coerceFrom c1 $ LC_ e Nothing -> case (typA, typB) of -- mighty, mighty hack! (SMeasure typ1, SMeasure _) -> do let x = Variable (pack "") 0 U.SU e2' <- checkBinder typ1 typB (bind x $ syn $ U.Dirac_ (var x)) return $ syn (MBind :$ e :* e2' :* End) (_ , _) -> typeMismatch s (Right typB) (Right typA) ---------------------------------------------------------------- ----------------------------------------------------------- fin.