{-# LANGUAGE CPP             #-}
{-# LANGUAGE GADTs           #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE ViewPatterns    #-}

module Ide.Plugin.Class.ExactPrint where

import           Control.Monad.Trans.Maybe
import qualified Data.Text                               as T
import           Development.IDE.GHC.Compat
import           Ide.Plugin.Class.Types
import           Ide.Plugin.Class.Utils
import           Language.Haskell.GHC.ExactPrint
import           Language.Haskell.GHC.ExactPrint.Parsers
import           Language.LSP.Types

#if MIN_VERSION_ghc(9,2,0)
import           Data.Either.Extra                       (eitherToMaybe)
import           GHC.Parser.Annotation
#else
import           Control.Monad                           (foldM)
import qualified Data.Map.Strict                         as Map
import           Language.Haskell.GHC.ExactPrint.Types   hiding (GhcPs)
import           Language.Haskell.GHC.ExactPrint.Utils   (rs)
#endif

makeEditText :: Monad m => ParsedModule -> DynFlags -> AddMinimalMethodsParams -> MaybeT m (T.Text, T.Text)
-- addMethodDecls :: ParsedSource -> [(LHsDecl GhcPs, LHsDecl GhcPs)] -> Range -> Bool -> TransformT Identity (Located HsModule)
#if MIN_VERSION_ghc(9,2,0)
makeEditText :: forall (m :: * -> *).
Monad m =>
ParsedModule
-> DynFlags -> AddMinimalMethodsParams -> MaybeT m (Text, Text)
makeEditText ParsedModule
pm DynFlags
df AddMinimalMethodsParams{Bool
List (Text, Text)
Uri
Range
withSig :: AddMinimalMethodsParams -> Bool
methodGroup :: AddMinimalMethodsParams -> List (Text, Text)
range :: AddMinimalMethodsParams -> Range
uri :: AddMinimalMethodsParams -> Uri
withSig :: Bool
methodGroup :: List (Text, Text)
range :: Range
uri :: Uri
..} = do
    List [(GenLocated (SrcSpanAnn' (EpAnn AnnListItem)) (HsDecl GhcPs),
  GenLocated (SrcSpanAnn' (EpAnn AnnListItem)) (HsDecl GhcPs))]
mDecls <- forall (m :: * -> *) a. m (Maybe a) -> MaybeT m a
MaybeT forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse (DynFlags -> (Text, Text) -> Maybe (LHsDecl GhcPs, LHsDecl GhcPs)
makeMethodDecl DynFlags
df) List (Text, Text)
methodGroup
    let ps :: ParsedSource
ps = forall ast. ExactPrint ast => ast -> ast
makeDeltaAst forall a b. (a -> b) -> a -> b
$ ParsedModule -> ParsedSource
pm_parsed_source ParsedModule
pm
        old :: Text
old = String -> Text
T.pack forall a b. (a -> b) -> a -> b
$ forall ast. ExactPrint ast => ast -> String
exactPrint ParsedSource
ps
        (ParsedSource
ps', Int
_, [String]
_) = forall a. Transform a -> (a, Int, [String])
runTransform (forall {m :: * -> *} {b}.
(Monad m, HasDecls b) =>
b
-> [(GenLocated (SrcSpanAnn' (EpAnn AnnListItem)) (HsDecl GhcPs),
     GenLocated (SrcSpanAnn' (EpAnn AnnListItem)) (HsDecl GhcPs))]
-> Range
-> Bool
-> TransformT m b
addMethodDecls ParsedSource
ps [(GenLocated (SrcSpanAnn' (EpAnn AnnListItem)) (HsDecl GhcPs),
  GenLocated (SrcSpanAnn' (EpAnn AnnListItem)) (HsDecl GhcPs))]
mDecls Range
range Bool
withSig)
        new :: Text
new = String -> Text
T.pack forall a b. (a -> b) -> a -> b
$ forall ast. ExactPrint ast => ast -> String
exactPrint ParsedSource
ps'
    forall (f :: * -> *) a. Applicative f => a -> f a
pure (Text
old, Text
new)

makeMethodDecl :: DynFlags -> (T.Text, T.Text) -> Maybe (LHsDecl GhcPs, LHsDecl GhcPs)
makeMethodDecl :: DynFlags -> (Text, Text) -> Maybe (LHsDecl GhcPs, LHsDecl GhcPs)
makeMethodDecl DynFlags
df (Text
mName, Text
sig) = do
    GenLocated (SrcSpanAnn' (EpAnn AnnListItem)) (HsDecl GhcPs)
name <- forall a b. Either a b -> Maybe b
eitherToMaybe forall a b. (a -> b) -> a -> b
$ Parser (LHsDecl GhcPs)
parseDecl DynFlags
df (Text -> String
T.unpack Text
mName) forall b c a. (b -> c) -> (a -> b) -> a -> c
. Text -> String
T.unpack forall a b. (a -> b) -> a -> b
$ Text -> Text
toMethodName Text
mName forall a. Semigroup a => a -> a -> a
<> Text
" = _"
    GenLocated (SrcSpanAnn' (EpAnn AnnListItem)) (HsDecl GhcPs)
sig' <- forall a b. Either a b -> Maybe b
eitherToMaybe forall a b. (a -> b) -> a -> b
$ Parser (LHsDecl GhcPs)
parseDecl DynFlags
df (Text -> String
T.unpack Text
sig) forall a b. (a -> b) -> a -> b
$ Text -> String
T.unpack Text
sig
    forall (f :: * -> *) a. Applicative f => a -> f a
pure (GenLocated (SrcSpanAnn' (EpAnn AnnListItem)) (HsDecl GhcPs)
name, GenLocated (SrcSpanAnn' (EpAnn AnnListItem)) (HsDecl GhcPs)
sig')

addMethodDecls :: b
-> [(GenLocated (SrcSpanAnn' (EpAnn AnnListItem)) (HsDecl GhcPs),
     GenLocated (SrcSpanAnn' (EpAnn AnnListItem)) (HsDecl GhcPs))]
-> Range
-> Bool
-> TransformT m b
addMethodDecls b
ps [(GenLocated (SrcSpanAnn' (EpAnn AnnListItem)) (HsDecl GhcPs),
  GenLocated (SrcSpanAnn' (EpAnn AnnListItem)) (HsDecl GhcPs))]
mDecls Range
range Bool
withSig
    | Bool
withSig = [GenLocated (SrcSpanAnn' (EpAnn AnnListItem)) (HsDecl GhcPs)]
-> TransformT m b
go (forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (\(GenLocated (SrcSpanAnn' (EpAnn AnnListItem)) (HsDecl GhcPs)
decl, GenLocated (SrcSpanAnn' (EpAnn AnnListItem)) (HsDecl GhcPs)
sig) -> [GenLocated (SrcSpanAnn' (EpAnn AnnListItem)) (HsDecl GhcPs)
sig, GenLocated (SrcSpanAnn' (EpAnn AnnListItem)) (HsDecl GhcPs)
decl]) [(GenLocated (SrcSpanAnn' (EpAnn AnnListItem)) (HsDecl GhcPs),
  GenLocated (SrcSpanAnn' (EpAnn AnnListItem)) (HsDecl GhcPs))]
mDecls)
    | Bool
otherwise = [GenLocated (SrcSpanAnn' (EpAnn AnnListItem)) (HsDecl GhcPs)]
-> TransformT m b
go (forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> a
fst [(GenLocated (SrcSpanAnn' (EpAnn AnnListItem)) (HsDecl GhcPs),
  GenLocated (SrcSpanAnn' (EpAnn AnnListItem)) (HsDecl GhcPs))]
mDecls)
    where
    go :: [GenLocated (SrcSpanAnn' (EpAnn AnnListItem)) (HsDecl GhcPs)]
-> TransformT m b
go [GenLocated (SrcSpanAnn' (EpAnn AnnListItem)) (HsDecl GhcPs)]
inserting = do
        [GenLocated (SrcSpanAnn' (EpAnn AnnListItem)) (HsDecl GhcPs)]
allDecls <- forall t (m :: * -> *).
(HasDecls t, Monad m) =>
t -> TransformT m [LHsDecl GhcPs]
hsDecls b
ps
        let ([GenLocated (SrcSpanAnn' (EpAnn AnnListItem)) (HsDecl GhcPs)]
before, ((L SrcSpanAnn' (EpAnn AnnListItem)
l HsDecl GhcPs
inst): [GenLocated (SrcSpanAnn' (EpAnn AnnListItem)) (HsDecl GhcPs)]
after)) = forall a. (a -> Bool) -> [a] -> ([a], [a])
break (Range -> SrcSpan -> Bool
inRange Range
range forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. HasSrcSpan a => a -> SrcSpan
getLoc) [GenLocated (SrcSpanAnn' (EpAnn AnnListItem)) (HsDecl GhcPs)]
allDecls
        forall t (m :: * -> *).
(HasDecls t, Monad m) =>
t -> [LHsDecl GhcPs] -> TransformT m t
replaceDecls b
ps ([GenLocated (SrcSpanAnn' (EpAnn AnnListItem)) (HsDecl GhcPs)]
before forall a. [a] -> [a] -> [a]
++ (forall l e. l -> e -> GenLocated l e
L SrcSpanAnn' (EpAnn AnnListItem)
l (forall {p} {b}.
(XCClsInstDecl p ~ (EpAnn [AddEpAnn], b)) =>
HsDecl p -> HsDecl p
addWhere HsDecl GhcPs
inst))forall a. a -> [a] -> [a]
: (forall a b. (a -> b) -> [a] -> [b]
map forall {ann} {e}.
Monoid ann =>
GenLocated (SrcSpanAnn' (EpAnn ann)) e
-> GenLocated (SrcSpanAnn' (EpAnn ann)) e
newLine [GenLocated (SrcSpanAnn' (EpAnn AnnListItem)) (HsDecl GhcPs)]
inserting forall a. [a] -> [a] -> [a]
++ [GenLocated (SrcSpanAnn' (EpAnn AnnListItem)) (HsDecl GhcPs)]
after))
    -- Add `where` keyword for `instance X where` if `where` is missing.
    --
    -- The `where` in ghc-9.2 is now stored in the instance declaration
    --   directly. More precisely, giving an `HsDecl GhcPs`, we have:
    --   InstD --> ClsInstD --> ClsInstDecl --> XCClsInstDecl --> (EpAnn [AddEpAnn], AnnSortKey),
    --   here `AnnEpAnn` keeps the track of Anns.
    --
    -- See the link for the original definition:
    --   https://hackage.haskell.org/package/ghc-9.2.1/docs/Language-Haskell-Syntax-Extension.html#t:XCClsInstDecl
    addWhere :: HsDecl p -> HsDecl p
addWhere (InstD XInstD p
xInstD (ClsInstD XClsInstD p
ext decl :: ClsInstDecl p
decl@ClsInstDecl{[LTyFamInstDecl p]
[LDataFamInstDecl p]
[LSig p]
Maybe (XRec p OverlapMode)
LHsSigType p
XCClsInstDecl p
LHsBinds p
cid_binds :: forall pass. ClsInstDecl pass -> LHsBinds pass
cid_datafam_insts :: forall pass. ClsInstDecl pass -> [LDataFamInstDecl pass]
cid_ext :: forall pass. ClsInstDecl pass -> XCClsInstDecl pass
cid_overlap_mode :: forall pass. ClsInstDecl pass -> Maybe (XRec pass OverlapMode)
cid_poly_ty :: forall pass. ClsInstDecl pass -> LHsSigType pass
cid_sigs :: forall pass. ClsInstDecl pass -> [LSig pass]
cid_tyfam_insts :: forall pass. ClsInstDecl pass -> [LTyFamInstDecl pass]
cid_overlap_mode :: Maybe (XRec p OverlapMode)
cid_datafam_insts :: [LDataFamInstDecl p]
cid_tyfam_insts :: [LTyFamInstDecl p]
cid_sigs :: [LSig p]
cid_binds :: LHsBinds p
cid_poly_ty :: LHsSigType p
cid_ext :: XCClsInstDecl p
..})) =
        let (EpAnn Anchor
entry [AddEpAnn]
anns EpAnnComments
comments, b
key) = XCClsInstDecl p
cid_ext
        in forall p. XInstD p -> InstDecl p -> HsDecl p
InstD XInstD p
xInstD (forall pass. XClsInstD pass -> ClsInstDecl pass -> InstDecl pass
ClsInstD XClsInstD p
ext ClsInstDecl p
decl {
        cid_ext :: XCClsInstDecl p
cid_ext = (forall ann. Anchor -> ann -> EpAnnComments -> EpAnn ann
EpAnn
                    Anchor
entry
                    (AnnKeywordId -> EpaLocation -> AddEpAnn
AddEpAnn AnnKeywordId
AnnWhere (DeltaPos -> [LEpaComment] -> EpaLocation
EpaDelta (Int -> DeltaPos
SameLine Int
1) []) forall a. a -> [a] -> [a]
: [AddEpAnn]
anns)
                    EpAnnComments
comments
                    , b
key)
        })
    addWhere HsDecl p
decl = HsDecl p
decl

    newLine :: GenLocated (SrcSpanAnn' (EpAnn ann)) e
-> GenLocated (SrcSpanAnn' (EpAnn ann)) e
newLine (L SrcSpanAnn' (EpAnn ann)
l e
e) =
        let dp :: DeltaPos
dp = Int -> Int -> DeltaPos
deltaPos Int
1 Int
defaultIndent
        in forall l e. l -> e -> GenLocated l e
L (forall ann.
Monoid ann =>
SrcSpan -> DeltaPos -> SrcSpanAnn' (EpAnn ann)
noAnnSrcSpanDP (forall a. HasSrcSpan a => a -> SrcSpan
getLoc SrcSpanAnn' (EpAnn ann)
l) DeltaPos
dp forall a. Semigroup a => a -> a -> a
<> SrcSpanAnn' (EpAnn ann)
l) e
e

#else

makeEditText pm df AddMinimalMethodsParams{..} = do
    List (unzip -> (mAnns, mDecls)) <- MaybeT . pure $ traverse (makeMethodDecl df) methodGroup
    let ps = pm_parsed_source pm
        anns = relativiseApiAnns ps (pm_annotations pm)
        old = T.pack $ exactPrint ps anns
        (ps', (anns', _), _) = runTransform (mergeAnns (mergeAnnList mAnns) anns) (addMethodDecls ps mDecls range withSig)
        new = T.pack $ exactPrint ps' anns'
    pure (old, new)

makeMethodDecl :: DynFlags -> (T.Text, T.Text) -> Maybe (Anns, (LHsDecl GhcPs, LHsDecl GhcPs))
makeMethodDecl df (mName, sig) = do
    (nameAnn, name) <- case parseDecl df (T.unpack mName) . T.unpack $ toMethodName mName <> " = _" of
        Right (ann, d) -> Just (setPrecedingLines d 1 defaultIndent ann, d)
        Left _         -> Nothing
    (sigAnn, sig) <- case parseDecl df (T.unpack sig) $ T.unpack sig of
        Right (ann, d) -> Just (setPrecedingLines d 1 defaultIndent ann, d)
        Left _         -> Nothing
    pure (mergeAnnList [nameAnn, sigAnn], (name, sig))

addMethodDecls ps mDecls range withSig = do
    d <- findInstDecl ps range
    newSpan <- uniqueSrcSpanT
    let decls = if withSig then concatMap (\(decl, sig) -> [sig, decl]) mDecls else map fst mDecls
        annKey = mkAnnKey d
        newAnnKey = AnnKey (rs newSpan) (CN "HsValBinds")
        addWhere mkds@(Map.lookup annKey -> Just ann) = Map.insert newAnnKey ann2 mkds2
            where
                ann1 = ann
                        { annsDP = annsDP ann ++ [(G AnnWhere, DP (0, 1))]
                        , annCapturedSpan = Just newAnnKey
                        , annSortKey = Just (fmap (rs . getLoc) decls)
                        }
                mkds2 = Map.insert annKey ann1 mkds
                ann2 = annNone
                        { annEntryDelta = DP (1, defaultIndent)
                        }
        addWhere _ = panic "Ide.Plugin.Class.addMethodPlaceholder"
    modifyAnnsT addWhere
    modifyAnnsT (captureOrderAnnKey newAnnKey decls)
    foldM (insertAfter d) ps (reverse decls)

findInstDecl :: ParsedSource -> Range -> Transform (LHsDecl GhcPs)
findInstDecl ps range = head . filter (inRange range . getLoc) <$> hsDecls ps
#endif