{-# LANGUAGE BangPatterns #-}
module ELynx.Simulate.PointProcess
( PointProcess (..),
TimeSpec,
simulate,
toReconstructedTree,
simulateReconstructedTree,
simulateNReconstructedTrees,
)
where
import Control.Monad
import Control.Monad.Primitive
import Data.Function
import Data.List
import Data.Sequence (Seq)
import qualified Data.Sequence as S
import ELynx.Data.Tree.Measurable
import ELynx.Data.Tree.Phylogeny
import ELynx.Data.Tree.Rooted
import ELynx.Distribution.BirthDeath
import ELynx.Distribution.BirthDeathCritical
import ELynx.Distribution.BirthDeathCriticalNoTime
import ELynx.Distribution.BirthDeathNearlyCritical
import ELynx.Distribution.TimeOfOrigin
import ELynx.Distribution.TimeOfOriginNearCritical
import ELynx.Distribution.Types
import qualified Statistics.Distribution as D
( genContVar,
)
import System.Random.MWC
epsNearCriticalPointProcess :: Double
epsNearCriticalPointProcess = 1e-5
epsNearCriticalTimeOfOrigin :: Double
epsNearCriticalTimeOfOrigin = 1e-8
eps :: Double
eps = 1e-12
(=~=) :: Double -> Double -> Bool
x =~= y = eps > abs (x - y)
sortListWithIndices :: Ord a => [a] -> [(a, Int)]
sortListWithIndices xs = sortBy (compare `on` fst) $ zip xs ([0 ..] :: [Int])
randomInsertList :: PrimMonad m => a -> [a] -> Gen (PrimState m) -> m [a]
randomInsertList e v g = do
let l = length v
i <- uniformR (0, l) g
return $ take i v ++ [e] ++ drop i v
data PointProcess a b = PointProcess
{ points :: ![a],
values :: ![b],
origin :: !b
}
deriving (Read, Show, Eq)
type TimeSpec = Maybe (Time, Bool)
simulate ::
(PrimMonad m) =>
Int ->
TimeSpec ->
Rate ->
Rate ->
Gen (PrimState m) ->
m (PointProcess Int Double)
simulate n Nothing l m g
|
m > l =
error
"Time of origin distribution formula not available when mu > lambda. Please specify height for the moment."
|
m =~= l =
do
!vs <- replicateM (n - 1) (D.genContVar (BDCNTD l) g)
let t = maximum vs
return $ PointProcess [0 .. (n - 1)] vs t
|
abs (m - l) <= epsNearCriticalTimeOfOrigin =
do
t <- D.genContVar (TONCD n l m) g
simulate n (Just (t, False)) l m g
|
otherwise =
do
t <- D.genContVar (TOD n l m) g
simulate n (Just (t, False)) l m g
simulate n (Just (t, c)) l m g
| n < 1 = error "Number of samples needs to be one or larger."
| t < 0.0 = error "Time of origin needs to be positive."
| l < 0.0 = error "Birth rate needs to be positive."
|
(m =~= l) && not c = do
!vs <- replicateM (n - 1) (D.genContVar (BDCD t l) g)
return $ PointProcess [0 .. (n - 1)] vs t
| (abs (m - l) <= epsNearCriticalPointProcess) && not c = do
!vs <- replicateM (n - 1) (D.genContVar (BDNCD t l m) g)
return $ PointProcess [0 .. (n - 1)] vs t
| not c = do
!vs <- replicateM (n - 1) (D.genContVar (BDD t l m) g)
return $ PointProcess [0 .. (n - 1)] vs t
| (m =~= l) && c = do
!vs <- replicateM (n - 2) (D.genContVar (BDCD t l) g)
vs' <- randomInsertList t vs g
return $ PointProcess [0 .. (n - 1)] vs' t
| (abs (m - l) <= epsNearCriticalPointProcess) && c = do
!vs <- replicateM (n - 2) (D.genContVar (BDNCD t l m) g)
vs' <- randomInsertList t vs g
return $ PointProcess [0 .. (n - 1)] vs' t
| c = do
!vs <- replicateM (n - 2) (D.genContVar (BDD t l m) g)
vs' <- randomInsertList t vs g
return $ PointProcess [0 .. (n - 1)] vs' t
| otherwise = error "simulate: Fell through guard, this should never happen."
sortPP :: (Ord b) => PointProcess a b -> ([b], [Int])
sortPP (PointProcess _ vs _) = (vsSorted, isSorted)
where
vsIsSorted = sortListWithIndices vs
vsSorted = map fst vsIsSorted
isSorted = flattenIndices $ map snd vsIsSorted
flattenIndices :: [Int] -> [Int]
flattenIndices is = snd $ mapAccumL fAcc [] is
fAcc :: [Int] -> Int -> ([Int], Int)
fAcc is i = (i : is, i') where i' = i - length (filter (< i) is)
simulateNReconstructedTrees ::
(PrimMonad m) =>
Int ->
Int ->
TimeSpec ->
Rate ->
Rate ->
Gen (PrimState m) ->
m (Forest Length Int)
simulateNReconstructedTrees nT nP t l m g
| nT <= 0 = return []
| otherwise = replicateM nT $ simulateReconstructedTree nP t l m g
simulateReconstructedTree ::
(PrimMonad m) =>
Int ->
TimeSpec ->
Rate ->
Rate ->
Gen (PrimState m) ->
m (Tree Length Int)
simulateReconstructedTree n t l m g =
toReconstructedTree 0 <$> simulate n t l m g
toReconstructedTree ::
a ->
PointProcess a Double ->
Tree Length a
toReconstructedTree l pp@(PointProcess ps vs o)
| length ps /= length vs + 1 = error "Too few or too many points."
| length vs <= 1 = error "Too few values."
|
otherwise =
treeOrigin
where
(vsSorted, isSorted) = sortPP pp
!lvs = S.fromList [Node (Length 0) p [] | p <- ps]
!heights = S.replicate (length ps) 0
!treeRoot = toReconstructedTree' isSorted vsSorted l lvs heights
!h = last vsSorted
!treeOrigin = applyStem (+ (o - h)) treeRoot
toReconstructedTree' ::
[Int] ->
[Double] ->
a ->
Seq (Tree Length a) ->
Seq Double ->
Tree Length a
toReconstructedTree' [] [] _ trs _ = trs `S.index` 0
toReconstructedTree' is vs l trs hs = toReconstructedTree' is' vs' l trs'' hs'
where
!i = head is
!is' = tail is
!v = head vs
!vs' = tail vs
!hl = hs `S.index` i
!hr = hs `S.index` (i + 1)
!dvl = v - hl
!dvr = v - hr
!tl = applyStem (+ dvl) $ trs `S.index` i
!tr = applyStem (+ dvr) $ trs `S.index` (i + 1)
!h' = hl + dvl
!tm = Node (Length 0) l [tl, tr]
!trs'' = (S.take i trs S.|> tm) S.>< S.drop (i + 2) trs
!hs' = (S.take i hs S.|> h') S.>< S.drop (i + 2) hs