{-# LANGUAGE ConstraintKinds     #-}
{-# LANGUAGE CPP                 #-}
{-# LANGUAGE FlexibleContexts    #-}
{-# LANGUAGE Rank2Types          #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeOperators       #-}

--------------------------------------------------------------------------------
-- |
-- Module      :  Data.Comp.Thunk
-- Copyright   :  (c) 2011 Patrick Bahr
-- License     :  BSD3
-- Maintainer  :  Patrick Bahr <paba@diku.dk>
-- Stability   :  experimental
-- Portability :  non-portable (GHC Extensions)
--
-- This modules defines terms & contexts with thunks, with deferred
-- monadic computations.
--
--------------------------------------------------------------------------------

module Data.Comp.Thunk
    (TermT
    ,CxtT
    ,thunk
    ,whnf
    ,whnf'
    ,whnfPr
    ,nf
    ,nfPr
    ,eval
    ,eval2
    ,deepEval
    ,deepEval2
    ,(#>)
    ,(#>>)
    ,AlgT
    ,cataT
    ,cataTM
    ,eqT
    ,strict
    ,strictAt) where

import Data.Comp.Algebra
import Data.Comp.Equality
import Data.Comp.Mapping
import Data.Comp.Ops ((:+:) (..), fromInr)
import Data.Comp.Sum
import Data.Comp.Term
import Data.Foldable hiding (and)

import qualified Data.IntSet as IntSet

import Control.Monad hiding (mapM, sequence)
import Data.Traversable

-- Control.Monad.Fail import is redundant since GHC 8.8.1
#if !MIN_VERSION_base(4,13,0)
import Control.Monad.Fail (MonadFail)
#endif

import Prelude hiding (foldl, foldl1, foldr, foldr1, mapM, sequence)


-- | This type represents terms with thunks.
type TermT m f = Term (m :+: f)

-- | This type represents contexts with thunks.
type CxtT  m h f a = Cxt h  (m :+: f) a


-- | This function turns a monadic computation into a thunk.
thunk :: m (CxtT m h f a) -> CxtT m h f a
thunk :: forall (m :: * -> *) h (f :: * -> *) a.
m (CxtT m h f a) -> CxtT m h f a
thunk = forall (g :: * -> *) (f :: * -> *) h a.
SigFun g f -> g (Cxt h f a) -> Cxt h f a
inject_ forall {k} (f :: k -> *) (g :: k -> *) (e :: k). f e -> (:+:) f g e
Inl

-- | This function evaluates all thunks until a non-thunk node is
-- found.
whnf :: Monad m => TermT m f -> m (f (TermT m f))
whnf :: forall (m :: * -> *) (f :: * -> *).
Monad m =>
TermT m f -> m (f (TermT m f))
whnf (Term (Inl m (Cxt NoHole (m :+: f) ())
m)) = m (Cxt NoHole (m :+: f) ())
m forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= forall (m :: * -> *) (f :: * -> *).
Monad m =>
TermT m f -> m (f (TermT m f))
whnf
whnf (Term (Inr f (Cxt NoHole (m :+: f) ())
t)) = forall (m :: * -> *) a. Monad m => a -> m a
return f (Cxt NoHole (m :+: f) ())
t

whnf' :: Monad m => TermT m f -> m (TermT m f)
whnf' :: forall (m :: * -> *) (f :: * -> *).
Monad m =>
TermT m f -> m (TermT m f)
whnf' = forall (m :: * -> *) a1 r. Monad m => (a1 -> r) -> m a1 -> m r
liftM (forall (g :: * -> *) (f :: * -> *) h a.
SigFun g f -> g (Cxt h f a) -> Cxt h f a
inject_ forall {k} (f :: k -> *) (g :: k -> *) (e :: k). g e -> (:+:) f g e
Inr) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *) (f :: * -> *).
Monad m =>
TermT m f -> m (f (TermT m f))
whnf

-- | This function first evaluates the argument term into whnf via
-- 'whnf' and then projects the top-level signature to the desired
-- subsignature. Failure to do the projection is signalled as a
-- failure in the monad.
whnfPr :: (MonadFail m, g :<: f) => TermT m f -> m (g (TermT m f))
whnfPr :: forall (m :: * -> *) (g :: * -> *) (f :: * -> *).
(MonadFail m, g :<: f) =>
TermT m f -> m (g (TermT m f))
whnfPr TermT m f
t = do f (TermT m f)
res <- forall (m :: * -> *) (f :: * -> *).
Monad m =>
TermT m f -> m (f (TermT m f))
whnf TermT m f
t
              case forall (f :: * -> *) (g :: * -> *) a.
(f :<: g) =>
g a -> Maybe (f a)
proj f (TermT m f)
res of
                Just g (TermT m f)
res' -> forall (m :: * -> *) a. Monad m => a -> m a
return g (TermT m f)
res'
                Maybe (g (TermT m f))
Nothing -> forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"projection failed"

-- | This function inspects the topmost non-thunk node (using
-- 'whnf') according to the given function.
eval :: Monad m => (f (TermT m f) -> TermT m f) -> TermT m f -> TermT m f
eval :: forall (m :: * -> *) (f :: * -> *).
Monad m =>
(f (TermT m f) -> TermT m f) -> TermT m f -> TermT m f
eval f (TermT m f) -> TermT m f
cont TermT m f
t = forall (m :: * -> *) h (f :: * -> *) a.
m (CxtT m h f a) -> CxtT m h f a
thunk forall a b. (a -> b) -> a -> b
$ f (TermT m f) -> m (TermT m f)
cont' forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (m :: * -> *) (f :: * -> *).
Monad m =>
TermT m f -> m (f (TermT m f))
whnf TermT m f
t
    where cont' :: f (TermT m f) -> m (TermT m f)
cont' = forall (m :: * -> *) a. Monad m => a -> m a
return forall b c a. (b -> c) -> (a -> b) -> a -> c
. f (TermT m f) -> TermT m f
cont

infixl 1 #>

-- | Variant of 'eval' with flipped argument positions
(#>) :: Monad m => TermT m f -> (f (TermT m f) -> TermT m f) -> TermT m f
#> :: forall (m :: * -> *) (f :: * -> *).
Monad m =>
TermT m f -> (f (TermT m f) -> TermT m f) -> TermT m f
(#>) = forall a b c. (a -> b -> c) -> b -> a -> c
flip forall (m :: * -> *) (f :: * -> *).
Monad m =>
(f (TermT m f) -> TermT m f) -> TermT m f -> TermT m f
eval

-- | This function inspects the topmost non-thunk nodes of two terms
-- (using 'whnf') according to the given function.
eval2 :: Monad m => (f (TermT m f) -> f (TermT m f) -> TermT m f)
                 -> TermT m f -> TermT m f -> TermT m f
eval2 :: forall (m :: * -> *) (f :: * -> *).
Monad m =>
(f (TermT m f) -> f (TermT m f) -> TermT m f)
-> TermT m f -> TermT m f -> TermT m f
eval2 f (TermT m f) -> f (TermT m f) -> TermT m f
cont TermT m f
x TermT m f
y = (\ f (TermT m f)
x' -> f (TermT m f) -> f (TermT m f) -> TermT m f
cont f (TermT m f)
x' forall (m :: * -> *) (f :: * -> *).
Monad m =>
(f (TermT m f) -> TermT m f) -> TermT m f -> TermT m f
`eval` TermT m f
y) forall (m :: * -> *) (f :: * -> *).
Monad m =>
(f (TermT m f) -> TermT m f) -> TermT m f -> TermT m f
`eval` TermT m f
x

-- | This function evaluates all thunks.
nf :: (Monad m, Traversable f) => TermT m f -> m (Term f)
nf :: forall (m :: * -> *) (f :: * -> *).
(Monad m, Traversable f) =>
TermT m f -> m (Term f)
nf = forall (m :: * -> *) a1 r. Monad m => (a1 -> r) -> m a1 -> m r
liftM forall (b :: * -> *) a c. b (Cxt a b c) -> Cxt a b c
Term forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM forall (m :: * -> *) (f :: * -> *).
(Monad m, Traversable f) =>
TermT m f -> m (Term f)
nf forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< forall (m :: * -> *) (f :: * -> *).
Monad m =>
TermT m f -> m (f (TermT m f))
whnf

-- | This function evaluates all thunks while simultaneously
-- projecting the term to a smaller signature. Failure to do the
-- projection is signalled as a failure in the monad as in 'whnfPr'.
nfPr :: (MonadFail m, Traversable g, g :<: f) => TermT m f -> m (Term g)
nfPr :: forall (m :: * -> *) (g :: * -> *) (f :: * -> *).
(MonadFail m, Traversable g, g :<: f) =>
TermT m f -> m (Term g)
nfPr = forall (m :: * -> *) a1 r. Monad m => (a1 -> r) -> m a1 -> m r
liftM forall (b :: * -> *) a c. b (Cxt a b c) -> Cxt a b c
Term forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM forall (m :: * -> *) (g :: * -> *) (f :: * -> *).
(MonadFail m, Traversable g, g :<: f) =>
TermT m f -> m (Term g)
nfPr forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< forall (m :: * -> *) (g :: * -> *) (f :: * -> *).
(MonadFail m, g :<: f) =>
TermT m f -> m (g (TermT m f))
whnfPr

-- | This function inspects a term (using 'nf') according to the
-- given function.
deepEval :: (Traversable f, Monad m) =>
            (Term f -> TermT m f) -> TermT m f -> TermT m f
deepEval :: forall (f :: * -> *) (m :: * -> *).
(Traversable f, Monad m) =>
(Term f -> TermT m f) -> TermT m f -> TermT m f
deepEval Term f -> TermT m f
cont TermT m f
v = case forall (g :: * -> *) (f :: * -> *).
Traversable g =>
SigFunM Maybe f g -> CxtFunM Maybe f g
deepProject_ forall {k} (f :: k -> *) (g :: k -> *) (e :: k).
(:+:) f g e -> Maybe (g e)
fromInr TermT m f
v of
                    Just Term f
v' -> Term f -> TermT m f
cont Term f
v'
                    Maybe (Term f)
_ -> forall (m :: * -> *) h (f :: * -> *) a.
m (CxtT m h f a) -> CxtT m h f a
thunk forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) a1 r. Monad m => (a1 -> r) -> m a1 -> m r
liftM Term f -> TermT m f
cont forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) (f :: * -> *).
(Monad m, Traversable f) =>
TermT m f -> m (Term f)
nf TermT m f
v

infixl 1 #>>

-- | Variant of 'deepEval' with flipped argument positions
(#>>) :: (Monad m, Traversable f) => TermT m f -> (Term f -> TermT m f) -> TermT m f
#>> :: forall (m :: * -> *) (f :: * -> *).
(Monad m, Traversable f) =>
TermT m f -> (Term f -> TermT m f) -> TermT m f
(#>>) = forall a b c. (a -> b -> c) -> b -> a -> c
flip forall (f :: * -> *) (m :: * -> *).
(Traversable f, Monad m) =>
(Term f -> TermT m f) -> TermT m f -> TermT m f
deepEval

-- | This function inspects two terms (using 'nf') according
-- to the given function.
deepEval2 :: (Monad m, Traversable f) =>
             (Term f -> Term f -> TermT m f)
          -> TermT m f -> TermT m f -> TermT m f
deepEval2 :: forall (m :: * -> *) (f :: * -> *).
(Monad m, Traversable f) =>
(Term f -> Term f -> TermT m f)
-> TermT m f -> TermT m f -> TermT m f
deepEval2 Term f -> Term f -> TermT m f
cont TermT m f
x TermT m f
y = (\ Term f
x' -> Term f -> Term f -> TermT m f
cont Term f
x' forall (f :: * -> *) (m :: * -> *).
(Traversable f, Monad m) =>
(Term f -> TermT m f) -> TermT m f -> TermT m f
`deepEval` TermT m f
y ) forall (f :: * -> *) (m :: * -> *).
(Traversable f, Monad m) =>
(Term f -> TermT m f) -> TermT m f -> TermT m f
`deepEval` TermT m f
x

-- | This type represents algebras which have terms with thunks as
-- carrier.
type AlgT m f g = Alg f (TermT m g)

-- | This combinator runs a monadic catamorphism on a term with thunks
cataTM :: forall m f a . (Traversable f, Monad m) => AlgM m f a -> TermT m f -> m a
cataTM :: forall (m :: * -> *) (f :: * -> *) a.
(Traversable f, Monad m) =>
AlgM m f a -> TermT m f -> m a
cataTM AlgM m f a
alg = TermT m f -> m a
run where
    -- implemented directly, otherwise Traversable m constraint needed
    run :: TermT m f -> m a
    run :: TermT m f -> m a
run (Term (Inl m (TermT m f)
m)) = m (TermT m f)
m forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= TermT m f -> m a
run
    run (Term (Inr f (TermT m f)
t)) =  forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM TermT m f -> m a
run f (TermT m f)
t forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= AlgM m f a
alg

-- | This combinator runs a catamorphism on a term with thunks.
cataT :: (Traversable f, Monad m) => Alg f a -> TermT m f -> m a
cataT :: forall (f :: * -> *) (m :: * -> *) a.
(Traversable f, Monad m) =>
Alg f a -> TermT m f -> m a
cataT Alg f a
alg = forall (m :: * -> *) (f :: * -> *) a.
(Traversable f, Monad m) =>
AlgM m f a -> TermT m f -> m a
cataTM (forall (m :: * -> *) a. Monad m => a -> m a
return forall b c a. (b -> c) -> (a -> b) -> a -> c
. Alg f a
alg)

-- | This combinator makes the evaluation of the given functor
-- application strict by evaluating all thunks of immediate subterms.
strict :: (f :<: g, Traversable f, Monad m) => f (TermT m g) -> TermT m g
strict :: forall (f :: * -> *) (g :: * -> *) (m :: * -> *).
(f :<: g, Traversable f, Monad m) =>
f (TermT m g) -> TermT m g
strict f (TermT m g)
x = forall (m :: * -> *) h (f :: * -> *) a.
m (CxtT m h f a) -> CxtT m h f a
thunk forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) a1 r. Monad m => (a1 -> r) -> m a1 -> m r
liftM (forall (g :: * -> *) (f :: * -> *) h a.
SigFun g f -> g (Cxt h f a) -> Cxt h f a
inject_ (forall {k} (f :: k -> *) (g :: k -> *) (e :: k). g e -> (:+:) f g e
Inr forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (f :: * -> *) (g :: * -> *) a. (f :<: g) => f a -> g a
inj)) forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM forall (m :: * -> *) (f :: * -> *).
Monad m =>
TermT m f -> m (TermT m f)
whnf' f (TermT m g)
x

-- | This type represents position representations for a functor
-- @f@. It is a function that extracts a number of components (of
-- polymorphic type @a@) from a functorial value and puts it into a
-- list.
type Pos f = forall a . f a -> [a]

-- | This combinator is a variant of 'strict' that only makes a subset
-- of the arguments of a functor application strict. The first
-- argument of this combinator specifies which positions are supposed
-- to be strict.
strictAt :: (f :<: g, Traversable f, Monad m) => Pos f ->  f (TermT m g) -> TermT m g
strictAt :: forall (f :: * -> *) (g :: * -> *) (m :: * -> *).
(f :<: g, Traversable f, Monad m) =>
Pos f -> f (TermT m g) -> TermT m g
strictAt Pos f
p f (TermT m g)
s = forall (m :: * -> *) h (f :: * -> *) a.
m (CxtT m h f a) -> CxtT m h f a
thunk forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) a1 r. Monad m => (a1 -> r) -> m a1 -> m r
liftM (forall (g :: * -> *) (f :: * -> *) h a.
SigFun g f -> g (Cxt h f a) -> Cxt h f a
inject_ (forall {k} (f :: k -> *) (g :: k -> *) (e :: k). g e -> (:+:) f g e
Inr forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (f :: * -> *) (g :: * -> *) a. (f :<: g) => f a -> g a
inj)) forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM forall {m :: * -> *} {f :: * -> *}.
Monad m =>
Numbered (TermT m f) -> m (TermT m f)
run f (Numbered (TermT m g))
s'
    where s' :: f (Numbered (TermT m g))
s'  = forall (f :: * -> *) a. Traversable f => f a -> f (Numbered a)
number f (TermT m g)
s
          isStrict :: Numbered a -> Bool
isStrict (Numbered Int
i a
_) = Int -> IntSet -> Bool
IntSet.member Int
i forall a b. (a -> b) -> a -> b
$ [Int] -> IntSet
IntSet.fromList forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map (\(Numbered Int
i TermT m g
_) -> Int
i) forall a b. (a -> b) -> a -> b
$ Pos f
p f (Numbered (TermT m g))
s'
          run :: Numbered (TermT m f) -> m (TermT m f)
run Numbered (TermT m f)
e | forall {a}. Numbered a -> Bool
isStrict Numbered (TermT m f)
e = forall (m :: * -> *) (f :: * -> *).
Monad m =>
TermT m f -> m (TermT m f)
whnf' forall a b. (a -> b) -> a -> b
$ forall a. Numbered a -> a
unNumbered Numbered (TermT m f)
e
                | Bool
otherwise  = forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ forall a. Numbered a -> a
unNumbered Numbered (TermT m f)
e


-- | This function decides equality of terms with thunks.
eqT :: (EqF f, Foldable f, Functor f, Monad m) => TermT m f -> TermT m f -> m Bool
eqT :: forall (f :: * -> *) (m :: * -> *).
(EqF f, Foldable f, Functor f, Monad m) =>
TermT m f -> TermT m f -> m Bool
eqT TermT m f
s TermT m f
t = do f (TermT m f)
s' <- forall (m :: * -> *) (f :: * -> *).
Monad m =>
TermT m f -> m (f (TermT m f))
whnf TermT m f
s
             f (TermT m f)
t' <- forall (m :: * -> *) (f :: * -> *).
Monad m =>
TermT m f -> m (f (TermT m f))
whnf TermT m f
t
             case forall (f :: * -> *) a b.
(EqF f, Functor f, Foldable f) =>
f a -> f b -> Maybe [(a, b)]
eqMod f (TermT m f)
s' f (TermT m f)
t' of
               Maybe [(TermT m f, TermT m f)]
Nothing -> forall (m :: * -> *) a. Monad m => a -> m a
return Bool
False
               Just [(TermT m f, TermT m f)]
l -> forall (m :: * -> *) a1 r. Monad m => (a1 -> r) -> m a1 -> m r
liftM forall (t :: * -> *). Foldable t => t Bool -> Bool
and forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry forall (f :: * -> *) (m :: * -> *).
(EqF f, Foldable f, Functor f, Monad m) =>
TermT m f -> TermT m f -> m Bool
eqT) [(TermT m f, TermT m f)]
l