module IRTS.LangOpts(inlineAll) where
import Idris.Core.CaseTree
import Idris.Core.TT
import IRTS.Lang
import Control.Monad.State hiding (lift)
import Data.List
import Debug.Trace
inlineAll :: [(Name, LDecl)] -> [(Name, LDecl)]
inlineAll lds = let defs = addAlist lds emptyContext in
map (\ (n, def) -> (n, doInline defs def)) lds
nextN :: State Int Name
nextN = do i <- get
put (i + 1)
return $ sMN i "in"
doInline :: LDefs -> LDecl -> LDecl
doInline defs d@(LConstructor _ _ _) = d
doInline defs (LFun opts topn args exp)
= let inl = evalState (eval [] initEnv [topn] defs exp)
(length args)
res = eta $ caseFloats 10 inl in
case res of
LLam args' body -> LFun opts topn (map snd initNames ++ args') body
_ -> LFun opts topn (map snd initNames) res
where
caseFloats 0 tm = tm
caseFloats n tm
= let res = caseFloat tm in
if res == tm
then res
else caseFloats (n1) res
initNames = zipWith (\n i -> (n, newn n i)) args [0..]
initEnv = map (\(n, n') -> (n, LV n')) initNames
newn (UN n) i = MN i n
newn _ i = sMN i "arg"
unload :: [LExp] -> LExp -> LExp
unload [] e = e
unload stk (LApp tc e args) = LApp tc e (args ++ stk)
unload stk e = LApp False e stk
takeStk :: [(Name, LExp)] -> [Name] -> [LExp] ->
([(Name, LExp)], [Name], [LExp])
takeStk env (a : args) (v : stk) = takeStk ((a, v) : env) args stk
takeStk env args stk = (env, args, stk)
eval :: [LExp] -> [(Name, LExp)] -> [Name] -> LDefs -> LExp -> State Int LExp
eval stk env rec defs (LLazyApp n es)
= unload stk <$> LLazyApp n <$> (mapM (eval [] env rec defs) es)
eval stk env rec defs (LForce e)
= do e' <- eval [] env rec defs e
case e' of
LLazyExp forced -> return $ unload stk forced
LLazyApp n es -> return $ unload stk (LApp False (LV n) es)
_ -> return (unload stk (LForce e'))
eval stk env rec defs (LLazyExp e)
= unload stk <$> LLazyExp <$> eval [] env rec defs e
eval [] env rec defs (LApp t (LV n) [_, _, _, act, (LLam [arg] k)])
| n == sUN "io_bind"
= do w <- nextN
let env' = (w, LV w) : env
act' <- eval [] env' rec defs (LApp False act [LV w])
argn <- nextN
k' <- eval [] ((arg, LV argn) : env') rec defs (LApp False k [LV w])
return $ LLam [w] (LLet argn act' k')
eval (world : stk) env rec defs (LApp t (LV n) [_, _, _, act, (LLam [arg] k)])
| n == sUN "io_bind"
= do act' <- eval [] env rec defs (LApp False act [world])
argn <- nextN
k' <- eval stk ((arg, LV argn) : env) rec defs (LApp False k [world])
return $ LLet argn act' k'
eval stk env rec defs (LApp t f es)
= do es' <- mapM (eval [] env rec defs) es
eval (es' ++ stk) env rec defs f
eval stk env rec defs (LLet n val sc)
= do n' <- nextN
LLet n' <$> eval [] env rec defs val
<*> eval stk ((n, LV n') : env) rec defs sc
eval stk env rec defs (LProj exp i)
= unload stk <$> (LProj <$> eval [] env rec defs exp <*> return i)
eval stk env rec defs (LCon loc i n es)
= unload stk <$> (LCon loc i n <$> mapM (eval [] env rec defs) es)
eval stk env rec defs (LCase ty e alts)
= do alts' <- mapM (evalAlt stk env rec defs) alts
e' <- eval [] env rec defs e
let prefix = getLams (map getRHS alts')
case prefix of
[] -> return $ conOpt $ LCase ty e' (replaceInAlts e' alts')
args -> do alts_red <- mapM (dropArgs args) alts'
return $ LLam args
(conOpt (LCase ty e' (replaceInAlts e' alts_red)))
eval stk env rec defs (LOp f es)
= unload stk <$> LOp f <$> mapM (eval [] env rec defs) es
eval stk env rec defs (LForeign t s args)
= unload stk <$>
LForeign t s <$> mapM (\(t, e) -> do e' <- eval [] env rec defs e
return (t, e')) args
eval stk env rec defs (LLam args sc)
| (env', args', stk') <- takeStk env args stk
= case args' of
[] -> eval stk' env' rec defs sc
as -> do ns' <- mapM (\n -> do n' <- nextN
return (n, n')) args'
unload stk' <$> LLam (map snd ns') <$>
eval [] (map (\ (n, n') -> (n, LV n')) ns' ++ env')
rec defs sc
eval stk env rec defs var@(LV n)
= case lookup n env of
Just t
| t /= LV n && n `notElem` rec ->
eval stk env (n : rec) defs t
| otherwise -> return (unload stk t)
Nothing
| n `notElem` rec,
Just (LFun opts _ args body) <- lookupCtxtExact n defs,
Inline `elem` opts ->
apply stk env (n : rec) defs var args body
| Just (LConstructor n t a) <- lookupCtxtExact n defs ->
return (LCon Nothing t n stk)
| otherwise -> return (unload stk var)
eval stk env rec defs t = return (unload stk t)
evalAlt stk env rec defs (LConCase i n es rhs)
= do ns' <- mapM (\n -> do n' <- nextN
return (n, n')) es
LConCase i n (map snd ns') <$>
eval stk (map (\ (n, n') -> (n, LV n')) ns' ++ env) rec defs rhs
evalAlt stk env rec defs (LConstCase c e)
= LConstCase c <$> eval stk env rec defs e
evalAlt stk env rec defs (LDefaultCase e)
= LDefaultCase <$> eval stk env rec defs e
apply :: [LExp] -> [(Name, LExp)] -> [Name] -> LDefs -> LExp ->
[Name] -> LExp -> State Int LExp
apply stk env rec defs var args body
= eval stk env rec defs (LLam args body)
dropArgs :: [Name] -> LAlt -> State Int LAlt
dropArgs as (LConCase i n es (LLam args rhs))
= do let old = take (length as) args
rhs' <- eval [] (zipWith (\ o n -> (o, LV n)) old as) [] emptyContext rhs
return (LConCase i n es rhs')
dropArgs as (LConstCase c (LLam args rhs))
= do let old = take (length as) args
rhs' <- eval [] (zipWith (\ o n -> (o, LV n)) old as) [] emptyContext rhs
return (LConstCase c rhs')
dropArgs as (LDefaultCase (LLam args rhs))
= do let old = take (length as) args
rhs' <- eval [] (zipWith (\ o n -> (o, LV n)) old as) [] emptyContext rhs
return (LDefaultCase rhs')
caseFloat :: LExp -> LExp
caseFloat (LApp tc e es) = LApp tc (caseFloat e) (map caseFloat es)
caseFloat (LLazyExp e) = LLazyExp (caseFloat e)
caseFloat (LForce e) = LForce (caseFloat e)
caseFloat (LCon up i n es) = LCon up i n (map caseFloat es)
caseFloat (LOp f es) = LOp f (map caseFloat es)
caseFloat (LLam ns sc) = LLam ns (caseFloat sc)
caseFloat (LLet v val sc) = LLet v (caseFloat val) (caseFloat sc)
caseFloat (LCase _ (LCase ct exp alts) alts')
| all conRHS alts || length alts == 1
= conOpt $ replaceInCase (LCase ct (caseFloat exp) (map (updateWith alts') alts))
where
conRHS (LConCase _ _ _ (LCon _ _ _ _)) = True
conRHS (LConstCase _ (LCon _ _ _ _)) = True
conRHS (LDefaultCase (LCon _ _ _ _)) = True
conRHS _ = False
updateWith alts (LConCase i n es rhs) =
LConCase i n es (caseFloat (conOpt (LCase Shared (caseFloat rhs) alts)))
updateWith alts (LConstCase c rhs) =
LConstCase c (caseFloat (conOpt (LCase Shared (caseFloat rhs) alts)))
updateWith alts (LDefaultCase rhs) =
LDefaultCase (caseFloat (conOpt (LCase Shared (caseFloat rhs) alts)))
caseFloat (LCase ct exp alts')
= conOpt $ replaceInCase (LCase ct (caseFloat exp) (map cfAlt alts'))
where
cfAlt (LConCase i n es rhs) = LConCase i n es (caseFloat rhs)
cfAlt (LConstCase c rhs) = LConstCase c (caseFloat rhs)
cfAlt (LDefaultCase rhs) = LDefaultCase (caseFloat rhs)
caseFloat exp = exp
conOpt :: LExp -> LExp
conOpt (LCase ct (LCon _ t n args) alts)
= pickAlt n args alts
where
pickAlt n args (LConCase i n' es rhs : as) | n == n'
= substAll (zip es args) rhs
pickAlt _ _ (LDefaultCase rhs : as) = rhs
pickAlt n args (_ : as) = pickAlt n args as
pickAlt n args [] = error "Can't happen pickAlt - impossible case found"
substAll [] rhs = rhs
substAll ((n, tm) : ss) rhs = lsubst n tm (substAll ss rhs)
conOpt tm = tm
replaceInCase :: LExp -> LExp
replaceInCase (LCase ty e alts)
= LCase ty e (replaceInAlts e alts)
replaceInCase exp = exp
replaceInAlts :: LExp -> [LAlt] -> [LAlt]
replaceInAlts exp alts = dropDups $ concatMap (replaceInAlt exp) alts
dropDups (alt@(LConCase _ i n ns) : alts)
= alt : dropDups (filter (notTag i) alts)
where
notTag i (LConCase _ j n ns) = i /= j
notTag _ _ = True
dropDups (c : alts) = c : dropDups alts
dropDups [] = []
replaceInAlt :: LExp -> LAlt -> [LAlt]
replaceInAlt exp@(LV _) (LConCase i con args rhs)
= [LConCase i con args $
replaceExp (LCon Nothing i con (map LV args)) exp rhs]
replaceInAlt exp@(LV var) (LDefaultCase (LCase ty (LV var') alts))
| var == var' = alts
replaceInAlt exp a = [a]
replaceExp :: LExp -> LExp -> LExp -> LExp
replaceExp (LCon _ t n args) new (LCon _ t' n' args')
| n == n' && args == args' = new
replaceExp (LCon _ t n args) new (LApp _ (LV n') args')
| n == n' && args == args' = new
replaceExp old new tm = tm
getRHS (LConCase i n es rhs) = rhs
getRHS (LConstCase _ rhs) = rhs
getRHS (LDefaultCase rhs) = rhs
getLams [] = []
getLams (LLam args tm : cs) = getLamPrefix args cs
getLams _ = []
getLamPrefix as [] = as
getLamPrefix as (LLam args tm : cs)
| length args < length as = getLamPrefix args cs
| otherwise = getLamPrefix as cs
getLamPrefix as (_ : cs) = []
eta :: LExp -> LExp
eta (LApp tc a es) = LApp tc (eta a) (map eta es)
eta (LLazyApp n es) = LLazyApp n (map eta es)
eta (LLazyExp e) = LLazyExp (eta e)
eta (LForce e) = LForce (eta e)
eta (LLet n val sc) = LLet n (eta val) (eta sc)
eta (LLam args (LApp tc f args'))
| args' == map LV args = eta f
eta (LLam args e) = LLam args (eta e)
eta (LProj e i) = LProj (eta e) i
eta (LCon a t n es) = LCon a t n (map eta es)
eta (LCase ct e alts) = LCase ct (eta e) (map (fmap eta) alts)
eta (LOp f es) = LOp f (map eta es)
eta tm = tm