{-# LANGUAGE CPP #-}
{-# LANGUAGE TupleSections #-}
module Retrie.Rewrites.Function (dfnsToRewrites) where
import Control.Monad
import Control.Monad.State.Lazy
import Data.List
import Data.Maybe
import Data.Traversable
import Retrie.ExactPrint
import Retrie.Expr
import Retrie.GHC
import Retrie.Quantifiers
import Retrie.Types
import Retrie.Util
dfnsToRewrites
:: [(FastString, Direction)]
-> AnnotatedModule
-> IO (UniqFM [Rewrite (LHsExpr GhcPs)])
dfnsToRewrites specs am = fmap astA $ transformA am $ \ (L _ m) -> do
let
fsMap = uniqBag specs
rrs <- sequence
[ do
fe <- mkLocatedHsVar fRdrName
imps <- getImports dir (hsmodName m)
(fName,) . concat <$>
forM (unLoc $ mg_alts $ fun_matches f) (matchToRewrites fe imps dir)
#if __GLASGOW_HASKELL__ < 806
| L _ (ValD f@FunBind{}) <- hsmodDecls m
#else
| L _ (ValD _ f@FunBind{}) <- hsmodDecls m
#endif
, let fRdrName = fun_id f
, let fName = occNameFS (occName (unLoc fRdrName))
, dir <- fromMaybe [] (lookupUFM fsMap fName)
]
return $ listToUFM_C (++) rrs
getImports
:: Direction -> Maybe (Located ModuleName) -> TransformT IO AnnotatedImports
getImports RightToLeft (Just (L _ mn)) =
lift $ liftIO $ parseImports ["import " ++ moduleNameString mn]
getImports _ _ = return mempty
matchToRewrites
:: LHsExpr GhcPs
-> AnnotatedImports
-> Direction
-> LMatch GhcPs (LHsExpr GhcPs)
-> TransformT IO [Rewrite (LHsExpr GhcPs)]
matchToRewrites e imps dir (L _ alt) = do
let
pats = m_pats alt
grhss = m_grhss alt
qss <- for (zip (inits pats) (tails pats)) $
makeFunctionQuery e imps dir grhss mkApps
qs <- backtickRules e imps dir grhss pats
return $ qs ++ concat qss
type AppBuilder =
LHsExpr GhcPs -> [LHsExpr GhcPs] -> TransformT IO (LHsExpr GhcPs)
makeFunctionQuery
:: LHsExpr GhcPs
-> AnnotatedImports
-> Direction
-> GRHSs GhcPs (LHsExpr GhcPs)
-> AppBuilder
-> ([LPat GhcPs], [LPat GhcPs])
-> TransformT IO [Rewrite (LHsExpr GhcPs)]
makeFunctionQuery e imps dir grhss mkAppFn (argpats, bndpats) = do
let
#if __GLASGOW_HASKELL__ < 806
GRHSs rhss lbs = grhss
#else
GRHSs _ rhss lbs = grhss
#endif
bs = collectPatsBinders argpats
(es,(_,bs')) <- runStateT (mapM patToExpr argpats) (wildSupply bs, bs)
lhs <- mkAppFn e es
for rhss $ \ grhs -> do
le <- mkLet (unLoc lbs) (grhsToExpr grhs)
rhs <- mkLams bndpats le
let
(pat, temp) =
case dir of
LeftToRight -> (lhs,rhs)
RightToLeft -> (rhs,lhs)
p <- pruneA pat
t <- pruneA temp
return $ addRewriteImports imps $ mkRewrite (mkQs bs') p t
backtickRules
:: LHsExpr GhcPs
-> AnnotatedImports
-> Direction
-> GRHSs GhcPs (LHsExpr GhcPs)
-> [LPat GhcPs]
-> TransformT IO [Rewrite (LHsExpr GhcPs)]
backtickRules e imps dir@LeftToRight grhss ps@[p1, p2] = do
let
both, left, right :: AppBuilder
#if __GLASGOW_HASKELL__ < 806
both op [l, r] = mkLoc (OpApp l op PlaceHolder r)
both _ _ = fail "backtickRules - both: impossible!"
left op [l] = mkLoc (SectionL l op)
left _ _ = fail "backtickRules - left: impossible!"
right op [r] = mkLoc (SectionR op r)
right _ _ = fail "backtickRules - right: impossible!"
#else
both op [l, r] = mkLoc (OpApp noExt l op r)
both _ _ = fail "backtickRules - both: impossible!"
left op [l] = mkLoc (SectionL noExt l op)
left _ _ = fail "backtickRules - left: impossible!"
right op [r] = mkLoc (SectionR noExt op r)
right _ _ = fail "backtickRules - right: impossible!"
#endif
qs <- makeFunctionQuery e imps dir grhss both (ps, [])
qsl <- makeFunctionQuery e imps dir grhss left ([p1], [p2])
qsr <- makeFunctionQuery e imps dir grhss right ([p2], [p1])
return $ qs ++ qsl ++ qsr
backtickRules _ _ _ _ _ = return []