module Control.Monad.Unify where
import Data.Data
import Data.Maybe
import Data.Monoid
import Data.Generics (mkT, mkQ, everywhere, everything)
import Control.Applicative
import Control.Monad.State
import Control.Monad.Error
import Data.HashMap.Strict as M
newtype Unknown = Unknown {
runUnknown :: Int
} deriving (Show, Eq, Ord, Data, Typeable)
class (Typeable t, Data t) => Partial t where
unknown :: Unknown -> t
isUnknown :: t -> Maybe Unknown
class (Partial t) => Unifiable m t | t -> m where
(=?=) :: t -> t -> UnifyT t m ()
data Substitution t = Substitution { runSubstitution :: M.HashMap Int t }
instance (Partial t) => Monoid (Substitution t) where
mempty = Substitution M.empty
s1 `mappend` s2 = Substitution $
M.map (s2 $?) (runSubstitution s1) `M.union`
M.map (s1 $?) (runSubstitution s2)
($?) :: (Partial t) => Substitution t -> t -> t
($?) sub = everywhere (mkT go)
where
go t =
case isUnknown t of
Nothing -> t
Just (Unknown u) -> case M.lookup u (runSubstitution sub) of
Nothing -> t
Just t' -> t'
data UnifyState t = UnifyState {
unifyNextVar :: Int
, unifyCurrentSubstitution :: Substitution t
}
defaultUnifyState :: (Partial t) => UnifyState t
defaultUnifyState = UnifyState 0 mempty
newtype UnifyT t m a = UnifyT { unUnify :: StateT (UnifyState t) (ErrorT String m) a }
deriving (Functor, Monad, Applicative, MonadPlus, MonadError String)
instance (MonadState s m) => MonadState s (UnifyT t m) where
get = UnifyT . lift $ get
put = UnifyT . lift . put
unknowns :: (Data d) => d -> [Unknown]
unknowns = everything (++) (mkQ [] collect)
where
collect u@(Unknown _) = [u]
runUnify :: UnifyState t -> UnifyT t m a -> m (Either String (a, UnifyState t))
runUnify s = runErrorT . flip runStateT s . unUnify
substituteOne :: (Partial t) => Unknown -> t -> Substitution t
substituteOne (Unknown u) t = Substitution $ M.singleton u t
(=:=) :: (Monad m, Unifiable m t) => Unknown -> t -> UnifyT t m ()
(=:=) u t' = do
st <- UnifyT get
let sub = unifyCurrentSubstitution st
let t = sub $? t'
occursCheck u t
let current = sub $? unknown u
case isUnknown current of
Just u1 | u1 == u -> return ()
_ -> current =?= t
UnifyT $ modify $ \s -> s { unifyCurrentSubstitution = substituteOne u t <> unifyCurrentSubstitution s }
occursCheck :: (Monad m, Partial t) => Unknown -> t -> UnifyT t m ()
occursCheck u t =
case isUnknown t of
Nothing -> when (u `elem` unknowns t) $ UnifyT . lift $ throwError "Occurs check fails"
_ -> return ()
fresh' :: (Monad m) => UnifyT t m Unknown
fresh' = do
st <- UnifyT get
UnifyT $ modify $ \s -> s { unifyNextVar = succ (unifyNextVar s) }
return $ Unknown (unifyNextVar st)
fresh :: (Monad m, Partial t) => UnifyT t m t
fresh = do
u <- fresh'
return $ unknown u