{-# LANGUAGE DeriveAnyClass, DeriveGeneric, LambdaCase, TemplateHaskell, TypeApplications #-}
module TreeSitter.GenerateSyntax
( syntaxDatatype
, astDeclarationsForLanguage
) where
import Language.Haskell.TH as TH
import Language.Haskell.TH.Syntax as TH
import TreeSitter.Deserialize (Datatype (..), DatatypeName (..), Field (..), Children(..), Required (..), Type (..), Named (..), Multiple (..))
import Data.List.NonEmpty (NonEmpty (..))
import Data.List
import Data.Foldable
import Data.Text (Text)
import qualified TreeSitter.Unmarshal as TS
import GHC.Generics hiding (Constructor, Datatype)
import Foreign.Ptr
import qualified TreeSitter.Language as TS
import Foreign.C.String
import Data.Aeson hiding (String)
import System.Directory
import System.FilePath.Posix
import TreeSitter.Node
import TreeSitter.Token
import TreeSitter.Symbol (TSSymbol, toHaskellCamelCaseIdentifier, toHaskellPascalCaseIdentifier)
astDeclarationsForLanguage :: Ptr TS.Language -> FilePath -> Q [Dec]
astDeclarationsForLanguage language filePath = do
_ <- TS.addDependentFileRelative filePath
currentFilename <- loc_filename <$> location
pwd <- runIO getCurrentDirectory
let invocationRelativePath = takeDirectory (pwd </> currentFilename) </> filePath
input <- runIO (eitherDecodeFileStrict' invocationRelativePath)
allSymbols <- runIO (getAllSymbols language)
either fail (fmap (concat @[]) . traverse (syntaxDatatype language allSymbols)) input
getAllSymbols :: Ptr TS.Language -> IO [(String, Named)]
getAllSymbols language = do
count <- TS.ts_language_symbol_count language
mapM getSymbol [(0 :: TSSymbol) .. fromIntegral (pred count)]
where
getSymbol i = do
cname <- TS.ts_language_symbol_name language i
n <- peekCString cname
t <- TS.ts_language_symbol_type language i
let named = if t == 0 then Named else Anonymous
pure (n, named)
syntaxDatatype :: Ptr TS.Language -> [(String, Named)] -> Datatype -> Q [Dec]
syntaxDatatype language allSymbols datatype = skipDefined $ do
typeParameterName <- newName "a"
case datatype of
SumType (DatatypeName _) _ subtypes -> do
types' <- fieldTypesToNestedSum subtypes
con <- normalC name [TH.bangType strictness (pure types' `appT` varT typeParameterName)]
pure [NewtypeD [] name [PlainTV typeParameterName] Nothing con [deriveGN, deriveStockClause, deriveAnyClassClause]]
ProductType (DatatypeName datatypeName) named children fields -> do
con <- ctorForProductType datatypeName typeParameterName children fields
result <- symbolMatchingInstance allSymbols name named datatypeName
pure $ generatedDatatype name [con] typeParameterName:result
LeafType (DatatypeName datatypeName) Anonymous -> do
tsSymbol <- runIO $ withCStringLen datatypeName (\(s, len) -> TS.ts_language_symbol_for_name language s len False)
pure [ TySynD name [] (ConT ''Token `AppT` LitT (StrTyLit datatypeName) `AppT` LitT (NumTyLit (fromIntegral tsSymbol))) ]
LeafType (DatatypeName datatypeName) Named -> do
con <- ctorForLeafType (DatatypeName datatypeName) typeParameterName
result <- symbolMatchingInstance allSymbols name Named datatypeName
pure $ generatedDatatype name [con] typeParameterName:result
where
skipDefined m = do
isLocal <- lookupTypeName nameStr >>= maybe (pure False) isLocalName
if isLocal then pure [] else m
name = mkName nameStr
nameStr = toNameString (datatypeNameStatus datatype) (getDatatypeName (TreeSitter.Deserialize.datatypeName datatype))
deriveStockClause = DerivClause (Just StockStrategy) [ ConT ''Eq, ConT ''Ord, ConT ''Show, ConT ''Generic, ConT ''Foldable, ConT ''Functor, ConT ''Traversable, ConT ''Generic1]
deriveAnyClassClause = DerivClause (Just AnyclassStrategy) [ConT ''TS.Unmarshal]
deriveGN = DerivClause (Just NewtypeStrategy) [ConT ''TS.SymbolMatching]
generatedDatatype name cons typeParameterName = DataD [] name [PlainTV typeParameterName] Nothing cons [deriveStockClause, deriveAnyClassClause]
symbolMatchingInstance :: [(String, Named)] -> Name -> Named -> String -> Q [Dec]
symbolMatchingInstance allSymbols name named str = do
let tsSymbols = elemIndices (str, named) allSymbols
let names = intercalate ", " $ fmap (debugPrefix . (!!) allSymbols) tsSymbols
[d|instance TS.SymbolMatching $(conT name) where
showFailure _ node = "expected " <> $(litE (stringL (show names))) <> " but got " <> show (debugPrefix (allSymbols !! fromIntegral (nodeSymbol node)))
symbolMatch _ node = elem (nodeSymbol node) tsSymbols|]
debugPrefix :: (String, Named) -> String
debugPrefix (name, Named) = name
debugPrefix (name, Anonymous) = "_" <> name
ctorForProductType :: String -> Name -> Maybe Children -> [(String, Field)] -> Q Con
ctorForProductType constructorName typeParameterName children fields = ctorForTypes constructorName lists where
lists = annotation : fieldList ++ childList
annotation = ("ann", varT typeParameterName)
fieldList = map (fmap toType) fields
childList = toList $ fmap toTypeChild children
toType (MkField required fieldTypes mult) =
let ftypes = fieldTypesToNestedSum fieldTypes `appT` varT typeParameterName
in case (required, mult) of
(Required, Multiple) -> appT (conT ''NonEmpty) ftypes
(Required, Single) -> ftypes
(Optional, Multiple) -> appT (conT ''[]) ftypes
(Optional, Single) -> appT (conT ''Maybe) ftypes
toTypeChild (MkChildren field) = ("extra_children", toType field)
ctorForLeafType :: DatatypeName -> Name -> Q Con
ctorForLeafType (DatatypeName name) typeParameterName = ctorForTypes name
[ ("ann", varT typeParameterName)
, ("text", conT ''Text)
]
ctorForTypes :: String -> [(String, Q TH.Type)] -> Q Con
ctorForTypes constructorName types = recC (toName Named constructorName) recordFields where
recordFields = map (uncurry toVarBangType) types
toVarBangType str type' = TH.varBangType (mkName . toHaskellCamelCaseIdentifier $ str) (TH.bangType strictness type')
fieldTypesToNestedSum :: NonEmpty TreeSitter.Deserialize.Type -> Q TH.Type
fieldTypesToNestedSum xs = go (toList xs)
where
combine lhs rhs = (conT ''(:+:) `appT` lhs) `appT` rhs
convertToQType (MkType (DatatypeName n) named) = conT (toName named n)
go [x] = convertToQType x
go xs = let (l,r) = splitAt (length xs `div` 2) xs in (combine (go l) (go r))
strictness :: BangQ
strictness = TH.bang noSourceUnpackedness noSourceStrictness
toName :: Named -> String -> Name
toName named str = mkName (toNameString named str)
toNameString :: Named -> String -> String
toNameString named str = prefix named <> toHaskellPascalCaseIdentifier str
where
prefix Anonymous = "Anonymous"
prefix Named = ""
moduleForName :: Name -> Maybe Module
moduleForName n = Module . PkgName <$> namePackage n <*> (ModName <$> nameModule n)
isLocalName :: Name -> Q Bool
isLocalName n = (moduleForName n ==) . Just <$> thisModule