-- Copyright (c) Facebook, Inc. and its affiliates.
--
-- This source code is licensed under the MIT license found in the
-- LICENSE file in the root directory of this source tree.
--
{-# LANGUAGE CPP #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE ScopedTypeVariables #-}
module Retrie.Context
  ( ContextUpdater
  , updateContext
  , emptyContext
  ) where

import Control.Monad.IO.Class
import Data.Char (isDigit)
import Data.Either (partitionEithers)
import Data.Generics hiding (Fixity)
import Data.List
import Data.Maybe

import Retrie.AlphaEnv
import Retrie.ExactPrint
import Retrie.Fixity
import Retrie.FreeVars
import Retrie.GHC
import Retrie.Substitution
import Retrie.SYB
import Retrie.Types
import Retrie.Universe

-------------------------------------------------------------------------------

-- | Type of context update functions for 'apply'.
-- When defining your own 'ContextUpdater', you probably want to extend
-- 'updateContext' using SYB combinators such as 'mkQ' and 'extQ'.
type ContextUpdater = forall m. MonadIO m => GenericCU (TransformT m) Context

-- | Default context update function.
updateContext :: forall m. MonadIO m => GenericCU (TransformT m) Context
updateContext :: GenericCU (TransformT m) Context
updateContext Context
c Int
i =
  TransformT m Context -> a -> TransformT m Context
forall a b. a -> b -> a
const (Context -> TransformT m Context
forall (m :: * -> *) a. Monad m => a -> m a
return Context
c)
    (a -> TransformT m Context)
-> (HsExpr GhcPs -> TransformT m Context)
-> a
-> TransformT m Context
forall a b q.
(Typeable a, Typeable b) =>
(a -> q) -> (b -> q) -> a -> q
`extQ` (Context -> TransformT m Context
forall (m :: * -> *) a. Monad m => a -> m a
return (Context -> TransformT m Context)
-> (HsExpr GhcPs -> Context)
-> HsExpr GhcPs
-> TransformT m Context
forall b c a. (b -> c) -> (a -> b) -> a -> c
. HsExpr GhcPs -> Context
updExp)
    (a -> TransformT m Context)
-> (HsType GhcPs -> TransformT m Context)
-> a
-> TransformT m Context
forall a b q.
(Typeable a, Typeable b) =>
(a -> q) -> (b -> q) -> a -> q
`extQ` (Context -> TransformT m Context
forall (m :: * -> *) a. Monad m => a -> m a
return (Context -> TransformT m Context)
-> (HsType GhcPs -> Context)
-> HsType GhcPs
-> TransformT m Context
forall b c a. (b -> c) -> (a -> b) -> a -> c
. HsType GhcPs -> Context
updType)
    (a -> TransformT m Context)
-> (Match GhcPs (LHsExpr GhcPs) -> TransformT m Context)
-> a
-> TransformT m Context
forall a b q.
(Typeable a, Typeable b) =>
(a -> q) -> (b -> q) -> a -> q
`extQ` (Context -> TransformT m Context
forall (m :: * -> *) a. Monad m => a -> m a
return (Context -> TransformT m Context)
-> (Match GhcPs (LHsExpr GhcPs) -> Context)
-> Match GhcPs (LHsExpr GhcPs)
-> TransformT m Context
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Match GhcPs (LHsExpr GhcPs) -> Context
updMatch)
    (a -> TransformT m Context)
-> (GRHSs GhcPs (LHsExpr GhcPs) -> TransformT m Context)
-> a
-> TransformT m Context
forall a b q.
(Typeable a, Typeable b) =>
(a -> q) -> (b -> q) -> a -> q
`extQ` (Context -> TransformT m Context
forall (m :: * -> *) a. Monad m => a -> m a
return (Context -> TransformT m Context)
-> (GRHSs GhcPs (LHsExpr GhcPs) -> Context)
-> GRHSs GhcPs (LHsExpr GhcPs)
-> TransformT m Context
forall b c a. (b -> c) -> (a -> b) -> a -> c
. GRHSs GhcPs (LHsExpr GhcPs) -> Context
updGRHSs)
    (a -> TransformT m Context)
-> (GRHS GhcPs (LHsExpr GhcPs) -> TransformT m Context)
-> a
-> TransformT m Context
forall a b q.
(Typeable a, Typeable b) =>
(a -> q) -> (b -> q) -> a -> q
`extQ` (Context -> TransformT m Context
forall (m :: * -> *) a. Monad m => a -> m a
return (Context -> TransformT m Context)
-> (GRHS GhcPs (LHsExpr GhcPs) -> Context)
-> GRHS GhcPs (LHsExpr GhcPs)
-> TransformT m Context
forall b c a. (b -> c) -> (a -> b) -> a -> c
. GRHS GhcPs (LHsExpr GhcPs) -> Context
updGRHS)
    (a -> TransformT m Context)
-> (Stmt GhcPs (LHsExpr GhcPs) -> TransformT m Context)
-> a
-> TransformT m Context
forall a b q.
(Typeable a, Typeable b) =>
(a -> q) -> (b -> q) -> a -> q
`extQ` (Context -> TransformT m Context
forall (m :: * -> *) a. Monad m => a -> m a
return (Context -> TransformT m Context)
-> (Stmt GhcPs (LHsExpr GhcPs) -> Context)
-> Stmt GhcPs (LHsExpr GhcPs)
-> TransformT m Context
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stmt GhcPs (LHsExpr GhcPs) -> Context
updStmt)
    (a -> TransformT m Context)
-> (Pat GhcPs -> TransformT m Context) -> a -> TransformT m Context
forall a b q.
(Typeable a, Typeable b) =>
(a -> q) -> (b -> q) -> a -> q
`extQ` (Context -> TransformT m Context
forall (m :: * -> *) a. Monad m => a -> m a
return (Context -> TransformT m Context)
-> (Pat GhcPs -> Context) -> Pat GhcPs -> TransformT m Context
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Pat GhcPs -> Context
updPat)
    (a -> TransformT m Context)
-> ([LStmt GhcPs (LHsExpr GhcPs)] -> TransformT m Context)
-> a
-> TransformT m Context
forall a b q.
(Typeable a, Typeable b) =>
(a -> q) -> (b -> q) -> a -> q
`extQ` [LStmt GhcPs (LHsExpr GhcPs)] -> TransformT m Context
updStmtList
    (a -> TransformT m Context)
-> (HsBind GhcPs -> TransformT m Context)
-> a
-> TransformT m Context
forall a b q.
(Typeable a, Typeable b) =>
(a -> q) -> (b -> q) -> a -> q
`extQ` (Context -> TransformT m Context
forall (m :: * -> *) a. Monad m => a -> m a
return (Context -> TransformT m Context)
-> (HsBind GhcPs -> Context)
-> HsBind GhcPs
-> TransformT m Context
forall b c a. (b -> c) -> (a -> b) -> a -> c
. HsBind GhcPs -> Context
updHsBind)
    (a -> TransformT m Context)
-> (TyClDecl GhcPs -> TransformT m Context)
-> a
-> TransformT m Context
forall a b q.
(Typeable a, Typeable b) =>
(a -> q) -> (b -> q) -> a -> q
`extQ` (Context -> TransformT m Context
forall (m :: * -> *) a. Monad m => a -> m a
return (Context -> TransformT m Context)
-> (TyClDecl GhcPs -> Context)
-> TyClDecl GhcPs
-> TransformT m Context
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TyClDecl GhcPs -> Context
updTyClDecl)
  where
    neverParen :: Context
neverParen = Context
c { ctxtParentPrec :: ParentPrec
ctxtParentPrec = ParentPrec
NeverParen }

    updExp :: HsExpr GhcPs -> Context
    updExp :: HsExpr GhcPs -> Context
updExp HsApp{} =
      Context
c { ctxtParentPrec :: ParentPrec
ctxtParentPrec = Fixity -> ParentPrec
HasPrec (Fixity -> ParentPrec) -> Fixity -> ParentPrec
forall a b. (a -> b) -> a -> b
$ SourceText -> Int -> FixityDirection -> Fixity
Fixity (String -> SourceText
SourceText String
"HsApp") (Int
10 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
firstChild) FixityDirection
InfixL }
    -- Reason for 10 + i: (i is index of child, 0 = left, 1 = right)
    -- In left child, prec is 10, so HsApp child will NOT get paren'd
    -- In right child, prec is 11, so every child gets paren'd (unless atomic)
    updExp (OpApp XOpApp GhcPs
_ LHsExpr GhcPs
_ LHsExpr GhcPs
op LHsExpr GhcPs
_) = Context
c { ctxtParentPrec :: ParentPrec
ctxtParentPrec = Fixity -> ParentPrec
HasPrec (Fixity -> ParentPrec) -> Fixity -> ParentPrec
forall a b. (a -> b) -> a -> b
$ LHsExpr GhcPs -> FixityEnv -> Fixity
lookupOp LHsExpr GhcPs
op (Context -> FixityEnv
ctxtFixityEnv Context
c) }
    updExp (HsLet XLet GhcPs
_ LHsLocalBinds GhcPs
lbs LHsExpr GhcPs
_) = Context -> [RdrName] -> Context
addInScope Context
neverParen ([RdrName] -> Context) -> [RdrName] -> Context
forall a b. (a -> b) -> a -> b
$ HsLocalBindsLR GhcPs GhcPs -> [IdP GhcPs]
forall (idL :: Pass) (idR :: Pass).
HsLocalBindsLR (GhcPass idL) (GhcPass idR) -> [IdP (GhcPass idL)]
collectLocalBinders (HsLocalBindsLR GhcPs GhcPs -> [IdP GhcPs])
-> HsLocalBindsLR GhcPs GhcPs -> [IdP GhcPs]
forall a b. (a -> b) -> a -> b
$ LHsLocalBinds GhcPs -> SrcSpanLess (LHsLocalBinds GhcPs)
forall a. HasSrcSpan a => a -> SrcSpanLess a
unLoc LHsLocalBinds GhcPs
lbs
    updExp HsExpr GhcPs
_ = Context
neverParen

    updType :: HsType GhcPs -> Context
    updType :: HsType GhcPs -> Context
updType HsAppTy{}
      | Int
i Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
firstChild = Context
c { ctxtParentPrec :: ParentPrec
ctxtParentPrec = ParentPrec
IsHsAppsTy }
    updType HsType GhcPs
_ = Context
neverParen

    updMatch :: Match GhcPs (LHsExpr GhcPs) -> Context
    updMatch :: Match GhcPs (LHsExpr GhcPs) -> Context
updMatch
      | Int
i Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
2  -- m_pats field
      = Context -> [RdrName] -> Context
addInScope Context
c{ctxtParentPrec :: ParentPrec
ctxtParentPrec = ParentPrec
IsLhs} ([RdrName] -> Context)
-> (Match GhcPs (LHsExpr GhcPs) -> [RdrName])
-> Match GhcPs (LHsExpr GhcPs)
-> Context
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Located (Pat GhcPs)] -> [RdrName]
forall (p :: Pass). [LPat (GhcPass p)] -> [IdP (GhcPass p)]
collectPatsBinders ([Located (Pat GhcPs)] -> [RdrName])
-> (Match GhcPs (LHsExpr GhcPs) -> [Located (Pat GhcPs)])
-> Match GhcPs (LHsExpr GhcPs)
-> [RdrName]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Match GhcPs (LHsExpr GhcPs) -> [Located (Pat GhcPs)]
forall p body. Match p body -> [LPat p]
m_pats
      | Bool
otherwise = Context -> [RdrName] -> Context
addInScope Context
neverParen ([RdrName] -> Context)
-> (Match GhcPs (LHsExpr GhcPs) -> [RdrName])
-> Match GhcPs (LHsExpr GhcPs)
-> Context
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Located (Pat GhcPs)] -> [RdrName]
forall (p :: Pass). [LPat (GhcPass p)] -> [IdP (GhcPass p)]
collectPatsBinders ([Located (Pat GhcPs)] -> [RdrName])
-> (Match GhcPs (LHsExpr GhcPs) -> [Located (Pat GhcPs)])
-> Match GhcPs (LHsExpr GhcPs)
-> [RdrName]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Match GhcPs (LHsExpr GhcPs) -> [Located (Pat GhcPs)]
forall p body. Match p body -> [LPat p]
m_pats
      where

    updGRHSs :: GRHSs GhcPs (LHsExpr GhcPs) -> Context
    updGRHSs :: GRHSs GhcPs (LHsExpr GhcPs) -> Context
updGRHSs = Context -> [RdrName] -> Context
addInScope Context
neverParen ([RdrName] -> Context)
-> (GRHSs GhcPs (LHsExpr GhcPs) -> [RdrName])
-> GRHSs GhcPs (LHsExpr GhcPs)
-> Context
forall b c a. (b -> c) -> (a -> b) -> a -> c
. HsLocalBindsLR GhcPs GhcPs -> [RdrName]
forall (idL :: Pass) (idR :: Pass).
HsLocalBindsLR (GhcPass idL) (GhcPass idR) -> [IdP (GhcPass idL)]
collectLocalBinders (HsLocalBindsLR GhcPs GhcPs -> [RdrName])
-> (GRHSs GhcPs (LHsExpr GhcPs) -> HsLocalBindsLR GhcPs GhcPs)
-> GRHSs GhcPs (LHsExpr GhcPs)
-> [RdrName]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. LHsLocalBinds GhcPs -> HsLocalBindsLR GhcPs GhcPs
forall a. HasSrcSpan a => a -> SrcSpanLess a
unLoc (LHsLocalBinds GhcPs -> HsLocalBindsLR GhcPs GhcPs)
-> (GRHSs GhcPs (LHsExpr GhcPs) -> LHsLocalBinds GhcPs)
-> GRHSs GhcPs (LHsExpr GhcPs)
-> HsLocalBindsLR GhcPs GhcPs
forall b c a. (b -> c) -> (a -> b) -> a -> c
. GRHSs GhcPs (LHsExpr GhcPs) -> LHsLocalBinds GhcPs
forall p body. GRHSs p body -> LHsLocalBinds p
grhssLocalBinds

    updGRHS :: GRHS GhcPs (LHsExpr GhcPs) -> Context
#if __GLASGOW_HASKELL__ < 900
    updGRHS :: GRHS GhcPs (LHsExpr GhcPs) -> Context
updGRHS XGRHS{} = Context
neverParen
#endif
    updGRHS (GRHS XCGRHS GhcPs (LHsExpr GhcPs)
_ [LStmt GhcPs (LHsExpr GhcPs)]
gs LHsExpr GhcPs
_)
        -- binders are in scope over the body (right child) only
      | Int
i Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
firstChild = Context -> [RdrName] -> Context
addInScope Context
neverParen [RdrName]
bs
      | Bool
otherwise = (Context, [RdrName]) -> Context
forall a b. (a, b) -> a
fst ((Context, [RdrName]) -> Context)
-> (Context, [RdrName]) -> Context
forall a b. (a -> b) -> a -> b
$ Context -> [RdrName] -> (Context, [RdrName])
updateSubstitution Context
neverParen [RdrName]
bs
      where
        bs :: [IdP GhcPs]
bs = [LStmt GhcPs (LHsExpr GhcPs)] -> [IdP GhcPs]
forall (idL :: Pass) (idR :: Pass) body.
[LStmtLR (GhcPass idL) (GhcPass idR) body] -> [IdP (GhcPass idL)]
collectLStmtsBinders [LStmt GhcPs (LHsExpr GhcPs)]
gs

    updStmt :: Stmt GhcPs (LHsExpr GhcPs) -> Context
    updStmt :: Stmt GhcPs (LHsExpr GhcPs) -> Context
updStmt Stmt GhcPs (LHsExpr GhcPs)
_ = Context
neverParen

    updStmtList :: [LStmt GhcPs (LHsExpr GhcPs)] -> TransformT m Context
    updStmtList :: [LStmt GhcPs (LHsExpr GhcPs)] -> TransformT m Context
updStmtList [] = Context -> TransformT m Context
forall (m :: * -> *) a. Monad m => a -> m a
return Context
neverParen
    updStmtList (LStmt GhcPs (LHsExpr GhcPs)
ls:[LStmt GhcPs (LHsExpr GhcPs)]
_)
        -- binders are in scope over tail of list (right child)
      | Int
i Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
0 = Context
-> [RdrName] -> LStmt GhcPs (LHsExpr GhcPs) -> TransformT m Context
forall k (m :: * -> *).
(Matchable k, MonadIO m) =>
Context -> [RdrName] -> k -> TransformT m Context
insertDependentRewrites Context
neverParen [RdrName]
bs LStmt GhcPs (LHsExpr GhcPs)
ls
        -- lets are recursive in do-blocks
      | L SrcSpan
_ (LetStmt XLetStmt GhcPs GhcPs (LHsExpr GhcPs)
_ (L SrcSpan
_ HsLocalBindsLR GhcPs GhcPs
bnds)) <- LStmt GhcPs (LHsExpr GhcPs)
ls =
          Context -> TransformT m Context
forall (m :: * -> *) a. Monad m => a -> m a
return (Context -> TransformT m Context)
-> Context -> TransformT m Context
forall a b. (a -> b) -> a -> b
$ Context -> [RdrName] -> Context
addInScope Context
neverParen ([RdrName] -> Context) -> [RdrName] -> Context
forall a b. (a -> b) -> a -> b
$ HsLocalBindsLR GhcPs GhcPs -> [IdP GhcPs]
forall (idL :: Pass) (idR :: Pass).
HsLocalBindsLR (GhcPass idL) (GhcPass idR) -> [IdP (GhcPass idL)]
collectLocalBinders HsLocalBindsLR GhcPs GhcPs
bnds
      | Bool
otherwise = Context -> TransformT m Context
forall (m :: * -> *) a. Monad m => a -> m a
return (Context -> TransformT m Context)
-> Context -> TransformT m Context
forall a b. (a -> b) -> a -> b
$ (Context, [RdrName]) -> Context
forall a b. (a, b) -> a
fst ((Context, [RdrName]) -> Context)
-> (Context, [RdrName]) -> Context
forall a b. (a -> b) -> a -> b
$ Context -> [RdrName] -> (Context, [RdrName])
updateSubstitution Context
neverParen [RdrName]
bs
      where
        bs :: [IdP GhcPs]
bs = LStmt GhcPs (LHsExpr GhcPs) -> [IdP GhcPs]
forall (idL :: Pass) (idR :: Pass) body.
LStmtLR (GhcPass idL) (GhcPass idR) body -> [IdP (GhcPass idL)]
collectLStmtBinders LStmt GhcPs (LHsExpr GhcPs)
ls

    updHsBind :: HsBind GhcPs -> Context
    updHsBind :: HsBind GhcPs -> Context
updHsBind FunBind{[Tickish Id]
HsWrapper
MatchGroup GhcPs (LHsExpr GhcPs)
XFunBind GhcPs GhcPs
Located (IdP GhcPs)
fun_ext :: forall idL idR. HsBindLR idL idR -> XFunBind idL idR
fun_id :: forall idL idR. HsBindLR idL idR -> Located (IdP idL)
fun_matches :: forall idL idR. HsBindLR idL idR -> MatchGroup idR (LHsExpr idR)
fun_co_fn :: forall idL idR. HsBindLR idL idR -> HsWrapper
fun_tick :: forall idL idR. HsBindLR idL idR -> [Tickish Id]
fun_tick :: [Tickish Id]
fun_co_fn :: HsWrapper
fun_matches :: MatchGroup GhcPs (LHsExpr GhcPs)
fun_id :: Located (IdP GhcPs)
fun_ext :: XFunBind GhcPs GhcPs
..} =
      let rdr :: SrcSpanLess (Located RdrName)
rdr = Located RdrName -> SrcSpanLess (Located RdrName)
forall a. HasSrcSpan a => a -> SrcSpanLess a
unLoc Located (IdP GhcPs)
Located RdrName
fun_id
      in Context -> [RdrName] -> Context
addBinders (Context -> [RdrName] -> Context
addInScope Context
neverParen [RdrName
rdr]) [RdrName
rdr]
    updHsBind HsBind GhcPs
_ = Context
neverParen

    updTyClDecl :: TyClDecl GhcPs -> Context
    updTyClDecl :: TyClDecl GhcPs -> Context
updTyClDecl SynDecl{LHsQTyVars GhcPs
XSynDecl GhcPs
LexicalFixity
LHsType GhcPs
Located (IdP GhcPs)
tcdSExt :: forall pass. TyClDecl pass -> XSynDecl pass
tcdLName :: forall pass. TyClDecl pass -> Located (IdP pass)
tcdTyVars :: forall pass. TyClDecl pass -> LHsQTyVars pass
tcdFixity :: forall pass. TyClDecl pass -> LexicalFixity
tcdRhs :: forall pass. TyClDecl pass -> LHsType pass
tcdRhs :: LHsType GhcPs
tcdFixity :: LexicalFixity
tcdTyVars :: LHsQTyVars GhcPs
tcdLName :: Located (IdP GhcPs)
tcdSExt :: XSynDecl GhcPs
..} = Context -> [RdrName] -> Context
addInScope Context
neverParen [Located RdrName -> SrcSpanLess (Located RdrName)
forall a. HasSrcSpan a => a -> SrcSpanLess a
unLoc Located (IdP GhcPs)
Located RdrName
tcdLName]
    updTyClDecl DataDecl{HsDataDefn GhcPs
LHsQTyVars GhcPs
XDataDecl GhcPs
LexicalFixity
Located (IdP GhcPs)
tcdDExt :: forall pass. TyClDecl pass -> XDataDecl pass
tcdDataDefn :: forall pass. TyClDecl pass -> HsDataDefn pass
tcdDataDefn :: HsDataDefn GhcPs
tcdFixity :: LexicalFixity
tcdTyVars :: LHsQTyVars GhcPs
tcdLName :: Located (IdP GhcPs)
tcdDExt :: XDataDecl GhcPs
tcdLName :: forall pass. TyClDecl pass -> Located (IdP pass)
tcdTyVars :: forall pass. TyClDecl pass -> LHsQTyVars pass
tcdFixity :: forall pass. TyClDecl pass -> LexicalFixity
..} = Context -> [RdrName] -> Context
addInScope Context
neverParen [Located RdrName -> SrcSpanLess (Located RdrName)
forall a. HasSrcSpan a => a -> SrcSpanLess a
unLoc Located (IdP GhcPs)
Located RdrName
tcdLName]
    updTyClDecl ClassDecl{[LHsFunDep GhcPs]
[LFamilyDecl GhcPs]
[LTyFamDefltDecl GhcPs]
[LDocDecl]
[LSig GhcPs]
LHsQTyVars GhcPs
XClassDecl GhcPs
LHsBinds GhcPs
LexicalFixity
LHsContext GhcPs
Located (IdP GhcPs)
tcdCExt :: forall pass. TyClDecl pass -> XClassDecl pass
tcdCtxt :: forall pass. TyClDecl pass -> LHsContext pass
tcdFDs :: forall pass. TyClDecl pass -> [LHsFunDep pass]
tcdSigs :: forall pass. TyClDecl pass -> [LSig pass]
tcdMeths :: forall pass. TyClDecl pass -> LHsBinds pass
tcdATs :: forall pass. TyClDecl pass -> [LFamilyDecl pass]
tcdATDefs :: forall pass. TyClDecl pass -> [LTyFamDefltDecl pass]
tcdDocs :: forall pass. TyClDecl pass -> [LDocDecl]
tcdDocs :: [LDocDecl]
tcdATDefs :: [LTyFamDefltDecl GhcPs]
tcdATs :: [LFamilyDecl GhcPs]
tcdMeths :: LHsBinds GhcPs
tcdSigs :: [LSig GhcPs]
tcdFDs :: [LHsFunDep GhcPs]
tcdFixity :: LexicalFixity
tcdTyVars :: LHsQTyVars GhcPs
tcdLName :: Located (IdP GhcPs)
tcdCtxt :: LHsContext GhcPs
tcdCExt :: XClassDecl GhcPs
tcdLName :: forall pass. TyClDecl pass -> Located (IdP pass)
tcdTyVars :: forall pass. TyClDecl pass -> LHsQTyVars pass
tcdFixity :: forall pass. TyClDecl pass -> LexicalFixity
..} = Context -> [RdrName] -> Context
addInScope Context
neverParen [Located RdrName -> SrcSpanLess (Located RdrName)
forall a. HasSrcSpan a => a -> SrcSpanLess a
unLoc Located (IdP GhcPs)
Located RdrName
tcdLName]
    updTyClDecl TyClDecl GhcPs
_ = Context
neverParen

    updPat :: Pat GhcPs -> Context
    updPat :: Pat GhcPs -> Context
updPat Pat GhcPs
_ = Context
neverParen

-- | Create an empty 'Context' with given 'FixityEnv', rewriter, and dependent
-- rewrite generator.
emptyContext :: FixityEnv -> Rewriter -> Rewriter -> Context
emptyContext :: FixityEnv -> Rewriter -> Rewriter -> Context
emptyContext FixityEnv
ctxtFixityEnv Rewriter
ctxtRewriter Rewriter
ctxtDependents = Context :: [RdrName]
-> Rewriter
-> FixityEnv
-> AlphaEnv
-> ParentPrec
-> Rewriter
-> Maybe Substitution
-> Context
Context{[RdrName]
Maybe Substitution
FixityEnv
AlphaEnv
Rewriter
ParentPrec
forall a. [a]
forall a. Maybe a
ctxtSubst :: Maybe Substitution
ctxtRewriter :: Rewriter
ctxtInScope :: AlphaEnv
ctxtDependents :: Rewriter
ctxtBinders :: [RdrName]
ctxtSubst :: forall a. Maybe a
ctxtParentPrec :: ParentPrec
ctxtInScope :: AlphaEnv
ctxtBinders :: forall a. [a]
ctxtDependents :: Rewriter
ctxtRewriter :: Rewriter
ctxtFixityEnv :: FixityEnv
ctxtFixityEnv :: FixityEnv
ctxtParentPrec :: ParentPrec
..}
  where
    ctxtBinders :: [a]
ctxtBinders = []
    ctxtInScope :: AlphaEnv
ctxtInScope = AlphaEnv
emptyAlphaEnv
    ctxtParentPrec :: ParentPrec
ctxtParentPrec = ParentPrec
NeverParen
    ctxtSubst :: Maybe a
ctxtSubst = Maybe a
forall a. Maybe a
Nothing

-- Deal with Trees-That-Grow adding extension points
-- as the first child everywhere.
firstChild :: Int
firstChild :: Int
firstChild = Int
1

-- | Add dependent rewrites to 'ctxtRewriter' if necessary.
insertDependentRewrites
  :: (Matchable k, MonadIO m) => Context -> [RdrName] -> k -> TransformT m Context
insertDependentRewrites :: Context -> [RdrName] -> k -> TransformT m Context
insertDependentRewrites Context
c [RdrName]
bs k
x = do
  MatchResult k
r <- (RewriterResult Universe -> RewriterResult Universe)
-> Context -> Rewriter -> k -> TransformT m (MatchResult k)
forall ast (m :: * -> *).
(Matchable ast, MonadIO m) =>
(RewriterResult Universe -> RewriterResult Universe)
-> Context -> Rewriter -> ast -> TransformT m (MatchResult ast)
runRewriter RewriterResult Universe -> RewriterResult Universe
forall a. a -> a
id Context
c (Context -> Rewriter
ctxtDependents Context
c) k
x
  let
    c' :: Context
c' = Context -> [RdrName] -> Context
addInScope Context
c [RdrName]
bs
  case MatchResult k
r of
    MatchResult k
NoMatch -> Context -> TransformT m Context
forall (m :: * -> *) a. Monad m => a -> m a
return Context
c'
    MatchResult Substitution
_ Template{Maybe [Rewrite Universe]
Annotated k
AnnotatedImports
tDependents :: forall ast. Template ast -> Maybe [Rewrite Universe]
tImports :: forall ast. Template ast -> AnnotatedImports
tTemplate :: forall ast. Template ast -> Annotated ast
tDependents :: Maybe [Rewrite Universe]
tImports :: AnnotatedImports
tTemplate :: Annotated k
..} -> do
      let
        rrs :: [Rewrite Universe]
rrs = [Rewrite Universe]
-> Maybe [Rewrite Universe] -> [Rewrite Universe]
forall a. a -> Maybe a -> a
fromMaybe [] Maybe [Rewrite Universe]
tDependents
        ds :: [Rewrite Universe]
ds = [Rewrite Universe] -> [Rewrite Universe]
forall ast. [Rewrite ast] -> [Rewrite ast]
rewritesWithDependents [Rewrite Universe]
rrs
        f :: [Rewrite Universe] -> Rewriter
f = (Rewrite Universe -> Rewriter) -> [Rewrite Universe] -> Rewriter
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap (AlphaEnv -> Rewrite Universe -> Rewriter
forall ast. Matchable ast => AlphaEnv -> Rewrite ast -> Rewriter
mkLocalRewriter (AlphaEnv -> Rewrite Universe -> Rewriter)
-> AlphaEnv -> Rewrite Universe -> Rewriter
forall a b. (a -> b) -> a -> b
$ Context -> AlphaEnv
ctxtInScope Context
c')
      Context -> TransformT m Context
forall (m :: * -> *) a. Monad m => a -> m a
return Context
c'
        { ctxtRewriter :: Rewriter
ctxtRewriter = [Rewrite Universe] -> Rewriter
f [Rewrite Universe]
rrs Rewriter -> Rewriter -> Rewriter
forall a. Semigroup a => a -> a -> a
<> Context -> Rewriter
ctxtRewriter Context
c'
        , ctxtDependents :: Rewriter
ctxtDependents = [Rewrite Universe] -> Rewriter
f [Rewrite Universe]
ds Rewriter -> Rewriter -> Rewriter
forall a. Semigroup a => a -> a -> a
<> Context -> Rewriter
ctxtDependents Context
c'
        }

-- | Add set of binders to 'ctxtInScope'.
addInScope :: Context -> [RdrName] -> Context
addInScope :: Context -> [RdrName] -> Context
addInScope Context
c [RdrName]
bs =
  Context
c' { ctxtInScope :: AlphaEnv
ctxtInScope = (RdrName -> AlphaEnv -> AlphaEnv)
-> AlphaEnv -> [RdrName] -> AlphaEnv
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr RdrName -> AlphaEnv -> AlphaEnv
extendAlphaEnv (Context -> AlphaEnv
ctxtInScope Context
c') [RdrName]
bs' }
  where
    (Context
c', [RdrName]
bs') = Context -> [RdrName] -> (Context, [RdrName])
updateSubstitution Context
c [RdrName]
bs

-- | Add set of binders to 'ctxtBinders'.
addBinders :: Context -> [RdrName] -> Context
addBinders :: Context -> [RdrName] -> Context
addBinders Context
c [RdrName]
bs = Context
c { ctxtBinders :: [RdrName]
ctxtBinders = [RdrName]
bs [RdrName] -> [RdrName] -> [RdrName]
forall a. [a] -> [a] -> [a]
++ Context -> [RdrName]
ctxtBinders Context
c }

-- Capture-avoiding substitution
--------------------------------------------------------------------------------

-- | Update the Context's substitution appropriately for a set of binders.
-- Returns a new Context and a potentially alpha-renamed set of binders.
updateSubstitution :: Context -> [RdrName] -> (Context, [RdrName])
updateSubstitution :: Context -> [RdrName] -> (Context, [RdrName])
updateSubstitution Context
c [RdrName]
rdrs =
  case Context -> Maybe Substitution
ctxtSubst Context
c of
    Maybe Substitution
Nothing -> (Context
c, [RdrName]
rdrs)
    Just Substitution
sub ->
      let
        -- This prevents substituting for 'x' under a binding for 'x'.
        sub' :: Substitution
sub' = Substitution -> [FastString] -> Substitution
deleteSubst Substitution
sub ([FastString] -> Substitution) -> [FastString] -> Substitution
forall a b. (a -> b) -> a -> b
$ (RdrName -> FastString) -> [RdrName] -> [FastString]
forall a b. (a -> b) -> [a] -> [b]
map RdrName -> FastString
rdrFS [RdrName]
rdrs
        -- Compute free vars of substitution that could possibly be captured.
        fvs :: FreeVars
fvs = Substitution -> FreeVars
substFVs Substitution
sub'
        -- Partition binders into noncapturing and capturing.
        ([RdrName]
noncapturing, [(RdrName, RdrName)]
capturing) =
          [Either RdrName (RdrName, RdrName)]
-> ([RdrName], [(RdrName, RdrName)])
forall a b. [Either a b] -> ([a], [b])
partitionEithers ([Either RdrName (RdrName, RdrName)]
 -> ([RdrName], [(RdrName, RdrName)]))
-> [Either RdrName (RdrName, RdrName)]
-> ([RdrName], [(RdrName, RdrName)])
forall a b. (a -> b) -> a -> b
$ (RdrName -> Either RdrName (RdrName, RdrName))
-> [RdrName] -> [Either RdrName (RdrName, RdrName)]
forall a b. (a -> b) -> [a] -> [b]
map (FreeVars -> RdrName -> Either RdrName (RdrName, RdrName)
updateBinder FreeVars
fvs) [RdrName]
rdrs
        -- Extend substitution with alpha-renamings.
        alphaSub :: Substitution
alphaSub = (Substitution -> (FastString, HoleVal) -> Substitution)
-> Substitution -> [(FastString, HoleVal)] -> Substitution
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' ((FastString -> HoleVal -> Substitution)
-> (FastString, HoleVal) -> Substitution
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry ((FastString -> HoleVal -> Substitution)
 -> (FastString, HoleVal) -> Substitution)
-> (Substitution -> FastString -> HoleVal -> Substitution)
-> Substitution
-> (FastString, HoleVal)
-> Substitution
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Substitution -> FastString -> HoleVal -> Substitution
extendSubst) Substitution
sub'
          [ (RdrName -> FastString
rdrFS RdrName
rdr, RdrName -> HoleVal
HoleRdr RdrName
rdr') | (RdrName
rdr, RdrName
rdr') <- [(RdrName, RdrName)]
capturing ]
        -- There are no telescopes in source Haskell, so order doesn't matter.
        -- Capturing should be rare, so put it first to avoid quadratic append.
        rdrs' :: [RdrName]
rdrs' = ((RdrName, RdrName) -> RdrName)
-> [(RdrName, RdrName)] -> [RdrName]
forall a b. (a -> b) -> [a] -> [b]
map (RdrName, RdrName) -> RdrName
forall a b. (a, b) -> b
snd [(RdrName, RdrName)]
capturing [RdrName] -> [RdrName] -> [RdrName]
forall a. [a] -> [a] -> [a]
++ [RdrName]
noncapturing
      in (Context
c { ctxtSubst :: Maybe Substitution
ctxtSubst = Substitution -> Maybe Substitution
forall a. a -> Maybe a
Just Substitution
alphaSub }, [RdrName]
rdrs')

-- | Check if RdrName is in FreeVars.
--
-- If so, return a pair of it and its new name (Right).
-- If not, return it unchanged (Left).
updateBinder :: FreeVars -> RdrName -> Either RdrName (RdrName, RdrName)
updateBinder :: FreeVars -> RdrName -> Either RdrName (RdrName, RdrName)
updateBinder FreeVars
fvs RdrName
rdr
  | RdrName -> FreeVars -> Bool
elemFVs RdrName
rdr FreeVars
fvs = (RdrName, RdrName) -> Either RdrName (RdrName, RdrName)
forall a b. b -> Either a b
Right (RdrName
rdr, RdrName -> FreeVars -> RdrName
renameBinder RdrName
rdr FreeVars
fvs)
  | Bool
otherwise = RdrName -> Either RdrName (RdrName, RdrName)
forall a b. a -> Either a b
Left RdrName
rdr

-- | Given a RdrName, rename it to something not in given FreeVars.
--
--   x => x1
--   x1 => x2
--   x9 => x10
--
-- etc.
--
-- Only works on unqualified RdrNames. This is fine, as we only use this to
-- rename local binders.
renameBinder :: RdrName -> FreeVars -> RdrName
renameBinder :: RdrName -> FreeVars -> RdrName
renameBinder RdrName
rdr FreeVars
fvs = [RdrName] -> RdrName
forall a. [a] -> a
head
  [ RdrName
rdr'
  | Int
i <- [Int
n..]
  , let rdr' :: RdrName
rdr' = FastString -> RdrName
mkVarUnqual (FastString -> RdrName) -> FastString -> RdrName
forall a b. (a -> b) -> a -> b
$ String -> FastString
mkFastString (String -> FastString) -> String -> FastString
forall a b. (a -> b) -> a -> b
$ String
baseName String -> String -> String
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show Int
i
  , Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ RdrName
rdr' RdrName -> FreeVars -> Bool
`elemFVs` FreeVars
fvs
  ]
  where
    (String
ds, String
rest) = (Char -> Bool) -> String -> (String, String)
forall a. (a -> Bool) -> [a] -> ([a], [a])
span Char -> Bool
isDigit (String -> (String, String)) -> String -> (String, String)
forall a b. (a -> b) -> a -> b
$ String -> String
forall a. [a] -> [a]
reverse (String -> String) -> String -> String
forall a b. (a -> b) -> a -> b
$ OccName -> String
occNameString (OccName -> String) -> OccName -> String
forall a b. (a -> b) -> a -> b
$ RdrName -> OccName
forall name. HasOccName name => name -> OccName
occName RdrName
rdr

    baseName :: String
baseName = String -> String
forall a. [a] -> [a]
reverse String
rest

    n :: Int
    n :: Int
n | String -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null String
ds = Int
1
      | Bool
otherwise = String -> Int
forall a. Read a => String -> a
read (String -> String
forall a. [a] -> [a]
reverse String
ds) Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1