{-# LANGUAGE FlexibleInstances, PatternGuards, ScopedTypeVariables #-}

-- | This marvellous module contributed by Thomas J\344ger
module Lambdabot.Plugin.Haskell.Pl.RuleLib
       (  -- Using rules
          RewriteRule(..), fire
       ,  -- Defining rules
          rr,rr0,rr1,rr2,up,down
       ) where

import Lambdabot.Plugin.Haskell.Pl.Common
import Lambdabot.Plugin.Haskell.Pl.Names

import Data.Array
import qualified Data.Set as S

import Control.Monad.Fix (fix)

-- Next time I do something like this, I'll actually think about the combinator
-- language before, instead of producing something ad-hoc like this:
data RewriteRule
  = RR     Rewrite Rewrite           -- ^ A 'Rewrite' rule, rewrite the first to the second
                                     --   'Rewrite's can contain 'Hole's
  | CRR    (Expr -> Maybe Expr)      -- ^ Haskell function as a rule, applied to subexpressions
  | Down   RewriteRule RewriteRule   -- ^ Like Up, but applied to subexpressions
  | Up     RewriteRule RewriteRule   -- ^ Apply the first rule, then try the second rule on the first result
                                     --   if it fails, returns the result of the first rule
  | Or     [RewriteRule]             -- ^ Use all rules
  | OrElse RewriteRule RewriteRule   -- ^ Try the first rule, if it fails use the second rule
  | Then   RewriteRule RewriteRule   -- ^ Apply the first rule, apply the second rule to the result
  | Opt    RewriteRule               -- ^ Optionally apply the rewrite rule, Opt x == Or [identity,x]
  | If     RewriteRule RewriteRule   -- ^ Apply the second rule only if the first rule has some results
  | Hard   RewriteRule               -- ^ Apply the rule only in the first pass

-- | An expression with holes to match or replace
data Rewrite = Rewrite {
  Rewrite -> MExpr
holes :: MExpr,  -- ^ Expression with holes
  Rewrite -> Int
rid   :: Int     -- ^ Number of holes
}

-- What are you gonna do when no recursive modules are possible?
class RewriteC a where
  getRewrite :: a -> Rewrite

instance RewriteC MExpr where
  getRewrite :: MExpr -> Rewrite
getRewrite MExpr
rule = Rewrite :: MExpr -> Int -> Rewrite
Rewrite {
    holes :: MExpr
holes = MExpr
rule,
    rid :: Int
rid   = Int
0
  }

-- lift functions to rewrite rules
instance RewriteC a => RewriteC (MExpr -> a) where
  getRewrite :: (MExpr -> a) -> Rewrite
getRewrite MExpr -> a
rule = Rewrite :: MExpr -> Int -> Rewrite
Rewrite {
    holes :: MExpr
holes = Rewrite -> MExpr
holes (Rewrite -> MExpr) -> (Int -> Rewrite) -> Int -> MExpr
forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> Rewrite
forall a. RewriteC a => a -> Rewrite
getRewrite (a -> Rewrite) -> (Int -> a) -> Int -> Rewrite
forall b c a. (b -> c) -> (a -> b) -> a -> c
. MExpr -> a
rule (MExpr -> a) -> (Int -> MExpr) -> Int -> a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> MExpr
Hole (Int -> MExpr) -> Int -> MExpr
forall a b. (a -> b) -> a -> b
$ Int
pid,
    rid :: Int
rid   = Int
pid Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1
  } where
     pid :: Int
pid = Rewrite -> Int
rid (Rewrite -> Int) -> Rewrite -> Int
forall a b. (a -> b) -> a -> b
$ a -> Rewrite
forall a. RewriteC a => a -> Rewrite
getRewrite (a
forall a. HasCallStack => a
undefined :: a)


----------------------------------------------------------------------------------------
-- Applying/matching Rewrites

type ExprArr = Array Int Expr

-- | Fill in the holes in a 'MExpr'
myFire :: ExprArr -> MExpr -> MExpr
myFire :: ExprArr -> MExpr -> MExpr
myFire ExprArr
xs (MApp MExpr
e1 MExpr
e2) = MExpr -> MExpr -> MExpr
MApp (ExprArr -> MExpr -> MExpr
myFire ExprArr
xs MExpr
e1) (ExprArr -> MExpr -> MExpr
myFire ExprArr
xs MExpr
e2)
myFire ExprArr
xs (Hole Int
h) = Expr -> MExpr
Quote (Expr -> MExpr) -> Expr -> MExpr
forall a b. (a -> b) -> a -> b
$ ExprArr
xs ExprArr -> Int -> Expr
forall i e. Ix i => Array i e -> i -> e
! Int
h
myFire ExprArr
_ MExpr
me = MExpr
me

nub' :: Ord a => [a] -> [a]
nub' :: [a] -> [a]
nub' = Set a -> [a]
forall a. Set a -> [a]
S.toList (Set a -> [a]) -> ([a] -> Set a) -> [a] -> [a]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [a] -> Set a
forall a. Ord a => [a] -> Set a
S.fromList

-- | Create an array, only if the keys in 'lst' are unique and all keys [0..n-1] are given
uniqueArray :: Ord v => Int -> [(Int, v)] -> Maybe (Array Int v)
uniqueArray :: Int -> [(Int, v)] -> Maybe (Array Int v)
uniqueArray Int
n [(Int, v)]
lst
  | [(Int, v)] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([(Int, v)] -> [(Int, v)]
forall a. Ord a => [a] -> [a]
nub' [(Int, v)]
lst) Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
n = Array Int v -> Maybe (Array Int v)
forall a. a -> Maybe a
Just (Array Int v -> Maybe (Array Int v))
-> Array Int v -> Maybe (Array Int v)
forall a b. (a -> b) -> a -> b
$ (Int, Int) -> [(Int, v)] -> Array Int v
forall i e. Ix i => (i, i) -> [(i, e)] -> Array i e
array (Int
0,Int
nInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1) [(Int, v)]
lst
  | Bool
otherwise = Maybe (Array Int v)
forall a. Maybe a
Nothing

-- | Try to match a Rewrite to an expression,
--   if there is a match, returns the expressions in the holes
match :: Rewrite -> Expr -> Maybe ExprArr
match :: Rewrite -> Expr -> Maybe ExprArr
match (Rewrite MExpr
hl Int
rid') Expr
e  = Int -> [(Int, Expr)] -> Maybe ExprArr
forall v. Ord v => Int -> [(Int, v)] -> Maybe (Array Int v)
uniqueArray Int
rid' ([(Int, Expr)] -> Maybe ExprArr)
-> Maybe [(Int, Expr)] -> Maybe ExprArr
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< MExpr -> Expr -> Maybe [(Int, Expr)]
matchWith MExpr
hl Expr
e

-- | Fill in the holes in a 'Rewrite'
fire' :: Rewrite -> ExprArr -> MExpr
fire' :: Rewrite -> ExprArr -> MExpr
fire' (Rewrite MExpr
hl Int
_)   = (ExprArr -> MExpr -> MExpr
`myFire` MExpr
hl)

fire :: Rewrite -> Rewrite -> Expr -> Maybe Expr
fire :: Rewrite -> Rewrite -> Expr -> Maybe Expr
fire Rewrite
r1 Rewrite
r2 Expr
e = (MExpr -> Expr
fromMExpr (MExpr -> Expr) -> (ExprArr -> MExpr) -> ExprArr -> Expr
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Rewrite -> ExprArr -> MExpr
fire' Rewrite
r2) (ExprArr -> Expr) -> Maybe ExprArr -> Maybe Expr
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
`fmap` Rewrite -> Expr -> Maybe ExprArr
match Rewrite
r1 Expr
e

-- | Match an Expr to a MExpr template, return the values used in the holes
matchWith :: MExpr -> Expr -> Maybe [(Int, Expr)]
matchWith :: MExpr -> Expr -> Maybe [(Int, Expr)]
matchWith (MApp MExpr
e1 MExpr
e2) (App Expr
e1' Expr
e2') =
  ([(Int, Expr)] -> [(Int, Expr)] -> [(Int, Expr)])
-> Maybe [(Int, Expr)]
-> Maybe [(Int, Expr)]
-> Maybe [(Int, Expr)]
forall (m :: * -> *) a1 a2 r.
Monad m =>
(a1 -> a2 -> r) -> m a1 -> m a2 -> m r
liftM2 [(Int, Expr)] -> [(Int, Expr)] -> [(Int, Expr)]
forall a. [a] -> [a] -> [a]
(++) (MExpr -> Expr -> Maybe [(Int, Expr)]
matchWith MExpr
e1 Expr
e1') (MExpr -> Expr -> Maybe [(Int, Expr)]
matchWith MExpr
e2 Expr
e2')
matchWith (Quote Expr
e) Expr
e' = if Expr
e Expr -> Expr -> Bool
forall a. Eq a => a -> a -> Bool
== Expr
e' then [(Int, Expr)] -> Maybe [(Int, Expr)]
forall a. a -> Maybe a
Just [] else Maybe [(Int, Expr)]
forall a. Maybe a
Nothing
matchWith (Hole Int
k) Expr
e = [(Int, Expr)] -> Maybe [(Int, Expr)]
forall a. a -> Maybe a
Just [(Int
k,Expr
e)]
matchWith MExpr
_ Expr
_ = Maybe [(Int, Expr)]
forall a. Maybe a
Nothing

fromMExpr :: MExpr -> Expr
fromMExpr :: MExpr -> Expr
fromMExpr (MApp MExpr
e1 MExpr
e2)  = Expr -> Expr -> Expr
App (MExpr -> Expr
fromMExpr MExpr
e1) (MExpr -> Expr
fromMExpr MExpr
e2)
fromMExpr (Hole Int
_)      = Fixity -> String -> Expr
Var Fixity
Pref String
"Hole" -- error "Hole in MExpr"
fromMExpr (Quote Expr
e)     = Expr
e

----------------------------------------------------------------------------------------
-- Difining rules

-- | Yet another pointless transformation:
--   Bring an MExpr to (more pointless) form by seeing it as a function
--     \hole_n -> ...
--   and writing that in pointless form
transformM :: Int -> MExpr -> MExpr
transformM :: Int -> MExpr -> MExpr
transformM Int
_ (Quote Expr
e) = MExpr
constE MExpr -> MExpr -> MExpr
`a` Expr -> MExpr
Quote Expr
e
transformM Int
n (Hole Int
n') = if Int
n Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
n' then MExpr
idE else MExpr
constE MExpr -> MExpr -> MExpr
`a` Int -> MExpr
Hole Int
n'
transformM Int
n (Quote (Var Fixity
_ String
".") `MApp` MExpr
e1 `MApp` MExpr
e2)
  | MExpr
e1 MExpr -> Int -> Bool
`hasHole` Int
n Bool -> Bool -> Bool
&& Bool -> Bool
not (MExpr
e2 MExpr -> Int -> Bool
`hasHole` Int
n)
  = MExpr
flipE MExpr -> MExpr -> MExpr
`a` MExpr
compE MExpr -> MExpr -> MExpr
`a` MExpr
e2 MExpr -> MExpr -> MExpr
`c` Int -> MExpr -> MExpr
transformM Int
n MExpr
e1
transformM Int
n e :: MExpr
e@(MApp MExpr
e1 MExpr
e2)
  | Bool
fr1 Bool -> Bool -> Bool
&& Bool
fr2 = MExpr
sE MExpr -> MExpr -> MExpr
`a` Int -> MExpr -> MExpr
transformM Int
n MExpr
e1 MExpr -> MExpr -> MExpr
`a` Int -> MExpr -> MExpr
transformM Int
n MExpr
e2
  | Bool
fr1        = MExpr
flipE MExpr -> MExpr -> MExpr
`a` Int -> MExpr -> MExpr
transformM Int
n MExpr
e1 MExpr -> MExpr -> MExpr
`a` MExpr
e2
  | Bool
fr2, Hole Int
n' <- MExpr
e2, Int
n' Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
n = MExpr
e1
  | Bool
fr2        = MExpr
e1 MExpr -> MExpr -> MExpr
`c` Int -> MExpr -> MExpr
transformM Int
n MExpr
e2
  | Bool
otherwise  = MExpr
constE MExpr -> MExpr -> MExpr
`a` MExpr
e
  where
    fr1 :: Bool
fr1 = MExpr
e1 MExpr -> Int -> Bool
`hasHole` Int
n
    fr2 :: Bool
fr2 = MExpr
e2 MExpr -> Int -> Bool
`hasHole` Int
n

-- | Is there a (Hole n) in an expression?
hasHole :: MExpr -> Int -> Bool
hasHole :: MExpr -> Int -> Bool
hasHole (MApp MExpr
e1 MExpr
e2) Int
n = MExpr
e1 MExpr -> Int -> Bool
`hasHole` Int
n Bool -> Bool -> Bool
|| MExpr
e2 MExpr -> Int -> Bool
`hasHole` Int
n
hasHole (Quote Expr
_)    Int
_ = Bool
False
hasHole (Hole Int
n')    Int
n = Int
n Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
n'

-- | Variants of a rewrite rule: fill in (some of) the holes
--
-- haddock doesn't like n+k patterns, so rewrite them
--
getVariants, getVariants' :: Rewrite -> [Rewrite]
getVariants' :: Rewrite -> [Rewrite]
getVariants' r :: Rewrite
r@(Rewrite MExpr
_ Int
0)  = [Rewrite
r]
getVariants' r :: Rewrite
r@(Rewrite MExpr
e Int
nk)
    | Int
nk Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
1    = Rewrite
r Rewrite -> [Rewrite] -> [Rewrite]
forall a. a -> [a] -> [a]
: Rewrite -> [Rewrite]
getVariants (MExpr -> Int -> Rewrite
Rewrite MExpr
e' (Int
nkInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1))
    | Bool
otherwise  = String -> [Rewrite]
forall a. HasCallStack => String -> a
error String
"getVariants' : nk went negative"
    where
        e' :: MExpr
e' = MExpr -> MExpr
decHoles (MExpr -> MExpr) -> MExpr -> MExpr
forall a b. (a -> b) -> a -> b
$ Int -> MExpr -> MExpr
transformM Int
0 MExpr
e

        -- decrement all hole numbers
        decHoles :: MExpr -> MExpr
decHoles (Hole Int
n')    = Int -> MExpr
Hole (Int
n'Int -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1)
        decHoles (MApp MExpr
e1 MExpr
e2) = MExpr -> MExpr
decHoles MExpr
e1 MExpr -> MExpr -> MExpr
`MApp` MExpr -> MExpr
decHoles MExpr
e2
        decHoles MExpr
me           = MExpr
me

getVariants :: Rewrite -> [Rewrite]
getVariants = Rewrite -> [Rewrite]
getVariants' -- r = trace (show vs) vs where vs = getVariants' r

-- | Use this rewrite rule and rewrite rules derived from it by iterated
--   pointless transformation
rrList :: RewriteC a => a -> a -> [RewriteRule]
rrList :: a -> a -> [RewriteRule]
rrList a
r1 a
r2 = (Rewrite -> Rewrite -> RewriteRule)
-> [Rewrite] -> [Rewrite] -> [RewriteRule]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Rewrite -> Rewrite -> RewriteRule
RR (Rewrite -> [Rewrite]
getVariants Rewrite
r1') (Rewrite -> [Rewrite]
getVariants Rewrite
r2') where
  r1' :: Rewrite
r1' = a -> Rewrite
forall a. RewriteC a => a -> Rewrite
getRewrite a
r1
  r2' :: Rewrite
r2' = a -> Rewrite
forall a. RewriteC a => a -> Rewrite
getRewrite a
r2

-- | Construct a 'RR' rewrite rule
rr, rr0, rr1, rr2 :: RewriteC a => a -> a -> RewriteRule
rr :: a -> a -> RewriteRule
rr  a
r1 a
r2 = [RewriteRule] -> RewriteRule
Or          ([RewriteRule] -> RewriteRule) -> [RewriteRule] -> RewriteRule
forall a b. (a -> b) -> a -> b
$ a -> a -> [RewriteRule]
forall a. RewriteC a => a -> a -> [RewriteRule]
rrList a
r1 a
r2
rr1 :: a -> a -> RewriteRule
rr1 a
r1 a
r2 = [RewriteRule] -> RewriteRule
Or ([RewriteRule] -> RewriteRule)
-> ([RewriteRule] -> [RewriteRule]) -> [RewriteRule] -> RewriteRule
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> [RewriteRule] -> [RewriteRule]
forall a. Int -> [a] -> [a]
take Int
2 ([RewriteRule] -> RewriteRule) -> [RewriteRule] -> RewriteRule
forall a b. (a -> b) -> a -> b
$ a -> a -> [RewriteRule]
forall a. RewriteC a => a -> a -> [RewriteRule]
rrList a
r1 a
r2
rr2 :: a -> a -> RewriteRule
rr2 a
r1 a
r2 = [RewriteRule] -> RewriteRule
Or ([RewriteRule] -> RewriteRule)
-> ([RewriteRule] -> [RewriteRule]) -> [RewriteRule] -> RewriteRule
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> [RewriteRule] -> [RewriteRule]
forall a. Int -> [a] -> [a]
take Int
3 ([RewriteRule] -> RewriteRule) -> [RewriteRule] -> RewriteRule
forall a b. (a -> b) -> a -> b
$ a -> a -> [RewriteRule]
forall a. RewriteC a => a -> a -> [RewriteRule]
rrList a
r1 a
r2
-- use only this rewrite rule, no variants
rr0 :: a -> a -> RewriteRule
rr0 a
r1 a
r2 = Rewrite -> Rewrite -> RewriteRule
RR Rewrite
r1' Rewrite
r2' where
  r1' :: Rewrite
r1' = a -> Rewrite
forall a. RewriteC a => a -> Rewrite
getRewrite a
r1
  r2' :: Rewrite
r2' = a -> Rewrite
forall a. RewriteC a => a -> Rewrite
getRewrite a
r2

-- | Apply Down/Up repeatedly
down, up :: RewriteRule -> RewriteRule
down :: RewriteRule -> RewriteRule
down = (RewriteRule -> RewriteRule) -> RewriteRule
forall a. (a -> a) -> a
fix ((RewriteRule -> RewriteRule) -> RewriteRule)
-> (RewriteRule -> RewriteRule -> RewriteRule)
-> RewriteRule
-> RewriteRule
forall b c a. (b -> c) -> (a -> b) -> a -> c
. RewriteRule -> RewriteRule -> RewriteRule
Down
up :: RewriteRule -> RewriteRule
up   = (RewriteRule -> RewriteRule) -> RewriteRule
forall a. (a -> a) -> a
fix ((RewriteRule -> RewriteRule) -> RewriteRule)
-> (RewriteRule -> RewriteRule -> RewriteRule)
-> RewriteRule
-> RewriteRule
forall b c a. (b -> c) -> (a -> b) -> a -> c
. RewriteRule -> RewriteRule -> RewriteRule
Up