{-# LANGUAGE FlexibleInstances, PatternGuards, ScopedTypeVariables #-}
module Lambdabot.Plugin.Haskell.Pl.RuleLib
(
RewriteRule(..), fire
,
rr,rr0,rr1,rr2,up,down
) where
import Lambdabot.Plugin.Haskell.Pl.Common
import Lambdabot.Plugin.Haskell.Pl.Names
import Data.Array
import qualified Data.Set as S
import Control.Monad.Fix (fix)
data RewriteRule
= RR Rewrite Rewrite
| CRR (Expr -> Maybe Expr)
| Down RewriteRule RewriteRule
| Up RewriteRule RewriteRule
| Or [RewriteRule]
| OrElse RewriteRule RewriteRule
| Then RewriteRule RewriteRule
| Opt RewriteRule
| If RewriteRule RewriteRule
| Hard RewriteRule
data Rewrite = Rewrite {
holes :: MExpr,
rid :: Int
}
class RewriteC a where
getRewrite :: a -> Rewrite
instance RewriteC MExpr where
getRewrite rule = Rewrite {
holes = rule,
rid = 0
}
instance RewriteC a => RewriteC (MExpr -> a) where
getRewrite rule = Rewrite {
holes = holes . getRewrite . rule . Hole $ pid,
rid = pid + 1
} where
pid = rid $ getRewrite (undefined :: a)
type ExprArr = Array Int Expr
myFire :: ExprArr -> MExpr -> MExpr
myFire xs (MApp e1 e2) = MApp (myFire xs e1) (myFire xs e2)
myFire xs (Hole h) = Quote $ xs ! h
myFire _ me = me
nub' :: Ord a => [a] -> [a]
nub' = S.toList . S.fromList
uniqueArray :: Ord v => Int -> [(Int, v)] -> Maybe (Array Int v)
uniqueArray n lst
| length (nub' lst) == n = Just $ array (0,n-1) lst
| otherwise = Nothing
match :: Rewrite -> Expr -> Maybe ExprArr
match (Rewrite hl rid') e = uniqueArray rid' =<< matchWith hl e
fire' :: Rewrite -> ExprArr -> MExpr
fire' (Rewrite hl _) = (`myFire` hl)
fire :: Rewrite -> Rewrite -> Expr -> Maybe Expr
fire r1 r2 e = (fromMExpr . fire' r2) `fmap` match r1 e
matchWith :: MExpr -> Expr -> Maybe [(Int, Expr)]
matchWith (MApp e1 e2) (App e1' e2') =
liftM2 (++) (matchWith e1 e1') (matchWith e2 e2')
matchWith (Quote e) e' = if e == e' then Just [] else Nothing
matchWith (Hole k) e = Just [(k,e)]
matchWith _ _ = Nothing
fromMExpr :: MExpr -> Expr
fromMExpr (MApp e1 e2) = App (fromMExpr e1) (fromMExpr e2)
fromMExpr (Hole _) = Var Pref "Hole"
fromMExpr (Quote e) = e
transformM :: Int -> MExpr -> MExpr
transformM _ (Quote e) = constE `a` Quote e
transformM n (Hole n') = if n == n' then idE else constE `a` Hole n'
transformM n (Quote (Var _ ".") `MApp` e1 `MApp` e2)
| e1 `hasHole` n && not (e2 `hasHole` n)
= flipE `a` compE `a` e2 `c` transformM n e1
transformM n e@(MApp e1 e2)
| fr1 && fr2 = sE `a` transformM n e1 `a` transformM n e2
| fr1 = flipE `a` transformM n e1 `a` e2
| fr2, Hole n' <- e2, n' == n = e1
| fr2 = e1 `c` transformM n e2
| otherwise = constE `a` e
where
fr1 = e1 `hasHole` n
fr2 = e2 `hasHole` n
hasHole :: MExpr -> Int -> Bool
hasHole (MApp e1 e2) n = e1 `hasHole` n || e2 `hasHole` n
hasHole (Quote _) _ = False
hasHole (Hole n') n = n == n'
getVariants, getVariants' :: Rewrite -> [Rewrite]
getVariants' r@(Rewrite _ 0) = [r]
getVariants' r@(Rewrite e nk)
| nk >= 1 = r : getVariants (Rewrite e' (nk-1))
| otherwise = error "getVariants' : nk went negative"
where
e' = decHoles $ transformM 0 e
decHoles (Hole n') = Hole (n'-1)
decHoles (MApp e1 e2) = decHoles e1 `MApp` decHoles e2
decHoles me = me
getVariants = getVariants'
rrList :: RewriteC a => a -> a -> [RewriteRule]
rrList r1 r2 = zipWith RR (getVariants r1') (getVariants r2') where
r1' = getRewrite r1
r2' = getRewrite r2
rr, rr0, rr1, rr2 :: RewriteC a => a -> a -> RewriteRule
rr r1 r2 = Or $ rrList r1 r2
rr1 r1 r2 = Or . take 2 $ rrList r1 r2
rr2 r1 r2 = Or . take 3 $ rrList r1 r2
rr0 r1 r2 = RR r1' r2' where
r1' = getRewrite r1
r2' = getRewrite r2
down, up :: RewriteRule -> RewriteRule
down = fix . Down
up = fix . Up