module Bio.Bam.Trim
        ( trim_3
        , trim_3'
        , trim_low_quality
        , AD_Seqs
        , withADSeqs
        , default_fwd_adapters
        , default_rev_adapters
        , find_merge
        , mergeBam
        , find_trim
        , trimBam
        , merged_seq
        , merged_qual
        ) where
import Bio.Bam.Header
import Bio.Bam.Rec
import Bio.Bam.Rmdup               ( ECig(..), setMD, toECig )
import Bio.Prelude
import Control.Monad.Trans.Control ( MonadBaseControl, control )
import Foreign.C.Types             ( CInt(..) )
import Foreign.Marshal.Array       ( allocaArray )
import qualified Data.ByteString                        as B
import qualified Data.Vector.Generic                    as V
import qualified Data.Vector.Storable                   as W
trim_3' :: ([Nucleotides] -> [Qual] -> Bool) -> BamRec -> BamRec
trim_3' p b = case b_qual b of
    Nothing                         ->  b
    Just qs | b_flag b `testBit` 4  ->  trim_3 len_rev b
            | otherwise             ->  trim_3 len_fwd b
      where
        len_fwd = subtract 1 . length . takeWhile (uncurry p) $
                      zip (inits . reverse . V.toList $ b_seq b)
                          (inits . reverse . V.toList $ qs)
        len_rev = subtract 1 . length . takeWhile (uncurry p) $
                      zip (inits . V.toList $ b_seq  b)
                          (inits . V.toList $ qs)
trim_3 :: Int -> BamRec -> BamRec
trim_3 l b | b_flag b `testBit` 4 = trim_rev
           | otherwise            = trim_fwd
  where
    trim_fwd = let (_, cigar') = trim_back_cigar (b_cigar b) l
                   c = modMd (takeECig (V.length (b_seq  b) - l)) b
               in c { b_seq   = V.take (V.length (b_seq c) - l)  $  b_seq  c
                    , b_qual  = V.take (V.length (b_seq c) - l) <$> b_qual c
                    , b_cigar = cigar'
                    , b_exts  = map (\(k,e) -> case e of
                                        Text t | k `elem` trim_set
                                          -> (k, Text (B.take (B.length t - l) t))
                                        _ -> (k,e)
                                    ) (b_exts c) }
    trim_rev = let (off, cigar') = trim_fwd_cigar (b_cigar b) l
                   c = modMd (dropECig l) b
               in c { b_seq   = V.drop l  $  b_seq  c
                    , b_qual  = V.drop l <$> b_qual c
                    , b_pos   = b_pos c + off
                    , b_cigar = cigar'
                    , b_exts  = map (\(k,e) -> case e of
                                        Text t | k `elem` trim_set
                                          -> (k, Text (B.drop l t))
                                        _ -> (k,e)
                                    ) (b_exts c) }
    trim_set = ["BQ","CQ","CS","E2","OQ","U2"]
    modMd :: (ECig -> ECig) -> BamRec -> BamRec
    modMd f br = maybe br (setMD br . f . toECig (b_cigar br)) (getMd br)
    endOf :: ECig -> ECig
    endOf  WithMD     = WithMD
    endOf  WithoutMD  = WithoutMD
    endOf (Mat' _ es) = endOf es
    endOf (Ins' _ es) = endOf es
    endOf (SMa' _ es) = endOf es
    endOf (Rep' _ es) = endOf es
    endOf (Del' _ es) = endOf es
    endOf (Nop' _ es) = endOf es
    endOf (HMa' _ es) = endOf es
    endOf (Pad' _ es) = endOf es
    takeECig :: Int -> ECig -> ECig
    takeECig 0  es          = endOf es
    takeECig _  WithMD      = WithMD
    takeECig _  WithoutMD   = WithoutMD
    takeECig n (Mat' m  es) = Mat' n  $ if n > m then takeECig (n-m) es else WithMD
    takeECig n (Ins' m  es) = Ins' n  $ if n > m then takeECig (n-m) es else WithMD
    takeECig n (SMa' m  es) = SMa' n  $ if n > m then takeECig (n-m) es else WithMD
    takeECig n (Rep' ns es) = Rep' ns $ takeECig (n-1) es
    takeECig n (Del' ns es) = Del' ns $ takeECig n es
    takeECig n (Nop' m  es) = Nop' m  $ takeECig n es
    takeECig n (HMa' m  es) = HMa' m  $ takeECig n es
    takeECig n (Pad' m  es) = Pad' m  $ takeECig n es
    dropECig :: Int -> ECig -> ECig
    dropECig 0  es         = es
    dropECig _  WithMD     = WithMD
    dropECig _  WithoutMD  = WithoutMD
    dropECig n (Mat' m es) = if n > m then dropECig (n-m) es else Mat' n WithMD
    dropECig n (Ins' m es) = if n > m then dropECig (n-m) es else Ins' n WithMD
    dropECig n (SMa' m es) = if n > m then dropECig (n-m) es else SMa' n WithMD
    dropECig n (Rep' _ es) = dropECig (n-1) es
    dropECig n (Del' _ es) = dropECig n es
    dropECig n (Nop' _ es) = dropECig n es
    dropECig n (HMa' _ es) = dropECig n es
    dropECig n (Pad' _ es) = dropECig n es
trim_back_cigar, trim_fwd_cigar :: V.Vector v Cigar => v Cigar -> Int -> ( Int, v Cigar )
trim_back_cigar c l = (o, V.fromList $ reverse c') where (o,c') = sanitize_cigar . trim_cigar l $ reverse $ V.toList c
trim_fwd_cigar  c l = (o, V.fromList           c') where (o,c') = sanitize_cigar $ trim_cigar l $ V.toList c
sanitize_cigar :: (Int, [Cigar]) -> (Int, [Cigar])
sanitize_cigar (o, [        ])                          = (o, [])
sanitize_cigar (o, (op:*l):xs) | op == Pad              = sanitize_cigar (o,xs)         
                               | op == Del || op == Nop = sanitize_cigar (o + l, xs)    
                               | op == Ins              = (o, (SMa :* l):xs)            
                               | otherwise              = (o, (op :* l):xs)             
trim_cigar :: Int -> [Cigar] -> (Int, [Cigar])
trim_cigar 0 cs = (0, cs)
trim_cigar _ [] = (0, [])
trim_cigar l ((op:*ll):cs) | bad_op op = let (o,cs') = trim_cigar l cs in (o + reflen op ll, cs')
                           | otherwise = case l `compare` ll of
    LT -> (reflen op  l, (op :* (ll-l)):cs)
    EQ -> (reflen op ll,                cs)
    GT -> let (o,cs') = trim_cigar (l - ll) cs in (o + reflen op ll, cs')
  where
    reflen op' = if ref_op op' then id else const 0
    bad_op o = o /= Mat && o /= Ins && o /= SMa
    ref_op o = o == Mat || o == Del
trim_low_quality :: Qual -> a -> [Qual] -> Bool
trim_low_quality q = const $ all (< q)
find_merge :: AD_Seqs -> AD_Seqs
           -> W.Vector Nucleotides -> W.Vector Qual
           -> W.Vector Nucleotides -> W.Vector Qual
           -> IO (Int, Int, Int)
find_merge pads1 pads2 r1 q1 r2 q2 =
        with_fw_seq r1 q1 $ \pr1 ->
        with_fw_seq r2 q2 $ \pr2 ->
        with_rc_seq r2 q2 $ \prv2 -> do
            min_merge_score pads1 pads2 pr1 pr2 prv2
mergeBam :: Int -> Int
         -> AD_Seqs -> AD_Seqs
         -> BamRec -> BamRec -> IO [BamRec]
mergeBam lowq highq ads1 ads2 r1 r2 = do
    let len_r1    = V.length  $ b_seq  r1
        len_r2    = V.length  $ b_seq  r2
        b_seq_r1  = V.convert $ b_seq  r1
        b_seq_r2  = V.convert $ b_seq  r2
        b_qual_r1 = fromMaybe (V.map (const (Q 23)) b_seq_r1) (b_qual r1)
        b_qual_r2 = fromMaybe (V.map (const (Q 23)) b_seq_r2) (b_qual r2)
    (mlen, qual1, qual2) <- find_merge ads1 ads2 b_seq_r1 b_qual_r1 b_seq_r2 b_qual_r2
    let flag_alternative br = br { b_exts = updateE "FF" (Int $ extAsInt 0 "FF" br .|. eflagAlternative) $ b_exts br }
        store_quals      br = br { b_exts = updateE "YM" (Int qual1) $ updateE "YN" (Int qual2) $ b_exts br }
        pair_flags = flagPaired.|.flagProperlyPaired.|.flagMateUnmapped.|.flagMateReversed.|.flagFirstMate.|.flagSecondMate
        r1' = store_quals r1
        r2' = store_quals r2
        rm  = store_quals $ merged_read mlen (fromIntegral $ min 63 qual1)
        merged_read l qmax =
            nullBamRec
                { b_qname = b_qname r1
                , b_flag  = flagUnmapped .|. complement pair_flags .&. b_flag r1
                , b_seq   = V.convert $  merged_seq l b_seq_r1 b_qual_r1 b_seq_r2 b_qual_r2
                , b_qual  = Just $ merged_qual qmax l b_seq_r1 b_qual_r1 b_seq_r2 b_qual_r2
                , b_exts  = let ff = if l < len_r1 then eflagTrimmed else 0
                            in updateE "FF" (Int $ extAsInt 0 "FF" r1 .|. eflagMerged .|. ff) $ b_exts r1 }
    return $ case () of
        _ | V.null (b_seq r1) && V.null (b_seq r2) -> [              ]
          | qual1 < lowq || mlen < 0               -> [ r1', r2'     ]
          | qual1 >= highq && mlen == 0            -> [              ]
          | qual1 >= highq                         -> [           rm ]
          | mlen < len_r1-20 || mlen < len_r2-20   -> [           rm ]
          | otherwise         -> map flag_alternative [ r1', r2', rm ]
{-# INLINE merged_seq #-}
merged_seq :: (V.Vector v Nucleotides, V.Vector v Qual)
           => Int -> v Nucleotides -> v Qual -> v Nucleotides -> v Qual -> v Nucleotides
merged_seq l b_seq_r1 b_qual_r1 b_seq_r2 b_qual_r2 = V.concat
        [ V.take (l - len_r2) b_seq_r1
        , V.zipWith4 zz          (V.take l $ V.drop (l - len_r2) b_seq_r1)
                                 (V.take l $ V.drop (l - len_r2) b_qual_r1)
                     (V.reverse $ V.take l $ V.drop (l - len_r1) b_seq_r2)
                     (V.reverse $ V.take l $ V.drop (l - len_r1) b_qual_r2)
        , V.reverse $ V.take (l - len_r1) b_seq_r2 ]
  where
    len_r1 = V.length b_qual_r1
    len_r2 = V.length b_qual_r2
    zz !n1 (Q !q1) !n2 (Q !q2) | n1 == compls n2 =        n1
                               | q1 > q2         =        n1
                               | otherwise       = compls n2
{-# INLINE merged_qual #-}
merged_qual :: (V.Vector v Nucleotides, V.Vector v Qual)
            => Word8 -> Int -> v Nucleotides -> v Qual -> v Nucleotides -> v Qual -> v Qual
merged_qual qmax l b_seq_r1 b_qual_r1 b_seq_r2 b_qual_r2 = V.concat
        [ V.take (l - len_r2) b_qual_r1
        , V.zipWith4 zz           (V.take l $ V.drop (l - len_r2) b_seq_r1)
                                  (V.take l $ V.drop (l - len_r2) b_qual_r1)
                      (V.reverse $ V.take l $ V.drop (l - len_r1) b_seq_r2)
                      (V.reverse $ V.take l $ V.drop (l - len_r1) b_qual_r2)
        , V.reverse $ V.take (l - len_r1) b_qual_r2 ]
  where
    len_r1 = V.length b_qual_r1
    len_r2 = V.length b_qual_r2
    zz !n1 (Q !q1) !n2 (Q !q2) | n1 == compls n2 = Q $ min qmax (q1 + q2)
                               | q1 > q2         = Q $           q1 - q2
                               | otherwise       = Q $           q2 - q1
find_trim :: AD_Seqs
          -> W.Vector Nucleotides -> W.Vector Qual
          -> IO (Int, Int, Int)
find_trim pads1 r1 q1 =
        withADSeqs [W.empty]              $ \pads2 ->
        with_fw_seq r1 q1                   $ \pr1 ->
        min_merge_score pads1 pads2 pr1 (FW_Seq nullPtr nullPtr 0) (RC_Seq nullPtr nullPtr 0)
trimBam :: Int -> Int -> AD_Seqs -> BamRec -> IO [BamRec]
trimBam lowq highq ads1 r1 = do
    let b_seq_r1 = V.convert $ b_seq r1
    (mlen, qual1, qual2) <- find_trim ads1 b_seq_r1 $
                            fromMaybe (V.map (const (Q 23)) b_seq_r1) (b_qual r1)
    let flag_alternative br = br { b_exts = updateE "FF" (Int $ extAsInt 0 "FF" br .|. eflagAlternative) $ b_exts br }
        store_quals      br = br { b_exts = updateE "YM" (Int qual1) $ updateE "YN" (Int qual2) $ b_exts br }
        r1'  = store_quals r1
        r1t  = store_quals $ trimmed_read mlen
        trimmed_read l = nullBamRec {
                b_qname = b_qname r1,
                b_flag  = flagUnmapped .|. b_flag r1,
                b_seq   = V.take l  $  b_seq  r1,
                b_qual  = V.take l <$> b_qual r1,
                b_exts  = updateE "FF" (Int $ extAsInt 0 "FF" r1 .|. eflagTrimmed) $ b_exts r1 }
    return $ case () of
        _ | V.null (b_seq r1)              -> [          ]
          | mlen == 0 && qual1 >= highq    -> [          ]
          | qual1 < lowq || mlen < 0       -> [ r1'      ]
          | qual1 >= highq                 -> [      r1t ]
          | otherwise -> map flag_alternative [ r1', r1t ]
default_fwd_adapters :: [W.Vector Nucleotides]
default_fwd_adapters = map (W.fromList . map toNucleotides . map c2w)
         [   "AGATCGGAAGAGCGGTTCAG"
         ,   "AGATCGGAAGAGCACACGTC"
         ,   "AGATCGGAAGAGCTCGTATG" ]
default_rev_adapters :: [W.Vector Nucleotides]
default_rev_adapters = map (W.fromList . map toNucleotides . map c2w)
         [   "AGATCGGAAGAGCGTCGTGT"
         ,   "GGAAGAGCGTCGTGTAGGGA" ]
min_merge_score
    :: AD_Seqs              
    -> AD_Seqs              
    -> FW_Seq               
    -> FW_Seq               
    -> RC_Seq               
    -> IO (Int,Int,Int)     
min_merge_score (AD_Seqs !p_fwd_ads !p_fwd_lns !n_fwd_ads) (AD_Seqs !p_rev_ads !p_rev_lns !n_rev_ads)
            (FW_Seq !p_rd1 !p_qs1 !l1) (FW_Seq !p_rd2 !p_qs2 !l2) (RC_Seq !p_rrd2 !p_rqs2 _) =
    allocaArray 2 $ \pmins ->
        liftM3 (,,)
               (fromIntegral <$>
                prim_merge_score p_fwd_ads p_fwd_lns (fromIntegral n_fwd_ads)
                                 p_rev_ads p_rev_lns (fromIntegral n_rev_ads)
                                 p_rd1 p_qs1 (fromIntegral l1)
                                 p_rd2 p_qs2 (fromIntegral l2)
                                 p_rrd2 p_rqs2 pmins)
               (fromIntegral <$> peekElemOff pmins 0)
               (fromIntegral <$> peekElemOff pmins 1)
foreign import ccall unsafe "prim_merge_score"
    prim_merge_score :: Ptr (Ptr Nucleotides) -> Ptr CInt -> CInt
                     -> Ptr (Ptr Nucleotides) -> Ptr CInt -> CInt
                     -> Ptr Nucleotides -> Ptr Qual -> CInt
                     -> Ptr Nucleotides -> Ptr Qual -> CInt
                     -> Ptr Nucleotides -> Ptr Qual
                     -> Ptr CInt -> IO CInt
data AD_Seqs = AD_Seqs !(Ptr (Ptr Nucleotides)) !(Ptr CInt) !Int
data FW_Seq  = FW_Seq !(Ptr Nucleotides) !(Ptr Qual) !Int
data RC_Seq  = RC_Seq !(Ptr Nucleotides) !(Ptr Qual) !Int
withADSeqs :: MonadBaseControl IO m => [W.Vector Nucleotides] -> (AD_Seqs -> m r) -> m r
withADSeqs ads0 k =
    control                                                 $ \run_io ->
    allocaArray (length ads0)                               $ \pps ->
    allocaArray (length ads0)                               $ \pls ->
    let go !n [    ] = run_io (k $! AD_Seqs pps pls n)
        go !n (v:vs) = W.unsafeWith v $ \pa -> do
                       pokeElemOff pps n pa
                       pokeElemOff pls n (fromIntegral (W.length v))
                       go (succ n) vs
    in go 0 ads0
with_fw_seq :: W.Vector Nucleotides -> W.Vector Qual -> (FW_Seq -> IO r) -> IO r
with_fw_seq ns qs k
    | W.length ns == W.length qs
        = W.unsafeWith ns    $ \p_ns ->
          W.unsafeWith qs    $ \p_qs ->
          k (FW_Seq p_ns p_qs $ W.length ns)
    | otherwise
        = throwIO $ LengthMismatch "forward adapter"
{-# INLINE with_fw_seq #-}
with_rc_seq :: W.Vector Nucleotides -> W.Vector Qual -> (RC_Seq -> IO r) -> IO r
with_rc_seq ns qs k
    | W.length ns == W.length qs
        = W.unsafeWith (W.reverse $ W.map compls ns)    $ \p_rns ->
          W.unsafeWith (W.reverse qs)                   $ \p_rqs ->
          k (RC_Seq p_rns p_rqs $ W.length ns)
    | otherwise
        = throwIO $ LengthMismatch "reverse adapter"
{-# INLINE with_rc_seq #-}