-- Copyright (c) Facebook, Inc. and its affiliates.
--
-- This source code is licensed under the MIT license found in the
-- LICENSE file in the root directory of this source tree.
--
{-# LANGUAGE CPP #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE RecordWildCards #-}
module Retrie.CPP
  ( CPP(..)
  , addImportsCPP
  , parseCPPFile
  , parseCPP
  , printCPP
    -- ** Internal interface exported for tests
  , cppFork
  ) where

import Data.Char (isSpace)
import Data.Function (on)
import Data.Functor.Identity
import Data.List (nubBy, sortOn)
import Data.Text (Text)
import qualified Data.Text as Text
import qualified Data.Text.IO as Text
import Debug.Trace
import Retrie.ExactPrint
import Retrie.GHC
import Retrie.Replace
#if __GLASGOW_HASKELL__ < 904
#else
import GHC.Types.PkgQual
#endif

-- Note [CPP]
-- We can't just run the pre-processor on files and then rewrite them, because
-- the rewrites will apply to a module that never exists as code! Exactprint
-- has no support for roundtripping CPP, because the GHC parser doesn't
-- actually parse it (it looks for the pragma and then delegates to the
-- pre-processor).
--
-- To solve this, we instead generate all possible versions of the module
-- (exponential in the number of #if directives :-P). We then apply rewrites
-- to all versions, and collect all the 'Replacement's that they generate.
-- We can then use these to splice results back into the original file.
--
-- Suprisingly, this works. It depends on a few observations:
--
-- * We don't need to actually evaluate any CPP directives. This is because
--   we want all versions of the file.
--
-- * Since we don't need to evaluate, we can simply replace all CPP directives
--   with blank lines and the locations of all AST elements in each version of
--   the module will be exactly the same as in the original module. This is the
--   key to splicing properly.
--
-- * Replacements can be spliced in directly with no smarts about binders, etc,
--   because retrie did the instantiation during matching.
--

-- The CPP Type ----------------------------------------------------------------

data CPP a
  = NoCPP a
  | CPP Text [AnnotatedImports] [a]

instance Functor CPP where
  fmap :: forall a b. (a -> b) -> CPP a -> CPP b
fmap a -> b
f (NoCPP a
x) = forall a. a -> CPP a
NoCPP (a -> b
f a
x)
  fmap a -> b
f (CPP Text
orig [AnnotatedImports]
is [a]
xs) = forall a. Text -> [AnnotatedImports] -> [a] -> CPP a
CPP Text
orig [AnnotatedImports]
is (forall a b. (a -> b) -> [a] -> [b]
map a -> b
f [a]
xs)

instance Foldable CPP where
  foldMap :: forall m a. Monoid m => (a -> m) -> CPP a -> m
foldMap a -> m
f (NoCPP a
x) = a -> m
f a
x
  foldMap a -> m
f (CPP Text
_ [AnnotatedImports]
_ [a]
xs) = forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap a -> m
f [a]
xs

instance Traversable CPP where
  traverse :: forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> CPP a -> f (CPP b)
traverse a -> f b
f (NoCPP a
x) = forall a. a -> CPP a
NoCPP forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> a -> f b
f a
x
  traverse a -> f b
f (CPP Text
orig [AnnotatedImports]
is [a]
xs) = forall a. Text -> [AnnotatedImports] -> [a] -> CPP a
CPP Text
orig [AnnotatedImports]
is forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse a -> f b
f [a]
xs

addImportsCPP
  :: [AnnotatedImports]
  -> CPP AnnotatedModule
  -> CPP AnnotatedModule
addImportsCPP :: [AnnotatedImports] -> CPP AnnotatedModule -> CPP AnnotatedModule
addImportsCPP [AnnotatedImports]
is (NoCPP AnnotatedModule
m) =
  forall a. a -> CPP a
NoCPP forall a b. (a -> b) -> a -> b
$ forall a. Identity a -> a
runIdentity forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) ast1 ast2.
Monad m =>
Annotated ast1 -> (ast1 -> TransformT m ast2) -> m (Annotated ast2)
transformA AnnotatedModule
m forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *).
Monad m =>
[AnnotatedImports]
-> Located HsModule -> TransformT m (Located HsModule)
insertImports [AnnotatedImports]
is
addImportsCPP [AnnotatedImports]
is (CPP Text
orig [AnnotatedImports]
is' [AnnotatedModule]
ms) = forall a. Text -> [AnnotatedImports] -> [a] -> CPP a
CPP Text
orig ([AnnotatedImports]
isforall a. [a] -> [a] -> [a]
++[AnnotatedImports]
is') [AnnotatedModule]
ms

-- Parsing a CPP Module --------------------------------------------------------

parseCPPFile
  :: (FilePath -> String -> IO AnnotatedModule)
  -> FilePath
  -> IO (CPP AnnotatedModule)
parseCPPFile :: (FilePath -> FilePath -> IO AnnotatedModule)
-> FilePath -> IO (CPP AnnotatedModule)
parseCPPFile FilePath -> FilePath -> IO AnnotatedModule
p FilePath
fp =
  -- read file strictly
  FilePath -> IO Text
Text.readFile FilePath
fp forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= forall (m :: * -> *).
Monad m =>
(FilePath -> m AnnotatedModule) -> Text -> m (CPP AnnotatedModule)
parseCPP (FilePath -> FilePath -> IO AnnotatedModule
p FilePath
fp)

parseCPP
  :: Monad m
  => (String -> m AnnotatedModule)
  -> Text -> m (CPP AnnotatedModule)
parseCPP :: forall (m :: * -> *).
Monad m =>
(FilePath -> m AnnotatedModule) -> Text -> m (CPP AnnotatedModule)
parseCPP FilePath -> m AnnotatedModule
p Text
orig
  | forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any Text -> Bool
isCPP (Text -> [Text]
Text.lines Text
orig) =
    forall a. Text -> [AnnotatedImports] -> [a] -> CPP a
CPP Text
orig [] forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (FilePath -> m AnnotatedModule
p forall b c a. (b -> c) -> (a -> b) -> a -> c
. Text -> FilePath
Text.unpack) (Text -> [Text]
cppFork Text
orig)
  | Bool
otherwise = forall a. a -> CPP a
NoCPP forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> FilePath -> m AnnotatedModule
p (Text -> FilePath
Text.unpack Text
orig)

-- Printing a CPP Module -------------------------------------------------------

printCPP :: [Replacement] -> CPP AnnotatedModule -> String
printCPP :: [Replacement] -> CPP AnnotatedModule -> FilePath
printCPP [Replacement]
_ (NoCPP AnnotatedModule
m) = forall ast. (Data ast, ExactPrint ast) => Annotated ast -> FilePath
printA AnnotatedModule
m
-- printCPP _ (NoCPP m) = error $ "printCPP:m=" ++ showAstA m
printCPP [Replacement]
repls (CPP Text
orig [AnnotatedImports]
is [AnnotatedModule]
ms) = Text -> FilePath
Text.unpack forall a b. (a -> b) -> a -> b
$ [Text] -> Text
Text.unlines forall a b. (a -> b) -> a -> b
$
  case [AnnotatedImports]
is of
    [] -> Text -> Int -> Int -> [(RealSrcSpan, FilePath)] -> [Text] -> [Text]
splice Text
"" Int
1 Int
1 [(RealSrcSpan, FilePath)]
sorted [Text]
origLines
    [AnnotatedImports]
_ ->
      Text -> Int -> Int -> [(RealSrcSpan, FilePath)] -> [Text] -> [Text]
splice
        ([Text] -> Text
Text.unlines [Text]
newHeader)
        (forall (t :: * -> *) a. Foldable t => t a -> Int
length [Text]
revHeader forall a. Num a => a -> a -> a
+ Int
1)
        Int
1
        [(RealSrcSpan, FilePath)]
sorted
        (forall a. [a] -> [a]
reverse [Text]
revDecls)
  where
    sorted :: [(RealSrcSpan, FilePath)]
sorted = forall b a. Ord b => (a -> b) -> [a] -> [a]
sortOn forall a b. (a, b) -> a
fst
      [ (RealSrcSpan
r, FilePath
replReplacement)
      | Replacement{FilePath
SrcSpan
replReplacement :: Replacement -> FilePath
replOriginal :: Replacement -> FilePath
replLocation :: Replacement -> SrcSpan
replOriginal :: FilePath
replLocation :: SrcSpan
replReplacement :: FilePath
..} <- [Replacement]
repls
      , Just RealSrcSpan
r <- [SrcSpan -> Maybe RealSrcSpan
getRealSpan SrcSpan
replLocation]
      ]

    origLines :: [Text]
origLines = Text -> [Text]
Text.lines Text
orig
    mbName :: Maybe ModuleName
mbName = forall l e. GenLocated l e -> e
unLoc forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> HsModule -> Maybe (GenLocated SrcSpanAnnA ModuleName)
hsmodName (forall l e. GenLocated l e -> e
unLoc forall a b. (a -> b) -> a -> b
$ forall ast. Annotated ast -> ast
astA forall a b. (a -> b) -> a -> b
$ forall a. [a] -> a
head [AnnotatedModule]
ms)
    importLines :: [Text]
importLines = forall a. Identity a -> a
runIdentity forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall ast. Annotated ast -> ast
astA forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) ast1 ast2.
Monad m =>
Annotated ast1 -> (ast1 -> TransformT m ast2) -> m (Annotated ast2)
transformA (Maybe ModuleName -> [AnnotatedImports] -> AnnotatedImports
filterAndFlatten Maybe ModuleName
mbName [AnnotatedImports]
is) forall a b. (a -> b) -> a -> b
$
      forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (FilePath -> Text
Text.pack forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. (a -> Bool) -> [a] -> [a]
dropWhile Char -> Bool
isSpace forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall ast. (Data ast, ExactPrint ast) => Annotated ast -> FilePath
printA) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall ast (m :: * -> *).
(Data ast, Monad m) =>
ast -> TransformT m (Annotated ast)
pruneA

    p :: Text -> Bool
p Text
t = Text -> Bool
isImport Text
t Bool -> Bool -> Bool
|| Text -> Bool
isModule Text
t Bool -> Bool -> Bool
|| Text -> Bool
isPragma Text
t
    ([Text]
revDecls, [Text]
revHeader) = forall a. (a -> Bool) -> [a] -> ([a], [a])
break Text -> Bool
p (forall a. [a] -> [a]
reverse [Text]
origLines)
    newHeader :: [Text]
newHeader = forall a. [a] -> [a]
reverse [Text]
revHeader forall a. [a] -> [a] -> [a]
++ [Text]
importLines

splice :: Text -> Int -> Int -> [(RealSrcSpan, String)] -> [Text] -> [Text]
splice :: Text -> Int -> Int -> [(RealSrcSpan, FilePath)] -> [Text] -> [Text]
splice Text
_ Int
_ Int
_ [(RealSrcSpan, FilePath)]
_ [] = []
splice Text
prefix Int
_ Int
_ [] (Text
t:[Text]
ts) = Text
prefix forall a. Semigroup a => a -> a -> a
<> Text
t forall a. a -> [a] -> [a]
: [Text]
ts
splice Text
prefix Int
l Int
c rs :: [(RealSrcSpan, FilePath)]
rs@((RealSrcSpan
r, FilePath
repl):[(RealSrcSpan, FilePath)]
rs') ts :: [Text]
ts@(Text
t:[Text]
ts')
  | RealSrcSpan -> Int
srcSpanStartLine RealSrcSpan
r forall a. Ord a => a -> a -> Bool
> Int
l =
      -- Next rewrite is not on this line. Output line.
      Text
prefix forall a. Semigroup a => a -> a -> a
<> Text
t forall a. a -> [a] -> [a]
: Text -> Int -> Int -> [(RealSrcSpan, FilePath)] -> [Text] -> [Text]
splice Text
"" (Int
lforall a. Num a => a -> a -> a
+Int
1) Int
1 [(RealSrcSpan, FilePath)]
rs [Text]
ts'
  | RealSrcSpan -> Int
srcSpanStartLine RealSrcSpan
r forall a. Ord a => a -> a -> Bool
< Int
l Bool -> Bool -> Bool
|| RealSrcSpan -> Int
srcSpanStartCol RealSrcSpan
r forall a. Ord a => a -> a -> Bool
< Int
c =
      -- Next rewrite starts before current position. This happens when
      -- the same rewrite is made in multiple versions of the CPP'd module.
      -- Drop the duplicate rewrite and keep going.
      Text -> Int -> Int -> [(RealSrcSpan, FilePath)] -> [Text] -> [Text]
splice Text
prefix Int
l Int
c [(RealSrcSpan, FilePath)]
rs' [Text]
ts
  | ([Text]
old, Text
ln:[Text]
lns) <- forall a. Int -> [a] -> ([a], [a])
splitAt (RealSrcSpan -> Int
srcSpanEndLine RealSrcSpan
r forall a. Num a => a -> a -> a
- Int
l) [Text]
ts =
      -- The next rewrite starts on this line.
      let
        start :: Int
start = RealSrcSpan -> Int
srcSpanStartCol RealSrcSpan
r
        end :: Int
end = RealSrcSpan -> Int
srcSpanEndCol RealSrcSpan
r

        prefix' :: Text
prefix' = Text
prefix forall a. Semigroup a => a -> a -> a
<> Int -> Text -> Text
Text.take (Int
start forall a. Num a => a -> a -> a
- Int
c) Text
t forall a. Semigroup a => a -> a -> a
<> FilePath -> Text
Text.pack FilePath
repl
        ln' :: Text
ln' = Int -> Text -> Text
Text.drop (Int
end forall a. Num a => a -> a -> a
- Int
c) Text
ln

        -- For an example of how this can happen, see the CPPConflict test.
        errMsg :: FilePath
errMsg = [FilePath] -> FilePath
unlines
          [ FilePath
"Refusing to rewrite across CPP directives."
          , FilePath
""
          , FilePath
"Location: " forall a. [a] -> [a] -> [a]
++ FilePath
locStr
          , FilePath
""
          , FilePath
"Original:"
          , FilePath
""
          , Text -> FilePath
Text.unpack Text
orig
          , FilePath
""
          , FilePath
"Replacement:"
          , FilePath
""
          , FilePath
repl
          ]
        orig :: Text
orig =
          [Text] -> Text
Text.unlines forall a b. (a -> b) -> a -> b
$ (Text
prefix forall a. Semigroup a => a -> a -> a
<> Text
t forall a. a -> [a] -> [a]
: forall a. Int -> [a] -> [a]
drop Int
1 [Text]
old) forall a. [a] -> [a] -> [a]
++ [Int -> Text -> Text
Text.take (Int
end forall a. Num a => a -> a -> a
- Int
c) Text
ln]
        locStr :: FilePath
locStr = FastString -> FilePath
unpackFS (RealSrcSpan -> FastString
srcSpanFile RealSrcSpan
r) forall a. [a] -> [a] -> [a]
++ FilePath
":" forall a. [a] -> [a] -> [a]
++ forall a. Show a => a -> FilePath
show Int
l forall a. [a] -> [a] -> [a]
++ FilePath
":" forall a. [a] -> [a] -> [a]
++ forall a. Show a => a -> FilePath
show Int
start
      in
        if forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any Text -> Bool
isCPP [Text]
old
        then forall a. FilePath -> a -> a
trace FilePath
errMsg forall a b. (a -> b) -> a -> b
$ Text -> Int -> Int -> [(RealSrcSpan, FilePath)] -> [Text] -> [Text]
splice Text
prefix Int
l Int
c [(RealSrcSpan, FilePath)]
rs' [Text]
ts
        else Text -> Int -> Int -> [(RealSrcSpan, FilePath)] -> [Text] -> [Text]
splice Text
prefix' (RealSrcSpan -> Int
srcSpanEndLine RealSrcSpan
r) Int
end [(RealSrcSpan, FilePath)]
rs' (Text
ln'forall a. a -> [a] -> [a]
:[Text]
lns)
  | Bool
otherwise = forall a. HasCallStack => FilePath -> a
error FilePath
"printCPP: impossible replacement past end of file"

-- Forking the module ----------------------------------------------------------

cppFork :: Text -> [Text]
cppFork :: Text -> [Text]
cppFork = CPPTree -> [Text]
cppTreeToList forall b c a. (b -> c) -> (a -> b) -> a -> c
. Text -> CPPTree
mkCPPTree

-- | Tree representing the module. Each #endif becomes a Node.
data CPPTree
  = Node [Text] CPPTree CPPTree
  | Leaf [Text]

-- | Stack type used to keep track of how many #ifs we are nested into.
-- Controls whether we emit lines into each version of the module.
data CPPBranch
  = CPPTrue -- print until an 'else'
  | CPPFalse -- print blanks until an 'else' or 'endif'
  | CPPOmit -- print blanks until an 'endif'

-- | Build CPPTree from lines of the module.
mkCPPTree :: Text -> CPPTree
mkCPPTree :: Text -> CPPTree
mkCPPTree = Bool -> [CPPBranch] -> [Text] -> [Text] -> CPPTree
go Bool
False [] [] forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. [a] -> [a]
reverse forall b c a. (b -> c) -> (a -> b) -> a -> c
. Text -> [Text]
Text.lines
  -- We reverse the lines once up front, then process the module from bottom
  -- to top, branching at #endifs. If we were to process from top to bottom,
  -- we'd have to reverse each version later, rather than reversing the original
  -- once. This also makes it easy to spot import statements and stop branching
  -- since we don't care about differences in imports.
  where
    go :: Bool -> [CPPBranch] -> [Text] -> [Text] -> CPPTree
    go :: Bool -> [CPPBranch] -> [Text] -> [Text] -> CPPTree
go Bool
_ [CPPBranch]
_ [Text]
suffix [] = [Text] -> CPPTree
Leaf [Text]
suffix
    go Bool
True [] [Text]
suffix [Text]
ls =
      [Text] -> CPPTree
Leaf ([Text] -> [Text] -> [Text]
blankifyAndReverse [Text]
suffix [Text]
ls) -- See Note [Imports]
    go Bool
seenImport [CPPBranch]
st [Text]
suffix (Text
l:[Text]
ls) =
      case Text -> Maybe CPPCond
extractCPPCond Text
l of
        Just CPPCond
If -> -- pops from stack
          case [CPPBranch]
st of
            (CPPBranch
_:[CPPBranch]
st') -> [CPPBranch] -> CPPTree
emptyLine [CPPBranch]
st'
            [] -> forall a. HasCallStack => FilePath -> a
error FilePath
"mkCPPTree: if with empty stack"
        Just CPPCond
ElIf -> -- stack same size
          case [CPPBranch]
st of
            (CPPBranch
CPPOmit:[CPPBranch]
_) -> [CPPBranch] -> CPPTree
emptyLine [CPPBranch]
st
            (CPPBranch
CPPFalse:[CPPBranch]
st') -> [CPPBranch] -> CPPTree
emptyLine (CPPBranch
CPPOmitforall a. a -> [a] -> [a]
:[CPPBranch]
st')
            (CPPBranch
CPPTrue:[CPPBranch]
st') -> -- See Note [ElIf]
              let
                omittedSuffix :: [Text]
omittedSuffix = forall a. Int -> a -> [a]
replicate (forall (t :: * -> *) a. Foldable t => t a -> Int
length [Text]
suffix) Text
""
              in
                [Text] -> CPPTree -> CPPTree -> CPPTree
Node
                  []
                  ([CPPBranch] -> CPPTree
emptyLine (CPPBranch
CPPOmitforall a. a -> [a] -> [a]
:[CPPBranch]
st'))
                  (Bool -> [CPPBranch] -> [Text] -> [Text] -> CPPTree
go Bool
seenImport (CPPBranch
CPPTrueforall a. a -> [a] -> [a]
:[CPPBranch]
st') (Text
""forall a. a -> [a] -> [a]
:[Text]
omittedSuffix) [Text]
ls)
            [] -> forall a. HasCallStack => FilePath -> a
error FilePath
"mkCPPTree: else with empty stack"
        Just CPPCond
Else -> -- stack same size
          case [CPPBranch]
st of
            (CPPBranch
CPPOmit:[CPPBranch]
_) -> [CPPBranch] -> CPPTree
emptyLine [CPPBranch]
st
            (CPPBranch
CPPTrue:[CPPBranch]
st') -> [CPPBranch] -> CPPTree
emptyLine (CPPBranch
CPPFalseforall a. a -> [a] -> [a]
:[CPPBranch]
st')
            (CPPBranch
CPPFalse:[CPPBranch]
st') -> [CPPBranch] -> CPPTree
emptyLine (CPPBranch
CPPTrueforall a. a -> [a] -> [a]
:[CPPBranch]
st')
            [] -> forall a. HasCallStack => FilePath -> a
error FilePath
"mkCPPTree: else with empty stack"
        Just CPPCond
EndIf -> -- push to stack
          case [CPPBranch]
st of
            (CPPBranch
CPPOmit:[CPPBranch]
_) -> [CPPBranch] -> CPPTree
emptyLine (CPPBranch
CPPOmitforall a. a -> [a] -> [a]
:[CPPBranch]
st)
            (CPPBranch
CPPFalse:[CPPBranch]
_) -> [CPPBranch] -> CPPTree
emptyLine (CPPBranch
CPPOmitforall a. a -> [a] -> [a]
:[CPPBranch]
st)
            [CPPBranch]
_ ->
              [Text] -> CPPTree -> CPPTree -> CPPTree
Node
                [Text]
suffix
                (Bool -> [CPPBranch] -> [Text] -> [Text] -> CPPTree
go Bool
seenImport (CPPBranch
CPPTrueforall a. a -> [a] -> [a]
:[CPPBranch]
st) [Text
""] [Text]
ls)
                (Bool -> [CPPBranch] -> [Text] -> [Text] -> CPPTree
go Bool
seenImport (CPPBranch
CPPFalseforall a. a -> [a] -> [a]
:[CPPBranch]
st) [Text
""] [Text]
ls)
        Maybe CPPCond
Nothing -> -- stack same size
          case [CPPBranch]
st of
            (CPPBranch
CPPOmit:[CPPBranch]
_) -> Bool -> [CPPBranch] -> [Text] -> [Text] -> CPPTree
go Bool
seenImport' [CPPBranch]
st (Text
""forall a. a -> [a] -> [a]
:[Text]
suffix) [Text]
ls
            (CPPBranch
CPPFalse:[CPPBranch]
_) -> Bool -> [CPPBranch] -> [Text] -> [Text] -> CPPTree
go Bool
seenImport' [CPPBranch]
st (Text
""forall a. a -> [a] -> [a]
:[Text]
suffix) [Text]
ls
            [CPPBranch]
_ -> Bool -> [CPPBranch] -> [Text] -> [Text] -> CPPTree
go Bool
seenImport' [CPPBranch]
st (Text -> Text
blankCPP Text
lforall a. a -> [a] -> [a]
:[Text]
suffix) [Text]
ls
      where
        emptyLine :: [CPPBranch] -> CPPTree
emptyLine [CPPBranch]
st' = Bool -> [CPPBranch] -> [Text] -> [Text] -> CPPTree
go Bool
seenImport [CPPBranch]
st' (Text
""forall a. a -> [a] -> [a]
:[Text]
suffix) [Text]
ls
        seenImport' :: Bool
seenImport' = Bool
seenImport Bool -> Bool -> Bool
|| Text -> Bool
isImport Text
l

    blankifyAndReverse :: [Text] -> [Text] -> [Text]
    blankifyAndReverse :: [Text] -> [Text] -> [Text]
blankifyAndReverse [Text]
suffix [] = [Text]
suffix
    blankifyAndReverse [Text]
suffix (Text
l:[Text]
ls) = [Text] -> [Text] -> [Text]
blankifyAndReverse (Text -> Text
blankCPP Text
lforall a. a -> [a] -> [a]
:[Text]
suffix) [Text]
ls

-- Note [Imports]
-- If we have seen an import statement, and have an empty stack, that means all
-- conditionals above this point only control imports/exports, etc. Retrie
-- doesn't match in those places anyway, and the imports don't matter because
-- we only parse, no renaming. As a micro-optimization, we can stop branching.
-- This saves forking the module in the common case that CPP is used to choose
-- imports. We have to wait for stack to be empty because we might have seen an
-- import in one branch, but there is a decl in the other branch.

-- Note [ElIf]
-- The way we handle #elif is pretty subtle. Some observations:
-- If we're on the CPPOmit branch, keep omitting up to the next #if, like usual.
-- If we're on the CPPFalse branch, we didn't show the #elif, but either we
-- showed the #else, or this whole #if might not output anything. So either way,
-- we need to omit up to the next #if.
-- If we're on the CPPTrue branch, we definitely showed the #elif, so we need to
-- fork with a Node. One side of the branch omits up to the next #if. The other
-- side is as if we have omitted everything from the last #endif, and we
-- continue showing up from here. This will show whatever is above the #elif.
-- It is crucial we do this branching on the CPPTrue branch, so any #elif
-- above this point is also handled correctly.

-- | Expand CPPTree into 2^h-1 versions of the module.
cppTreeToList :: CPPTree -> [Text]
cppTreeToList :: CPPTree -> [Text]
cppTreeToList CPPTree
t = [Text] -> CPPTree -> [Text] -> [Text]
go [] CPPTree
t []
  where
    go :: [Text] -> CPPTree -> [Text] -> [Text]
go [Text]
rest (Leaf [Text]
suffix) = ([Text] -> Text
Text.unlines ([Text]
suffix forall a. [a] -> [a] -> [a]
++ [Text]
rest) forall a. a -> [a] -> [a]
:)
    go [Text]
rest (Node [Text]
suffix CPPTree
l CPPTree
r) =
      let rest' :: [Text]
rest' = [Text]
suffix forall a. [a] -> [a] -> [a]
++ [Text]
rest -- right-nested
      in [Text] -> CPPTree -> [Text] -> [Text]
go [Text]
rest' CPPTree
l forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Text] -> CPPTree -> [Text] -> [Text]
go [Text]
rest' CPPTree
r

-- Spotting CPP directives -----------------------------------------------------

data CPPCond = If | ElIf | Else | EndIf

extractCPPCond :: Text -> Maybe CPPCond
extractCPPCond :: Text -> Maybe CPPCond
extractCPPCond Text
t
  | Just (Char
'#',Text
t') <- Text -> Maybe (Char, Text)
Text.uncons Text
t =
    case Text -> [Text]
Text.words Text
t' of
      (Text
"if":[Text]
_) -> forall a. a -> Maybe a
Just CPPCond
If
      (Text
"else":[Text]
_) -> forall a. a -> Maybe a
Just CPPCond
Else
      (Text
"elif":[Text]
_) -> forall a. a -> Maybe a
Just CPPCond
ElIf
      (Text
"endif":[Text]
_) -> forall a. a -> Maybe a
Just CPPCond
EndIf
      [Text]
_ -> forall a. Maybe a
Nothing
  | Bool
otherwise = forall a. Maybe a
Nothing

blankCPP :: Text -> Text
blankCPP :: Text -> Text
blankCPP Text
t
  | Text -> Bool
isCPP Text
t = Text
""
  | Bool
otherwise = Text
t

isCPP :: Text -> Bool
isCPP :: Text -> Bool
isCPP = Text -> Text -> Bool
Text.isPrefixOf Text
"#"

isImport :: Text -> Bool
isImport :: Text -> Bool
isImport = Text -> Text -> Bool
Text.isPrefixOf Text
"import"

isModule :: Text -> Bool
isModule :: Text -> Bool
isModule = Text -> Text -> Bool
Text.isPrefixOf Text
"module"

isPragma :: Text -> Bool
isPragma :: Text -> Bool
isPragma = Text -> Text -> Bool
Text.isPrefixOf Text
"{-#"

-------------------------------------------------------------------------------
-- This would make more sense in Retrie.Expr, but that creates an import cycle.
-- Ironic, I know.

insertImports
  :: Monad m
  => [AnnotatedImports]   -- ^ imports and their annotations
  -> Located HsModule     -- ^ target module
  -> TransformT m (Located HsModule)
insertImports :: forall (m :: * -> *).
Monad m =>
[AnnotatedImports]
-> Located HsModule -> TransformT m (Located HsModule)
insertImports [AnnotatedImports]
is (L SrcSpan
l HsModule
m) = do
  [GenLocated SrcSpanAnnA (ImportDecl GhcPs)]
imps <- forall ast (m :: * -> *).
(Data ast, Monad m) =>
Annotated ast -> TransformT m ast
graftA forall a b. (a -> b) -> a -> b
$ Maybe ModuleName -> [AnnotatedImports] -> AnnotatedImports
filterAndFlatten (forall l e. GenLocated l e -> e
unLoc forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> HsModule -> Maybe (GenLocated SrcSpanAnnA ModuleName)
hsmodName HsModule
m) [AnnotatedImports]
is
  let
    deduped :: [GenLocated SrcSpanAnnA (ImportDecl GhcPs)]
deduped = forall a. (a -> a -> Bool) -> [a] -> [a]
nubBy (ImportDecl GhcPs -> ImportDecl GhcPs -> Bool
eqImportDecl forall b c a. (b -> b -> c) -> (a -> b) -> a -> a -> c
`on` forall l e. GenLocated l e -> e
unLoc) forall a b. (a -> b) -> a -> b
$ HsModule -> [LImportDecl GhcPs]
hsmodImports HsModule
m forall a. [a] -> [a] -> [a]
++ [GenLocated SrcSpanAnnA (ImportDecl GhcPs)]
imps
  forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ forall l e. l -> e -> GenLocated l e
L SrcSpan
l HsModule
m { hsmodImports :: [LImportDecl GhcPs]
hsmodImports = [GenLocated SrcSpanAnnA (ImportDecl GhcPs)]
deduped }

filterAndFlatten :: Maybe ModuleName -> [AnnotatedImports] -> AnnotatedImports
filterAndFlatten :: Maybe ModuleName -> [AnnotatedImports] -> AnnotatedImports
filterAndFlatten Maybe ModuleName
mbName [AnnotatedImports]
is =
  forall a. Identity a -> a
runIdentity forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) ast1 ast2.
Monad m =>
Annotated ast1 -> (ast1 -> TransformT m ast2) -> m (Annotated ast2)
transformA (forall a. Monoid a => [a] -> a
mconcat [AnnotatedImports]
is) forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) a. Monad m => a -> m a
return forall b c a. (b -> c) -> (a -> b) -> a -> c
. Maybe ModuleName -> [LImportDecl GhcPs] -> [LImportDecl GhcPs]
externalImps Maybe ModuleName
mbName
  where
    externalImps :: Maybe ModuleName -> [LImportDecl GhcPs] -> [LImportDecl GhcPs]
    externalImps :: Maybe ModuleName -> [LImportDecl GhcPs] -> [LImportDecl GhcPs]
externalImps (Just ModuleName
mn) = forall a. (a -> Bool) -> [a] -> [a]
filter ((forall a. Eq a => a -> a -> Bool
/= ModuleName
mn) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall l e. GenLocated l e -> e
unLoc forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall pass. ImportDecl pass -> XRec pass ModuleName
ideclName forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall l e. GenLocated l e -> e
unLoc)
    externalImps Maybe ModuleName
_ = forall a. a -> a
id

eqImportDecl :: ImportDecl GhcPs -> ImportDecl GhcPs -> Bool
eqImportDecl :: ImportDecl GhcPs -> ImportDecl GhcPs -> Bool
eqImportDecl ImportDecl GhcPs
x ImportDecl GhcPs
y =
  (forall a. Eq a => a -> a -> Bool
(==) forall b c a. (b -> b -> c) -> (a -> b) -> a -> a -> c
`on` forall l e. GenLocated l e -> e
unLoc forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall pass. ImportDecl pass -> XRec pass ModuleName
ideclName) ImportDecl GhcPs
x ImportDecl GhcPs
y
  Bool -> Bool -> Bool
&& (forall a. Eq a => a -> a -> Bool
(==) forall b c a. (b -> b -> c) -> (a -> b) -> a -> a -> c
`on` forall pass. ImportDecl pass -> ImportDeclQualifiedStyle
ideclQualified) ImportDecl GhcPs
x ImportDecl GhcPs
y
  Bool -> Bool -> Bool
&& (forall a. Eq a => a -> a -> Bool
(==) forall b c a. (b -> b -> c) -> (a -> b) -> a -> a -> c
`on` forall pass. ImportDecl pass -> Maybe (XRec pass ModuleName)
ideclAs) ImportDecl GhcPs
x ImportDecl GhcPs
y
  Bool -> Bool -> Bool
&& (forall a. Eq a => a -> a -> Bool
(==) forall b c a. (b -> b -> c) -> (a -> b) -> a -> a -> c
`on` forall pass. ImportDecl pass -> Maybe (Bool, XRec pass [LIE pass])
ideclHiding) ImportDecl GhcPs
x ImportDecl GhcPs
y
#if __GLASGOW_HASKELL__ < 904
  Bool -> Bool -> Bool
&& (forall a. Eq a => a -> a -> Bool
(==) forall b c a. (b -> b -> c) -> (a -> b) -> a -> a -> c
`on` forall pass. ImportDecl pass -> Maybe StringLiteral
ideclPkgQual) ImportDecl GhcPs
x ImportDecl GhcPs
y
#else
  && (eqRawPkgQual `on` ideclPkgQual) x y
#endif
  Bool -> Bool -> Bool
&& (forall a. Eq a => a -> a -> Bool
(==) forall b c a. (b -> b -> c) -> (a -> b) -> a -> a -> c
`on` forall pass. ImportDecl pass -> IsBootInterface
ideclSource) ImportDecl GhcPs
x ImportDecl GhcPs
y
  Bool -> Bool -> Bool
&& (forall a. Eq a => a -> a -> Bool
(==) forall b c a. (b -> b -> c) -> (a -> b) -> a -> a -> c
`on` forall pass. ImportDecl pass -> Bool
ideclSafe) ImportDecl GhcPs
x ImportDecl GhcPs
y
  -- intentionally leave out ideclImplicit and ideclSourceSrc
  -- former doesn't matter for this check, latter is prone to whitespace issues
#if __GLASGOW_HASKELL__ < 904
#else
  where
    eqRawPkgQual NoRawPkgQual NoRawPkgQual = True
    eqRawPkgQual NoRawPkgQual (RawPkgQual _) = False
    eqRawPkgQual (RawPkgQual _) NoRawPkgQual = False
    eqRawPkgQual (RawPkgQual s) (RawPkgQual s') = s == s'
#endif