{-# OPTIONS_HADDOCK not-home #-}
{-# LANGUAGE CPP #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE UndecidableInstances #-}
module Hedgehog.Internal.Tree (
Tree(..)
, Node(..)
, fromNode
, unfold
, unfoldForest
, expand
, prune
, render
) where
import Control.Applicative (Alternative(..))
import Control.Monad (MonadPlus(..), ap, join)
import Control.Monad.Base (MonadBase(..))
import Control.Monad.Catch (MonadThrow(..), MonadCatch(..), Exception)
import Control.Monad.Error.Class (MonadError(..))
import Control.Monad.IO.Class (MonadIO(..))
import Control.Monad.Morph (MFunctor(..), MMonad(..))
import Control.Monad.Primitive (PrimMonad(..))
import Control.Monad.Reader.Class (MonadReader(..))
import Control.Monad.State.Class (MonadState(..))
import Control.Monad.Trans.Class (MonadTrans(..))
import Control.Monad.Trans.Resource (MonadResource(..))
import Control.Monad.Writer.Class (MonadWriter(..))
#if MIN_VERSION_base(4,9,0)
import Data.Functor.Classes (Show1(..), showsPrec1)
import Data.Functor.Classes (showsUnaryWith, showsBinaryWith)
#endif
import Hedgehog.Internal.Distributive
newtype Tree m a =
Tree {
runTree :: m (Node m a)
}
data Node m a =
Node {
nodeValue :: a
, nodeChildren :: [Tree m a]
}
fromNode :: Applicative m => Node m a -> Tree m a
fromNode =
Tree . pure
unfold :: Monad m => (a -> [a]) -> a -> Tree m a
unfold f x =
Tree . pure $
Node x (unfoldForest f x)
unfoldForest :: Monad m => (a -> [a]) -> a -> [Tree m a]
unfoldForest f =
fmap (unfold f) . f
expand :: Monad m => (a -> [a]) -> Tree m a -> Tree m a
expand f m =
Tree $ do
Node x xs <- runTree m
pure . Node x $
fmap (expand f) xs ++ unfoldForest f x
prune :: Monad m => Tree m a -> Tree m a
prune m =
Tree $ do
Node x _ <- runTree m
pure $ Node x []
instance Functor m => Functor (Node m) where
fmap f (Node x xs) =
Node (f x) (fmap (fmap f) xs)
instance Functor m => Functor (Tree m) where
fmap f =
Tree . fmap (fmap f) . runTree
instance Monad m => Applicative (Node m) where
pure =
return
(<*>) =
ap
instance Monad m => Applicative (Tree m) where
pure =
return
(<*>) =
ap
instance Monad m => Monad (Node m) where
return x =
Node x []
(>>=) (Node x xs) k =
case k x of
Node y ys ->
Node y $
fmap (Tree . fmap (>>= k) . runTree) xs ++ ys
instance Monad m => Monad (Tree m) where
return x =
Tree . pure $ Node x []
(>>=) m k =
Tree $ do
Node x xs <- runTree m
Node y ys <- runTree (k x)
pure . Node y $
fmap (>>= k) xs ++ ys
instance MonadPlus m => Alternative (Tree m) where
empty =
mzero
(<|>) =
mplus
instance MonadPlus m => MonadPlus (Tree m) where
mzero =
Tree mzero
mplus x y =
Tree (runTree x `mplus` runTree y)
instance MonadTrans Tree where
lift m =
Tree $ do
x <- m
pure (Node x [])
instance MFunctor Node where
hoist f (Node x xs) =
Node x (fmap (hoist f) xs)
instance MFunctor Tree where
hoist f (Tree m) =
Tree . f $ fmap (hoist f) m
embedNode :: Monad m => (t (Node t b) -> Tree m (Node t b)) -> Node t b -> Node m b
embedNode f (Node x xs) =
Node x (fmap (embedTree f) xs)
embedTree :: Monad m => (t (Node t b) -> Tree m (Node t b)) -> Tree t b -> Tree m b
embedTree f (Tree m) =
Tree . pure . embedNode f =<< f m
instance MMonad Tree where
embed f m =
embedTree f m
distributeNode :: Transformer t Tree m => Node (t m) a -> t (Tree m) a
distributeNode (Node x xs) =
join . lift . fromNode . Node (pure x) $
fmap (pure . distributeTree) xs
distributeTree :: Transformer t Tree m => Tree (t m) a -> t (Tree m) a
distributeTree x =
distributeNode =<< hoist lift (runTree x)
instance Distributive Tree where
distribute =
distributeTree
instance PrimMonad m => PrimMonad (Tree m) where
type PrimState (Tree m) =
PrimState m
primitive =
lift . primitive
instance MonadIO m => MonadIO (Tree m) where
liftIO =
lift . liftIO
instance MonadBase b m => MonadBase b (Tree m) where
liftBase =
lift . liftBase
instance MonadThrow m => MonadThrow (Tree m) where
throwM =
lift . throwM
handleNode :: (Exception e, MonadCatch m) => (e -> Tree m a) -> Node m a -> Node m a
handleNode onErr (Node x xs) =
Node x $
fmap (handleTree onErr) xs
handleTree :: (Exception e, MonadCatch m) => (e -> Tree m a) -> Tree m a -> Tree m a
handleTree onErr m =
Tree . fmap (handleNode onErr) $
catch (runTree m) (runTree . onErr)
instance MonadCatch m => MonadCatch (Tree m) where
catch =
flip handleTree
localNode :: MonadReader r m => (r -> r) -> Node m a -> Node m a
localNode f (Node x xs) =
Node x $
fmap (localTree f) xs
localTree :: MonadReader r m => (r -> r) -> Tree m a -> Tree m a
localTree f (Tree m) =
Tree $
pure . localNode f =<< local f m
instance MonadReader r m => MonadReader r (Tree m) where
ask =
lift ask
local =
localTree
instance MonadState s m => MonadState s (Tree m) where
get =
lift get
put =
lift . put
state =
lift . state
listenNode :: MonadWriter w m => w -> Node m a -> Node m (a, w)
listenNode w (Node x xs) =
Node (x, w) $
fmap (listenTree w) xs
listenTree :: MonadWriter w m => w -> Tree m a -> Tree m (a, w)
listenTree w0 (Tree m) =
Tree $ do
(x, w) <- listen m
pure $ listenNode (mappend w0 w) x
passNode :: MonadWriter w m => Node m (a, w -> w) -> Node m a
passNode (Node (x, _) xs) =
Node x $
fmap passTree xs
passTree :: MonadWriter w m => Tree m (a, w -> w) -> Tree m a
passTree (Tree m) =
Tree $
pure . passNode =<< m
instance MonadWriter w m => MonadWriter w (Tree m) where
writer =
lift . writer
tell =
lift . tell
listen =
listenTree mempty
pass =
passTree
handleErrorNode :: MonadError e m => (e -> Tree m a) -> Node m a -> Node m a
handleErrorNode onErr (Node x xs) =
Node x $
fmap (handleErrorTree onErr) xs
handleErrorTree :: MonadError e m => (e -> Tree m a) -> Tree m a -> Tree m a
handleErrorTree onErr m =
Tree . fmap (handleErrorNode onErr) $
catchError (runTree m) (runTree . onErr)
instance MonadError e m => MonadError e (Tree m) where
throwError =
lift . throwError
catchError =
flip handleErrorTree
instance MonadResource m => MonadResource (Tree m) where
liftResourceT =
lift . liftResourceT
#if MIN_VERSION_base(4,9,0)
instance (Show1 m, Show a) => Show (Node m a) where
showsPrec =
showsPrec1
instance (Show1 m, Show a) => Show (Tree m a) where
showsPrec =
showsPrec1
instance Show1 m => Show1 (Node m) where
liftShowsPrec sp sl d (Node x xs) =
let
sp1 =
liftShowsPrec sp sl
sl1 =
liftShowList sp sl
sp2 =
liftShowsPrec sp1 sl1
in
showsBinaryWith sp sp2 "Node" d x xs
instance Show1 m => Show1 (Tree m) where
liftShowsPrec sp sl d (Tree m) =
let
sp1 =
liftShowsPrec sp sl
sl1 =
liftShowList sp sl
sp2 =
liftShowsPrec sp1 sl1
in
showsUnaryWith sp2 "Tree" d m
#endif
renderTreeLines :: Monad m => Tree m String -> m [String]
renderTreeLines (Tree m) = do
Node x xs0 <- m
xs <- renderForestLines xs0
pure $
lines (renderNode x) ++ xs
renderNode :: String -> String
renderNode xs =
case xs of
[_] ->
' ' : xs
_ ->
xs
renderForestLines :: Monad m => [Tree m String] -> m [String]
renderForestLines xs0 =
let
shift hd other =
zipWith (++) (hd : repeat other)
in
case xs0 of
[] ->
pure []
[x] -> do
s <- renderTreeLines x
pure $
shift " └╼" " " s
x : xs -> do
s <- renderTreeLines x
ss <- renderForestLines xs
pure $
shift " ├╼" " │ " s ++ ss
render :: Monad m => Tree m String -> m String
render =
fmap unlines . renderTreeLines