-- | Trimming and fusing of reads as found in BAM files.
--
-- This API is remarkably ugly because the core loop is implemented in
-- C.  This requires the adapters to be in storable vectors, and since
-- they shouldn't be constantly copied around, the ugly 'withADSeqs'
-- function is needed.  The performance gain seems to be worth it,
-- though.

module Bio.Bam.Trim
        ( trim_3
        , trim_3'
        , trim_low_quality
        , AD_Seqs
        , withADSeqs
        , default_fwd_adapters
        , default_rev_adapters
        , find_merge
        , mergeBam
        , find_trim
        , trimBam
        , merged_seq
        , merged_qual
        ) where

import Bio.Bam.Header
import Bio.Bam.Rec
import Bio.Bam.Rmdup               ( ECig(..), setMD, toECig )
import Bio.Prelude
import Control.Monad.Trans.Control ( MonadBaseControl, control )
import Foreign.C.Types             ( CInt(..) )
import Foreign.Marshal.Array       ( allocaArray )

import qualified Data.ByteString                        as B
import qualified Data.Vector.Generic                    as V
import qualified Data.Vector.Storable                   as W

-- | Trims from the 3' end of a sequence.
-- @trim_3' p b@ trims the 3' end of the sequence in @b@ at the
-- earliest position such that @p@ evaluates to true on every suffix
-- that was trimmed off.  Note that the 3' end may be the beginning of
-- the sequence if it happens to be stored in reverse-complemented form.
-- Also note that trimming from the 3' end may not make sense for reads
-- that were constructed by merging paired end data (but we cannot take
-- care of that here).  Further note that trimming may break dependent
-- information, notably the "mate" information and many optional fields.
-- Since the intention is to trim based on quality scores, reads without
-- qualities are passed along unchanged.

trim_3' :: ([Nucleotides] -> [Qual] -> Bool) -> BamRec -> BamRec
trim_3' p b = case b_qual b of
    Nothing                         ->  b
    Just qs | b_flag b `testBit` 4  ->  trim_3 len_rev b
            | otherwise             ->  trim_3 len_fwd b
      where
        len_fwd = subtract 1 . length . takeWhile (uncurry p) $
                      zip (inits . reverse . V.toList $ b_seq b)
                          (inits . reverse . V.toList $ qs)

        len_rev = subtract 1 . length . takeWhile (uncurry p) $
                      zip (inits . V.toList $ b_seq  b)
                          (inits . V.toList $ qs)


trim_3 :: Int -> BamRec -> BamRec
trim_3 l b | b_flag b `testBit` 4 = trim_rev
           | otherwise            = trim_fwd
  where
    trim_fwd = let (_, cigar') = trim_back_cigar (b_cigar b) l
                   c = modMd (takeECig (V.length (b_seq  b) - l)) b
               in c { b_seq   = V.take (V.length (b_seq c) - l)  $  b_seq  c
                    , b_qual  = V.take (V.length (b_seq c) - l) <$> b_qual c
                    , b_cigar = cigar'
                    , b_exts  = map (\(k,e) -> case e of
                                        Text t | k `elem` trim_set
                                          -> (k, Text (B.take (B.length t - l) t))
                                        _ -> (k,e)
                                    ) (b_exts c) }

    trim_rev = let (off, cigar') = trim_fwd_cigar (b_cigar b) l
                   c = modMd (dropECig l) b
               in c { b_seq   = V.drop l  $  b_seq  c
                    , b_qual  = V.drop l <$> b_qual c
                    , b_pos   = b_pos c + off
                    , b_cigar = cigar'
                    , b_exts  = map (\(k,e) -> case e of
                                        Text t | k `elem` trim_set
                                          -> (k, Text (B.drop l t))
                                        _ -> (k,e)
                                    ) (b_exts c) }

    trim_set = ["BQ","CQ","CS","E2","OQ","U2"]

    modMd :: (ECig -> ECig) -> BamRec -> BamRec
    modMd f br = maybe br (setMD br . f . toECig (b_cigar br)) (getMd br)

    endOf :: ECig -> ECig
    endOf  WithMD     = WithMD
    endOf  WithoutMD  = WithoutMD
    endOf (Mat' _ es) = endOf es
    endOf (Ins' _ es) = endOf es
    endOf (SMa' _ es) = endOf es
    endOf (Rep' _ es) = endOf es
    endOf (Del' _ es) = endOf es
    endOf (Nop' _ es) = endOf es
    endOf (HMa' _ es) = endOf es
    endOf (Pad' _ es) = endOf es

    takeECig :: Int -> ECig -> ECig
    takeECig 0  es          = endOf es
    takeECig _  WithMD      = WithMD
    takeECig _  WithoutMD   = WithoutMD
    takeECig n (Mat' m  es) = Mat' n  $ if n > m then takeECig (n-m) es else WithMD
    takeECig n (Ins' m  es) = Ins' n  $ if n > m then takeECig (n-m) es else WithMD
    takeECig n (SMa' m  es) = SMa' n  $ if n > m then takeECig (n-m) es else WithMD
    takeECig n (Rep' ns es) = Rep' ns $ takeECig (n-1) es
    takeECig n (Del' ns es) = Del' ns $ takeECig n es
    takeECig n (Nop' m  es) = Nop' m  $ takeECig n es
    takeECig n (HMa' m  es) = HMa' m  $ takeECig n es
    takeECig n (Pad' m  es) = Pad' m  $ takeECig n es

    dropECig :: Int -> ECig -> ECig
    dropECig 0  es         = es
    dropECig _  WithMD     = WithMD
    dropECig _  WithoutMD  = WithoutMD
    dropECig n (Mat' m es) = if n > m then dropECig (n-m) es else Mat' n WithMD
    dropECig n (Ins' m es) = if n > m then dropECig (n-m) es else Ins' n WithMD
    dropECig n (SMa' m es) = if n > m then dropECig (n-m) es else SMa' n WithMD
    dropECig n (Rep' _ es) = dropECig (n-1) es
    dropECig n (Del' _ es) = dropECig n es
    dropECig n (Nop' _ es) = dropECig n es
    dropECig n (HMa' _ es) = dropECig n es
    dropECig n (Pad' _ es) = dropECig n es


trim_back_cigar, trim_fwd_cigar :: V.Vector v Cigar => v Cigar -> Int -> ( Int, v Cigar )
trim_back_cigar c l = (o, V.fromList $ reverse c') where (o,c') = sanitize_cigar . trim_cigar l $ reverse $ V.toList c
trim_fwd_cigar  c l = (o, V.fromList           c') where (o,c') = sanitize_cigar $ trim_cigar l $ V.toList c

sanitize_cigar :: (Int, [Cigar]) -> (Int, [Cigar])
sanitize_cigar (o, [        ])                          = (o, [])
sanitize_cigar (o, (op:*l):xs) | op == Pad              = sanitize_cigar (o,xs)         -- del P
                               | op == Del || op == Nop = sanitize_cigar (o + l, xs)    -- adjust D,N
                               | op == Ins              = (o, (SMa :* l):xs)            -- I --> S
                               | otherwise              = (o, (op :* l):xs)             -- rest is fine

trim_cigar :: Int -> [Cigar] -> (Int, [Cigar])
trim_cigar 0 cs = (0, cs)
trim_cigar _ [] = (0, [])
trim_cigar l ((op:*ll):cs) | bad_op op = let (o,cs') = trim_cigar l cs in (o + reflen op ll, cs')
                           | otherwise = case l `compare` ll of
    LT -> (reflen op  l, (op :* (ll-l)):cs)
    EQ -> (reflen op ll,                cs)
    GT -> let (o,cs') = trim_cigar (l - ll) cs in (o + reflen op ll, cs')

  where
    reflen op' = if ref_op op' then id else const 0
    bad_op o = o /= Mat && o /= Ins && o /= SMa
    ref_op o = o == Mat || o == Del


-- | Trim predicate to get rid of low quality sequence.
-- @trim_low_quality q ns qs@ evaluates to true if all qualities in @qs@
-- are smaller (i.e. worse) than @q@.
trim_low_quality :: Qual -> a -> [Qual] -> Bool
trim_low_quality q = const $ all (< q)


-- | Finds the merge point.  Input is list of forward adapters, list of
-- reverse adapters, sequence1, quality1, sequence2, quality2; output is
-- merge point and two qualities (YM, YN).
find_merge :: AD_Seqs -> AD_Seqs
           -> W.Vector Nucleotides -> W.Vector Qual
           -> W.Vector Nucleotides -> W.Vector Qual
           -> IO (Int, Int, Int)
find_merge pads1 pads2 r1 q1 r2 q2 =
        with_fw_seq r1 q1 $ \pr1 ->
        with_fw_seq r2 q2 $ \pr2 ->
        with_rc_seq r2 q2 $ \prv2 -> do
            min_merge_score pads1 pads2 pr1 pr2 prv2

-- | Overlap-merging of read pairs.  We shall compute the likelihood
-- for every possible overlap, then select the most likely one (unless it
-- looks completely random), compute a quality from the second best
-- merge, then merge and clamp the quality accordingly.
-- (We could try looking for chimaera after completing the merge, if
-- only we knew which ones to expect?)
--
-- Two reads go in, with two adapter lists.  We return 'Nothing' if all
-- merges looked mostly random.  Else we return the two original reads,
-- flagged as 'eflagVestigial' *and* the merged version, flagged as
-- 'eflagMerged' and optionally 'eflagTrimmed'.  All reads contain the
-- computed qualities (in YM and YN), which we also return.
--
-- The merging automatically limits quality scores some of the time.  We
-- additionally impose a hard limit of 63 to avoid difficulties
-- representing the result, and even that is ridiculous.  Sane people
-- would further limit the returned quality!  (In practice, map quality
-- later imposes a limit anyway, so no worries...)

mergeBam :: Int -> Int
         -> AD_Seqs -> AD_Seqs
         -> BamRec -> BamRec -> IO [BamRec]
mergeBam lowq highq ads1 ads2 r1 r2 = do
    let len_r1    = V.length  $ b_seq  r1
        len_r2    = V.length  $ b_seq  r2

        b_seq_r1  = V.convert $ b_seq  r1
        b_seq_r2  = V.convert $ b_seq  r2
        b_qual_r1 = fromMaybe (V.map (const (Q 23)) b_seq_r1) (b_qual r1)
        b_qual_r2 = fromMaybe (V.map (const (Q 23)) b_seq_r2) (b_qual r2)

    (mlen, qual1, qual2) <- find_merge ads1 ads2 b_seq_r1 b_qual_r1 b_seq_r2 b_qual_r2

    let flag_alternative br = br { b_exts = updateE "FF" (Int $ extAsInt 0 "FF" br .|. eflagAlternative) $ b_exts br }
        store_quals      br = br { b_exts = updateE "YM" (Int qual1) $ updateE "YN" (Int qual2) $ b_exts br }
        pair_flags = flagPaired.|.flagProperlyPaired.|.flagMateUnmapped.|.flagMateReversed.|.flagFirstMate.|.flagSecondMate

        r1' = store_quals r1
        r2' = store_quals r2
        rm  = store_quals $ merged_read mlen (fromIntegral $ min 63 qual1)

        merged_read l qmax =
            nullBamRec
                { b_qname = b_qname r1
                , b_flag  = flagUnmapped .|. complement pair_flags .&. b_flag r1
                , b_seq   = V.convert $  merged_seq l b_seq_r1 b_qual_r1 b_seq_r2 b_qual_r2
                , b_qual  = Just $ merged_qual qmax l b_seq_r1 b_qual_r1 b_seq_r2 b_qual_r2
                , b_exts  = let ff = if l < len_r1 then eflagTrimmed else 0
                            in updateE "FF" (Int $ extAsInt 0 "FF" r1 .|. eflagMerged .|. ff) $ b_exts r1 }

    return $ case () of
        _ | V.null (b_seq r1) && V.null (b_seq r2) -> [              ]
          | qual1 < lowq || mlen < 0               -> [ r1', r2'     ]
          | qual1 >= highq && mlen == 0            -> [              ]
          | qual1 >= highq                         -> [           rm ]
          | mlen < len_r1-20 || mlen < len_r2-20   -> [           rm ]
          | otherwise         -> map flag_alternative [ r1', r2', rm ]

{-# INLINE merged_seq #-}
merged_seq :: (V.Vector v Nucleotides, V.Vector v Qual)
           => Int -> v Nucleotides -> v Qual -> v Nucleotides -> v Qual -> v Nucleotides
merged_seq l b_seq_r1 b_qual_r1 b_seq_r2 b_qual_r2 = V.concat
        [ V.take (l - len_r2) b_seq_r1
        , V.zipWith4 zz          (V.take l $ V.drop (l - len_r2) b_seq_r1)
                                 (V.take l $ V.drop (l - len_r2) b_qual_r1)
                     (V.reverse $ V.take l $ V.drop (l - len_r1) b_seq_r2)
                     (V.reverse $ V.take l $ V.drop (l - len_r1) b_qual_r2)
        , V.reverse $ V.take (l - len_r1) b_seq_r2 ]
  where
    len_r1 = V.length b_qual_r1
    len_r2 = V.length b_qual_r2
    zz !n1 (Q !q1) !n2 (Q !q2) | n1 == compls n2 =        n1
                               | q1 > q2         =        n1
                               | otherwise       = compls n2

{-# INLINE merged_qual #-}
merged_qual :: (V.Vector v Nucleotides, V.Vector v Qual)
            => Word8 -> Int -> v Nucleotides -> v Qual -> v Nucleotides -> v Qual -> v Qual
merged_qual qmax l b_seq_r1 b_qual_r1 b_seq_r2 b_qual_r2 = V.concat
        [ V.take (l - len_r2) b_qual_r1
        , V.zipWith4 zz           (V.take l $ V.drop (l - len_r2) b_seq_r1)
                                  (V.take l $ V.drop (l - len_r2) b_qual_r1)
                      (V.reverse $ V.take l $ V.drop (l - len_r1) b_seq_r2)
                      (V.reverse $ V.take l $ V.drop (l - len_r1) b_qual_r2)
        , V.reverse $ V.take (l - len_r1) b_qual_r2 ]
  where
    len_r1 = V.length b_qual_r1
    len_r2 = V.length b_qual_r2
    zz !n1 (Q !q1) !n2 (Q !q2) | n1 == compls n2 = Q $ min qmax (q1 + q2)
                               | q1 > q2         = Q $           q1 - q2
                               | otherwise       = Q $           q2 - q1



-- | Finds the trimming point.  Input is list of forward adapters,
-- sequence, quality; output is trim point and two qualities (YM, YN).
find_trim :: AD_Seqs
          -> W.Vector Nucleotides -> W.Vector Qual
          -> IO (Int, Int, Int)
find_trim pads1 r1 q1 =
        withADSeqs [W.empty]              $ \pads2 ->
        with_fw_seq r1 q1                   $ \pr1 ->
        min_merge_score pads1 pads2 pr1 (FW_Seq nullPtr nullPtr 0) (RC_Seq nullPtr nullPtr 0)

-- | Trimming for a single read:  we need one adapter only (the one coming
-- /after/ the read), here provided as a list of options, and then we
-- merge with an empty second read.  Results in up to two reads (the
-- original, possibly flagged, and the trimmed one, definitely flagged,
-- and two qualities).
trimBam :: Int -> Int -> AD_Seqs -> BamRec -> IO [BamRec]
trimBam lowq highq ads1 r1 = do
    let b_seq_r1 = V.convert $ b_seq r1
    (mlen, qual1, qual2) <- find_trim ads1 b_seq_r1 $
                            fromMaybe (V.map (const (Q 23)) b_seq_r1) (b_qual r1)

    let flag_alternative br = br { b_exts = updateE "FF" (Int $ extAsInt 0 "FF" br .|. eflagAlternative) $ b_exts br }
        store_quals      br = br { b_exts = updateE "YM" (Int qual1) $ updateE "YN" (Int qual2) $ b_exts br }

        r1'  = store_quals r1
        r1t  = store_quals $ trimmed_read mlen

        trimmed_read l = nullBamRec {
                b_qname = b_qname r1,
                b_flag  = flagUnmapped .|. b_flag r1,
                b_seq   = V.take l  $  b_seq  r1,
                b_qual  = V.take l <$> b_qual r1,
                b_exts  = updateE "FF" (Int $ extAsInt 0 "FF" r1 .|. eflagTrimmed) $ b_exts r1 }

    return $ case () of
        _ | V.null (b_seq r1)              -> [          ]
          | mlen == 0 && qual1 >= highq    -> [          ]
          | qual1 < lowq || mlen < 0       -> [ r1'      ]
          | qual1 >= highq                 -> [      r1t ]
          | otherwise -> map flag_alternative [ r1', r1t ]


-- | For merging, we don't need the complete adapters (length around 70!),
-- only a sufficient prefix.  Taking only the more-or-less constant
-- part (length around 30), there aren't all that many different
-- adapters in the world.  To deal with pretty much every library, we
-- only need the following forward adapters, which will be the default
-- (defined here in the direction they would be sequenced in):  Genomic
-- R2, Multiplex R2, Fraft P7.

default_fwd_adapters :: [W.Vector Nucleotides]
default_fwd_adapters = map (W.fromList . map toNucleotides . map c2w)
         [ {- Genomic R2   -}  "AGATCGGAAGAGCGGTTCAG"
         , {- Multiplex R2 -}  "AGATCGGAAGAGCACACGTC"
         , {- Graft P7     -}  "AGATCGGAAGAGCTCGTATG" ]

-- | Like 'default_rev_adapters', these are the few adapters needed for
-- the reverse read (defined in the direction they would be sequenced in
-- as part of the second read):  Genomic R1, CL 72.

default_rev_adapters :: [W.Vector Nucleotides]
default_rev_adapters = map (W.fromList . map toNucleotides . map c2w)
         [ {- Genomic_R1   -}  "AGATCGGAAGAGCGTCGTGT"
         , {- CL72         -}  "GGAAGAGCGTCGTGTAGGGA" ]

-- We need to compute the likelihood of a read pair given an assumed
-- insert length.  The likelihood of the first read is the likelihood of
-- a match with the adapter where it overlaps the 3' adapter, elsewhere
-- it's 1/4 per position.  The likelihood of the second read is the
-- likelihood of a match with the adapter where it overlaps the adapter,
-- the likehood of a read-read match where it overlaps read one, 1/4 per
-- position elsewhere.  (Yes, this ignores base composition.  It doesn't
-- matter enough.)

min_merge_score
    :: AD_Seqs              -- 3' adapters as they appear in the first read
    -> AD_Seqs              -- 5' adapters as they appear in the second read
    -> FW_Seq               -- first read, prepped
    -> FW_Seq               -- second read, qual, prepped
    -> RC_Seq               -- second read, qual, reversed and prepped
    -> IO (Int,Int,Int)     -- best length, min score, 2nd min score
min_merge_score (AD_Seqs !p_fwd_ads !p_fwd_lns !n_fwd_ads) (AD_Seqs !p_rev_ads !p_rev_lns !n_rev_ads)
            (FW_Seq !p_rd1 !p_qs1 !l1) (FW_Seq !p_rd2 !p_qs2 !l2) (RC_Seq !p_rrd2 !p_rqs2 _) =
    allocaArray 2 $ \pmins ->
        liftM3 (,,)
               (fromIntegral <$>
                prim_merge_score p_fwd_ads p_fwd_lns (fromIntegral n_fwd_ads)
                                 p_rev_ads p_rev_lns (fromIntegral n_rev_ads)
                                 p_rd1 p_qs1 (fromIntegral l1)
                                 p_rd2 p_qs2 (fromIntegral l2)
                                 p_rrd2 p_rqs2 pmins)
               (fromIntegral <$> peekElemOff pmins 0)
               (fromIntegral <$> peekElemOff pmins 1)

foreign import ccall unsafe "prim_merge_score"
    prim_merge_score :: Ptr (Ptr Nucleotides) -> Ptr CInt -> CInt
                     -> Ptr (Ptr Nucleotides) -> Ptr CInt -> CInt
                     -> Ptr Nucleotides -> Ptr Qual -> CInt
                     -> Ptr Nucleotides -> Ptr Qual -> CInt
                     -> Ptr Nucleotides -> Ptr Qual
                     -> Ptr CInt -> IO CInt



data AD_Seqs = AD_Seqs !(Ptr (Ptr Nucleotides)) !(Ptr CInt) !Int
data FW_Seq  = FW_Seq !(Ptr Nucleotides) !(Ptr Qual) !Int
data RC_Seq  = RC_Seq !(Ptr Nucleotides) !(Ptr Qual) !Int

-- Maybe pad with something suitable?
withADSeqs :: MonadBaseControl IO m => [W.Vector Nucleotides] -> (AD_Seqs -> m r) -> m r
withADSeqs ads0 k =
    control                                                 $ \run_io ->
    allocaArray (length ads0)                               $ \pps ->
    allocaArray (length ads0)                               $ \pls ->
    let go !n [    ] = run_io (k $! AD_Seqs pps pls n)
        go !n (v:vs) = W.unsafeWith v $ \pa -> do
                       pokeElemOff pps n pa
                       pokeElemOff pls n (fromIntegral (W.length v))
                       go (succ n) vs
    in go 0 ads0

-- Maybe pad with something suitable?
with_fw_seq :: W.Vector Nucleotides -> W.Vector Qual -> (FW_Seq -> IO r) -> IO r
with_fw_seq ns qs k
    | W.length ns == W.length qs
        = W.unsafeWith ns    $ \p_ns ->
          W.unsafeWith qs    $ \p_qs ->
          k (FW_Seq p_ns p_qs $ W.length ns)
    | otherwise
        = throwIO $ LengthMismatch "forward adapter"
{-# INLINE with_fw_seq #-}

-- Maybe pad with something suitable?
with_rc_seq :: W.Vector Nucleotides -> W.Vector Qual -> (RC_Seq -> IO r) -> IO r
with_rc_seq ns qs k
    | W.length ns == W.length qs
        = W.unsafeWith (W.reverse $ W.map compls ns)    $ \p_rns ->
          W.unsafeWith (W.reverse qs)                   $ \p_rqs ->
          k (RC_Seq p_rns p_rqs $ W.length ns)
    | otherwise
        = throwIO $ LengthMismatch "reverse adapter"
{-# INLINE with_rc_seq #-}