{-# LANGUAGE ImplicitParams #-}
module Plugin.Pl.Optimize (
    optimize,
  ) where

import Plugin.Pl.Common
import Plugin.Pl.Rules
import Plugin.Pl.PrettyPrinter (prettyExpr)

import Data.List (nub)

cut :: [a] -> [a]
cut :: forall a. [a] -> [a]
cut = forall a. Int -> [a] -> [a]
take Int
1

toMonadPlus :: MonadPlus m => Maybe a -> m a
toMonadPlus :: forall (m :: * -> *) a. MonadPlus m => Maybe a -> m a
toMonadPlus Maybe a
Nothing = forall (m :: * -> *) a. MonadPlus m => m a
mzero
toMonadPlus (Just a
x)= forall (m :: * -> *) a. Monad m => a -> m a
return a
x

type Size = Integer
-- This seems to be a better size for our purposes,
-- despite being "a little" slower because of the wasteful uglyprinting
sizeExpr' :: Expr -> Size 
sizeExpr' :: Expr -> Size
sizeExpr' Expr
e = Size
100 forall a. Num a => a -> a -> a
* forall a b. (Integral a, Num b) => a -> b
fromIntegral (forall (t :: * -> *) a. Foldable t => t a -> Int
length forall a b. (a -> b) -> a -> b
$ Expr -> String
prettyExpr Expr
e) forall a. Num a => a -> a -> a
+ Expr -> Size
adjust Expr
e where
  -- hackish thing to favor some expressions if the length is the same:
  -- (+ x) --> (x +)
  -- x >>= f --> f =<< x
  -- f $ g x --> f (g x)
  adjust :: Expr -> Size
  adjust :: Expr -> Size
adjust (Var Fixity
_ String
str) -- Just n <- readM str = log (n*n+1) / 4
                     | String
str forall a. Eq a => a -> a -> Bool
== String
"uncurry"    = -Size
400
--                     | str == "s"          = 500
                     | String
str forall a. Eq a => a -> a -> Bool
== String
"flip"       = Size
10
                     | String
str forall a. Eq a => a -> a -> Bool
== String
">>="        = Size
5
                     | String
str forall a. Eq a => a -> a -> Bool
== String
"$"          = Size
1
                     | String
str forall a. Eq a => a -> a -> Bool
== String
"subtract"   = Size
1
                     | String
str forall a. Eq a => a -> a -> Bool
== String
"ap"         = Size
200
                     | String
str forall a. Eq a => a -> a -> Bool
== String
"liftM2"     = Size
101
                     | String
str forall a. Eq a => a -> a -> Bool
== String
"return"     = -Size
200
                     | String
str forall a. Eq a => a -> a -> Bool
== String
"zipWith"    = -Size
400
                     | String
str forall a. Eq a => a -> a -> Bool
== String
"const"      = Size
0 -- -200
                     | String
str forall a. Eq a => a -> a -> Bool
== String
"fmap"       = -Size
100
  adjust (Lambda Pattern
_ Expr
e') = Expr -> Size
adjust Expr
e'
  adjust (App Expr
e1 Expr
e2)  = Expr -> Size
adjust Expr
e1 forall a. Num a => a -> a -> a
+ Expr -> Size
adjust Expr
e2
  adjust Expr
_ = Size
0

optimize :: Expr -> [Expr]
optimize :: Expr -> [Expr]
optimize Expr
e = [Expr]
result where
  result :: [Expr]
  result :: [Expr]
result = forall a b. (a -> b) -> [a] -> [b]
map (forall a b. (a, b) -> b
snd forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. HasCallStack => Maybe a -> a
fromJust) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. (a -> Bool) -> [a] -> [a]
takeWhile forall a. Maybe a -> Bool
isJust forall b c a. (b -> c) -> (a -> b) -> a -> c
. 
    forall a. (a -> a) -> a -> [a]
iterate (forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
(=<<) (Size, Expr) -> Maybe (Size, Expr)
simpleStep) forall a b. (a -> b) -> a -> b
$ forall a. a -> Maybe a
Just (Expr -> Size
sizeExpr' Expr
e, Expr
e)

  simpleStep :: (Size, Expr) -> Maybe (Size, Expr)
  simpleStep :: (Size, Expr) -> Maybe (Size, Expr)
simpleStep (Size, Expr)
t = do 
    let chn :: [Expr]
chn = let ?first = Bool
True in (?first::Bool) => Expr -> [Expr]
step (forall a b. (a, b) -> b
snd (Size, Expr)
t)
        chnn :: [Expr]
chnn = let ?first = Bool
False in (?first::Bool) => Expr -> [Expr]
step forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< [Expr]
chn
        new :: [(Size, Expr)]
new = forall a. (a -> Bool) -> [a] -> [a]
filter (\(Size
x,Expr
_) -> Size
x forall a. Ord a => a -> a -> Bool
< forall a b. (a, b) -> a
fst (Size, Expr)
t) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a -> b) -> [a] -> [b]
map (Expr -> Size
sizeExpr' forall (a :: * -> * -> *) b c c'.
Arrow a =>
a b c -> a b c' -> a b (c, c')
&&& forall a. a -> a
id) forall a b. (a -> b) -> a -> b
$ 
                forall a b. (a, b) -> b
snd (Size, Expr)
tforall a. a -> [a] -> [a]
: [Expr]
chn forall a. [a] -> [a] -> [a]
++ [Expr]
chnn
    case [(Size, Expr)]
new of
      [] -> forall a. Maybe a
Nothing
      ((Size, Expr)
new':[(Size, Expr)]
_) -> forall (m :: * -> *) a. Monad m => a -> m a
return (Size, Expr)
new'

step :: (?first :: Bool) => Expr -> [Expr]
step :: (?first::Bool) => Expr -> [Expr]
step Expr
e = forall a. Eq a => [a] -> [a]
nub forall a b. (a -> b) -> a -> b
$ (?first::Bool) => RewriteRule -> Expr -> [Expr]
rewrite RewriteRule
rules Expr
e
 
rewrite :: (?first :: Bool) => RewriteRule -> Expr -> [Expr]
rewrite :: (?first::Bool) => RewriteRule -> Expr -> [Expr]
rewrite RewriteRule
rl Expr
e = case RewriteRule
rl of
    Up RewriteRule
r1 RewriteRule
r2     -> let e' :: [Expr]
e'  = forall a. [a] -> [a]
cut forall a b. (a -> b) -> a -> b
$ (?first::Bool) => RewriteRule -> Expr -> [Expr]
rewrite RewriteRule
r1 Expr
e
                        e'' :: [Expr]
e'' = (?first::Bool) => RewriteRule -> Expr -> [Expr]
rewrite RewriteRule
r2 forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< [Expr]
e'
                    in if forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Expr]
e'' then [Expr]
e' else [Expr]
e''
    OrElse RewriteRule
r1 RewriteRule
r2 -> let e' :: [Expr]
e'  = (?first::Bool) => RewriteRule -> Expr -> [Expr]
rewrite RewriteRule
r1 Expr
e
                    in if forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Expr]
e' then (?first::Bool) => RewriteRule -> Expr -> [Expr]
rewrite RewriteRule
r2 Expr
e else [Expr]
e' 
    Then RewriteRule
r1 RewriteRule
r2   -> (?first::Bool) => RewriteRule -> Expr -> [Expr]
rewrite RewriteRule
r2 forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall a. Eq a => [a] -> [a]
nub ((?first::Bool) => RewriteRule -> Expr -> [Expr]
rewrite RewriteRule
r1 Expr
e)
    Opt  RewriteRule
r       -> Expr
eforall a. a -> [a] -> [a]
: (?first::Bool) => RewriteRule -> Expr -> [Expr]
rewrite RewriteRule
r Expr
e
    If   RewriteRule
p  RewriteRule
r    -> if forall (t :: * -> *) a. Foldable t => t a -> Bool
null ((?first::Bool) => RewriteRule -> Expr -> [Expr]
rewrite RewriteRule
p Expr
e) then forall (m :: * -> *) a. MonadPlus m => m a
mzero else (?first::Bool) => RewriteRule -> Expr -> [Expr]
rewrite RewriteRule
r Expr
e
    Hard RewriteRule
r       -> if ?first::Bool
?first then (?first::Bool) => RewriteRule -> Expr -> [Expr]
rewrite RewriteRule
r Expr
e else forall (m :: * -> *) a. MonadPlus m => m a
mzero
    Or [RewriteRule]
rs        -> (\RewriteRule
x -> (?first::Bool) => RewriteRule -> Expr -> [Expr]
rewrite RewriteRule
x Expr
e) forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< [RewriteRule]
rs
    RR {}        -> (?first::Bool) => RewriteRule -> Expr -> [Expr]
rewDeep RewriteRule
rl Expr
e
    CRR {}       -> (?first::Bool) => RewriteRule -> Expr -> [Expr]
rewDeep RewriteRule
rl Expr
e
    Down {}      -> (?first::Bool) => RewriteRule -> Expr -> [Expr]
rewDeep RewriteRule
rl Expr
e
    
  where -- rew = ...; rewDeep = ...

rewDeep :: (?first :: Bool) => RewriteRule -> Expr -> [Expr]
rewDeep :: (?first::Bool) => RewriteRule -> Expr -> [Expr]
rewDeep RewriteRule
rule Expr
e = (?first::Bool) => RewriteRule -> Expr -> [Expr]
rew RewriteRule
rule Expr
e forall (m :: * -> *) a. MonadPlus m => m a -> m a -> m a
`mplus` case Expr
e of
    Var Fixity
_ String
_    -> forall (m :: * -> *) a. MonadPlus m => m a
mzero
    Lambda Pattern
_ Expr
_ -> forall a. HasCallStack => String -> a
error String
"lambda: optimizer only works for closed expressions"
    Let [Decl]
_ Expr
_    -> forall a. HasCallStack => String -> a
error String
"let: optimizer only works for closed expressions"
    App Expr
e1 Expr
e2  -> ((Expr -> Expr -> Expr
`App` Expr
e2) forall a b. (a -> b) -> [a] -> [b]
`map` (?first::Bool) => RewriteRule -> Expr -> [Expr]
rewDeep RewriteRule
rule Expr
e1) forall (m :: * -> *) a. MonadPlus m => m a -> m a -> m a
`mplus`
                  ((Expr
e1 Expr -> Expr -> Expr
`App`) forall a b. (a -> b) -> [a] -> [b]
`map` (?first::Bool) => RewriteRule -> Expr -> [Expr]
rewDeep RewriteRule
rule Expr
e2)

rew :: (?first :: Bool) => RewriteRule -> Expr -> [Expr]
rew :: (?first::Bool) => RewriteRule -> Expr -> [Expr]
rew (RR Rewrite
r1 Rewrite
r2) Expr
e = forall (m :: * -> *) a. MonadPlus m => Maybe a -> m a
toMonadPlus forall a b. (a -> b) -> a -> b
$ Rewrite -> Rewrite -> Expr -> Maybe Expr
fire Rewrite
r1 Rewrite
r2 Expr
e 
rew (CRR Expr -> Maybe Expr
r) Expr
e = forall (m :: * -> *) a. MonadPlus m => Maybe a -> m a
toMonadPlus forall a b. (a -> b) -> a -> b
$ Expr -> Maybe Expr
r Expr
e
rew (Or [RewriteRule]
rs) Expr
e = (\RewriteRule
x -> (?first::Bool) => RewriteRule -> Expr -> [Expr]
rew RewriteRule
x Expr
e) forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< [RewriteRule]
rs
rew (Down RewriteRule
r1 RewriteRule
r2) Expr
e
  = if forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Expr]
e'' then [Expr]
e' else [Expr]
e'' where
    e' :: [Expr]
e'  = forall a. [a] -> [a]
cut forall a b. (a -> b) -> a -> b
$ (?first::Bool) => RewriteRule -> Expr -> [Expr]
rew RewriteRule
r1 Expr
e
    e'' :: [Expr]
e'' = (?first::Bool) => RewriteRule -> Expr -> [Expr]
rewDeep RewriteRule
r2 forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< [Expr]
e'
rew r :: RewriteRule
r@(Then   {}) Expr
e = (?first::Bool) => RewriteRule -> Expr -> [Expr]
rewrite RewriteRule
r Expr
e
rew r :: RewriteRule
r@(OrElse {}) Expr
e = (?first::Bool) => RewriteRule -> Expr -> [Expr]
rewrite RewriteRule
r Expr
e
rew r :: RewriteRule
r@(Up     {}) Expr
e = (?first::Bool) => RewriteRule -> Expr -> [Expr]
rewrite RewriteRule
r Expr
e
rew r :: RewriteRule
r@(Opt    {}) Expr
e = (?first::Bool) => RewriteRule -> Expr -> [Expr]
rewrite RewriteRule
r Expr
e
rew r :: RewriteRule
r@(If     {}) Expr
e = (?first::Bool) => RewriteRule -> Expr -> [Expr]
rewrite RewriteRule
r Expr
e
rew r :: RewriteRule
r@(Hard   {}) Expr
e = (?first::Bool) => RewriteRule -> Expr -> [Expr]
rewrite RewriteRule
r Expr
e