{-# LANGUAGE CPP #-}
module Djinn.HCheck (
htCheckEnv,
htCheckType
) where
import Data.List (union)
#if MIN_VERSION_mtl(2,2,0)
import Control.Monad.Except ()
#else
import Control.Monad.Error ()
#endif
import Control.Monad.State
import Data.Graph (SCC (..), stronglyConnComp)
import Data.IntMap (IntMap, empty, insert, (!))
import Djinn.HTypes
type KState = (Int, IntMap (Maybe HKind))
initState :: KState
initState = (0, empty)
type M a = StateT KState (Either String) a
type KEnv = [(HSymbol, HKind)]
newKVar :: M HKind
newKVar = do
(i, m) <- get
put (i+1, insert i Nothing m)
return $ KVar i
getVar :: Int -> M (Maybe HKind)
getVar i = do
(_, m) <- get
case m!i of
Just (KVar i') -> getVar i'
mk -> return mk
addMap :: Int -> HKind -> M ()
addMap i k = do
(n, m) <- get
put (n, insert i (Just k) m)
clearState :: M ()
clearState = put initState
htCheckType :: [(HSymbol, ([HSymbol], HType, HKind))] -> HType -> Either String ()
htCheckType its t = flip evalStateT initState $ do
let vs = getHTVars t
ks <- mapM (const newKVar) vs
let env = zip vs ks ++ [(i, k) | (i, (_, _, k)) <- its ]
iHKindStar env t
htCheckEnv :: [(HSymbol, ([HSymbol], HType, a))] -> Either String [(HSymbol, ([HSymbol], HType, HKind))]
htCheckEnv its =
let graph = [ (n, i, getHTCons t) | n@(i, (_, t, _)) <- its ]
order = stronglyConnComp graph
in case [ c | CyclicSCC c <- order ] of
c : _ -> Left $ "Recursive types are not allowed: " ++ unwords [ i | (i, _) <- c ]
[] -> flip evalStateT initState $ addKinds
where addKinds = do
env <- inferHKinds [] $ map (\ (AcyclicSCC n) -> n) order
let getK i = maybe (error $ "htCheck " ++ i) id $ lookup i env
return [ (i, (vs, t, getK i)) | (i, (vs, t, _)) <- its ]
inferHKinds :: KEnv -> [(HSymbol, ([HSymbol], HType, a))] -> M KEnv
inferHKinds env [] = return env
inferHKinds env ((i, (vs, t, _)) : its) = do
k <- inferHKind env vs t
inferHKinds ((i, k) : env) its
inferHKind :: KEnv -> [HSymbol] -> HType -> M HKind
inferHKind _ _ (HTAbstract _ k) = return k
inferHKind env vs t = do
clearState
ks <- mapM (const newKVar) vs
let env' = zip vs ks ++ env
k <- iHKind env' t
ground $ foldr KArrow k ks
iHKind :: KEnv -> HType -> M HKind
iHKind env (HTApp f a) = do
kf <- iHKind env f
ka <- iHKind env a
r <- newKVar
unifyK (KArrow ka r) kf
return r
iHKind env (HTVar v) = do
getVarHKind env v
iHKind env (HTCon c) = do
getConHKind env c
iHKind env (HTTuple ts) = do
mapM_ (iHKindStar env) ts
return KStar
iHKind env (HTArrow f a) = do
iHKindStar env f
iHKindStar env a
return KStar
iHKind env (HTUnion cs) = do
mapM_ (\ (_, ts) -> mapM_ (iHKindStar env) ts) cs
return KStar
iHKind _ (HTAbstract _ _) = error "iHKind HTAbstract"
iHKindStar :: KEnv -> HType -> M ()
iHKindStar env t = do
k <- iHKind env t
unifyK k KStar
unifyK :: HKind -> HKind -> M ()
unifyK k1 k2 = do
let follow k@(KVar i) = getVar i >>= return . maybe k id
follow k = return k
unify KStar KStar = return ()
unify (KArrow k11 k12) (KArrow k21 k22) = do unifyK k11 k21; unifyK k12 k22
unify (KVar i1) (KVar i2) | i1 == i2 = return ()
unify (KVar i) k = do occurs i k; addMap i k
unify k (KVar i) = do occurs i k; addMap i k
unify _ _ = lift $ Left $ "kind error: " ++ show (k1, k2)
occurs _ KStar = return ()
occurs i (KArrow f a) = do follow f >>= occurs i; follow a >>= occurs i
occurs i (KVar i') = if i == i' then lift $ Left "cyclic kind" else return ()
k1' <- follow k1
k2' <- follow k2
unify k1' k2'
getVarHKind :: KEnv -> HSymbol -> M HKind
getVarHKind env v = case lookup v env of
Just k -> return k
Nothing -> lift $ Left $ "Undefined type variable " ++ v
getConHKind :: KEnv -> HSymbol -> M HKind
getConHKind env v = case lookup v env of
Just k -> return k
Nothing -> lift $ Left $ "Undefined type " ++ v
ground :: HKind -> M HKind
ground KStar = return KStar
ground (KArrow k1 k2) = liftM2 KArrow (ground k1) (ground k2)
ground (KVar i) = do
mk <- getVar i
case mk of
Just k -> return k
Nothing -> return KStar
getHTCons :: HType -> [HSymbol]
getHTCons (HTApp f a) = getHTCons f `union` getHTCons a
getHTCons (HTVar _) = []
getHTCons (HTCon s) = [s]
getHTCons (HTTuple ts) = foldr union [] (map getHTCons ts)
getHTCons (HTArrow f a) = getHTCons f `union` getHTCons a
getHTCons (HTUnion alts) = foldr union [] [ getHTCons t | (_, ts) <- alts, t <- ts ]
getHTCons (HTAbstract _ _) = []