{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE DeriveAnyClass #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE TypeApplications #-}
module TreeSitter.GenerateSyntax
( syntaxDatatype
, removeUnderscore
, initUpper
, astDeclarationsForLanguage
) where
import Data.Char
import Language.Haskell.TH as TH
import Data.HashSet (HashSet)
import TreeSitter.Deserialize (Datatype (..), DatatypeName (..), Field (..), Children(..), Required (..), Type (..), Named (..), Multiple (..))
import Data.List.NonEmpty (NonEmpty (..))
import Data.Foldable
import Data.Text (Text)
import qualified Data.HashSet as HashSet
import qualified TreeSitter.Unmarshal as TS
import GHC.Generics hiding (Constructor, Datatype)
import Foreign.Ptr
import qualified TreeSitter.Language as TS
import Foreign.C.String
import Data.Proxy
import Data.Aeson hiding (String)
import System.Directory
import System.FilePath.Posix
import TreeSitter.Node
import TreeSitter.Symbol (escapeOperatorPunctuation)
astDeclarationsForLanguage :: Ptr TS.Language -> FilePath -> Q [Dec]
astDeclarationsForLanguage language filePath = do
_ <- TS.addDependentFileRelative filePath
currentFilename <- loc_filename <$> location
pwd <- runIO getCurrentDirectory
let invocationRelativePath = takeDirectory (pwd </> currentFilename) </> filePath
input <- runIO (eitherDecodeFileStrict' invocationRelativePath)
either fail (fmap (concat @[]) . traverse (syntaxDatatype language)) input
syntaxDatatype :: Ptr TS.Language -> Datatype -> Q [Dec]
syntaxDatatype language datatype = case datatype of
SumType (DatatypeName datatypeName) _ subtypes -> do
cons <- traverse (constructorForSumChoice datatypeName) subtypes
result <- symbolMatchingInstanceForSums language name subtypes
pure $ generatedDatatype name cons:result
ProductType (DatatypeName datatypeName) _ children fields -> do
con <- ctorForProductType datatypeName children fields
result <- symbolMatchingInstance language name datatypeName
pure $ generatedDatatype name [con]:result
LeafType (DatatypeName datatypeName) named -> do
con <- ctorForLeafType named (DatatypeName datatypeName)
result <- symbolMatchingInstance language name datatypeName
pure $ case named of
Anonymous -> generatedDatatype name [con]:result
Named -> NewtypeD [] name [] Nothing con deriveClause:result
where
name = toName (datatypeNameStatus datatype) (getDatatypeName (TreeSitter.Deserialize.datatypeName datatype))
deriveClause = [ DerivClause Nothing [ ConT ''TS.Unmarshal, ConT ''Eq, ConT ''Ord, ConT ''Show, ConT ''Generic ] ]
generatedDatatype name cons = DataD [] name [] Nothing cons deriveClause
symbolMatchingInstance :: Ptr TS.Language -> Name -> String -> Q [Dec]
symbolMatchingInstance language name str = do
tsSymbol <- runIO $ withCString str (TS.ts_language_symbol_for_name language)
tsSymbolType <- toEnum <$> runIO (TS.ts_language_symbol_type language tsSymbol)
[d|instance TS.SymbolMatching $(conT name) where
showFailure _ node = "Expected " <> $(litE (stringL (show name))) <> " but got " <> show (TS.fromTSSymbol (nodeSymbol node) :: $(conT (mkName "Grammar.Grammar")))
symbolMatch _ node = TS.fromTSSymbol (nodeSymbol node) == $(conE (mkName $ "Grammar." <> TS.symbolToName tsSymbolType str))|]
symbolMatchingInstanceForSums :: Ptr TS.Language -> Name -> [TreeSitter.Deserialize.Type] -> Q [Dec]
symbolMatchingInstanceForSums _ name subtypes =
[d|instance TS.SymbolMatching $(conT name) where
showFailure _ node = "Expected " <> $(litE (stringL (show (map extractn subtypes)))) <> " but got " <> show (TS.fromTSSymbol (nodeSymbol node) :: $(conT (mkName "Grammar.Grammar")))
symbolMatch _ = $(foldr1 mkOr (perMkType `map` subtypes)) |]
where perMkType (MkType (DatatypeName n) named) = [e|TS.symbolMatch (Proxy :: Proxy $(conT (toName named n))) |]
mkOr lhs rhs = [e| (||) <$> $(lhs) <*> $(rhs) |]
extractn (MkType (DatatypeName n) Named) = toCamelCase n
extractn (MkType (DatatypeName n) Anonymous) = "Anonymous" <> toCamelCase n
constructorForSumChoice :: String -> TreeSitter.Deserialize.Type -> Q Con
constructorForSumChoice str (MkType (DatatypeName n) named) = normalC (toName named (n ++ str)) [child]
where child = TH.bangType (TH.bang noSourceUnpackedness noSourceStrictness) (conT (toName named n))
ctorForProductType :: String -> Maybe Children -> [(String, Field)] -> Q Con
ctorForProductType constructorName children fields = recC (toName Named constructorName) lists where
lists = fieldList ++ childList
fieldList = fmap (uncurry toVarBangType) fields
childList = toList $ fmap toVarBangTypeChild children
toVarBangType name (MkField required fieldTypes mult) =
let fieldName = mkName . addTickIfNecessary . removeUnderscore $ name
strictness = TH.bang noSourceUnpackedness noSourceStrictness
ftypes = fieldTypesToNestedEither fieldTypes
fieldContents = case (required, mult) of
(Required, Multiple) -> appT (conT ''NonEmpty) ftypes
(Required, Single) -> ftypes
(Optional, Multiple) -> appT (conT ''[]) ftypes
(Optional, Single) -> appT (conT ''Maybe) ftypes
in TH.varBangType fieldName (TH.bangType strictness fieldContents)
toVarBangTypeChild (MkChildren field) = toVarBangType "extra_children" field
ctorForLeafType :: Named -> DatatypeName -> Q Con
ctorForLeafType Anonymous (DatatypeName name) = normalC (toName Anonymous name) []
ctorForLeafType Named (DatatypeName name) = recC (toName Named name) [leafBytes] where
leafBytes = TH.varBangType (mkName "bytes") textValue
textValue = TH.bangType (TH.bang noSourceUnpackedness noSourceStrictness) (conT ''Text)
fieldTypesToNestedEither :: NonEmpty TreeSitter.Deserialize.Type -> Q TH.Type
fieldTypesToNestedEither xs = foldr1 combine $ fmap convertToQType xs
where
combine convertedQType = appT (appT (conT ''Either) convertedQType)
convertToQType (MkType (DatatypeName n) named) = conT (toName named n)
toCamelCase :: String -> String
toCamelCase = initUpper . escapeOperatorPunctuation . removeUnderscore
clashingNames :: HashSet String
clashingNames = HashSet.fromList ["type", "module", "data"]
addTickIfNecessary :: String -> String
addTickIfNecessary s
| HashSet.member s clashingNames = s ++ "'"
| otherwise = s
toName :: Named -> String -> Name
toName named str = mkName $ addTickIfNecessary $ case named of
Anonymous -> "Anonymous" <> toCamelCase str
Named -> toCamelCase str
initUpper :: String -> String
initUpper (c:cs) = toUpper c : cs
initUpper "" = ""
removeUnderscore :: String -> String
removeUnderscore = foldr appender ""
where appender :: Char -> String -> String
appender '_' cs = initUpper cs
appender c cs = c : cs