-- | Printers for BAM and SAM.  BAM is properly supported, SAM can be
-- piped to standard output.

module Bio.Bam.Writer (
    IsBamRec(..),
    encodeBamWith,

    packBam,
    writeBamFile,
    writeBamHandle,
    pipeBamOutput,
    pipeSamOutput
                      ) where

import Bio.Bam.Header
import Bio.Bam.Rec
import Bio.Prelude
import Bio.Streaming
import Bio.Streaming.Bgzf

import Data.ByteString.Builder.Prim ( (>*<) )
import Data.ByteString.Internal     ( fromForeignPtr )
import Data.ByteString.Lazy         ( foldrChunks )
import Foreign.Marshal.Alloc        ( alloca )

import qualified Bio.Streaming.Bytes                as S
import qualified Data.ByteString                    as B
import qualified Data.ByteString.Builder            as B
import qualified Data.ByteString.Builder.Extra      as B
import qualified Data.ByteString.Builder.Prim       as E
import qualified Data.Vector.Generic                as V
import qualified Data.Vector.Storable               as W
import qualified Data.Vector.Unboxed                as U
import qualified Streaming.Prelude                  as Q

{- | write in SAM format to stdout

This is useful for piping to other tools (say, AWK scripts) or for
debugging.  No convenience functions to send SAM to a file or to
compress it exist, because these are stupid ideas.
-}
pipeSamOutput :: (IsBamRec a, MonadIO m) => BamMeta -> Stream (Of a) m r -> m r
pipeSamOutput meta s = do
    liftIO . B.hPutBuilder stdout $ showBamMeta meta
    Q.mapM_ (liftIO . B.hPutBuilder stdout . encodeSamEntry (meta_refs meta) . unpackBamRec) s
{-# INLINE pipeSamOutput #-}

encodeSamEntry :: Refs -> BamRec -> B.Builder
encodeSamEntry refs b =
    B.byteStringCopy (b_qname b)                         <> B.char7 '\t' <>
    B.intDec         (b_flag b .&. 0xffff)               <> B.char7 '\t' <>
    B.byteStringCopy (sq_name $ getRef refs $ b_rname b) <> B.char7 '\t' <>
    B.intDec         (b_pos b + 1)                       <> B.char7 '\t' <>
    B.word8Dec       (unQ $ b_mapq b)                    <> B.char7 '\t' <>
    buildCigar       (b_cigar b)                         <> B.char7 '\t' <>
    buildMrnm        (b_mrnm b) (b_rname b)              <> B.char7 '\t' <>
    B.intDec         (b_mpos b + 1)                      <> B.char7 '\t' <>
    B.intDec         (b_isize b)                         <> B.char7 '\t' <>
    buildSeq         (b_seq b)                           <> B.char7 '\t' <>
    buildQual        (b_qual b)                          <>
    foldMap buildExt (b_exts b)                          <> B.char7 '\n'
  where
    buildCigar = E.primUnfoldrBounded
                    (E.intDec >*< E.liftFixedToBounded E.word8)
                    (vuncons $ \(op :* num) -> (num, B.index "MIDNSHP" (fromEnum op)))

    buildMrnm mrnm rname
        | isValidRefseq mrnm && mrnm == rname  =  B.char7 '='
        | otherwise                            =  B.byteString (sq_name $ getRef refs mrnm)

    buildSeq  = E.primUnfoldrFixed E.word8 (vuncons $ \(Ns x) -> B.index "-ACMGRSVTWYHKDBN" $ fromIntegral $ x .&. 15)
    buildQual = maybe (B.char7 '*') (E.primUnfoldrFixed E.word8 (vuncons $ \(Q q) -> q + 33))

    buildExt (BamKey k,v) = B.char7 '\t' <>
                            B.word8 (fromIntegral         k   ) <>
                            B.word8 (fromIntegral (shiftR k 8)) <>
                            B.char7 ':' <>
                            buildExtVal v

    buildExtVal (Int      i) = B.char7 'i' <> B.char7 ':' <> B.intDec i
    buildExtVal (Float    f) = B.char7 'f' <> B.char7 ':' <> B.floatDec f
    buildExtVal (Text     t) = B.char7 'Z' <> B.char7 ':' <> B.byteStringCopy t
    buildExtVal (Bin      x) = B.char7 'H' <> B.char7 ':' <> B.byteStringHex x
    buildExtVal (Char     c) = B.char7 'A' <> B.char7 ':' <> B.word8 c
    buildExtVal (IntArr   a) = B.char7 'B' <> B.char7 ':' <> B.char7 'i' <> buildArr   B.intDec a
    buildExtVal (FloatArr a) = B.char7 'B' <> B.char7 ':' <> B.char7 'f' <> buildArr B.floatDec a

    buildArr p = U.foldr (\x k -> B.char7 ',' <> p x <> k) mempty

    vuncons f v | V.null  v = Nothing
                | otherwise = Just (f (V.unsafeHead v), V.unsafeTail v)


class IsBamRec a where
    pushBam :: a -> BgzfTokens -> BgzfTokens
    unpackBamRec :: a -> BamRec

instance IsBamRec BamRaw where
    {-# INLINE pushBam #-}
    pushBam = pushBamRaw
    {-# INLINE unpackBamRec #-}
    unpackBamRec = unpackBam

instance IsBamRec BamRec where
    {-# INLINE pushBam #-}
    pushBam = pushBamRec
    {-# INLINE unpackBamRec #-}
    unpackBamRec = id

instance (IsBamRec a, IsBamRec b) => IsBamRec (Either a b) where
    {-# INLINE pushBam #-}
    pushBam = either pushBam pushBam
    {-# INLINE unpackBamRec #-}
    unpackBamRec = either unpackBamRec unpackBamRec

-- | Encodes BAM records straight into a dynamic buffer, then BGZF's it.
-- Should be fairly direct and perform well.
encodeBamWith :: (IsBamRec a, MonadIO m) => Int -> BamMeta -> Stream (Of a) m r -> ByteStream m r
encodeBamWith lv meta = encodeBgzf lv . enc_bam
  where
    enc_bam bs = Q.cons pushHeader $ Q.map (Endo . pushBam) bs

    pushHeader :: Endo BgzfTokens
    pushHeader = Endo $ TkString "BAM\1"
                      . TkSetMark                        -- the length byte
                      . pushBuilder (showBamMeta meta)
                      . TkEndRecord                      -- fills the length in
                      . TkWord32 (fromIntegral . V.length . unRefs $ meta_refs meta)
                      . appEndo (foldMap (Endo . pushRef) (unRefs $ meta_refs meta))

    pushRef :: BamSQ -> BgzfTokens -> BgzfTokens
    pushRef bs = TkWord32 (fromIntegral $ B.length (sq_name bs) + 1)
               . TkString (sq_name bs)
               . TkWord8 0
               . TkWord32 (fromIntegral $ sq_length bs)

    pushBuilder :: B.Builder -> BgzfTokens -> BgzfTokens
    pushBuilder b tk = foldrChunks TkString tk (B.toLazyByteString b)
{-# INLINE encodeBamWith #-}

pushBamRaw :: BamRaw -> BgzfTokens -> BgzfTokens
pushBamRaw r = TkWord32 (fromIntegral $ B.length $ raw_data r) .
               TkString (raw_data r)
{-# INLINE pushBamRaw #-}

-- | Writes BAM encoded stuff to a file.
-- In reality, it cleverly writes to a temporary file and renames it
-- when done.
writeBamFile :: (IsBamRec a, MonadIO m, MonadMask m) => FilePath -> BamMeta -> Stream (Of a) m r -> m r
writeBamFile fp meta = S.writeFile fp . encodeBamWith 6 meta

-- | Write BAM encoded stuff to stdout.
-- This sends uncompressed(!) BAM to stdout.  Useful for piping to other
-- tools.  The output is still wrapped in a BGZF stream, because that's
-- what all tools expect; but the individuals blocks are not compressed.
pipeBamOutput :: (IsBamRec a, MonadIO m) => BamMeta -> Stream (Of a) m r -> m r
pipeBamOutput meta = S.hPut stdout . encodeBamWith 0 meta
{-# INLINE pipeBamOutput #-}

-- | Writes BAM encoded stuff to a 'Handle'.
writeBamHandle :: (IsBamRec a, MonadIO m) => Handle -> BamMeta -> Stream (Of a) m r -> m r
writeBamHandle hdl meta = S.hPut hdl . encodeBamWith 6 meta

{-# RULES
    "pushBam/unpackBam"     forall b . pushBamRec (unpackBam b) = pushBamRaw b
  #-}

{-# INLINE[1] pushBamRec #-}
pushBamRec :: BamRec -> BgzfTokens -> BgzfTokens
pushBamRec BamRec{..} =
      TkSetMark
    . TkWord32 (unRefseq b_rname)
    . TkWord32 (fromIntegral b_pos)
    . TkWord8  (fromIntegral $ B.length b_qname + 1)
    . TkWord8  (unQ b_mapq)
    . TkWord16 (fromIntegral bin)
    . TkWord16 (fromIntegral $ W.length b_cigar)
    . TkWord16 (fromIntegral b_flag)
    . TkWord32 (fromIntegral $ V.length b_seq)
    . TkWord32 (unRefseq b_mrnm)
    . TkWord32 (fromIntegral b_mpos)
    . TkWord32 (fromIntegral b_isize)
    . TkString b_qname
    . TkWord8 0
    . TkMemCopy (W.unsafeCast b_cigar)
    . pushSeq b_seq
    . maybe (TkMemFill (V.length b_seq) 0xff) (TkMemCopy . W.unsafeCast) b_qual
    . foldr ((.) . pushExt) id b_exts
    . TkEndRecord
  where
    bin = distinctBin b_pos (alignedLength b_cigar)

    pushSeq :: V.Vector vec Nucleotides => vec Nucleotides -> BgzfTokens -> BgzfTokens
    pushSeq v = case v V.!? 0 of
                    Nothing -> id
                    Just a  -> case v V.!? 1 of
                        Nothing -> TkWord8 (unNs a `shiftL` 4)
                        Just b  -> TkWord8 (unNs a `shiftL` 4 .|. unNs b) . pushSeq (V.drop 2 v)

    pushExt :: (BamKey, Ext) -> BgzfTokens -> BgzfTokens
    pushExt (BamKey k, e) = case e of
        Text  t -> common 'Z' . TkString t . TkWord8 0
        Bin   t -> common 'H' . TkString t . TkWord8 0
        Char  c -> common 'A' . TkWord8 c
        Float f -> common 'f' . TkWord32 (fromFloat f)

        Int i   -> case put_some_int (U.singleton i) of
                        (c,op) -> common c . op i

        IntArr  ia -> case put_some_int ia of
                        (c,op) -> common 'B' . TkWord8 (fromIntegral $ ord c)
                                  . TkWord32 (fromIntegral $ U.length ia-1)
                                  . U.foldr ((.) . op) id ia

        FloatArr fa -> common 'B' . TkWord8 (fromIntegral $ ord 'f')
                       . TkWord32 (fromIntegral $ U.length fa-1)
                       . U.foldr ((.) . TkWord32 . fromFloat) id fa
      where
        common :: Char -> BgzfTokens -> BgzfTokens
        common z = TkWord16 k . TkWord8 (fromIntegral $ ord z)

        put_some_int :: U.Vector Int -> (Char, Int -> BgzfTokens -> BgzfTokens)
        put_some_int is
            | U.all (between        0    0xff) is = ('C', TkWord8  . fromIntegral)
            | U.all (between   (-0x80)   0x7f) is = ('c', TkWord8  . fromIntegral)
            | U.all (between        0  0xffff) is = ('S', TkWord16 . fromIntegral)
            | U.all (between (-0x8000) 0x7fff) is = ('s', TkWord16 . fromIntegral)
            | U.all                      (> 0) is = ('I', TkWord32 . fromIntegral)
            | otherwise                           = ('i', TkWord32 . fromIntegral)

        between :: Int -> Int -> Int -> Bool
        between l r x = l <= x && x <= r

        fromFloat :: Float -> Word32
        fromFloat float = unsafeDupablePerformIO $ alloca $ \buf ->
                          pokeByteOff buf 0 float >> peek buf

packBam :: BamRec -> IO BamRaw
packBam br = do bb <- newBuffer 1000
                (bb', TkEnd) <- store_loop bb (pushBamRec br TkEnd)
                bamRaw 0 $ fromForeignPtr (buffer bb') 4 (used bb' - 4)
  where
    store_loop bb tk = do (bb',tk') <- fillBuffer bb tk
                          case tk' of TkEnd -> return (bb',tk')
                                      _     -> do bb'' <- expandBuffer (128*1024) bb'
                                                  store_loop bb'' tk'