module Control.Monad.Unify where
import Data.Data
import Data.Maybe
import Data.Monoid
import Data.Generics (mkT, mkQ, everywhere, everything)
import Control.Applicative
import Control.Monad.State
import Control.Monad.Error
newtype Unknown = Unknown {
runUnknown :: Int
} deriving (Show, Eq, Ord, Data, Typeable)
newtype TypedUnknown t = TypedUnknown {
runTypedUnknown :: Unknown
} deriving (Show, Eq, Ord, Data, Typeable)
class (Typeable t, Data t, Show t) => Unifiable m t | t -> m where
unknown :: TypedUnknown t -> t
isUnknown :: t -> Maybe (TypedUnknown t)
(?=) :: t -> t -> UnifyT m ()
newtype Substitution = Substitution { runSubstitution :: forall d. (Data d) => d -> d }
instance Monoid Substitution where
mempty = Substitution id
s1 `mappend` s2 = Substitution $ runSubstitution s1 . runSubstitution s2
data UnifyState = UnifyState {
unifyNextVar :: Int
, unifyCurrentSubstitution :: Substitution
}
defaultUnifyState :: UnifyState
defaultUnifyState = UnifyState 0 mempty
newtype UnifyT m a = UnifyT { unUnify :: StateT UnifyState (ErrorT String m) a }
deriving (Functor, Monad, Applicative, MonadPlus, MonadError String)
instance (MonadState s m) => MonadState s (UnifyT m) where
get = UnifyT . lift $ get
put = UnifyT . lift . put
unknowns :: (Data d) => d -> [Unknown]
unknowns = everything (++) (mkQ [] collect)
where
collect u@(Unknown _) = [u]
runUnify :: UnifyState -> UnifyT m a -> m (Either String (a, UnifyState))
runUnify s = runErrorT . flip runStateT s . unUnify
substituteOne :: (Unifiable m t) => TypedUnknown t -> t -> Substitution
substituteOne u t = Substitution $ everywhere (mkT go)
where
go t' = case isUnknown t' of
Just u1 | u1 == u -> t
_ -> t'
replace :: (Monad m, Unifiable m t) => TypedUnknown t -> t -> UnifyT m ()
replace u t' = do
st <- UnifyT get
let sub = unifyCurrentSubstitution st
let t = runSubstitution sub t'
occursCheck u t
let current = runSubstitution sub $ unknown u
case isUnknown current of
Just u1 | u1 == u -> return ()
_ -> current ?= t
UnifyT $ modify $ \s -> s { unifyCurrentSubstitution = substituteOne u t <> unifyCurrentSubstitution s }
occursCheck :: (Monad m, Unifiable m t) => TypedUnknown s -> t -> UnifyT m ()
occursCheck (TypedUnknown u) t =
case isUnknown t of
Nothing -> when (u `elem` unknowns t) $ UnifyT . lift $ throwError "Occurs check fails"
_ -> return ()
fresh' :: (Monad m) => UnifyT m Unknown
fresh' = do
st <- UnifyT get
UnifyT $ modify $ \s -> s { unifyNextVar = succ (unifyNextVar s) }
return $ Unknown (unifyNextVar st)
fresh :: (Monad m, Unifiable m t) => UnifyT m t
fresh = do
u <- fresh'
return $ unknown $ TypedUnknown u