-- |
-- Module      :  ELynx.Tree.Distance
-- Description :  Compute distances between trees
-- Copyright   :  (c) Dominik Schrempf 2021
-- License     :  GPL-3.0-or-later
--
-- Maintainer  :  dominik.schrempf@gmail.com
-- Stability   :  unstable
-- Portability :  portable
--
-- Creation date: Thu Jun 13 17:15:54 2019.
--
-- Various distance functions for trees.
--
-- The functions provided in this module return distances for __unrooted__
-- trees. See comments of 'symmetric', 'branchScore', and 'bipartitionToBranch',
-- as well as the documentation of
-- [treedist](http://evolution.genetics.washington.edu/phylip/doc/treedist.html).
--
-- It is a little unfortunate that the 'Tree' data type represents rooted trees.
-- However, rooted trees are much easier to handle computationally. In the
-- future, a separate data type for unrooted trees may be introduced, for
-- example, using algebraic graphs. Difficulties may arise because the branches
-- of an unrooted tree are undirected.
module ELynx.Tree.Distance
  ( symmetric,
    incompatibleSplits,
    branchScore,
  )
where

import Data.Bifunctor
import Data.List
import qualified Data.Map as M
import Data.Monoid
import Data.Set (Set)
import qualified Data.Set as S
import ELynx.Tree.Bipartition
import ELynx.Tree.Length
import ELynx.Tree.Partition
import ELynx.Tree.Rooted

-- Symmetric difference between two 'Set's.
symmetricDifference :: Ord a => Set a -> Set a -> Set a
symmetricDifference :: Set a -> Set a -> Set a
symmetricDifference Set a
xs Set a
ys = Set a -> Set a -> Set a
forall a. Ord a => Set a -> Set a -> Set a
S.difference Set a
xs Set a
ys Set a -> Set a -> Set a
forall a. Ord a => Set a -> Set a -> Set a
`S.union` Set a -> Set a -> Set a
forall a. Ord a => Set a -> Set a -> Set a
S.difference Set a
ys Set a
xs

-- | Symmetric (Robinson-Foulds) distance between two trees.
--
-- Although a rooted tree data type is used, the distance between the unrooted
-- trees is returned.
--
-- Return 'Nothing' if the trees contain different leaves.
--
-- XXX: Comparing a list of trees recomputes bipartitions.
symmetric :: Ord a => Tree e1 a -> Tree e2 a -> Either String Int
symmetric :: Tree e1 a -> Tree e2 a -> Either String Int
symmetric Tree e1 a
t1 Tree e2 a
t2
  | [a] -> Set a
forall a. Ord a => [a] -> Set a
S.fromList (Tree e1 a -> [a]
forall e a. Tree e a -> [a]
leaves Tree e1 a
t1) Set a -> Set a -> Bool
forall a. Eq a => a -> a -> Bool
/= [a] -> Set a
forall a. Ord a => [a] -> Set a
S.fromList (Tree e2 a -> [a]
forall e a. Tree e a -> [a]
leaves Tree e2 a
t2) =
    String -> Either String Int
forall a b. a -> Either a b
Left String
"symmetric: Trees contain different leaves."
  | Bool
otherwise = do
    Set (Bipartition a)
bps1 <- Tree e1 a -> Either String (Set (Bipartition a))
forall a e.
Ord a =>
Tree e a -> Either String (Set (Bipartition a))
bipartitions Tree e1 a
t1
    Set (Bipartition a)
bps2 <- Tree e2 a -> Either String (Set (Bipartition a))
forall a e.
Ord a =>
Tree e a -> Either String (Set (Bipartition a))
bipartitions Tree e2 a
t2
    Int -> Either String Int
forall (m :: * -> *) a. Monad m => a -> m a
return (Int -> Either String Int) -> Int -> Either String Int
forall a b. (a -> b) -> a -> b
$ Set (Bipartition a) -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (Set (Bipartition a) -> Int) -> Set (Bipartition a) -> Int
forall a b. (a -> b) -> a -> b
$ Set (Bipartition a) -> Set (Bipartition a) -> Set (Bipartition a)
forall a. Ord a => Set a -> Set a -> Set a
symmetricDifference Set (Bipartition a)
bps1 Set (Bipartition a)
bps2

countIncompatibilities :: (Show a, Ord a) => Set (Bipartition a) -> Set (Partition a) -> Int
countIncompatibilities :: Set (Bipartition a) -> Set (Partition a) -> Int
countIncompatibilities Set (Bipartition a)
bs Set (Partition a)
ms =
  (Int -> Bipartition a -> Int) -> Int -> Set (Bipartition a) -> Int
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' (\Int
i Bipartition a
b -> if (Partition a -> Bool) -> Set (Partition a) -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (Partition a -> Partition a -> Bool
forall a. (Show a, Ord a) => Partition a -> Partition a -> Bool
compatible (Partition a -> Partition a -> Bool)
-> Partition a -> Partition a -> Bool
forall a b. (a -> b) -> a -> b
$ Bipartition a -> Partition a
forall a. Ord a => Bipartition a -> Partition a
bpToPt Bipartition a
b) Set (Partition a)
ms then Int
i else Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) Int
0 Set (Bipartition a)
bs

-- | Number of incompatible splits.
--
-- Similar to 'symmetric' but all bipartitions induced by multifurcations are
-- considered. For a detailed description of how the distance is calculated, see
-- 'ELynx.Tree.Bipartition.bipartitionCompatible'.
--
-- A multifurcation on a tree may (but not necessarily does) represent missing
-- information about the order of bifurcations. In this case, it is interesting
-- to get a set of compatible bifurcations of the tree. For example, the star tree
--
-- > (A,B,C,D);
--
-- induces the following bipartitions:
--
-- > A|BCD
-- > B|ACD
-- > C|ABD
-- > D|ABC
--
-- However, the tree is additionally compatible with the following hidden
-- bipartitions:
--
-- > AB|CD
-- > AC|BD
-- > AD|BC
--
-- For an explanation of how compatibility of partitions is checked, see
-- 'compatible'. Before using 'compatible', bipartitions are simply converted to
-- partitions with two subsets.
--
-- A bipartition is incompatible with a tree if it is incompatible with all
-- induced multifurcations of the tree.
--
-- XXX: Comparing a list of trees recomputes bipartitions.
incompatibleSplits :: (Show a, Ord a) => Tree e1 a -> Tree e2 a -> Either String Int
incompatibleSplits :: Tree e1 a -> Tree e2 a -> Either String Int
incompatibleSplits Tree e1 a
t1 Tree e2 a
t2
  | [a] -> Set a
forall a. Ord a => [a] -> Set a
S.fromList (Tree e1 a -> [a]
forall e a. Tree e a -> [a]
leaves Tree e1 a
t1) Set a -> Set a -> Bool
forall a. Eq a => a -> a -> Bool
/= [a] -> Set a
forall a. Ord a => [a] -> Set a
S.fromList (Tree e2 a -> [a]
forall e a. Tree e a -> [a]
leaves Tree e2 a
t2) =
    String -> Either String Int
forall a b. a -> Either a b
Left String
"incompatibleSplits: Trees contain different leaves."
  | Bool
otherwise = do
    -- Bipartitions.
    Set (Bipartition a)
bs1 <- Tree e1 a -> Either String (Set (Bipartition a))
forall a e.
Ord a =>
Tree e a -> Either String (Set (Bipartition a))
bipartitions Tree e1 a
t1
    Set (Bipartition a)
bs2 <- Tree e2 a -> Either String (Set (Bipartition a))
forall a e.
Ord a =>
Tree e a -> Either String (Set (Bipartition a))
bipartitions Tree e2 a
t2
    -- traceShowM $ "bs1" ++ show (S.map bpHuman bs1)
    -- traceShowM $ "bs2" ++ show (S.map bpHuman bs2)
    let -- Putative incompatible bipartitions of trees one and two, respectively.
        putIncBs1 :: Set (Bipartition a)
putIncBs1 = Set (Bipartition a)
bs1 Set (Bipartition a) -> Set (Bipartition a) -> Set (Bipartition a)
forall a. Ord a => Set a -> Set a -> Set a
S.\\ Set (Bipartition a)
bs2
        putIncBs2 :: Set (Bipartition a)
putIncBs2 = Set (Bipartition a)
bs2 Set (Bipartition a) -> Set (Bipartition a) -> Set (Bipartition a)
forall a. Ord a => Set a -> Set a -> Set a
S.\\ Set (Bipartition a)
bs1
    -- Partitions.
    Set (Partition a)
ms1 <- Tree e1 a -> Either String (Set (Partition a))
forall a e. Ord a => Tree e a -> Either String (Set (Partition a))
partitions Tree e1 a
t1
    Set (Partition a)
ms2 <- Tree e2 a -> Either String (Set (Partition a))
forall a e. Ord a => Tree e a -> Either String (Set (Partition a))
partitions Tree e2 a
t2
    -- traceShowM $ "putIncBs1 " ++ show (S.map bpHuman putIncBs1)
    -- traceShowM $ "putIncBs2 " ++ show (S.map bpHuman putIncBs2)
    Int -> Either String Int
forall (m :: * -> *) a. Monad m => a -> m a
return (Int -> Either String Int) -> Int -> Either String Int
forall a b. (a -> b) -> a -> b
$ Set (Bipartition a) -> Set (Partition a) -> Int
forall a.
(Show a, Ord a) =>
Set (Bipartition a) -> Set (Partition a) -> Int
countIncompatibilities Set (Bipartition a)
putIncBs1 Set (Partition a)
ms2 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Set (Bipartition a) -> Set (Partition a) -> Int
forall a.
(Show a, Ord a) =>
Set (Bipartition a) -> Set (Partition a) -> Int
countIncompatibilities Set (Bipartition a)
putIncBs2 Set (Partition a)
ms1

-- | Compute branch score distance between two trees.
--
-- Although a rooted tree data type is used, the distance between the unrooted
-- trees is returned.
--
-- XXX: Comparing a list of trees recomputes bipartitions.
branchScore :: (HasLength e1, HasLength e2, Ord a) => Tree e1 a -> Tree e2 a -> Either String Double
branchScore :: Tree e1 a -> Tree e2 a -> Either String Double
branchScore Tree e1 a
t1 Tree e2 a
t2
  | [a] -> Set a
forall a. Ord a => [a] -> Set a
S.fromList (Tree e1 a -> [a]
forall e a. Tree e a -> [a]
leaves Tree e1 a
t1) Set a -> Set a -> Bool
forall a. Eq a => a -> a -> Bool
/= [a] -> Set a
forall a. Ord a => [a] -> Set a
S.fromList (Tree e2 a -> [a]
forall e a. Tree e a -> [a]
leaves Tree e2 a
t2) =
    String -> Either String Double
forall a b. a -> Either a b
Left String
"branchScoreWith: Trees contain different leaves."
  | Bool
otherwise = do
    Map (Bipartition a) (Sum Length)
bpToBr1 <- Tree (Sum Length) a
-> Either String (Map (Bipartition a) (Sum Length))
forall e a.
(Semigroup e, Ord a) =>
Tree e a -> Either String (Map (Bipartition a) e)
bipartitionToBranch (Tree (Sum Length) a
 -> Either String (Map (Bipartition a) (Sum Length)))
-> Tree (Sum Length) a
-> Either String (Map (Bipartition a) (Sum Length))
forall a b. (a -> b) -> a -> b
$ (e1 -> Sum Length) -> Tree e1 a -> Tree (Sum Length) a
forall (p :: * -> * -> *) a b c.
Bifunctor p =>
(a -> b) -> p a c -> p b c
first (Length -> Sum Length
forall a. a -> Sum a
Sum (Length -> Sum Length) -> (e1 -> Length) -> e1 -> Sum Length
forall b c a. (b -> c) -> (a -> b) -> a -> c
. e1 -> Length
forall e. HasLength e => e -> Length
getLength) Tree e1 a
t1
    Map (Bipartition a) (Sum Length)
bpToBr2 <- Tree (Sum Length) a
-> Either String (Map (Bipartition a) (Sum Length))
forall e a.
(Semigroup e, Ord a) =>
Tree e a -> Either String (Map (Bipartition a) e)
bipartitionToBranch (Tree (Sum Length) a
 -> Either String (Map (Bipartition a) (Sum Length)))
-> Tree (Sum Length) a
-> Either String (Map (Bipartition a) (Sum Length))
forall a b. (a -> b) -> a -> b
$ (e2 -> Sum Length) -> Tree e2 a -> Tree (Sum Length) a
forall (p :: * -> * -> *) a b c.
Bifunctor p =>
(a -> b) -> p a c -> p b c
first (Length -> Sum Length
forall a. a -> Sum a
Sum (Length -> Sum Length) -> (e2 -> Length) -> e2 -> Sum Length
forall b c a. (b -> c) -> (a -> b) -> a -> c
. e2 -> Length
forall e. HasLength e => e -> Length
getLength) Tree e2 a
t2
    let dBs :: Map (Bipartition a) (Sum Length)
dBs = (Sum Length -> Sum Length -> Sum Length)
-> Map (Bipartition a) (Sum Length)
-> Map (Bipartition a) (Sum Length)
-> Map (Bipartition a) (Sum Length)
forall k a. Ord k => (a -> a -> a) -> Map k a -> Map k a -> Map k a
M.unionWith (-) Map (Bipartition a) (Sum Length)
bpToBr1 Map (Bipartition a) (Sum Length)
bpToBr2
        dsSquared :: Sum Length
dsSquared = (Sum Length -> Sum Length -> Sum Length)
-> Sum Length -> Map (Bipartition a) (Sum Length) -> Sum Length
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' (\Sum Length
acc Sum Length
e -> Sum Length
acc Sum Length -> Sum Length -> Sum Length
forall a. Num a => a -> a -> a
+ Sum Length
e Sum Length -> Sum Length -> Sum Length
forall a. Num a => a -> a -> a
* Sum Length
e) Sum Length
0 Map (Bipartition a) (Sum Length)
dBs
    Double -> Either String Double
forall (m :: * -> *) a. Monad m => a -> m a
return (Double -> Either String Double) -> Double -> Either String Double
forall a b. (a -> b) -> a -> b
$ Double -> Double
forall a. Floating a => a -> a
sqrt (Double -> Double) -> Double -> Double
forall a b. (a -> b) -> a -> b
$ Length -> Double
fromLength (Length -> Double) -> Length -> Double
forall a b. (a -> b) -> a -> b
$ Sum Length -> Length
forall a. Sum a -> a
getSum Sum Length
dsSquared