{-# LANGUAGE FlexibleContexts #-}
module Control.Search.Combinator.Or ((<|>)) where
import Control.Search.Language
import Control.Search.GeneratorInfo
import Control.Search.Generator
import Control.Search.MemoReader
import Control.Search.Memo
import Control.Monatron.Monatron hiding (Abort, L, state, cont)
import Control.Monatron.Zipper hiding (i,r)
xs1 uid lsuper rsuper = Struct ("LeftEvalState" ++ show uid) $ (THook "TreeState", "parent") : (Int, "ref_count") : [(ty, field) | (field,ty,_) <- evalState_ lsuper]
xfs1 uid lsuper rsuper = [(field,init) | (field,ty,init) <- evalState_ lsuper ]
xs2 uid lsuper rsuper = Struct ("RightEvalState" ++ show uid) $ xneedSide uid lsuper rsuper SecondS $ (Int, "ref_count") : [(ty, field) | (field,ty,_) <- evalState_ rsuper]
xfs2 uid lsuper rsuper = [(field,init) | (field,ty,init) <- evalState_ rsuper ]
xet uid lsuper rsuper FirstS = SType $ xs1 uid lsuper rsuper
xet uid lsuper rsuper SecondS = SType $ xs2 uid lsuper rsuper
xs3 uid lsuper rsuper = Struct ("LeftTreeState" ++ show uid) $ (Pointer $ SType $ xs1 uid lsuper rsuper, "evalState") : [(ty, field) | (field,ty,_) <- treeState_ lsuper]
xfs3 uid lsuper rsuper = [(field,init) | (field,ty,init) <- treeState_ lsuper]
xs4 uid lsuper rsuper = Struct ("RightTreeState" ++ show uid) $ xneedSide uid lsuper rsuper SecondS [(Pointer $ SType $ xs2 uid lsuper rsuper, "evalState")] ++ [(ty, field) | (field,ty,_) <- treeState_ rsuper]
xst uid lsuper rsuper FirstS = SType $ xs3 uid lsuper rsuper
xst uid lsuper rsuper SecondS = SType $ xs4 uid lsuper rsuper
xneedSide :: Monoid m => Int -> Eval n -> Eval n -> SeqPos -> m -> m
xneedSide uid lsuper rsuper = \pos stm -> case pos of { FirstS -> stm;
SecondS -> if (length (evalState_ rsuper) == 0) then mempty else stm;
}
orLoop :: (ReaderM SeqPos m, Evalable m) => Int -> (Eval m) -> (Eval m) -> Eval m
orLoop uid (lsuper) (rsuper) = commentEval $
Eval { structs = structs lsuper @++@ structs rsuper @++@ mystructs
, toString = "or" ++ show uid ++ "(" ++ toString lsuper ++ "," ++ toString rsuper ++ ")"
, treeState_ = [entry ("is_fst",Bool,assign true)
, ("or_union",Union [(SType s3,"fst"),(SType s4,"snd")],
\i ->
let j = withPath i in1 (et FirstS) (st FirstS)
in do cc <- cachedClone i (cloneBase j)
return ( (estate j <== New s1)
>>> (ref_count j <== 1)
>>> (parent j <== baseTstate j)
>>> cc
)
@>>>@ mseqs [init (j `withClone` (\k -> inc $ ref_count k)) | (f,init) <- fs3]
@>>>@ inite fs1 j
)]
, initH = \i -> initE lsuper (withPath i in1 (et FirstS) (st FirstS))
, evalState_ = []
, pushLeftH = push pushLeft
, pushRightH = push pushRight
, nextSameH = \i -> let j = i `withBase` "popped_estate"
in do nS1 <- local (const FirstS) $ inSeq nextSame i
nS2 <- local (const SecondS) $ inSeq nextSame i
nD1 <- local (const FirstS) $ inSeq nextDiff i
nD2 <- local (const SecondS) $ inSeq nextDiff i
return $ IfThenElse (is_fst i)
(IfThenElse (is_fst j) nS1 nD1)
(IfThenElse (is_fst j) nD2 nS2)
, nextDiffH = \i -> inSeq nextDiff i
, bodyH = \i ->
let f y z p =
let j = withPath i y (et p) (st p)
in dec_ref i >>= \deref -> bodyE z (j `onAbort` deref)
in IfThenElse (is_fst i) @$ local (const FirstS) (f in1 lsuper FirstS)
@. local (const SecondS) (f in2 rsuper SecondS)
, addH = inSeq $ addE
, failH = \i -> inSeq failE i @>>>@ dec_ref i
, returnH = \i ->
let j1 deref = (withPath i in1 (et FirstS) (st FirstS)) `onCommit` (comment "returnE-deref-j1" >>> deref >>> comment "end returnE-deref-j1")
j2 deref = (withPath i in2 (et SecondS) (st SecondS)) `onCommit` (comment "returnE-deref-j2" >>> deref >>> comment "end returnE-deref-j2")
in seqSwitch (dec_ref1 i >>= returnE lsuper . j1)
(dec_ref2 (j2 Skip) >>= returnE rsuper . j2)
, tryH = \i ->
do dr <- dec_ref i
inSeq (\super j -> tryE super (j `onAbort` (comment "Combinator/Or tryH onAbort" >>> dr ))) i
, startTryH = \i -> local (const FirstS) $ inSeq startTryE i
, tryLH = \i -> inSeq tryE_ i @>>>@ dec_ref i
, boolArraysE = boolArraysE lsuper ++ boolArraysE rsuper
, intArraysE = intArraysE lsuper ++ intArraysE rsuper
, intVarsE = intVarsE lsuper ++ intVarsE rsuper
, deleteH = deleteMe
, canBranch = return True
, complete = \i -> do sid1 <- complete lsuper (withPath i in1 (et FirstS) (st FirstS))
sid2 <- complete rsuper (withPath i in2 (et SecondS) (st SecondS))
return $ (Cond (tstate i @-> "is_fst") sid1 sid2)
}
where mystructs = ([s1,s2],[s3,s4])
s1 = xs1 uid lsuper rsuper
s2 = xs2 uid lsuper rsuper
s3 = xs3 uid lsuper rsuper
s4 = xs4 uid lsuper rsuper
fs1 = xfs1 uid lsuper rsuper
fs2 = xfs2 uid lsuper rsuper
fs3 = xfs3 uid lsuper rsuper
et = xet uid lsuper rsuper
st = xst uid lsuper rsuper
needSide = xneedSide uid lsuper rsuper
parent = \i -> estate i @=> "parent"
withSeq f = seqSwitch (f lsuper in1 FirstS) (f rsuper in2 SecondS)
withSeq_ f = seqSwitch (f lsuper in1 FirstS) (f rsuper in2 SecondS)
inSeq f = \i -> withSeq_ $ \super ins pos -> f super (withPath i ins (et pos) (st pos))
dec_ref = \i -> seqSwitch (dec_ref1 i) (dec_ref2 $ withPath i in2 (et SecondS) (st SecondS))
dec_ref1 = \i -> let j1 = withPath i in1 (et FirstS) (st FirstS)
i' = resetClone $ resetAbort $ resetCommit $ i `withBase` ("or_tstate" ++ show uid)
j2 = withPath i' in2 (et SecondS) (st SecondS)
in (local (const SecondS) $
do stmt1 <- inits rsuper j2
stmt2 <- startTryE rsuper j2
ini <- inite fs2 j2
compl <- complete lsuper j1
return ( dec (ref_count j1)
>>> (ifthen (ref_count j1 @== 0) $
(
(ifthen (Not compl) $
( SHook ("TreeState or_tstate" ++ show uid ++ ";")
>>> (baseTstate j2 <== parent j1)
>>> (is_fst i' <== false)
>>> comment "orLoop-dec_ref1-Delete" >>> Delete (estate j1)
>>> needSide SecondS (estate j2 <== New s2)
>>> needSide SecondS (ref_count j2 <== 1)
>>> ini
>>> stmt1 >>> stmt2
)
)
)
)
)
)
dec_ref2 = \j -> (complete rsuper (withPath (resetClone $ resetAbort $ resetCommit $ j `withBase` ("or_tstate" ++ show uid)) in2 (et SecondS) (st SecondS)) >>= \compl -> (return $ needSide SecondS $ dec (ref_count j) >>> ifthen (ref_count j @== 0) ( comment "orLoop-dec_ref2-Delete" >>> Delete (estate j))))
push dir = \i -> seqSwitch (push1 dir i) (push2 dir i)
push1 dir = \i ->
let j = withPath i in1 (et FirstS) (st FirstS)
in dir lsuper (j `onCommit` ( mkCopy i "is_fst"
>>> mkCopy j "evalState"
>>> inc (ref_count j)
))
push2 dir = \i ->
let j = withPath i in2 (et SecondS) (st SecondS)
in dir rsuper (j `onCommit` ( mkCopy i "is_fst"
>>> needSide SecondS (mkCopy j "evalState")
>>> needSide SecondS (inc (ref_count j))
))
in1 = \state -> state @-> "or_union" @-> "fst"
in2 = \state -> state @-> "or_union" @-> "snd"
is_fst = \i -> tstate i @-> "is_fst"
deleteMe = \i -> seqSwitch (deleteE lsuper (withPath i in1 (et FirstS) (st FirstS))) (deleteE rsuper (withPath i in2 (et SecondS) (st SecondS))) @>>>@ dec_ref i
(<|>)
:: Search
-> Search
-> Search
s1 <|> s2 =
case s1 of
Search { mkeval = evals1, runsearch = runs1 } ->
case s2 of
Search { mkeval = evals2, runsearch = runs2 } ->
Search {mkeval =
\super -> do { s2' <- evals2 $ mapE (L . L . L . mmap (mmap runL . runL) . runL) super
; s1' <- evals1 $ mapE (L . L . mmap (mmap runL . runL) . runL) super
; uid <- get
; put (uid + 1)
; return $ mapE (L . mmap L . runL) $
orLoop uid (mapE (L . mmap (mmap L) . runL . runL) s1')
(mapE (L . mmap (mmap L) . runL . runL . runL) s2')
}
, runsearch = runs2 . runs1 . runL . rReaderT FirstS . runL
}
where in1 = \state -> state @-> "or_union" @-> "fst"
in2 = \state -> state @-> "or_union" @-> "snd"