{-# LANGUAGE ForeignFunctionInterface #-}
-- | Buffer builder to assemble Bgzf blocks.  The idea is to serialize
-- stuff (BAM and BCF) into a buffer, then bgzf chunks from the buffer.
-- We use a large buffer, and we always make sure there is plenty of
-- space in it (to avoid redundant checks).

module Bio.Streaming.Bgzf (
    bgunzip,
    getBgzfHdr,
    BB(..),
    newBuffer,
    fillBuffer,
    expandBuffer,
    encodeBgzf,
    BgzfTokens(..),
    BclArgs(..),
    BclSpecialType(..),
    loop_dec_int,
    loop_bcl_special,
    CompressionError(..),
    DecompressionError(..)
                            ) where

import Bio.Prelude
import Bio.Streaming
import Foreign.C.Types                     ( CInt(..) )
import Foreign.Marshal.Utils               ( copyBytes, fillBytes, with )

import qualified Bio.Streaming.Bytes        as S
import qualified Data.ByteString            as B
import qualified Data.ByteString.Internal   as B
import qualified Data.ByteString.Unsafe     as B
import qualified Data.Vector.Storable       as V

{-| Decompresses a bgzip stream.  Individual chunks are decompressed in
    parallel.  Leftovers are discarded (some compressed HETFA files
    appear to have junk at the end). -}

bgunzip :: MonadIO m => ByteStream m r -> ByteStream m r
bgunzip s = do
    np <- liftIO $ getNumCapabilities
    S.fromChunks $ psequence (2*np) $ lift (getBgzfHdr s) >>= go
  where
    go (Nothing,    _hdr, s1) = lift (S.effects s1)
    go (Just bsize, _hdr, s1) = do
            blk :> s2 <- lift $ S.splitAt' bsize s1
            wrap (decompressChunk blk :> (lift (getBgzfHdr s2) >>= go))
{-# INLINABLE bgunzip #-}

getBgzfHdr :: Monad m => ByteStream m r -> m (Maybe Int, B.ByteString, ByteStream m r)
getBgzfHdr s0 = do
        hdr :> s1 <- S.splitAt' 18 s0
        if or [ B.length hdr < 18
              , B.index hdr 0 /= 139
              , B.index hdr 1 /= 31
              , B.index hdr 3 .&. 4 /= 4
              , B.index hdr 10 /= 6
              , B.index hdr 11 /= 0
              , B.index hdr 12 /= 66
              , B.index hdr 13 /= 67 ]
          then return (Nothing, hdr, s1)
          else do
            let bsize = fromIntegral (B.index hdr 16) + fromIntegral (B.index hdr 17) `shiftL` 8 - 16
            return (Just bsize, hdr, s1)
{-# INLINE getBgzfHdr #-}

-- | We manage a large buffer (multiple megabytes), of which we fill an
-- initial portion.  We remember the size, the used part, and two marks
-- where we later fill in sizes for the length prefixed BAM or BCF
-- records.  We move the buffer down when we yield a piece downstream,
-- and when we run out of space, we simply move to a new buffer.
-- Garbage collection should take care of the rest.  Unused 'mark' must
-- be set to (maxBound::Int) so it doesn't interfere with flushing.

data BB = BB { buffer :: {-# UNPACK #-} !(ForeignPtr Word8)
             , size   :: {-# UNPACK #-} !Int            -- total size of buffer
             , off    :: {-# UNPACK #-} !Int            -- offset of active portion
             , used   :: {-# UNPACK #-} !Int            -- used portion (inactive & active)
             , mark   :: {-# UNPACK #-} !Int            -- offset of mark
             , mark2  :: {-# UNPACK #-} !Int }          -- offset of mark2

-- | Things we are able to encode.  Taking inspiration from
-- binary-serialise-cbor, we define these as a lazy list-like thing and
-- consume it in a interpreter.

data BgzfTokens = TkWord32   {-# UNPACK #-} !Word32       BgzfTokens -- a 4-byte int
                | TkWord16   {-# UNPACK #-} !Word16       BgzfTokens -- a 2-byte int
                | TkWord8    {-# UNPACK #-} !Word8        BgzfTokens -- a byte
                | TkFloat    {-# UNPACK #-} !Float        BgzfTokens -- a float
                | TkDouble   {-# UNPACK #-} !Double       BgzfTokens -- a double
                | TkString   {-# UNPACK #-} !B.ByteString BgzfTokens -- a raw string
                | TkDecimal  {-# UNPACK #-} !Int          BgzfTokens -- roughly ':%d'

                | TkMemFill {-# UNPACK #-} !Int {-# UNPACK #-} !Word8   BgzfTokens
                | TkMemCopy {-# UNPACK #-} !(V.Vector Word8)            BgzfTokens

                | TkSetMark                               BgzfTokens -- sets the first mark
                | TkEndRecord                             BgzfTokens -- completes a BAM record
                | TkEndRecordPart1                        BgzfTokens -- completes part 1 of a BCF record
                | TkEndRecordPart2                        BgzfTokens -- completes part 2 of a BCF record
                | TkEnd                                              -- nothing more, for now

                | TkBclSpecial !BclArgs                   BgzfTokens
                | TkLowLevel {-# UNPACK #-} !Int (BB -> IO BB) BgzfTokens

data BclSpecialType = BclNucsBin  | BclNucsAsc  | BclNucsAscRev  | BclNucsWide
                    | BclQualsBin | BclQualsAsc | BclQualsAscRev

data BclArgs = BclArgs BclSpecialType
                       {-# UNPACK #-} !(V.Vector Word8)   -- bcl matrix
                       {-# UNPACK #-} !Int                -- stride
                       {-# UNPACK #-} !Int                -- first cycle
                       {-# UNPACK #-} !Int                -- last cycle
                       {-# UNPACK #-} !Int                -- cluster index

-- | Creates a buffer.
newBuffer :: Int -> IO BB
newBuffer sz = mallocForeignPtrBytes sz >>= \ar -> return $ BB ar sz 0 0 maxBound maxBound

-- | Creates a new buffer, copying the active content from an old one,
-- with higher capacity.  The size of the new buffer is twice the free
-- space in the old buffer, but at least @minsz@.
expandBuffer :: Int -> BB -> IO BB
expandBuffer minsz b = do
    let sz' = max (2 * (size b - used b)) minsz
    arr1 <- mallocForeignPtrBytes sz'
    withForeignPtr arr1 $ \d ->
        withForeignPtr (buffer b) $ \s ->
             copyBytes d (plusPtr s (off b)) (used b - off b)
    return BB{ buffer = arr1
             , size   = sz'
             , off    = 0
             , used   = used b - off b
             , mark   = if mark  b == maxBound then maxBound else mark  b - off b
             , mark2  = if mark2 b == maxBound then maxBound else mark2 b - off b }

data CompressionError = CompressionError !CInt deriving (Typeable,Show)
instance Exception CompressionError where
    displayException (CompressionError rc) = "compress_chunk failed: " ++ show rc

data DecompressionError = DecompressionError !CInt deriving (Typeable,Show)
instance Exception DecompressionError where
    displayException (DecompressionError rc) = "decompress_chunk failed: " ++ show rc

compressChunk :: Int -> ForeignPtr Word8 -> Int -> Int -> IO B.ByteString
compressChunk lv fptr off slen =
    withForeignPtr fptr                 $ \ptr ->
    B.createAndTrim 65536               $ \buf ->
    with 65536                          $ \p_len -> do
        rc <- compress_chunk buf p_len (plusPtr ptr off) (fromIntegral slen) (fromIntegral lv)
        when (rc /= 0 && rc /= 1) $ throwIO $ CompressionError rc
        fromIntegral <$> peek p_len

decompressChunk :: B.ByteString -> IO B.ByteString
decompressChunk ck =
    B.unsafeUseAsCString ck                         $ \psrc ->
    peekByteOff psrc (B.length ck - 4)            >>= \dlen ->
    B.create (fromIntegral (dlen::Word32))          $ \pdest -> do
        rc <- decompress_chunk pdest (fromIntegral dlen) (castPtr psrc) (fromIntegral $ B.length ck)
        when (rc /= 0) $ throwIO $ DecompressionError rc


-- | Expand a chain of tokens into a buffer, sending finished pieces
-- downstream as soon as possible.
encodeBgzf :: MonadIO m => Int -> Stream (Of (Endo BgzfTokens)) m b -> S.ByteStream m b
encodeBgzf lv str = do
    np <- liftIO $ getNumCapabilities
    bb <- liftIO $ newBuffer (1024*1024)
    S.fromChunks $ psequence (2*np) $ lift (inspect str) >>= go bb
  where
    go :: MonadIO m
       => BB
       -> Either b (Of (Endo BgzfTokens) (Stream (Of (Endo BgzfTokens)) m b))
       -> Stream (Of (IO Bytes)) m b
    go bb0 (Left r) = final_flush bb0 r
    go bb0 (Right (f :> s))
        -- initially, we make sure we have reasonable space.  this may not be enough.
        | size bb0 - used bb0 < 1024 = liftIO (expandBuffer (1024*1024) bb0) >>= \bb' -> go' bb' (appEndo f TkEnd) s
        | otherwise                  =                                                   go' bb0 (appEndo f TkEnd) s

    go' bb0 tk s = liftIO (fillBuffer bb0 tk) >>= \(bb',tk') -> flush_blocks tk' bb' s

    -- We can flush anything that is between 'off' and the lower of 'mark'
    -- and 'used'.  When done, we bump 'off'.
    flush_blocks tk bb s
        | min (mark bb) (used bb) - off bb < maxBlockSize =
            case tk of TkEnd -> lift (inspect s) >>= go bb
                       _     -> -- we arrive here because we ran out of buffer space, so we expand it.
                                liftIO (expandBuffer (1024*1024) bb) >>= \bb' -> go' bb' tk s
        | otherwise =
            wrap $  compressChunk lv (buffer bb) (off bb) maxBlockSize
                 :> flush_blocks tk bb { off = off bb + maxBlockSize } s

    final_flush bb r
        | used bb > off bb =
            wrap $  compressChunk lv (buffer bb) (off bb) (used bb - off bb)
                 :> wrap (return bgzfEofMarker :> pure r)
        | otherwise =
            wrap (return bgzfEofMarker :> pure r)

    -- maximum block size for Bgzf: 64k with some room for
    -- headers and uncompressible stuff
    maxBlockSize = 65478


fillBuffer :: BB -> BgzfTokens -> IO (BB, BgzfTokens)
fillBuffer bb0 tk = withForeignPtr (buffer bb0) (\p -> go_slowish p bb0 tk)
  where
    go_slowish p bb = go_fast p bb (used bb)

    go_fast p bb use tk1 = case tk1 of
        -- no space?  not our job.
        _ | size bb - use < 1024 -> return (bb { used = use },tk1)

        -- the actual end.
        TkEnd                    -> return (bb { used = use },tk1)

        -- I'm cheating.  This stuff works only if the platform allows
        -- unaligned accesses, is little-endian and uses IEEE floats.
        -- It's true on i386 and ix86_64.
        TkWord32   x tk' -> do pokeByteOff p use x
                               go_fast p bb (use + 4) tk'

        TkWord16   x tk' -> do pokeByteOff p use x
                               go_fast p bb (use + 2) tk'

        TkWord8    x tk' -> do pokeByteOff p use x
                               go_fast p bb (use + 1) tk'

        TkFloat    x tk' -> do pokeByteOff p use x
                               go_fast p bb (use + 4) tk'

        TkDouble   x tk' -> do pokeByteOff p use x
                               go_fast p bb (use + 8) tk'

        -- The next three may be too big to handle.  By returning with
        -- unfinished business, we will get progressively bigger buffers
        -- and eventually handle it just fine.
        TkString   s tk'
            | B.length s > size bb - use -> return (bb { used = use },tk1)

            | otherwise  -> do let ln = B.length s
                               B.unsafeUseAsCString s $ \q ->
                                    copyBytes (p `plusPtr` use) q ln
                               go_fast p bb (use + ln) tk'

        TkMemFill ln c tk'
            | ln > size bb - use -> return (bb { used = use },tk1)

            | otherwise  -> do fillBytes (p `plusPtr` use) c ln
                               go_fast p bb (use + ln) tk'

        TkMemCopy v tk'
            | V.length v > size bb - use -> return (bb { used = use },tk1)

            | otherwise  -> do let ln = V.length v
                               V.unsafeWith v $ \q ->
                                    copyBytes (p `plusPtr` use) q ln
                               go_fast p bb (use + ln) tk'


        TkDecimal  x tk' -> do ln <- int_loop (p `plusPtr` use) (fromIntegral x)
                               go_fast p bb (use + fromIntegral ln) tk'

        TkSetMark        tk' ->    go_slowish p bb { used = use + 4, mark = use } tk'

        TkEndRecord      tk' -> do let !l = use - mark bb - 4
                                   pokeByteOff p (mark bb) (fromIntegral l :: Word32)
                                   go_slowish p bb { used = use, mark = maxBound } tk'

        TkEndRecordPart1 tk' -> do let !l = use - mark bb - 4
                                   pokeByteOff p (mark bb - 4) (fromIntegral l :: Word32)
                                   go_slowish p bb { used = use, mark2 = use } tk'

        TkEndRecordPart2 tk' -> do let !l = use - mark2 bb
                                   pokeByteOff p (mark bb) (fromIntegral l :: Word32)
                                   go_slowish p bb { used = use, mark = maxBound } tk'

        TkBclSpecial special_args tk' -> do
            l <- loop_bcl_special (p `plusPtr` use) special_args
            go_fast p bb (use + l) tk'

        TkLowLevel minsize proc tk'
            | size bb - use < minsize -> return (bb { used = use },tk1)
            | otherwise               -> do bb' <- proc bb { used = use }
                                            go_slowish p bb' tk'

-- | The EOF marker for BGZF files.
-- This is just an empty string compressed as BGZF.  Appended to BAM
-- files to indicate their end.
bgzfEofMarker :: Bytes
bgzfEofMarker = "\x1f\x8b\x8\x4\0\0\0\0\0\xff\x6\0\x42\x43\x2\0\x1b\0\x3\0\0\0\0\0\0\0\0\0"

loop_dec_int :: Ptr Word8 -> Int -> IO Int
loop_dec_int p i = fromIntegral <$> int_loop p (fromIntegral i)

loop_bcl_special :: Ptr Word8 -> BclArgs -> IO Int
loop_bcl_special p (BclArgs tp vec stride u v i) =

    V.unsafeWith vec $ \q -> case tp of
        BclNucsBin -> do
            nuc_loop p (fromIntegral stride) (plusPtr q i) (fromIntegral u) (fromIntegral v)
            return $ (v - u + 2) `div` 2

        BclNucsWide -> do
            nuc_loop_wide p (fromIntegral stride) (plusPtr q i) (fromIntegral u) (fromIntegral v)
            return $ v - u + 1

        BclNucsAsc -> do
            nuc_loop_asc p (fromIntegral stride) (plusPtr q i) (fromIntegral u) (fromIntegral v)
            return $ v - u + 1

        BclNucsAscRev -> do
            nuc_loop_asc_rev p (fromIntegral stride) (plusPtr q i) (fromIntegral u) (fromIntegral v)
            return $ v - u + 1

        BclQualsBin -> do
            qual_loop p (fromIntegral stride) (plusPtr q i) (fromIntegral u) (fromIntegral v)
            return $ v - u + 1

        BclQualsAsc -> do
            qual_loop_asc p (fromIntegral stride) (plusPtr q i) (fromIntegral u) (fromIntegral v)
            return $ v - u + 1

        BclQualsAscRev -> do
            qual_loop_asc_rev p (fromIntegral stride) (plusPtr q i) (fromIntegral u) (fromIntegral v)
            return $ v - u + 1

foreign import ccall unsafe "nuc_loop"
    nuc_loop :: Ptr Word8 -> CInt -> Ptr Word8 -> CInt -> CInt -> IO ()

foreign import ccall unsafe "nuc_loop_wide"
    nuc_loop_wide :: Ptr Word8 -> CInt -> Ptr Word8 -> CInt -> CInt -> IO ()

foreign import ccall unsafe "nuc_loop_asc"
    nuc_loop_asc :: Ptr Word8 -> CInt -> Ptr Word8 -> CInt -> CInt -> IO ()

foreign import ccall unsafe "nuc_loop_asc_rev"
    nuc_loop_asc_rev :: Ptr Word8 -> CInt -> Ptr Word8 -> CInt -> CInt -> IO ()

foreign import ccall unsafe "qual_loop"
    qual_loop :: Ptr Word8 -> CInt -> Ptr Word8 -> CInt -> CInt -> IO ()

foreign import ccall unsafe "qual_loop_asc"
    qual_loop_asc :: Ptr Word8 -> CInt -> Ptr Word8 -> CInt -> CInt -> IO ()

foreign import ccall unsafe "qual_loop_asc_rev"
    qual_loop_asc_rev :: Ptr Word8 -> CInt -> Ptr Word8 -> CInt -> CInt -> IO ()

foreign import ccall unsafe "int_loop"
    int_loop :: Ptr Word8 -> CInt -> IO CInt

foreign import ccall unsafe "compress_chunk"
    compress_chunk :: Ptr Word8 -> Ptr CInt -> Ptr Word8 -> CInt -> CInt -> IO CInt

foreign import ccall unsafe "decompress_chunk"
    decompress_chunk :: Ptr Word8 -> CInt -> Ptr Word8 -> CInt -> IO CInt