-- Copyright (c) Facebook, Inc. and its affiliates.
--
-- This source code is licensed under the MIT license found in the
-- LICENSE file in the root directory of this source tree.
--
{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
module Retrie.SYB
  ( everywhereMWithContextBut
  , GenericCU
  , GenericMC
  , Strategy
  , topDown
  , bottomUp
  , everythingMWithContextBut
  , GenericMCQ
  , module Data.Generics
  ) where

import Control.Monad
import Data.Generics hiding (Fixity(..))

-- | Monadic rewrite with context
type GenericMC m c = forall a. Data a => c -> a -> m a

-- | Context update:
-- Given current context, child number, and parent, create new context
type GenericCU m c = forall a. Data a => c -> Int -> a -> m c

-- | Monadic traversal with pruning and context propagation.
everywhereMWithContextBut
  :: forall m c. Monad m
  => Strategy m    -- ^ Traversal order (see 'topDown' and 'bottomUp')
  -> GenericQ Bool -- ^ Short-circuiting stop condition
  -> GenericCU m c -- ^ Context update function
  -> GenericMC m c -- ^ Context-aware rewrite
  -> GenericMC m c
everywhereMWithContextBut :: forall (m :: * -> *) c.
Monad m =>
Strategy m
-> GenericQ Bool -> GenericCU m c -> GenericMC m c -> GenericMC m c
everywhereMWithContextBut Strategy m
strategy GenericQ Bool
stop GenericCU m c
upd GenericMC m c
f = c -> a -> m a
GenericMC m c
go
  where
    go :: GenericMC m c
    go :: GenericMC m c
go c
ctxt a
x
      | a -> Bool
GenericQ Bool
stop a
x    = a -> m a
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return a
x
      | Bool
otherwise = (a -> m a) -> (a -> m a) -> a -> m a
Strategy m
strategy (c -> a -> m a
GenericMC m c
f c
ctxt) (c -> a -> m a
GenericMC m c
h c
ctxt) a
x

    h :: c -> a -> m a
h c
ctxt a
parent = a -> (forall d. Data d => Int -> d -> m d) -> m a
forall (m :: * -> *) a.
(Monad m, Data a) =>
a -> (forall d. Data d => Int -> d -> m d) -> m a
gforMIndexed a
parent ((forall d. Data d => Int -> d -> m d) -> m a)
-> (forall d. Data d => Int -> d -> m d) -> m a
forall a b. (a -> b) -> a -> b
$ \Int
i d
child -> do
      c
ctxt' <- c -> Int -> a -> m c
GenericCU m c
upd c
ctxt Int
i a
parent
      c -> d -> m d
GenericMC m c
go c
ctxt' d
child

type GenericMCQ m c r = forall a. Data a => c -> a -> m r

-- | Monadic query with pruning and context propagation.
everythingMWithContextBut
  :: forall m c r. (Monad m, Monoid r)
  => GenericQ Bool -- ^ Short-circuiting stop condition
  -> GenericCU m c -- ^ Context update function
  -> GenericMCQ m c r -- ^ Context-aware query
  -> GenericMCQ m c r
everythingMWithContextBut :: forall (m :: * -> *) c r.
(Monad m, Monoid r) =>
GenericQ Bool
-> GenericCU m c -> GenericMCQ m c r -> GenericMCQ m c r
everythingMWithContextBut GenericQ Bool
stop GenericCU m c
upd GenericMCQ m c r
q = c -> a -> m r
GenericMCQ m c r
go
  where
    go :: GenericMCQ m c r
    go :: GenericMCQ m c r
go c
ctxt a
x
      | a -> Bool
GenericQ Bool
stop a
x = r -> m r
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return r
forall a. Monoid a => a
mempty
      | Bool
otherwise = do
        r
r <- c -> a -> m r
GenericMCQ m c r
q c
ctxt a
x
        [r]
rs <- a -> (forall {d}. Data d => Int -> d -> m r) -> m [r]
forall (m :: * -> *) a r.
(Monad m, Data a) =>
a -> (forall d. Data d => Int -> d -> m r) -> m [r]
gforQIndexed a
x ((forall {d}. Data d => Int -> d -> m r) -> m [r])
-> (forall {d}. Data d => Int -> d -> m r) -> m [r]
forall a b. (a -> b) -> a -> b
$ \Int
i d
child -> do
          c
ctxt' <- c -> Int -> a -> m c
GenericCU m c
upd c
ctxt Int
i a
x
          c -> d -> m r
GenericMCQ m c r
go c
ctxt' d
child
        r -> m r
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return (r -> m r) -> r -> m r
forall a b. (a -> b) -> a -> b
$ [r] -> r
forall a. Monoid a => [a] -> a
mconcat (r
rr -> [r] -> [r]
forall a. a -> [a] -> [a]
:[r]
rs)

-- | Traversal strategy.
-- Given a rewrite on the node and a rewrite on the node's children, define
-- a composite rewrite.
type Strategy m = forall a. Monad m => (a -> m a) -> (a -> m a) -> a -> m a

-- | Perform a top-down traversal.
topDown :: Strategy m
topDown :: forall (m :: * -> *) a.
Monad m =>
(a -> m a) -> (a -> m a) -> a -> m a
topDown a -> m a
p a -> m a
cs = a -> m a
p (a -> m a) -> (a -> m a) -> a -> m a
forall (m :: * -> *) a b c.
Monad m =>
(a -> m b) -> (b -> m c) -> a -> m c
>=> a -> m a
cs

-- | Perform a bottom-up traversal.
bottomUp :: Strategy m
bottomUp :: forall (m :: * -> *) a.
Monad m =>
(a -> m a) -> (a -> m a) -> a -> m a
bottomUp a -> m a
p a -> m a
cs = a -> m a
cs (a -> m a) -> (a -> m a) -> a -> m a
forall (m :: * -> *) a b c.
Monad m =>
(a -> m b) -> (b -> m c) -> a -> m c
>=> a -> m a
p

-- | 'gmapM' with arguments flipped and providing zero-based index of child
-- to mapped function.
gforMIndexed
  :: (Monad m, Data a) => a -> (forall d. Data d => Int -> d -> m d) -> m a
gforMIndexed :: forall (m :: * -> *) a.
(Monad m, Data a) =>
a -> (forall d. Data d => Int -> d -> m d) -> m a
gforMIndexed a
x forall d. Data d => Int -> d -> m d
f = (Int, m a) -> m a
forall a b. (a, b) -> b
snd ((forall e. Data e => Int -> e -> (Int, m e))
-> Int -> a -> (Int, m a)
forall d (m :: * -> *) a.
(Data d, Monad m) =>
(forall e. Data e => a -> e -> (a, m e)) -> a -> d -> (a, m d)
gmapAccumM ((Int -> e -> m e) -> Int -> e -> (Int, m e)
forall a b. (Int -> a -> b) -> Int -> a -> (Int, b)
accumIndex Int -> e -> m e
forall d. Data d => Int -> d -> m d
f) (-Int
1) a
x)
-- -1 is constructor, 0 is first child

accumIndex :: (Int -> a -> b) -> Int -> a -> (Int, b)
accumIndex :: forall a b. (Int -> a -> b) -> Int -> a -> (Int, b)
accumIndex Int -> a -> b
f Int
i a
y = let !i' :: Int
i' = Int
iInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1 in (Int
i', Int -> a -> b
f Int
i' a
y)

gforQIndexed
  :: (Monad m, Data a) => a -> (forall d. Data d => Int -> d -> m r) -> m [r]
gforQIndexed :: forall (m :: * -> *) a r.
(Monad m, Data a) =>
a -> (forall d. Data d => Int -> d -> m r) -> m [r]
gforQIndexed a
x forall d. Data d => Int -> d -> m r
f = [m r] -> m [r]
forall (t :: * -> *) (m :: * -> *) a.
(Traversable t, Monad m) =>
t (m a) -> m (t a)
forall (m :: * -> *) a. Monad m => [m a] -> m [a]
sequence ([m r] -> m [r]) -> [m r] -> m [r]
forall a b. (a -> b) -> a -> b
$ (Int, [m r]) -> [m r]
forall a b. (a, b) -> b
snd ((Int, [m r]) -> [m r]) -> (Int, [m r]) -> [m r]
forall a b. (a -> b) -> a -> b
$ (forall e. Data e => Int -> e -> (Int, m r))
-> Int -> a -> (Int, [m r])
forall d a q.
Data d =>
(forall e. Data e => a -> e -> (a, q)) -> a -> d -> (a, [q])
gmapAccumQ ((Int -> e -> m r) -> Int -> e -> (Int, m r)
forall a b. (Int -> a -> b) -> Int -> a -> (Int, b)
accumIndex Int -> e -> m r
forall d. Data d => Int -> d -> m r
f) (-Int
1) a
x