module Text.HPaco.Optimizer
where
import Text.HPaco.AST.AST
import Text.HPaco.AST.Expression
import Text.HPaco.AST.Statement
import qualified Text.HPaco.Writers.Run as Run
import Data.Variant hiding (lookup, elem)
import Control.Monad.State
import Control.Applicative
import System.IO.Unsafe (unsafePerformIO)
import Text.HPaco.AST.Identifier (Identifier)
import Data.Maybe
import Control.Arrow ( (***) )
import qualified Control.Arrow as Arrow
optimize :: AST -> AST
optimize = expandDefs . optimizeASTStatements . optimizeASTDefs
expandDefs :: AST -> AST
expandDefs ast =
ast { astRootStatement = goStatement [] $ astRootStatement ast
, astDefs = [ (i, goStatement [i] s) | (i, s) <- astDefs ast ]
}
where
goStatement :: [Identifier] -> Statement -> Statement
goStatement identPath stmt = fromMaybe stmt $ goRaw identPath stmt
goRaw :: [Identifier] -> Statement -> Maybe Statement
goRaw identPath (CallStatement ident) | ident `elem` identPath =
Just $ CallStatement ident
goRaw identPath (CallStatement ident) =
lookup ident (astDefs ast) >>= goRaw (ident:identPath)
goRaw identPath stmt =
case stmt of
PrintStatement _ -> Just stmt
NullStatement -> Just stmt
SourcePositionStatement _ _ -> Just stmt
otherwise -> Nothing
optimizeASTDefs :: AST -> AST
optimizeASTDefs (AST { astRootStatement = rs, astDeps = deps, astDefs = defs }) =
AST {
astRootStatement = rs,
astDeps = deps,
astDefs = [ (i, optimizeStatement s) | (i, s) <- defs ]
}
optimizeASTStatements :: AST -> AST
optimizeASTStatements ast =
ast { astRootStatement = optimizeStatement . astRootStatement $ ast
, astDefs = map (Arrow.second optimizeStatement) . astDefs $ ast
}
optimizeStatement :: Statement -> Statement
optimizeStatement (PrintStatement (IntLiteral i)) = PrintStatement . StringLiteral . show $ i
optimizeStatement (PrintStatement (FloatLiteral i)) = PrintStatement . StringLiteral . show $ i
optimizeStatement (PrintStatement (StringLiteral [])) = NullStatement
optimizeStatement (PrintStatement e) =
let e' = optimizeExpression e
in if e == e'
then PrintStatement e
else optimizeStatement $ PrintStatement e'
optimizeStatement (SourcePositionStatement _ _) = NullStatement
optimizeStatement (StatementSequence xs) =
let xs' = fusePrints $ filter (/= NullStatement) (map optimizeStatement xs)
in case xs' of
StatementSequence ss:rem -> optimizeStatement . StatementSequence $ ss ++ rem
s:[] -> s
[] -> NullStatement
otherwise -> StatementSequence xs'
optimizeStatement (IfStatement cond true false) =
let true' = optimizeStatement true
false' = optimizeStatement false
cond' = optimizeExpression cond
in case cond' of
BooleanLiteral True -> true'
BooleanLiteral False -> false'
otherwise -> IfStatement cond' true' false'
optimizeStatement (LetStatement id e stmt) =
let e' = optimizeExpression e
stmt' = optimizeStatement stmt
in LetStatement id e' stmt'
optimizeStatement (ForStatement iter id e stmt) =
let e' = optimizeExpression e
stmt' = optimizeStatement stmt
in ForStatement iter id e' stmt'
optimizeStatement (SwitchStatement e branches) =
let e' = optimizeExpression e
branches' = map (optimizeExpression *** optimizeStatement) branches
in SwitchStatement e' branches'
optimizeStatement s = s
fusePrints :: [Statement] -> [Statement]
fusePrints (PrintStatement (StringLiteral lhs):PrintStatement (StringLiteral rhs):rem) =
fusePrints $ PrintStatement (StringLiteral $ lhs ++ rhs):fusePrints rem
fusePrints (s:rem) = s:fusePrints rem
fusePrints [] = []
optimizeExpression :: Expression -> Expression
optimizeExpression e =
if isConst e
then evaluateConstExpression e
else e
isConst :: Expression -> Bool
isConst (StringLiteral _) = True
isConst (IntLiteral _) = True
isConst (FloatLiteral _) = True
isConst (BooleanLiteral _) = True
isConst (ListExpression xs) = all isConst xs
isConst (AListExpression xs) = all (\(k,v) -> isConst k && isConst v) xs
isConst (EscapeExpression _ e) = isConst e
isConst (UnaryExpression _ e) = isConst e
isConst (BinaryExpression _ a b) = isConst a && isConst b
isConst e = False
evaluateConstExpression :: Expression -> Expression
evaluateConstExpression e =
let rs = Run.RunState { Run.rsScope = Null, Run.rsAST = defAST, Run.rsOptions = Run.defaultOptions }
v = unsafePerformIO $ evalStateT (Run.runExpression e) rs
in fromVariant v
fromVariant :: Variant -> Expression
fromVariant (String s) = StringLiteral s
fromVariant (Integer i) = IntLiteral i
fromVariant (Double d) = FloatLiteral d
fromVariant (Bool b) = BooleanLiteral b
fromVariant (List xs) = ListExpression $ map fromVariant xs
fromVariant (AList xs) = AListExpression $ map (fromVariant *** fromVariant) xs