{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE PatternSynonyms #-}
{-# LANGUAGE ViewPatterns #-}
-- Undo pointfree transformations. Plugin code derived from Pl.hs.
module Lambdabot.Plugin.Haskell.Pointful (pointfulPlugin) where

import Lambdabot.Module as Lmb (Module)
import Lambdabot.Plugin
import Lambdabot.Util.Parser (withParsed, prettyPrintInLine)

import Control.Monad.Reader
import Control.Monad.State
import Data.Functor.Identity (Identity)
import Data.Generics
import qualified Data.Set as S
import qualified Data.Map as M
import Data.List
import Data.Maybe
import Language.Haskell.Exts.Simple as Hs

pointfulPlugin :: Lmb.Module ()
pointfulPlugin = newModule
    { moduleCmds = return
        [ (command "pointful")
            { aliases = ["pointy","repoint","unpointless","unpl","unpf"]
            , help = say "pointful <expr>. Make code pointier."
            , process = mapM_ say . lines . pointful
            }
        ]
    }

---- Utilities ----

stabilize :: Eq a => (a -> a) -> a -> a
stabilize f x = let x' = f x in if x' == x then x else stabilize f x'

-- varsBoundHere returns variables bound by top patterns or binders
varsBoundHere :: Data d => d -> S.Set Name
varsBoundHere (cast -> Just (PVar name)) = S.singleton name
varsBoundHere (cast -> Just (Match name _ _ _)) = S.singleton name
varsBoundHere (cast -> Just (PatBind pat _ _)) = varsBoundHere pat
varsBoundHere (cast -> Just (_ :: Exp)) = S.empty
varsBoundHere d = S.unions (gmapQ varsBoundHere d)

-- note: the tempting idea of using a pattern synonym for the frequent
-- (cast -> Just _) patterns causes compiler crashes with ghc before
-- version 8; cf. https://ghc.haskell.org/trac/ghc/ticket/11336

foldFreeVars :: forall a d. Data d => (Name -> S.Set Name -> a) -> ([a] -> a) -> d -> a
foldFreeVars var sum e = runReader (go e) S.empty where
    go :: forall d. Data d => d -> Reader (S.Set Name) a
    go (cast -> Just (Var (UnQual name))) =
        asks (var name)
    go (cast -> Just (Lambda ps exp)) =
        bind [varsBoundHere ps] $ go exp
    go (cast -> Just (Let bs exp)) =
        bind [varsBoundHere bs] $ collect [go bs, go exp]
    go (cast -> Just (Alt pat exp bs)) =
        bind [varsBoundHere pat, varsBoundHere bs] $ collect [go exp, go bs]
    go (cast -> Just (PatBind pat exp bs)) =
        bind [varsBoundHere pat, varsBoundHere bs] $ collect [go exp, go bs]
    go (cast -> Just (Match _ ps exp bs)) =
        bind [varsBoundHere ps, varsBoundHere bs] $ collect [go exp, go bs]
    go d = collect (gmapQ go d)

    collect :: forall m. Monad m => [m a] -> m a
    collect ms = sum `liftM` sequence ms

    bind :: forall a b. Ord a => [S.Set a] -> Reader (S.Set a) b -> Reader (S.Set a) b
    bind ss = local (S.unions ss `S.union`)

-- return free variables
freeVars :: Data d => d -> S.Set Name
freeVars = foldFreeVars (\name bv -> S.singleton name `S.difference` bv) S.unions

-- return number of free occurrences of a variable
countOcc :: Data d => Name -> d -> Int
countOcc name = foldFreeVars var sum where
    sum = foldl' (+) 0
    var name' bv = if name /= name' || name' `S.member` bv then 0 else 1

-- variable capture avoiding substitution
substAvoiding :: Data d => M.Map Name Exp -> S.Set Name -> d -> d
substAvoiding subst bv = base `extT` exp `extT` alt `extT` decl `extT` match where
    base :: Data d => d -> d
    base = gmapT (substAvoiding subst bv)

    exp e@(Var (UnQual name)) =
        fromMaybe e (M.lookup name subst)
    exp (Lambda ps exp) =
        let (subst', bv', ps') = renameBinds subst bv ps
        in  Lambda ps' (substAvoiding subst' bv' exp)
    exp (Let bs exp) =
        let (subst', bv', bs') = renameBinds subst bv bs
        in  Let (substAvoiding subst' bv' bs') (substAvoiding subst' bv' exp)
    exp d = base d

    alt (Alt pat exp bs) =
        let (subst1, bv1, pat') = renameBinds subst bv pat
            (subst', bv', bs') = renameBinds subst1 bv1 bs
        in  Alt pat' (substAvoiding subst' bv' exp) (substAvoiding subst' bv' bs')

    decl (PatBind pat exp bs) =
        let (subst', bv', bs') = renameBinds subst bv bs in
        PatBind pat (substAvoiding subst' bv' exp) (substAvoiding subst' bv' bs')
    decl d = base d

    match (Match name ps exp bs) =
        let (subst1, bv1, ps') = renameBinds subst bv ps
            (subst', bv', bs') = renameBinds subst1 bv1 bs
        in  Match name ps' (substAvoiding subst' bv' exp) (substAvoiding subst' bv' bs')

-- rename local binders (but not the nested expressions)
renameBinds :: Data d => M.Map Name Exp -> S.Set Name -> d -> (M.Map Name Exp, S.Set Name, d)
renameBinds subst bv d = (subst', bv', d') where
    (d', (subst', bv', _)) = runState (go d) (subst, bv, M.empty)

    go, base :: Data d => d -> State (M.Map Name Exp, S.Set Name, M.Map Name Name) d
    go = base `extM` pat `extM` match `extM` decl `extM` exp
    base d = gmapM go d

    pat (PVar name) = PVar `fmap` rename name
    pat d = base d

    match (Match name ps exp bs) = do
        name' <- rename name
        return $ Match name' ps exp bs

    decl (PatBind pat exp bs) = do
        pat' <- go pat
        return $ PatBind pat' exp bs
    decl d = base d

    exp (e :: Exp) = return e

    rename :: Name -> State (M.Map Name Exp, S.Set Name, M.Map Name Name) Name
    rename name = do
        (subst, bv, ass) <- get
        case (name `M.lookup` ass, name `S.member` bv) of
            (Just name', _) -> do
                 return name'
            (_, False) -> do
                put (M.delete name subst, S.insert name bv, ass)
                return name
            _ -> do
                let name' = freshNameAvoiding name bv
                put (M.insert name (Var (UnQual name')) subst,
                     S.insert name' bv, M.insert name name' ass)
                return name'

-- generate fresh names
freshNameAvoiding :: Name -> S.Set Name -> Name
freshNameAvoiding name forbidden  = con (pre ++ suf) where
    (con, nm, cs) = case name of
         Ident  n -> (Ident,  n, "0123456789")
         Symbol n -> (Symbol, n, "?#")
    pre = reverse . dropWhile (`elem` cs) . reverse $ nm
    sufs = [1..] >>= flip replicateM cs
    suf = head $ dropWhile (\suf -> con (pre ++ suf) `S.member` forbidden) sufs

---- Optimization (removing explicit lambdas) and restoration of infix ops ----

-- move lambda patterns into LHS
optimizeD :: Decl -> Decl
optimizeD (PatBind (PVar fname) (UnGuardedRhs (Lambda pats rhs)) Nothing) =
    let (subst, bv, pats') = renameBinds M.empty (S.singleton fname) pats
        rhs' = substAvoiding subst bv rhs
    in  FunBind [Match fname pats' (UnGuardedRhs rhs') Nothing]
---- combine function binding and lambda
optimizeD (FunBind [Match fname pats1 (UnGuardedRhs (Lambda pats2 rhs)) Nothing]) =
    let (subst, bv, pats2') = renameBinds M.empty (varsBoundHere pats1) pats2
        rhs' = substAvoiding subst bv rhs
    in  FunBind [Match fname (pats1 ++ pats2') (UnGuardedRhs rhs') Nothing]
optimizeD x = x

-- remove parens
optimizeRhs :: Rhs -> Rhs
optimizeRhs (UnGuardedRhs (Paren x)) = UnGuardedRhs x
optimizeRhs x = x

optimizeE :: Exp -> Exp
-- apply ((\x z -> ...x...) y) yielding (\z -> ...y...) if there is only one x or y is simple
optimizeE (App (Lambda (PVar ident : pats) body) arg) | single || simple arg =
     let (subst, bv, pats') = renameBinds (M.singleton ident arg) (freeVars arg) pats
     in  Paren (Lambda pats' (substAvoiding subst bv body))
  where
    single = countOcc ident body <= 1
    simple e = case e of Var _ -> True; Lit _ -> True; Paren e' -> simple e'; _ -> False
-- apply ((\_ z -> ...) y) yielding (\z -> ...)
optimizeE (App (Lambda (PWildCard : pats) body) _) =
    Paren (Lambda pats body)
-- remove 0-arg lambdas resulting from application rules
optimizeE (Lambda [] b) =
    b
-- replace (\x -> \y -> z) with (\x y -> z)
optimizeE (Lambda p1 (Lambda p2 body)) =
    let (subst, bv, p2') = renameBinds M.empty (freeVars (Lambda p2 body)) p2
        body' = substAvoiding subst bv body
    in  Lambda (p1 ++ p2') body'
-- remove double parens
optimizeE (Paren (Paren x)) =
    Paren x
-- remove parens around applied lambdas (the pretty printer restores them)
optimizeE (App (Paren (x@Lambda{})) y) =
    App x y
-- remove lambda body parens
optimizeE (Lambda p (Paren x)) =
    Lambda p x
-- remove var, lit parens
optimizeE (Paren x@(Var _)) =
    x
optimizeE (Paren x@(Lit _)) =
    x
-- remove infix+lambda parens
optimizeE (InfixApp a o (Paren l@(Lambda _ _))) =
    InfixApp a o l
-- remove infix+app aprens
optimizeE (InfixApp (Paren a@App{}) o l) =
    InfixApp a o l
optimizeE (InfixApp a o (Paren l@App{})) =
    InfixApp a o l
-- remove left-assoc application parens
optimizeE (App (Paren (App a b)) c) =
    App (App a b) c
-- restore infix
optimizeE (App (App (Var name'@(UnQual (Symbol _))) l) r) =
    (InfixApp l (QVarOp name') r)
-- eta reduce
optimizeE (Lambda ps@(_:_) (App e (Var (UnQual v))))
    | free && last ps == PVar v = Lambda (init ps) e
  where free = countOcc v e == 0
-- fail
optimizeE x = x

---- Decombinatorization ----

uncomb' :: Exp -> Exp

uncomb' (Paren (Paren e)) = Paren e

-- eliminate sections
uncomb' (RightSection op' arg) =
    let a = freshNameAvoiding (Ident "a") (freeVars arg)
    in  (Paren (Lambda [PVar a] (InfixApp (Var (UnQual a)) op' arg)))
uncomb' (LeftSection arg op') =
    let a = freshNameAvoiding (Ident "a") (freeVars arg)
    in  (Paren (Lambda [PVar a] (InfixApp arg op' (Var (UnQual a)))))
-- infix to prefix for canonicality
uncomb' (InfixApp lf (QVarOp name') rf) =
    (Paren (App (App (Var name') (Paren lf)) (Paren rf)))

-- Expand (>>=) when it is obviously the reader monad:

-- rewrite: (>>=) (\x -> e)
-- to:      (\ a b -> a ((\ x -> e) b) b)
uncomb' (App (Var (UnQual (Symbol ">>="))) (Paren lam@Lambda{})) =
   let a = freshNameAvoiding (Ident "a") (freeVars lam)
       b = freshNameAvoiding (Ident "b") (freeVars lam)
   in  (Paren (Lambda [PVar a, PVar b]
           (App (App (Var (UnQual a)) (Paren (App lam (Var (UnQual b))))) (Var (UnQual b)))))
-- rewrite: ((>>=) e1) (\x y -> e2)
-- to:      (\a -> (\x y -> e2) (e1 a) a)
uncomb' (App (App (Var (UnQual (Symbol ">>="))) e1) (Paren lam@(Lambda (_:_:_) _))) =
    let a = freshNameAvoiding (Ident "a") (freeVars [e1,lam])
    in  (Paren (Lambda [PVar a]
            (App (App lam (App e1 (Var (UnQual a)))) (Var (UnQual a)))))

-- fail
uncomb' expr = expr

---- Simple combinator definitions ---
combinators :: M.Map Name Exp
combinators = M.fromList $ map declToTuple defs
  where defs = case parseModule combinatorModule of
          ParseOk (Hs.Module _ _ _ d) -> d
          f@(ParseFailed _ _) -> error ("Combinator loading: " ++ show f)
        declToTuple (PatBind (PVar fname) (UnGuardedRhs body) Nothing)
          = (fname, Paren body)
        declToTuple _ = error "Pointful Plugin error: can't convert declaration to tuple"

combinatorModule :: String
combinatorModule = unlines [
  "(.)    = \\f g x -> f (g x)                                          ",
  "($)    = \\f x   -> f x                                              ",
  "flip   = \\f x y -> f y x                                            ",
  "const  = \\x _ -> x                                                  ",
  "id     = \\x -> x                                                    ",
  "(=<<)  = flip (>>=)                                                  ",
  "liftM2 = \\f m1 m2 -> m1 >>= \\x1 -> m2 >>= \\x2 -> return (f x1 x2) ",
  "join   = (>>= id)                                                    ",
  "ap     = liftM2 id                                                   ",
  "(>=>)  = flip (<=<)                                                  ",
  "(<=<)  = \\f g x -> f >>= g x                                        ",
  "                                                                     ",
  "-- ASSUMED reader monad                                              ",
  "-- (>>=)  = (\\f k r -> k (f r) r)                                   ",
  "-- return = const                                                    ",
  ""]

---- Top level ----

unfoldCombinators :: (Data a) => a -> a
unfoldCombinators = substAvoiding combinators (freeVars combinators)

uncombOnce :: (Data a) => a -> a
uncombOnce x = everywhere (mkT uncomb') x
uncomb :: (Eq a, Data a) => a -> a
uncomb = stabilize uncombOnce

optimizeOnce :: (Data a) => a -> a
optimizeOnce x = everywhere (mkT optimizeD `extT` optimizeRhs `extT` optimizeE) x
optimize :: (Eq a, Data a) => a -> a
optimize = stabilize optimizeOnce

pointful :: String -> String
pointful = withParsed (stabilize (optimize . uncomb) . stabilize (unfoldCombinators . uncomb))

-- TODO: merge this into a proper test suite once one exists
-- test s = case parseModule s of
--   f@(ParseFailed _ _) -> fail (show f)
--   ParseOk (Hs.Module _ _ _ _ _ _ defs) ->
--     flip mapM_ defs $ \def -> do
--       putStrLn . prettyPrintInLine  $ def
--       putStrLn . prettyPrintInLine  . uncomb $ def
--       putStrLn . prettyPrintInLine  . optimize . uncomb $ def
--       putStrLn . prettyPrintInLine  . stabilize (optimize . uncomb) $ def
--       putStrLn ""
--
-- main = test "f = tail . head; g = head . tail; h = tail + tail; three = g . h . i; dontSub = (\\x -> x + x) 1; ofHead f = f . head; fm = flip mapM_ xs (\\x -> g x); po = (+1); op = (1+); g = (. f); stabilize = fix (ap . flip (ap . (flip =<< (if' .) . (==))) =<<)"
--