{-# LANGUAGE Rank2Types #-}
module Bio.Bam.Pileup
( PrimChunks(..)
, PrimBase(..)
, PosPrimChunks
, DamagedBase(..)
, DmgToken(..)
, dissect
, CallStats(..)
, V_Nuc(..)
, V_Nucs(..)
, IndelVariant(..)
, BasePile
, IndelPile
, Pile'(..)
, Pile
, pileup
) where
import Bio.Bam.Header
import Bio.Bam.Rec
import Bio.Prelude
import Bio.Streaming
import qualified Data.ByteString as B
import qualified Data.Vector.Generic as V
import qualified Data.Vector.Storable as U
import qualified Streaming.Prelude as Q
data PrimChunks = Seek {-# UNPACK #-} !Int PrimBase
| Indel [Nucleotides] [DamagedBase] PrimBase
| EndOfRead
deriving Show
data PrimBase = Base { _pb_wait :: {-# UNPACK #-} !Int
, _pb_base :: {-# UNPACK #-} !DamagedBase
, _pb_mapq :: {-# UNPACK #-} !Qual
, _pb_chunks :: PrimChunks }
deriving Show
type PosPrimChunks = (Refseq, Int, Bool, PrimChunks)
data DamagedBase = DB { db_call :: {-# UNPACK #-} !Nucleotide
, db_qual :: {-# UNPACK #-} !Qual
, db_dmg_tk :: {-# UNPACK #-} !DmgToken
, db_dmg_pos :: {-# UNPACK #-} !Int
, db_ref :: {-# UNPACK #-} !Nucleotides }
newtype DmgToken = DmgToken { fromDmgToken :: Int }
instance Show DamagedBase where
showsPrec _ (DB n q _ _ r)
| nucToNucs n == r = shows n . (:) '@' . shows q
| otherwise = shows n . (:) '/' . shows r . (:) '@' . shows q
{-# INLINE dissect #-}
dissect :: DmgToken -> BamRaw -> [PosPrimChunks]
dissect dtok br =
if isUnmapped b || isDuplicate b || not (isValidRefseq b_rname)
then [] else [(b_rname, b_pos, isReversed b, pchunks)]
where
b@BamRec{..} = unpackBam br
pchunks = firstBase b_pos 0 0 (fromMaybe [] $ getMd b)
!max_cig = V.length b_cigar
!max_seq = V.length b_seq
!baq = extAsString "BQ" b
get_seq :: Int -> (Nucleotides -> Nucleotides) -> DamagedBase
get_seq i f = case b_seq `V.unsafeIndex` i of
n | n == nucsA -> DB nucA qe dtok dmg (f n)
| n == nucsC -> DB nucC qe dtok dmg (f n)
| n == nucsG -> DB nucG qe dtok dmg (f n)
| n == nucsT -> DB nucT qe dtok dmg (f n)
| otherwise -> DB nucA (Q 0) dtok dmg (f n)
where
!q = maybe (Q 23) (`V.unsafeIndex` i) b_qual
!q' | i >= B.length baq = q
| otherwise = Q (unQ q + (B.index baq i - 64))
!qe = min q' b_mapq
!dmg = if i+i > max_seq then i-max_seq else i
firstBase :: Int -> Int -> Int -> [MdOp] -> PrimChunks
firstBase !pos !is !ic mds
| is >= max_seq || ic >= max_cig = EndOfRead
| otherwise = case b_cigar `V.unsafeIndex` ic of
Ins :* cl -> firstBase pos (cl+is) (ic+1) mds
SMa :* cl -> firstBase pos (cl+is) (ic+1) mds
Del :* cl -> firstBase (pos+cl) is (ic+1) (drop_del cl mds)
Nop :* cl -> firstBase (pos+cl) is (ic+1) mds
HMa :* _ -> firstBase pos is (ic+1) mds
Pad :* _ -> firstBase pos is (ic+1) mds
Mat :* 0 -> firstBase pos is (ic+1) mds
Mat :* _ -> Seek pos $ nextBase 0 pos is ic 0 mds
where
drop_del n (MdDel ns : mds')
| n < length ns = MdDel (drop n ns) : mds'
| n > length ns = drop_del (n - length ns) mds'
| otherwise = mds'
drop_del n (MdNum 0 : mds') = drop_del n mds'
drop_del _ mds' = mds'
nextBase :: Int -> Int -> Int -> Int -> Int -> [MdOp] -> PrimBase
nextBase !wt !pos !is !ic !io mds = case mds of
MdNum 0 : mds' -> nextBase wt pos is ic io mds'
MdDel [] : mds' -> nextBase wt pos is ic io mds'
MdNum 1 : mds' -> nextBase' (get_seq is id ) mds'
MdNum n : mds' -> nextBase' (get_seq is id ) (MdNum (n-1) : mds')
MdRep ref : mds' -> nextBase' (get_seq is $ const ref ) mds'
MdDel _ : _ -> nextBase' (get_seq is $ const nucsN) mds
[ ] -> nextBase' (get_seq is $ const nucsN) [ ]
where
nextBase' ref mds' = Base wt ref b_mapq $ nextIndel [] [] (pos+1) (is+1) ic (io+1) mds'
nextIndel :: [[DamagedBase]] -> [Nucleotides] -> Int -> Int -> Int -> Int -> [MdOp] -> PrimChunks
nextIndel ins del !pos !is !ic !io mds
| is >= max_seq || ic >= max_cig = EndOfRead
| otherwise = case b_cigar `V.unsafeIndex` ic of
Ins :* cl -> nextIndel (isq cl) del pos (cl+is) (ic+1) 0 mds
SMa :* cl -> nextIndel ins del pos (cl+is) (ic+1) 0 mds
Del :* cl -> nextIndel ins (del++dsq) (pos+cl) is (ic+1) 0 mds'
where (dsq,mds') = split_del cl mds
Pad :* _ -> nextIndel ins del pos is (ic+1) 0 mds
HMa :* _ -> nextIndel ins del pos is (ic+1) 0 mds
Nop :* cl -> firstBase (pos+cl) is (ic+1) mds
Mat :* cl | io == cl -> nextIndel ins del pos is (ic+1) 0 mds
| otherwise -> indel del out $ nextBase (length del) pos is ic io mds
where
indel d o k = foldr seq (Indel d o k) o
out = concat $ reverse ins
isq cl = [ get_seq i $ const gap | i <- [is..is+cl-1] ] : ins
split_del n (MdDel ns : mds')
| n < length ns = (take n ns, MdDel (drop n ns) : mds')
| n > length ns = let (ns', mds'') = split_del (n - length ns) mds' in (ns++ns', mds'')
| otherwise = (ns, mds')
split_del n (MdNum 0 : mds') = split_del n mds'
split_del n mds' = (replicate n nucsN, mds')
data CallStats = CallStats
{ read_depth :: {-# UNPACK #-} !Int
, reads_mapq0 :: {-# UNPACK #-} !Int
, sum_mapq :: {-# UNPACK #-} !Int
, sum_mapq_squared :: {-# UNPACK #-} !Int }
deriving (Show, Eq, Generic)
instance Monoid CallStats where
mempty = CallStats { read_depth = 0
, reads_mapq0 = 0
, sum_mapq = 0
, sum_mapq_squared = 0 }
mappend = (<>)
instance Semigroup CallStats where
x <> y = CallStats { read_depth = read_depth x + read_depth y
, reads_mapq0 = reads_mapq0 x + reads_mapq0 y
, sum_mapq = sum_mapq x + sum_mapq y
, sum_mapq_squared = sum_mapq_squared x + sum_mapq_squared y }
newtype V_Nuc = V_Nuc (U.Vector Nucleotide) deriving (Eq, Ord, Show)
newtype V_Nucs = V_Nucs (U.Vector Nucleotides) deriving (Eq, Ord, Show)
data IndelVariant = IndelVariant { deleted_bases :: !V_Nucs, inserted_bases :: !V_Nuc }
deriving (Eq, Ord, Show, Generic)
type BasePile = [DamagedBase]
type IndelPile = [( Qual, ([Nucleotides], [DamagedBase]) )]
data Pile' a b = Pile { p_refseq :: {-# UNPACK #-} !Refseq
, p_pos :: {-# UNPACK #-} !Int
, p_snp_stat :: {-# UNPACK #-} !CallStats
, p_snp_pile :: a
, p_indel_stat :: {-# UNPACK #-} !CallStats
, p_indel_pile :: b }
deriving Show
type Pile = Pile' (BasePile, BasePile) (IndelPile, IndelPile)
{-# INLINE pileup #-}
pileup :: Monad m => Stream (Of PosPrimChunks) m b -> Stream (Of Pile) m b
pileup = runPileM pileup' finish (Refseq 0) 0 ([],[]) (Empty,Empty)
where
finish () _ _ ([],[]) (Empty,Empty) inp = lift (Q.effects inp)
finish () _ _ _ _ _ = error "logic error: leftovers after pileup"
newtype PileM m a = PileM { runPileM :: forall r . (a -> PileF m r) -> PileF m r }
type PileF m r = Refseq -> Int ->
( [PrimBase], [PrimBase] ) ->
( Heap, Heap ) ->
Stream (Of PosPrimChunks) m r ->
Stream (Of Pile) m r
instance Functor (PileM m) where
{-# INLINE fmap #-}
fmap f (PileM m) = PileM $ \k -> m (k . f)
instance Applicative (PileM m) where
{-# INLINE pure #-}
pure a = PileM $ \k -> k a
{-# INLINE (<*>) #-}
u <*> v = PileM $ \k -> runPileM u (\a -> runPileM v (k . a))
instance Monad (PileM m) where
{-# INLINE return #-}
return a = PileM $ \k -> k a
{-# INLINE (>>=) #-}
m >>= k = PileM $ \k' -> runPileM m (\a -> runPileM (k a) k')
{-# INLINE upd_pos #-}
upd_pos :: (Int -> Int) -> PileM m ()
upd_pos f = PileM $ \k r p -> k () r $! f p
{-# INLINE yieldPile #-}
yieldPile :: Monad m => CallStats -> BasePile -> BasePile -> CallStats -> IndelPile -> IndelPile -> PileM m ()
yieldPile x1 x2a x2b x3 x4a x4b = PileM $ \ !kont !r !p !a !w !inp ->
let pile = Pile r p x1 (x2a,x2b) x3 (x4a,x4b)
in Q.cons pile $ kont () r p a w inp
pileup' :: Monad m => PileM m ()
pileup' = PileM $ \ !k !refseq !pos !active !waiting inp0 -> do
inp <- lift $ inspect inp0
let inp1 = either pure wrap inp
cont2 rs po = runPileM pileup'' k rs po active waiting inp1
leave = k () refseq pos active waiting inp1
case (active, getMinKeysH waiting, inp) of
( (_:_,_), _, _ ) -> cont2 refseq pos
( (_,_:_), _, _ ) -> cont2 refseq pos
( _, Just nw, Left _ ) -> cont2 refseq nw
( _, Nothing, Left _ ) -> leave
( _, Nothing, Right ((r,p,_,_):>_) ) -> cont2 r p
( _, Just nw, Right ((r,p,_,_):>_) )
| (refseq,nw) <= (r,p) -> cont2 refseq nw
| otherwise -> cont2 r p
where
getMinKeysH :: (Heap, Heap) -> Maybe Int
getMinKeysH (a,b) = case (getMinKeyH a, getMinKeyH b) of
( Nothing, Nothing ) -> Nothing
( Just x, Nothing ) -> Just x
( Nothing, Just y ) -> Just y
( Just x, Just y ) -> Just (min x y)
pileup'' :: Monad m => PileM m ()
pileup'' = do
p'feed_input
p'check_waiting
((fin_bsL, fin_bpL), (fin_bsR, fin_bpR), (fin_isL, fin_ipL), (fin_isR, fin_ipR)) <- p'scan_active
let uninteresting (_,(d,i)) = null d && null i
unless (null fin_bpL && null fin_bpR && all uninteresting fin_ipL && all uninteresting fin_ipR) $
yieldPile (fin_bsL <> fin_bsR) fin_bpL fin_bpR
(fin_isL <> fin_isR) fin_ipL fin_ipR
upd_pos succ
pileup'
p'feed_input :: Monad m => PileM m ()
p'feed_input = PileM $ \kont rs po ac@(af,ar) wt@(wf,wr) ->
lift . inspect >=> \case
Right ((rs', po', str, prim) :> bs)
| rs == rs' && po == po' ->
case prim of
EndOfRead -> runPileM p'feed_input kont rs po ac wt bs
Indel _ _ !pb -> runPileM p'feed_input kont rs po (if str then (af,pb:ar) else (pb:af,ar)) wt bs
Seek !p !pb -> runPileM p'feed_input kont rs po ac (if str then (wf,wr') else (wf',wr)) bs
where wf' = Node p pb Empty Empty `unionH` wf
wr' = Node p pb Empty Empty `unionH` wr
inp -> kont () rs po ac wt $ either pure wrap inp
p'check_waiting :: PileM m ()
p'check_waiting = PileM $ \kont rs po (af0,ar0) (wf0,wr0) ->
let go1 af wf = case viewMinH wf of
Just (!mk, !pb, !wf') | mk == po -> go1 (pb:af) wf'
_ -> go2 af wf ar0 wr0
go2 af wf ar wr = case viewMinH wr of
Just (!mk, !pb, !wr') | mk == po -> go2 af wf (pb:ar) wr'
_ -> kont () rs po (af,ar) (wf,wr)
in go1 af0 wf0
p'scan_active :: PileM m (( CallStats, BasePile ), ( CallStats, BasePile ),
( CallStats, IndelPile ), ( CallStats, IndelPile ))
p'scan_active = do
(bpf,ipf) <- PileM $ \kont rs pos (af,ar) (wf,wr) -> go (\r af' wf' -> kont r rs pos (af',ar) (wf',wr)) [] wf mempty mempty af
(bpr,ipr) <- PileM $ \kont rs pos (af,ar) (wf,wr) -> go (\r ar' wr' -> kont r rs pos (af,ar') (wf,wr')) [] wr mempty mempty ar
return (bpf,bpr,ipf,ipr)
where
go k !ac !wt !bpile !ipile [ ] = k (bpile, ipile) (reverse ac) wt
go k !ac !wt !bpile !ipile (Base nwt qs mq pchunks : bs) =
case pchunks of
_ | nwt > 0 -> b' `seq` go k (b':ac) wt bpile ipile bs
Seek p' pb' -> go k ac (ins p' pb' wt) (z bpile) ipile bs
Indel nd ni pb' -> go k (pb':ac) wt (z bpile) (y ipile) bs where y = put (,) mq (nd,ni)
EndOfRead -> go k ac wt (z bpile) ipile bs
where
b' = Base (nwt-1) qs mq pchunks
z = put (\q x -> x { db_qual = min q (db_qual x) }) mq qs
ins q v w = Node q v Empty Empty `unionH` w
put f (Q !q) !x (!st,!vs) = ( st { read_depth = read_depth st + 1
, reads_mapq0 = reads_mapq0 st + (if q == 0 then 1 else 0)
, sum_mapq = sum_mapq st + fromIntegral q
, sum_mapq_squared = sum_mapq_squared st + fromIntegral q * fromIntegral q }
, f (Q q) x : vs )
data Heap = Empty | Node {-# UNPACK #-} !Int PrimBase Heap Heap
unionH :: Heap -> Heap -> Heap
Empty `unionH` t2 = t2
t1 `unionH` Empty = t1
t1@(Node k1 x1 l1 r1) `unionH` t2@(Node k2 x2 l2 r2)
| k1 <= k2 = Node k1 x1 (t2 `unionH` r1) l1
| otherwise = Node k2 x2 (t1 `unionH` r2) l2
getMinKeyH :: Heap -> Maybe Int
getMinKeyH Empty = Nothing
getMinKeyH (Node x _ _ _) = Just x
viewMinH :: Heap -> Maybe (Int, PrimBase, Heap)
viewMinH Empty = Nothing
viewMinH (Node k v l r) = Just (k, v, l `unionH` r)