-- Copyright (c) Facebook, Inc. and its affiliates.
--
-- This source code is licensed under the MIT license found in the
-- LICENSE file in the root directory of this source tree.
--
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
module Retrie.Replace
  ( replace
  , Replacement(..)
  , Change(..)
  ) where

import Control.Monad.Trans.Class
import Control.Monad.Writer.Strict
import Data.Char (isSpace)
import Data.Generics

import Retrie.ExactPrint
import Retrie.Expr
import Retrie.FreeVars
import Retrie.GHC
import Retrie.Subst
import Retrie.Types
import Retrie.Universe

------------------------------------------------------------------------

-- | Specializes 'replaceImpl' to each of the AST types that retrie supports.
replace
  :: (Data a, MonadIO m) => Context -> a -> TransformT (WriterT Change m) a
replace :: Context -> a -> TransformT (WriterT Change m) a
replace Context
c =
  (Located (HsExpr GhcPs)
 -> TransformT (WriterT Change m) (Located (HsExpr GhcPs)))
-> a -> TransformT (WriterT Change m) a
forall (m :: * -> *) a b.
(Monad m, Typeable a, Typeable b) =>
(b -> m b) -> a -> m a
mkM (Context
-> Located (HsExpr GhcPs)
-> TransformT (WriterT Change m) (Located (HsExpr GhcPs))
forall ast (m :: * -> *).
(Annotate ast, Matchable (Located ast), MonadIO m) =>
Context
-> Located ast -> TransformT (WriterT Change m) (Located ast)
replaceImpl @(HsExpr GhcPs) Context
c)
    (a -> TransformT (WriterT Change m) a)
-> (Located (Stmt GhcPs (Located (HsExpr GhcPs)))
    -> TransformT
         (WriterT Change m) (Located (Stmt GhcPs (Located (HsExpr GhcPs)))))
-> a
-> TransformT (WriterT Change m) a
forall (m :: * -> *) a b.
(Monad m, Typeable a, Typeable b) =>
(a -> m a) -> (b -> m b) -> a -> m a
`extM` (Context
-> Located (Stmt GhcPs (Located (HsExpr GhcPs)))
-> TransformT
     (WriterT Change m) (Located (Stmt GhcPs (Located (HsExpr GhcPs))))
forall ast (m :: * -> *).
(Annotate ast, Matchable (Located ast), MonadIO m) =>
Context
-> Located ast -> TransformT (WriterT Change m) (Located ast)
replaceImpl @(Stmt GhcPs (LHsExpr GhcPs)) Context
c)
    (a -> TransformT (WriterT Change m) a)
-> (Located (HsType GhcPs)
    -> TransformT (WriterT Change m) (Located (HsType GhcPs)))
-> a
-> TransformT (WriterT Change m) a
forall (m :: * -> *) a b.
(Monad m, Typeable a, Typeable b) =>
(a -> m a) -> (b -> m b) -> a -> m a
`extM` (Context
-> Located (HsType GhcPs)
-> TransformT (WriterT Change m) (Located (HsType GhcPs))
forall ast (m :: * -> *).
(Annotate ast, Matchable (Located ast), MonadIO m) =>
Context
-> Located ast -> TransformT (WriterT Change m) (Located ast)
replaceImpl @(HsType GhcPs) Context
c)
    (a -> TransformT (WriterT Change m) a)
-> (Located (Pat GhcPs)
    -> TransformT (WriterT Change m) (Located (Pat GhcPs)))
-> a
-> TransformT (WriterT Change m) a
forall (m :: * -> *) a b.
(Monad m, Typeable a, Typeable b) =>
(a -> m a) -> (b -> m b) -> a -> m a
`extM` Context -> LPat GhcPs -> TransformT (WriterT Change m) (LPat GhcPs)
forall (m :: * -> *).
MonadIO m =>
Context -> LPat GhcPs -> TransformT (WriterT Change m) (LPat GhcPs)
replacePat Context
c

replacePat :: MonadIO m => Context -> LPat GhcPs -> TransformT (WriterT Change m) (LPat GhcPs)
-- We need to ensure we have a location available at the top level so we can
-- transfer annotations. This ensures we don't try to rewrite a naked Pat.
replacePat :: Context -> LPat GhcPs -> TransformT (WriterT Change m) (LPat GhcPs)
replacePat Context
c LPat GhcPs
p
  | Just Located (Pat GhcPs)
lp <- LPat GhcPs -> Maybe (Located (Pat GhcPs))
forall (p :: Pass).
LPat (GhcPass p) -> Maybe (Located (Pat (GhcPass p)))
dLPat LPat GhcPs
p = Located (Pat GhcPs) -> Located (Pat GhcPs)
forall (p :: Pass). Located (Pat (GhcPass p)) -> LPat (GhcPass p)
cLPat (Located (Pat GhcPs) -> Located (Pat GhcPs))
-> TransformT (WriterT Change m) (Located (Pat GhcPs))
-> TransformT (WriterT Change m) (Located (Pat GhcPs))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Context
-> Located (Pat GhcPs)
-> TransformT (WriterT Change m) (Located (Pat GhcPs))
forall ast (m :: * -> *).
(Annotate ast, Matchable (Located ast), MonadIO m) =>
Context
-> Located ast -> TransformT (WriterT Change m) (Located ast)
replaceImpl Context
c Located (Pat GhcPs)
lp
  | Bool
otherwise = Located (Pat GhcPs)
-> TransformT (WriterT Change m) (Located (Pat GhcPs))
forall (m :: * -> *) a. Monad m => a -> m a
return LPat GhcPs
Located (Pat GhcPs)
p

-- | Generic replacement function. This is the thing that actually runs the
-- 'Rewriter' carried by the context, instantiates templates, handles parens
-- and other whitespace bookkeeping, and emits resulting 'Replacement's.
replaceImpl
  :: forall ast m. (Annotate ast, Matchable (Located ast), MonadIO m)
  => Context -> Located ast -> TransformT (WriterT Change m) (Located ast)
replaceImpl :: Context
-> Located ast -> TransformT (WriterT Change m) (Located ast)
replaceImpl Context
c Located ast
e = do
  let
    -- Prevent rewriting source of the rewrite itself by refusing to
    -- match under a binding of something that appears in the template.
    f :: RewriterResult ast -> RewriterResult ast
f result :: RewriterResult ast
result@RewriterResult{SrcSpan
Quantifiers
Template ast
MatchResultTransformer
rrTemplate :: forall ast. RewriterResult ast -> Template ast
rrTransformer :: forall ast. RewriterResult ast -> MatchResultTransformer
rrQuantifiers :: forall ast. RewriterResult ast -> Quantifiers
rrOrigin :: forall ast. RewriterResult ast -> SrcSpan
rrTemplate :: Template ast
rrTransformer :: MatchResultTransformer
rrQuantifiers :: Quantifiers
rrOrigin :: SrcSpan
..} = RewriterResult ast
result
      { rrTransformer :: MatchResultTransformer
rrTransformer =
          (IO (MatchResult Universe) -> IO (MatchResult Universe))
-> (MatchResult Universe -> IO (MatchResult Universe))
-> MatchResult Universe
-> IO (MatchResult Universe)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ((MatchResult Universe -> MatchResult Universe)
-> IO (MatchResult Universe) -> IO (MatchResult Universe)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (SrcSpan
-> Quantifiers -> MatchResult Universe -> MatchResult Universe
forall ast.
Data ast =>
SrcSpan -> Quantifiers -> MatchResult ast -> MatchResult ast
check SrcSpan
rrOrigin Quantifiers
rrQuantifiers)) ((MatchResult Universe -> IO (MatchResult Universe))
 -> MatchResult Universe -> IO (MatchResult Universe))
-> MatchResultTransformer -> MatchResultTransformer
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> MatchResultTransformer
rrTransformer
      }
    check :: SrcSpan -> Quantifiers -> MatchResult ast -> MatchResult ast
check SrcSpan
origin Quantifiers
quantifiers MatchResult ast
match
      | Located ast -> SrcSpan
forall a. HasSrcSpan a => a -> SrcSpan
getLoc Located ast
e SrcSpan -> SrcSpan -> Bool
`overlaps` SrcSpan
origin = MatchResult ast
forall ast. MatchResult ast
NoMatch
      | MatchResult Substitution
_ Template{Maybe [Rewrite Universe]
Annotated ast
AnnotatedImports
tDependents :: forall ast. Template ast -> Maybe [Rewrite Universe]
tImports :: forall ast. Template ast -> AnnotatedImports
tTemplate :: forall ast. Template ast -> Annotated ast
tDependents :: Maybe [Rewrite Universe]
tImports :: AnnotatedImports
tTemplate :: Annotated ast
..} <- MatchResult ast
match
      , FreeVars
fvs <- Quantifiers -> ast -> FreeVars
forall a. (Data a, Typeable a) => Quantifiers -> a -> FreeVars
freeVars Quantifiers
quantifiers (Annotated ast -> ast
forall ast. Annotated ast -> ast
astA Annotated ast
tTemplate)
      , (RdrName -> Bool) -> [RdrName] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (RdrName -> FreeVars -> Bool
`elemFVs` FreeVars
fvs) (Context -> [RdrName]
ctxtBinders Context
c) = MatchResult ast
forall ast. MatchResult ast
NoMatch
      | Bool
otherwise = MatchResult ast
match

  -- We want to match through HsPar so we can make a decision
  -- about whether to keep the parens or not based on the
  -- resulting expression, but we need to know the entry location
  -- of the parens, not the inner expression, so we have to
  -- keep both expressions around.
  MatchResult (Located ast)
match <- (RewriterResult Universe -> RewriterResult Universe)
-> Context
-> Rewriter
-> Located ast
-> TransformT (WriterT Change m) (MatchResult (Located ast))
forall ast (m :: * -> *).
(Matchable ast, MonadIO m) =>
(RewriterResult Universe -> RewriterResult Universe)
-> Context -> Rewriter -> ast -> TransformT m (MatchResult ast)
runRewriter RewriterResult Universe -> RewriterResult Universe
forall ast. RewriterResult ast -> RewriterResult ast
f Context
c (Context -> Rewriter
ctxtRewriter Context
c) (Located ast -> Located ast
forall k. Data k => k -> k
getUnparened Located ast
e)

  case MatchResult (Located ast)
match of
    MatchResult (Located ast)
NoMatch -> Located ast -> TransformT (WriterT Change m) (Located ast)
forall (m :: * -> *) a. Monad m => a -> m a
return Located ast
e
    MatchResult Substitution
sub Template{Maybe [Rewrite Universe]
AnnotatedImports
Annotated (Located ast)
tDependents :: Maybe [Rewrite Universe]
tImports :: AnnotatedImports
tTemplate :: Annotated (Located ast)
tDependents :: forall ast. Template ast -> Maybe [Rewrite Universe]
tImports :: forall ast. Template ast -> AnnotatedImports
tTemplate :: forall ast. Template ast -> Annotated ast
..} -> do
      -- graft template into target module
      Located ast
t' <- Annotated (Located ast)
-> TransformT (WriterT Change m) (Located ast)
forall ast (m :: * -> *).
(Data ast, Monad m) =>
Annotated ast -> TransformT m ast
graftA Annotated (Located ast)
tTemplate
      -- substitute for quantifiers in grafted template
      Located ast
r <- Substitution
-> Context
-> Located ast
-> TransformT (WriterT Change m) (Located ast)
forall (m :: * -> *) ast.
(MonadIO m, Data ast) =>
Substitution -> Context -> ast -> TransformT m ast
subst Substitution
sub Context
c Located ast
t'
      -- copy appropriate annotations from old expression to template
      Located ast -> Located ast -> TransformT (WriterT Change m) ()
forall a b (m :: * -> *).
(HasCallStack, Data a, Data b, Monad m) =>
Located a -> Located b -> TransformT m ()
addAllAnnsT Located ast
e Located ast
r
      -- add parens to template if needed
      Located ast
res <- ((Located (HsExpr GhcPs)
 -> TransformT (WriterT Change m) (Located (HsExpr GhcPs)))
-> Located ast -> TransformT (WriterT Change m) (Located ast)
forall (m :: * -> *) a b.
(Monad m, Typeable a, Typeable b) =>
(b -> m b) -> a -> m a
mkM (Context
-> Located (HsExpr GhcPs)
-> TransformT (WriterT Change m) (Located (HsExpr GhcPs))
forall (m :: * -> *).
Monad m =>
Context
-> Located (HsExpr GhcPs) -> TransformT m (Located (HsExpr GhcPs))
parenify Context
c) (Located ast -> TransformT (WriterT Change m) (Located ast))
-> (Located (HsType GhcPs)
    -> TransformT (WriterT Change m) (Located (HsType GhcPs)))
-> Located ast
-> TransformT (WriterT Change m) (Located ast)
forall (m :: * -> *) a b.
(Monad m, Typeable a, Typeable b) =>
(a -> m a) -> (b -> m b) -> a -> m a
`extM` Context
-> Located (HsType GhcPs)
-> TransformT (WriterT Change m) (Located (HsType GhcPs))
forall (m :: * -> *).
Monad m =>
Context
-> Located (HsType GhcPs) -> TransformT m (Located (HsType GhcPs))
parenifyT Context
c (Located ast -> TransformT (WriterT Change m) (Located ast))
-> (Located (Pat GhcPs)
    -> TransformT (WriterT Change m) (Located (Pat GhcPs)))
-> Located ast
-> TransformT (WriterT Change m) (Located ast)
forall (m :: * -> *) a b.
(Monad m, Typeable a, Typeable b) =>
(a -> m a) -> (b -> m b) -> a -> m a
`extM` Context
-> Located (Pat GhcPs)
-> TransformT (WriterT Change m) (Located (Pat GhcPs))
forall (m :: * -> *).
Monad m =>
Context
-> Located (Pat GhcPs) -> TransformT m (Located (Pat GhcPs))
parenifyP Context
c) Located ast
r
      -- prune the resulting expression and log it with location
      String
orig <- Annotated (Located ast) -> String
forall k. Annotate k => Annotated (Located k) -> String
printNoLeadingSpaces (Annotated (Located ast) -> String)
-> TransformT (WriterT Change m) (Annotated (Located ast))
-> TransformT (WriterT Change m) String
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Located ast
-> TransformT (WriterT Change m) (Annotated (Located ast))
forall ast (m :: * -> *).
(Data ast, Monad m) =>
ast -> TransformT m (Annotated ast)
pruneA Located ast
e
      String
repl <- Annotated (Located ast) -> String
forall k. Annotate k => Annotated (Located k) -> String
printNoLeadingSpaces (Annotated (Located ast) -> String)
-> TransformT (WriterT Change m) (Annotated (Located ast))
-> TransformT (WriterT Change m) String
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Located ast
-> TransformT (WriterT Change m) (Annotated (Located ast))
forall ast (m :: * -> *).
(Data ast, Monad m) =>
ast -> TransformT m (Annotated ast)
pruneA Located ast
res
      let replacement :: Replacement
replacement = SrcSpan -> String -> String -> Replacement
Replacement (Located ast -> SrcSpan
forall a. HasSrcSpan a => a -> SrcSpan
getLoc Located ast
e) String
orig String
repl
      RWST () [String] (Anns, Int) (WriterT Change m) ()
-> TransformT (WriterT Change m) ()
forall (m :: * -> *) a.
RWST () [String] (Anns, Int) m a -> TransformT m a
TransformT (RWST () [String] (Anns, Int) (WriterT Change m) ()
 -> TransformT (WriterT Change m) ())
-> RWST () [String] (Anns, Int) (WriterT Change m) ()
-> TransformT (WriterT Change m) ()
forall a b. (a -> b) -> a -> b
$ WriterT Change m ()
-> RWST () [String] (Anns, Int) (WriterT Change m) ()
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (WriterT Change m ()
 -> RWST () [String] (Anns, Int) (WriterT Change m) ())
-> WriterT Change m ()
-> RWST () [String] (Anns, Int) (WriterT Change m) ()
forall a b. (a -> b) -> a -> b
$ Change -> WriterT Change m ()
forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell (Change -> WriterT Change m ()) -> Change -> WriterT Change m ()
forall a b. (a -> b) -> a -> b
$ [Replacement] -> [AnnotatedImports] -> Change
Change [Replacement
replacement] [AnnotatedImports
tImports]
      -- make the actual replacement
      Located ast -> TransformT (WriterT Change m) (Located ast)
forall (m :: * -> *) a. Monad m => a -> m a
return Located ast
res

-- | Records a replacement made. In cases where we cannot use ghc-exactprint
-- to print the resulting AST (e.g. CPP modules), we fall back on splicing
-- strings. Can also be used by external tools (search, linters, etc).
data Replacement = Replacement
  { Replacement -> SrcSpan
replLocation :: SrcSpan
  , Replacement -> String
replOriginal :: String
  , Replacement -> String
replReplacement :: String
  }

-- | Used as the writer type during matching to indicate whether any change
-- to the module should be made.
data Change = NoChange | Change [Replacement] [AnnotatedImports]

instance Semigroup Change where
  <> :: Change -> Change -> Change
(<>) = Change -> Change -> Change
forall a. Monoid a => a -> a -> a
mappend

instance Monoid Change where
  mempty :: Change
mempty = Change
NoChange
  mappend :: Change -> Change -> Change
mappend Change
NoChange     Change
other        = Change
other
  mappend Change
other        Change
NoChange     = Change
other
  mappend (Change [Replacement]
rs1 [AnnotatedImports]
is1) (Change [Replacement]
rs2 [AnnotatedImports]
is2) =
    [Replacement] -> [AnnotatedImports] -> Change
Change ([Replacement]
rs1 [Replacement] -> [Replacement] -> [Replacement]
forall a. Semigroup a => a -> a -> a
<> [Replacement]
rs2) ([AnnotatedImports]
is1 [AnnotatedImports] -> [AnnotatedImports] -> [AnnotatedImports]
forall a. Semigroup a => a -> a -> a
<> [AnnotatedImports]
is2)

-- The location of 'e' accurately points to the first non-space character
-- of 'e', but when we exactprint 'e', we might get some leading spaces (if
-- annEntryDelta of the first token is non-zero). This means we can't just
-- splice in the printed expression at the desired location and call it a day.
-- Unfortunately, its hard to find the right annEntryDelta (it may not be the
-- top of the redex) and zero it out. As janky as it seems, its easier to just
-- drop leading spaces like this.
printNoLeadingSpaces :: Annotate k => Annotated (Located k) -> String
printNoLeadingSpaces :: Annotated (Located k) -> String
printNoLeadingSpaces = (Char -> Bool) -> String -> String
forall a. (a -> Bool) -> [a] -> [a]
dropWhile Char -> Bool
isSpace (String -> String)
-> (Annotated (Located k) -> String)
-> Annotated (Located k)
-> String
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Annotated (Located k) -> String
forall k. Annotate k => Annotated (Located k) -> String
printA