module Lambdabot.Pointful (pointful, ParseResult(..), test, main, combinatorModule) where
import Lambdabot.Parser
import Control.Monad.State
import Data.Generics
import Data.Maybe
import Language.Haskell.Parser
import Language.Haskell.Syntax
import qualified Data.Map as M
extT' :: (Typeable a, Typeable b) => (a -> a) -> (b -> b) -> a -> a
extT' = extT
infixl `extT'`
unkLoc = SrcLoc "<new>" 1 1
stabilize f x = let x' = f x in if x' == x then x else stabilize f x'
namesIn h = everything (++) (mkQ [] (\x -> case x of UnQual name -> [name]; _ -> [])) h
pVarsIn h = everything (++) (mkQ [] (\x -> case x of HsPVar name -> [name]; _ -> [])) h
succName (HsIdent s) = HsIdent . reverse . succAlpha . reverse $ s
succAlpha ('z':xs) = 'a' : succAlpha xs
succAlpha (x :xs) = succ x : xs
succAlpha [] = "a"
optimizeD (HsPatBind loc (HsPVar fname) (HsUnGuardedRhs (HsLambda _ pats rhs)) [])
= HsFunBind [HsMatch loc fname pats (HsUnGuardedRhs rhs) []]
optimizeD (HsFunBind [HsMatch loc fname pats1 (HsUnGuardedRhs (HsLambda _ pats2 rhs)) []])
= HsFunBind [HsMatch loc fname (pats1 ++ pats2) (HsUnGuardedRhs rhs) []]
optimizeD x = x
optimizeRhs (HsUnGuardedRhs (HsParen x))
= HsUnGuardedRhs x
optimizeRhs x = x
optimizeE :: HsExp -> HsExp
optimizeE (HsApp (HsParen (HsLambda loc (HsPVar ident : pats) body)) arg) | single || simple
= HsParen (HsLambda loc pats (everywhere (mkT (\x -> if x == (HsVar (UnQual ident)) then arg else x)) body))
where single = gcount (mkQ False (== ident)) body == 1
simple = case arg of HsVar _ -> True; _ -> False
optimizeE (HsApp (HsParen (HsLambda loc (HsPWildCard : pats) body)) _)
= HsParen (HsLambda loc pats body)
optimizeE (HsLambda _ [] b)
= b
optimizeE (HsLambda loc p1 (HsLambda _ p2 body))
= HsLambda loc (p1 ++ p2) body
optimizeE (HsParen (HsParen x))
= HsParen x
optimizeE (HsLambda l p (HsParen x))
= HsLambda l p x
optimizeE (HsParen x@(HsVar _))
= x
optimizeE (HsParen x@(HsLit _))
= x
optimizeE (HsInfixApp a o (HsParen l@(HsLambda _ _ _)))
= HsInfixApp a o l
optimizeE (HsApp (HsParen (HsApp a b)) c)
= HsApp (HsApp a b) c
optimizeE (HsApp (HsApp (HsVar name@(UnQual (HsSymbol _))) l) r)
= (HsInfixApp l (HsQVarOp name) r)
optimizeE x = x
fresh = do (_, used) <- get
modify (\(v,u) -> (until (not . (`elem` used)) succName (succName v), u))
(name, _) <- get
return name
rename = do everywhereM (mkM (\e -> case e of
(HsLambda _ ps _) -> do
let pVars = concatMap pVarsIn ps
newVars <- mapM (const fresh) pVars
let replacements = zip pVars newVars
return (everywhere (mkT (\n -> fromMaybe n (lookup n replacements))) e)
_ -> return e))
uncomb' :: HsExp -> State (HsName, [HsName]) HsExp
uncomb' (HsVar qname) | isJust maybeDef = rename (fromJust maybeDef)
where maybeDef = M.lookup qname combinators
uncomb' (HsRightSection op arg)
= do a <- fresh
return (HsParen (HsLambda unkLoc [HsPVar a] (HsInfixApp (HsVar (UnQual a)) op arg)))
uncomb' (HsLeftSection arg op)
= do a <- fresh
return (HsParen (HsLambda unkLoc [HsPVar a] (HsInfixApp arg op (HsVar (UnQual a)))))
uncomb' (HsInfixApp lf (HsQVarOp name) rf)
= return (HsParen (HsApp (HsApp (HsVar name) (HsParen lf)) (HsParen rf)))
uncomb' expr = return expr
combinators = M.fromList $ map declToTuple defs
where defs = case parseModule combinatorModule of
ParseOk (HsModule _ _ _ _ d) -> d
f@(ParseFailed _ _) -> error ("Combinator loading: " ++ show f)
declToTuple (HsPatBind _ (HsPVar fname) (HsUnGuardedRhs body) [])
= (UnQual fname, HsParen body)
recognizedNames = map (\(UnQual n) -> n) $ M.keys combinators
combinatorModule = unlines [
"(.) = \\f g x -> f (g x) ",
"($) = \\f x -> f x ",
"flip = \\f x y -> f y x ",
"const = \\x _ -> x ",
"id = \\x -> x ",
"(=<<) = flip (>>=) ",
"liftM2 = \\f m1 m2 -> m1 >>= \\x1 -> m2 >>= \\x2 -> return (f x1 x2) ",
"join = (>>= id) ",
"ap = liftM2 id ",
" ",
"-- ASSUMED reader monad ",
"-- (>>=) = (\\f k r -> k (f r) r) ",
"-- return = const ",
""]
uncombOnce :: (Data a) => a -> a
uncombOnce x = evalState (everywhereM (mkM uncomb') x) (HsIdent "`", namesIn x ++ recognizedNames)
uncomb :: (Eq a, Data a) => a -> a
uncomb = stabilize uncombOnce
optimizeOnce :: (Data a) => a -> a
optimizeOnce x = everywhere (mkT optimizeD `extT'` optimizeRhs `extT'` optimizeE) x
optimize :: (Eq a, Data a) => a -> a
optimize = stabilize optimizeOnce
pointful = withParsed (optimize . uncomb)
test s = case parseModule s of
f@(ParseFailed _ _) -> fail (show f)
ParseOk (HsModule _ _ _ _ defs) ->
flip mapM_ defs $ \def -> do
putStrLn . prettyPrintInLine $ def
putStrLn . prettyPrintInLine . uncomb $ def
putStrLn . prettyPrintInLine . optimize . uncomb $ def
main = test "f = tail . head; g = head . tail; h = tail + tail; three = g . h . i; dontSub = (\\x -> x + x) 1; ofHead f = f . head; fm = flip mapM_ xs (\\x -> g x); po = (+1); op = (1+); g = (. f); stabilize = fix (ap . flip (ap . (flip =<< (if' .) . (==))) =<<)"