module Idris.PartialEval(
partial_eval, getSpecApps, specType
, mkPE_TyDecl, mkPE_TermDecl, PEArgType(..)
, pe_app, pe_def, pe_clauses, pe_simple
) where
import Idris.AbsSyntax
import Idris.Delaborate
import Idris.Core.TT
import Idris.Core.CaseTree
import Idris.Core.Evaluate
import Control.Monad.State
import Control.Applicative
import Data.Maybe
import Debug.Trace
data PEArgType = ImplicitS
| ImplicitD
| ExplicitS
| ExplicitD
| UnifiedD
deriving (Eq, Show)
data PEDecl = PEDecl { pe_app :: PTerm,
pe_def :: PTerm,
pe_clauses :: [(PTerm, PTerm)],
pe_simple :: Bool
}
partial_eval :: Context
-> [(Name, Maybe Int)]
-> [Either Term (Term, Term)]
-> Maybe [Either Term (Term, Term)]
partial_eval ctxt ns_in tms = mapM peClause tms where
ns = squash ns_in
squash ((n, Just x) : ns)
| Just (Just y) <- lookup n ns
= squash ((n, Just (x + y)) : drop n ns)
| otherwise = (n, Just x) : squash ns
squash (n : ns) = n : squash ns
squash [] = []
drop n ((m, _) : ns) | n == m = ns
drop n (x : ns) = x : drop n ns
drop n [] = []
peClause (Left t) = Just $ Left t
peClause (Right (lhs, rhs))
= let (rhs', reductions) = specialise ctxt [] (map toLimit ns) rhs in
do when (length tms == 1) $ checkProgress ns reductions
return (Right (lhs, rhs'))
toLimit (n, Nothing) | isTCDict n ctxt = (n, 2)
toLimit (n, Nothing) = (n, 65536)
toLimit (n, Just l) = (n, l)
checkProgress ns [] = return ()
checkProgress ns ((n, r) : rs)
| Just (Just start) <- lookup n ns
= if start <= 1 || r < start then checkProgress ns rs else Nothing
| otherwise = checkProgress ns rs
specType :: [(PEArgType, Term)] -> Type -> (Type, [(PEArgType, Term)])
specType args ty = let (t, args') = runState (unifyEq args ty) [] in
(st (map fst args') t, map fst args')
where
st ((ExplicitS, v) : xs) (Bind n (Pi _ t _) sc)
= Bind n (Let t v) (st xs sc)
st ((ImplicitS, v) : xs) (Bind n (Pi _ t _) sc)
= Bind n (Let t v) (st xs sc)
st ((UnifiedD, _) : xs) (Bind n (Pi _ t _) sc)
= st xs sc
st (_ : xs) (Bind n (Pi i t k) sc)
= Bind n (Pi i t k) (st xs sc)
st _ t = t
unifyEq (imp@(ImplicitD, v) : xs) (Bind n (Pi i t k) sc)
= do amap <- get
case lookup imp amap of
Just n' ->
do put (amap ++ [((UnifiedD, Erased), n)])
sc' <- unifyEq xs (subst n (P Bound n' Erased) sc)
return (Bind n (Pi i t k) sc')
_ -> do put (amap ++ [(imp, n)])
sc' <- unifyEq xs sc
return (Bind n (Pi i t k) sc')
unifyEq (x : xs) (Bind n (Pi i t k) sc)
= do args <- get
put (args ++ [(x, n)])
sc' <- unifyEq xs sc
return (Bind n (Pi i t k) sc')
unifyEq xs t = do args <- get
put (args ++ (zip xs (repeat (sUN "_"))))
return t
mkPE_TyDecl :: IState -> [(PEArgType, Term)] -> Type -> PTerm
mkPE_TyDecl ist args ty = mkty args ty
where
mkty ((ExplicitD, v) : xs) (Bind n (Pi _ t k) sc)
= PPi expl n NoFC (delab ist (generaliseIn t)) (mkty xs sc)
mkty ((ImplicitD, v) : xs) (Bind n (Pi _ t k) sc)
| concreteInterface ist t = mkty xs sc
| interfaceConstraint ist t
= PPi constraint n NoFC (delab ist (generaliseIn t)) (mkty xs sc)
| otherwise = PPi impl n NoFC (delab ist (generaliseIn t)) (mkty xs sc)
mkty (_ : xs) t
= mkty xs t
mkty [] t = delab ist t
generaliseIn tm = evalState (gen tm) 0
gen tm | (P _ fn _, args) <- unApply tm,
isFnName fn (tt_ctxt ist)
= do nm <- get
put (nm + 1)
return (P Bound (sMN nm "spec") Erased)
gen (App s f a) = App s <$> gen f <*> gen a
gen tm = return tm
interfaceConstraint :: Idris.AbsSyntax.IState -> TT Name -> Bool
interfaceConstraint ist v
| (P _ c _, args) <- unApply v = case lookupCtxt c (idris_interfaces ist) of
[_] -> True
_ -> False
| otherwise = False
concreteInterface :: IState -> TT Name -> Bool
concreteInterface ist v
| not (interfaceConstraint ist v) = False
| (P _ c _, args) <- unApply v = all concrete args
| otherwise = False
where concrete (Constant _) = True
concrete tm | (P _ n _, args) <- unApply tm
= case lookupTy n (tt_ctxt ist) of
[_] -> all concrete args
_ -> False
| otherwise = False
mkNewPats :: IState
-> [(Term, Term)]
-> [(PEArgType, Term)]
-> Name
-> Name
-> PTerm
-> PTerm
-> PEDecl
mkNewPats ist d ns newname sname lhs rhs | all dynVar (map fst d)
= PEDecl lhs rhs [(lhs, rhs)] True
where dynVar ap = case unApply ap of
(_, args) -> dynArgs ns args
dynArgs _ [] = True
dynArgs ((ImplicitS, _) : ns) (a : as) = dynArgs ns as
dynArgs ((ExplicitS, _) : ns) (a : as) = dynArgs ns as
dynArgs (_ : ns) (V _ : as) = dynArgs ns as
dynArgs (_ : ns) (P _ _ _ : as) = dynArgs ns as
dynArgs _ _ = False
mkNewPats ist d ns newname sname lhs rhs =
PEDecl lhs rhs (map mkClause d) False
where
mkClause :: (Term, Term) -> (PTerm, PTerm)
mkClause (oldlhs, oldrhs)
= let (_, as) = unApply oldlhs
lhsargs = mkLHSargs [] ns as
lhs = PApp emptyFC (PRef emptyFC [] newname) lhsargs
rhs = PApp emptyFC (PRef emptyFC [] sname)
(mkRHSargs ns lhsargs) in
(lhs, rhs)
mkLHSargs _ [] _ = []
mkLHSargs sub ((ExplicitD, t) : ns) (a : as)
= pexp (delab ist (substNames sub a)) : mkLHSargs sub ns as
mkLHSargs sub ((ImplicitD, _) : ns) (a : as)
= mkLHSargs sub ns as
mkLHSargs sub ((UnifiedD, _) : ns) (a : as)
= mkLHSargs sub ns as
mkLHSargs sub ((ImplicitS, t) : ns) (a : as)
= mkLHSargs (extend a t sub) ns as
mkLHSargs sub ((ExplicitS, t) : ns) (a : as)
= mkLHSargs (extend a t sub) ns as
mkLHSargs sub _ [] = []
extend (P _ n _) t sub = (n, t) : sub
extend _ _ sub = sub
mkRHSargs ((ExplicitS, t) : ns) as = pexp (delab ist t) : mkRHSargs ns as
mkRHSargs ((ExplicitD, t) : ns) (a : as) = a : mkRHSargs ns as
mkRHSargs (_ : ns) as = mkRHSargs ns as
mkRHSargs _ _ = []
mkSubst :: (Term, Term) -> Maybe (Name, Term)
mkSubst (P _ n _, t) = Just (n, t)
mkSubst _ = Nothing
mkPE_TermDecl :: IState
-> Name
-> Name
-> [(PEArgType, Term)]
-> PEDecl
mkPE_TermDecl ist newname sname ns
= let lhs = PApp emptyFC (PRef emptyFC [] newname) (map pexp (mkp ns))
rhs = eraseImps $ delab ist (mkApp (P Ref sname Erased) (map snd ns))
patdef = lookupCtxtExact sname (idris_patdefs ist)
newpats = case patdef of
Nothing -> PEDecl lhs rhs [(lhs, rhs)] True
Just d -> mkNewPats ist (getPats d) ns
newname sname lhs rhs in
newpats where
getPats (ps, _) = map (\(_, lhs, rhs) -> (lhs, rhs)) ps
mkp [] = []
mkp ((ExplicitD, tm) : tms) = delab ist tm : mkp tms
mkp (_ : tms) = mkp tms
eraseImps tm = mapPT deImp tm
deImp (PApp fc t as) = PApp fc t (map deImpArg as)
deImp t = t
deImpArg a@(PImp _ _ _ _ _) = a { getTm = Placeholder }
deImpArg a = a
getSpecApps :: IState
-> [Name]
-> Term
-> [(Name, [(PEArgType, Term)])]
getSpecApps ist env tm = ga env (explicitNames tm) where
staticArg env x imp tm n
| x && imparg imp = (ImplicitS, tm)
| x = (ExplicitS, tm)
| imparg imp = (ImplicitD, tm)
| otherwise = (ExplicitD, (P Ref (sUN (show n ++ "arg")) Erased))
imparg (PExp _ _ _ _) = False
imparg _ = True
buildApp env [] [] _ _ = []
buildApp env (s:ss) (i:is) (a:as) (n:ns)
= let s' = staticArg env s i a n
ss' = buildApp env ss is as ns in
(s' : ss')
ga env tm@(App _ f a) | (P _ n _, args) <- unApply tm,
n `notElem` map fst (idris_metavars ist) =
ga env f ++ ga env a ++
case (lookupCtxtExact n (idris_statics ist),
lookupCtxtExact n (idris_implicits ist)) of
(Just statics, Just imps) ->
if (length statics == length args && or statics
&& specialisable (tt_ctxt ist) n) then
case buildApp env statics imps args [0..] of
args -> [(n, args)]
else []
_ -> []
ga env (Bind n (Let t v) sc) = ga env v ++ ga (n : env) sc
ga env (Bind n t sc) = ga (n : env) sc
ga env t = []
specialisable :: Context -> Name -> Bool
specialisable ctxt n = case lookupDefExact n ctxt of
Just (CaseOp _ _ _ _ _ cds) ->
noOverlap (snd (cases_compiletime cds))
_ -> False
noOverlap :: SC -> Bool
noOverlap (Case _ _ [DefaultCase sc]) = noOverlap sc
noOverlap (Case _ _ alts) = noOverlapAlts alts
noOverlap _ = True
noOverlapAlts (ConCase _ _ _ sc : rest)
= noOverlapAlts rest && noOverlap sc
noOverlapAlts (FnCase _ _ sc : rest) = noOverlapAlts rest
noOverlapAlts (ConstCase _ sc : rest)
= noOverlapAlts rest && noOverlap sc
noOverlapAlts (SucCase _ sc : rest)
= noOverlapAlts rest && noOverlap sc
noOverlapAlts (DefaultCase _ : _) = False
noOverlapAlts _ = True