module Control.Monad.Unify where
import Data.Maybe
import Data.Monoid
import Control.Applicative
import Control.Monad.State
import Control.Monad.Error.Class
import Data.HashMap.Strict as M
type Unknown = Int
class Partial t where
unknown :: Unknown -> t
isUnknown :: t -> Maybe Unknown
unknowns :: t -> [Unknown]
($?) :: Substitution t -> t -> t
class (Partial t) => Unifiable m t | t -> m where
(=?=) :: t -> t -> UnifyT t m ()
data Substitution t = Substitution { runSubstitution :: M.HashMap Int t }
instance (Partial t) => Monoid (Substitution t) where
mempty = Substitution M.empty
s1 `mappend` s2 = Substitution $
M.map (s2 $?) (runSubstitution s1) `M.union`
M.map (s1 $?) (runSubstitution s2)
data UnifyState t = UnifyState {
unifyNextVar :: Int
, unifyCurrentSubstitution :: Substitution t
}
defaultUnifyState :: (Partial t) => UnifyState t
defaultUnifyState = UnifyState 0 mempty
newtype UnifyT t m a = UnifyT { unUnify :: StateT (UnifyState t) m a }
deriving (Functor, Monad, Applicative, MonadPlus)
instance (MonadState s m) => MonadState s (UnifyT t m) where
get = UnifyT . lift $ get
put = UnifyT . lift . put
instance (MonadError e m) => MonadError e (UnifyT t m) where
throwError = UnifyT . throwError
catchError e f = UnifyT $ catchError (unUnify e) (unUnify . f)
runUnify :: UnifyState t -> UnifyT t m a -> m (a, UnifyState t)
runUnify s = flip runStateT s . unUnify
substituteOne :: (Partial t) => Unknown -> t -> Substitution t
substituteOne u t = Substitution $ M.singleton u t
(=:=) :: (Error e, Monad m, MonadError e m, Unifiable m t) => Unknown -> t -> UnifyT t m ()
(=:=) u t' = do
st <- UnifyT get
let sub = unifyCurrentSubstitution st
let t = sub $? t'
occursCheck u t
let current = sub $? unknown u
case isUnknown current of
Just u1 | u1 == u -> return ()
_ -> current =?= t
UnifyT $ modify $ \s -> s { unifyCurrentSubstitution = substituteOne u t <> unifyCurrentSubstitution s }
occursCheck :: (Error e, Monad m, MonadError e m, Partial t) => Unknown -> t -> UnifyT t m ()
occursCheck u t =
case isUnknown t of
Nothing -> when (u `elem` unknowns t) $ UnifyT . lift . throwError . strMsg $ "Occurs check fails"
_ -> return ()
fresh' :: (Monad m) => UnifyT t m Unknown
fresh' = do
st <- UnifyT get
UnifyT $ modify $ \s -> s { unifyNextVar = succ (unifyNextVar s) }
return $ unifyNextVar st
fresh :: (Monad m, Partial t) => UnifyT t m t
fresh = do
u <- fresh'
return $ unknown u