{-# LANGUAGE BangPatterns #-}

-- |
--   Module      :  ELynx.Tree.Simulate.PointProcess
--   Description :  Point process and functions
--   Copyright   :  (c) Dominik Schrempf 2021
--   License     :  GPL-3.0-or-later
--
--   Maintainer  :  dominik.schrempf@gmail.com
--   Stability   :  unstable
--   Portability :  portable
--
-- Creation date: Tue Feb 13 13:16:18 2018.
--
-- See Gernhard, T. (2008). The conditioned reconstructed process. Journal of
-- Theoretical Biology, 253(4), 769–778. http://doi.org/10.1016/j.jtbi.2008.04.005.
--
-- The point process can be used to simulate reconstructed trees under the birth
-- and death process.
module ELynx.Tree.Simulate.PointProcess
  ( PointProcess (..),
    TimeSpec (..),
    simulate,
    toReconstructedTree,
    simulateReconstructedTree,
    simulateNReconstructedTrees,
  )
where

import Control.Monad
import Control.Monad.Primitive
import Data.Function
import Data.List
import Data.Sequence (Seq)
import qualified Data.Sequence as S
import ELynx.Tree.Distribution.BirthDeath
import ELynx.Tree.Distribution.BirthDeathCritical
import ELynx.Tree.Distribution.BirthDeathCriticalNoTime
import ELynx.Tree.Distribution.BirthDeathNearlyCritical
import ELynx.Tree.Distribution.TimeOfOrigin
import ELynx.Tree.Distribution.TimeOfOriginNearCritical
import ELynx.Tree.Distribution.Types
import ELynx.Tree.Length
import ELynx.Tree.Rooted
import qualified Statistics.Distribution as D
  ( genContVar,
  )
import System.Random.MWC

-- Require near critical process if birth and death rates are closer than this value.
epsNearCriticalPointProcess :: Double
epsNearCriticalPointProcess :: Double
epsNearCriticalPointProcess = Double
1e-5

-- Also the distribution of origins needs a Tailor expansion for near critical values.
epsNearCriticalTimeOfOrigin :: Double
epsNearCriticalTimeOfOrigin :: Double
epsNearCriticalTimeOfOrigin = Double
1e-8

-- Require critical process if birth and death rates are closer than this value.
eps :: Double
eps :: Double
eps = Double
1e-12

(=~=) :: Double -> Double -> Bool
Double
x =~= :: Double -> Double -> Bool
=~= Double
y = Double
eps Double -> Double -> Bool
forall a. Ord a => a -> a -> Bool
> Double -> Double
forall a. Num a => a -> a
abs (Double
x Double -> Double -> Double
forall a. Num a => a -> a -> a
- Double
y)

-- Sort a list and also return original indices.
sortListWithIndices :: Ord a => [a] -> [(a, Int)]
sortListWithIndices :: [a] -> [(a, Int)]
sortListWithIndices [a]
xs = ((a, Int) -> (a, Int) -> Ordering) -> [(a, Int)] -> [(a, Int)]
forall a. (a -> a -> Ordering) -> [a] -> [a]
sortBy (a -> a -> Ordering
forall a. Ord a => a -> a -> Ordering
compare (a -> a -> Ordering)
-> ((a, Int) -> a) -> (a, Int) -> (a, Int) -> Ordering
forall b c a. (b -> b -> c) -> (a -> b) -> a -> a -> c
`on` (a, Int) -> a
forall a b. (a, b) -> a
fst) ([(a, Int)] -> [(a, Int)]) -> [(a, Int)] -> [(a, Int)]
forall a b. (a -> b) -> a -> b
$ [a] -> [Int] -> [(a, Int)]
forall a b. [a] -> [b] -> [(a, b)]
zip [a]
xs ([Int
0 ..] :: [Int])

-- Insert element into random position of list.
randomInsertList :: PrimMonad m => a -> [a] -> Gen (PrimState m) -> m [a]
randomInsertList :: a -> [a] -> Gen (PrimState m) -> m [a]
randomInsertList a
e [a]
v Gen (PrimState m)
g = do
  let l :: Int
l = [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [a]
v
  Int
i <- (Int, Int) -> Gen (PrimState m) -> m Int
forall a (m :: * -> *).
(Variate a, PrimMonad m) =>
(a, a) -> Gen (PrimState m) -> m a
uniformR (Int
0, Int
l) Gen (PrimState m)
g
  [a] -> m [a]
forall (m :: * -> *) a. Monad m => a -> m a
return ([a] -> m [a]) -> [a] -> m [a]
forall a b. (a -> b) -> a -> b
$ Int -> [a] -> [a]
forall a. Int -> [a] -> [a]
take Int
i [a]
v [a] -> [a] -> [a]
forall a. [a] -> [a] -> [a]
++ [a
e] [a] -> [a] -> [a]
forall a. [a] -> [a] -> [a]
++ Int -> [a] -> [a]
forall a. Int -> [a] -> [a]
drop Int
i [a]
v

-- | A __point process__ for \(n\) points and of age \(t_{or}\) is defined as
-- follows. Draw $n$ points on the horizontal axis at \(1,2,\ldots,n\). Pick
-- \(n-1\) points at locations \((i+1/2, s_i)\), \(i=1,2,\ldots,n-1\);
-- \(0 < s_i < t_{or}\). There is a bijection between (ranked) oriented trees
-- and the point process. Usually, a will be 'String' (or 'Int') and b will be
-- 'Double'.
data PointProcess a b = PointProcess
  { PointProcess a b -> [a]
points :: ![a],
    PointProcess a b -> [b]
values :: ![b],
    PointProcess a b -> b
origin :: !b
  }
  deriving (ReadPrec [PointProcess a b]
ReadPrec (PointProcess a b)
Int -> ReadS (PointProcess a b)
ReadS [PointProcess a b]
(Int -> ReadS (PointProcess a b))
-> ReadS [PointProcess a b]
-> ReadPrec (PointProcess a b)
-> ReadPrec [PointProcess a b]
-> Read (PointProcess a b)
forall a.
(Int -> ReadS a)
-> ReadS [a] -> ReadPrec a -> ReadPrec [a] -> Read a
forall a b. (Read a, Read b) => ReadPrec [PointProcess a b]
forall a b. (Read a, Read b) => ReadPrec (PointProcess a b)
forall a b. (Read a, Read b) => Int -> ReadS (PointProcess a b)
forall a b. (Read a, Read b) => ReadS [PointProcess a b]
readListPrec :: ReadPrec [PointProcess a b]
$creadListPrec :: forall a b. (Read a, Read b) => ReadPrec [PointProcess a b]
readPrec :: ReadPrec (PointProcess a b)
$creadPrec :: forall a b. (Read a, Read b) => ReadPrec (PointProcess a b)
readList :: ReadS [PointProcess a b]
$creadList :: forall a b. (Read a, Read b) => ReadS [PointProcess a b]
readsPrec :: Int -> ReadS (PointProcess a b)
$creadsPrec :: forall a b. (Read a, Read b) => Int -> ReadS (PointProcess a b)
Read, Int -> PointProcess a b -> ShowS
[PointProcess a b] -> ShowS
PointProcess a b -> String
(Int -> PointProcess a b -> ShowS)
-> (PointProcess a b -> String)
-> ([PointProcess a b] -> ShowS)
-> Show (PointProcess a b)
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
forall a b. (Show a, Show b) => Int -> PointProcess a b -> ShowS
forall a b. (Show a, Show b) => [PointProcess a b] -> ShowS
forall a b. (Show a, Show b) => PointProcess a b -> String
showList :: [PointProcess a b] -> ShowS
$cshowList :: forall a b. (Show a, Show b) => [PointProcess a b] -> ShowS
show :: PointProcess a b -> String
$cshow :: forall a b. (Show a, Show b) => PointProcess a b -> String
showsPrec :: Int -> PointProcess a b -> ShowS
$cshowsPrec :: forall a b. (Show a, Show b) => Int -> PointProcess a b -> ShowS
Show, PointProcess a b -> PointProcess a b -> Bool
(PointProcess a b -> PointProcess a b -> Bool)
-> (PointProcess a b -> PointProcess a b -> Bool)
-> Eq (PointProcess a b)
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
forall a b.
(Eq a, Eq b) =>
PointProcess a b -> PointProcess a b -> Bool
/= :: PointProcess a b -> PointProcess a b -> Bool
$c/= :: forall a b.
(Eq a, Eq b) =>
PointProcess a b -> PointProcess a b -> Bool
== :: PointProcess a b -> PointProcess a b -> Bool
$c== :: forall a b.
(Eq a, Eq b) =>
PointProcess a b -> PointProcess a b -> Bool
Eq)

-- | Tree height specification.
data TimeSpec
  = -- | Sample time of origin from respective distribution.
    Random
  | -- | Condition on time of origin.
    Origin Time
  | -- | Condition on time of most recent common ancestor (MRCA).
    Mrca Time

-- | Sample a point process using the 'BirthDeathDistribution'. The names of the
-- points will be integers.
simulate ::
  (PrimMonad m) =>
  -- | Number of points (samples).
  Int ->
  -- | Time of origin or MRCA.
  TimeSpec ->
  -- | Birth rate.
  Rate ->
  -- | Death rate.
  Rate ->
  -- | Generator.
  Gen (PrimState m) ->
  m (PointProcess Int Double)
simulate :: Int
-> TimeSpec
-> Double
-> Double
-> Gen (PrimState m)
-> m (PointProcess Int Double)
simulate Int
n TimeSpec
ts Double
l Double
m Gen (PrimState m)
g
  | Int
n Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
1 = String -> m (PointProcess Int Double)
forall a. HasCallStack => String -> a
error String
"Number of samples needs to be one or larger."
  | Double
l Double -> Double -> Bool
forall a. Ord a => a -> a -> Bool
< Double
0.0 = String -> m (PointProcess Int Double)
forall a. HasCallStack => String -> a
error String
"Birth rate needs to be positive."
  | Bool
otherwise = case TimeSpec
ts of
    TimeSpec
Random -> Int
-> Double
-> Double
-> Gen (PrimState m)
-> m (PointProcess Int Double)
forall (m :: * -> *).
PrimMonad m =>
Int
-> Double
-> Double
-> Gen (PrimState m)
-> m (PointProcess Int Double)
simulateRandom Int
n Double
l Double
m Gen (PrimState m)
g
    Origin Double
t -> Int
-> Double
-> Double
-> Double
-> Gen (PrimState m)
-> m (PointProcess Int Double)
forall (m :: * -> *).
PrimMonad m =>
Int
-> Double
-> Double
-> Double
-> Gen (PrimState m)
-> m (PointProcess Int Double)
simulateOrigin Int
n Double
t Double
l Double
m Gen (PrimState m)
g
    Mrca Double
t -> Int
-> Double
-> Double
-> Double
-> Gen (PrimState m)
-> m (PointProcess Int Double)
forall (m :: * -> *).
PrimMonad m =>
Int
-> Double
-> Double
-> Double
-> Gen (PrimState m)
-> m (PointProcess Int Double)
simulateMrca Int
n Double
t Double
l Double
m Gen (PrimState m)
g

-- No time of origin given. We also don't need to take care of the conditioning
-- (origin or MRCA).
simulateRandom ::
  PrimMonad m =>
  Int ->
  Double ->
  Double ->
  Gen (PrimState m) ->
  m (PointProcess Int Double)
simulateRandom :: Int
-> Double
-> Double
-> Gen (PrimState m)
-> m (PointProcess Int Double)
simulateRandom Int
n Double
l Double
m Gen (PrimState m)
g
  | -- There is no formula for the over-critical process.
    Double
m Double -> Double -> Bool
forall a. Ord a => a -> a -> Bool
> Double
l =
    String -> m (PointProcess Int Double)
forall a. HasCallStack => String -> a
error
      String
"simulateRandom: Please specify height if mu > lambda."
  | -- For the critical process, we have no idea about the time of origin, but can
    -- use a specially derived distribution.
    Double
m Double -> Double -> Bool
=~= Double
l =
    do
      ![Double]
vs <- Int -> m Double -> m [Double]
forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM (Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) (BirthDeathCriticalNoTimeDistribution
-> Gen (PrimState m) -> m Double
forall d (m :: * -> *).
(ContGen d, PrimMonad m) =>
d -> Gen (PrimState m) -> m Double
D.genContVar (Double -> BirthDeathCriticalNoTimeDistribution
BDCNTD Double
l) Gen (PrimState m)
g)
      -- The length of the root branch will be 0.
      let t :: Double
t = [Double] -> Double
forall (t :: * -> *) a. (Foldable t, Ord a) => t a -> a
maximum [Double]
vs
      PointProcess Int Double -> m (PointProcess Int Double)
forall (m :: * -> *) a. Monad m => a -> m a
return (PointProcess Int Double -> m (PointProcess Int Double))
-> PointProcess Int Double -> m (PointProcess Int Double)
forall a b. (a -> b) -> a -> b
$ [Int] -> [Double] -> Double -> PointProcess Int Double
forall a b. [a] -> [b] -> b -> PointProcess a b
PointProcess [Int
0 .. (Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1)] [Double]
vs Double
t
  | -- For the near critical process, we use a special distribution.
    Double -> Double
forall a. Num a => a -> a
abs (Double
m Double -> Double -> Double
forall a. Num a => a -> a -> a
- Double
l) Double -> Double -> Bool
forall a. Ord a => a -> a -> Bool
<= Double
epsNearCriticalTimeOfOrigin =
    do
      Double
t <- TimeOfOriginNearCriticalDistribution
-> Gen (PrimState m) -> m Double
forall d (m :: * -> *).
(ContGen d, PrimMonad m) =>
d -> Gen (PrimState m) -> m Double
D.genContVar (Int -> Double -> Double -> TimeOfOriginNearCriticalDistribution
TONCD Int
n Double
l Double
m) Gen (PrimState m)
g
      Int
-> Double
-> Double
-> Double
-> Gen (PrimState m)
-> m (PointProcess Int Double)
forall (m :: * -> *).
PrimMonad m =>
Int
-> Double
-> Double
-> Double
-> Gen (PrimState m)
-> m (PointProcess Int Double)
simulateOrigin Int
n Double
t Double
l Double
m Gen (PrimState m)
g
  | -- For a sub-critical branching process, we can use the formula from Tanja Stadler.
    Bool
otherwise =
    do
      Double
t <- TimeOfOriginDistribution -> Gen (PrimState m) -> m Double
forall d (m :: * -> *).
(ContGen d, PrimMonad m) =>
d -> Gen (PrimState m) -> m Double
D.genContVar (Int -> Double -> Double -> TimeOfOriginDistribution
TOD Int
n Double
l Double
m) Gen (PrimState m)
g
      Int
-> Double
-> Double
-> Double
-> Gen (PrimState m)
-> m (PointProcess Int Double)
forall (m :: * -> *).
PrimMonad m =>
Int
-> Double
-> Double
-> Double
-> Gen (PrimState m)
-> m (PointProcess Int Double)
simulateOrigin Int
n Double
t Double
l Double
m Gen (PrimState m)
g

-- Time of origin is given.
simulateOrigin ::
  PrimMonad m =>
  Int ->
  Time ->
  Double ->
  Double ->
  Gen (PrimState m) ->
  m (PointProcess Int Double)
simulateOrigin :: Int
-> Double
-> Double
-> Double
-> Gen (PrimState m)
-> m (PointProcess Int Double)
simulateOrigin Int
n Double
t Double
l Double
m Gen (PrimState m)
g
  | Double
t Double -> Double -> Bool
forall a. Ord a => a -> a -> Bool
< Double
0.0 = String -> m (PointProcess Int Double)
forall a. HasCallStack => String -> a
error String
"simulateOrigin: Time of origin needs to be positive."
  | -- See Stadler, T., & Steel, M. (2019). Swapping birth and death: symmetries
    -- and transformations in phylodynamic models. , (), .
    -- http://dx.doi.org/10.1101/494583. Should be possible now.
    -- -- | m < 0.0   = error "Death rate needs to be positive."
    -- Now, we have three different cases.
    -- 1. The critical branching process.
    -- 2. The near critical branching process.
    -- 3. Normal values :).
    Double
m Double -> Double -> Bool
=~= Double
l = do
    ![Double]
vs <- Int -> m Double -> m [Double]
forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM (Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) (BirthDeathCriticalDistribution -> Gen (PrimState m) -> m Double
forall d (m :: * -> *).
(ContGen d, PrimMonad m) =>
d -> Gen (PrimState m) -> m Double
D.genContVar (Double -> Double -> BirthDeathCriticalDistribution
BDCD Double
t Double
l) Gen (PrimState m)
g)
    PointProcess Int Double -> m (PointProcess Int Double)
forall (m :: * -> *) a. Monad m => a -> m a
return (PointProcess Int Double -> m (PointProcess Int Double))
-> PointProcess Int Double -> m (PointProcess Int Double)
forall a b. (a -> b) -> a -> b
$ [Int] -> [Double] -> Double -> PointProcess Int Double
forall a b. [a] -> [b] -> b -> PointProcess a b
PointProcess [Int
0 .. (Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1)] [Double]
vs Double
t
  | Double -> Double
forall a. Num a => a -> a
abs (Double
m Double -> Double -> Double
forall a. Num a => a -> a -> a
- Double
l) Double -> Double -> Bool
forall a. Ord a => a -> a -> Bool
<= Double
epsNearCriticalPointProcess = do
    ![Double]
vs <- Int -> m Double -> m [Double]
forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM (Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) (BirthDeathNearlyCriticalDistribution
-> Gen (PrimState m) -> m Double
forall d (m :: * -> *).
(ContGen d, PrimMonad m) =>
d -> Gen (PrimState m) -> m Double
D.genContVar (Double -> Double -> Double -> BirthDeathNearlyCriticalDistribution
BDNCD Double
t Double
l Double
m) Gen (PrimState m)
g)
    PointProcess Int Double -> m (PointProcess Int Double)
forall (m :: * -> *) a. Monad m => a -> m a
return (PointProcess Int Double -> m (PointProcess Int Double))
-> PointProcess Int Double -> m (PointProcess Int Double)
forall a b. (a -> b) -> a -> b
$ [Int] -> [Double] -> Double -> PointProcess Int Double
forall a b. [a] -> [b] -> b -> PointProcess a b
PointProcess [Int
0 .. (Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1)] [Double]
vs Double
t
  | Bool
otherwise = do
    ![Double]
vs <- Int -> m Double -> m [Double]
forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM (Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) (BirthDeathDistribution -> Gen (PrimState m) -> m Double
forall d (m :: * -> *).
(ContGen d, PrimMonad m) =>
d -> Gen (PrimState m) -> m Double
D.genContVar (Double -> Double -> Double -> BirthDeathDistribution
BDD Double
t Double
l Double
m) Gen (PrimState m)
g)
    PointProcess Int Double -> m (PointProcess Int Double)
forall (m :: * -> *) a. Monad m => a -> m a
return (PointProcess Int Double -> m (PointProcess Int Double))
-> PointProcess Int Double -> m (PointProcess Int Double)
forall a b. (a -> b) -> a -> b
$ [Int] -> [Double] -> Double -> PointProcess Int Double
forall a b. [a] -> [b] -> b -> PointProcess a b
PointProcess [Int
0 .. (Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1)] [Double]
vs Double
t

-- Time of Mrca is given.
simulateMrca ::
  PrimMonad m =>
  Int ->
  Time ->
  Double ->
  Double ->
  Gen (PrimState m) ->
  m (PointProcess Int Double)
simulateMrca :: Int
-> Double
-> Double
-> Double
-> Gen (PrimState m)
-> m (PointProcess Int Double)
simulateMrca Int
n Double
t Double
l Double
m Gen (PrimState m)
g
  | Double
t Double -> Double -> Bool
forall a. Ord a => a -> a -> Bool
< Double
0.0 = String -> m (PointProcess Int Double)
forall a. HasCallStack => String -> a
error String
"simulateMrca: Time of MRCA needs to be positive."
  | Double
m Double -> Double -> Bool
=~= Double
l = do
    ![Double]
vs <- Int -> m Double -> m [Double]
forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM (Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
2) (BirthDeathCriticalDistribution -> Gen (PrimState m) -> m Double
forall d (m :: * -> *).
(ContGen d, PrimMonad m) =>
d -> Gen (PrimState m) -> m Double
D.genContVar (Double -> Double -> BirthDeathCriticalDistribution
BDCD Double
t Double
l) Gen (PrimState m)
g)
    [Double]
vs' <- Double -> [Double] -> Gen (PrimState m) -> m [Double]
forall (m :: * -> *) a.
PrimMonad m =>
a -> [a] -> Gen (PrimState m) -> m [a]
randomInsertList Double
t [Double]
vs Gen (PrimState m)
g
    PointProcess Int Double -> m (PointProcess Int Double)
forall (m :: * -> *) a. Monad m => a -> m a
return (PointProcess Int Double -> m (PointProcess Int Double))
-> PointProcess Int Double -> m (PointProcess Int Double)
forall a b. (a -> b) -> a -> b
$ [Int] -> [Double] -> Double -> PointProcess Int Double
forall a b. [a] -> [b] -> b -> PointProcess a b
PointProcess [Int
0 .. (Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1)] [Double]
vs' Double
t
  | Double -> Double
forall a. Num a => a -> a
abs (Double
m Double -> Double -> Double
forall a. Num a => a -> a -> a
- Double
l) Double -> Double -> Bool
forall a. Ord a => a -> a -> Bool
<= Double
epsNearCriticalPointProcess = do
    ![Double]
vs <- Int -> m Double -> m [Double]
forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM (Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
2) (BirthDeathNearlyCriticalDistribution
-> Gen (PrimState m) -> m Double
forall d (m :: * -> *).
(ContGen d, PrimMonad m) =>
d -> Gen (PrimState m) -> m Double
D.genContVar (Double -> Double -> Double -> BirthDeathNearlyCriticalDistribution
BDNCD Double
t Double
l Double
m) Gen (PrimState m)
g)
    [Double]
vs' <- Double -> [Double] -> Gen (PrimState m) -> m [Double]
forall (m :: * -> *) a.
PrimMonad m =>
a -> [a] -> Gen (PrimState m) -> m [a]
randomInsertList Double
t [Double]
vs Gen (PrimState m)
g
    PointProcess Int Double -> m (PointProcess Int Double)
forall (m :: * -> *) a. Monad m => a -> m a
return (PointProcess Int Double -> m (PointProcess Int Double))
-> PointProcess Int Double -> m (PointProcess Int Double)
forall a b. (a -> b) -> a -> b
$ [Int] -> [Double] -> Double -> PointProcess Int Double
forall a b. [a] -> [b] -> b -> PointProcess a b
PointProcess [Int
0 .. (Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1)] [Double]
vs' Double
t
  | Bool
otherwise = do
    ![Double]
vs <- Int -> m Double -> m [Double]
forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM (Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
2) (BirthDeathDistribution -> Gen (PrimState m) -> m Double
forall d (m :: * -> *).
(ContGen d, PrimMonad m) =>
d -> Gen (PrimState m) -> m Double
D.genContVar (Double -> Double -> Double -> BirthDeathDistribution
BDD Double
t Double
l Double
m) Gen (PrimState m)
g)
    [Double]
vs' <- Double -> [Double] -> Gen (PrimState m) -> m [Double]
forall (m :: * -> *) a.
PrimMonad m =>
a -> [a] -> Gen (PrimState m) -> m [a]
randomInsertList Double
t [Double]
vs Gen (PrimState m)
g
    PointProcess Int Double -> m (PointProcess Int Double)
forall (m :: * -> *) a. Monad m => a -> m a
return (PointProcess Int Double -> m (PointProcess Int Double))
-> PointProcess Int Double -> m (PointProcess Int Double)
forall a b. (a -> b) -> a -> b
$ [Int] -> [Double] -> Double -> PointProcess Int Double
forall a b. [a] -> [b] -> b -> PointProcess a b
PointProcess [Int
0 .. (Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1)] [Double]
vs' Double
t

-- Sort the values of a point process and their indices to be (the indices
-- that they will have while creating the tree).
sortPP :: (Ord b) => PointProcess a b -> ([b], [Int])
sortPP :: PointProcess a b -> ([b], [Int])
sortPP (PointProcess [a]
_ [b]
vs b
_) = ([b]
vsSorted, [Int]
isSorted)
  where
    vsIsSorted :: [(b, Int)]
vsIsSorted = [b] -> [(b, Int)]
forall a. Ord a => [a] -> [(a, Int)]
sortListWithIndices [b]
vs
    vsSorted :: [b]
vsSorted = ((b, Int) -> b) -> [(b, Int)] -> [b]
forall a b. (a -> b) -> [a] -> [b]
map (b, Int) -> b
forall a b. (a, b) -> a
fst [(b, Int)]
vsIsSorted
    isSorted :: [Int]
isSorted = [Int] -> [Int]
flattenIndices ([Int] -> [Int]) -> [Int] -> [Int]
forall a b. (a -> b) -> a -> b
$ ((b, Int) -> Int) -> [(b, Int)] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map (b, Int) -> Int
forall a b. (a, b) -> b
snd [(b, Int)]
vsIsSorted

-- Decrement indices that are above the one that is merged.
flattenIndices :: [Int] -> [Int]
flattenIndices :: [Int] -> [Int]
flattenIndices [Int]
is = ([Int], [Int]) -> [Int]
forall a b. (a, b) -> b
snd (([Int], [Int]) -> [Int]) -> ([Int], [Int]) -> [Int]
forall a b. (a -> b) -> a -> b
$ ([Int] -> Int -> ([Int], Int)) -> [Int] -> [Int] -> ([Int], [Int])
forall (t :: * -> *) a b c.
Traversable t =>
(a -> b -> (a, c)) -> a -> t b -> (a, t c)
mapAccumL [Int] -> Int -> ([Int], Int)
fAcc [] [Int]
is

-- NOTE: fAcc is the speed bottleneck for simulating large trees.
--
-- The accumulating function. Count the number of indices which are before the
-- current index and lower than the current index.
fAcc :: [Int] -> Int -> ([Int], Int)
fAcc :: [Int] -> Int -> ([Int], Int)
fAcc [Int]
is Int
i = (Int
i Int -> [Int] -> [Int]
forall a. a -> [a] -> [a]
: [Int]
is, Int
i') where i' :: Int
i' = Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
- [Int] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ((Int -> Bool) -> [Int] -> [Int]
forall a. (a -> Bool) -> [a] -> [a]
filter (Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
i) [Int]
is)

-- | See 'simulateReconstructedTree', but n times.
simulateNReconstructedTrees ::
  (PrimMonad m) =>
  -- | Number of trees
  Int ->
  -- | Number of points (samples)
  Int ->
  -- | Time of origin or MRCA
  TimeSpec ->
  -- | Birth rate
  Rate ->
  -- | Death rate
  Rate ->
  -- | Generator (see 'System.Random.MWC')
  Gen (PrimState m) ->
  m (Forest Length Int)
simulateNReconstructedTrees :: Int
-> Int
-> TimeSpec
-> Double
-> Double
-> Gen (PrimState m)
-> m (Forest Length Int)
simulateNReconstructedTrees Int
nT Int
nP TimeSpec
t Double
l Double
m Gen (PrimState m)
g
  | Int
nT Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
0 = Forest Length Int -> m (Forest Length Int)
forall (m :: * -> *) a. Monad m => a -> m a
return []
  | Bool
otherwise = Int -> m (Tree Length Int) -> m (Forest Length Int)
forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM Int
nT (m (Tree Length Int) -> m (Forest Length Int))
-> m (Tree Length Int) -> m (Forest Length Int)
forall a b. (a -> b) -> a -> b
$ Int
-> TimeSpec
-> Double
-> Double
-> Gen (PrimState m)
-> m (Tree Length Int)
forall (m :: * -> *).
PrimMonad m =>
Int
-> TimeSpec
-> Double
-> Double
-> Gen (PrimState m)
-> m (Tree Length Int)
simulateReconstructedTree Int
nP TimeSpec
t Double
l Double
m Gen (PrimState m)
g

-- | Use the point process to simulate a reconstructed tree (see
-- 'toReconstructedTree') possibly with specific height and a fixed number of
-- leaves according to the birth and death process.
simulateReconstructedTree ::
  (PrimMonad m) =>
  -- | Number of points (samples)
  Int ->
  -- | Time of origin or MRCA
  TimeSpec ->
  -- | Birth rate
  Rate ->
  -- | Death rate
  Rate ->
  -- | Generator (see 'System.Random.MWC')
  Gen (PrimState m) ->
  m (Tree Length Int)
simulateReconstructedTree :: Int
-> TimeSpec
-> Double
-> Double
-> Gen (PrimState m)
-> m (Tree Length Int)
simulateReconstructedTree Int
n TimeSpec
t Double
l Double
m Gen (PrimState m)
g = do
  PointProcess [Int]
ns [Double]
vs Double
o <- Int
-> TimeSpec
-> Double
-> Double
-> Gen (PrimState m)
-> m (PointProcess Int Double)
forall (m :: * -> *).
PrimMonad m =>
Int
-> TimeSpec
-> Double
-> Double
-> Gen (PrimState m)
-> m (PointProcess Int Double)
simulate Int
n TimeSpec
t Double
l Double
m Gen (PrimState m)
g
  Tree Length Int -> m (Tree Length Int)
forall (m :: * -> *) a. Monad m => a -> m a
return (Tree Length Int -> m (Tree Length Int))
-> Tree Length Int -> m (Tree Length Int)
forall a b. (a -> b) -> a -> b
$ Int -> PointProcess Int Length -> Tree Length Int
forall a. a -> PointProcess a Length -> Tree Length a
toReconstructedTree Int
0 (PointProcess Int Length -> Tree Length Int)
-> PointProcess Int Length -> Tree Length Int
forall a b. (a -> b) -> a -> b
$ [Int] -> [Length] -> Length -> PointProcess Int Length
forall a b. [a] -> [b] -> b -> PointProcess a b
PointProcess [Int]
ns ((Double -> Length) -> [Double] -> [Length]
forall a b. (a -> b) -> [a] -> [b]
map Double -> Length
toLengthUnsafe [Double]
vs) (Double -> Length
toLengthUnsafe Double
o)

-- | Convert a point process to a reconstructed tree. See Lemma 2.2.

-- Of course, I decided to only use one tree structure with extinct and extant
-- leaves (actually a complete tree). So a tree created here just does not
-- contain extinct leaves. A function 'isReconstructed' is provided to test if a
-- tree is reconstructed (and not complete) in this sense. However, a complete
-- tree might show up as "reconstructed", just because, by chance, it does not
-- contain extinct leaves. I wanted to use a Monoid constraint to get the unit
-- element, but this fails for classical 'Int's. So, I rather have another
-- (useless) argument.
toReconstructedTree ::
  a -> -- Default node label.
  PointProcess a Length ->
  Tree Length a
toReconstructedTree :: a -> PointProcess a Length -> Tree Length a
toReconstructedTree a
l pp :: PointProcess a Length
pp@(PointProcess [a]
ps [Length]
vs Length
o)
  | [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [a]
ps Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= [Length] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Length]
vs Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1 = String -> Tree Length a
forall a. HasCallStack => String -> a
error String
"Too few or too many points."
  | [Length] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Length]
vs Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
1 = String -> Tree Length a
forall a. HasCallStack => String -> a
error String
"Too few values."
  | -- -- Test is deactivated.
    -- -- | otherwise = if isReconstructed treeOrigin then treeOrigin else error "Error in algorithm."
    Bool
otherwise =
    Tree Length a
treeOrigin
  where
    ([Length]
vsSorted, [Int]
isSorted) = PointProcess a Length -> ([Length], [Int])
forall b a. Ord b => PointProcess a b -> ([b], [Int])
sortPP PointProcess a Length
pp
    !lvs :: Seq (Tree Length a)
lvs = [Tree Length a] -> Seq (Tree Length a)
forall a. [a] -> Seq a
S.fromList [Length -> a -> [Tree Length a] -> Tree Length a
forall e a. e -> a -> Forest e a -> Tree e a
Node Length
0 a
p [] | a
p <- [a]
ps]
    !heights :: Seq Length
heights = Int -> Length -> Seq Length
forall a. Int -> a -> Seq a
S.replicate ([a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [a]
ps) Length
0
    !treeRoot :: Tree Length a
treeRoot = [Int]
-> [Length]
-> a
-> Seq (Tree Length a)
-> Seq Length
-> Tree Length a
forall a.
[Int]
-> [Length]
-> a
-> Seq (Tree Length a)
-> Seq Length
-> Tree Length a
toReconstructedTree' [Int]
isSorted [Length]
vsSorted a
l Seq (Tree Length a)
lvs Seq Length
heights
    !h :: Length
h = [Length] -> Length
forall a. [a] -> a
last [Length]
vsSorted
    !treeOrigin :: Tree Length a
treeOrigin = (Length -> Length) -> Tree Length a -> Tree Length a
forall e a. (e -> e) -> Tree e a -> Tree e a
modifyStem (Length -> Length -> Length
forall a. Num a => a -> a -> a
+ (Length
o Length -> Length -> Length
forall a. Num a => a -> a -> a
- Length
h)) Tree Length a
treeRoot

-- Move up the tree, connect nodes when they join according to the point process.
toReconstructedTree' ::
  [Int] -> -- Sorted indices, see 'sort'.
  [Length] -> -- Sorted merge values.
  a -> -- Default node label.
  Seq (Tree Length a) -> -- Leaves with accumulated root branch lengths.
  Seq Length -> -- Accumulated heights of the leaves.
  Tree Length a
toReconstructedTree' :: [Int]
-> [Length]
-> a
-> Seq (Tree Length a)
-> Seq Length
-> Tree Length a
toReconstructedTree' [] [] a
_ Seq (Tree Length a)
trs Seq Length
_ = Seq (Tree Length a)
trs Seq (Tree Length a) -> Int -> Tree Length a
forall a. Seq a -> Int -> a
`S.index` Int
0
toReconstructedTree' [Int]
is [Length]
vs a
l Seq (Tree Length a)
trs Seq Length
hs = [Int]
-> [Length]
-> a
-> Seq (Tree Length a)
-> Seq Length
-> Tree Length a
forall a.
[Int]
-> [Length]
-> a
-> Seq (Tree Length a)
-> Seq Length
-> Tree Length a
toReconstructedTree' [Int]
is' [Length]
vs' a
l Seq (Tree Length a)
trs'' Seq Length
hs'
  where
    -- For the algorithm, see 'simulate' but index starts at zero.

    !i :: Int
i = [Int] -> Int
forall a. [a] -> a
head [Int]
is
    !is' :: [Int]
is' = [Int] -> [Int]
forall a. [a] -> [a]
tail [Int]
is
    !v :: Length
v = [Length] -> Length
forall a. [a] -> a
head [Length]
vs
    !vs' :: [Length]
vs' = [Length] -> [Length]
forall a. [a] -> [a]
tail [Length]
vs
    -- Left: l, right: r.
    !hl :: Length
hl = Seq Length
hs Seq Length -> Int -> Length
forall a. Seq a -> Int -> a
`S.index` Int
i
    !hr :: Length
hr = Seq Length
hs Seq Length -> Int -> Length
forall a. Seq a -> Int -> a
`S.index` (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)
    !dvl :: Length
dvl = Length
v Length -> Length -> Length
forall a. Num a => a -> a -> a
- Length
hl
    !dvr :: Length
dvr = Length
v Length -> Length -> Length
forall a. Num a => a -> a -> a
- Length
hr
    !tl :: Tree Length a
tl = (Length -> Length) -> Tree Length a -> Tree Length a
forall e a. (e -> e) -> Tree e a -> Tree e a
modifyStem (Length -> Length -> Length
forall a. Num a => a -> a -> a
+ Length
dvl) (Tree Length a -> Tree Length a) -> Tree Length a -> Tree Length a
forall a b. (a -> b) -> a -> b
$ Seq (Tree Length a)
trs Seq (Tree Length a) -> Int -> Tree Length a
forall a. Seq a -> Int -> a
`S.index` Int
i
    !tr :: Tree Length a
tr = (Length -> Length) -> Tree Length a -> Tree Length a
forall e a. (e -> e) -> Tree e a -> Tree e a
modifyStem (Length -> Length -> Length
forall a. Num a => a -> a -> a
+ Length
dvr) (Tree Length a -> Tree Length a) -> Tree Length a -> Tree Length a
forall a b. (a -> b) -> a -> b
$ Seq (Tree Length a)
trs Seq (Tree Length a) -> Int -> Tree Length a
forall a. Seq a -> Int -> a
`S.index` (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)
    !h' :: Length
h' = Length
hl Length -> Length -> Length
forall a. Num a => a -> a -> a
+ Length
dvl -- Should be the same as 'hr + dvr'.
    !tm :: Tree Length a
tm = Length -> a -> Forest Length a -> Tree Length a
forall e a. e -> a -> Forest e a -> Tree e a
Node Length
0 a
l [Tree Length a
tl, Tree Length a
tr]
    !trs'' :: Seq (Tree Length a)
trs'' = (Int -> Seq (Tree Length a) -> Seq (Tree Length a)
forall a. Int -> Seq a -> Seq a
S.take Int
i Seq (Tree Length a)
trs Seq (Tree Length a) -> Tree Length a -> Seq (Tree Length a)
forall a. Seq a -> a -> Seq a
S.|> Tree Length a
tm) Seq (Tree Length a) -> Seq (Tree Length a) -> Seq (Tree Length a)
forall a. Seq a -> Seq a -> Seq a
S.>< Int -> Seq (Tree Length a) -> Seq (Tree Length a)
forall a. Int -> Seq a -> Seq a
S.drop (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
2) Seq (Tree Length a)
trs
    !hs' :: Seq Length
hs' = (Int -> Seq Length -> Seq Length
forall a. Int -> Seq a -> Seq a
S.take Int
i Seq Length
hs Seq Length -> Length -> Seq Length
forall a. Seq a -> a -> Seq a
S.|> Length
h') Seq Length -> Seq Length -> Seq Length
forall a. Seq a -> Seq a -> Seq a
S.>< Int -> Seq Length -> Seq Length
forall a. Int -> Seq a -> Seq a
S.drop (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
2) Seq Length
hs