-- Copyright (c) Facebook, Inc. and its affiliates.
--
-- This source code is licensed under the MIT license found in the
-- LICENSE file in the root directory of this source tree.
--
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE PackageImports #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE TypeApplications #-}
module Retrie.Elaborate
  ( defaultElaborations
  , elaborateRewritesInternal
  ) where

import Control.Monad
import Control.Monad.IO.Class
import Control.Monad.Trans.Class
import "list-t" ListT
import Data.Maybe

import Retrie.Context
import Retrie.ExactPrint
import Retrie.Expr
import Retrie.Fixity
import Retrie.GHC
import Retrie.Quantifiers
import Retrie.Rewrites
import Retrie.Subst
import Retrie.Substitution
import Retrie.SYB
import Retrie.Types
import Retrie.Universe

defaultElaborations :: [RewriteSpec]
defaultElaborations :: [RewriteSpec]
defaultElaborations =
  [ String -> RewriteSpec
Adhoc String
"forall f x. f $ x = f (x)"
  ]

elaborateRewritesInternal
  :: FixityEnv
  -> [Rewrite Universe]
  -> [Rewrite Universe]
  -> IO [Rewrite Universe]
elaborateRewritesInternal :: FixityEnv
-> [Rewrite Universe]
-> [Rewrite Universe]
-> IO [Rewrite Universe]
elaborateRewritesInternal FixityEnv
_ [] [Rewrite Universe]
rewrites = [Rewrite Universe] -> IO [Rewrite Universe]
forall (m :: * -> *) a. Monad m => a -> m a
return [Rewrite Universe]
rewrites
elaborateRewritesInternal FixityEnv
fixityEnv [Rewrite Universe]
elaborations [Rewrite Universe]
rewrites =
  [[Rewrite Universe]] -> [Rewrite Universe]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat ([[Rewrite Universe]] -> [Rewrite Universe])
-> IO [[Rewrite Universe]] -> IO [Rewrite Universe]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Rewrite Universe -> IO [Rewrite Universe])
-> [Rewrite Universe] -> IO [[Rewrite Universe]]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (FixityEnv -> Rewriter -> Rewrite Universe -> IO [Rewrite Universe]
elaborateOne FixityEnv
fixityEnv Rewriter
elaborator) [Rewrite Universe]
rewrites
  where
    elaborator :: Rewriter
elaborator = (Rewrite Universe -> Rewriter) -> [Rewrite Universe] -> Rewriter
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap Rewrite Universe -> Rewriter
forall ast. Matchable ast => Rewrite ast -> Rewriter
mkRewriter [Rewrite Universe]
elaborations

elaborateOne :: FixityEnv -> Rewriter -> Rewrite Universe -> IO [Rewrite Universe]
elaborateOne :: FixityEnv -> Rewriter -> Rewrite Universe -> IO [Rewrite Universe]
elaborateOne FixityEnv
fixityEnv Rewriter
elaborator Rewrite Universe
rr = do
  Annotated [Universe]
patterns <-
    Annotated Universe
-> (Universe -> TransformT IO [Universe])
-> IO (Annotated [Universe])
forall (m :: * -> *) ast1 ast2.
Monad m =>
Annotated ast1 -> (ast1 -> TransformT m ast2) -> m (Annotated ast2)
transformA (Rewrite Universe -> Annotated Universe
forall ast v. Query ast v -> Annotated ast
qPattern Rewrite Universe
rr) ((Universe -> TransformT IO [Universe])
 -> IO (Annotated [Universe]))
-> (Universe -> TransformT IO [Universe])
-> IO (Annotated [Universe])
forall a b. (a -> b) -> a -> b
$ ListT (TransformT IO) Universe -> TransformT IO [Universe]
forall (m :: * -> *) a. Monad m => ListT m a -> m [a]
toList (ListT (TransformT IO) Universe -> TransformT IO [Universe])
-> (Universe -> ListT (TransformT IO) Universe)
-> Universe
-> TransformT IO [Universe]
forall b c a. (b -> c) -> (a -> b) -> a -> c
.
      Strategy (ListT (TransformT IO))
-> GenericQ Bool
-> GenericCU (ListT (TransformT IO)) Context
-> GenericMC (ListT (TransformT IO)) Context
-> Context
-> Universe
-> ListT (TransformT IO) Universe
forall (m :: * -> *) c.
Monad m =>
Strategy m
-> GenericQ Bool -> GenericCU m c -> GenericMC m c -> GenericMC m c
everywhereMWithContextBut Strategy (ListT (TransformT IO))
forall (m :: * -> *). Strategy m
topDown
        (Bool -> a -> Bool
forall a b. a -> b -> a
const Bool
False) (\Context
c Int
i a
x -> TransformT IO Context -> ListT (TransformT IO) Context
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (TransformT IO Context -> ListT (TransformT IO) Context)
-> TransformT IO Context -> ListT (TransformT IO) Context
forall a b. (a -> b) -> a -> b
$ Context -> Int -> a -> TransformT IO Context
forall (m :: * -> *). MonadIO m => GenericCU (TransformT m) Context
updateContext Context
c Int
i a
x) GenericMC (ListT (TransformT IO)) Context
forall a (m :: * -> *).
(Data a, MonadIO m) =>
Context -> a -> ListT (TransformT m) a
elaborate Context
ctxt
  [Rewrite Universe] -> IO [Rewrite Universe]
forall (m :: * -> *) a. Monad m => a -> m a
return [ Rewrite Universe
rr { qPattern :: Annotated Universe
qPattern = Annotated Universe
pattern } | Annotated Universe
pattern <- Annotated [Universe] -> [Annotated Universe]
forall (t :: * -> *) (f :: * -> *) a.
(Traversable t, Applicative f) =>
t (f a) -> f (t a)
sequenceA Annotated [Universe]
patterns ]
  where
    ctxt :: Context
ctxt = FixityEnv -> Rewriter -> Rewriter -> Context
emptyContext FixityEnv
fixityEnv Rewriter
elaborator Rewriter
forall a. Monoid a => a
mempty

elaborate
  :: (Data a, MonadIO m) => Context -> a -> ListT (TransformT m) a
elaborate :: Context -> a -> ListT (TransformT m) a
elaborate Context
c =
  (Located (HsExpr GhcPs)
 -> ListT (TransformT m) (Located (HsExpr GhcPs)))
-> a -> ListT (TransformT m) a
forall (m :: * -> *) a b.
(Monad m, Typeable a, Typeable b) =>
(b -> m b) -> a -> m a
mkM (Context
-> Located (HsExpr GhcPs)
-> ListT (TransformT m) (Located (HsExpr GhcPs))
forall ast (m :: * -> *).
(Annotate ast, Matchable (Located ast), MonadIO m) =>
Context -> Located ast -> ListT (TransformT m) (Located ast)
elaborateImpl @(HsExpr GhcPs) Context
c)
    (a -> ListT (TransformT m) a)
-> (Located (Stmt GhcPs (Located (HsExpr GhcPs)))
    -> ListT
         (TransformT m) (Located (Stmt GhcPs (Located (HsExpr GhcPs)))))
-> a
-> ListT (TransformT m) a
forall (m :: * -> *) a b.
(Monad m, Typeable a, Typeable b) =>
(a -> m a) -> (b -> m b) -> a -> m a
`extM` (Context
-> Located (Stmt GhcPs (Located (HsExpr GhcPs)))
-> ListT
     (TransformT m) (Located (Stmt GhcPs (Located (HsExpr GhcPs))))
forall ast (m :: * -> *).
(Annotate ast, Matchable (Located ast), MonadIO m) =>
Context -> Located ast -> ListT (TransformT m) (Located ast)
elaborateImpl @(Stmt GhcPs (LHsExpr GhcPs)) Context
c)
    (a -> ListT (TransformT m) a)
-> (Located (HsType GhcPs)
    -> ListT (TransformT m) (Located (HsType GhcPs)))
-> a
-> ListT (TransformT m) a
forall (m :: * -> *) a b.
(Monad m, Typeable a, Typeable b) =>
(a -> m a) -> (b -> m b) -> a -> m a
`extM` (Context
-> Located (HsType GhcPs)
-> ListT (TransformT m) (Located (HsType GhcPs))
forall ast (m :: * -> *).
(Annotate ast, Matchable (Located ast), MonadIO m) =>
Context -> Located ast -> ListT (TransformT m) (Located ast)
elaborateImpl @(HsType GhcPs) Context
c)
    (a -> ListT (TransformT m) a)
-> (Located (Pat GhcPs)
    -> ListT (TransformT m) (Located (Pat GhcPs)))
-> a
-> ListT (TransformT m) a
forall (m :: * -> *) a b.
(Monad m, Typeable a, Typeable b) =>
(a -> m a) -> (b -> m b) -> a -> m a
`extM` (Context -> LPat GhcPs -> ListT (TransformT m) (LPat GhcPs)
forall (m :: * -> *).
MonadIO m =>
Context -> LPat GhcPs -> ListT (TransformT m) (LPat GhcPs)
elaboratePat Context
c)

elaboratePat :: MonadIO m => Context -> LPat GhcPs -> ListT (TransformT m) (LPat GhcPs)
-- We need to ensure we have a location available at the top level so we can
-- transfer annotations. This ensures we don't try to rewrite a naked Pat.
elaboratePat :: Context -> LPat GhcPs -> ListT (TransformT m) (LPat GhcPs)
elaboratePat Context
c LPat GhcPs
p
  | Just Located (Pat GhcPs)
lp <- LPat GhcPs -> Maybe (Located (Pat GhcPs))
forall (p :: Pass).
LPat (GhcPass p) -> Maybe (Located (Pat (GhcPass p)))
dLPat LPat GhcPs
p = Located (Pat GhcPs) -> Located (Pat GhcPs)
forall (p :: Pass). Located (Pat (GhcPass p)) -> LPat (GhcPass p)
cLPat (Located (Pat GhcPs) -> Located (Pat GhcPs))
-> ListT (TransformT m) (Located (Pat GhcPs))
-> ListT (TransformT m) (Located (Pat GhcPs))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Context
-> Located (Pat GhcPs)
-> ListT (TransformT m) (Located (Pat GhcPs))
forall ast (m :: * -> *).
(Annotate ast, Matchable (Located ast), MonadIO m) =>
Context -> Located ast -> ListT (TransformT m) (Located ast)
elaborateImpl Context
c Located (Pat GhcPs)
lp
  | Bool
otherwise = Located (Pat GhcPs) -> ListT (TransformT m) (Located (Pat GhcPs))
forall (m :: * -> *) a. Monad m => a -> m a
return LPat GhcPs
Located (Pat GhcPs)
p

elaborateImpl
  :: forall ast m. (Annotate ast, Matchable (Located ast), MonadIO m)
  => Context -> Located ast -> ListT (TransformT m) (Located ast)
elaborateImpl :: Context -> Located ast -> ListT (TransformT m) (Located ast)
elaborateImpl Context
ctxt Located ast
e = do
  [Located ast]
elaborations <- TransformT m [Located ast] -> ListT (TransformT m) [Located ast]
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (TransformT m [Located ast] -> ListT (TransformT m) [Located ast])
-> TransformT m [Located ast] -> ListT (TransformT m) [Located ast]
forall a b. (a -> b) -> a -> b
$ do
    [(Substitution, RewriterResult Universe)]
matches <- Context
-> Rewriter
-> Located ast
-> TransformT m [(Substitution, RewriterResult Universe)]
forall ast (m :: * -> *) v.
(Matchable ast, MonadIO m) =>
Context -> Matcher v -> ast -> TransformT m [(Substitution, v)]
runMatcher Context
ctxt (Context -> Rewriter
ctxtRewriter Context
ctxt) (Located ast -> Located ast
forall k. Data k => k -> k
getUnparened Located ast
e)
    [MatchResult (Located ast)]
validMatches <- Context
-> [(Substitution, RewriterResult Universe)]
-> TransformT m [MatchResult (Located ast)]
forall ast (m :: * -> *).
(Matchable ast, MonadIO m) =>
Context
-> [(Substitution, RewriterResult Universe)]
-> TransformT m [MatchResult ast]
allMatches Context
ctxt [(Substitution, RewriterResult Universe)]
matches
    [(Substitution, Template (Located ast))]
-> ((Substitution, Template (Located ast))
    -> TransformT m (Located ast))
-> TransformT m [Located ast]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [ (Substitution
sub, Template (Located ast)
tmpl) | MatchResult Substitution
sub Template (Located ast)
tmpl <- [MatchResult (Located ast)]
validMatches ] (((Substitution, Template (Located ast))
  -> TransformT m (Located ast))
 -> TransformT m [Located ast])
-> ((Substitution, Template (Located ast))
    -> TransformT m (Located ast))
-> TransformT m [Located ast]
forall a b. (a -> b) -> a -> b
$ \(Substitution
sub, Template{Maybe [Rewrite Universe]
AnnotatedImports
Annotated (Located ast)
tDependents :: forall ast. Template ast -> Maybe [Rewrite Universe]
tImports :: forall ast. Template ast -> AnnotatedImports
tTemplate :: forall ast. Template ast -> Annotated ast
tDependents :: Maybe [Rewrite Universe]
tImports :: AnnotatedImports
tTemplate :: Annotated (Located ast)
..}) -> do
      -- graft template into target
      Located ast
t' <- Annotated (Located ast) -> TransformT m (Located ast)
forall ast (m :: * -> *).
(Data ast, Monad m) =>
Annotated ast -> TransformT m ast
graftA Annotated (Located ast)
tTemplate
      -- substitute for quantifiers in grafted template
      Located ast
r <- Substitution
-> Context -> Located ast -> TransformT m (Located ast)
forall (m :: * -> *) ast.
(MonadIO m, Data ast) =>
Substitution -> Context -> ast -> TransformT m ast
subst Substitution
sub Context
ctxt Located ast
t'
      -- copy appropriate annotations from old expression to template
      Located ast -> Located ast -> TransformT m ()
forall a b (m :: * -> *).
(HasCallStack, Data a, Data b, Monad m) =>
Located a -> Located b -> TransformT m ()
addAllAnnsT Located ast
e Located ast
r
      -- add parens to template if needed
      ((Located (HsExpr GhcPs) -> TransformT m (Located (HsExpr GhcPs)))
-> Located ast -> TransformT m (Located ast)
forall (m :: * -> *) a b.
(Monad m, Typeable a, Typeable b) =>
(b -> m b) -> a -> m a
mkM (Context
-> Located (HsExpr GhcPs) -> TransformT m (Located (HsExpr GhcPs))
forall (m :: * -> *).
Monad m =>
Context
-> Located (HsExpr GhcPs) -> TransformT m (Located (HsExpr GhcPs))
parenify Context
ctxt) (Located ast -> TransformT m (Located ast))
-> (Located (HsType GhcPs)
    -> TransformT m (Located (HsType GhcPs)))
-> Located ast
-> TransformT m (Located ast)
forall (m :: * -> *) a b.
(Monad m, Typeable a, Typeable b) =>
(a -> m a) -> (b -> m b) -> a -> m a
`extM` Context
-> Located (HsType GhcPs) -> TransformT m (Located (HsType GhcPs))
forall (m :: * -> *).
Monad m =>
Context
-> Located (HsType GhcPs) -> TransformT m (Located (HsType GhcPs))
parenifyT Context
ctxt (Located ast -> TransformT m (Located ast))
-> (Located (Pat GhcPs) -> TransformT m (Located (Pat GhcPs)))
-> Located ast
-> TransformT m (Located ast)
forall (m :: * -> *) a b.
(Monad m, Typeable a, Typeable b) =>
(a -> m a) -> (b -> m b) -> a -> m a
`extM` Context
-> Located (Pat GhcPs) -> TransformT m (Located (Pat GhcPs))
forall (m :: * -> *).
Monad m =>
Context
-> Located (Pat GhcPs) -> TransformT m (Located (Pat GhcPs))
parenifyP Context
ctxt) Located ast
r

  [Located ast] -> ListT (TransformT m) (Located ast)
forall (m :: * -> *) (f :: * -> *) a.
(Monad m, Foldable f) =>
f a -> ListT m a
fromFoldable (Located ast
e Located ast -> [Located ast] -> [Located ast]
forall a. a -> [a] -> [a]
: [Located ast]
elaborations)

-- | Find the first 'valid' match.
-- Runs the user's 'MatchResultTransformer' and sanity checks the result.
allMatches
  :: (Matchable ast, MonadIO m)
  => Context
  -> [(Substitution, RewriterResult Universe)]
  -> TransformT m [MatchResult ast]
allMatches :: Context
-> [(Substitution, RewriterResult Universe)]
-> TransformT m [MatchResult ast]
allMatches Context
_ [] = [MatchResult ast] -> TransformT m [MatchResult ast]
forall (m :: * -> *) a. Monad m => a -> m a
return []
allMatches Context
ctxt [(Substitution, RewriterResult Universe)]
matchResults = do
  [(Quantifiers, MatchResult Universe)]
results <-
    [(Substitution, RewriterResult Universe)]
-> ((Substitution, RewriterResult Universe)
    -> TransformT m (Quantifiers, MatchResult Universe))
-> TransformT m [(Quantifiers, MatchResult Universe)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [(Substitution, RewriterResult Universe)]
matchResults (((Substitution, RewriterResult Universe)
  -> TransformT m (Quantifiers, MatchResult Universe))
 -> TransformT m [(Quantifiers, MatchResult Universe)])
-> ((Substitution, RewriterResult Universe)
    -> TransformT m (Quantifiers, MatchResult Universe))
-> TransformT m [(Quantifiers, MatchResult Universe)]
forall a b. (a -> b) -> a -> b
$ \(Substitution
sub, RewriterResult{SrcSpan
Quantifiers
Template Universe
MatchResultTransformer
rrTemplate :: forall ast. RewriterResult ast -> Template ast
rrTransformer :: forall ast. RewriterResult ast -> MatchResultTransformer
rrQuantifiers :: forall ast. RewriterResult ast -> Quantifiers
rrOrigin :: forall ast. RewriterResult ast -> SrcSpan
rrTemplate :: Template Universe
rrTransformer :: MatchResultTransformer
rrQuantifiers :: Quantifiers
rrOrigin :: SrcSpan
..}) -> do
      MatchResult Universe
result <- m (MatchResult Universe) -> TransformT m (MatchResult Universe)
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m (MatchResult Universe) -> TransformT m (MatchResult Universe))
-> m (MatchResult Universe) -> TransformT m (MatchResult Universe)
forall a b. (a -> b) -> a -> b
$ IO (MatchResult Universe) -> m (MatchResult Universe)
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO (MatchResult Universe) -> m (MatchResult Universe))
-> IO (MatchResult Universe) -> m (MatchResult Universe)
forall a b. (a -> b) -> a -> b
$ MatchResultTransformer
rrTransformer Context
ctxt (MatchResult Universe -> IO (MatchResult Universe))
-> MatchResult Universe -> IO (MatchResult Universe)
forall a b. (a -> b) -> a -> b
$ Substitution -> Template Universe -> MatchResult Universe
forall ast. Substitution -> Template ast -> MatchResult ast
MatchResult Substitution
sub Template Universe
rrTemplate
      (Quantifiers, MatchResult Universe)
-> TransformT m (Quantifiers, MatchResult Universe)
forall (m :: * -> *) a. Monad m => a -> m a
return (Quantifiers
rrQuantifiers, MatchResult Universe
result)
  [MatchResult ast] -> TransformT m [MatchResult ast]
forall (m :: * -> *) a. Monad m => a -> m a
return
    [ Universe -> ast
forall ast. Matchable ast => Universe -> ast
project (Universe -> ast) -> MatchResult Universe -> MatchResult ast
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> MatchResult Universe
result
    | (Quantifiers
quantifiers, result :: MatchResult Universe
result@(MatchResult Substitution
sub' Template Universe
_)) <- [(Quantifiers, MatchResult Universe)]
results
      -- Check that all quantifiers from the original rewrite have mappings
      -- in the resulting substitution. This is mostly to prevent a bad
      -- user-defined MatchResultTransformer from causing havok.
    , Maybe [HoleVal] -> Bool
forall a. Maybe a -> Bool
isJust (Maybe [HoleVal] -> Bool) -> Maybe [HoleVal] -> Bool
forall a b. (a -> b) -> a -> b
$ [Maybe HoleVal] -> Maybe [HoleVal]
forall (t :: * -> *) (m :: * -> *) a.
(Traversable t, Monad m) =>
t (m a) -> m (t a)
sequence [ FastString -> Substitution -> Maybe HoleVal
lookupSubst FastString
q Substitution
sub' | FastString
q <- Quantifiers -> [FastString]
qList Quantifiers
quantifiers ]
    ]