module Tip.Simplify where
import Tip.Core
import Tip.Scope
import Tip.Fresh
import Data.Generics.Geniplate
import Data.List
import Data.Maybe
import Data.Monoid
import Control.Applicative
import Control.Monad
import qualified Data.Map as Map
import Tip.Writer
data SimplifyOpts a =
SimplifyOpts {
touch_lets :: Bool,
remove_variable_scrutinee_in_branches :: Bool,
should_inline :: Occurrences -> Maybe (Scope a) -> Expr a -> Bool,
inline_match :: Bool
}
newtype Occurrences = Occurrences Int
gentlyNoInline :: SimplifyOpts a
gentlyNoInline = gently { should_inline = \ _ _ _ -> False }
gently :: SimplifyOpts a
gently = SimplifyOpts True True (\ (Occurrences occ) _ e -> occ <= 1 || atomic e) True
aggressively :: Name a => SimplifyOpts a
aggressively = SimplifyOpts True True (\ (Occurrences occ) mscp e -> occ <= 1 || useful mscp e) True
where
useful _ Lam{} = True
useful mscp (f :@: _) = isConstructor mscp f
useful _ _ = False
simplifyTheory :: Name a => SimplifyOpts a -> Theory a -> Fresh (Theory a)
simplifyTheory opts thy@Theory{..} = do
thy_funcs <- mapM (simplifyExprIn (Just thy) opts) thy_funcs
thy_asserts <- mapM (simplifyExprIn (Just thy) opts{inline_match = False}) thy_asserts
return Theory{..}
simplifyExpr :: forall f a. (TransformBiM (WriterT Any Fresh) (Expr a) (f a), Name a) => SimplifyOpts a -> f a -> Fresh (f a)
simplifyExpr opts = simplifyExprIn Nothing opts
simplifyExprIn :: forall f a. (TransformBiM (WriterT Any Fresh) (Expr a) (f a), Name a) => Maybe (Theory a) -> SimplifyOpts a -> f a -> Fresh (f a)
simplifyExprIn mthy opts@SimplifyOpts{..} = fmap fst . runWriterT . aux
where
aux :: forall f. TransformBiM (WriterT Any Fresh) (Expr a) (f a) => f a -> WriterT Any Fresh (f a)
aux = transformBiM $ \e0 ->
let
share e1 | e1 /= e0 = return e1
| otherwise = return e0 in
case e0 of
Builtin At :@: (Lam vars body:args) ->
hooray $
aux (foldr (uncurry Let) body (zip vars args))
Let x e body | touch_lets && (atomic e || occurrences x body <= 1) ->
lift ((e // x) body) >>= aux
Let x e body | touch_lets && inlineable body x e ->
do e1 <- lift ((e // x) body)
(e2, Any simplified) <- lift (runWriterT (aux e1))
if simplified then hooray $ return e2 else return e0
Match _ [Case Default body] -> hooray $ return body
Match e (Case Default (Match e' cases'):cases) | e == e' ->
hooray $ aux $
Match e (filter (not . dead . case_pat) cases' ++ cases)
where
dead (LitPat l) = LitPat l `elem` map case_pat cases
dead (ConPat{..}) =
gbl_name pat_con `elem`
[ gbl_name pat_con | ConPat{..} <- map case_pat cases ]
dead Default = False
Match e (Case Default def:cases)
| TyCon ty args <- exprType e,
Just (d, c@Constructor{..}) <- missingCase mscp ty cases -> do
let gbl = constructor d c args
pat <- lift (fmap (ConPat gbl) (freshArgs gbl))
aux (Match e (Case pat def:cases))
Match e [Case _ e1,Case (LitPat (Bool b)) e2]
| e1 == bool (not b) && e2 == bool b -> hooray $ return e
| e1 == bool b && e2 == bool (not b) -> hooray $ return (neg e)
Match (Let x e body) alts | touch_lets ->
aux (Let x e (Match body alts))
Match e alts
| Just e' <- tryMatch mscp e alts -> hooray $ aux e'
Match (Lcl x) alts | remove_variable_scrutinee_in_branches ->
Match (Lcl x) <$> sequence
[ Case pat <$> case pat of
ConPat g bs -> subst ((Gbl g :@: map Lcl bs) /// x) rhs
LitPat l -> subst (literal l /// x) rhs
_ -> return rhs
| Case pat rhs <- alts
]
where
subst f e = do
(e', Any successful) <- lift (runWriterT (f e))
if successful then aux e' else return e
Builtin Equal :@: [Builtin (Lit (Bool x)) :@: [], t]
| x -> hooray $ return t
| otherwise -> hooray $ return $ neg t
Builtin Equal :@: [t, Builtin (Lit (Bool x)) :@: []]
| x -> hooray $ return t
| otherwise -> hooray $ return $ neg t
Builtin Equal :@: [litView -> Just s,litView -> Just t] -> hooray $ return (bool (s == t))
Builtin eq_op :@: [Gbl k :@: kargs,Gbl j :@: jargs]
| Just scp <- mscp
, Just (_,Constructor kk _ _) <- lookupConstructor (gbl_name k) scp
, Just (_,Constructor jj _ _) <- lookupConstructor (gbl_name j) scp
, Just res <- case (eq_op, kk == jj) of
(Equal ,False) -> Just falseExpr
(Distinct,False) -> Just trueExpr
(Equal, True) -> Just (ands (zipWith (===) kargs jargs))
(Distinct,True) -> Just (ors (zipWith (=/=) kargs jargs))
_ -> Nothing
-> hooray $ aux res
Builtin Distinct :@: [litView -> Just s,litView -> Just t] -> hooray $ return (bool (s /= t))
Builtin Not :@: [e] -> share (neg e)
Builtin And :@: [e1, e2] | e1 == e2 -> return e1
| otherwise -> share (e1 /\ e2)
Builtin Or :@: [e1, e2] | e1 == e2 -> return e1
| otherwise -> share (e1 \/ e2)
Builtin Implies :@: [e1, e2] -> share (e1 ==> e2)
Builtin Equal :@: [e1, e2] ->
case exprType e1 of
t@(_ :=>: _) -> hooray $ go t e1 e2 []
where
go (args :=>: rest) u v lcls =
do more <- lift (mapM freshLocal args)
go rest (apply u (map Lcl more))
(apply v (map Lcl more))
(lcls ++ more)
go _ u v lcls = return (mkQuant Forall lcls (u === v))
_ -> return e0
Gbl gbl@Global{..} :@: ts ->
case Map.lookup gbl_name inlinings of
Just func@Function{..}
| and [ inlineable func_body x t | (x, t) <- zip func_args ts ] -> do
func_body <- boo $ aux func_body
e1 <-
transformTypeInExpr (applyType func_tvs gbl_args) <$>
lift (substMany (zip func_args ts) func_body)
(e2, Any simplified) <- lift (runWriterT (aux e1))
if (simplified && (inline_match || not (containsMatch e2))) || atomic func_body
then hooray $ return e2
else return (Gbl gbl :@: ts)
_ -> return (Gbl gbl :@: ts)
_ -> return e0
inlineable body var val = should_inline (Occurrences (occurrences var body)) mscp val
mscp = fmap scope mthy
isRecursiveGroup [fun] = defines fun `elem` uses fun
isRecursiveGroup _ = True
inlinings =
case mthy of
Nothing -> Map.empty
Just Theory{..} ->
Map.fromList . map (\fun -> (func_name fun, fun)) .
concat . filter (not . isRecursiveGroup) . topsort $ thy_funcs
containsMatch e = not (null [ e' | e'@Match{} <- universe e ])
new /// old = transformExprM $ \e ->
if e == Lcl old then hooray $ lift (freshen new) else return e
hooray x = do
tell (Any True)
x
boo x = censor (const (Any False)) x
isConstructor :: Name a => Maybe (Scope a) -> Head a -> Bool
isConstructor _ (Builtin Lit{}) = True
isConstructor mscp (Gbl gbl) = isJust $ do
scp <- mscp
lookupConstructor (gbl_name gbl) scp
isConstructor _ _ = False
missingCase :: Name a => Maybe (Scope a) -> a -> [Case a] -> Maybe (Datatype a, Constructor a)
missingCase mscp tc cases = do
scp <- mscp
dt@Datatype{..} <- lookupDatatype tc scp
let matched Constructor{..} =
con_name `elem` [ gbl_name pat_con | ConPat{..} <- map case_pat cases ]
case filter (not . matched) data_cons of
[con] -> return (dt, con)
_ -> Nothing
tryMatch :: Name a => Maybe (Scope a) -> Expr a -> [Case a] -> Maybe (Expr a)
tryMatch mscp (hd :@: args) alts | isConstructor mscp hd =
case filter (matches hd . case_pat) (reverse alts) of
[] -> Nothing
Case (ConPat _ lcls) body:_ ->
Just $ foldr (uncurry Let) body (zip lcls args)
Case _ body:_ -> Just body
where
matches (Gbl gbl) (ConPat gbl' _) = gbl == gbl'
matches (Builtin (Lit lit)) (LitPat lit') = lit == lit'
matches _ Default = True
matches _ _ = False
tryMatch _ _ _ = Nothing