{-# LANGUAGE ImplicitParams #-}
module Lambdabot.Plugin.Haskell.Pl.Optimize (
    optimize,
  ) where

import Lambdabot.Plugin.Haskell.Pl.Common
import Lambdabot.Plugin.Haskell.Pl.Rules
import Lambdabot.Plugin.Haskell.Pl.PrettyPrinter ()

import Data.List (nub)
import Data.Maybe (listToMaybe)

cut :: [a] -> [a]
cut :: [a] -> [a]
cut = Int -> [a] -> [a]
forall a. Int -> [a] -> [a]
take Int
1

toMonadPlus :: MonadPlus m => Maybe a -> m a
toMonadPlus :: Maybe a -> m a
toMonadPlus Maybe a
Nothing = m a
forall (m :: * -> *) a. MonadPlus m => m a
mzero
toMonadPlus (Just a
x)= a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return a
x

type Size = Double
-- | The 'size' of an expression, lower is better
--
-- This seems to be a better size for our purposes,
-- despite being "a little" slower because of the wasteful uglyprinting
sizeExpr' :: Expr -> Size
sizeExpr' :: Expr -> Size
sizeExpr' Expr
e = Int -> Size
forall a b. (Integral a, Num b) => a -> b
fromIntegral ([Char] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([Char] -> Int) -> [Char] -> Int
forall a b. (a -> b) -> a -> b
$ Expr -> [Char]
forall a. Show a => a -> [Char]
show Expr
e) Size -> Size -> Size
forall a. Num a => a -> a -> a
+ Expr -> Size
adjust Expr
e where
  -- hackish thing to favor some expressions if the length is the same:
  -- (+ x) --> (x +)
  -- x >>= f --> f =<< x
  -- f $ g x --> f (g x)
  adjust :: Expr -> Size
  adjust :: Expr -> Size
adjust (Var Fixity
_ [Char]
str) -- Just n <- readM str = log (n*n+1) / 4
                     | [Char]
str [Char] -> [Char] -> Bool
forall a. Eq a => a -> a -> Bool
== [Char]
"uncurry"    = -Size
4
--                     | str == "s"          = 5
                     | [Char]
str [Char] -> [Char] -> Bool
forall a. Eq a => a -> a -> Bool
== [Char]
"flip"       = Size
0.1
                     | [Char]
str [Char] -> [Char] -> Bool
forall a. Eq a => a -> a -> Bool
== [Char]
">>="        = Size
0.05
                     | [Char]
str [Char] -> [Char] -> Bool
forall a. Eq a => a -> a -> Bool
== [Char]
"$"          = Size
0.01
                     | [Char]
str [Char] -> [Char] -> Bool
forall a. Eq a => a -> a -> Bool
== [Char]
"subtract"   = Size
0.01
                     | [Char]
str [Char] -> [Char] -> Bool
forall a. Eq a => a -> a -> Bool
== [Char]
"ap"         = Size
2
                     | [Char]
str [Char] -> [Char] -> Bool
forall a. Eq a => a -> a -> Bool
== [Char]
"liftM2"     = Size
1.01
                     | [Char]
str [Char] -> [Char] -> Bool
forall a. Eq a => a -> a -> Bool
== [Char]
"return"     = -Size
2
                     | [Char]
str [Char] -> [Char] -> Bool
forall a. Eq a => a -> a -> Bool
== [Char]
"zipWith"    = -Size
4
                     | [Char]
str [Char] -> [Char] -> Bool
forall a. Eq a => a -> a -> Bool
== [Char]
"const"      = Size
0 -- -2
                     | [Char]
str [Char] -> [Char] -> Bool
forall a. Eq a => a -> a -> Bool
== [Char]
"fmap"       = -Size
1
  adjust (Lambda Pattern
_ Expr
e') = Expr -> Size
adjust Expr
e'
  adjust (App Expr
e1 Expr
e2)  = Expr -> Size
adjust Expr
e1 Size -> Size -> Size
forall a. Num a => a -> a -> a
+ Expr -> Size
adjust Expr
e2
  adjust Expr
_ = Size
0

-- | Optimize an expression
optimize :: Expr -> [Expr]
optimize :: Expr -> [Expr]
optimize Expr
e = [Expr]
result where
  result :: [Expr]
  result :: [Expr]
result = (Maybe (Size, Expr) -> Expr) -> [Maybe (Size, Expr)] -> [Expr]
forall a b. (a -> b) -> [a] -> [b]
map ((Size, Expr) -> Expr
forall a b. (a, b) -> b
snd ((Size, Expr) -> Expr)
-> (Maybe (Size, Expr) -> (Size, Expr))
-> Maybe (Size, Expr)
-> Expr
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Maybe (Size, Expr) -> (Size, Expr)
forall a. HasCallStack => Maybe a -> a
fromJust) ([Maybe (Size, Expr)] -> [Expr])
-> (Maybe (Size, Expr) -> [Maybe (Size, Expr)])
-> Maybe (Size, Expr)
-> [Expr]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Maybe (Size, Expr) -> Bool)
-> [Maybe (Size, Expr)] -> [Maybe (Size, Expr)]
forall a. (a -> Bool) -> [a] -> [a]
takeWhile Maybe (Size, Expr) -> Bool
forall a. Maybe a -> Bool
isJust ([Maybe (Size, Expr)] -> [Maybe (Size, Expr)])
-> (Maybe (Size, Expr) -> [Maybe (Size, Expr)])
-> Maybe (Size, Expr)
-> [Maybe (Size, Expr)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
.
    (Maybe (Size, Expr) -> Maybe (Size, Expr))
-> Maybe (Size, Expr) -> [Maybe (Size, Expr)]
forall a. (a -> a) -> a -> [a]
iterate (Maybe (Size, Expr)
-> ((Size, Expr) -> Maybe (Size, Expr)) -> Maybe (Size, Expr)
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= (Size, Expr) -> Maybe (Size, Expr)
simpleStep) (Maybe (Size, Expr) -> [Expr]) -> Maybe (Size, Expr) -> [Expr]
forall a b. (a -> b) -> a -> b
$ (Size, Expr) -> Maybe (Size, Expr)
forall a. a -> Maybe a
Just (Expr -> Size
sizeExpr' Expr
e, Expr
e)

  simpleStep :: (Size, Expr) -> Maybe (Size, Expr)
  simpleStep :: (Size, Expr) -> Maybe (Size, Expr)
simpleStep (Size, Expr)
t = do
    let chn :: [Expr]
chn  = let ?first = True  in (?first::Bool) => Expr -> [Expr]
Expr -> [Expr]
step ((Size, Expr) -> Expr
forall a b. (a, b) -> b
snd (Size, Expr)
t)
        chnn :: [Expr]
chnn = let ?first = False in (?first::Bool) => Expr -> [Expr]
Expr -> [Expr]
step (Expr -> [Expr]) -> [Expr] -> [Expr]
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< [Expr]
chn
        new :: [(Size, Expr)]
new  = ((Size, Expr) -> Bool) -> [(Size, Expr)] -> [(Size, Expr)]
forall a. (a -> Bool) -> [a] -> [a]
filter (\(Size
x,Expr
_) -> Size
x Size -> Size -> Bool
forall a. Ord a => a -> a -> Bool
< (Size, Expr) -> Size
forall a b. (a, b) -> a
fst (Size, Expr)
t) ([(Size, Expr)] -> [(Size, Expr)])
-> ([Expr] -> [(Size, Expr)]) -> [Expr] -> [(Size, Expr)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Expr -> (Size, Expr)) -> [Expr] -> [(Size, Expr)]
forall a b. (a -> b) -> [a] -> [b]
map (Expr -> Size
sizeExpr' (Expr -> Size) -> (Expr -> Expr) -> Expr -> (Size, Expr)
forall (a :: * -> * -> *) b c c'.
Arrow a =>
a b c -> a b c' -> a b (c, c')
&&& Expr -> Expr
forall a. a -> a
id) ([Expr] -> [(Size, Expr)]) -> [Expr] -> [(Size, Expr)]
forall a b. (a -> b) -> a -> b
$
                (Size, Expr) -> Expr
forall a b. (a, b) -> b
snd (Size, Expr)
tExpr -> [Expr] -> [Expr]
forall a. a -> [a] -> [a]
: [Expr]
chn [Expr] -> [Expr] -> [Expr]
forall a. [a] -> [a] -> [a]
++ [Expr]
chnn
    [(Size, Expr)] -> Maybe (Size, Expr)
forall a. [a] -> Maybe a
listToMaybe [(Size, Expr)]
new

-- | Apply all rewrite rules once
step :: (?first :: Bool) => Expr -> [Expr]
step :: Expr -> [Expr]
step Expr
e = [Expr] -> [Expr]
forall a. Eq a => [a] -> [a]
nub ([Expr] -> [Expr]) -> [Expr] -> [Expr]
forall a b. (a -> b) -> a -> b
$ (?first::Bool) => RewriteRule -> Expr -> [Expr]
RewriteRule -> Expr -> [Expr]
rewrite RewriteRule
rules Expr
e

-- | Apply a single rewrite rule
--
rewrite :: (?first :: Bool) => RewriteRule -> Expr -> [Expr]
rewrite :: RewriteRule -> Expr -> [Expr]
rewrite RewriteRule
rl Expr
e = case RewriteRule
rl of
    Up RewriteRule
r1 RewriteRule
r2     -> let e' :: [Expr]
e'  = [Expr] -> [Expr]
forall a. [a] -> [a]
cut ([Expr] -> [Expr]) -> [Expr] -> [Expr]
forall a b. (a -> b) -> a -> b
$ (?first::Bool) => RewriteRule -> Expr -> [Expr]
RewriteRule -> Expr -> [Expr]
rewrite RewriteRule
r1 Expr
e
                        e'' :: [Expr]
e'' = (?first::Bool) => RewriteRule -> Expr -> [Expr]
RewriteRule -> Expr -> [Expr]
rewrite RewriteRule
r2 (Expr -> [Expr]) -> [Expr] -> [Expr]
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< [Expr]
e'
                    in if [Expr] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Expr]
e'' then [Expr]
e' else [Expr]
e''
    OrElse RewriteRule
r1 RewriteRule
r2 -> let e' :: [Expr]
e'  = (?first::Bool) => RewriteRule -> Expr -> [Expr]
RewriteRule -> Expr -> [Expr]
rewrite RewriteRule
r1 Expr
e
                    in if [Expr] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Expr]
e' then (?first::Bool) => RewriteRule -> Expr -> [Expr]
RewriteRule -> Expr -> [Expr]
rewrite RewriteRule
r2 Expr
e else [Expr]
e'
    Then RewriteRule
r1 RewriteRule
r2   -> (?first::Bool) => RewriteRule -> Expr -> [Expr]
RewriteRule -> Expr -> [Expr]
rewrite RewriteRule
r2 (Expr -> [Expr]) -> [Expr] -> [Expr]
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< [Expr] -> [Expr]
forall a. Eq a => [a] -> [a]
nub ((?first::Bool) => RewriteRule -> Expr -> [Expr]
RewriteRule -> Expr -> [Expr]
rewrite RewriteRule
r1 Expr
e)
    Opt  RewriteRule
r       -> Expr
eExpr -> [Expr] -> [Expr]
forall a. a -> [a] -> [a]
: (?first::Bool) => RewriteRule -> Expr -> [Expr]
RewriteRule -> Expr -> [Expr]
rewrite RewriteRule
r Expr
e
    If   RewriteRule
p  RewriteRule
r    -> if [Expr] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null ((?first::Bool) => RewriteRule -> Expr -> [Expr]
RewriteRule -> Expr -> [Expr]
rewrite RewriteRule
p Expr
e) then [Expr]
forall (m :: * -> *) a. MonadPlus m => m a
mzero else (?first::Bool) => RewriteRule -> Expr -> [Expr]
RewriteRule -> Expr -> [Expr]
rewrite RewriteRule
r Expr
e
    Hard RewriteRule
r       -> if ?first::Bool
Bool
?first then (?first::Bool) => RewriteRule -> Expr -> [Expr]
RewriteRule -> Expr -> [Expr]
rewrite RewriteRule
r Expr
e else [Expr]
forall (m :: * -> *) a. MonadPlus m => m a
mzero
    Or [RewriteRule]
rs        -> (\RewriteRule
x -> (?first::Bool) => RewriteRule -> Expr -> [Expr]
RewriteRule -> Expr -> [Expr]
rewrite RewriteRule
x Expr
e) (RewriteRule -> [Expr]) -> [RewriteRule] -> [Expr]
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< [RewriteRule]
rs
    RR {}        -> (?first::Bool) => RewriteRule -> Expr -> [Expr]
RewriteRule -> Expr -> [Expr]
rewDeep RewriteRule
rl Expr
e
    CRR {}       -> (?first::Bool) => RewriteRule -> Expr -> [Expr]
RewriteRule -> Expr -> [Expr]
rewDeep RewriteRule
rl Expr
e
    Down {}      -> (?first::Bool) => RewriteRule -> Expr -> [Expr]
RewriteRule -> Expr -> [Expr]
rewDeep RewriteRule
rl Expr
e

  where -- rew = ...; rewDeep = ...

-- Apply a 'deep' reqrite rule
rewDeep :: (?first :: Bool) => RewriteRule -> Expr -> [Expr]
rewDeep :: RewriteRule -> Expr -> [Expr]
rewDeep RewriteRule
rule Expr
e = (?first::Bool) => RewriteRule -> Expr -> [Expr]
RewriteRule -> Expr -> [Expr]
rew RewriteRule
rule Expr
e [Expr] -> [Expr] -> [Expr]
forall (m :: * -> *) a. MonadPlus m => m a -> m a -> m a
`mplus` case Expr
e of
    Var Fixity
_ [Char]
_    -> [Expr]
forall (m :: * -> *) a. MonadPlus m => m a
mzero
    Lambda Pattern
_ Expr
_ -> [Char] -> [Expr]
forall a. HasCallStack => [Char] -> a
error [Char]
"lambda: optimizer only works for closed expressions"
    Let [Decl]
_ Expr
_    -> [Char] -> [Expr]
forall a. HasCallStack => [Char] -> a
error [Char]
"let: optimizer only works for closed expressions"
    App Expr
e1 Expr
e2  -> ((Expr -> Expr -> Expr
`App` Expr
e2) (Expr -> Expr) -> [Expr] -> [Expr]
forall a b. (a -> b) -> [a] -> [b]
`map` (?first::Bool) => RewriteRule -> Expr -> [Expr]
RewriteRule -> Expr -> [Expr]
rewDeep RewriteRule
rule Expr
e1) [Expr] -> [Expr] -> [Expr]
forall (m :: * -> *) a. MonadPlus m => m a -> m a -> m a
`mplus`
                  ((Expr
e1 Expr -> Expr -> Expr
`App`) (Expr -> Expr) -> [Expr] -> [Expr]
forall a b. (a -> b) -> [a] -> [b]
`map` (?first::Bool) => RewriteRule -> Expr -> [Expr]
RewriteRule -> Expr -> [Expr]
rewDeep RewriteRule
rule Expr
e2)

-- | Apply a rewrite rule to an expression
--   in a 'deep' position, i.e. from inside a RR,CRR or Down
rew :: (?first :: Bool) => RewriteRule -> Expr -> [Expr]
rew :: RewriteRule -> Expr -> [Expr]
rew (RR Rewrite
r1 Rewrite
r2)   Expr
e = Maybe Expr -> [Expr]
forall (m :: * -> *) a. MonadPlus m => Maybe a -> m a
toMonadPlus (Maybe Expr -> [Expr]) -> Maybe Expr -> [Expr]
forall a b. (a -> b) -> a -> b
$ Rewrite -> Rewrite -> Expr -> Maybe Expr
fire Rewrite
r1 Rewrite
r2 Expr
e
rew (CRR Expr -> Maybe Expr
r)      Expr
e = Maybe Expr -> [Expr]
forall (m :: * -> *) a. MonadPlus m => Maybe a -> m a
toMonadPlus (Maybe Expr -> [Expr]) -> Maybe Expr -> [Expr]
forall a b. (a -> b) -> a -> b
$ Expr -> Maybe Expr
r Expr
e
rew (Or [RewriteRule]
rs)      Expr
e = (\RewriteRule
x -> (?first::Bool) => RewriteRule -> Expr -> [Expr]
RewriteRule -> Expr -> [Expr]
rew RewriteRule
x Expr
e) (RewriteRule -> [Expr]) -> [RewriteRule] -> [Expr]
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< [RewriteRule]
rs
rew (Down RewriteRule
r1 RewriteRule
r2) Expr
e = if [Expr] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Expr]
e'' then [Expr]
e' else [Expr]
e''
  where
    e' :: [Expr]
e'  = [Expr] -> [Expr]
forall a. [a] -> [a]
cut ([Expr] -> [Expr]) -> [Expr] -> [Expr]
forall a b. (a -> b) -> a -> b
$ (?first::Bool) => RewriteRule -> Expr -> [Expr]
RewriteRule -> Expr -> [Expr]
rew RewriteRule
r1 Expr
e
    e'' :: [Expr]
e'' = (?first::Bool) => RewriteRule -> Expr -> [Expr]
RewriteRule -> Expr -> [Expr]
rewDeep RewriteRule
r2 (Expr -> [Expr]) -> [Expr] -> [Expr]
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< [Expr]
e'
rew r :: RewriteRule
r@(Then   {}) Expr
e = (?first::Bool) => RewriteRule -> Expr -> [Expr]
RewriteRule -> Expr -> [Expr]
rewrite RewriteRule
r Expr
e
rew r :: RewriteRule
r@(OrElse {}) Expr
e = (?first::Bool) => RewriteRule -> Expr -> [Expr]
RewriteRule -> Expr -> [Expr]
rewrite RewriteRule
r Expr
e
rew r :: RewriteRule
r@(Up     {}) Expr
e = (?first::Bool) => RewriteRule -> Expr -> [Expr]
RewriteRule -> Expr -> [Expr]
rewrite RewriteRule
r Expr
e
rew r :: RewriteRule
r@(Opt    {}) Expr
e = (?first::Bool) => RewriteRule -> Expr -> [Expr]
RewriteRule -> Expr -> [Expr]
rewrite RewriteRule
r Expr
e
rew r :: RewriteRule
r@(If     {}) Expr
e = (?first::Bool) => RewriteRule -> Expr -> [Expr]
RewriteRule -> Expr -> [Expr]
rewrite RewriteRule
r Expr
e
rew r :: RewriteRule
r@(Hard   {}) Expr
e = (?first::Bool) => RewriteRule -> Expr -> [Expr]
RewriteRule -> Expr -> [Expr]
rewrite RewriteRule
r Expr
e