-- | N-ary trees.

{-# LANGUAGE FlexibleInstances, TypeSynonymInstances #-}
module Math.Combinat.Trees.Nary 
  (      
    -- * Types
    module Data.Tree
  , Tree(..)
    -- * Regular trees  
  , ternaryTrees
  , regularNaryTrees
  , semiRegularTrees
  , countTernaryTrees
  , countRegularNaryTrees
    -- * \"derivation trees\"
  , derivTrees
    -- * ASCII drawings
  , asciiTreeVertical_
  , asciiTreeVertical
  , asciiTreeVerticalLeavesOnly
    -- * Graphviz drawing
  , Dot
  , graphvizDotTree  
  , graphvizDotForest
    -- * Classifying nodes
  , classifyTreeNode
  , isTreeLeaf  , isTreeNode
  , isTreeLeaf_ , isTreeNode_
  , treeNodeNumberOfChildren 
    -- * Counting nodes
  , countTreeNodes
  , countTreeLeaves
  , countTreeLabelsWith
  , countTreeNodesWith 
    -- * Left and right spines
  , leftSpine  , leftSpine_
  , rightSpine , rightSpine_
  , leftSpineLength , rightSpineLength
    -- * Unique labels
  , addUniqueLabelsTree
  , addUniqueLabelsForest
  , addUniqueLabelsTree_
  , addUniqueLabelsForest_
    -- * Labelling by depth
  , labelDepthTree
  , labelDepthForest
  , labelDepthTree_
  , labelDepthForest_
    -- * Labelling by number of children
  , labelNChildrenTree
  , labelNChildrenForest
  , labelNChildrenTree_
  , labelNChildrenForest_
    
  ) where


--------------------------------------------------------------------------------

import Data.Tree
import Data.List

import Control.Applicative

--import Control.Monad.State
import Control.Monad.Trans.State
import Data.Traversable (traverse)

import Math.Combinat.Sets                  ( listTensor )
import Math.Combinat.Partitions.Multiset   ( partitionMultiset )
import Math.Combinat.Compositions          ( compositions )
import Math.Combinat.Numbers               ( factorial, binomial )

import Math.Combinat.Trees.Graphviz ( Dot , graphvizDotForest , graphvizDotTree )

import Math.Combinat.Classes
import Math.Combinat.ASCII as ASCII
import Math.Combinat.Helper

--------------------------------------------------------------------------------

instance HasNumberOfNodes (Tree a) where
  numberOfNodes :: Tree a -> Int
numberOfNodes = Tree a -> Int
forall p a. Num p => Tree a -> p
go where
    go :: Tree a -> p
go (Node a
label Forest a
subforest) = if Forest a -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null Forest a
subforest 
      then p
0 
      else p
1 p -> p -> p
forall a. Num a => a -> a -> a
+ [p] -> p
forall a. Num a => [a] -> a
sum' ((Tree a -> p) -> Forest a -> [p]
forall a b. (a -> b) -> [a] -> [b]
map Tree a -> p
go Forest a
subforest)

instance HasNumberOfLeaves (Tree a) where
  numberOfLeaves :: Tree a -> Int
numberOfLeaves = Tree a -> Int
forall p a. Num p => Tree a -> p
go where
    go :: Tree a -> p
go (Node a
label Forest a
subforest) = if Forest a -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null Forest a
subforest 
      then p
1
      else [p] -> p
forall a. Num a => [a] -> a
sum' ((Tree a -> p) -> Forest a -> [p]
forall a b. (a -> b) -> [a] -> [b]
map Tree a -> p
go Forest a
subforest)

--------------------------------------------------------------------------------

-- | @regularNaryTrees d n@ returns the list of (rooted) trees on @n@ nodes where each
-- node has exactly @d@ children. Note that the leaves do not count in @n@.
-- Naive algorithm.
regularNaryTrees 
  :: Int         -- ^ degree = number of children of each node
  -> Int         -- ^ number of nodes
  -> [Tree ()]
regularNaryTrees :: Int -> Int -> [Tree ()]
regularNaryTrees Int
d = Int -> [Tree ()]
go where
  go :: Int -> [Tree ()]
go Int
0 = [ () -> [Tree ()] -> Tree ()
forall a. a -> Forest a -> Tree a
Node () [] ]
  go Int
n = [ () -> [Tree ()] -> Tree ()
forall a. a -> Forest a -> Tree a
Node () [Tree ()]
cs
         | [Int]
is <- Int -> Int -> [[Int]]
forall a. Integral a => a -> a -> [[Int]]
compositions Int
d (Int
nInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1) 
         , [Tree ()]
cs <- [[Tree ()]] -> [[Tree ()]]
forall a. [[a]] -> [[a]]
listTensor [ Int -> [Tree ()]
go Int
i | Int
i<-[Int]
is ] 
         ]
  
-- | Ternary trees on @n@ nodes (synonym for @regularNaryTrees 3@)
ternaryTrees :: Int -> [Tree ()]  
ternaryTrees :: Int -> [Tree ()]
ternaryTrees = Int -> Int -> [Tree ()]
regularNaryTrees Int
3

-- | We have 
--
-- > length (regularNaryTrees d n) == countRegularNaryTrees d n == \frac {1} {(d-1)n+1} \binom {dn} {n} 
--
countRegularNaryTrees :: (Integral a, Integral b) => a -> b -> Integer
countRegularNaryTrees :: a -> b -> Integer
countRegularNaryTrees a
d b
n = Integer -> Integer -> Integer
forall a. Integral a => a -> a -> Integer
binomial (Integer
ddInteger -> Integer -> Integer
forall a. Num a => a -> a -> a
*Integer
nn) Integer
nn Integer -> Integer -> Integer
forall a. Integral a => a -> a -> a
`div` ((Integer
ddInteger -> Integer -> Integer
forall a. Num a => a -> a -> a
-Integer
1)Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
*Integer
nnInteger -> Integer -> Integer
forall a. Num a => a -> a -> a
+Integer
1) where
  dd :: Integer
dd = a -> Integer
forall a b. (Integral a, Num b) => a -> b
fromIntegral a
d :: Integer
  nn :: Integer
nn = b -> Integer
forall a b. (Integral a, Num b) => a -> b
fromIntegral b
n :: Integer 

-- | @\# = \\frac {1} {(2n+1} \\binom {3n} {n}@
countTernaryTrees :: Integral a => a -> Integer  
countTernaryTrees :: a -> Integer
countTernaryTrees = Int -> a -> Integer
forall a b. (Integral a, Integral b) => a -> b -> Integer
countRegularNaryTrees (Int
3::Int)

--------------------------------------------------------------------------------

-- | All trees on @n@ nodes where the number of children of all nodes is
-- in element of the given set. Example:
--
-- > autoTabulate RowMajor (Right 5) $ map asciiTreeVertical 
-- >                                 $ map labelNChildrenTree_ 
-- >                                 $ semiRegularTrees [2,3] 2
-- >
-- > [ length $ semiRegularTrees [2,3] n | n<-[0..] ] == [1,2,10,66,498,4066,34970,312066,2862562,26824386,...]
--
-- The latter sequence is A027307 in OEIS: <https://oeis.org/A027307>
--
-- Remark: clearly, we have
--
-- > semiRegularTrees [d] n == regularNaryTrees d n
--
-- 
semiRegularTrees 
  :: [Int]         -- ^ set of allowed number of children
  -> Int           -- ^ number of nodes
  -> [Tree ()]
semiRegularTrees :: [Int] -> Int -> [Tree ()]
semiRegularTrees []    Int
n = if Int
nInt -> Int -> Bool
forall a. Eq a => a -> a -> Bool
==Int
0 then [() -> [Tree ()] -> Tree ()
forall a. a -> Forest a -> Tree a
Node () []] else []
semiRegularTrees [Int]
dset_ Int
n = 
  if [Int] -> Int
forall a. [a] -> a
head [Int]
dset Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>=Int
1 
    then Int -> [Tree ()]
go Int
n
    else [Char] -> [Tree ()]
forall a. HasCallStack => [Char] -> a
error [Char]
"semiRegularTrees: expecting a list of positive integers"
  where
    dset :: [Int]
dset = ([Int] -> Int) -> [[Int]] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map [Int] -> Int
forall a. [a] -> a
head ([[Int]] -> [Int]) -> [[Int]] -> [Int]
forall a b. (a -> b) -> a -> b
$ [Int] -> [[Int]]
forall a. Eq a => [a] -> [[a]]
group ([Int] -> [[Int]]) -> [Int] -> [[Int]]
forall a b. (a -> b) -> a -> b
$ [Int] -> [Int]
forall a. Ord a => [a] -> [a]
sort ([Int] -> [Int]) -> [Int] -> [Int]
forall a b. (a -> b) -> a -> b
$ [Int]
dset_
    
    go :: Int -> [Tree ()]
go Int
0 = [ () -> [Tree ()] -> Tree ()
forall a. a -> Forest a -> Tree a
Node () [] ]
    go Int
n = [ () -> [Tree ()] -> Tree ()
forall a. a -> Forest a -> Tree a
Node () [Tree ()]
cs
           | Int
d <- [Int]
dset
           , [Int]
is <- Int -> Int -> [[Int]]
forall a. Integral a => a -> a -> [[Int]]
compositions Int
d (Int
nInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1) 
           , [Tree ()]
cs <- [[Tree ()]] -> [[Tree ()]]
forall a. [[a]] -> [[a]]
listTensor [ Int -> [Tree ()]
go Int
i | Int
i<-[Int]
is ]
           ]
           
{- 

NOTES:

A006318 = [ length $ semiRegularTrees [1,2] n | n<-[0..] ] == [1,2,6,22,90,394,1806,8558,41586,206098,1037718.. ]
??      = [ length $ semiRegularTrees [1,3] n | n<-[0..] ] == [1,2,8,44,280,1936,14128,107088,834912,6652608 .. ]
??      = [ length $ semiRegularTrees [1,4] n | n<-[0..] ] == [1,2,10,74,642,6082,60970,635818,6826690

A027307 = [ length $ semiRegularTrees [2,3] n | n<-[0..] ] == [1,2,10,66,498,4066,34970,312066,2862562,26824386,...]
A219534 = [ length $ semiRegularTrees [2,4] n | n<-[0..] ] == [1,2,12,100,968,10208,113792,1318832 ..]
??      = [ length $ semiRegularTrees [2,5] n | n<-[0..] ] == [1,2,14,142,1690,21994,303126,4348102 ..]

A144097 = [ length $ semiRegularTrees [3,4] n | n<-[0..] ] == [1,2,14,134,1482,17818,226214,2984206,40503890..]

A107708 = [ length $ semiRegularTrees [1,2,3]   n | n<-[0..] ] == [1,3,18,144,1323,13176,138348,1507977 .. ]
??      = [ length $ semiRegularTrees [1,2,3,4] n | n<-[0..] ] == [1,4,40,560,9120,161856,3036800,59242240 .. ] 

-}
             
--------------------------------------------------------------------------------

-- | Vertical ASCII drawing of a tree, without labels. Example:
--
-- > autoTabulate RowMajor (Right 5) $ map asciiTreeVertical_ $ regularNaryTrees 2 4 
--
-- Nodes are denoted by @\@@, leaves by @*@.
--
asciiTreeVertical_ :: Tree a -> ASCII
asciiTreeVertical_ :: Tree a -> ASCII
asciiTreeVertical_ Tree a
tree = [[Char]] -> ASCII
ASCII.asciiFromLines (Tree a -> [[Char]]
forall b. Tree b -> [[Char]]
go Tree a
tree) where
  go :: Tree b -> [String]
  go :: Tree b -> [[Char]]
go (Node b
_ Forest b
cs) = case Forest b
cs of
    [] -> [[Char]
"-*"]
    Forest b
_  -> [[[Char]]] -> [[Char]]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat ([[[Char]]] -> [[Char]]) -> [[[Char]]] -> [[Char]]
forall a b. (a -> b) -> a -> b
$ (Bool -> Bool -> [[Char]] -> [[Char]]) -> [[[Char]]] -> [[[Char]]]
forall a b. (Bool -> Bool -> a -> b) -> [a] -> [b]
mapWithFirstLast Bool -> Bool -> [[Char]] -> [[Char]]
f ([[[Char]]] -> [[[Char]]]) -> [[[Char]]] -> [[[Char]]]
forall a b. (a -> b) -> a -> b
$ (Tree b -> [[Char]]) -> Forest b -> [[[Char]]]
forall a b. (a -> b) -> [a] -> [b]
map Tree b -> [[Char]]
forall b. Tree b -> [[Char]]
go Forest b
cs
    
  f :: Bool -> Bool -> [String] -> [String] 
  f :: Bool -> Bool -> [[Char]] -> [[Char]]
f Bool
bf Bool
bl ([Char]
l:[[Char]]
ls) = let indent :: [Char]
indent = if Bool
bl           then [Char]
"  "  else  [Char]
"| "
                       gap :: [[Char]]
gap    = if Bool
bl           then []    else [[Char]
"| "]
                       branch :: [Char]
branch = if Bool
bl Bool -> Bool -> Bool
&& Bool -> Bool
not Bool
bf 
                                  then [Char]
"\\-" 
                                  else if Bool
bf then [Char]
"@-"
                                             else [Char]
"+-"
                   in  ([Char]
branch[Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++[Char]
l) [Char] -> [[Char]] -> [[Char]]
forall a. a -> [a] -> [a]
: ([Char] -> [Char]) -> [[Char]] -> [[Char]]
forall a b. (a -> b) -> [a] -> [b]
map ([Char]
indent[Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++) [[Char]]
ls [[Char]] -> [[Char]] -> [[Char]]
forall a. [a] -> [a] -> [a]
++ [[Char]]
gap

instance DrawASCII (Tree ()) where
  ascii :: Tree () -> ASCII
ascii = Tree () -> ASCII
forall a. Tree a -> ASCII
asciiTreeVertical_

-- | Prints all labels. Example:
-- 
-- > asciiTreeVertical $ addUniqueLabelsTree_ $ (regularNaryTrees 3 9) !! 666
--
-- Nodes are denoted by @(label)@, leaves by @label@.
--
asciiTreeVertical :: Show a => Tree a -> ASCII
asciiTreeVertical :: Tree a -> ASCII
asciiTreeVertical Tree a
tree = [[Char]] -> ASCII
ASCII.asciiFromLines (Tree a -> [[Char]]
forall b. Show b => Tree b -> [[Char]]
go Tree a
tree) where
  go :: Show b => Tree b -> [String]
  go :: Tree b -> [[Char]]
go (Node b
x Forest b
cs) = case Forest b
cs of
    [] -> [[Char]
"-- " [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ b -> [Char]
forall a. Show a => a -> [Char]
show b
x]
    Forest b
_  -> [[[Char]]] -> [[Char]]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat ([[[Char]]] -> [[Char]]) -> [[[Char]]] -> [[Char]]
forall a b. (a -> b) -> a -> b
$ (Bool -> Bool -> [[Char]] -> [[Char]]) -> [[[Char]]] -> [[[Char]]]
forall a b. (Bool -> Bool -> a -> b) -> [a] -> [b]
mapWithFirstLast ([Char] -> Bool -> Bool -> [[Char]] -> [[Char]]
f (b -> [Char]
forall a. Show a => a -> [Char]
show b
x)) ([[[Char]]] -> [[[Char]]]) -> [[[Char]]] -> [[[Char]]]
forall a b. (a -> b) -> a -> b
$ (Tree b -> [[Char]]) -> Forest b -> [[[Char]]]
forall a b. (a -> b) -> [a] -> [b]
map Tree b -> [[Char]]
forall b. Show b => Tree b -> [[Char]]
go Forest b
cs
    
  f :: String -> Bool -> Bool -> [String] -> [String] 
  f :: [Char] -> Bool -> Bool -> [[Char]] -> [[Char]]
f [Char]
label Bool
bf Bool
bl ([Char]
l:[[Char]]
ls) =
        let spaces :: [Char]
spaces = ((Char -> Char) -> [Char] -> [Char]
forall a b. (a -> b) -> [a] -> [b]
map (Char -> Char -> Char
forall a b. a -> b -> a
const Char
' ') [Char]
label  ) 
            dashes :: [Char]
dashes = ((Char -> Char) -> [Char] -> [Char]
forall a b. (a -> b) -> [a] -> [b]
map (Char -> Char -> Char
forall a b. a -> b -> a
const Char
'-') [Char]
spaces ) 
            indent :: [Char]
indent = if Bool
bl then [Char]
"  " [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++[Char]
spaces[Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++[Char]
"  " else  [Char]
" |" [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
spaces [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
"  "
            gap :: [[Char]]
gap    = if Bool
bl then []                  else [[Char]
" |" [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
spaces [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
"  "]
            branch :: [Char]
branch = if Bool
bl Bool -> Bool -> Bool
&& Bool -> Bool
not Bool
bf
                           then [Char]
" \\"[Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++[Char]
dashes[Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++[Char]
"--" 
                           else if Bool
bf 
                             then [Char]
"-(" [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
label  [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
")-"
                             else [Char]
" +" [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
dashes [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
"--"
        in  ([Char]
branch[Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++[Char]
l) [Char] -> [[Char]] -> [[Char]]
forall a. a -> [a] -> [a]
: ([Char] -> [Char]) -> [[Char]] -> [[Char]]
forall a b. (a -> b) -> [a] -> [b]
map ([Char]
indent[Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++) [[Char]]
ls [[Char]] -> [[Char]] -> [[Char]]
forall a. [a] -> [a] -> [a]
++ [[Char]]
gap

-- | Prints the labels for the leaves, but not for the  nodes.
asciiTreeVerticalLeavesOnly :: Show a => Tree a -> ASCII
asciiTreeVerticalLeavesOnly :: Tree a -> ASCII
asciiTreeVerticalLeavesOnly Tree a
tree = [[Char]] -> ASCII
ASCII.asciiFromLines (Tree a -> [[Char]]
forall b. Show b => Tree b -> [[Char]]
go Tree a
tree) where
  go :: Show b => Tree b -> [String]
  go :: Tree b -> [[Char]]
go (Node b
x Forest b
cs) = case Forest b
cs of
    [] -> [[Char]
"- " [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ b -> [Char]
forall a. Show a => a -> [Char]
show b
x]
    Forest b
_  -> [[[Char]]] -> [[Char]]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat ([[[Char]]] -> [[Char]]) -> [[[Char]]] -> [[Char]]
forall a b. (a -> b) -> a -> b
$ (Bool -> Bool -> [[Char]] -> [[Char]]) -> [[[Char]]] -> [[[Char]]]
forall a b. (Bool -> Bool -> a -> b) -> [a] -> [b]
mapWithFirstLast Bool -> Bool -> [[Char]] -> [[Char]]
f ([[[Char]]] -> [[[Char]]]) -> [[[Char]]] -> [[[Char]]]
forall a b. (a -> b) -> a -> b
$ (Tree b -> [[Char]]) -> Forest b -> [[[Char]]]
forall a b. (a -> b) -> [a] -> [b]
map Tree b -> [[Char]]
forall b. Show b => Tree b -> [[Char]]
go Forest b
cs
    
  f :: Bool -> Bool -> [String] -> [String] 
  f :: Bool -> Bool -> [[Char]] -> [[Char]]
f Bool
bf Bool
bl ([Char]
l:[[Char]]
ls) = let indent :: [Char]
indent = if Bool
bl           then [Char]
"  "  else  [Char]
"| "
                       gap :: [[Char]]
gap    = if Bool
bl           then []    else [[Char]
"| "]
                       branch :: [Char]
branch = if Bool
bl Bool -> Bool -> Bool
&& Bool -> Bool
not Bool
bf 
                                  then [Char]
"\\-" 
                                  else if Bool
bf then [Char]
"@-"
                                             else [Char]
"+-"
                   in  ([Char]
branch[Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++[Char]
l) [Char] -> [[Char]] -> [[Char]]
forall a. a -> [a] -> [a]
: ([Char] -> [Char]) -> [[Char]] -> [[Char]]
forall a b. (a -> b) -> [a] -> [b]
map ([Char]
indent[Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++) [[Char]]
ls [[Char]] -> [[Char]] -> [[Char]]
forall a. [a] -> [a] -> [a]
++ [[Char]]
gap
  
--------------------------------------------------------------------------------
  
-- | The leftmost spine (the second element of the pair is the leaf node)
leftSpine  :: Tree a -> ([a],a)
leftSpine :: Tree a -> ([a], a)
leftSpine = Tree a -> ([a], a)
forall a. Tree a -> ([a], a)
go where
  go :: Tree a -> ([a], a)
go (Node a
x Forest a
cs) = case Forest a
cs of
    [] -> ([],a
x)
    Forest a
_  -> let ([a]
xs,a
y) = Tree a -> ([a], a)
go (Forest a -> Tree a
forall a. [a] -> a
head Forest a
cs) in (a
xa -> [a] -> [a]
forall a. a -> [a] -> [a]
:[a]
xs,a
y) 

rightSpine  :: Tree a -> ([a],a)
rightSpine :: Tree a -> ([a], a)
rightSpine = Tree a -> ([a], a)
forall a. Tree a -> ([a], a)
go where
  go :: Tree a -> ([a], a)
go (Node a
x Forest a
cs) = case Forest a
cs of
    [] -> ([],a
x)
    Forest a
_  -> let ([a]
xs,a
y) = Tree a -> ([a], a)
go (Forest a -> Tree a
forall a. [a] -> a
last Forest a
cs) in (a
xa -> [a] -> [a]
forall a. a -> [a] -> [a]
:[a]
xs,a
y) 

-- | The leftmost spine without the leaf node
leftSpine_  :: Tree a -> [a]
leftSpine_ :: Tree a -> [a]
leftSpine_ = Tree a -> [a]
forall a. Tree a -> [a]
go where
  go :: Tree a -> [a]
go (Node a
x Forest a
cs) = case Forest a
cs of
    [] -> []
    Forest a
_  -> a
x a -> [a] -> [a]
forall a. a -> [a] -> [a]
: Tree a -> [a]
go (Forest a -> Tree a
forall a. [a] -> a
head Forest a
cs)

rightSpine_ :: Tree a -> [a] 
rightSpine_ :: Tree a -> [a]
rightSpine_ = Tree a -> [a]
forall a. Tree a -> [a]
go where
  go :: Tree a -> [a]
go (Node a
x Forest a
cs) = case Forest a
cs of
    [] -> []
    Forest a
_  -> a
x a -> [a] -> [a]
forall a. a -> [a] -> [a]
: Tree a -> [a]
go (Forest a -> Tree a
forall a. [a] -> a
last Forest a
cs) 

-- | The length (number of edges) on the left spine 
--
-- > leftSpineLength tree == length (leftSpine_ tree)
--
leftSpineLength  :: Tree a -> Int  
leftSpineLength :: Tree a -> Int
leftSpineLength = Int -> Tree a -> Int
forall t a. Num t => t -> Tree a -> t
go Int
0 where
  go :: t -> Tree a -> t
go t
n (Node a
x Forest a
cs) = case Forest a
cs of
    [] -> t
n
    Forest a
_  -> t -> Tree a -> t
go (t
nt -> t -> t
forall a. Num a => a -> a -> a
+t
1) (Forest a -> Tree a
forall a. [a] -> a
head Forest a
cs)
  
rightSpineLength :: Tree a -> Int  
rightSpineLength :: Tree a -> Int
rightSpineLength = Int -> Tree a -> Int
forall t a. Num t => t -> Tree a -> t
go Int
0 where
  go :: t -> Tree a -> t
go t
n (Node a
x Forest a
cs) = case Forest a
cs of
    [] -> t
n
    Forest a
_  -> t -> Tree a -> t
go (t
nt -> t -> t
forall a. Num a => a -> a -> a
+t
1) (Forest a -> Tree a
forall a. [a] -> a
last Forest a
cs)

--------------------------------------------------------------------------------

-- | 'Left' is leaf, 'Right' is node
classifyTreeNode :: Tree a -> Either a a
classifyTreeNode :: Tree a -> Either a a
classifyTreeNode (Node a
x Forest a
cs) = case Forest a
cs of { [] -> a -> Either a a
forall a b. a -> Either a b
Left a
x ; Forest a
_ -> a -> Either a a
forall a b. b -> Either a b
Right a
x }

isTreeLeaf :: Tree a -> Maybe a  
isTreeLeaf :: Tree a -> Maybe a
isTreeLeaf (Node a
x Forest a
cs) = case Forest a
cs of { [] -> a -> Maybe a
forall a. a -> Maybe a
Just a
x ; Forest a
_ -> Maybe a
forall a. Maybe a
Nothing }  

isTreeNode :: Tree a -> Maybe a  
isTreeNode :: Tree a -> Maybe a
isTreeNode (Node a
x Forest a
cs) = case Forest a
cs of { [] -> Maybe a
forall a. Maybe a
Nothing ; Forest a
_ -> a -> Maybe a
forall a. a -> Maybe a
Just a
x }  

isTreeLeaf_ :: Tree a -> Bool  
isTreeLeaf_ :: Tree a -> Bool
isTreeLeaf_ (Node a
x Forest a
cs) = case Forest a
cs of { [] -> Bool
True ; Forest a
_ -> Bool
False }  
  
isTreeNode_ :: Tree a -> Bool  
isTreeNode_ :: Tree a -> Bool
isTreeNode_ (Node a
x Forest a
cs) = case Forest a
cs of { [] -> Bool
False ; Forest a
_ -> Bool
True }  

treeNodeNumberOfChildren :: Tree a -> Int
treeNodeNumberOfChildren :: Tree a -> Int
treeNodeNumberOfChildren (Node a
_ Forest a
cs) = Forest a -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length Forest a
cs

--------------------------------------------------------------------------------
-- counting

countTreeNodes :: Tree a -> Int
countTreeNodes :: Tree a -> Int
countTreeNodes = Tree a -> Int
forall p a. Num p => Tree a -> p
go where
  go :: Tree a -> p
go (Node a
x Forest a
cs) = case Forest a
cs of
    [] -> p
0
    Forest a
_  -> p
1 p -> p -> p
forall a. Num a => a -> a -> a
+ [p] -> p
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum ((Tree a -> p) -> Forest a -> [p]
forall a b. (a -> b) -> [a] -> [b]
map Tree a -> p
go Forest a
cs)

countTreeLeaves :: Tree a -> Int
countTreeLeaves :: Tree a -> Int
countTreeLeaves = Tree a -> Int
forall p a. Num p => Tree a -> p
go where
  go :: Tree a -> p
go (Node a
x Forest a
cs) = case Forest a
cs of
    [] -> p
1
    Forest a
_  -> [p] -> p
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum ((Tree a -> p) -> Forest a -> [p]
forall a b. (a -> b) -> [a] -> [b]
map Tree a -> p
go Forest a
cs)

countTreeLabelsWith :: (a -> Bool) -> Tree a -> Int
countTreeLabelsWith :: (a -> Bool) -> Tree a -> Int
countTreeLabelsWith a -> Bool
f = Tree a -> Int
forall a. Num a => Tree a -> a
go where
  go :: Tree a -> a
go (Node a
label Forest a
cs) = (if a -> Bool
f a
label then a
1 else a
0) a -> a -> a
forall a. Num a => a -> a -> a
+ [a] -> a
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum ((Tree a -> a) -> Forest a -> [a]
forall a b. (a -> b) -> [a] -> [b]
map Tree a -> a
go Forest a
cs)

countTreeNodesWith :: (Tree a -> Bool) -> Tree a -> Int
countTreeNodesWith :: (Tree a -> Bool) -> Tree a -> Int
countTreeNodesWith Tree a -> Bool
f = Tree a -> Int
forall a. Num a => Tree a -> a
go where
  go :: Tree a -> a
go node :: Tree a
node@(Node a
_ Forest a
cs) = (if Tree a -> Bool
f Tree a
node then a
1 else a
0) a -> a -> a
forall a. Num a => a -> a -> a
+ [a] -> a
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum ((Tree a -> a) -> Forest a -> [a]
forall a b. (a -> b) -> [a] -> [b]
map Tree a -> a
go Forest a
cs)

--------------------------------------------------------------------------------

-- | Adds unique labels to the nodes (including leaves) of a 'Tree'.
addUniqueLabelsTree :: Tree a -> Tree (a,Int) 
addUniqueLabelsTree :: Tree a -> Tree (a, Int)
addUniqueLabelsTree Tree a
tree = [Tree (a, Int)] -> Tree (a, Int)
forall a. [a] -> a
head (Forest a -> [Tree (a, Int)]
forall a. Forest a -> Forest (a, Int)
addUniqueLabelsForest [Tree a
tree])

-- | Adds unique labels to the nodes (including leaves) of a 'Forest'
addUniqueLabelsForest :: Forest a -> Forest (a,Int) 
addUniqueLabelsForest :: Forest a -> Forest (a, Int)
addUniqueLabelsForest Forest a
forest = State Int (Forest (a, Int)) -> Int -> Forest (a, Int)
forall s a. State s a -> s -> a
evalState ((Tree a -> StateT Int Identity (Tree (a, Int)))
-> Forest a -> State Int (Forest (a, Int))
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Tree a -> StateT Int Identity (Tree (a, Int))
forall (t :: * -> *) (m :: * -> *) b a.
(Traversable t, Monad m, Num b) =>
t a -> StateT b m (t (a, b))
globalAction Forest a
forest) Int
1 where
  globalAction :: t a -> StateT b m (t (a, b))
globalAction t a
tree = 
    WrappedMonad (StateT b m) (t (a, b)) -> StateT b m (t (a, b))
forall (m :: * -> *) a. WrappedMonad m a -> m a
unwrapMonad (WrappedMonad (StateT b m) (t (a, b)) -> StateT b m (t (a, b)))
-> WrappedMonad (StateT b m) (t (a, b)) -> StateT b m (t (a, b))
forall a b. (a -> b) -> a -> b
$ (a -> WrappedMonad (StateT b m) (a, b))
-> t a -> WrappedMonad (StateT b m) (t (a, b))
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse a -> WrappedMonad (StateT b m) (a, b)
forall (m :: * -> *) b a.
(Monad m, Num b) =>
a -> WrappedMonad (StateT b m) (a, b)
localAction t a
tree 
  localAction :: a -> WrappedMonad (StateT b m) (a, b)
localAction a
x = StateT b m (a, b) -> WrappedMonad (StateT b m) (a, b)
forall (m :: * -> *) a. m a -> WrappedMonad m a
WrapMonad (StateT b m (a, b) -> WrappedMonad (StateT b m) (a, b))
-> StateT b m (a, b) -> WrappedMonad (StateT b m) (a, b)
forall a b. (a -> b) -> a -> b
$ do
    b
i <- StateT b m b
forall (m :: * -> *) s. Monad m => StateT s m s
get
    b -> StateT b m ()
forall (m :: * -> *) s. Monad m => s -> StateT s m ()
put (b
ib -> b -> b
forall a. Num a => a -> a -> a
+b
1)
    (a, b) -> StateT b m (a, b)
forall (m :: * -> *) a. Monad m => a -> m a
return (a
x,b
i)

addUniqueLabelsTree_ :: Tree a -> Tree Int
addUniqueLabelsTree_ :: Tree a -> Tree Int
addUniqueLabelsTree_ = ((a, Int) -> Int) -> Tree (a, Int) -> Tree Int
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (a, Int) -> Int
forall a b. (a, b) -> b
snd (Tree (a, Int) -> Tree Int)
-> (Tree a -> Tree (a, Int)) -> Tree a -> Tree Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Tree a -> Tree (a, Int)
forall a. Tree a -> Tree (a, Int)
addUniqueLabelsTree  

addUniqueLabelsForest_ :: Forest a -> Forest Int
addUniqueLabelsForest_ :: Forest a -> Forest Int
addUniqueLabelsForest_ = (Tree (a, Int) -> Tree Int) -> [Tree (a, Int)] -> Forest Int
forall a b. (a -> b) -> [a] -> [b]
map (((a, Int) -> Int) -> Tree (a, Int) -> Tree Int
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (a, Int) -> Int
forall a b. (a, b) -> b
snd) ([Tree (a, Int)] -> Forest Int)
-> (Forest a -> [Tree (a, Int)]) -> Forest a -> Forest Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Forest a -> [Tree (a, Int)]
forall a. Forest a -> Forest (a, Int)
addUniqueLabelsForest

--------------------------------------------------------------------------------
    
-- | Attaches the depth to each node. The depth of the root is 0. 
labelDepthTree :: Tree a -> Tree (a,Int) 
labelDepthTree :: Tree a -> Tree (a, Int)
labelDepthTree Tree a
tree = Int -> Tree a -> Tree (a, Int)
forall t a. Num t => t -> Tree a -> Tree (a, t)
worker Int
0 Tree a
tree where
  worker :: t -> Tree a -> Tree (a, t)
worker t
depth (Node a
label Forest a
subtrees) = (a, t) -> Forest (a, t) -> Tree (a, t)
forall a. a -> Forest a -> Tree a
Node (a
label,t
depth) ((Tree a -> Tree (a, t)) -> Forest a -> Forest (a, t)
forall a b. (a -> b) -> [a] -> [b]
map (t -> Tree a -> Tree (a, t)
worker (t
deptht -> t -> t
forall a. Num a => a -> a -> a
+t
1)) Forest a
subtrees)

labelDepthForest :: Forest a -> Forest (a,Int) 
labelDepthForest :: Forest a -> Forest (a, Int)
labelDepthForest Forest a
forest = (Tree a -> Tree (a, Int)) -> Forest a -> Forest (a, Int)
forall a b. (a -> b) -> [a] -> [b]
map Tree a -> Tree (a, Int)
forall a. Tree a -> Tree (a, Int)
labelDepthTree Forest a
forest
    
labelDepthTree_ :: Tree a -> Tree Int
labelDepthTree_ :: Tree a -> Tree Int
labelDepthTree_ = ((a, Int) -> Int) -> Tree (a, Int) -> Tree Int
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (a, Int) -> Int
forall a b. (a, b) -> b
snd (Tree (a, Int) -> Tree Int)
-> (Tree a -> Tree (a, Int)) -> Tree a -> Tree Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Tree a -> Tree (a, Int)
forall a. Tree a -> Tree (a, Int)
labelDepthTree

labelDepthForest_ :: Forest a -> Forest Int 
labelDepthForest_ :: Forest a -> Forest Int
labelDepthForest_ = (Tree (a, Int) -> Tree Int) -> [Tree (a, Int)] -> Forest Int
forall a b. (a -> b) -> [a] -> [b]
map (((a, Int) -> Int) -> Tree (a, Int) -> Tree Int
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (a, Int) -> Int
forall a b. (a, b) -> b
snd) ([Tree (a, Int)] -> Forest Int)
-> (Forest a -> [Tree (a, Int)]) -> Forest a -> Forest Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Forest a -> [Tree (a, Int)]
forall a. Forest a -> Forest (a, Int)
labelDepthForest

--------------------------------------------------------------------------------

-- | Attaches the number of children to each node. 
labelNChildrenTree :: Tree a -> Tree (a,Int)
labelNChildrenTree :: Tree a -> Tree (a, Int)
labelNChildrenTree (Node a
x Forest a
subforest) = 
  (a, Int) -> Forest (a, Int) -> Tree (a, Int)
forall a. a -> Forest a -> Tree a
Node (a
x, Forest a -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length Forest a
subforest) ((Tree a -> Tree (a, Int)) -> Forest a -> Forest (a, Int)
forall a b. (a -> b) -> [a] -> [b]
map Tree a -> Tree (a, Int)
forall a. Tree a -> Tree (a, Int)
labelNChildrenTree Forest a
subforest)
  
labelNChildrenForest :: Forest a -> Forest (a,Int) 
labelNChildrenForest :: Forest a -> Forest (a, Int)
labelNChildrenForest Forest a
forest = (Tree a -> Tree (a, Int)) -> Forest a -> Forest (a, Int)
forall a b. (a -> b) -> [a] -> [b]
map Tree a -> Tree (a, Int)
forall a. Tree a -> Tree (a, Int)
labelNChildrenTree Forest a
forest

labelNChildrenTree_ :: Tree a -> Tree Int
labelNChildrenTree_ :: Tree a -> Tree Int
labelNChildrenTree_ = ((a, Int) -> Int) -> Tree (a, Int) -> Tree Int
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (a, Int) -> Int
forall a b. (a, b) -> b
snd (Tree (a, Int) -> Tree Int)
-> (Tree a -> Tree (a, Int)) -> Tree a -> Tree Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Tree a -> Tree (a, Int)
forall a. Tree a -> Tree (a, Int)
labelNChildrenTree

labelNChildrenForest_ :: Forest a -> Forest Int 
labelNChildrenForest_ :: Forest a -> Forest Int
labelNChildrenForest_ = (Tree (a, Int) -> Tree Int) -> [Tree (a, Int)] -> Forest Int
forall a b. (a -> b) -> [a] -> [b]
map (((a, Int) -> Int) -> Tree (a, Int) -> Tree Int
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (a, Int) -> Int
forall a b. (a, b) -> b
snd) ([Tree (a, Int)] -> Forest Int)
-> (Forest a -> [Tree (a, Int)]) -> Forest a -> Forest Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Forest a -> [Tree (a, Int)]
forall a. Forest a -> Forest (a, Int)
labelNChildrenForest
    
--------------------------------------------------------------------------------

-- | Computes the set of equivalence classes of rooted trees (in the 
-- sense that the leaves of a node are /unordered/) 
-- with @n = length ks@ leaves where the set of heights of 
-- the leaves matches the given set of numbers. 
-- The height is defined as the number of /edges/ from the leaf to the root. 
--
-- TODO: better name?
derivTrees :: [Int] -> [Tree ()]
derivTrees :: [Int] -> [Tree ()]
derivTrees [Int]
xs = [Int] -> [Tree ()]
derivTrees' ((Int -> Int) -> [Int] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map (Int -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1) [Int]
xs)

derivTrees' :: [Int] -> [Tree ()]
derivTrees' :: [Int] -> [Tree ()]
derivTrees' [] = []
derivTrees' [Int
n] = 
  if Int
nInt -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>=Int
1 
    then [(Int -> ((), [Int])) -> Int -> Tree ()
forall b a. (b -> (a, [b])) -> b -> Tree a
unfoldTree Int -> ((), [Int])
f Int
1] 
    else [] 
  where 
    f :: Int -> ((), [Int])
f Int
k = if Int
kInt -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<Int
n then ((),[Int
kInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1]) else ((),[])
derivTrees' [Int]
ks = 
  if [Bool] -> Bool
forall (t :: * -> *). Foldable t => t Bool -> Bool
and ((Int -> Bool) -> [Int] -> [Bool]
forall a b. (a -> b) -> [a] -> [b]
map (Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>Int
0) [Int]
ks)
    then
      [ () -> [Tree ()] -> Tree ()
forall a. a -> Forest a -> Tree a
Node () [Tree ()]
sub 
      | [[Int]]
part <- [[[Int]]]
parts
      , let subtrees :: [[Tree ()]]
subtrees = ([Int] -> [Tree ()]) -> [[Int]] -> [[Tree ()]]
forall a b. (a -> b) -> [a] -> [b]
map [Int] -> [Tree ()]
g [[Int]]
part
      , [Tree ()]
sub <- [[Tree ()]] -> [[Tree ()]]
forall a. [[a]] -> [[a]]
listTensor [[Tree ()]]
subtrees 
      ] 
    else []
  where
    parts :: [[[Int]]]
parts = [Int] -> [[[Int]]]
forall a. (Eq a, Ord a) => [a] -> [[[a]]]
partitionMultiset [Int]
ks
    g :: [Int] -> [Tree ()]
g [Int]
xs = [Int] -> [Tree ()]
derivTrees' ((Int -> Int) -> [Int] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map (\Int
x->Int
xInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1) [Int]
xs)

--------------------------------------------------------------------------------