{-# LANGUAGE TemplateHaskellQuotes, DerivingVia #-}
module AST.TH.Internal.Utils
(
TypeInfo(..), TypeContents(..), CtrTypePattern(..), NodeWitnesses(..)
, makeTypeInfo, makeNodeOf
, parts, toTuple, matchType
, applicativeStyle, unapply, getVar, makeConstructorVars
, consPat, simplifyContext
) where
import AST.Class.Nodes
import AST.Knot (Knot(..), GetKnot, type (#))
import qualified Control.Lens as Lens
import Control.Lens.Operators
import Control.Monad.Trans.Class (MonadTrans(..))
import Control.Monad.Trans.State (StateT(..), evalStateT, execStateT, gets, modify)
import Data.Foldable (traverse_)
import Data.List (nub)
import Data.Map (Map)
import qualified Data.Map as Map
import Data.Maybe (fromMaybe)
import Data.Set (Set)
import qualified Data.Set as Set
import Generic.Data (Generically(..))
import GHC.Generics (Generic)
import Language.Haskell.TH
import qualified Language.Haskell.TH.Datatype as D
import Prelude.Compat
data TypeInfo = TypeInfo
{ tiName :: Name
, tiInstance :: Type
, tiVar :: Name
, tiContents :: TypeContents
, tiCons :: [D.ConstructorInfo]
} deriving Show
data TypeContents = TypeContents
{ tcChildren :: Set Type
, tcEmbeds :: Set Type
, tcOthers :: Set Type
} deriving (Show, Generic)
deriving (Semigroup, Monoid) via Generically TypeContents
makeTypeInfo :: Name -> Q TypeInfo
makeTypeInfo name =
do
info <- D.reifyDatatype name
(dst, var) <- parts info
contents <- evalStateT (childrenTypes var (AppT dst (VarT var))) mempty
pure TypeInfo
{ tiName = name
, tiInstance = dst
, tiVar = var
, tiContents = contents
, tiCons = D.datatypeCons info
}
parts :: D.DatatypeInfo -> Q (Type, Name)
parts info =
case D.datatypeVars info of
[] -> fail "expected type constructor which requires arguments"
xs ->
case last xs of
KindedTV var (ConT knot) | knot == ''Knot -> pure (res, var)
PlainTV var -> pure (res, var)
_ -> fail "expected last argument to be a knot variable"
where
res =
foldl AppT (ConT (D.datatypeName info)) (init xs <&> VarT . D.tvName)
childrenTypes ::
Name -> Type -> StateT (Set Type) Q TypeContents
childrenTypes var typ =
do
did <- gets (^. Lens.contains typ)
if did
then pure mempty
else modify (Lens.contains typ .~ True) *> add (matchType var typ)
where
add (NodeFofX ast) = pure mempty { tcChildren = Set.singleton ast }
add (XofF ast) =
case unapply ast of
(ConT name, as) -> childrenTypesFromTypeName name as
(x@VarT{}, as) -> pure mempty { tcEmbeds = Set.singleton (foldl AppT x as) }
_ -> pure mempty
add (Tof _ pat) = add pat
add Other{} = pure mempty
unapply :: Type -> (Type, [Type])
unapply =
go []
where
go as (SigT x _) = go as x
go as (AppT f a) = go (a:as) f
go as x = (x, as)
matchType :: Name -> Type -> CtrTypePattern
matchType var (ConT runKnot `AppT` VarT k `AppT` (PromotedT knot `AppT` ast))
| runKnot == ''GetKnot && knot == 'Knot && k == var =
NodeFofX ast
matchType var (ConT tie `AppT` VarT k `AppT` ast)
| tie == ''(#) && k == var =
NodeFofX ast
matchType var (ast `AppT` VarT knot)
| knot == var && ast /= ConT ''GetKnot =
XofF ast
matchType var x@(AppT t typ) =
case matchType var typ of
Other{} -> Other x
pat -> Tof t pat
matchType _ t = Other t
data CtrTypePattern
= NodeFofX Type
| XofF Type
| Tof Type CtrTypePattern
| Other Type
deriving Show
getVar :: Type -> Maybe Name
getVar (VarT x) = Just x
getVar (SigT x _) = getVar x
getVar _ = Nothing
childrenTypesFromTypeName ::
Name -> [Type] -> StateT (Set Type) Q TypeContents
childrenTypesFromTypeName name args =
reifyInstances ''KNodesConstraint [typ, VarT constraintVar] & lift
>>=
\case
[] ->
D.reifyDatatype name
<&> Just
& recover (pure Nothing)
& lift
>>=
\case
Just info ->
do
(_, var) <- parts info & lift
D.datatypeCons info >>= D.constructorFields
<&> D.applySubstitution substs
& traverse (childrenTypes var)
<&> mconcat
where
substs =
zip (D.datatypeVars info) args
<&> Lens._1 %~ D.tvName
& Map.fromList
Nothing ->
pure mempty { tcEmbeds = Set.singleton typ }
[TySynInstD ccI (TySynEqn [typI, VarT cI] x)]
| ccI == ''KNodesConstraint ->
case unapply typI of
(ConT n1, argsI) | n1 == name ->
case traverse getVar argsI of
Nothing ->
error ("TODO: Support Children constraint of flexible instances " <> show typ)
Just argNames ->
childrenTypesFromChildrenConstraint cI (D.applySubstitution substs x)
where
substs = zip argNames args & Map.fromList
_ -> error ("ReifyInstances brought wrong typ: " <> show (name, typI))
xs -> error ("Malformed ChildrenConstraint instance: " <> show xs)
where
typ = foldl AppT (ConT name) args
constraintVar :: Name
constraintVar = mkName "constraint"
childrenTypesFromChildrenConstraint ::
Name -> Type -> StateT (Set Type) Q TypeContents
childrenTypesFromChildrenConstraint c0 c@(AppT (VarT c1) x)
| c0 == c1 = pure mempty { tcChildren = Set.singleton x }
| otherwise = error ("TODO: Unsupported ChildrenContraint " <> show c)
childrenTypesFromChildrenConstraint c0 constraints =
case unapply constraints of
(ConT cc1, [x, VarT c1])
| cc1 == ''KNodesConstraint && c0 == c1 ->
pure mempty { tcEmbeds = Set.singleton x }
(TupleT{}, xs) ->
traverse (childrenTypesFromChildrenConstraint c0) xs <&> mconcat
_ -> pure mempty { tcOthers = Set.singleton (D.applySubstitution subst constraints) }
where
subst = mempty & Lens.at c0 ?~ VarT constraintVar
toTuple :: Foldable t => t Type -> Type
toTuple xs = foldl AppT (TupleT (length xs)) xs
applicativeStyle :: Exp -> [Exp] -> Exp
applicativeStyle f =
foldl ap (AppE (VarE 'pure) f)
where
ap x y = InfixE (Just x) (VarE '(<*>)) (Just y)
makeConstructorVars :: String -> D.ConstructorInfo -> [(Type, Name)]
makeConstructorVars prefix cons =
[0::Int ..] <&> show <&> (('_':prefix) <>) <&> mkName
& zip (D.constructorFields cons)
consPat :: D.ConstructorInfo -> [(Type, Name)] -> Pat
consPat cons vars =
ConP (D.constructorName cons) (vars <&> snd <&> VarP)
simplifyContext :: [Pred] -> CxtQ
simplifyContext preds =
goPreds preds
& (`execStateT` (mempty :: Set (Name, [Type]), mempty :: Set Pred))
<&> snd
<&> Set.toList
where
goPreds ps = ps <&> unapply & traverse_ go
go (c, [VarT v]) =
yep c [VarT v]
go (ConT c, xs) =
Lens.use (Lens._1 . Lens.contains key)
>>=
\case
True -> pure ()
False ->
do
Lens._1 . Lens.contains key .= True
reifyInstances c xs & lift
>>=
\case
[InstanceD _ context other _] ->
D.unifyTypes [foldl AppT (ConT c) xs, other] & lift
<&> (`D.applySubstitution` context)
>>= goPreds
_ -> yep (ConT c) xs
where
key = (c, xs)
go (c, xs) = yep c xs
yep c xs = Lens._2 . Lens.contains (foldl AppT c xs) .= True
data NodeWitnesses = NodeWitnesses
{ nodeWit :: Type -> Exp
, embedWit :: Type -> Exp
, nodeWitCtrs :: [Name]
, embedWitCtrs :: [Name]
}
makeNodeOf :: TypeInfo -> ([Con], NodeWitnesses)
makeNodeOf info =
( (nodes <&> Lens._1 %~ nodeGadtType) <> (embeds <&> Lens._1 %~ embedGadtType)
<&> \(t, n) -> GadtC [n] [] t
, NodeWitnesses
{ nodeWit = nodes <&> Lens._2 %~ ConE & Map.fromList & getWit
, embedWit = embeds <&> Lens._2 %~ ConE & Map.fromList & getWit
, nodeWitCtrs = nodes <&> snd
, embedWitCtrs = embeds <&> snd
}
)
where
niceTypeName = tiName info & show & makeNiceName
makeNiceName = reverse . takeWhile (/= '.') . reverse
nodeBase = "W_" <> niceTypeName <> "_"
embedBase = "E_" <> niceTypeName <> "_"
pats =
tiCons info
>>= D.constructorFields
<&> matchType (tiVar info)
makeNiceType (ConT x) = makeNiceName (show x)
makeNiceType (AppT x y) = makeNiceType x <> "_" <> makeNiceType y
makeNiceType (VarT x) = takeWhile (/= '_') (show x)
makeNiceType (SigT x _) = makeNiceType x
makeNiceType x = error ("TODO: Witness name generator is partial! Need to support " <> show x)
nodes =
pats >>= nodesForPat & nub
<&> \t -> (t, mkName (nodeBase <> makeNiceType t))
nodesForPat (NodeFofX t) = [t]
nodesForPat (Tof _ pat) = nodesForPat pat
nodesForPat _ = []
nodeGadtType t = ConT ''KWitness `AppT` tiInstance info `AppT` t
embeds =
pats >>= embedsForPat & nub
<&> \t -> (t, mkName (embedBase <> makeNiceType t))
embedsForPat (XofF t) = [t]
embedsForPat (Tof _ pat) = embedsForPat pat
embedsForPat _ = []
embedGadtType t =
ArrowT
`AppT` (ConT ''KWitness `AppT` t `AppT` VarT nodeVar)
`AppT` (ConT ''KWitness `AppT` tiInstance info `AppT` VarT nodeVar)
nodeVar = mkName "node"
getWit :: Map Type Exp -> Type -> Exp
getWit m k =
m ^? Lens.ix k
& fromMaybe (LitE (StringL ("Cant find witness for " <> show k <> " in " <> show m)))