module Language.Haskell.TH.Context
( InstMap
, ContextM
, DecStatus(Declared, Undeclared, instanceDec)
, reifyInstancesWithContext
, tellInstance
, tellUndeclared
, noInstance
) where
import Control.Lens (view)
import Control.Monad (filterM)
import Control.Monad.Reader (ReaderT)
import Control.Monad.State (execStateT)
import Control.Monad.States (MonadStates, getPoly, modifyPoly)
import Control.Monad.Writer (MonadWriter, tell, WriterT)
import Data.Generics (everywhere, mkT)
import Data.List (intercalate)
import Data.Logic.ATP.TH (expandBindings )
import Data.Logic.ATP.Unif (Unify(unify'), unify)
import Data.Map as Map (elems, insert, lookup, Map)
import Data.Maybe (mapMaybe)
import Debug.Trace (trace)
import Language.Haskell.TH
import Language.Haskell.TH.Desugar as DS (DsMonad)
import Language.Haskell.TH.PprLib (cat, ptext)
import Language.Haskell.TH.Syntax hiding (lift)
import Language.Haskell.TH.Expand (ExpandMap, expandType, E, unE)
import Language.Haskell.TH.Instances ()
type InstMap = Map (E Pred) [DecStatus InstanceDec]
class (DsMonad m, MonadStates InstMap m, MonadStates ExpandMap m, MonadStates String m) => ContextM m
instance ContextM m => ContextM (ReaderT r m)
instance (ContextM m, Monoid w) => ContextM (WriterT w m)
data DecStatus a
= Declared {instanceDec :: a}
| Undeclared {instanceDec :: a}
deriving Show
instance Ppr a => Ppr (DecStatus a) where
ppr (Undeclared x) = cat [ptext "Undeclared (", ppr x, ptext ")"]
ppr (Declared x) = cat [ptext "Declared (", ppr x, ptext ")"]
reifyInstancesWithContext :: forall m. ContextM m => Name -> [Type] -> m [InstanceDec]
reifyInstancesWithContext className typeParameters = do
p <- expandType $ foldInstance className typeParameters
mp <- getPoly :: m InstMap
case Map.lookup p mp of
Just x -> return $ map instanceDec x
Nothing -> do
modifyPoly (" " ++)
pre <- getPoly :: m String
modifyPoly (Map.insert p [] :: InstMap -> InstMap)
insts <- qReifyInstances className typeParameters
r <- filterM (testInstance className typeParameters) insts
#ifdef DEBUG
trace (intercalate ("\n" ++ pre ++ " ")
((pre ++ "reifyInstancesWithContext " ++ pprint1 (foldInstance className typeParameters) ++ " -> [") :
map (\(InstanceD _ typ _) -> pprint1 typ) r) ++
"]") (return ())
#endif
modifyPoly (Map.insert p (map Declared r))
modifyPoly (drop 2 :: String -> String)
return r
testInstance :: ContextM m => Name -> [Type] -> InstanceDec -> m Bool
testInstance className typeParameters
#if MIN_VERSION_template_haskell(2,11,0)
(InstanceD Nothing instanceContext instanceType _)
#else
(InstanceD instanceContext instanceType _)
#endif
= do
mapM expandType (instancePredicates ++ instanceContext) >>= testContext . map (view unE)
where
instancePredicates :: [Pred]
instancePredicates = maybe (error $ "Invalid instance type: " ++ show instanceType) instanceEqualities (unfoldInstance instanceType)
instanceEqualities (_, instanceArgs)
| length instanceArgs /= length typeParameters =
error $ "type class arity error:" ++
"\n class name = " ++ show className ++
"\n type parameters = " ++ show typeParameters ++
"\n instance args = " ++ show instanceArgs
instanceEqualities (_, instanceArgs) = map (\(a, b) -> AppT (AppT EqualityT a) b) (zip typeParameters instanceArgs)
testInstance _ _ x = error $ "qReifyInstances returned something that doesn't appear to be an instance declaration: " ++ show x
testContext :: ContextM m => [Pred] -> m Bool
testContext context = and <$> (unify context mempty >>= \mp -> mapM consistent (everywhere (mkT (expandBindings mp)) context))
consistent :: ContextM m => Pred -> m Bool
consistent (AppT (AppT EqualityT a) b) | a == b = return True
consistent typ =
maybe (error $ "Unexpected Pred: " ++ pprint typ)
(\(className, typeParameters) -> (not . null) <$> reifyInstancesWithContext className typeParameters)
(unfoldInstance typ)
tellInstance :: ContextM m => Dec -> m ()
#if MIN_VERSION_template_haskell(2,11,0)
tellInstance inst@(InstanceD _ _ instanceType _) =
#else
tellInstance inst@(InstanceD _ instanceType _) =
#endif
do let Just (className, typeParameters) = unfoldInstance instanceType
p <- expandType $ foldInstance className typeParameters
(mp :: InstMap) <- getPoly
case Map.lookup p mp of
Just (_ : _) -> return ()
_ -> modifyPoly (Map.insert p [Undeclared inst])
tellInstance inst = error $ "tellInstance - Not an instance: " ++ pprint inst
tellUndeclared :: (MonadWriter [Dec] m, MonadStates InstMap m) => m ()
tellUndeclared =
getPoly >>= \(mp :: InstMap) -> tell . mapMaybe undeclared . concat . Map.elems $ mp
where
undeclared :: DecStatus Dec -> Maybe Dec
undeclared (Undeclared dec) = Just dec
undeclared (Declared _) = Nothing
foldInstance :: Name -> [Type] -> Pred
foldInstance className typeParameters = foldl AppT (ConT className) typeParameters
unfoldInstance :: Pred -> Maybe (Name, [Type])
unfoldInstance (ConT name) = Just (name, [])
unfoldInstance (AppT t1 t2) = maybe Nothing (\ (name, types) -> Just (name, types ++ [t2])) (unfoldInstance t1)
unfoldInstance _ = Nothing
noInstance :: forall m. ContextM m => Name -> Name -> m Bool
noInstance className typeName =
null <$> reifyInstancesWithContext className [ConT typeName]
#if 0
qReify typeName >>= doInfo >>= \typ -> null <$> reifyInstancesWithContext className [typ]
where
doInfo :: Info -> m Type
doInfo (TyConI dec) = doDec dec
doDec :: Dec -> m Type
doDec (NewtypeD cxt name tvbs con decs) = doDec (DataD cxt name tvbs [con] decs)
doDec (DataD _cxt name tvbs _cons _decs) = return $ foldl AppT (ConT name) (map (VarT . toName) tvbs)
doDec (TySynD name tvbs typ) = return $ foldl AppT (ConT name) (map (VarT . toName) tvbs)
toName (PlainTV x) = x
toName (KindedTV x _) = x
#endif
#if 0
noInstance className typeName = do
i <- qReify typeName
typ <- case i of
TyConI (TySynD _name _tvbs typ) ->
TyConI (DataD _cxt _name tvbs _fundeps _decs) ->
do vs <- mapM (\c -> VarT <$> runQ (newName [c])) (take (length tvbs) ['a'..'z'])
return $ foldl AppT (ConT typeName) vs
_ -> error $ "noInstance - " ++ show typeName ++ " has an invalid type: " ++ show i
r <- null <$> reifyInstancesWithContext className [typ]
#ifdef DEBUG
trace ("noInstance " ++ show className ++ " " ++ show typeName ++ " -> " ++ show r) (return ())
#endif
return r
#endif