-- |
-- Module      :  ELynx.Tree.Distance
-- Description :  Compute distances between trees
-- Copyright   :  2021 Dominik Schrempf
-- License     :  GPL-3.0-or-later
--
-- Maintainer  :  dominik.schrempf@gmail.com
-- Stability   :  unstable
-- Portability :  portable
--
-- Creation date: Thu Jun 13 17:15:54 2019.
--
-- Various distance functions for trees.
--
-- The functions provided in this module return distances for __unrooted__
-- trees. See comments of 'symmetric', 'branchScore', and 'bipartitionToBranch',
-- as well as the documentation of
-- [treedist](http://evolution.genetics.washington.edu/phylip/doc/treedist.html).
--
-- It is a little unfortunate that the 'Tree' data type represents rooted trees.
-- However, rooted trees are much easier to handle computationally. In the
-- future, a separate data type for unrooted trees may be introduced, for
-- example, using algebraic graphs. Difficulties may arise because the branches
-- of an unrooted tree are undirected.
module ELynx.Tree.Distance
  ( symmetric,
    incompatibleSplits,
    branchScore,
  )
where

import Data.Bifunctor
import Data.List
import qualified Data.Map as M
import Data.Monoid
import Data.Set (Set)
import qualified Data.Set as S
import ELynx.Tree.Bipartition
import ELynx.Tree.Length
import ELynx.Tree.Partition
import ELynx.Tree.Rooted

-- Symmetric difference between two 'Set's.
symmetricDifference :: Ord a => Set a -> Set a -> Set a
symmetricDifference :: forall a. Ord a => Set a -> Set a -> Set a
symmetricDifference Set a
xs Set a
ys = forall a. Ord a => Set a -> Set a -> Set a
S.difference Set a
xs Set a
ys forall a. Ord a => Set a -> Set a -> Set a
`S.union` forall a. Ord a => Set a -> Set a -> Set a
S.difference Set a
ys Set a
xs

-- | Symmetric (Robinson-Foulds) distance between two trees.
--
-- Although a rooted tree data type is used, the distance between the unrooted
-- trees is returned.
--
-- Return 'Nothing' if the trees contain different leaves.
--
-- XXX: Comparing a list of trees recomputes bipartitions.
symmetric :: Ord a => Tree e1 a -> Tree e2 a -> Either String Int
symmetric :: forall a e1 e2.
Ord a =>
Tree e1 a -> Tree e2 a -> Either String Int
symmetric Tree e1 a
t1 Tree e2 a
t2
  | forall a. Ord a => [a] -> Set a
S.fromList (forall e a. Tree e a -> [a]
leaves Tree e1 a
t1) forall a. Eq a => a -> a -> Bool
/= forall a. Ord a => [a] -> Set a
S.fromList (forall e a. Tree e a -> [a]
leaves Tree e2 a
t2) =
      forall a b. a -> Either a b
Left String
"symmetric: Trees contain different leaves."
  | Bool
otherwise = do
      Set (Bipartition a)
bps1 <- forall a e.
Ord a =>
Tree e a -> Either String (Set (Bipartition a))
bipartitions Tree e1 a
t1
      Set (Bipartition a)
bps2 <- forall a e.
Ord a =>
Tree e a -> Either String (Set (Bipartition a))
bipartitions Tree e2 a
t2
      forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) a. Foldable t => t a -> Int
length forall a b. (a -> b) -> a -> b
$ forall a. Ord a => Set a -> Set a -> Set a
symmetricDifference Set (Bipartition a)
bps1 Set (Bipartition a)
bps2

countIncompatibilities :: (Show a, Ord a) => Set (Bipartition a) -> Set (Partition a) -> Int
countIncompatibilities :: forall a.
(Show a, Ord a) =>
Set (Bipartition a) -> Set (Partition a) -> Int
countIncompatibilities Set (Bipartition a)
bs Set (Partition a)
ms =
  forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' (\Int
i Bipartition a
b -> if forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (forall a. (Show a, Ord a) => Partition a -> Partition a -> Bool
compatible forall a b. (a -> b) -> a -> b
$ forall a. Ord a => Bipartition a -> Partition a
bpToPt Bipartition a
b) Set (Partition a)
ms then Int
i else Int
i forall a. Num a => a -> a -> a
+ Int
1) Int
0 Set (Bipartition a)
bs

-- | Number of incompatible splits.
--
-- Similar to 'symmetric' but all bipartitions induced by multifurcations are
-- considered. For a detailed description of how the distance is calculated, see
-- 'ELynx.Tree.Bipartition.bipartitionCompatible'.
--
-- A multifurcation on a tree may (but not necessarily does) represent missing
-- information about the order of bifurcations. In this case, it is interesting
-- to get a set of compatible bifurcations of the tree. For example, the star tree
--
-- > (A,B,C,D);
--
-- induces the following bipartitions:
--
-- > A|BCD
-- > B|ACD
-- > C|ABD
-- > D|ABC
--
-- However, the tree is additionally compatible with the following hidden
-- bipartitions:
--
-- > AB|CD
-- > AC|BD
-- > AD|BC
--
-- For an explanation of how compatibility of partitions is checked, see
-- 'compatible'. Before using 'compatible', bipartitions are simply converted to
-- partitions with two subsets.
--
-- A bipartition is incompatible with a tree if it is incompatible with all
-- induced multifurcations of the tree.
--
-- XXX: Comparing a list of trees recomputes bipartitions.
incompatibleSplits :: (Show a, Ord a) => Tree e1 a -> Tree e2 a -> Either String Int
incompatibleSplits :: forall a e1 e2.
(Show a, Ord a) =>
Tree e1 a -> Tree e2 a -> Either String Int
incompatibleSplits Tree e1 a
t1 Tree e2 a
t2
  | forall a. Ord a => [a] -> Set a
S.fromList (forall e a. Tree e a -> [a]
leaves Tree e1 a
t1) forall a. Eq a => a -> a -> Bool
/= forall a. Ord a => [a] -> Set a
S.fromList (forall e a. Tree e a -> [a]
leaves Tree e2 a
t2) =
      forall a b. a -> Either a b
Left String
"incompatibleSplits: Trees contain different leaves."
  | Bool
otherwise = do
      -- Bipartitions.
      Set (Bipartition a)
bs1 <- forall a e.
Ord a =>
Tree e a -> Either String (Set (Bipartition a))
bipartitions Tree e1 a
t1
      Set (Bipartition a)
bs2 <- forall a e.
Ord a =>
Tree e a -> Either String (Set (Bipartition a))
bipartitions Tree e2 a
t2
      -- traceShowM $ "bs1" ++ show (S.map bpHuman bs1)
      -- traceShowM $ "bs2" ++ show (S.map bpHuman bs2)
      let -- Putative incompatible bipartitions of trees one and two, respectively.
          putIncBs1 :: Set (Bipartition a)
putIncBs1 = Set (Bipartition a)
bs1 forall a. Ord a => Set a -> Set a -> Set a
S.\\ Set (Bipartition a)
bs2
          putIncBs2 :: Set (Bipartition a)
putIncBs2 = Set (Bipartition a)
bs2 forall a. Ord a => Set a -> Set a -> Set a
S.\\ Set (Bipartition a)
bs1
      -- Partitions.
      Set (Partition a)
ms1 <- forall a e. Ord a => Tree e a -> Either String (Set (Partition a))
partitions Tree e1 a
t1
      Set (Partition a)
ms2 <- forall a e. Ord a => Tree e a -> Either String (Set (Partition a))
partitions Tree e2 a
t2
      -- traceShowM $ "putIncBs1 " ++ show (S.map bpHuman putIncBs1)
      -- traceShowM $ "putIncBs2 " ++ show (S.map bpHuman putIncBs2)
      forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ forall a.
(Show a, Ord a) =>
Set (Bipartition a) -> Set (Partition a) -> Int
countIncompatibilities Set (Bipartition a)
putIncBs1 Set (Partition a)
ms2 forall a. Num a => a -> a -> a
+ forall a.
(Show a, Ord a) =>
Set (Bipartition a) -> Set (Partition a) -> Int
countIncompatibilities Set (Bipartition a)
putIncBs2 Set (Partition a)
ms1

-- | Compute branch score distance between two trees.
--
-- Although a rooted tree data type is used, the distance between the unrooted
-- trees is returned.
--
-- XXX: Comparing a list of trees recomputes bipartitions.
branchScore :: (HasLength e1, HasLength e2, Ord a) => Tree e1 a -> Tree e2 a -> Either String Double
branchScore :: forall e1 e2 a.
(HasLength e1, HasLength e2, Ord a) =>
Tree e1 a -> Tree e2 a -> Either String Double
branchScore Tree e1 a
t1 Tree e2 a
t2
  | forall a. Ord a => [a] -> Set a
S.fromList (forall e a. Tree e a -> [a]
leaves Tree e1 a
t1) forall a. Eq a => a -> a -> Bool
/= forall a. Ord a => [a] -> Set a
S.fromList (forall e a. Tree e a -> [a]
leaves Tree e2 a
t2) =
      forall a b. a -> Either a b
Left String
"branchScoreWith: Trees contain different leaves."
  | Bool
otherwise = do
      Map (Bipartition a) (Sum Length)
bpToBr1 <- forall e a.
(Semigroup e, Ord a) =>
Tree e a -> Either String (Map (Bipartition a) e)
bipartitionToBranch forall a b. (a -> b) -> a -> b
$ forall (p :: * -> * -> *) a b c.
Bifunctor p =>
(a -> b) -> p a c -> p b c
first (forall a. a -> Sum a
Sum forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall e. HasLength e => e -> Length
getLength) Tree e1 a
t1
      Map (Bipartition a) (Sum Length)
bpToBr2 <- forall e a.
(Semigroup e, Ord a) =>
Tree e a -> Either String (Map (Bipartition a) e)
bipartitionToBranch forall a b. (a -> b) -> a -> b
$ forall (p :: * -> * -> *) a b c.
Bifunctor p =>
(a -> b) -> p a c -> p b c
first (forall a. a -> Sum a
Sum forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall e. HasLength e => e -> Length
getLength) Tree e2 a
t2
      let dBs :: Map (Bipartition a) (Sum Length)
dBs = forall k a. Ord k => (a -> a -> a) -> Map k a -> Map k a -> Map k a
M.unionWith (-) Map (Bipartition a) (Sum Length)
bpToBr1 Map (Bipartition a) (Sum Length)
bpToBr2
          dsSquared :: Sum Length
dsSquared = forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' (\Sum Length
acc Sum Length
e -> Sum Length
acc forall a. Num a => a -> a -> a
+ Sum Length
e forall a. Num a => a -> a -> a
* Sum Length
e) Sum Length
0 Map (Bipartition a) (Sum Length)
dBs
      forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ forall a. Floating a => a -> a
sqrt forall a b. (a -> b) -> a -> b
$ Length -> Double
fromLength forall a b. (a -> b) -> a -> b
$ forall a. Sum a -> a
getSum Sum Length
dsSquared