-- | Trimming of reads as found in BAM files.  Implements trimming low
-- quality sequence from the 3' end.

module Bio.Bam.Trim
        ( trim_3
        , trim_3'
        , trim_low_quality
        , default_fwd_adapters
        , default_rev_adapters
        , find_merge
        , mergeBam
        , find_trim
        , trimBam
        , mergeTrimBam
        , twoMins
        , merged_seq
        , merged_qual
        ) where

import Bio.Bam.Header
import Bio.Bam.Rec
import Bio.Bam.Rmdup        ( ECig(..), setMD, toECig )
import Bio.Iteratee
import Bio.Prelude

import Foreign.C.Types      ( CInt(..) )

import qualified Data.ByteString                        as B
import qualified Data.Vector.Generic                    as V
import qualified Data.Vector.Storable                   as W

-- | Trims from the 3' end of a sequence.
-- @trim_3\' p b@ trims the 3' end of the sequence in @b@ at the
-- earliest position such that @p@ evaluates to true on every suffix
-- that was trimmed off.  Note that the 3' end may be the beginning of
-- the sequence if it happens to be stored in reverse-complemented form.
-- Also note that trimming from the 3' end may not make sense for reads
-- that were constructed by merging paired end data (but we cannot take
-- care of that here).  Further note that trimming may break dependent
-- information, notably the "mate" information of the mate and many
-- optional fields.

trim_3' :: ([Nucleotides] -> [Qual] -> Bool) -> BamRec -> BamRec
trim_3' p b | b_flag b `testBit` 4 = trim_rev
            | otherwise            = trim_fwd
  where
    trim_fwd = let l = subtract 1 . fromIntegral . length . takeWhile (uncurry p) $
                            zip (inits . reverse . V.toList $ b_seq b)
                                (inits . reverse . V.toList $ b_qual b)
               in trim_3 l b

    trim_rev = let l = subtract 1 . fromIntegral . length . takeWhile (uncurry p) $
                            zip (inits . V.toList $ b_seq  b)
                                (inits . V.toList $ b_qual b)
               in trim_3 l b

trim_3 :: Int -> BamRec -> BamRec
trim_3 l b | b_flag b `testBit` 4 = trim_rev
           | otherwise            = trim_fwd
  where
    trim_fwd = let (_, cigar') = trim_back_cigar (b_cigar b) l
                   c = modMd (takeECig (V.length (b_seq  b) - l)) b
               in c { b_seq   = V.take (V.length (b_seq  c) - l) (b_seq  c)
                    , b_qual  = V.take (V.length (b_qual c) - l) (b_qual c)
                    , b_cigar = cigar'
                    , b_exts  = map (\(k,e) -> case e of
                                        Text t | k `elem` trim_set
                                          -> (k, Text (B.take (B.length t - l) t))
                                        _ -> (k,e)
                                    ) (b_exts c) }

    trim_rev = let (off, cigar') = trim_fwd_cigar (b_cigar b) l
                   c = modMd (dropECig l) b
               in c { b_seq   = V.drop l (b_seq  c)
                    , b_qual  = V.drop l (b_qual c)
                    , b_pos   = b_pos c + off
                    , b_cigar = cigar'
                    , b_exts  = map (\(k,e) -> case e of
                                        Text t | k `elem` trim_set
                                          -> (k, Text (B.drop l t))
                                        _ -> (k,e)
                                    ) (b_exts c) }

    trim_set = ["BQ","CQ","CS","E2","OQ","U2"]

    modMd :: (ECig -> ECig) -> BamRec -> BamRec
    modMd f br = maybe br (setMD br . f . toECig (b_cigar br)) (getMd br)

    endOf :: ECig -> ECig
    endOf  WithMD     = WithMD
    endOf  WithoutMD  = WithoutMD
    endOf (Mat' _ es) = endOf es
    endOf (Ins' _ es) = endOf es
    endOf (SMa' _ es) = endOf es
    endOf (Rep' _ es) = endOf es
    endOf (Del' _ es) = endOf es
    endOf (Nop' _ es) = endOf es
    endOf (HMa' _ es) = endOf es
    endOf (Pad' _ es) = endOf es

    takeECig :: Int -> ECig -> ECig
    takeECig 0  es          = endOf es
    takeECig _  WithMD      = WithMD
    takeECig _  WithoutMD   = WithoutMD
    takeECig n (Mat' m  es) = Mat' n  $ if n > m then takeECig (n-m) es else WithMD
    takeECig n (Ins' m  es) = Ins' n  $ if n > m then takeECig (n-m) es else WithMD
    takeECig n (SMa' m  es) = SMa' n  $ if n > m then takeECig (n-m) es else WithMD
    takeECig n (Rep' ns es) = Rep' ns $ takeECig (n-1) es
    takeECig n (Del' ns es) = Del' ns $ takeECig n es
    takeECig n (Nop' m  es) = Nop' m  $ takeECig n es
    takeECig n (HMa' m  es) = HMa' m  $ takeECig n es
    takeECig n (Pad' m  es) = Pad' m  $ takeECig n es

    dropECig :: Int -> ECig -> ECig
    dropECig 0  es         = es
    dropECig _  WithMD     = WithMD
    dropECig _  WithoutMD  = WithoutMD
    dropECig n (Mat' m es) = if n > m then dropECig (n-m) es else Mat' n WithMD
    dropECig n (Ins' m es) = if n > m then dropECig (n-m) es else Ins' n WithMD
    dropECig n (SMa' m es) = if n > m then dropECig (n-m) es else SMa' n WithMD
    dropECig n (Rep' _ es) = dropECig (n-1) es
    dropECig n (Del' _ es) = dropECig n es
    dropECig n (Nop' _ es) = dropECig n es
    dropECig n (HMa' _ es) = dropECig n es
    dropECig n (Pad' _ es) = dropECig n es


trim_back_cigar, trim_fwd_cigar :: V.Vector v Cigar => v Cigar -> Int -> ( Int, v Cigar )
trim_back_cigar c l = (o, V.fromList $ reverse c') where (o,c') = sanitize_cigar . trim_cigar l $ reverse $ V.toList c
trim_fwd_cigar  c l = (o, V.fromList           c') where (o,c') = sanitize_cigar $ trim_cigar l $ V.toList c

sanitize_cigar :: (Int, [Cigar]) -> (Int, [Cigar])
sanitize_cigar (o, [        ])                          = (o, [])
sanitize_cigar (o, (op:*l):xs) | op == Pad              = sanitize_cigar (o,xs)         -- del P
                               | op == Del || op == Nop = sanitize_cigar (o + l, xs)    -- adjust D,N
                               | op == Ins              = (o, (SMa :* l):xs)            -- I --> S
                               | otherwise              = (o, (op :* l):xs)             -- rest is fine

trim_cigar :: Int -> [Cigar] -> (Int, [Cigar])
trim_cigar 0 cs = (0, cs)
trim_cigar _ [] = (0, [])
trim_cigar l ((op:*ll):cs) | bad_op op = let (o,cs') = trim_cigar l cs in (o + reflen op ll, cs')
                           | otherwise = case l `compare` ll of
    LT -> (reflen op  l, (op :* (ll-l)):cs)
    EQ -> (reflen op ll,                cs)
    GT -> let (o,cs') = trim_cigar (l - ll) cs in (o + reflen op ll, cs')

  where
    reflen op' = if ref_op op' then id else const 0
    bad_op o = o /= Mat && o /= Ins && o /= SMa
    ref_op o = o == Mat || o == Del


-- | Trim predicate to get rid of low quality sequence.
-- @trim_low_quality q ns qs@ evaluates to true if all qualities in @qs@
-- are smaller (i.e. worse) than @q@.
trim_low_quality :: Qual -> a -> [Qual] -> Bool
trim_low_quality q = const $ all (< q)


-- | Finds the merge point.  Input is list of forward adapters, list of
-- reverse adapters, sequence1, quality1, sequence2, quality2; output is
-- merge point and two qualities (YM, YN).
find_merge :: [W.Vector Nucleotides] -> [W.Vector Nucleotides]
           -> W.Vector Nucleotides -> W.Vector Qual
           -> W.Vector Nucleotides -> W.Vector Qual
           -> (Int, Int, Int)
find_merge ads1 ads2 r1 q1 r2 q2 = (mlen, score2 - score1, plain_score - score1)
  where
    plain_score = 6 * fromIntegral (V.length r1 + V.length r2)
    (score1, mlen, score2) = twoMins plain_score (V.length r1 + V.length r2) $
                             merge_score ads1 ads2 r1 q1 r2 q2

-- | Overlap-merging of read pairs.  We shall compute the likelihood
-- for every possible overlap, then select the most likely one (unless it
-- looks completely random), compute a quality from the second best
-- merge, then merge and clamp the quality accordingly.
-- (We could try looking for chimaera after completing the merge, if
-- only we knew which ones to expect?)
--
-- Two reads go in, with two adapter lists.  We return 'Nothing' if all
-- merges looked mostly random.  Else we return the two original reads,
-- flagged as 'eflagVestigial' *and* the merged version, flagged as
-- 'eflagMerged' and optionally 'eflagTrimmed'.  All reads contain the
-- computed qualities (in YM and YN), which we also return.
--
-- The merging automatically limits quality scores some of the time.  We
-- additionally impose a hard limit of 63 to avoid difficulties
-- representing the result, and even that is ridiculous.  Sane people
-- would further limit the returned quality!  (In practice, map quality
-- later imposes a limit anyway, so no worries...)

mergeBam :: Int -> Int -> [W.Vector Nucleotides] -> [W.Vector Nucleotides] -> BamRec -> BamRec -> [BamRec]
mergeBam lowq highq ads1 ads2 r1 r2
    | V.null (b_seq r1) && V.null (b_seq r2) = [ r1', r2'     ]
    | qual1 < lowq                           = [ r1', r2'     ]
    | qual1 >= highq && mlen == 0            = [              ]
    | qual1 >= highq                         = [           rm ]
    | mlen < len_r1-20 || mlen < len_r2-20   = [           rm ]
    | otherwise         = map flag_alternative [ r1', r2', rm ]
  where
    len_r1    = V.length  $ b_seq  r1
    len_r2    = V.length  $ b_seq  r2

    b_seq_r1  = V.convert $ b_seq  r1
    b_seq_r2  = V.convert $ b_seq  r2
    b_qual_r1 = V.convert $ b_qual r1
    b_qual_r2 = V.convert $ b_qual r2

    (mlen, qual1, qual2) = find_merge ads1 ads2 b_seq_r1 b_qual_r1 b_seq_r2 b_qual_r2

    flag_alternative br = br { b_exts = updateE "FF" (Int $ extAsInt 0 "FF" br .|. eflagAlternative) $ b_exts br }
    store_quals      br = br { b_exts = updateE "YM" (Int qual1) $ updateE "YN" (Int qual2) $ b_exts br }
    pair_flags = flagPaired.|.flagProperlyPaired.|.flagMateUnmapped.|.flagMateReversed.|.flagFirstMate.|.flagSecondMate

    r1' = store_quals r1
    r2' = store_quals r2
    rm  = store_quals $ merged_read mlen (fromIntegral $ min 63 qual1)

    merged_read l qmax = nullBamRec {
                b_qname = b_qname r1,
                b_flag  = flagUnmapped .|. complement pair_flags .&. b_flag r1,
                b_seq   = V.convert $ merged_seq l b_seq_r1 b_qual_r1 b_seq_r2 b_qual_r2,
                b_qual  = V.convert $ merged_qual qmax l b_seq_r1 b_qual_r1 b_seq_r2 b_qual_r2,
                b_exts  = let ff = if l < len_r1 then eflagTrimmed else 0
                          in updateE "FF" (Int $ extAsInt 0 "FF" r1 .|. eflagMerged .|. ff) $ b_exts r1 }

{-# INLINE merged_seq #-}
merged_seq :: (V.Vector v Nucleotides, V.Vector v Qual)
           => Int -> v Nucleotides -> v Qual -> v Nucleotides -> v Qual -> v Nucleotides
merged_seq l b_seq_r1 b_qual_r1 b_seq_r2 b_qual_r2 = V.concat
        [ V.take (l - len_r2) (b_seq_r1)
        , V.zipWith4 zz          (V.take l $ V.drop (l - len_r2) b_seq_r1)
                                 (V.take l $ V.drop (l - len_r2) b_qual_r1)
                     (V.reverse $ V.take l $ V.drop (l - len_r1) b_seq_r2)
                     (V.reverse $ V.take l $ V.drop (l - len_r1) b_qual_r2)
        , V.reverse $ V.take (l - len_r1) b_seq_r2 ]
  where
    len_r1 = V.length b_qual_r1
    len_r2 = V.length b_qual_r2
    zz !n1 (Q !q1) !n2 (Q !q2) | n1 == compls n2 =        n1
                               | q1 > q2         =        n1
                               | otherwise       = compls n2

{-# INLINE merged_qual #-}
merged_qual :: (V.Vector v Nucleotides, V.Vector v Qual)
            => Word8 -> Int -> v Nucleotides -> v Qual -> v Nucleotides -> v Qual -> v Qual
merged_qual qmax l b_seq_r1 b_qual_r1 b_seq_r2 b_qual_r2 = V.concat
        [ V.take (l - len_r2) (b_qual_r1)
        , V.zipWith4 zz           (V.take l $ V.drop (l - len_r2) b_seq_r1)
                                  (V.take l $ V.drop (l - len_r2) b_qual_r1)
                      (V.reverse $ V.take l $ V.drop (l - len_r1) b_seq_r2)
                      (V.reverse $ V.take l $ V.drop (l - len_r1) b_qual_r2)
        , V.reverse $ V.take (l - len_r1) b_qual_r2 ]
  where
    len_r1 = V.length b_qual_r1
    len_r2 = V.length b_qual_r2
    zz !n1 (Q !q1) !n2 (Q !q2) | n1 == compls n2 = Q $ min qmax (q1 + q2)
                               | q1 > q2         = Q $           q1 - q2
                               | otherwise       = Q $           q2 - q1



-- | Finds the trimming point.  Input is list of forward adapters,
-- sequence, quality; output is trim point and two qualities (YM, YN).
find_trim :: [W.Vector Nucleotides]
          -> W.Vector Nucleotides -> W.Vector Qual
          -> (Int, Int, Int)
find_trim ads1 r1 q1 = (mlen, score2 - score1, plain_score - score1)
  where
    plain_score = 6 * fromIntegral (V.length r1)
    (score1, mlen, score2) = twoMins plain_score (V.length r1) $
                             merge_score ads1 [V.empty] r1 q1 V.empty V.empty

-- | Trimming for a single read:  we need one adapter only (the one coming
-- /after/ the read), here provided as a list of options, and then we
-- merge with an empty second read.  Results in up to two reads (the
-- original, possibly flagged, and the trimmed one, definitely flagged,
-- and two qualities).
trimBam :: Int -> Int -> [W.Vector Nucleotides] -> BamRec -> [BamRec]
trimBam lowq highq ads1 r1
    | V.null (b_seq r1)              = [ r1'      ]
    | mlen == 0 && qual1 >= highq    = [          ]
    | qual1 < lowq                   = [ r1'      ]
    | qual1 >= highq                 = [      r1t ]
    | otherwise = map flag_alternative [ r1', r1t ]
  where
    -- the "merge" score if there is no trimming

    b_seq_r1 = V.convert $ b_seq r1
    b_qual_r1 = V.convert $ b_qual r1

    (mlen, qual1, qual2) = find_trim ads1 b_seq_r1 b_qual_r1

    flag_alternative br = br { b_exts = updateE "FF" (Int $ extAsInt 0 "FF" br .|. eflagAlternative) $ b_exts br }
    store_quals      br = br { b_exts = updateE "YM" (Int qual1) $ updateE "YN" (Int qual2) $ b_exts br }

    r1'  = store_quals r1
    r1t  = store_quals $ trimmed_read mlen

    trimmed_read l = nullBamRec {
            b_qname = b_qname r1,
            b_flag  = flagUnmapped .|. b_flag r1,
            b_seq   = V.take l $ b_seq  r1,
            b_qual  = V.take l $ b_qual r1,
            b_exts  = updateE "FF" (Int $ extAsInt 0 "FF" r1 .|. eflagTrimmed) $ b_exts r1 }


-- | For merging, we don't need the complete adapters (length around 70!),
-- only a sufficient prefix.  Taking only the more-or-less constant
-- part (length around 30), there aren't all that many different
-- adapters in the world.  To deal with pretty much every library, we
-- only need the following forward adapters, which will be the default
-- (defined here in the direction they would be sequenced in):  Genomic
-- R2, Multiplex R2, Fraft P7.

default_fwd_adapters :: [ W.Vector Nucleotides ]
default_fwd_adapters = map (W.fromList. map toNucleotides)
         [ {- Genomic R2   -}  "AGATCGGAAGAGCGGTTCAG"
         , {- Multiplex R2 -}  "AGATCGGAAGAGCACACGTC"
         , {- Graft P7     -}  "AGATCGGAAGAGCTCGTATG" ]

-- | Like 'default_rev_adapters', these are the few adapters needed for
-- the reverse read (defined in the direction they would be sequenced in
-- as part of the second read):  Genomic R1, CL 72.

default_rev_adapters :: [ W.Vector Nucleotides ]
default_rev_adapters = map (W.fromList. map toNucleotides)
         [ {- Genomic_R1   -}  "AGATCGGAAGAGCGTCGTGT"
         , {- CL72         -}  "GGAAGAGCGTCGTGTAGGGA" ]

-- We need to compute the likelihood of a read pair given an assumed
-- insert length.  The likelihood of the first read is the likelihood of
-- a match with the adapter where it overlaps the 3' adapter, elsewhere
-- it's 1/4 per position.  The likelihood of the second read is the
-- likelihood of a match with the adapter where it overlaps the adapter,
-- the likehood of a read-read match where it overlaps read one, 1/4 per
-- position elsewhere.  (Yes, this ignores base composition.  It doesn't
-- matter enough.)

merge_score
    :: [ W.Vector Nucleotides ]                 -- 3' adapters as they appear in the first read
    -> [ W.Vector Nucleotides ]                 -- 5' adapters as they appear in the second read
    -> W.Vector Nucleotides -> W.Vector Qual    -- first read, qual
    -> W.Vector Nucleotides -> W.Vector Qual    -- second read, qual
    -> Int                                      -- assumed insert length
    -> Int                                      -- score (roughly sum of qualities at mismatches)
merge_score fwd_adapters rev_adapters !read1 !qual1 !read2 !qual2 !l
    =   6 * fromIntegral (l `min` V.length read1)                                           -- read1, part before adapter
      + 6 * fromIntegral (max 0 (l - V.length read1))                                       -- read2, part before overlap

      + foldl' (\acc fwd_ad -> min acc
                    (match_adapter l read1 qual1 fwd_ad +                                   -- read1, match with forward adapter
                     6 * fromIntegral (max 0 (V.length read1 - V.length fwd_ad - l)))       -- read1, part after (known) adapter
               ) maxBound fwd_adapters

      + foldl' (\acc rev_ad -> min acc
                    (match_adapter l read2 qual2 rev_ad +                                   -- read2, match with reverse adapter
                     6 * fromIntegral (max 0 (V.length read2 - V.length rev_ad - l)))       -- read2, part after (known) adapter
               ) maxBound rev_adapters

      + match_reads l read1 qual1 read2 qual2

{-# INLINE match_adapter #-}
match_adapter :: Int -> W.Vector Nucleotides -> W.Vector Qual -> W.Vector Nucleotides -> Int
match_adapter !off !rd !qs !ad
    | V.length rd /= V.length qs = error "read/qual length mismatch"
    | efflength <= 0             = 0
    | otherwise
        = fromIntegral . unsafePerformIO $
          W.unsafeWith rd $ \p_rd ->
          W.unsafeWith qs $ \p_qs ->
          W.unsafeWith ad $ \p_ad ->
          prim_match_ad (fromIntegral off)
                        (fromIntegral efflength)
                        p_rd p_qs p_ad
  where
    !efflength =  (V.length rd - off) `min` V.length ad

foreign import ccall unsafe "prim_match_ad"
    prim_match_ad :: CInt -> CInt
                  -> Ptr Nucleotides -> Ptr Qual
                  -> Ptr Nucleotides -> IO CInt


-- | Computes overlap score for two reads (with qualities) assuming an
-- insert length.
{-# INLINE match_reads #-}
match_reads :: Int -> W.Vector Nucleotides -> W.Vector Qual -> W.Vector Nucleotides -> W.Vector Qual -> Int
match_reads !l !rd1 !qs1 !rd2 !qs2
    | V.length rd1 /= V.length qs1 || V.length rd2 /= V.length qs2 = error "read/qual length mismatch"
    | efflength <= 0                                               = 0
    | otherwise
        = fromIntegral . unsafePerformIO $
          W.unsafeWith rd1 $ \p_rd1 ->
          W.unsafeWith qs1 $ \p_qs1 ->
          W.unsafeWith rd2 $ \p_rd2 ->
          W.unsafeWith qs2 $ \p_qs2 ->
          prim_match_reads (fromIntegral minidx1)
                           (fromIntegral maxidx2)
                           (fromIntegral efflength)
                           p_rd1 p_qs1 p_rd2 p_qs2
  where
    -- vec1, forward
    !minidx1 = (l - V.length rd2) `max` 0
    -- vec2, backward
    !maxidx2 = l `min` V.length rd2
    -- effective length
    !efflength = ((V.length rd1 + V.length rd2 - l) `min` l) `max` 0


foreign import ccall unsafe "prim_match_reads"
    prim_match_reads :: CInt -> CInt -> CInt
                     -> Ptr Nucleotides -> Ptr Qual
                     -> Ptr Nucleotides -> Ptr Qual -> IO CInt


{-# INLINE twoMins #-}
twoMins :: (Bounded a, Ord a) => a -> Int -> (Int -> a) -> (a,Int,a)
twoMins a0 imax f = go a0 0 maxBound 0 0
  where
    go !m1 !i1 !m2 !i2 !i
        | i == imax = (m1,i1,m2)
        | otherwise =
            case f i of
                x | x < m1    -> go  x  i m1 i1 (i+1)
                  | x < m2    -> go m1 i1  x  i (i+1)
                  | otherwise -> go m1 i1 m2 i2 (i+1)


mergeTrimBam :: Monad m => Int -> Int -> [W.Vector Nucleotides] -> [W.Vector Nucleotides] -> Enumeratee [BamRec] [BamRec] m a
mergeTrimBam lowq highq fwd_ads rev_ads = convStream go
  where
    go = do r1 <- headStream
            if isPaired r1
              then tryHead >>= go2 r1
              else return $ trimBam lowq highq fwd_ads r1

    go2 r1  Nothing  = error $ "Lone mate found: " ++ show (b_qname r1)
    go2 r1 (Just r2) = return $ mergeBam lowq highq fwd_ads rev_ads r1 r2