{-# LANGUAGE ExistentialQuantification #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE TupleSections #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeSynonymInstances #-}
{-# OPTIONS_GHC -Wno-name-shadowing #-}
module Nix.Type.Infer (
Constraint(..),
TypeError(..),
InferError(..),
Subst(..),
inferTop
) where
import Control.Applicative
import Control.Arrow
import Control.Monad.Catch
import Control.Monad.Except
import Control.Monad.Logic
import Control.Monad.Reader
import Control.Monad.ST
import Control.Monad.State
import Data.Fix
import Data.Foldable
import qualified Data.HashMap.Lazy as M
import Data.List (delete, find, nub, intersect, (\\))
import Data.Map (Map)
import qualified Data.Map as Map
import Data.Maybe (fromJust)
import Data.STRef
import Data.Semigroup
import qualified Data.Set as Set
import Data.Text (Text)
import Nix.Atoms
import Nix.Convert
import Nix.Eval (MonadEval(..))
import qualified Nix.Eval as Eval
import Nix.Expr.Types
import Nix.Expr.Types.Annotated
import Nix.Scope
import Nix.Thunk
import qualified Nix.Type.Assumption as As
import Nix.Type.Env
import qualified Nix.Type.Env as Env
import Nix.Type.Type
import Nix.Utils
newtype Infer s a = Infer
{ getInfer ::
ReaderT (Set.Set TVar, Scopes (Infer s) (JThunk s))
(StateT InferState (ExceptT InferError (ST s))) a
}
deriving (Functor, Applicative, Alternative, Monad, MonadPlus, MonadFix,
MonadReader (Set.Set TVar, Scopes (Infer s) (JThunk s)),
MonadState InferState, MonadError InferError)
newtype InferState = InferState { count :: Int }
initInfer :: InferState
initInfer = InferState { count = 0 }
data Constraint
= EqConst Type Type
| ExpInstConst Type Scheme
| ImpInstConst Type (Set.Set TVar) Type
deriving (Show, Eq, Ord)
newtype Subst = Subst (Map TVar Type)
deriving (Eq, Ord, Show, Semigroup, Monoid)
class Substitutable a where
apply :: Subst -> a -> a
instance Substitutable TVar where
apply (Subst s) a = tv
where t = TVar a
(TVar tv) = Map.findWithDefault t a s
instance Substitutable Type where
apply _ (TCon a) = TCon a
apply s (TSet b a) = TSet b (M.map (apply s) a)
apply s (TList a) = TList (map (apply s) a)
apply (Subst s) t@(TVar a) = Map.findWithDefault t a s
apply s (t1 `TArr` t2) = apply s t1 `TArr` apply s t2
apply s (TMany ts) = TMany (map (apply s) ts)
instance Substitutable Scheme where
apply (Subst s) (Forall as t) = Forall as $ apply s' t
where s' = Subst $ foldr Map.delete s as
instance Substitutable Constraint where
apply s (EqConst t1 t2) = EqConst (apply s t1) (apply s t2)
apply s (ExpInstConst t sc) = ExpInstConst (apply s t) (apply s sc)
apply s (ImpInstConst t1 ms t2) = ImpInstConst (apply s t1) (apply s ms) (apply s t2)
instance Substitutable a => Substitutable [a] where
apply = map . apply
instance (Ord a, Substitutable a) => Substitutable (Set.Set a) where
apply = Set.map . apply
class FreeTypeVars a where
ftv :: a -> Set.Set TVar
instance FreeTypeVars Type where
ftv TCon{} = Set.empty
ftv (TVar a) = Set.singleton a
ftv (TSet _ a) = Set.unions (map ftv (M.elems a))
ftv (TList a) = Set.unions (map ftv a)
ftv (t1 `TArr` t2) = ftv t1 `Set.union` ftv t2
ftv (TMany ts) = Set.unions (map ftv ts)
instance FreeTypeVars TVar where
ftv = Set.singleton
instance FreeTypeVars Scheme where
ftv (Forall as t) = ftv t `Set.difference` Set.fromList as
instance FreeTypeVars a => FreeTypeVars [a] where
ftv = foldr (Set.union . ftv) Set.empty
instance (Ord a, FreeTypeVars a) => FreeTypeVars (Set.Set a) where
ftv = foldr (Set.union . ftv) Set.empty
class ActiveTypeVars a where
atv :: a -> Set.Set TVar
instance ActiveTypeVars Constraint where
atv (EqConst t1 t2) = ftv t1 `Set.union` ftv t2
atv (ImpInstConst t1 ms t2) = ftv t1 `Set.union` (ftv ms `Set.intersection` ftv t2)
atv (ExpInstConst t s) = ftv t `Set.union` ftv s
instance ActiveTypeVars a => ActiveTypeVars [a] where
atv = foldr (Set.union . atv) Set.empty
data TypeError
= UnificationFail Type Type
| InfiniteType TVar Type
| UnboundVariables [Text]
| Ambigious [Constraint]
| UnificationMismatch [Type] [Type]
deriving (Eq, Show)
data InferError
= TypeInferenceErrors [TypeError]
| TypeInferenceAborted
| forall s. Exception s => EvaluationError s
typeError :: MonadError InferError m => TypeError -> m ()
typeError err = throwError $ TypeInferenceErrors [err]
deriving instance Show InferError
instance Exception InferError
instance Semigroup InferError where
x <> _ = x
instance Monoid InferError where
mempty = TypeInferenceAborted
mappend = (<>)
runInfer' :: Infer s a -> ST s (Either InferError a)
runInfer' = runExceptT
. (`evalStateT` initInfer)
. (`runReaderT` (Set.empty, emptyScopes))
. getInfer
runInfer :: (forall s. Infer s a) -> Either InferError a
runInfer m = runST (runInfer' m)
inferType :: Env -> NExpr -> Infer s [(Subst, Type)]
inferType env ex = do
Judgment as cs t <- infer ex
let unbounds = Set.fromList (As.keys as) `Set.difference`
Set.fromList (Env.keys env)
unless (Set.null unbounds) $
typeError $ UnboundVariables (nub (Set.toList unbounds))
let cs' = [ ExpInstConst t s
| (x, ss) <- Env.toList env
, s <- ss
, t <- As.lookup x as]
inferState <- get
let eres = (`evalState` inferState) $ runSolver $ do
subst <- solve (cs ++ cs')
return (subst, subst `apply` t)
case eres of
Left errs -> throwError $ TypeInferenceErrors errs
Right xs -> pure xs
inferExpr :: Env -> NExpr -> Either InferError [Scheme]
inferExpr env ex = case runInfer (inferType env ex) of
Left err -> Left err
Right xs -> Right $ map (\(subst, ty) -> closeOver (subst `apply` ty)) xs
closeOver :: Type -> Scheme
closeOver = normalize . generalize Set.empty
extendMSet :: TVar -> Infer s a -> Infer s a
extendMSet x = Infer . local (first (Set.insert x)) . getInfer
letters :: [String]
letters = [1..] >>= flip replicateM ['a'..'z']
fresh :: MonadState InferState m => m Type
fresh = do
s <- get
put s{count = count s + 1}
return $ TVar $ TV (letters !! count s)
instantiate :: MonadState InferState m => Scheme -> m Type
instantiate (Forall as t) = do
as' <- mapM (const fresh) as
let s = Subst $ Map.fromList $ zip as as'
return $ apply s t
generalize :: Set.Set TVar -> Type -> Scheme
generalize free t = Forall as t
where as = Set.toList $ ftv t `Set.difference` free
unops :: Type -> NUnaryOp -> [Constraint]
unops u1 = \case
NNot -> [ EqConst u1 (typeFun [typeBool, typeBool]) ]
NNeg -> [ EqConst u1 (TMany [ typeFun [typeInt, typeInt]
, typeFun [typeFloat, typeFloat] ]) ]
binops :: Type -> NBinaryOp -> [Constraint]
binops u1 = \case
NApp -> []
NEq -> []
NNEq -> []
NGt -> inequality
NGte -> inequality
NLt -> inequality
NLte -> inequality
NAnd -> [ EqConst u1 (typeFun [typeBool, typeBool, typeBool]) ]
NOr -> [ EqConst u1 (typeFun [typeBool, typeBool, typeBool]) ]
NImpl -> [ EqConst u1 (typeFun [typeBool, typeBool, typeBool]) ]
NConcat -> [ EqConst u1 (TMany [ typeFun [typeList, typeList, typeList]
, typeFun [typeList, typeNull, typeList]
, typeFun [typeNull, typeList, typeList]
]) ]
NUpdate -> [ EqConst u1 (TMany [ typeFun [typeSet, typeSet, typeSet]
, typeFun [typeSet, typeNull, typeSet]
, typeFun [typeNull, typeSet, typeSet]
]) ]
NPlus -> [ EqConst u1 (TMany [ typeFun [typeInt, typeInt, typeInt]
, typeFun [typeFloat, typeFloat, typeFloat]
, typeFun [typeInt, typeFloat, typeFloat]
, typeFun [typeFloat, typeInt, typeFloat]
, typeFun [typeString, typeString, typeString]
, typeFun [typePath, typePath, typePath]
, typeFun [typeString, typeString, typePath]
]) ]
NMinus -> arithmetic
NMult -> arithmetic
NDiv -> arithmetic
where
inequality =
[ EqConst u1 (TMany [ typeFun [typeInt, typeInt, typeBool]
, typeFun [typeFloat, typeFloat, typeBool]
, typeFun [typeInt, typeFloat, typeBool]
, typeFun [typeFloat, typeInt, typeBool]
]) ]
arithmetic =
[ EqConst u1 (TMany [ typeFun [typeInt, typeInt, typeInt]
, typeFun [typeFloat, typeFloat, typeFloat]
, typeFun [typeInt, typeFloat, typeFloat]
, typeFun [typeFloat, typeInt, typeFloat]
]) ]
instance MonadVar (Infer s) where
type Var (Infer s) = STRef s
newVar x = Infer $ lift $ lift $ lift $ newSTRef x
readVar x = Infer $ lift $ lift $ lift $ readSTRef x
writeVar x y = Infer $ lift $ lift $ lift $ writeSTRef x y
atomicModifyVar x f = Infer $ lift $ lift $ lift $ do
res <- snd . f <$> readSTRef x
_ <- modifySTRef x (fst . f)
return res
newtype JThunk s = JThunk (Thunk (Infer s) (Judgment s))
instance MonadThrow (Infer s) where
throwM = throwError . EvaluationError
instance MonadCatch (Infer s) where
catch m h = catchError m $ \case
EvaluationError e ->
maybe (error $ "Exception was not an exception: " ++ show e) h
(fromException (toException e))
err -> error $ "Unexpected error: " ++ show err
instance MonadThunk (Judgment s) (JThunk s) (Infer s) where
thunk = fmap JThunk . buildThunk
force (JThunk t) f = catch (forceThunk t f) $ \(_ :: ThunkLoop) ->
f =<< Judgment As.empty [] <$> fresh
value = JThunk . valueRef
instance MonadEval (Judgment s) (Infer s) where
freeVariable var = do
tv <- fresh
return $ Judgment (As.singleton var tv) [] tv
attrMissing _ _ = Judgment As.empty [] <$> fresh
evaledSym _ = pure
evalCurPos =
return $ Judgment As.empty [] $ TSet False $ M.fromList
[ ("file", typePath)
, ("line", typeInt)
, ("col", typeInt) ]
evalConstant c = return $ Judgment As.empty [] (go c)
where
go = \case
NInt _ -> typeInt
NFloat _ -> typeFloat
NBool _ -> typeBool
NNull -> typeNull
evalString = const $ return $ Judgment As.empty [] typeString
evalLiteralPath = const $ return $ Judgment As.empty [] typePath
evalEnvPath = const $ return $ Judgment As.empty [] typePath
evalUnary op (Judgment as1 cs1 t1) = do
tv <- fresh
return $ Judgment as1 (cs1 ++ unops (t1 `TArr` tv) op) tv
evalBinary op (Judgment as1 cs1 t1) e2 = do
Judgment as2 cs2 t2 <- e2
tv <- fresh
return $ Judgment
(as1 `As.merge` as2)
(cs1 ++ cs2 ++ binops (t1 `TArr` (t2 `TArr` tv)) op)
tv
evalWith = Eval.evalWithAttrSet
evalIf (Judgment as1 cs1 t1) t f = do
Judgment as2 cs2 t2 <- t
Judgment as3 cs3 t3 <- f
return $ Judgment
(as1 `As.merge` as2 `As.merge` as3)
(cs1 ++ cs2 ++ cs3 ++ [EqConst t1 typeBool, EqConst t2 t3])
t2
evalAssert (Judgment as1 cs1 t1) body = do
Judgment as2 cs2 t2 <- body
return $ Judgment
(as1 `As.merge` as2)
(cs1 ++ cs2 ++ [EqConst t1 typeBool])
t2
evalApp (Judgment as1 cs1 t1) e2 = do
Judgment as2 cs2 t2 <- e2
tv <- fresh
return $ Judgment
(as1 `As.merge` as2)
(cs1 ++ cs2 ++ [EqConst t1 (t2 `TArr` tv)])
tv
evalAbs (Param x) k = do
tv@(TVar a) <- fresh
((), Judgment as cs t) <-
extendMSet a (k (pure (Judgment (As.singleton x tv) [] tv))
(\_ b -> ((),) <$> b))
return $ Judgment
(as `As.remove` x)
(cs ++ [EqConst t' tv | t' <- As.lookup x as])
(tv `TArr` t)
evalAbs (ParamSet ps variadic _mname) k = do
js <- fmap concat $ forM ps $ \(name, _) -> do
tv <- fresh
pure [(name, tv)]
let (env, tys) = (\f -> foldl' f (As.empty, M.empty) js)
$ \(as1, t1) (k, t) ->
(as1 `As.merge` As.singleton k t, M.insert k t t1)
arg = pure $ Judgment env [] (TSet True tys)
call = k arg $ \args b -> (args,) <$> b
names = map fst js
(args, Judgment as cs t) <-
foldr (\(_, TVar a) -> extendMSet a) call js
ty <- TSet variadic <$> traverse (inferredType <$>) args
return $ Judgment
(foldl' As.remove as names)
(cs ++ [ EqConst t' (tys M.! x)
| x <- names
, t' <- As.lookup x as])
(ty `TArr` t)
evalError = throwError . EvaluationError
data Judgment s = Judgment
{ assumptions :: As.Assumption
, typeConstraints :: [Constraint]
, inferredType :: Type
}
deriving Show
instance FromValue (Text, DList Text) (Infer s) (Judgment s) where
fromValueMay _ = return Nothing
fromValue _ = error "Unused"
instance FromValue (AttrSet (JThunk s), AttrSet SourcePos) (Infer s) (Judgment s) where
fromValueMay (Judgment _ _ (TSet _ xs)) = do
let sing _ = Judgment As.empty []
pure $ Just (M.mapWithKey (\k v -> value (sing k v)) xs, M.empty)
fromValueMay _ = pure Nothing
fromValue = fromValueMay >=> \case
Just v -> pure v
Nothing -> pure (M.empty, M.empty)
instance ToValue (AttrSet (JThunk s), AttrSet SourcePos) (Infer s) (Judgment s) where
toValue (xs, _) = Judgment
<$> foldrM go As.empty xs
<*> (concat <$> traverse (`force` (pure . typeConstraints)) xs)
<*> (TSet True <$> traverse (`force` (pure . inferredType)) xs)
where
go x rest = force x $ \x' -> pure $ As.merge (assumptions x') rest
instance ToValue [JThunk s] (Infer s) (Judgment s) where
toValue xs = Judgment
<$> foldrM go As.empty xs
<*> (concat <$> traverse (`force` (pure . typeConstraints)) xs)
<*> (TList <$> traverse (`force` (pure . inferredType)) xs)
where
go x rest = force x $ \x' -> pure $ As.merge (assumptions x') rest
instance ToValue Bool (Infer s) (Judgment s) where
toValue _ = pure $ Judgment As.empty [] typeBool
infer :: NExpr -> Infer s (Judgment s)
infer = cata Eval.eval
inferTop :: Env -> [(Text, NExpr)] -> Either InferError Env
inferTop env [] = Right env
inferTop env ((name, ex):xs) = case inferExpr env ex of
Left err -> Left err
Right ty -> inferTop (extend env (name, ty)) xs
normalize :: Scheme -> Scheme
normalize (Forall _ body) = Forall (map snd ord) (normtype body)
where
ord = zip (nub $ fv body) (map TV letters)
fv (TVar a) = [a]
fv (TArr a b) = fv a ++ fv b
fv (TCon _) = []
fv (TSet _ a) = concatMap fv (M.elems a)
fv (TList a) = concatMap fv a
fv (TMany ts) = concatMap fv ts
normtype (TArr a b) = TArr (normtype a) (normtype b)
normtype (TCon a) = TCon a
normtype (TSet b a) = TSet b (M.map normtype a)
normtype (TList a) = TList (map normtype a)
normtype (TMany ts) = TMany (map normtype ts)
normtype (TVar a) =
case Prelude.lookup a ord of
Just x -> TVar x
Nothing -> error "type variable not in signature"
newtype Solver m a = Solver (LogicT (StateT [TypeError] m) a)
deriving (Functor, Applicative, Alternative, Monad, MonadPlus,
MonadLogic, MonadState [TypeError])
instance MonadTrans Solver where
lift = Solver . lift . lift
instance Monad m => MonadError TypeError (Solver m) where
throwError err = Solver $ lift (modify (err:)) >> mzero
catchError _ _ = error "This is never used"
runSolver :: Monad m => Solver m a -> m (Either [TypeError] [a])
runSolver (Solver s) = do
res <- runStateT (observeAllT s) []
pure $ case res of
(x:xs, _) -> Right (x:xs)
(_, es) -> Left (nub es)
emptySubst :: Subst
emptySubst = mempty
compose :: Subst -> Subst -> Subst
Subst s1 `compose` Subst s2 =
Subst $ Map.map (apply (Subst s1)) s2 `Map.union` s1
unifyMany :: Monad m => [Type] -> [Type] -> Solver m Subst
unifyMany [] [] = return emptySubst
unifyMany (t1 : ts1) (t2 : ts2) =
do su1 <- unifies t1 t2
su2 <- unifyMany (apply su1 ts1) (apply su1 ts2)
return (su2 `compose` su1)
unifyMany t1 t2 = throwError $ UnificationMismatch t1 t2
allSameType :: [Type] -> Bool
allSameType [] = True
allSameType [_] = True
allSameType (x:y:ys) = x == y && allSameType (y:ys)
unifies :: Monad m => Type -> Type -> Solver m Subst
unifies t1 t2 | t1 == t2 = return emptySubst
unifies (TVar v) t = v `bind` t
unifies t (TVar v) = v `bind` t
unifies (TList xs) (TList ys)
| allSameType xs && allSameType ys = case (xs, ys) of
(x:_, y:_) -> unifies x y
_ -> return emptySubst
| length xs == length ys = unifyMany xs ys
unifies t1@(TList _) t2@(TList _) = throwError $ UnificationFail t1 t2
unifies (TSet True _) (TSet True _) = return emptySubst
unifies (TSet False b) (TSet True s)
| M.keys b `intersect` M.keys s == M.keys s = return emptySubst
unifies (TSet True s) (TSet False b)
| M.keys b `intersect` M.keys s == M.keys b = return emptySubst
unifies (TSet False s) (TSet False b)
| null (M.keys b \\ M.keys s) = return emptySubst
unifies (TArr t1 t2) (TArr t3 t4) = unifyMany [t1, t2] [t3, t4]
unifies (TMany t1s) t2 = considering t1s >>- unifies ?? t2
unifies t1 (TMany t2s) = considering t2s >>- unifies t1
unifies t1 t2 = throwError $ UnificationFail t1 t2
bind :: Monad m => TVar -> Type -> Solver m Subst
bind a t | t == TVar a = return emptySubst
| occursCheck a t = throwError $ InfiniteType a t
| otherwise = return (Subst $ Map.singleton a t)
occursCheck :: FreeTypeVars a => TVar -> a -> Bool
occursCheck a t = a `Set.member` ftv t
nextSolvable :: [Constraint] -> (Constraint, [Constraint])
nextSolvable xs = fromJust (find solvable (chooseOne xs))
where
chooseOne xs = [(x, ys) | x <- xs, let ys = delete x xs]
solvable (EqConst{}, _) = True
solvable (ExpInstConst{}, _) = True
solvable (ImpInstConst _t1 ms t2, cs) =
Set.null ((ftv t2 `Set.difference` ms) `Set.intersection` atv cs)
considering :: [a] -> Solver m a
considering xs = Solver $ LogicT $ \c n -> foldr c n xs
solve :: MonadState InferState m => [Constraint] -> Solver m Subst
solve [] = return emptySubst
solve cs = solve' (nextSolvable cs)
where
solve' (EqConst t1 t2, cs) =
unifies t1 t2 >>- \su1 ->
solve (apply su1 cs) >>- \su2 ->
return (su2 `compose` su1)
solve' (ImpInstConst t1 ms t2, cs) =
solve (ExpInstConst t1 (generalize ms t2) : cs)
solve' (ExpInstConst t s, cs) = do
s' <- lift $ instantiate s
solve (EqConst t s' : cs)