{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE LambdaCase #-}

module Axel.Macros where

import Axel.AST
  ( Identifier
  , MacroDefinition
  , Statement(SDataDeclaration, SFunctionDefinition, SLanguagePragma,
          SMacroDefinition, SModuleDeclaration, SQualifiedImport,
          SRestrictedImport, STopLevel, STypeSynonym, STypeclassInstance,
          SUnrestrictedImport)
  , ToHaskell(toHaskell)
  , name
  , statements
  )
import Axel.Denormalize (denormalizeStatement)
import Axel.Error (Error(MacroError))
import Axel.Eval (evalMacro)
import Axel.Normalize (normalizeStatement)
import qualified Axel.Parse as Parse
  ( Expression(LiteralChar, LiteralInt, LiteralString, SExpression,
           Symbol)
  , parseMultiple
  )
import Axel.Utils.Display (Delimiter(Newlines), delimit, isOperator)
import Axel.Utils.Function (uncurry3)
import Axel.Utils.Recursion
  ( Recursive(bottomUpFmap, bottomUpTraverse)
  , exhaustM
  )
import Axel.Utils.Resources (readResource)
import qualified Axel.Utils.Resources as Res
  ( astDefinition
  , macroDefinitionAndEnvironmentHeader
  , macroScaffold
  )
import Axel.Utils.String (replace)

import Control.Lens.Operators ((%~), (^.))
import Control.Lens.Tuple (_1, _2)
import Control.Monad (foldM)
import Control.Monad.Except (MonadError, catchError, throwError)
import Control.Monad.IO.Class (MonadIO, liftIO)
import Control.Monad.Trans.Control (MonadBaseControl)

import Data.Function ((&))
import Data.Semigroup ((<>))

generateMacroProgram ::
     (MonadBaseControl IO m, MonadError Error m, MonadIO m)
  => MacroDefinition
  -> [Statement]
  -> [Parse.Expression]
  -> m (String, String, String)
generateMacroProgram macroDefinition environment applicationArguments = do
  astDefinition <- liftIO $ readResource Res.astDefinition
  scaffold <- liftIO getScaffold
  macroDefinitionAndEnvironment <-
    (<>) <$> liftIO (readResource Res.macroDefinitionAndEnvironmentHeader) <*>
    getMacroDefinitionAndEnvironmentFooter
  pure (astDefinition, scaffold, macroDefinitionAndEnvironment)
  where
    getMacroDefinitionAndEnvironmentFooter = do
      hygenicMacroDefinition <-
        replaceName
          (macroDefinition ^. name)
          newMacroName
          (SMacroDefinition macroDefinition)
      let source =
            delimit Newlines $
            map toHaskell (environment <> [hygenicMacroDefinition])
      pure source
    getScaffold =
      let insertApplicationArguments =
            let applicationArgumentsPlaceholder = "%%%ARGUMENTS%%%"
            in replace
                 applicationArgumentsPlaceholder
                 (show applicationArguments)
          insertDefinitionName =
            let definitionNamePlaceholder = "%%%MACRO_NAME%%%"
            in replace definitionNamePlaceholder newMacroName
      in insertApplicationArguments . insertDefinitionName <$>
         readResource Res.macroScaffold
    newMacroName =
      (macroDefinition ^. name) ++
      if isOperator (macroDefinition ^. name)
        then "%%%%%%%%%%"
        else "_AXEL_AUTOGENERATED_MACRO_DEFINITION"

expansionPass ::
     (MonadBaseControl IO m, MonadError Error m, MonadIO m)
  => Parse.Expression
  -> m Parse.Expression
expansionPass programExpr =
  stmtExprsToProgram . map denormalizeStatement <$>
  expandMacros (programToTopLevelExprs programExpr)
  where
    programToTopLevelExprs :: Parse.Expression -> [Parse.Expression]
    programToTopLevelExprs (Parse.SExpression (Parse.Symbol "begin":stmts)) =
      stmts
    programToTopLevelExprs _ =
      error "programToTopLevelExprs must be passed a top-level program!"
    stmtExprsToProgram :: [Parse.Expression] -> Parse.Expression
    stmtExprsToProgram stmts = Parse.SExpression (Parse.Symbol "begin" : stmts)

exhaustivelyExpandMacros ::
     (MonadBaseControl IO m, MonadError Error m, MonadIO m)
  => Parse.Expression
  -> m Parse.Expression
exhaustivelyExpandMacros = exhaustM expansionPass

-- TODO This needs heavy optimization.
expandMacros ::
     (MonadBaseControl IO m, MonadError Error m, MonadIO m)
  => [Parse.Expression]
  -> m [Statement]
expandMacros topLevelExprs =
  fst <$>
  foldM
    (\acc@(stmts, macroDefs) expr -> do
       expandedExpr <- fullyExpandExpr stmts macroDefs expr
       stmt <- normalizeStatement expandedExpr
       pure $
         case stmt of
           SMacroDefinition macroDefinition ->
             acc & _2 %~ (<> [macroDefinition])
           _ ->
             if isStmtNonconflicting stmt
               then acc & _1 %~ (<> [stmt])
               else acc)
    ([], [])
    topLevelExprs
  where
    isStmtNonconflicting =
      \case
        SDataDeclaration _ -> True
        SFunctionDefinition _ -> True
        SLanguagePragma _ -> True
        SMacroDefinition _ -> True
        SModuleDeclaration _ -> False
        SQualifiedImport _ -> True
        SRestrictedImport _ -> True
        STopLevel _ -> False
        STypeclassInstance _ -> True
        STypeSynonym _ -> True
        SUnrestrictedImport _ -> True
    fullyExpandExpr stmts macroDefs =
      exhaustM $
      bottomUpTraverse
        (\case
           Parse.SExpression xs ->
             Parse.SExpression <$>
             foldM
               (\acc x ->
                  case x of
                    Parse.LiteralChar _ -> pure $ acc ++ [x]
                    Parse.LiteralInt _ -> pure $ acc ++ [x]
                    Parse.LiteralString _ -> pure $ acc ++ [x]
                    Parse.SExpression [] -> pure $ acc ++ [x]
                    Parse.SExpression (function:args) ->
                      lookupMacroDefinition macroDefs function >>= \case
                        Just macroDefinition ->
                          (acc ++) <$>
                          expandMacroApplication macroDefinition stmts args
                        Nothing -> pure $ acc ++ [x]
                    Parse.Symbol _ -> pure $ acc ++ [x])
               []
               xs
           expr -> pure expr)

expandMacroApplication ::
     (MonadBaseControl IO m, MonadError Error m, MonadIO m)
  => MacroDefinition
  -> [Statement]
  -> [Parse.Expression]
  -> m [Parse.Expression]
expandMacroApplication macroDef auxEnv args = do
  macroProgram <- generateMacroProgram macroDef auxEnv args
  newSource <- uncurry3 evalMacro macroProgram
  Parse.parseMultiple newSource

lookupMacroDefinition ::
     (MonadError Error m)
  => [MacroDefinition]
  -> Parse.Expression
  -> m (Maybe MacroDefinition)
lookupMacroDefinition macroDefs identifierExpr =
  case filter (`isMacroBeingCalled` identifierExpr) macroDefs of
    [] -> pure Nothing
    [macroDef] -> pure $ Just macroDef
    macroDef:_ ->
      throwError
        (MacroError $
         "Multiple macro definitions named: `" <> macroDef ^. name <> "`!")

isMacroBeingCalled :: MacroDefinition -> Parse.Expression -> Bool
isMacroBeingCalled macroDef identifierExpr =
  case identifierExpr of
    Parse.LiteralChar _ -> False
    Parse.LiteralInt _ -> False
    Parse.LiteralString _ -> False
    Parse.SExpression _ -> False
    Parse.Symbol identifier -> macroDef ^. name == identifier

stripMacroDefinitions :: Statement -> Statement
stripMacroDefinitions =
  \case
    STopLevel topLevel ->
      STopLevel $
      (statements %~ filter (not . isMacroDefinitionStatement)) topLevel
    x -> x

isMacroDefinitionStatement :: Statement -> Bool
isMacroDefinitionStatement (SMacroDefinition _) = True
isMacroDefinitionStatement _ = False

replaceName ::
     (MonadError Error m)
  => Identifier
  -> Identifier
  -> Statement
  -> m Statement
replaceName oldName newName =
  normalize . bottomUpFmap replaceSymbol . denormalizeStatement
  where
    normalize expr =
      normalizeStatement expr `catchError` \_ ->
        throwError (MacroError $ "Invalid macro name: `" <> oldName <> "`!")
    replaceSymbol expr =
      case expr of
        Parse.Symbol identifier ->
          Parse.Symbol $
          if identifier == oldName
            then newName
            else identifier
        _ -> expr