{-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE ViewPatterns #-} -- Undo pointfree transformations. Plugin code derived from Pl.hs. module Lambdabot.Plugin.Haskell.Pointful (pointfulPlugin) where import Lambdabot.Module as Lmb (Module) import Lambdabot.Plugin import Lambdabot.Util.Parser (withParsed, prettyPrintInLine) import Control.Monad.Reader import Control.Monad.State import Data.Functor.Identity (Identity) import Data.Generics import qualified Data.Set as S import qualified Data.Map as M import Data.List import Data.Maybe import Language.Haskell.Exts as Hs pointfulPlugin :: Lmb.Module () pointfulPlugin = newModule { moduleCmds = return [ (command "pointful") { aliases = ["pointy","repoint","unpointless","unpl","unpf"] , help = say "pointful . Make code pointier." , process = mapM_ say . lines . pointful } ] } ---- Utilities ---- unkLoc :: SrcLoc unkLoc = SrcLoc "" 1 1 stabilize :: Eq a => (a -> a) -> a -> a stabilize f x = let x' = f x in if x' == x then x else stabilize f x' -- varsBoundHere returns variables bound by top patterns or binders varsBoundHere :: Data d => d -> S.Set Name varsBoundHere (cast -> Just (PVar name)) = S.singleton name varsBoundHere (cast -> Just (Match _ name _ _ _ _)) = S.singleton name varsBoundHere (cast -> Just (PatBind _ pat _ _)) = varsBoundHere pat varsBoundHere (cast -> Just (_ :: Exp)) = S.empty varsBoundHere d = S.unions (gmapQ varsBoundHere d) -- note: the tempting idea of using a pattern synonym for the frequent -- (cast -> Just _) patterns causes compiler crashes with ghc before -- version 8; cf. https://ghc.haskell.org/trac/ghc/ticket/11336 foldFreeVars :: forall a d. Data d => (Name -> S.Set Name -> a) -> ([a] -> a) -> d -> a foldFreeVars var sum e = runReader (go e) S.empty where go :: forall d. Data d => d -> Reader (S.Set Name) a go (cast -> Just (Var (UnQual name))) = asks (var name) go (cast -> Just (Lambda _ ps exp)) = bind [varsBoundHere ps] $ go exp go (cast -> Just (Let bs exp)) = bind [varsBoundHere bs] $ collect [go bs, go exp] go (cast -> Just (Alt _ pat exp bs)) = bind [varsBoundHere pat, varsBoundHere bs] $ collect [go exp, go bs] go (cast -> Just (PatBind _ pat exp bs)) = bind [varsBoundHere pat, varsBoundHere bs] $ collect [go exp, go bs] go (cast -> Just (Match _ _ ps _ exp bs)) = bind [varsBoundHere ps, varsBoundHere bs] $ collect [go exp, go bs] go d = collect (gmapQ go d) collect :: forall m. Monad m => [m a] -> m a collect ms = sum `liftM` sequence ms bind :: forall a b. Ord a => [S.Set a] -> Reader (S.Set a) b -> Reader (S.Set a) b bind ss = local (S.unions ss `S.union`) -- return free variables freeVars :: Data d => d -> S.Set Name freeVars = foldFreeVars (\name bv -> S.singleton name `S.difference` bv) S.unions -- return number of free occurrences of a variable countOcc :: Data d => Name -> d -> Int countOcc name = foldFreeVars var sum where sum = foldl' (+) 0 var name' bv = if name /= name' || name' `S.member` bv then 0 else 1 -- variable capture avoiding substitution substAvoiding :: Data d => M.Map Name Exp -> S.Set Name -> d -> d substAvoiding subst bv = base `extT` exp `extT` alt `extT` decl `extT` match where base :: Data d => d -> d base = gmapT (substAvoiding subst bv) exp e@(Var (UnQual name)) = fromMaybe e (M.lookup name subst) exp (Lambda sloc ps exp) = let (subst', bv', ps') = renameBinds subst bv ps in Lambda sloc ps' (substAvoiding subst' bv' exp) exp (Let bs exp) = let (subst', bv', bs') = renameBinds subst bv bs in Let (substAvoiding subst' bv' bs') (substAvoiding subst' bv' exp) exp d = base d alt (Alt sloc pat exp bs) = let (subst1, bv1, pat') = renameBinds subst bv pat (subst', bv', bs') = renameBinds subst1 bv1 bs in Alt sloc pat' (substAvoiding subst' bv' exp) (substAvoiding subst' bv' bs') decl (PatBind sloc pat exp bs) = let (subst', bv', bs') = renameBinds subst bv bs in PatBind sloc pat (substAvoiding subst' bv' exp) (substAvoiding subst' bv' bs') decl d = base d match (Match sloc name ps typ exp bs) = let (subst1, bv1, ps') = renameBinds subst bv ps (subst', bv', bs') = renameBinds subst1 bv1 bs in Match sloc name ps' typ (substAvoiding subst' bv' exp) (substAvoiding subst' bv' bs') -- rename local binders (but not the nested expressions) renameBinds :: Data d => M.Map Name Exp -> S.Set Name -> d -> (M.Map Name Exp, S.Set Name, d) renameBinds subst bv d = (subst', bv', d') where (d', (subst', bv', _)) = runState (go d) (subst, bv, M.empty) go, base :: Data d => d -> State (M.Map Name Exp, S.Set Name, M.Map Name Name) d go = base `extM` pat `extM` match `extM` decl `extM` exp base d = gmapM go d pat (PVar name) = PVar `fmap` rename name pat d = base d match (Match sloc name ps typ exp bs) = do name' <- rename name return $ Match sloc name' ps typ exp bs decl (PatBind sloc pat exp bs) = do pat' <- go pat return $ PatBind sloc pat' exp bs decl d = base d exp (e :: Exp) = return e rename :: Name -> State (M.Map Name Exp, S.Set Name, M.Map Name Name) Name rename name = do (subst, bv, ass) <- get case (name `M.lookup` ass, name `S.member` bv) of (Just name', _) -> do return name' (_, False) -> do put (M.delete name subst, S.insert name bv, ass) return name _ -> do let name' = freshNameAvoiding name bv put (M.insert name (Var (UnQual name')) subst, S.insert name' bv, M.insert name name' ass) return name' -- generate fresh names freshNameAvoiding :: Name -> S.Set Name -> Name freshNameAvoiding name forbidden = con (pre ++ suf) where (con, nm, cs) = case name of Ident n -> (Ident, n, "0123456789") Symbol n -> (Symbol, n, "?#") pre = reverse . dropWhile (`elem` cs) . reverse $ nm sufs = [1..] >>= flip replicateM cs suf = head $ dropWhile (\suf -> con (pre ++ suf) `S.member` forbidden) sufs ---- Optimization (removing explicit lambdas) and restoration of infix ops ---- -- move lambda patterns into LHS optimizeD :: Decl -> Decl optimizeD (PatBind locat (PVar fname) (UnGuardedRhs (Lambda _ pats rhs)) Nothing) = let (subst, bv, pats') = renameBinds M.empty (S.singleton fname) pats rhs' = substAvoiding subst bv rhs in FunBind [Match locat fname pats' Nothing (UnGuardedRhs rhs') Nothing] ---- combine function binding and lambda optimizeD (FunBind [Match locat fname pats1 Nothing (UnGuardedRhs (Lambda _ pats2 rhs)) Nothing]) = let (subst, bv, pats2') = renameBinds M.empty (varsBoundHere pats1) pats2 rhs' = substAvoiding subst bv rhs in FunBind [Match locat fname (pats1 ++ pats2') Nothing (UnGuardedRhs rhs') Nothing] optimizeD x = x -- remove parens optimizeRhs :: Rhs -> Rhs optimizeRhs (UnGuardedRhs (Paren x)) = UnGuardedRhs x optimizeRhs x = x optimizeE :: Exp -> Exp -- apply ((\x z -> ...x...) y) yielding (\z -> ...y...) if there is only one x or y is simple optimizeE (App (Lambda locat (PVar ident : pats) body) arg) | single || simple arg = let (subst, bv, pats') = renameBinds (M.singleton ident arg) (freeVars arg) pats in Paren (Lambda locat pats' (substAvoiding subst bv body)) where single = countOcc ident body <= 1 simple e = case e of Var _ -> True; Lit _ -> True; Paren e' -> simple e'; _ -> False -- apply ((\_ z -> ...) y) yielding (\z -> ...) optimizeE (App (Lambda locat (PWildCard : pats) body) _) = Paren (Lambda locat pats body) -- remove 0-arg lambdas resulting from application rules optimizeE (Lambda _ [] b) = b -- replace (\x -> \y -> z) with (\x y -> z) optimizeE (Lambda locat p1 (Lambda _ p2 body)) = let (subst, bv, p2') = renameBinds M.empty (varsBoundHere p1) p2 body' = substAvoiding subst bv body in Lambda locat (p1 ++ p2') body' -- remove double parens optimizeE (Paren (Paren x)) = Paren x -- remove parens around applied lambdas (the pretty printer restores them) optimizeE (App (Paren (x@Lambda{})) y) = App x y -- remove lambda body parens optimizeE (Lambda l p (Paren x)) = Lambda l p x -- remove var, lit parens optimizeE (Paren x@(Var _)) = x optimizeE (Paren x@(Lit _)) = x -- remove infix+lambda parens optimizeE (InfixApp a o (Paren l@(Lambda _ _ _))) = InfixApp a o l -- remove infix+app aprens optimizeE (InfixApp (Paren a@App{}) o l) = InfixApp a o l optimizeE (InfixApp a o (Paren l@App{})) = InfixApp a o l -- remove left-assoc application parens optimizeE (App (Paren (App a b)) c) = App (App a b) c -- restore infix optimizeE (App (App (Var name'@(UnQual (Symbol _))) l) r) = (InfixApp l (QVarOp name') r) -- eta reduce optimizeE (Lambda l ps@(_:_) (App e (Var (UnQual v)))) | free && last ps == PVar v = Lambda l (init ps) e where free = countOcc v e == 0 -- fail optimizeE x = x ---- Decombinatorization ---- uncomb' :: Exp -> Exp uncomb' (Paren (Paren e)) = Paren e -- eliminate sections uncomb' (RightSection op' arg) = let a = freshNameAvoiding (Ident "a") (freeVars arg) in (Paren (Lambda unkLoc [PVar a] (InfixApp (Var (UnQual a)) op' arg))) uncomb' (LeftSection arg op') = let a = freshNameAvoiding (Ident "a") (freeVars arg) in (Paren (Lambda unkLoc [PVar a] (InfixApp arg op' (Var (UnQual a))))) -- infix to prefix for canonicality uncomb' (InfixApp lf (QVarOp name') rf) = (Paren (App (App (Var name') (Paren lf)) (Paren rf))) -- Expand (>>=) when it is obviously the reader monad: -- rewrite: (>>=) (\x -> e) -- to: (\ a b -> a ((\ x -> e) b) b) uncomb' (App (Var (UnQual (Symbol ">>="))) (Paren lam@Lambda{})) = let a = freshNameAvoiding (Ident "a") (freeVars lam) b = freshNameAvoiding (Ident "b") (freeVars lam) in (Paren (Lambda unkLoc [PVar a, PVar b] (App (App (Var (UnQual a)) (Paren (App lam (Var (UnQual b))))) (Var (UnQual b))))) -- rewrite: ((>>=) e1) (\x y -> e2) -- to: (\a -> (\x y -> e2) (e1 a) a) uncomb' (App (App (Var (UnQual (Symbol ">>="))) e1) (Paren lam@(Lambda _ (_:_:_) _))) = let a = freshNameAvoiding (Ident "a") (freeVars [e1,lam]) in (Paren (Lambda unkLoc [PVar a] (App (App lam (App e1 (Var (UnQual a)))) (Var (UnQual a))))) -- fail uncomb' expr = expr ---- Simple combinator definitions --- combinators :: M.Map Name Exp combinators = M.fromList $ map declToTuple defs where defs = case parseModule combinatorModule of ParseOk (Hs.Module _ _ _ _ _ _ d) -> d f@(ParseFailed _ _) -> error ("Combinator loading: " ++ show f) declToTuple (PatBind _ (PVar fname) (UnGuardedRhs body) Nothing) = (fname, Paren body) declToTuple _ = error "Pointful Plugin error: can't convert declaration to tuple" combinatorModule :: String combinatorModule = unlines [ "(.) = \\f g x -> f (g x) ", "($) = \\f x -> f x ", "flip = \\f x y -> f y x ", "const = \\x _ -> x ", "id = \\x -> x ", "(=<<) = flip (>>=) ", "liftM2 = \\f m1 m2 -> m1 >>= \\x1 -> m2 >>= \\x2 -> return (f x1 x2) ", "join = (>>= id) ", "ap = liftM2 id ", "(>=>) = flip (<=<) ", "(<=<) = \\f g x -> f >>= g x ", " ", "-- ASSUMED reader monad ", "-- (>>=) = (\\f k r -> k (f r) r) ", "-- return = const ", ""] ---- Top level ---- unfoldCombinators :: (Data a) => a -> a unfoldCombinators = substAvoiding combinators (freeVars combinators) uncombOnce :: (Data a) => a -> a uncombOnce x = everywhere (mkT uncomb') x uncomb :: (Eq a, Data a) => a -> a uncomb = stabilize uncombOnce optimizeOnce :: (Data a) => a -> a optimizeOnce x = everywhere (mkT optimizeD `extT` optimizeRhs `extT` optimizeE) x optimize :: (Eq a, Data a) => a -> a optimize = stabilize optimizeOnce pointful :: String -> String pointful = withParsed (stabilize (optimize . uncomb) . stabilize (unfoldCombinators . uncomb)) -- TODO: merge this into a proper test suite once one exists -- test s = case parseModule s of -- f@(ParseFailed _ _) -> fail (show f) -- ParseOk (Hs.Module _ _ _ _ _ _ defs) -> -- flip mapM_ defs $ \def -> do -- putStrLn . prettyPrintInLine $ def -- putStrLn . prettyPrintInLine . uncomb $ def -- putStrLn . prettyPrintInLine . optimize . uncomb $ def -- putStrLn . prettyPrintInLine . stabilize (optimize . uncomb) $ def -- putStrLn "" -- -- main = test "f = tail . head; g = head . tail; h = tail + tail; three = g . h . i; dontSub = (\\x -> x + x) 1; ofHead f = f . head; fm = flip mapM_ xs (\\x -> g x); po = (+1); op = (1+); g = (. f); stabilize = fix (ap . flip (ap . (flip =<< (if' .) . (==))) =<<)" --