{-# LANGUAGE OverloadedStrings #-}
module ELynx.Data.Tree.MeasurableTree
( Measurable (..)
, distancesRootLeaves
, distancesOriginLeaves
, averageDistanceOriginLeaves
, height
, rootHeight
, lengthenStem
, shortenStem
, summarize
, totalBranchLength
, normalize
, prune
, removeMultifurcations
, ultrametric
) where
import qualified Data.ByteString.Lazy.Char8 as L
import Data.Foldable
import Data.Tree
import ELynx.Data.Tree.Tree
import ELynx.Tools.Equality (allNearlyEqual)
class Measurable a where
getLen :: a -> Double
setLen :: Double -> a -> a
lengthen :: Measurable a => Double -> a -> a
lengthen dl l = setLen (dl + getLen l) l
distancesRootLeaves :: (Measurable a) => Tree a -> [Double]
distancesRootLeaves (Node _ []) = [0]
distancesRootLeaves (Node _ f ) = concat [map (+ getLen (rootLabel d)) (distancesRootLeaves d) | d <- f]
distancesOriginLeaves :: (Measurable a) => Tree a -> [Double]
distancesOriginLeaves t@(Node l _) = map (+ getLen l) (distancesRootLeaves t)
averageDistanceOriginLeaves :: (Measurable a) => Tree a -> Double
averageDistanceOriginLeaves tr = sum ds / fromIntegral n
where ds = distancesOriginLeaves tr
n = length ds
height :: (Measurable a) => Tree a -> Double
height = maximum . distancesOriginLeaves
rootHeight :: (Measurable a) => Tree a -> Double
rootHeight = maximum . distancesRootLeaves
lengthenStem :: (Measurable a) => Double -> Tree a -> Tree a
lengthenStem dl (Node lbl chs) = Node (lengthen dl lbl) chs
shortenStem :: (Measurable a) => Double -> Tree a -> Tree a
shortenStem dl = lengthenStem (-dl)
summarize :: (Measurable a) => Tree a -> L.ByteString
summarize t = L.intercalate "\n" $ map L.pack
[ "Leaves: " ++ show n ++ "."
, "Height: " ++ show h ++ "."
, "Average distance root to leaves: " ++ show h' ++ "."
, "Total branch length: " ++ show b ++ "." ]
where n = length . leaves $ t
h = height t
b = totalBranchLength t
h' = sum (distancesOriginLeaves t) / fromIntegral n
totalBranchLength :: (Measurable a) => Tree a -> Double
totalBranchLength = foldl' (\acc n -> acc + getLen n) 0
normalize :: (Measurable a) => Tree a -> Tree a
normalize t = fmap (\n -> setLen (getLen n / s) n) t
where s = totalBranchLength t
prune :: (Measurable a) => Tree a -> Tree a
prune = pruneWith f
where f da pa = lengthen (getLen pa) da
removeMultifurcations :: Measurable a => Tree a -> Tree a
removeMultifurcations t@(Node _ [] ) = t
removeMultifurcations (Node l [x]) = Node l [removeMultifurcations x]
removeMultifurcations (Node l [x, y]) = Node l $ map removeMultifurcations [x, y]
removeMultifurcations (Node l (x:xs)) = Node l $ map removeMultifurcations [x, Node l' xs]
where l' = setLen 1.0 l
ultrametric :: Measurable a => Tree a -> Bool
ultrametric = allNearlyEqual . distancesOriginLeaves