module Language.PureScript.TypeChecker.Skolems (
newSkolemConstant,
introduceSkolemScope,
newSkolemScope,
skolemize,
skolemizeTypesInValue,
skolemEscapeCheck
) where
import Data.List (nub, (\\))
import Data.Monoid
import Control.Applicative
import Control.Monad.Error
import Control.Monad.Unify
import Language.PureScript.AST
import Language.PureScript.Errors
import Language.PureScript.Pretty
import Language.PureScript.TypeChecker.Monad
import Language.PureScript.Types
newSkolemConstant :: UnifyT Type Check Int
newSkolemConstant = fresh'
introduceSkolemScope :: Type -> UnifyT Type Check Type
introduceSkolemScope = everywhereOnTypesM go
where
go (ForAll ident ty Nothing) = ForAll ident ty <$> (Just <$> newSkolemScope)
go other = return other
newSkolemScope :: UnifyT Type Check SkolemScope
newSkolemScope = SkolemScope <$> fresh'
skolemize :: String -> Int -> SkolemScope -> Type -> Type
skolemize ident sko scope = replaceTypeVars ident (Skolem ident sko scope)
skolemizeTypesInValue :: String -> Int -> SkolemScope -> Expr -> Expr
skolemizeTypesInValue ident sko scope = let (_, f, _) = everywhereOnValues id go id in f
where
go (SuperClassDictionary c ts) = SuperClassDictionary c (map (skolemize ident sko scope) ts)
go other = other
skolemEscapeCheck :: Expr -> Check ()
skolemEscapeCheck (TypedValue False _ _) = return ()
skolemEscapeCheck root@TypedValue{} =
let (_, f, _, _, _) = everythingWithContextOnValues [] [] (++) def go def def def
in case f root of
[] -> return ()
((binding, val) : _) -> throwError $ mkErrorStack ("Rigid/skolem type variable " ++ maybe "" (("bound by " ++) . prettyPrintValue) binding ++ " has escaped.") (Just (ExprError val))
where
def s _ = (s, [])
go :: [(SkolemScope, Expr)] -> Expr -> ([(SkolemScope, Expr)], [(Maybe Expr, Expr)])
go scos val@(TypedValue _ _ (ForAll _ _ (Just sco))) = ((sco, val) : scos, [])
go scos val@(TypedValue _ _ ty) = case collectSkolems ty \\ map fst scos of
(sco : _) -> (scos, [(findBindingScope sco, val)])
_ -> (scos, [])
where
collectSkolems :: Type -> [SkolemScope]
collectSkolems = nub . everythingOnTypes (++) collect
where
collect (Skolem _ _ scope) = [scope]
collect _ = []
go scos _ = (scos, [])
findBindingScope :: SkolemScope -> Maybe Expr
findBindingScope sco =
let (_, f, _, _, _) = everythingOnValues mappend (const mempty) go' (const mempty) (const mempty) (const mempty)
in getFirst $ f root
where
go' val@(TypedValue _ _ (ForAll _ _ (Just sco'))) | sco == sco' = First (Just val)
go' _ = mempty
skolemEscapeCheck val = throwError $ mkErrorStack "Untyped value passed to skolemEscapeCheck" (Just (ExprError val))