{-# LANGUAGE GADTs                    #-}
{-# LANGUAGE StandaloneKindSignatures #-}
module Control.Monad.Loop.Internal where

import           Control.Monad.Except  (ExceptT, MonadError (throwError),
                                        runExceptT)
import           Control.Monad.ST.Lazy (runST)
import           Data.Foldable         (traverse_)
import           Data.Kind             (Type)
import           Data.STRef.Lazy       (modifySTRef, newSTRef, readSTRef)


type Loop :: (Type -> Type) -> Type -> Type
data Loop m a where
  For    :: t a -> Loop m (t a)
  While  :: Loop m (t a) -> (a -> Bool) -> Loop m (t a, a -> Bool)

-- | `for` clause to start a loop
for :: Traversable t => t a -> Loop m (t a)
for :: t a -> Loop m (t a)
for = t a -> Loop m (t a)
forall (t :: * -> *) a (m :: * -> *). t a -> Loop m (t a)
For

{- | `while` clause to determine the terminal condition of a loop.

@
  for [(0::Int)..] \`while\` (\<10) \`with`\ \\(i::Int) -> lift do
    putStrLn "hi"
@

-}
while  :: Traversable t
       => Loop m (t a) -> (a -> Bool) -> Loop m (t a, a -> Bool)
while :: Loop m (t a) -> (a -> Bool) -> Loop m (t a, a -> Bool)
while = Loop m (t a) -> (a -> Bool) -> Loop m (t a, a -> Bool)
forall (m :: * -> *) (t :: * -> *) a.
Loop m (t a) -> (a -> Bool) -> Loop m (t a, a -> Bool)
While

evalLoop :: Monad m => Loop m a -> m a
evalLoop :: Loop m a -> m a
evalLoop (For t a
xs) = do
  t a -> m (t a)
forall (m :: * -> *) a. Monad m => a -> m a
return t a
xs
evalLoop (While Loop m (t a)
loop a -> Bool
pred) = do
  t a
xs <- Loop m (t a) -> m (t a)
forall (m :: * -> *) a. Monad m => Loop m a -> m a
evalLoop Loop m (t a)
loop
  (t a, a -> Bool) -> m (t a, a -> Bool)
forall (m :: * -> *) a. Monad m => a -> m a
return (t a
xs, a -> Bool
pred)

-- | start a for-each style loop.
with_ :: (Traversable t, Monad m)
      => Loop m (t a) -> (a -> ExceptT () m ()) -> m ()
with_ :: Loop m (t a) -> (a -> ExceptT () m ()) -> m ()
with_ Loop m (t a)
loop a -> ExceptT () m ()
k = do
  t a
xs <- Loop m (t a) -> m (t a)
forall (m :: * -> *) a. Monad m => Loop m a -> m a
evalLoop Loop m (t a)
loop
  ExceptT () m () -> m (Either () ())
forall e (m :: * -> *) a. ExceptT e m a -> m (Either e a)
runExceptT (ExceptT () m () -> m (Either () ()))
-> ExceptT () m () -> m (Either () ())
forall a b. (a -> b) -> a -> b
$ (a -> ExceptT () m ()) -> t a -> ExceptT () m ()
forall (t :: * -> *) (f :: * -> *) a b.
(Foldable t, Applicative f) =>
(a -> f b) -> t a -> f ()
traverse_ a -> ExceptT () m ()
k t a
xs
  () -> m ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()

enumerateTrav :: (Traversable t, Integral n) => t a -> t (n, a)
enumerateTrav :: t a -> t (n, a)
enumerateTrav t a
ts = (forall s. ST s (t (n, a))) -> t (n, a)
forall a. (forall s. ST s a) -> a
runST ((forall s. ST s (t (n, a))) -> t (n, a))
-> (forall s. ST s (t (n, a))) -> t (n, a)
forall a b. (a -> b) -> a -> b
$ do
  STRef s n
idxref <- n -> ST s (STRef s n)
forall a s. a -> ST s (STRef s a)
newSTRef n
0
  ((a -> ST s (n, a)) -> t a -> ST s (t (n, a)))
-> t a -> (a -> ST s (n, a)) -> ST s (t (n, a))
forall a b c. (a -> b -> c) -> b -> a -> c
flip (a -> ST s (n, a)) -> t a -> ST s (t (n, a))
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse t a
ts ((a -> ST s (n, a)) -> ST s (t (n, a)))
-> (a -> ST s (n, a)) -> ST s (t (n, a))
forall a b. (a -> b) -> a -> b
$ \a
value -> do
    n
idx <- STRef s n -> ST s n
forall s a. STRef s a -> ST s a
readSTRef STRef s n
idxref
    STRef s n
idxref STRef s n -> (n -> n) -> ST s ()
forall s a. STRef s a -> (a -> a) -> ST s ()
`modifySTRef` (n -> n -> n
forall a. Num a => a -> a -> a
+ n
1)
    (n, a) -> ST s (n, a)
forall (m :: * -> *) a. Monad m => a -> m a
return (n
idx, a
value)

-- | start a for-each style loop with access to indices.
withi_ :: (Traversable t, Monad m, Integral n)
       => Loop m (t a) -> ((n, a) -> ExceptT () m ()) -> m ()
withi_ :: Loop m (t a) -> ((n, a) -> ExceptT () m ()) -> m ()
withi_ Loop m (t a)
loop (n, a) -> ExceptT () m ()
k = do
  t a
xs <- Loop m (t a) -> m (t a)
forall (m :: * -> *) a. Monad m => Loop m a -> m a
evalLoop Loop m (t a)
loop
  ExceptT () m () -> m (Either () ())
forall e (m :: * -> *) a. ExceptT e m a -> m (Either e a)
runExceptT
    (ExceptT () m () -> m (Either () ()))
-> (t a -> ExceptT () m ()) -> t a -> m (Either () ())
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((n, a) -> ExceptT () m ()) -> t (n, a) -> ExceptT () m ()
forall (t :: * -> *) (f :: * -> *) a b.
(Foldable t, Applicative f) =>
(a -> f b) -> t a -> f ()
traverse_ (n, a) -> ExceptT () m ()
k (t (n, a) -> ExceptT () m ())
-> (t a -> t (n, a)) -> t a -> ExceptT () m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. t a -> t (n, a)
forall (t :: * -> *) n a.
(Traversable t, Integral n) =>
t a -> t (n, a)
enumerateTrav
    (t a -> m (Either () ())) -> t a -> m (Either () ())
forall a b. (a -> b) -> a -> b
$ t a
xs
  () -> m ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()

-- | start a for-each style loop with while clause.
withWhile_ :: (Traversable t, Monad m)
           => Loop m (t a, a -> Bool) -> (a -> ExceptT () m ()) -> m ()
withWhile_ :: Loop m (t a, a -> Bool) -> (a -> ExceptT () m ()) -> m ()
withWhile_ Loop m (t a, a -> Bool)
loop a -> ExceptT () m ()
k = do
  (t a
ts, a -> Bool
pred) <- Loop m (t a, a -> Bool) -> m (t a, a -> Bool)
forall (m :: * -> *) a. Monad m => Loop m a -> m a
evalLoop Loop m (t a, a -> Bool)
loop
  ExceptT () m () -> m (Either () ())
forall e (m :: * -> *) a. ExceptT e m a -> m (Either e a)
runExceptT
    (ExceptT () m () -> m (Either () ()))
-> (t a -> ExceptT () m ()) -> t a -> m (Either () ())
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (a -> ExceptT () m ()) -> t a -> ExceptT () m ()
forall (t :: * -> *) (f :: * -> *) a b.
(Foldable t, Applicative f) =>
(a -> f b) -> t a -> f ()
traverse_ (\a
a -> if a -> Bool
pred a
a then a -> ExceptT () m ()
k a
a else () -> ExceptT () m ()
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError ())
    (t a -> m (Either () ())) -> t a -> m (Either () ())
forall a b. (a -> b) -> a -> b
$ t a
ts
  () -> m ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()

-- | start a for-each style loop with while clause and access to indices.
withWhilei_ :: (Traversable t, Monad m, Integral n)
           => Loop m (t a, a -> Bool) -> ((n, a) -> ExceptT () m ()) -> m ()
withWhilei_ :: Loop m (t a, a -> Bool) -> ((n, a) -> ExceptT () m ()) -> m ()
withWhilei_ Loop m (t a, a -> Bool)
loop (n, a) -> ExceptT () m ()
k = do
  (t a
ts, a -> Bool
pred) <- Loop m (t a, a -> Bool) -> m (t a, a -> Bool)
forall (m :: * -> *) a. Monad m => Loop m a -> m a
evalLoop Loop m (t a, a -> Bool)
loop
  ExceptT () m () -> m (Either () ())
forall e (m :: * -> *) a. ExceptT e m a -> m (Either e a)
runExceptT
    (ExceptT () m () -> m (Either () ()))
-> (t a -> ExceptT () m ()) -> t a -> m (Either () ())
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((n, a) -> ExceptT () m ()) -> t (n, a) -> ExceptT () m ()
forall (t :: * -> *) (f :: * -> *) a b.
(Foldable t, Applicative f) =>
(a -> f b) -> t a -> f ()
traverse_ (\(n
n, a
a) -> if a -> Bool
pred a
a then (n, a) -> ExceptT () m ()
k (n
n, a
a) else () -> ExceptT () m ()
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError ())
    (t (n, a) -> ExceptT () m ())
-> (t a -> t (n, a)) -> t a -> ExceptT () m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. t a -> t (n, a)
forall (t :: * -> *) n a.
(Traversable t, Integral n) =>
t a -> t (n, a)
enumerateTrav
    (t a -> m (Either () ())) -> t a -> m (Either () ())
forall a b. (a -> b) -> a -> b
$ t a
ts
  () -> m ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()

-- | break to the outer loop.
quit :: Monad m => ExceptT () m a
quit :: ExceptT () m a
quit = () -> ExceptT () m a
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError ()

-- | break to the outer most loop.
cease :: Monad m => ExceptT () m a
cease :: ExceptT () m a
cease = ExceptT () m Any
forall (m :: * -> *) a. Monad m => ExceptT () m a
quit ExceptT () m Any -> ExceptT () m a -> ExceptT () m a
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> ExceptT () m a
forall (m :: * -> *) a. Monad m => ExceptT () m a
cease