{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE DeriveAnyClass #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE TypeApplications #-}

-- {-# LANGUAGE TypeOperators #-}
module TreeSitter.GenerateSyntax
( syntaxDatatype
, removeUnderscore
, initUpper
, astDeclarationsForLanguage
-- * Internal functions exposed for testing

) where

import Data.Char
import Language.Haskell.TH as TH
import Data.HashSet (HashSet)
import TreeSitter.Deserialize (Datatype (..), DatatypeName (..), Field (..), Children(..), Required (..), Type (..), Named (..), Multiple (..))
import Data.List.NonEmpty (NonEmpty (..))
import Data.Foldable
import Data.Text (Text)
import qualified Data.HashSet as HashSet
import qualified TreeSitter.Unmarshal as TS
import GHC.Generics hiding (Constructor, Datatype)
import Foreign.Ptr
import qualified TreeSitter.Language as TS
import Foreign.C.String
import Data.Proxy
import Data.Aeson hiding (String)
import System.Directory
import System.FilePath.Posix
import TreeSitter.Node
import TreeSitter.Symbol (escapeOperatorPunctuation)


-- Auto-generate Haskell datatypes from node-types.json
astDeclarationsForLanguage :: Ptr TS.Language -> FilePath -> Q [Dec]
astDeclarationsForLanguage language filePath = do
  _ <- TS.addDependentFileRelative filePath
  currentFilename <- loc_filename <$> location
  pwd             <- runIO getCurrentDirectory
  let invocationRelativePath = takeDirectory (pwd </> currentFilename) </> filePath
  input <- runIO (eitherDecodeFileStrict' invocationRelativePath)
  either fail (fmap (concat @[]) . traverse (syntaxDatatype language)) input


-- Auto-generate Haskell datatypes for sums, products and leaf types
syntaxDatatype :: Ptr TS.Language -> Datatype -> Q [Dec]
syntaxDatatype language datatype = case datatype of
  SumType (DatatypeName datatypeName) _ subtypes -> do
    cons <- traverse (constructorForSumChoice datatypeName) subtypes
    result <- symbolMatchingInstanceForSums language name subtypes
    pure $ generatedDatatype name cons:result
  ProductType (DatatypeName datatypeName) _ children fields -> do
    con <- ctorForProductType datatypeName children fields
    result <- symbolMatchingInstance language name datatypeName
    pure $ generatedDatatype name [con]:result
  LeafType (DatatypeName datatypeName) named -> do
    con <- ctorForLeafType named (DatatypeName datatypeName)
    result <- symbolMatchingInstance language name datatypeName
    pure $ case named of
      Anonymous -> generatedDatatype name [con]:result
      Named -> NewtypeD [] name [] Nothing con deriveClause:result
  where
    name = toName (datatypeNameStatus datatype) (getDatatypeName (TreeSitter.Deserialize.datatypeName datatype))
    deriveClause = [ DerivClause Nothing [ ConT ''TS.Unmarshal, ConT ''Eq, ConT ''Ord, ConT ''Show, ConT ''Generic ] ]
    generatedDatatype name cons = DataD [] name [] Nothing cons deriveClause


-- | Create TH-generated SymbolMatching instances for sums, products, leaves
symbolMatchingInstance :: Ptr TS.Language -> Name -> String -> Q [Dec]
symbolMatchingInstance language name str = do
  tsSymbol <- runIO $ withCString str (TS.ts_language_symbol_for_name language)
  tsSymbolType <- toEnum <$> runIO (TS.ts_language_symbol_type language tsSymbol)
  [d|instance TS.SymbolMatching $(conT name) where
      showFailure _ node = "Expected " <> $(litE (stringL (show name))) <> " but got " <> show (TS.fromTSSymbol (nodeSymbol node) :: $(conT (mkName "Grammar.Grammar")))
      symbolMatch _ node = TS.fromTSSymbol (nodeSymbol node) == $(conE (mkName $ "Grammar." <> TS.symbolToName tsSymbolType str))|]

symbolMatchingInstanceForSums ::  Ptr TS.Language -> Name -> [TreeSitter.Deserialize.Type] -> Q [Dec]
symbolMatchingInstanceForSums _ name subtypes =
  [d|instance TS.SymbolMatching $(conT name) where
      showFailure _ node = "Expected " <> $(litE (stringL (show (map extractn subtypes)))) <> " but got " <> show (TS.fromTSSymbol (nodeSymbol node) :: $(conT (mkName "Grammar.Grammar")))
      symbolMatch _ = $(foldr1 mkOr (perMkType `map` subtypes)) |]
  where perMkType (MkType (DatatypeName n) named) = [e|TS.symbolMatch (Proxy :: Proxy $(conT (toName named n))) |]
        mkOr lhs rhs = [e| (||) <$> $(lhs) <*> $(rhs) |]
        extractn (MkType (DatatypeName n) Named) = toCamelCase n
        extractn (MkType (DatatypeName n) Anonymous) = "Anonymous" <> toCamelCase n


-- | Append string with constructor name (ex., @IfStatementStatement IfStatement@)
constructorForSumChoice :: String -> TreeSitter.Deserialize.Type -> Q Con
constructorForSumChoice str (MkType (DatatypeName n) named) = normalC (toName named (n ++ str)) [child]
  where child = TH.bangType (TH.bang noSourceUnpackedness noSourceStrictness) (conT (toName named n))

-- | Build Q Constructor for product types (nodes with fields)
ctorForProductType :: String -> Maybe Children -> [(String, Field)] -> Q Con
ctorForProductType constructorName children fields = recC (toName Named constructorName) lists where
  lists = fieldList ++ childList
  fieldList = fmap (uncurry toVarBangType) fields
  childList = toList $ fmap toVarBangTypeChild children
  toVarBangType name (MkField required fieldTypes mult) =
    let fieldName = mkName . addTickIfNecessary . removeUnderscore $ name
        strictness = TH.bang noSourceUnpackedness noSourceStrictness
        ftypes = fieldTypesToNestedEither fieldTypes
        fieldContents = case (required, mult) of
          (Required, Multiple) -> appT (conT ''NonEmpty) ftypes
          (Required, Single) -> ftypes
          (Optional, Multiple) -> appT (conT ''[]) ftypes
          (Optional, Single) -> appT (conT ''Maybe) ftypes
    in TH.varBangType fieldName (TH.bangType strictness fieldContents)
  toVarBangTypeChild (MkChildren field) = toVarBangType "extra_children" field


-- | Build Q Constructor for leaf types (nodes with no fields or subtypes)
ctorForLeafType :: Named -> DatatypeName -> Q Con
ctorForLeafType Anonymous (DatatypeName name) = normalC (toName Anonymous name) []
ctorForLeafType Named (DatatypeName name) = recC (toName Named name) [leafBytes] where
  leafBytes = TH.varBangType (mkName "bytes") textValue
  textValue = TH.bangType (TH.bang noSourceUnpackedness noSourceStrictness) (conT ''Text)

-- | Convert field types to Q types
fieldTypesToNestedEither :: NonEmpty TreeSitter.Deserialize.Type -> Q TH.Type
fieldTypesToNestedEither xs = foldr1 combine $ fmap convertToQType xs
  where
    combine convertedQType = appT (appT (conT ''Either) convertedQType)
    convertToQType (MkType (DatatypeName n) named) = conT (toName named n)

-- | Convert snake_case string to CamelCase String
toCamelCase :: String -> String
toCamelCase = initUpper . escapeOperatorPunctuation . removeUnderscore

clashingNames :: HashSet String
clashingNames = HashSet.fromList ["type", "module", "data"]

addTickIfNecessary :: String -> String
addTickIfNecessary s
  | HashSet.member s clashingNames = s ++ "'"
  | otherwise                        = s

-- | Prepend "Anonymous" to named node when false, otherwise use regular toName
toName :: Named -> String -> Name
toName named str = mkName $ addTickIfNecessary $ case named of
  Anonymous -> "Anonymous" <> toCamelCase str
  Named -> toCamelCase str

-- Helper function to output camel cased data type names
initUpper :: String -> String
initUpper (c:cs) = toUpper c : cs
initUpper "" = ""

-- Helper function to remove underscores from output of data type names
removeUnderscore :: String -> String
removeUnderscore = foldr appender ""
  where appender :: Char -> String -> String
        appender '_' cs = initUpper cs
        appender c cs = c : cs