module E.Rules(
ARules,
Rule(Rule,ruleHead,ruleBinds,ruleArgs,ruleBody,ruleUniq,ruleName),
RuleType(..),
Rules(..),
applyRules,
arules,
builtinRule,
dropArguments,
fromRules,
ruleUpdate,
mapRBodyArgs,
makeRule,
mapBodies,
printRules,
rulesFromARules
)where
import Control.Monad.Writer(execWriterT,liftM,tell)
import Data.Maybe
import Data.Monoid(Monoid(..))
import qualified Data.Traversable as T
import Data.Binary
import Doc.DocLike
import Doc.PPrint
import Doc.Pretty
import E.Binary()
import E.E
import E.Show()
import E.Subst
import E.Values
import GenUtil
import Info.Types
import Name.Id
import Name.Name
import Name.Names
import Options
import Stats
import StringTable.Atom(toAtom)
import Support.CanType
import Support.FreeVars
import Support.MapBinaryInstance()
import Util.HasSize
import Util.SetLike as S
import qualified Util.Seq as Seq
instance Show Rule where
showsPrec _ r = shows $ ruleName r
emptyRule :: Rule
emptyRule = Rule {
ruleHead = error "ruleHead undefined",
ruleArgs = [],
ruleNArgs = 0,
ruleBinds = [],
ruleType = RuleUser,
ruleBody = error "ruleBody undefined",
ruleName = error "ruleName undefined",
ruleUniq = error "ruleUniq undefined"
}
newtype Rules = Rules (IdMap [Rule])
deriving(HasSize,IsEmpty)
instance Eq Rule where
r1 == r2 = ruleUniq r1 == ruleUniq r2
instance Binary Rules where
put (Rules mp) = put (concat $ values mp)
get = do
rs <- get
return $ fromRules rs
mapBodies :: Monad m => (E -> m E) -> Rules -> m Rules
mapBodies g (Rules mp) = do
let f rule = do
b <- g (ruleBody rule)
return rule { ruleBody = b }
mp' <- T.mapM (mapM f) mp
return $ Rules mp'
instance FreeVars Rule [Id] where
freeVars rule = idSetToList $ freeVars rule
printRules ty (Rules rules) = mapM_ (\r -> printRule r >> putChar '\n') [ r | r <- concat $ values rules, ruleType r == ty ]
putDocMLn' :: Monad m => (String -> m ()) -> Doc -> m ()
putDocMLn' putStr d = displayM putStr (renderPretty 0.80 (optColumns options) d) >> putStr "\n"
printRule Rule {ruleName = n, ruleBinds = vs, ruleBody = e2, ruleHead = head, ruleArgs = args } = do
let e1 = foldl EAp (EVar head) args
let p v = parens $ pprint v <> text "::" <> pprint (getType v)
putDocMLn' putStr $ (tshow n) <+> text "forall" <+> hsep (map p vs) <+> text "."
let ty = pprint $ getType e1
putDocMLn' putStr (indent 2 (pprint e1) <+> text "::" <+> ty )
putDocMLn' putStr $ text " ==>" <+> pprint e2
combineRules as bs = map head $ sortGroupUnder ruleUniq (as ++ bs)
instance Monoid Rules where
mempty = Rules mempty
mappend (Rules x) (Rules y) = Rules $ unionWith combineRules x y
fromRules :: [Rule] -> Rules
fromRules rs = Rules $ fmap snds $ fromList $ sortGroupUnderF fst [ (tvrIdent $ ruleHead r,ruleUpdate r) | r <- rs ]
mapRBodyArgs :: Monad m => (E -> m E) -> Rule -> m Rule
mapRBodyArgs g r = do
let f rule = do
b <- g (ruleBody rule)
as <- mapM g (ruleArgs rule)
return rule { ruleArgs = as, ruleBody = b }
f r
rulesFromARules :: ARules -> [Rule]
rulesFromARules = aruleRules
dropArguments :: [(Int,E)] -> [Rule] -> [Rule]
dropArguments os rs = catMaybes $ map f rs where
f r = do
let g (i,a) | Just v <- lookup i os = do
rs <- match (const Nothing) (ruleBinds r) a v
return (Right rs)
g (i,a) = return (Left a)
as' <- mapM g $ zip naturals (ruleArgs r)
let sb = substLet (concat $ rights as')
sa = substMap $ fromList [ (tvrIdent t,v) | Right ds <- as', (t,v) <- ds ]
return r { ruleArgs = map sa (lefts as'), ruleBody = sb (ruleBody r) }
instance Show ARules where
showsPrec n a = showsPrec n (aruleRules a)
arules xs = ARules { aruleFreeVars = freeVars rs, aruleRules = rs } where
rs = sortUnder ruleNArgs (map f xs)
f rule = rule {
ruleNArgs = length (ruleArgs rule),
ruleBinds = bs,
ruleBody = g (ruleBody rule),
ruleArgs = map g (ruleArgs rule)
} where
bs = map (setProperty prop_RULEBINDER) (ruleBinds rule)
g e = substMap (fromList [ (tvrIdent t, EVar t) | t <- bs ]) e
instance Monoid ARules where
mempty = ARules { aruleFreeVars = mempty, aruleRules = [] }
mappend = joinARules
ruleUpdate rule = rule {
ruleNArgs = length (ruleArgs rule),
ruleBinds = bs,
ruleBody = g (ruleBody rule),
ruleArgs = map g (ruleArgs rule)
} where
bs = map (setProperty prop_RULEBINDER) (ruleBinds rule)
g e = substMap (fromList [ (tvrIdent t, EVar t) | t <- bs ]) e
joinARules ar@(ARules fvsa a) br@(ARules fvsb b)
| [] <- rs = ARules mempty []
| all (== r) rs = ARules (fvsa `mappend` fvsb) (sortUnder (\r -> (ruleNArgs r,ruleUniq r)) (snubUnder ruleUniq $ a ++ b))
| otherwise = error $ "mixing rules!" ++ show (ar,br) where
rs@(r:_) = map ruleHead a ++ map ruleHead b
rsubstMap :: IdMap E -> E -> E
rsubstMap im e = doSubst False True (fmap ( (`mlookup` im) . tvrIdent) (unions $ (freeVars e :: IdMap TVr):map freeVars (values im))) e
applyRules :: MonadStats m => (Id -> Maybe E) -> ARules -> [E] -> m (Maybe (E,[E]))
applyRules lup (ARules _ rs) xs = f rs where
lxs = length xs
f [] = return Nothing
f (r:_) | ruleNArgs r > lxs = return Nothing
f (r:rs) = case sequence (zipWith (match lup (ruleBinds r)) (ruleArgs r) xs) of
Just ss -> do
mtick (ruleName r)
let b = rsubstMap (fromList [ (i,x) | (TVr { tvrIdent = i },x) <- concat ss ]) (ruleBody r)
return $ Just (b,drop (ruleNArgs r) xs)
Nothing -> do f rs
preludeError = toId v_error
ruleError = toAtom "Rule.error/EError"
builtinRule TVr { tvrIdent = n } (ty:s:rs)
| n == preludeError, Just s' <- toString s = do
mtick ruleError
return $ Just ((EError ("Prelude.error: " ++ s') ty),rs)
builtinRule _ _ = return Nothing
makeRule ::
String
-> (Module,Int)
-> RuleType
-> [TVr]
-> TVr
-> [E]
-> E
-> Rule
makeRule name uniq ruleType fvs head args body = rule where
rule = emptyRule {
ruleHead = head,
ruleBinds = fvs,
ruleArgs = args,
ruleType = ruleType,
ruleNArgs = length args,
ruleBody = body,
ruleUniq = uniq,
ruleName = toAtom $ "Rule.User." ++ name
}
match :: Monad m =>
(Id -> Maybe E)
-> [TVr]
-> E
-> E
-> m [(TVr,E)]
match lup vs = \e1 e2 -> liftM Seq.toList $ execWriterT (un e1 e2 () etherealIds) where
bvs :: IdSet
bvs = fromList (map tvrIdent vs)
un _ _ _ c | c `seq` False = undefined
un (EAp a b) (EAp a' b') mm c = do
un a a' mm c
un b b' mm c
un (ELam va ea) (ELam vb eb) mm c = lam va ea vb eb mm c
un (EPi va ea) (EPi vb eb) mm c = lam va ea vb eb mm c
un (EPrim s xs t) (EPrim s' ys t') mm c | length xs == length ys = do
sequence_ [ un x y mm c | x <- xs | y <- ys]
un t t' mm c
un (ESort x) (ESort y) mm c | x == y = return ()
un (ELit (LitInt x t1)) (ELit (LitInt y t2)) mm c | x == y = un t1 t2 mm c
un (ELit LitCons { litName = n, litArgs = xs, litType = t }) (ELit LitCons { litName = n', litArgs = ys, litType = t'}) mm c | n == n' && length xs == length ys = do
sequence_ [ un x y mm c | x <- xs | y <- ys]
un t t' mm c
un (EVar TVr { tvrIdent = i, tvrType = t}) (EVar TVr {tvrIdent = j, tvrType = u}) mm c | i == j = un t u mm c
un (EVar TVr { tvrIdent = i, tvrType = t}) (EVar TVr {tvrIdent = j, tvrType = u}) mm c | isEtherealId i || isEtherealId j = fail "Expressions don't match"
un (EVar tvr@TVr { tvrIdent = i, tvrType = t}) b mm c
| i `member` bvs = tell (Seq.single (tvr,b))
| otherwise = fail $ "Expressions do not unify: " ++ show tvr ++ show b
un a (EVar tvr) mm c | Just b <- lup (tvrIdent tvr), not $ isEVar b = un a b mm c
un a b _ _ = fail $ "Expressions do not unify: " ++ show a ++ show b
lam va ea vb eb mm ~(c:cs) = do
un (tvrType va) (tvrType vb) mm (c:cs)
un (subst va (EVar va { tvrIdent = c }) ea) (subst vb (EVar vb { tvrIdent = c }) eb) mm cs