{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE RecordWildCards #-}
module Retrie.CPP
( CPP(..)
, addImportsCPP
, parseCPPFile
, parseCPP
, printCPP
, cppFork
) where
import Data.Char (isSpace)
import Data.Function (on)
import Data.Functor.Identity
import Data.List (nubBy, sortOn)
import Data.Text (Text)
import qualified Data.Text as Text
import qualified Data.Text.IO as Text
import Debug.Trace
import Retrie.ExactPrint
import Retrie.GHC
import Retrie.Replace
data CPP a
= NoCPP a
| CPP Text [AnnotatedImports] [a]
instance Functor CPP where
fmap f (NoCPP x) = NoCPP (f x)
fmap f (CPP orig is xs) = CPP orig is (map f xs)
instance Foldable CPP where
foldMap f (NoCPP x) = f x
foldMap f (CPP _ _ xs) = foldMap f xs
instance Traversable CPP where
traverse f (NoCPP x) = NoCPP <$> f x
traverse f (CPP orig is xs) = CPP orig is <$> traverse f xs
addImportsCPP
:: [AnnotatedImports]
-> CPP AnnotatedModule
-> CPP AnnotatedModule
addImportsCPP is (NoCPP m) =
NoCPP $ runIdentity $ transformA m $ insertImports is
addImportsCPP is (CPP orig is' ms) = CPP orig (is++is') ms
parseCPPFile
:: (FilePath -> String -> IO AnnotatedModule)
-> FilePath
-> IO (CPP AnnotatedModule)
parseCPPFile p fp =
Text.readFile fp >>= parseCPP (p fp)
parseCPP
:: Monad m
=> (String -> m AnnotatedModule)
-> Text -> m (CPP AnnotatedModule)
parseCPP p orig
| any isCPP (Text.lines orig) =
CPP orig [] <$> mapM (p . Text.unpack) (cppFork orig)
| otherwise = NoCPP <$> p (Text.unpack orig)
printCPP :: [Replacement] -> CPP AnnotatedModule -> String
printCPP _ (NoCPP m) = printA m
printCPP repls (CPP orig is ms) = Text.unpack $ Text.unlines $
case is of
[] -> splice "" 1 1 sorted origLines
_ ->
splice
(Text.unlines newHeader)
(length revHeader + 1)
1
sorted
(reverse revDecls)
where
sorted = sortOn fst
[ (r, replReplacement)
| Replacement{..} <- repls
, RealSrcSpan r <- [replLocation]
]
origLines = Text.lines orig
mbName = unLoc <$> hsmodName (unLoc $ astA $ head ms)
importLines = runIdentity $ fmap astA $ transformA (filterAndFlatten mbName is) $
mapM $ fmap (Text.pack . dropWhile isSpace . printA) . pruneA
p t = isImport t || isModule t || isPragma t
(revDecls, revHeader) = break p (reverse origLines)
newHeader = reverse revHeader ++ importLines
splice :: Text -> Int -> Int -> [(RealSrcSpan, String)] -> [Text] -> [Text]
splice _ _ _ _ [] = []
splice prefix _ _ [] (t:ts) = prefix <> t : ts
splice prefix l c rs@((r, repl):rs') ts@(t:ts')
| srcSpanStartLine r > l =
prefix <> t : splice "" (l+1) 1 rs ts'
| srcSpanStartLine r < l || srcSpanStartCol r < c =
splice prefix l c rs' ts
| (old, ln:lns) <- splitAt (srcSpanEndLine r - l) ts =
let
start = srcSpanStartCol r
end = srcSpanEndCol r
prefix' = prefix <> Text.take (start - c) t <> Text.pack repl
ln' = Text.drop (end - c) ln
errMsg = unlines
[ "Refusing to rewrite across CPP directives."
, ""
, "Location: " ++ locStr
, ""
, "Original:"
, ""
, Text.unpack orig
, ""
, "Replacement:"
, ""
, repl
]
orig =
Text.unlines $ (prefix <> t : drop 1 old) ++ [Text.take (end - c) ln]
locStr = unpackFS (srcSpanFile r) ++ ":" ++ show l ++ ":" ++ show start
in
if any isCPP old
then trace errMsg $ splice prefix l c rs' ts
else splice prefix' (srcSpanEndLine r) end rs' (ln':lns)
| otherwise = error "printCPP: impossible replacement past end of file"
cppFork :: Text -> [Text]
cppFork = cppTreeToList . mkCPPTree
data CPPTree
= Node [Text] CPPTree CPPTree
| Leaf [Text]
data CPPBranch
= CPPTrue
| CPPFalse
| CPPOmit
mkCPPTree :: Text -> CPPTree
mkCPPTree = go False [] [] . reverse . Text.lines
where
go :: Bool -> [CPPBranch] -> [Text] -> [Text] -> CPPTree
go _ _ suffix [] = Leaf suffix
go True [] suffix ls =
Leaf (blankifyAndReverse suffix ls)
go seenImport st suffix (l:ls) =
case extractCPPCond l of
Just If ->
case st of
(_:st') -> emptyLine st'
[] -> error "mkCPPTree: if with empty stack"
Just ElIf ->
case st of
(CPPOmit:_) -> emptyLine st
(CPPFalse:st') -> emptyLine (CPPOmit:st')
(CPPTrue:st') ->
let
omittedSuffix = replicate (length suffix) ""
in
Node
[]
(emptyLine (CPPOmit:st'))
(go seenImport (CPPTrue:st') ("":omittedSuffix) ls)
[] -> error "mkCPPTree: else with empty stack"
Just Else ->
case st of
(CPPOmit:_) -> emptyLine st
(CPPTrue:st') -> emptyLine (CPPFalse:st')
(CPPFalse:st') -> emptyLine (CPPTrue:st')
[] -> error "mkCPPTree: else with empty stack"
Just EndIf ->
case st of
(CPPOmit:_) -> emptyLine (CPPOmit:st)
(CPPFalse:_) -> emptyLine (CPPOmit:st)
_ ->
Node
suffix
(go seenImport (CPPTrue:st) [""] ls)
(go seenImport (CPPFalse:st) [""] ls)
Nothing ->
case st of
(CPPOmit:_) -> go seenImport' st ("":suffix) ls
(CPPFalse:_) -> go seenImport' st ("":suffix) ls
_ -> go seenImport' st (blankCPP l:suffix) ls
where
emptyLine st' = go seenImport st' ("":suffix) ls
seenImport' = seenImport || isImport l
blankifyAndReverse :: [Text] -> [Text] -> [Text]
blankifyAndReverse suffix [] = suffix
blankifyAndReverse suffix (l:ls) = blankifyAndReverse (blankCPP l:suffix) ls
cppTreeToList :: CPPTree -> [Text]
cppTreeToList t = go [] t []
where
go rest (Leaf suffix) = (Text.unlines (suffix ++ rest) :)
go rest (Node suffix l r) =
let rest' = suffix ++ rest
in go rest' l . go rest' r
data CPPCond = If | ElIf | Else | EndIf
extractCPPCond :: Text -> Maybe CPPCond
extractCPPCond t
| Just ('#',t') <- Text.uncons t =
case Text.words t' of
("if":_) -> Just If
("else":_) -> Just Else
("elif":_) -> Just ElIf
("endif":_) -> Just EndIf
_ -> Nothing
| otherwise = Nothing
blankCPP :: Text -> Text
blankCPP t
| isCPP t = ""
| otherwise = t
isCPP :: Text -> Bool
isCPP = Text.isPrefixOf "#"
isImport :: Text -> Bool
isImport = Text.isPrefixOf "import"
isModule :: Text -> Bool
isModule = Text.isPrefixOf "module"
isPragma :: Text -> Bool
isPragma = Text.isPrefixOf "{-#"
insertImports
:: Monad m
=> [AnnotatedImports]
-> Located (HsModule GhcPs)
-> TransformT m (Located (HsModule GhcPs))
insertImports is (L l m) = do
imps <- graftA $ filterAndFlatten (unLoc <$> hsmodName m) is
let
deduped = nubBy (eqImportDecl `on` unLoc) $ hsmodImports m ++ imps
return $ L l m { hsmodImports = deduped }
filterAndFlatten :: Maybe ModuleName -> [AnnotatedImports] -> AnnotatedImports
filterAndFlatten mbName is =
runIdentity $ transformA (mconcat is) $ return . externalImps mbName
where
externalImps :: Maybe ModuleName -> [LImportDecl GhcPs] -> [LImportDecl GhcPs]
externalImps (Just mn) = filter ((/= mn) . unLoc . ideclName . unLoc)
externalImps _ = id
eqImportDecl :: ImportDecl GhcPs -> ImportDecl GhcPs -> Bool
eqImportDecl x y =
((==) `on` unLoc . ideclName) x y
&& ((==) `on` ideclQualified) x y
&& ((==) `on` ideclAs) x y
&& ((==) `on` ideclHiding) x y
&& ((==) `on` ideclPkgQual) x y
&& ((==) `on` ideclSource) x y
&& ((==) `on` ideclSafe) x y