{-# LANGUAGE Rank2Types #-}
{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE CPP #-}
module Codec.Picture.BitWriter( BoolReader
                              , emptyBoolState
                              , BoolState
                              , byteAlignJpg
                              , getNextBitsLSBFirst
                              , getNextBitsMSBFirst
                              , getNextBitJpg
                              , getNextIntJpg
                              , setDecodedString
                              , setDecodedStringMSB
                              , setDecodedStringJpg
                              , runBoolReader
                              , BoolWriteStateRef
                              , newWriteStateRef
                              , finalizeBoolWriter
                              , finalizeBoolWriterGif
                              , writeBits'
                              , writeBitsGif
                              , initBoolState
                              , initBoolStateJpg
                              , execBoolReader
                              , runBoolReaderWith
                              ) where
#if !MIN_VERSION_base(4,8,0)
import Control.Applicative( (<*>), (<$>) )
#endif
import Data.STRef
import Control.Monad( when )
import Control.Monad.ST( ST )
import qualified Control.Monad.Trans.State.Strict as S
import Data.Int ( Int32 )
import Data.Word( Word8, Word32 )
import Data.Bits( (.&.), (.|.), unsafeShiftR, unsafeShiftL )
import Codec.Picture.VectorByteConversion( blitVector )
import qualified Data.Vector.Storable.Mutable as M
import qualified Data.Vector.Storable as VS
import qualified Data.ByteString as B
import qualified Data.ByteString.Lazy as L
data BoolState = BoolState {-# UNPACK #-} !Int
                           {-# UNPACK #-} !Word8
                           !B.ByteString
emptyBoolState :: BoolState
emptyBoolState = BoolState (-1) 0 B.empty
type BoolReader s a = S.StateT BoolState (ST s) a
runBoolReader :: BoolReader s a -> ST s a
runBoolReader action = S.evalStateT action $ BoolState 0 0 B.empty
runBoolReaderWith :: BoolState -> BoolReader s a -> ST s (a, BoolState)
runBoolReaderWith st action = S.runStateT action st
execBoolReader :: BoolState -> BoolReader s a -> ST s BoolState
execBoolReader st reader = S.execStateT reader st
initBoolState :: B.ByteString -> BoolState
initBoolState str = case B.uncons str of
     Nothing -> BoolState 0 0 B.empty
     Just (v, rest) -> BoolState 0 v rest
initBoolStateJpg :: B.ByteString -> BoolState
initBoolStateJpg str =
   case B.uncons str of
     Nothing -> BoolState 0 0 B.empty
     Just (0xFF, rest) -> case B.uncons rest of
            Nothing                  -> BoolState 7 0 B.empty
            Just (0x00, afterMarker) -> BoolState 7 0xFF afterMarker
            Just (_   , afterMarker) -> initBoolStateJpg afterMarker
     Just (v, rest) -> BoolState 7 v rest
setDecodedString :: B.ByteString -> BoolReader s ()
setDecodedString str = case B.uncons str of
     Nothing        -> S.put $ BoolState      0 0 B.empty
     Just (v, rest) -> S.put $ BoolState      0 v    rest
byteAlignJpg :: BoolReader s ()
byteAlignJpg = do
  BoolState idx _ chain <- S.get
  when (idx /= 7) (setDecodedStringJpg chain)
getNextBitJpg :: BoolReader s Bool
{-# INLINE getNextBitJpg #-}
getNextBitJpg = do
    BoolState idx v chain <- S.get
    let val = (v .&. (1 `unsafeShiftL` idx)) /= 0
    if idx == 0
      then setDecodedStringJpg chain
      else S.put $ BoolState (idx - 1) v chain
    return val
getNextIntJpg :: Int -> BoolReader s Int32
{-# INLINE getNextIntJpg #-}
getNextIntJpg = go 0 where
  go !acc !0 = return acc
  go !acc !n = do
    BoolState idx v chain <- S.get
    let !leftBits = 1 + fromIntegral idx
    if n >= leftBits then do
      setDecodedStringJpg chain
      let !remaining = n - leftBits
          !mask = (1 `unsafeShiftL` leftBits) - 1
          !finalV = fromIntegral v .&. mask
          !theseBits = finalV `unsafeShiftL` remaining
      go (acc .|. theseBits) remaining
    else do
      let !remaining = leftBits - n
          !mask = (1 `unsafeShiftL` n) - 1
          !finalV = fromIntegral v `unsafeShiftR` remaining
      S.put $ BoolState (fromIntegral remaining - 1) v chain
      return $ (finalV .&. mask) .|. acc
setDecodedStringMSB :: B.ByteString -> BoolReader s ()
setDecodedStringMSB str = case B.uncons str of
  Nothing        -> S.put $ BoolState      8 0 B.empty
  Just (v, rest) -> S.put $ BoolState      8 v    rest
{-# INLINE getNextBitsMSBFirst #-}
getNextBitsMSBFirst :: Int -> BoolReader s Word32
getNextBitsMSBFirst requested = go 0 requested where
  go :: Word32 -> Int -> BoolReader s Word32
  go !acc !0 = return acc
  go !acc !n = do
    BoolState idx v chain <- S.get
    let !leftBits = fromIntegral idx
    if n >= leftBits then do
      setDecodedStringMSB chain
      let !theseBits = fromIntegral v `unsafeShiftL` (n - leftBits)
      go (acc .|. theseBits) (n - leftBits)
    else do
      let !remaining = leftBits - n
          !mask = (1 `unsafeShiftL` remaining) - 1
      S.put $ BoolState (fromIntegral remaining) (v .&. mask) chain
      return $ (fromIntegral v `unsafeShiftR` remaining) .|. acc
{-# INLINE getNextBitsLSBFirst #-}
getNextBitsLSBFirst :: Int -> BoolReader s Word32
getNextBitsLSBFirst count = aux 0 count
  where aux acc 0 = return acc
        aux acc n = do
            bit <- getNextBit
            let nextVal | bit = acc .|. (1 `unsafeShiftL` (count - n))
                        | otherwise = acc
            aux nextVal (n - 1)
{-# INLINE getNextBit #-}
getNextBit :: BoolReader s Bool
getNextBit = do
    BoolState idx v chain <- S.get
    let val = (v .&. (1 `unsafeShiftL` idx)) /= 0
    if idx == 7
      then setDecodedString chain
      else S.put $ BoolState (idx + 1) v chain
    return val
setDecodedStringJpg :: B.ByteString -> BoolReader s ()
setDecodedStringJpg str = case B.uncons str of
     Nothing        -> S.put $ BoolState 7 0 B.empty
     Just (0xFF, rest) -> case B.uncons rest of
            Nothing                  -> S.put $ BoolState 7 0 B.empty
            Just (0x00, afterMarker) -> 
                S.put $ BoolState 7 0xFF afterMarker
            Just (_   , afterMarker) -> setDecodedStringJpg afterMarker
     Just (v, rest) ->
        S.put $ BoolState 7 v rest
defaultBufferSize :: Int
defaultBufferSize = 256 * 1024
data BoolWriteStateRef s = BoolWriteStateRef
        { bwsCurrBuffer   :: STRef s (M.MVector s Word8)
        , bwsBufferList   :: STRef s [B.ByteString]
        , bwsWrittenWords :: STRef s Int
        , bwsBitAcc       :: STRef s Word8
        , bwsBitReaded    :: STRef s Int
        }
newWriteStateRef :: ST s (BoolWriteStateRef s)
newWriteStateRef = do
    origMv <- M.new defaultBufferSize
    BoolWriteStateRef <$> newSTRef origMv
                      <*> newSTRef []
                      <*> newSTRef 0
                      <*> newSTRef 0
                      <*> newSTRef 0
finalizeBoolWriter :: BoolWriteStateRef s -> ST s L.ByteString
finalizeBoolWriter st = do
    flushLeftBits' st
    forceBufferFlushing' st
    L.fromChunks <$> readSTRef (bwsBufferList st)
forceBufferFlushing' :: BoolWriteStateRef s -> ST s ()
forceBufferFlushing' (BoolWriteStateRef { bwsCurrBuffer = vecRef
                                        , bwsWrittenWords = countRef
                                        , bwsBufferList = lstRef
                                        }) = do
    vec <- readSTRef vecRef
    count <- readSTRef countRef
    lst <- readSTRef lstRef
    nmv <- M.new defaultBufferSize
    str <- byteStringFromVector vec count
    writeSTRef vecRef nmv
    writeSTRef lstRef $ lst ++ [str]
    writeSTRef countRef 0
flushCurrentBuffer' :: BoolWriteStateRef s -> ST s ()
flushCurrentBuffer' st = do
    count <- readSTRef $ bwsWrittenWords st
    when (count >= defaultBufferSize)
         (forceBufferFlushing' st)
byteStringFromVector :: M.MVector s Word8 -> Int -> ST s B.ByteString
byteStringFromVector vec size = do
    frozen <- VS.unsafeFreeze vec
    return $ blitVector frozen 0 size
setBitCount' :: BoolWriteStateRef s -> Word8 -> Int -> ST s ()
{-# INLINE setBitCount' #-}
setBitCount' st acc count = do
    writeSTRef (bwsBitAcc st) acc
    writeSTRef (bwsBitReaded st) count
resetBitCount' :: BoolWriteStateRef s -> ST s ()
{-# INLINE resetBitCount' #-}
resetBitCount' st = setBitCount' st 0 0
pushByte' :: BoolWriteStateRef s -> Word8 -> ST s ()
{-# INLINE pushByte' #-}
pushByte' st v = do
    flushCurrentBuffer' st
    idx <- readSTRef (bwsWrittenWords st)
    vec <- readSTRef (bwsCurrBuffer st)
    M.write vec idx v
    writeSTRef (bwsWrittenWords st) $ idx + 1
flushLeftBits' :: BoolWriteStateRef s -> ST s ()
flushLeftBits' st = do
    currCount <- readSTRef $ bwsBitReaded st
    when (currCount > 0) $ do
      currWord <- readSTRef $ bwsBitAcc st
      pushByte' st $ currWord `unsafeShiftL` (8 - currCount)
writeBits' :: BoolWriteStateRef s
           -> Word32     
           -> Int        
           -> ST s ()
{-# INLINE writeBits' #-}
writeBits' st d c = do
    currWord <- readSTRef $ bwsBitAcc st
    currCount <- readSTRef $  bwsBitReaded st
    serialize d c currWord currCount
  where dumpByte 0xFF = pushByte' st 0xFF >> pushByte' st 0x00
        dumpByte    i = pushByte' st i
        serialize bitData bitCount currentWord count
            | bitCount + count == 8 = do
                     resetBitCount' st
                     dumpByte (fromIntegral $ (currentWord `unsafeShiftL` bitCount) .|.
                                                fromIntegral cleanData)
            | bitCount + count < 8 =
                let newVal = currentWord `unsafeShiftL` bitCount
                in setBitCount' st (newVal .|. fromIntegral cleanData) $ count + bitCount
            | otherwise =
                let leftBitCount = 8 - count :: Int
                    highPart = cleanData `unsafeShiftR` (bitCount - leftBitCount) :: Word32
                    prevPart = fromIntegral currentWord `unsafeShiftL` leftBitCount :: Word32
                    nextMask = (1 `unsafeShiftL` (bitCount - leftBitCount)) - 1 :: Word32
                    newData = cleanData .&. nextMask :: Word32
                    newCount = bitCount - leftBitCount :: Int
                    toWrite = fromIntegral $ prevPart .|. highPart :: Word8
                in dumpByte toWrite >> serialize newData newCount 0 0
              where cleanMask = (1 `unsafeShiftL` bitCount) - 1 :: Word32
                    cleanData = bitData .&. cleanMask     :: Word32
writeBitsGif :: BoolWriteStateRef s
             -> Word32     
             -> Int        
             -> ST s ()
{-# INLINE writeBitsGif #-}
writeBitsGif st d c = do
    currWord <- readSTRef $ bwsBitAcc st
    currCount <- readSTRef $  bwsBitReaded st
    serialize d c currWord currCount
  where dumpByte = pushByte' st
        serialize bitData bitCount currentWord count
            | bitCount + count == 8 = do
                     resetBitCount' st
                     dumpByte (fromIntegral $ currentWord  .|.
                                                (fromIntegral cleanData `unsafeShiftL` count))
            | bitCount + count < 8 =
                let newVal = fromIntegral cleanData `unsafeShiftL` count
                in setBitCount' st (newVal .|. currentWord) $ count + bitCount
            | otherwise =
                let leftBitCount = 8 - count :: Int
                    newData = cleanData `unsafeShiftR` leftBitCount :: Word32
                    newCount = bitCount - leftBitCount :: Int
                    toWrite = fromIntegral $ fromIntegral currentWord
                                            .|. (cleanData `unsafeShiftL` count) :: Word8
                in dumpByte toWrite >> serialize newData newCount 0 0
              where cleanMask = (1 `unsafeShiftL` bitCount) - 1 :: Word32
                    cleanData = bitData .&. cleanMask     :: Word32
finalizeBoolWriterGif :: BoolWriteStateRef s -> ST s L.ByteString
finalizeBoolWriterGif st = do
    flushLeftBitsGif st
    forceBufferFlushing' st
    L.fromChunks <$> readSTRef (bwsBufferList st)
flushLeftBitsGif :: BoolWriteStateRef s -> ST s ()
flushLeftBitsGif st = do
    currCount <- readSTRef $ bwsBitReaded st
    when (currCount > 0) $ do
      currWord <- readSTRef $ bwsBitAcc st
      pushByte' st currWord
{-# ANN module "HLint: ignore Reduce duplication" #-}