{-# LANGUAGE FlexibleContexts #-}



module Control.Search.Combinator.Until (until,limit,glimit) where



import Prelude hiding (until)

import Data.Int



import Control.Search.Language

import Control.Search.GeneratorInfo

import Control.Search.MemoReader

import Control.Search.Generator

import Control.Search.Combinator.Failure

import Control.Search.Stat



import Control.Monatron.Monatron hiding (Abort, L, state, cont)

import Control.Monatron.Zipper hiding (i,r)



untilLoop :: (Evalable m, ReaderM SeqPos m) => Stat -> Int -> (Eval m) -> (Eval m) -> Eval m

untilLoop cond uid lsuper' rsuper = commentEval c

 where c = Eval { structs     = structs lsuper @++@ structs rsuper @++@ mystructs 

                , toString    = "until" ++ show uid ++ "(" ++ show cond ++ "," ++ toString lsuper' ++ "," ++ toString rsuper ++ ")"

                , treeState_   = [entry ("is_fst",Bool,assign true)

                                ,("until_union", Union [(SType s3,"fst"),(SType s4,"snd")], 

         				 \i -> 

                                            let j = xpath i FirstS

                                            in  initSubEvalState j s1 fs1 FirstS)

                                ]

                , initH       = \i -> inits lsuper (i `xpath` FirstS)

                , evalState_  = [("until_complete",Bool,const $ return true)]

                , pushLeftH    = push pushLeft

                , pushRightH   = push pushRight

                , nextSameH    = \i -> let j = i `withBase` "popped_estate"

                                      in do let nS1 = local (const FirstS)  $ inSeq nextSame i

                                            let nS2 = local (const SecondS) $ inSeq nextSame i

                                            let nD1 = local (const FirstS)  $ inSeq nextDiff i

                                            let nD2 = local (const SecondS) $ inSeq nextDiff i

                                            swfst i (swfst j nS1 nD1) (swfst j nD2 nS2)

                , nextDiffH    = inSeq nextDiff

                , -- MAIN ENTRY POINT FOR NEW NODE

                  --   if (fst) {

                  --       if (seq_union.fst.evalState->cont) {

                  --       } else {

         	 --       }

                  --   } else {

                  --       if (seq_union.snd.evalState->cont) {

                  --       } else {

         	 --	  }

                  --   }

         	 bodyH       = \i -> 

                                 let f y z iscomplete pos = 

                                       do compl <- iscomplete (i `xpath` pos)

                                          let j = i `xpath` pos `onAbort` (comment "untilLoop.bodyE" >>> dec_ref i j compl pos)

                                          bodyE z j

         			 in do let s1 = local (const FirstS)  $ f in1 lsuper liscomplete FirstS

                                           s2 = local (const SecondS) $ f in2 rsuper riscomplete SecondS

                                       swfst i s1 s2

                , addH        = inSeq $ addE

                , failH       = \i -> inSeq' (\super j iscomplete pos -> iscomplete j >>= \compl -> (failE super j @>>>@ return (dec_ref i j compl pos))) i

                , returnH     = \i -> inSeq' (\super j iscomplete pos -> iscomplete j >>= \compl -> (returnE super (j `onCommit` dec_ref i j compl pos))) i

--                , continue    = \_ -> return true

                 -- IF THE CURRENT STATUS IS NOT FAILED

         	 -- EITHER (is_fst)

         	 --   if (<CONDITION>) {   // SWITCH TO NEW SEARCH

         	 --   } else {

         	 --       <TRY-REC>

          	 --   }

         	 -- OR      (!is_fst)

                , tryH        = tryX tryE

                , startTryH   = tryX startTryE

                , tryLH       = \i -> inSeq' (\super j iscomplete pos -> iscomplete j >>= \compl -> (tryE_ super j @>>>@ return (dec_ref i j compl pos))) i

                , boolArraysE  = boolArraysE lsuper ++ boolArraysE rsuper

                , intArraysE  = intArraysE lsuper ++ intArraysE rsuper

                , intVarsE    = intVarsE lsuper ++ intVarsE rsuper

                , deleteH     = error "untilLoop.deleteE NOT YET IMPLEMENTED"

                , canBranch   = canBranch lsuper >>= \l -> canBranch rsuper >>= \r -> return (l || r)

                , complete = \i -> return $ estate i @=> "until_complete"

--                , complete = const $ return false

                }

       needSide_ = \pos stmY stmN -> case pos of { FirstS -> if (length (evalState_ lsuper) == 0) then stmN else stmY;

                                                   SecondS -> if (length (evalState_ rsuper) == 0) then stmN else stmY;

                                                 }

       needSide :: Monoid m => SeqPos -> m -> m

       needSide = \pos stm -> needSide_ pos stm mempty

       mystructs = ([s1,s2],[s3,s4])

       s1        = Struct ("LeftEvalState"  ++ show uid)  $ needSide FirstS $ {- (Bool, "cont") : -} (Int, "ref_count_until" ++ show uid) : [(ty, field) | (field,ty,_) <- evalState_ lsuper]

       fs1       = [(field,init) | (field,ty,init) <- evalState_ lsuper ]

       s2        = Struct ("RightEvalState" ++ show uid) $ needSide SecondS $ {- (Bool, "cont") : -} (Int, "ref_count_until" ++ show uid) : [(ty, field) | (field,ty,_) <- evalState_ rsuper]

       fs2       = [(field,init) | (field,ty,init) <- evalState_ rsuper ]

       s3        = Struct ("LeftTreeState"  ++ show uid) $ needSide FirstS [(Pointer $ SType s1, "evalState")] ++ [(ty, field) | (field,ty,_) <- treeState_ lsuper]

       fs3       = [(field,init) | (field,ty,init) <- treeState_ lsuper]

       s4        = Struct ("RightTreeState" ++ show uid) $ needSide SecondS [(Pointer $ SType s2, "evalState")] ++ [(ty, field) | (field,ty,_) <- treeState_ rsuper]

       xpath i FirstS  = withPath i in1 (Pointer $ SType s1) (Pointer $ SType s3)

       xpath i SecondS  = withPath i in2 (Pointer $ SType s2) (Pointer $ SType s4)

       in1       = \state -> state @-> "until_union" @-> "fst"

       in2       = \state -> state @-> "until_union" @-> "snd"

       is_fst    = \i -> tstate i @-> "is_fst"

       withSeq f = seqSwitch (f lsuper in1) (f rsuper in2)

       withSeq_ f = seqSwitch (f lsuper in1 FirstS) (f rsuper in2 SecondS)

       inSeq  f  = \i -> withSeq_ $ \super ins pos -> f super (i `xpath` pos)

       inSeq' f  = \i -> seqSwitch (f lsuper (i `xpath` FirstS) liscomplete FirstS)  

                                   (f rsuper (i `xpath` SecondS) riscomplete SecondS)

       dec_ref   = \i j iscomplete pos -> needSide_ pos (dec (ref_countx j $ "until" ++ show uid) >>>

                                                         ifthen (ref_countx j ("until" ++ show uid) @== 0) (

                                                        {-       DebugValue ("until" ++ show uid ++ ": left branch finished with complete") iscomplete
                                                           >>> DebugValue ("until" ++ show uid ++ ": until's previous completeness was") (complet i)
                                                           >>> -} (complet i <== (complet i &&& iscomplete)) >>> Delete (estate j)

                                                         )

                                                        ) (complet i <== (complet i &&& iscomplete))

       push dir  = \i -> seqSwitch (push1 dir i) (push2 dir i)

       push1 dir = \i -> 

                          let j = i `xpath` FirstS

                          in  dir lsuper (j `onCommit` (   mkCopy i "is_fst"

                                                       >>> mkCopy j "evalState"

                                                       >>> inc (ref_countx j $ "until" ++ show uid)

                                                       ))

       push2 dir = \i -> 

                          let j = i `xpath` SecondS

                          in  dir rsuper (j `onCommit` (    mkCopy i "is_fst"

                                                       >>> mkCopy j "evalState"

                                                       >>> inc (ref_countx j $ "until" ++ show uid)

                                                      ))

       lsuper = evalStat cond lsuper'

       complet  = \i -> estate i @=> "until_complete"

       liscomplete = complete lsuper'

       riscomplete = complete rsuper

       initSubEvalState = \j s fs pos -> return (needSide pos (    (estate j <== New s)  

					                       >>> (ref_countx j ("until" ++ show uid) <== 1)

--			                                       >>> (cont j <== true)

                                                              )

                                                )

                                         @>>>@ inite fs j

       tryX        = \x i -> do lc <- liscomplete (i `xpath` FirstS)

                                rc <- riscomplete (i `xpath` SecondS)

                                let j1  = i `xpath` FirstS `onAbort` (comment "untilLoop.tryE j1" >>> dec_ref i j1 lc FirstS)

                                    j2  = i `xpath` SecondS `onAbort` (comment "untilLoop.tryE j2" >>> dec_ref i j2 rc SecondS)

                                    j2b = i `xpath` SecondS `onAbort` (comment "untilLoop.tryE j2b" >>> dec_ref i j2b rc SecondS)

                                seqSwitch (x       lsuper j1 >>= \try1 ->

                                                   deleteE lsuper j1 >>= \delete1 ->

                                                   (local (const SecondS) $

                                                     do stmt1 <- inits rsuper j2b

                                                        stmt2 <- startTryE rsuper j2b

                                                        ini <- initSubEvalState j2b s2 fs2 SecondS

                                                        return (   delete1

         						      >>> dec_ref i j1 lc FirstS

                                                     	      >>> (is_fst i <== false)

         						      >>> ini

                                                               >>> comment "initTreeState_ j2b rsuper" 

         						      >>> stmt1 

                                                               >>> comment "tryE rsuper j2b" 

         						      >>> comment ("length: " ++ show (length (abort_ j2b)))

         						      >>> stmt2)

                                                   ) >>= \start2 -> readStat cond >>= \r -> return $ IfThenElse (r j1) ({- (DebugOutput $ "until" ++ show uid ++ " switches") >>> -} start2) try1

                                                  )

                                                  (x rsuper j2) 

       swfst i t e = do  b1 <- canBranch lsuper

                         b2 <- canBranch rsuper

                         if (b1 && b2) then do { tt <- t; ee <- e; return $ IfThenElse (is_fst i) tt ee }

                                       else if b1 then t

                                                  else e





limit :: Int32 -> Stat -> Search -> Search

limit n stat s = until (stat #>= constStat (const (IVal n))) s failure



glimit :: Stat -> Search -> Search

glimit cond s = until (cond) s failure



until 

  :: Stat

  -> Search

  -> Search

  -> Search

until cond s1 s2 = 

  case s1 of

    Search { mkeval = evals1, runsearch = runs1 } ->

      case s2 of

        Search { mkeval = evals2, runsearch = runs2 } ->

	  Search { mkeval =

	          \super -> do { s2' <- evals2 $ mapE (L . L . L . mmap (mmap runL . runL) . runL)  super

	                       ; s1' <- evals1 $ mapE (L . L . mmap (mmap runL . runL) . runL) super

		   	       ; uid <- get

		   	       ; put (uid + 1)

	                       ; return $ mapE (L . mmap L . runL) $ memoLoop $

		   			untilLoop cond uid (mapE ({- L . mmap (mmap L) . runL . runL-} mmap L . runL) s1')

	                                                      (mapE ({- L . mmap (mmap L) . runL . runL . runL-} mmap L . runL . runL) s2')

	                       }

	         , runsearch  = runs2 . runs1 . runL . rReaderT FirstS . runL

	         }