module Language.Fay.Compiler.Optimizer where
import Control.Applicative
import Control.Arrow (first)
import Control.Monad.Error
import Control.Monad.Writer
import Control.Monad.State
import Data.List
import Data.Maybe
import Language.Fay.Print
import Language.Fay.Types
import Language.Haskell.Exts (QName(..),ModuleName(..),Name(..))
import Language.Haskell.Exts (SrcLoc(..))
import Prelude hiding (exp)
type FuncArity = (QName,Int)
type Optimize = State OptState
data OptState = OptState
{ optStmts :: [JsStmt]
, optUncurry :: [QName]
}
runOptimizer :: ([JsStmt] -> Optimize [JsStmt]) -> [JsStmt] -> [JsStmt]
runOptimizer optimizer stmts =
let (newstmts,OptState _ uncurried) = flip runState st $ optimizer stmts
in (newstmts ++ (tco (catMaybes (map (uncurryBinding newstmts) uncurried))))
where st = OptState stmts []
optimizeToplevel :: [JsStmt] -> Optimize [JsStmt]
optimizeToplevel = stripAndUncurry
tco :: [JsStmt] -> [JsStmt]
tco = map inStmt where
inStmt stmt =
case stmt of
JsMappedVar srcloc name exp -> JsMappedVar srcloc name (inject name exp)
JsVar name exp -> JsVar name (inject name exp)
e -> e
inject name exp =
case exp of
JsFun params [] (Just (JsNew JsThunk [JsFun [] stmts ret])) ->
JsFun params
[]
(Just
(JsNew JsThunk
[JsFun []
(optimize params name (stmts ++ [ JsEarlyReturn e | Just e <- [ret] ]))
Nothing]))
_ -> exp
optimize params name stmts = result where
result = let (newstmts,w) = runWriter makeWhile
in if null w
then stmts
else newstmts
makeWhile = do
newstmts <- fmap concat (mapM swap stmts)
return [JsWhile (JsLit (JsBool True)) newstmts]
swap stmt =
case stmt of
JsEarlyReturn e
| tailCall e -> do tell [()]
return (rebind e ++ [JsContinue])
| otherwise -> return [stmt]
JsIf p ithen ielse -> do
newithen <- fmap concat (mapM swap ithen)
newielse <- fmap concat (mapM swap ielse)
return [JsIf p newithen newielse]
e -> return [e]
tailCall (JsApp (JsName cname) _) = cname == name
tailCall _ = False
rebind (JsApp _ args) = zipWith go args params where
go arg param = JsUpdate param arg
rebind e = error . show $ e
stripAndUncurry :: [JsStmt] -> Optimize [JsStmt]
stripAndUncurry = applyToExpsInStmts stripFuncForces where
stripFuncForces arities exp =
case exp of
JsApp (JsName JsForce) [JsName (JsNameVar f)]
| Just _ <- lookup f arities -> return (JsName (JsNameVar f))
JsFun ps stmts body -> do substmts <- mapM stripInStmt stmts
sbody <- maybe (return Nothing) (fmap Just . go) body
return (JsFun ps substmts sbody)
JsApp a b -> do
result <- walkAndStripForces arities exp
case result of
Just strippedExp -> go strippedExp
Nothing -> JsApp <$> go a <*> mapM go b
JsNegApp e -> JsNegApp <$> go e
JsTernaryIf a b c -> JsTernaryIf <$> go a <*> go b <*> go c
JsParen e -> JsParen <$> go e
JsUpdateProp e n a -> JsUpdateProp <$> go e <*> pure n <*> go a
JsList xs -> JsList <$> mapM go xs
JsEq a b -> JsEq <$> go a <*> go b
JsInfix op a b -> JsInfix op <$> go a <*> go b
JsObj xs -> JsObj <$> mapM (\(x,y) -> (x,) <$> go y) xs
JsNew name xs -> JsNew name <$> mapM go xs
e -> return e
where go = stripFuncForces arities
stripInStmt = applyToExpsInStmt arities stripFuncForces
walkAndStripForces :: [FuncArity] -> JsExp -> Optimize (Maybe JsExp)
walkAndStripForces arities = go True [] where
go frst args app = case app of
JsApp (JsName JsForce) [e] -> if frst
then do result <- go False args e
case result of
Nothing -> return Nothing
Just ex -> return (Just (JsApp (JsName JsForce) [ex]))
else go False args e
JsApp op [arg] -> go False (arg:args) op
JsName (JsNameVar f)
| Just arity <- lookup f arities, length args == arity -> do
modify $ \s -> s { optUncurry = f : optUncurry s }
return (Just (JsApp (JsName (JsNameVar (renameUncurried f))) args))
_ -> return Nothing
applyToExpsInStmts :: ([FuncArity] -> JsExp -> Optimize JsExp) -> [JsStmt] -> Optimize [JsStmt]
applyToExpsInStmts f stmts = mapM (applyToExpsInStmt (collectFuncs stmts) f) stmts
applyToExpsInStmt :: [FuncArity] -> ([FuncArity] -> JsExp -> Optimize JsExp) -> JsStmt -> Optimize JsStmt
applyToExpsInStmt funcs f stmts = uncurryInStmt stmts where
transform = f funcs
uncurryInStmt stmt =
case stmt of
JsMappedVar srcloc name exp -> JsMappedVar srcloc name <$> transform exp
JsVar name exp -> JsVar name <$> transform exp
JsEarlyReturn exp -> JsEarlyReturn <$> transform exp
JsIf op ithen ielse -> JsIf <$> transform op
<*> mapM uncurryInStmt ithen
<*> mapM uncurryInStmt ielse
s -> pure s
collectFuncs :: [JsStmt] -> [FuncArity]
collectFuncs = (++ prim) . concat . map collectFunc where
collectFunc (JsMappedVar _ name exp) = collectFunc (JsVar name exp)
collectFunc (JsVar (JsNameVar name) exp) | arity > 0 = [(name,arity)]
where arity = expArity exp
collectFunc _ = []
prim = map (first (Qual (ModuleName "Fay$"))) (unary ++ binary)
unary = map (,1) [Ident "return"]
binary = map ((,2) . Ident)
["then","bind","mult","mult","add","sub","div"
,"eq","neq","gt","lt","gte","lte","and","or"]
expArity :: JsExp -> Int
expArity (JsFun _ _ mexp) = 1 + maybe 0 expArity mexp
expArity _ = 0
test :: IO ()
test = do
let (newstmts,OptState _ uncurried) = flip runState st $ optimizeToplevel stmts
putStrLn $ printJSPretty newstmts
putStrLn $ printJSPretty (catMaybes (map (uncurryBinding newstmts) uncurried))
where
st = OptState stmts []
stmts = [JsMappedVar (SrcLoc {srcFilename = "", srcLine = 1, srcColumn = 1}) (JsNameVar (Qual (ModuleName "Main") (Ident "sum$uncurried"))) (JsFun [JsParam 1,JsParam 2] [] (Just (JsNew JsThunk [JsFun [] [JsVar (JsNameVar (UnQual (Ident "acc"))) (JsName (JsParam 2)),JsIf (JsEq (JsApp (JsName JsForce) [JsName (JsParam 1)]) (JsLit (JsInt 0))) [JsEarlyReturn (JsName (JsNameVar (UnQual (Ident "acc"))))] [],JsVar (JsNameVar (UnQual (Ident "acc"))) (JsName (JsParam 2)),JsVar (JsNameVar (UnQual (Ident "n"))) (JsName (JsParam 1)),JsEarlyReturn (JsApp (JsName (JsNameVar (Qual (ModuleName "Main") (Ident "sum$uncurried")))) [JsApp (JsName (JsNameVar (Qual (ModuleName "Fay$") (Ident "sub$uncurried")))) [JsApp (JsName JsForce) [JsName (JsNameVar (UnQual (Ident "n")))],JsLit (JsInt 1)],JsApp (JsName (JsNameVar (Qual (ModuleName "Fay$") (Ident "add$uncurried")))) [JsApp (JsName JsForce) [JsName (JsNameVar (UnQual (Ident "acc")))],JsApp (JsName JsForce) [JsName (JsNameVar (UnQual (Ident "n")))]]])] Nothing])))]
uncurryBinding :: [JsStmt] -> QName -> Maybe JsStmt
uncurryBinding stmts qname = listToMaybe (mapMaybe funBinding stmts)
where funBinding stmt =
case stmt of
JsMappedVar srcloc (JsNameVar name) body
| name == qname -> JsMappedVar srcloc (JsNameVar (renameUncurried name)) <$> uncurryIt body
JsVar (JsNameVar name) body
| name == qname -> JsVar (JsNameVar (renameUncurried name)) <$> uncurryIt body
_ -> Nothing
uncurryIt = Just . go [] where
go args exp =
case exp of
JsFun [arg] [] (Just body) -> go (arg : args) body
inner -> JsFun (reverse args) [] (Just inner)
renameUncurried :: QName -> QName
renameUncurried q =
case q of
Qual m n -> Qual m (renameUnQual n)
UnQual n -> UnQual (renameUnQual n)
s -> s
where renameUnQual n =
case n of
Ident nom -> Ident (nom ++ postfix)
Symbol nom -> Symbol (nom ++ postfix)
postfix = "$uncurried"