module Optimus.Strategy where import Control.Monad.Identity import Control.Monad.State import Data.Char import Data.List import qualified Data.Map as Map import qualified Data.Set as Set import Data.Maybe import Optimus.Trace import Flite.Identify import Flite.Fresh --import Flite.Inline --import Flite.Pretty import Optimus.Pretty import Flite.Syntax import Flite.Traversals import Optimus.Generalise import Optimus.Homeo import Optimus.Inline import Optimus.Simplify import Optimus.Uniplate import Optimus.Util import Data.Generics.Uniplate -------------------------------------------------- -------------------------------------------------- reachable :: Prog -> Id -> Set.Set Id reachable p = reachable' (Set.empty) where reachable' :: Set.Set Id -> Id -> Set.Set Id reachable' fs f | f `Set.member` fs = fs | otherwise = foldl reachable' (f `Set.insert` fs) (reach f) reach f = [ g | d <- lookupFuncs f p, g <- (calls . funcRhs) d ] onlyReachable :: Prog -> Id -> Prog onlyReachable p f = [ d | d <- p, funcName d `Set.member` reachable p f ] supercompileFunc :: String -> Prog -> Prog supercompileFunc f_ p_ = onlyReachable p'' f_ ++ [ d | d@(Func f _ _) <- p, f /= f_ ] where p = t_S "Desugar program" $ freshProg (desugar . funcReuse) p_ p'' = Func f a main : p' Func f a r = fromMaybe (error $ "Could not find '" ++ f_ ++ "'") $ Map.lookup (t_sc f_ f_) m (main, SCState _ p' m' _) = snd (runFresh (runStateT (tie r) (SCState ((+) 1 $ maximum $ -1 : [ read rest | 'f':rest <- funcs p, all isNumber rest ]) [] m [])) "v" $ (+) 1 $ maximum $ -1 : [ read rest | 'v':rest <- allNames p, (not . null) rest, all isNumber rest ]) m = t_M $ byFuncName p supercompileMany :: [String] -> Prog -> Prog supercompileMany fs p = flip onlyReachable "main" $ freshProg (finalSimplification simplifyProg . flip onlyReachable "main" <=< finalInlining progInline) (foldl (flip supercompileFunc) p fs) supercompile :: Prog -> Prog supercompile = flip onlyReachable "main" . freshProg (finalSimplification simplifyProg . flip onlyReachable "main" <=< finalInlining progInline) . supercompileFunc "main" type SCStateT m a = StateT SCState m a data SCState = SCState { scCount :: Int, scResidual :: Prog, scFuncMap :: (Map.Map Id Decl), scRho :: [Decl] } scIncCount :: StateT SCState Fresh () scIncCount = do SCState count prog m rho <- get put (SCState (count+1) prog m rho) scAddDecl :: Decl -> StateT SCState Fresh Exp scAddDecl d@(Func fId fArgs _) = do SCState count prog m rho <- get put (SCState count (d:prog) (t_M $ Map.insert fId d m) rho) return $ App (Fun fId) fArgs scAddRho :: Decl -> StateT SCState Fresh () scAddRho d = do SCState count prog m rho <- get put (SCState count prog m (d:rho)) buildSig :: Id -> Exp -> Exp -> Decl buildSig i q r = Func i (map Var $ sort $ freeVars q) (autoAlphaExp r) tie :: Exp -> SCStateT Fresh Exp tie x | simpleExpr x = return x | otherwise = do SCState count prog m rho <- get let fHead = buildSig ('f' : show count) x in case findExp x $ t_Rho (show rho) rho of Just e -> (t_T $ "Seen!" ++ show e) return e Nothing -> (t_T $ "Making f" ++ show count) (do scIncCount fRhs <- (drive fHead . isTerm rho) x scAddDecl (fHead fRhs)) tieOrGen :: (Exp -> Decl) -> Exp -> Exp -> SCStateT Fresh Exp tieOrGen s x y = do SCState _ _ _ rho <- get case findExp x rho of Just e -> (t_T $ "Seen!" ++ show e) return e Nothing -> scAddRho (s x) >> lift (generalise1 x y) >>= \x' -> scAddRho (s x') >> let (cs, gen) = uniplate $ t_D ("generalised to " ++ show x') x' in liftM gen (mapM tie cs) drive :: (Exp -> Decl) -> Unfold -> SCStateT Fresh Exp drive s (NonTerm x) = t_D (show x) $ scAddRho (s x) >> get >>= \(SCState _ _ m rho) -> lift (unfoldT (flip Map.lookup m) rho x) >>= drive s drive s (SimplTerm x) = t_D ("Simple Termination" ++ show x) $ scAddRho (s x) >> let (cs, gen) = uniplate x in liftM gen (mapM tie cs) drive s (HomeoTerm x y) = t_D ("Homeomorphic Embedding of" ++ show x ++ "\nto" ++ show y) $ tieOrGen s x y drive s (NoUnfold x) = t_D ("No Unfolds Remaining" ++ show x) $ scAddRho (s x) >> return x data Unfold = NonTerm { ntExp :: Exp } | SimplTerm { stExp :: Exp } | HomeoTerm { htExp :: Exp, htHomeo :: Exp } | NoUnfold { nuExp :: Exp } deriving Show isTerm :: [Decl] -> Exp -> Unfold isTerm rho x | simpleTerm x = SimplTerm x | isHomeoTerm = HomeoTerm x y | otherwise = NonTerm x where ht = homeoTerm rho x isHomeoTerm = isJust ht Just y = ht filterTerms :: Exp -> [Fresh Unfold] -> Fresh Unfold filterTerms def ufs = filterTerms' ufs where filterTerms' :: [Fresh Unfold] -> Fresh Unfold filterTerms' [] | null ufs = return (NoUnfold def) | otherwise = t_U ("No non-terminating expression found.") $ head ufs filterTerms' (x:xs) = x >>= \x' -> case x' of NonTerm _ -> t_U ("Found a non-terminating unfold.") (return x') otherwise -> t_U ("Skipped as " ++ case x' of { SimplTerm _ -> "Simple Termination."; HomeoTerm _ _ -> "Simple Termination."; otherwise -> "Unknown Problem." }) filterTerms' xs -- Move simplify down here \/. And freshen those variables! unfoldT :: (Id -> Maybe Decl) -> [Decl] -> Exp -> Fresh Unfold unfoldT m rho e = t_U "\nChecking unfold of: " $ filterTerms e unfolds where -- Should this be holes? Should freshen anything that redefines a global variable or just block it. unfolds :: [Fresh Unfold] unfolds = [ c' >>= simplify . h >>= return . isTerm rho | (c, h) <- ((e, id) : holes e), c' <- maybeToList (unfold c) ] unfold :: Exp -> Maybe (Fresh Exp) unfold e@(App (Fun f) xs) = m f >>= maybeInline xs unfold (Fun f) = unfold $ App (Fun f) [] unfold _ = Nothing findExp :: Exp -> [Decl] -> Maybe Exp -- Slowdown here. Do properly findExp e ds = listToMaybe $ mapMaybe (matchExp (autoAlphaExp e)) ds -- Must be linear and alphaed matchExp :: Exp -> Decl -> Maybe Exp matchExp e (Func f a r) | success = sequence [ join (Map.lookup v mapping) | Var v <- a ] >>= Just . App (Fun f) | otherwise = Nothing where (success, mapping) = runState (matchExp' e (autoAlphaExp r)) (Map.fromList $ [ (v, Nothing) | Var v <- a ]) matchExp' :: Exp -> Exp -> State (Map.Map Id (Maybe Exp)) Bool matchExp' e (Var v) = get >>= \a -> case v `Map.lookup` a of Nothing -> case e of { Var v' -> return (v == v'); otherwise -> return False } Just (Nothing) -> put (Map.insert v (Just e) a) >> return True Just (Just e') -> return $ e == e' -- maybe deal with variable bindings matchExp' e e' = liftM ((&&) (e =~ e') . and) $ zipWithM matchExp' (children e) (children e') -------------------------------------------------- -------------------------------------------------- simpleExpr :: Exp -> Bool simpleExpr (Var _) = True simpleExpr (Con _) = True simpleExpr (Int _) = True -- simpleExpr (Fun _) = True simpleExpr (Fun f) = f `elem` primitives simpleExpr _ = False simpleTerm :: Exp -> Bool simpleTerm (Var _) = True simpleTerm (App (Con _) _) = True simpleTerm (App (Fun f) _) = f `elem` primitives simpleTerm (Case (Var _) _) = True simpleTerm (Case (App (Con _) _) _) = True simpleTerm (Case (App (Fun f) _) _) = f `elem` primitives simpleTerm _ = False homeoTerm :: [Decl] -> Exp -> Maybe Exp homeoTerm es y = listToMaybe (concatMap homeo es) where homeo :: Decl -> [Exp] homeo (Func _ _ rhs) = if rhs <|| y then [rhs] else []