module Agda.Auto.CaseSplit where
import Agda.Utils.Impossible
#include "../undefined.h"
import Data.IORef
import Data.List (findIndex, union)
import qualified Data.IntMap as IntMap
import Agda.Auto.NarrowingSearch
import Agda.Auto.Syntax
import Agda.Auto.SearchControl
import Agda.Auto.Typecheck
abspatvarname = "\0absurdPattern"
costCaseSplitVeryHigh = 10000
costCaseSplitHigh = 5000
costCaseSplitLow = 2000
costAddVarDepth = 1000
data HI a = HI FMode a
drophid = map (\(HI _ x) -> x)
type CSPat o = HI (CSPatI o)
type CSCtx o = [HI (MId, MExp o)]
data CSPatI o = CSPatConApp (ConstRef o) [CSPat o]
| CSPatVar Nat
| CSPatExp (MExp o)
| CSWith (MExp o)
| CSAbsurd
| CSOmittedArg
type Sol o = [(CSCtx o, [CSPat o], Maybe (MExp o))]
caseSplitSearch :: forall o . IORef Int -> Int -> [ConstRef o] -> Maybe (EqReasoningConsts o) -> Int -> Int -> ConstRef o -> CSCtx o -> MExp o -> [CSPat o] -> IO [Sol o]
caseSplitSearch ticks nsolwanted chints meqr depthinterval depth recdef ctx tt pats = do
let branchsearch depth ctx tt termcheckenv = do
nsol <- newIORef 1
m <- initMeta
sol <- newIORef Nothing
let trm = Meta m
hsol = do trm' <- expandExp trm
writeIORef sol (Just trm')
hpartsol = __IMPOSSIBLE__
initcon = mpret $ Sidecondition (localTerminationSidecond termcheckenv recdef trm)
(
(case meqr of
Nothing -> id
Just eqr -> mpret . Sidecondition (calcEqRState eqr trm)
)
(tcSearch False (map (\(id, t) -> (id, closify t)) (drophid ctx)) (closify tt) trm)
)
recdefd <- readIORef recdef
let env = RIEnv {rieHints = (recdef, HMRecCall) : map (\x -> (x, HMNormal)) chints,
rieDefFreeVars = cddeffreevars recdefd
, rieEqReasoningConsts = meqr
}
depreached <- topSearch ticks nsol hsol env initcon depth (depth + 1)
rsol <- readIORef sol
return rsol
ctx' = ff 1 ctx
ff _ [] = []
ff n (HI hid (id, t) : ctx) = HI hid (id, lift n t) : ff (n + 1) ctx
caseSplitSearch' branchsearch depthinterval depth recdef ctx' tt pats
caseSplitSearch' :: forall o . (Int -> CSCtx o -> MExp o -> ([Nat], Nat, [Nat]) -> IO (Maybe (MExp o))) -> Int -> Int -> ConstRef o -> CSCtx o -> MExp o -> [CSPat o] -> IO [Sol o]
caseSplitSearch' branchsearch depthinterval depth recdef ctx tt pats = do
recdefd <- readIORef recdef
sols <- rc depth (cddeffreevars recdefd) ctx tt pats
return sols
where
rc :: Int -> Int -> CSCtx o -> MExp o -> [CSPat o] -> IO [Sol o]
rc depth _ _ _ _ | depth < 0 = return []
rc depth nscrutavoid ctx tt pats = do
mblkvar <- getblks tt
fork
mblkvar
where
fork :: [Nat] -> IO [Sol o]
fork mblkvar = do
sols1 <- dobody
case sols1 of
(_:_) -> return sols1
[] -> do
let r [] = return []
r (v:vs) = do
sols2 <- splitvar mblkvar v
case sols2 of
(_:_) -> return sols2
[] -> r vs
r [nv x | x <- [0..nv]]
where nv = length ctx 1
dobody :: IO [Sol o]
dobody = do
case findperm (map snd (drophid ctx)) of
Just perm -> do
let (ctx', tt', pats') = applyperm perm ctx tt pats
res <- branchsearch depth ctx' tt' (localTerminationEnv pats')
return $ case res of
Just trm -> [[(ctx', pats', Just trm)]]
Nothing -> []
Nothing -> __IMPOSSIBLE__
splitvar :: [Nat] -> Nat -> IO [Sol o]
splitvar mblkvar scrut = do
let scruttype = infertypevar ctx scrut
case rm scruttype of
App _ _ (Const c) _ -> do
cd <- readIORef c
case cdcont cd of
Datatype cons _ -> do
sols <- dobranches cons
return $ map (\sol -> case sol of
[] ->
case findperm (map snd (drophid ctx)) of
Just perm ->
let HI scrhid(_, scrt) = ctx !! scrut
ctx1 = take scrut ctx ++ (HI scrhid (Id abspatvarname, scrt)) : drop (scrut + 1) ctx
(ctx', _, pats') = applyperm perm ctx1 tt (pats)
in [(ctx', pats', Nothing)]
Nothing -> __IMPOSSIBLE__
_ -> sol
) sols
where
dobranches :: [ConstRef o] -> IO [Sol o]
dobranches [] = return [[]]
dobranches (con : cons) = do
cond <- readIORef con
let ff t = case rm t of
Pi _ h _ it (Abs id ot) ->
let (xs, inft) = ff ot
in ((Pair h (scrut + length xs), id, lift (scrut + length xs + 1) it) : xs, inft)
_ -> ([], lift scrut t)
(newvars, inftype) = ff (cdtype cond)
constrapp = mm $ App Nothing (mm OKVal) (Const con) (foldl (\xs (Pair h v, _, _) -> mm $ ALCons h (mm $ App Nothing (mm OKVal) (Var v) (mm ALNil)) xs) (mm ALNil) (reverse newvars))
pconstrapp = CSPatConApp con (map (\(Pair hid v, _, _) -> HI hid (CSPatVar v)) newvars)
thesub = replace scrut (length newvars) constrapp
Id newvarprefix = fst $ (drophid ctx) !! scrut
ctx1 = map (\(HI hid (id, t)) -> HI hid (id, thesub t)) (take scrut ctx) ++
reverse (map (\((Pair hid _, id, t), i) ->
HI hid (Id (case id of {NoId -> newvarprefix; Id id -> id}), t)
) (zip newvars [0..])) ++
map (\(HI hid (id, t)) -> HI hid (id, thesub t)) (drop (scrut + 1) ctx)
tt' = thesub tt
pats' = map (replacep scrut (length newvars) pconstrapp constrapp) pats
scruttype' = thesub scruttype
case unifyexp inftype scruttype' of
Nothing -> do
res <- notequal scrut (length newvars) scruttype' inftype
if res then
dobranches cons
else
return []
Just unif ->
do
let (ctx2, tt2, pats2) = removevar ctx1 tt' pats' unif
cost = if null mblkvar then
if scrut < length ctx nscrutavoid && nothid then costCaseSplitLow + costAddVarDepth * depthofvar scrut pats else costCaseSplitVeryHigh
else
if elem scrut mblkvar then costCaseSplitLow else (if scrut < length ctx nscrutavoid && nothid then costCaseSplitHigh else costCaseSplitVeryHigh)
nothid = let HI hid _ = ctx !! scrut
in case hid of {Hidden -> False; NotHidden -> True}
sols <- rc (depth cost) (length ctx 1 scrut) ctx2 tt2 pats2
case sols of
[] -> return []
_ -> do
sols2 <- dobranches cons
return $ concat (map (\sol -> map (\sol2 -> sol ++ sol2) sols2) sols)
_ -> return []
_ -> return []
infertypevar :: CSCtx o -> Nat -> MExp o
infertypevar ctx v = snd $ (drophid ctx) !! v
replace :: Nat -> Nat -> MExp o -> MExp o -> MExp o
replace sv nnew re = r 0
where
r n e =
case rm e of
App uid ok elr@(Var v) args ->
if v >= n then
if v n == sv then
betareduce (lift n re) (rs n args)
else
if v n > sv then
mm $ App uid ok (Var (v + nnew 1)) (rs n args)
else
mm $ App uid ok elr (rs n args)
else
mm $ App uid ok elr (rs n args)
App uid ok elr@(Const _) args ->
mm $ App uid ok elr (rs n args)
Lam hid (Abs mid e) -> mm $ Lam hid (Abs mid (r (n + 1) e))
Pi uid hid possdep it (Abs mid ot) -> mm $ Pi uid hid possdep (r n it) (Abs mid (r (n + 1) ot))
Sort{} -> e
AbsurdLambda{} -> e
rs n es =
case rm es of
ALNil -> mm $ ALNil
ALCons hid a as -> mm $ ALCons hid (r n a) (rs n as)
ALProj{} -> __IMPOSSIBLE__
ALConPar as -> mm $ ALConPar (rs n as)
betareduce :: MExp o -> MArgList o -> MExp o
betareduce e args = case rm args of
ALNil -> e
ALCons _ a rargs -> case rm e of
App uid ok elr eargs -> mm $ App uid ok elr (concatargs eargs args)
Lam _ (Abs _ b) -> betareduce (replace 0 0 a b) rargs
_ -> __IMPOSSIBLE__
ALProj{} -> __IMPOSSIBLE__
ALConPar as -> __IMPOSSIBLE__
concatargs xs ys = case rm xs of
ALNil -> ys
ALCons hid x xs -> mm $ ALCons hid x (concatargs xs ys)
ALProj{} -> __IMPOSSIBLE__
ALConPar as -> mm $ ALConPar (concatargs xs ys)
eqelr :: Elr o -> Elr o -> Bool
eqelr (Var v1) (Var v2) = v1 == v2
eqelr (Const c1) (Const c2) = c1 == c2
eqelr _ _ = False
replacep :: Nat -> Nat -> CSPatI o -> MExp o -> CSPat o -> CSPat o
replacep sv nnew rp re = r
where
r (HI hid (CSPatConApp c ps)) = HI hid (CSPatConApp c (map r ps))
r (HI hid (CSPatVar v)) = if v == sv then
HI hid rp
else
if v > sv then
HI hid (CSPatVar (v + nnew 1))
else
HI hid (CSPatVar v)
r (HI hid (CSPatExp e)) = HI hid (CSPatExp $ replace sv nnew re e)
r p@(HI _ CSOmittedArg) = p
r _ = __IMPOSSIBLE__
rm :: MM a b -> a
rm (NotM x) = x
rm (Meta{}) = __IMPOSSIBLE__
mm :: a -> MM a b
mm = NotM
unifyexp :: MExp o -> MExp o -> Maybe [(Nat, MExp o)]
unifyexp e1 e2 = r e1 e2 (\unif -> Just unif) []
where
r e1 e2 cont unif = case (rm e1, rm e2) of
(App _ _ elr1 args1, App _ _ elr2 args2) | eqelr elr1 elr2 -> rs args1 args2 cont unif
(Lam hid1 (Abs _ b1), Lam hid2 (Abs _ b2)) | hid1 == hid2 -> r b1 b2 cont unif
(Pi _ hid1 _ it1 (Abs _ ot1), Pi _ hid2 _ it2 (Abs _ ot2)) | hid1 == hid2 -> r it1 it2 (r ot1 ot2 cont) unif
(Sort _, Sort _) -> cont unif
(App _ _ (Var v) (NotM ALNil), _) ->
case lookup v unif of
Nothing -> cont ((v, e2) : unif)
Just e1' -> r e1' e2 cont unif
(_, App _ _ (Var v) (NotM ALNil)) ->
case lookup v unif of
Nothing -> cont ((v, e1) : unif)
Just e2' -> r e1 e2' cont unif
_ -> Nothing
rs args1 args2 cont unif = case (rm args1, rm args2) of
(ALNil, ALNil) -> cont unif
(ALCons hid1 a1 as1, ALCons hid2 a2 as2) | hid1 == hid2 -> r a1 a2 (rs as1 as2 cont) unif
(ALConPar as1, ALCons _ _ as2) -> rs as1 as2 cont unif
(ALCons _ _ as1, ALConPar as2) -> rs as1 as2 cont unif
(ALConPar as1, ALConPar as2) -> rs as1 as2 cont unif
_ -> Nothing
lift :: Nat -> MExp o -> MExp o
lift 0 = id
lift n = r 0
where
r j e =
case rm e of
App uid ok elr args -> case elr of
Var v | v >= j -> mm $ App uid ok (Var (v + n)) (rs j args)
_ -> mm $ App uid ok elr (rs j args)
Lam hid (Abs mid e) -> mm $ Lam hid (Abs mid (r (j + 1) e))
Pi uid hid possdep it (Abs mid ot) -> mm $ Pi uid hid possdep (r j it) (Abs mid (r (j + 1) ot))
Sort{} -> e
AbsurdLambda{} -> e
rs j es =
case rm es of
ALNil -> mm $ ALNil
ALCons hid a as -> mm $ ALCons hid (r j a) (rs j as)
ALProj{} -> __IMPOSSIBLE__
ALConPar as -> mm $ ALConPar (rs j as)
removevar :: CSCtx o -> MExp o -> [CSPat o] -> [(Nat, MExp o)] -> (CSCtx o, MExp o, [CSPat o])
removevar ctx tt pats [] = (ctx, tt, pats)
removevar ctx tt pats ((v, e) : unif) =
let
e2 = replace v 0 (__IMPOSSIBLE__ ) e
thesub = replace v 0 e2
ctx1 = map (\(HI hid (id, t)) -> HI hid (id, thesub t)) (take v ctx) ++
map (\(HI hid (id, t)) -> HI hid (id, thesub t)) (drop (v + 1) ctx)
tt' = thesub tt
pats' = map (replacep v 0 (CSPatExp e2) e2) pats
unif' = map (\(uv, ue) -> (if uv > v then uv 1 else uv, thesub ue)) unif
in
removevar ctx1 tt' pats' unif'
notequal :: Nat -> Nat -> MExp o -> MExp o -> IO Bool
notequal firstnew nnew e1 e2 =
case (rm e1, rm e2) of
(App _ _ _ es1, App _ _ _ es2) -> rs es1 es2 (\_ -> return False) []
_ -> __IMPOSSIBLE__
where
rs :: MArgList o -> MArgList o -> ([(Nat, MExp o)] -> IO Bool) -> [(Nat, MExp o)] -> IO Bool
rs es1 es2 cont unifier2 =
case (rm es1, rm es2) of
(ALCons _ e1 es1, ALCons _ e2 es2) -> r e1 e2 (rs es1 es2 cont) unifier2
(ALConPar es1, ALConPar es2) -> rs es1 es2 cont unifier2
_ -> cont unifier2
r :: MExp o -> MExp o -> ([(Nat, MExp o)] -> IO Bool) -> [(Nat, MExp o)] -> IO Bool
r e1 e2 cont unifier2 = case rm e2 of
App _ _ (Var v2) es2 | firstnew <= v2 && v2 < firstnew + nnew ->
case rm es2 of
ALNil ->
case lookup v2 unifier2 of
Nothing -> cont ((v2, e1) : unifier2)
Just e2' -> cc e1 e2'
ALCons{} -> cont unifier2
ALProj{} -> __IMPOSSIBLE__
ALConPar{} -> __IMPOSSIBLE__
_ -> cc e1 e2
where
cc e1 e2 = case (rm e1, rm e2) of
(App _ _ (Const c1) es1, App _ _ (Const c2) es2) -> do
cd1 <- readIORef c1
cd2 <- readIORef c2
case (cdcont cd1, cdcont cd2) of
(Constructor{}, Constructor{}) ->
if c1 == c2 then
rs es1 es2 cont unifier2
else
return True
_ -> cont unifier2
_ -> cont unifier2
findperm :: [MExp o] -> Maybe [Nat]
findperm ts =
let
frees = map freevars ts
m = IntMap.fromList (map (\i -> (i, length (filter (elem i) frees))) [0..length ts 1])
r _ perm 0 = Just $ reverse perm
r m perm n =
case lookup 0 (map (\(x,y) -> (y,x)) (IntMap.toList m)) of
Nothing -> Nothing
Just i -> r (foldl (\m i -> IntMap.adjust (\x -> x 1) i m) (IntMap.insert i (1) m) (frees !! i)) (i : perm) (n 1)
in r m [] (length ts)
freevars :: MExp o -> [Nat]
freevars = f 0
where
f n e = case rm e of
App _ _ (Var v) args -> union [v n] (fs n args)
App _ _ (Const _) args -> fs n args
Lam _ (Abs _ b) -> f (n + 1) b
Pi _ _ _ it (Abs _ ot) -> union (f n it) (f (n + 1) ot)
Sort{} -> []
AbsurdLambda{} -> []
fs n es = case rm es of
ALNil -> []
ALCons _ e es -> union (f n e) (fs n es)
ALProj{} -> __IMPOSSIBLE__
ALConPar es -> fs n es
applyperm :: [Nat] -> CSCtx o -> MExp o -> [CSPat o] -> (CSCtx o, MExp o, [CSPat o])
applyperm perm ctx tt pats =
let ctx1 = map (\(HI hid (id, t)) -> HI hid (id, rename (ren perm) t)) ctx
ctx2 = map (\i -> ctx1 !! i) perm
ctx3 = seqctx ctx2
tt' = rename (ren perm) tt
pats' = map (renamep (ren perm)) pats
in (ctx3, tt', pats')
ren n i = let Just j = findIndex (== i) n in j
rename :: (Nat -> Nat) -> MExp o -> MExp o
rename ren = r 0
where
r j e =
case rm e of
App uid ok elr args -> case elr of
Var v | v >= j -> mm $ App uid ok (Var (ren (v j) + j)) (rs j args)
_ -> mm $ App uid ok elr (rs j args)
Lam hid (Abs mid e) -> mm $ Lam hid (Abs mid (r (j + 1) e))
Pi uid hid possdep it (Abs mid ot) -> mm $ Pi uid hid possdep (r j it) (Abs mid (r (j + 1) ot))
Sort{} -> e
AbsurdLambda{} -> e
rs j es =
case rm es of
ALNil -> mm $ ALNil
ALCons hid a as -> mm $ ALCons hid (r j a) (rs j as)
ALProj{} -> __IMPOSSIBLE__
ALConPar as -> mm $ ALConPar (rs j as)
renamep :: (Nat -> Nat) -> CSPat o -> CSPat o
renamep ren = r
where
r (HI hid (CSPatConApp c pats)) = HI hid (CSPatConApp c (map r pats))
r (HI hid (CSPatVar i)) = HI hid (CSPatVar $ ren i)
r (HI hid (CSPatExp e)) = HI hid (CSPatExp $ rename ren e)
r p@(HI _ CSOmittedArg) = p
r _ = __IMPOSSIBLE__
seqctx :: CSCtx o -> CSCtx o
seqctx = r (1)
where
r _ [] = []
r n (HI hid (id, t) : ctx) = HI hid (id, lift n t) : r (n 1) ctx
depthofvar :: Nat -> [CSPat o] -> Nat
depthofvar v pats =
let [depth] = concatMap (f 0) (drophid pats)
f d (CSPatConApp _ pats) = concatMap (f (d + 1)) (drophid pats)
f d (CSPatVar v') = if v == v' then [d] else []
f _ _ = []
in depth
localTerminationEnv :: [CSPat o] -> ([Nat], Nat, [Nat])
localTerminationEnv pats =
let g _ [] = ([], 0, [])
g i (hp@(HI _ p) : ps) = case p of
CSPatConApp{} ->
let (size, vars) = h hp
(is, size', vars') = g (i + 1) ps
in (i : is, size + size', vars ++ vars')
_ -> g (i + 1) ps
h (HI _ p) = case p of
CSPatConApp c ps ->
let (size, vars) = hs ps
in (size + 1, vars)
CSPatVar n -> (0, [n])
CSPatExp e -> he e
_ -> (0, [])
hs [] = (0, [])
hs (p : ps) =
let (size, vars) = h p
(size', vars') = hs ps
in (size + size', vars ++ vars')
he e = case rm e of
App _ _ (Var v) _ -> (0, [v])
App _ _ (Const _) args ->
let (size, vars) = hes args
in (size + 1, vars)
_ -> (0, [])
hes as = case rm as of
ALNil -> (0, [])
ALCons _ a as ->
let (size, vars) = he a
(size', vars') = hes as
in (size + size', vars ++ vars')
ALProj{} -> __IMPOSSIBLE__
ALConPar as -> hes as
in g 0 pats
localTerminationSidecond :: ([Nat], Nat, [Nat]) -> ConstRef o -> MExp o -> EE (MyPB o)
localTerminationSidecond (is, size, vars) reccallc b =
ok b
where
ok e = mmpcase (False, prioNo, Nothing) e $ \e -> case e of
App _ _ elr args -> mpret $ Sidecondition
(oks args)
(case elr of
Const c | c == reccallc -> if size == 0 then mpret (Error "localTerminationSidecond: no size to decrement") else okcall 0 size vars args
_ -> mpret OK
)
Lam _ (Abs _ e) -> ok e
Pi _ _ _ it (Abs _ ot) -> mpret $ Sidecondition
(ok it)
(ok ot)
Sort{} -> mpret OK
AbsurdLambda{} -> mpret OK
oks as = mmpcase (False, prioNo, Nothing) as $ \as -> case as of
ALNil -> mpret OK
ALCons _ a as -> mpret $ Sidecondition
(ok a)
(oks as)
ALProj eas _ _ as -> mpret $ Sidecondition (oks eas) (oks as)
ALConPar as -> oks as
okcall i size vars as = mmpcase (False, prioNo, Nothing) as $ \as -> case as of
ALNil -> mpret OK
ALCons _ a as | elem i is ->
mbpcase prioNo Nothing (he size vars a) $ \x -> case x of
Nothing -> mpret $ Error "localTerminationSidecond: reccall not ok"
Just (size', vars') -> okcall (i + 1) size' vars' as
ALCons _ a as -> okcall (i + 1) size vars as
ALProj{} -> mpret OK
ALConPar as -> __IMPOSSIBLE__
he size vars e = mmcase e $ \e -> case e of
App _ _ (Var v) _ ->
case remove v vars of
Nothing -> mbret Nothing
Just vars' -> mbret $ Just (size, vars')
App _ _ (Const c) args -> do
cd <- readIORef c
case cdcont cd of
Constructor{} ->
if size == 1 then
mbret Nothing
else
hes (size 1) vars args
_ -> mbret Nothing
_ -> mbret Nothing
hes size vars as = mmcase as $ \as -> case as of
ALNil -> mbret $ Just (size, vars)
ALCons _ a as ->
mbcase (he size vars a) $ \x -> case x of
Nothing -> mbret Nothing
Just (size', vars') -> hes size' vars' as
ALProj{} -> __IMPOSSIBLE__
ALConPar as -> __IMPOSSIBLE__
remove _ [] = Nothing
remove x (y : ys) | x == y = Just ys
remove x (y : ys) = case remove x ys of {Nothing -> Nothing; Just ys' -> Just (y : ys')}
getblks :: MExp o -> IO [Nat]
getblks tt = do
NotB (hntt, blks) <- hnn_blks (Clos [] tt)
case f blks of
Just v -> return [v]
Nothing -> case hntt of
HNApp _ (Const c) args -> do
cd <- readIORef c
case cdcont cd of
Datatype{} -> g [] args
_ -> return []
_ -> return []
where
f blks = case blks of
(_:_) -> case last blks of
HNApp _ (Var v) _ -> Just v
_ -> Nothing
_ -> Nothing
g vs args = do
NotB hnargs <- hnarglist args
case hnargs of
HNALCons _ a as -> do
NotB (_, blks) <- hnn_blks a
let vs' = case f blks of
Just v | v `notElem` vs -> v : vs
_ -> vs
g vs' as
_ -> return vs