{-# LANGUAGE BangPatterns, CPP #-}
module Codec.Picture.Gif.Internal.LZWEncoding( lzwEncode ) where

#if !MIN_VERSION_base(4,8,0)
import Control.Applicative( (<$>) )
import Data.Monoid( mempty )
#endif

import Control.Monad.ST( runST )
import qualified Data.ByteString.Lazy as L
import Data.Maybe( fromMaybe )
import Data.Word( Word8 )

#if MIN_VERSION_containers(0,5,0)
import qualified Data.IntMap.Strict as I
#else
import qualified Data.IntMap as I
#endif
import qualified Data.Vector.Storable as V

import Codec.Picture.BitWriter

type Trie = I.IntMap TrieNode

data TrieNode = TrieNode
    { TrieNode -> Int
trieIndex :: {-# UNPACK #-} !Int
    , TrieNode -> Trie
trieSub   :: !Trie
    }

emptyNode :: TrieNode
emptyNode :: TrieNode
emptyNode = TrieNode
    { trieIndex :: Int
trieIndex = -Int
1
    , trieSub :: Trie
trieSub = Trie
forall a. Monoid a => a
mempty
    }

initialTrie :: Trie
initialTrie :: Trie
initialTrie = [(Int, TrieNode)] -> Trie
forall a. [(Int, a)] -> IntMap a
I.fromList
    [(Int
i, TrieNode
emptyNode { trieIndex = i }) | Int
i <- [Int
0 .. Int
255]]

lookupUpdate :: V.Vector Word8 -> Int -> Int -> Trie -> (Int, Int, Trie)
lookupUpdate :: Vector Word8 -> Int -> Int -> Trie -> (Int, Int, Trie)
lookupUpdate Vector Word8
vector Int
freeIndex Int
firstIndex Trie
trie =
    (Int, Int, Maybe Trie) -> (Int, Int, Trie)
forall {a} {b}. (a, b, Maybe Trie) -> (a, b, Trie)
matchUpdate ((Int, Int, Maybe Trie) -> (Int, Int, Trie))
-> (Int, Int, Maybe Trie) -> (Int, Int, Trie)
forall a b. (a -> b) -> a -> b
$ Trie -> Int -> Int -> (Int, Int, Maybe Trie)
go Trie
trie Int
0 Int
firstIndex 
  where
    matchUpdate :: (a, b, Maybe Trie) -> (a, b, Trie)
matchUpdate (a
lzwOutputIndex, b
nextReadIndex, Maybe Trie
sub) =
        (a
lzwOutputIndex, b
nextReadIndex, Trie -> Maybe Trie -> Trie
forall a. a -> Maybe a -> a
fromMaybe Trie
trie Maybe Trie
sub)

    maxi :: Int
maxi = Vector Word8 -> Int
forall a. Storable a => Vector a -> Int
V.length Vector Word8
vector
    go :: Trie -> Int -> Int -> (Int, Int, Maybe Trie)
go !Trie
currentTrie !Int
prevIndex !Int
index
      | Int
index Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
maxi = (Int
prevIndex, Int
index, Maybe Trie
forall a. Maybe a
Nothing)
      | Bool
otherwise = case Int -> Trie -> Maybe TrieNode
forall a. Int -> IntMap a -> Maybe a
I.lookup Int
val Trie
currentTrie of
          Just (TrieNode Int
ix Trie
subTable) ->
              let (Int
lzwOutputIndex, Int
nextReadIndex, Maybe Trie
newTable) =
                        Trie -> Int -> Int -> (Int, Int, Maybe Trie)
go Trie
subTable Int
ix (Int -> (Int, Int, Maybe Trie)) -> Int -> (Int, Int, Maybe Trie)
forall a b. (a -> b) -> a -> b
$ Int
index Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1
                  tableUpdater :: Trie -> Trie
tableUpdater Trie
t =
                      Int -> TrieNode -> Trie -> Trie
forall a. Int -> a -> IntMap a -> IntMap a
I.insert Int
val (Int -> Trie -> TrieNode
TrieNode Int
ix Trie
t) Trie
currentTrie
              in
              (Int
lzwOutputIndex, Int
nextReadIndex, Trie -> Trie
tableUpdater (Trie -> Trie) -> Maybe Trie -> Maybe Trie
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Maybe Trie
newTable)

          Maybe TrieNode
Nothing | Int
index Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
maxi -> (Int
prevIndex, Int
index, Maybe Trie
forall a. Maybe a
Nothing)
                  | Bool
otherwise -> (Int
prevIndex, Int
index, Trie -> Maybe Trie
forall a. a -> Maybe a
Just (Trie -> Maybe Trie) -> Trie -> Maybe Trie
forall a b. (a -> b) -> a -> b
$ Int -> TrieNode -> Trie -> Trie
forall a. Int -> a -> IntMap a -> IntMap a
I.insert Int
val TrieNode
newNode Trie
currentTrie)

      where val :: Int
val = Word8 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Word8 -> Int) -> Word8 -> Int
forall a b. (a -> b) -> a -> b
$ Vector Word8
vector Vector Word8 -> Int -> Word8
forall a. Storable a => Vector a -> Int -> a
`V.unsafeIndex` Int
index
            newNode :: TrieNode
newNode = TrieNode
emptyNode { trieIndex = freeIndex }

lzwEncode :: Int -> V.Vector Word8 -> L.ByteString
lzwEncode :: Int -> Vector Word8 -> ByteString
lzwEncode Int
initialKeySize Vector Word8
vec = (forall s. ST s ByteString) -> ByteString
forall a. (forall s. ST s a) -> a
runST ((forall s. ST s ByteString) -> ByteString)
-> (forall s. ST s ByteString) -> ByteString
forall a b. (a -> b) -> a -> b
$ do
    BoolWriteStateRef s
bitWriter <- ST s (BoolWriteStateRef s)
forall s. ST s (BoolWriteStateRef s)
newWriteStateRef 

    let updateCodeSize :: Int -> Int -> Trie -> ST s (Int, Int, Trie)
updateCodeSize Int
12 Int
writeIdx Trie
_
            | Int
writeIdx Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
2 Int -> Int -> Int
forall a b. (Num a, Integral b) => a -> b -> a
^ (Int
12 :: Int) Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1 = do
               BoolWriteStateRef s -> Word32 -> Int -> ST s ()
forall s. BoolWriteStateRef s -> Word32 -> Int -> ST s ()
writeBitsGif BoolWriteStateRef s
bitWriter (Int -> Word32
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
clearCode) Int
12
               (Int, Int, Trie) -> ST s (Int, Int, Trie)
forall a. a -> ST s a
forall (m :: * -> *) a. Monad m => a -> m a
return (Int
startCodeSize, Int
firstFreeIndex, Trie
initialTrie)

        updateCodeSize Int
codeSize Int
writeIdx Trie
trie
            | Int
writeIdx Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
2 Int -> Int -> Int
forall a b. (Num a, Integral b) => a -> b -> a
^ Int
codeSize =
                (Int, Int, Trie) -> ST s (Int, Int, Trie)
forall a. a -> ST s a
forall (m :: * -> *) a. Monad m => a -> m a
return (Int
codeSize Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1, Int
writeIdx Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1, Trie
trie)
            | Bool
otherwise = (Int, Int, Trie) -> ST s (Int, Int, Trie)
forall a. a -> ST s a
forall (m :: * -> *) a. Monad m => a -> m a
return (Int
codeSize, Int
writeIdx Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1, Trie
trie)

        go :: Int -> (Int, Int, Trie) -> ST s ()
go Int
readIndex (Int
codeSize, Int
_, Trie
_) | Int
readIndex Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
maxi =
            BoolWriteStateRef s -> Word32 -> Int -> ST s ()
forall s. BoolWriteStateRef s -> Word32 -> Int -> ST s ()
writeBitsGif BoolWriteStateRef s
bitWriter (Int -> Word32
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
endOfInfo) Int
codeSize
        go !Int
readIndex (!Int
codeSize, !Int
writeIndex, !Trie
trie) = do
            let (Int
indexToWrite, Int
endIndex, Trie
trie') =
                    Int -> Int -> Trie -> (Int, Int, Trie)
lookuper Int
writeIndex Int
readIndex Trie
trie
            BoolWriteStateRef s -> Word32 -> Int -> ST s ()
forall s. BoolWriteStateRef s -> Word32 -> Int -> ST s ()
writeBitsGif BoolWriteStateRef s
bitWriter (Int -> Word32
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
indexToWrite) Int
codeSize
            Int -> Int -> Trie -> ST s (Int, Int, Trie)
updateCodeSize Int
codeSize Int
writeIndex Trie
trie'
                ST s (Int, Int, Trie) -> ((Int, Int, Trie) -> ST s ()) -> ST s ()
forall a b. ST s a -> (a -> ST s b) -> ST s b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Int -> (Int, Int, Trie) -> ST s ()
go Int
endIndex 

    BoolWriteStateRef s -> Word32 -> Int -> ST s ()
forall s. BoolWriteStateRef s -> Word32 -> Int -> ST s ()
writeBitsGif BoolWriteStateRef s
bitWriter (Int -> Word32
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
clearCode) Int
startCodeSize
    Int -> (Int, Int, Trie) -> ST s ()
go Int
0 (Int
startCodeSize, Int
firstFreeIndex, Trie
initialTrie)

    BoolWriteStateRef s -> ST s ByteString
forall s. BoolWriteStateRef s -> ST s ByteString
finalizeBoolWriterGif BoolWriteStateRef s
bitWriter
  where
    maxi :: Int
maxi = Vector Word8 -> Int
forall a. Storable a => Vector a -> Int
V.length Vector Word8
vec

    startCodeSize :: Int
startCodeSize = Int
initialKeySize Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1

    clearCode :: Int
clearCode = Int
2 Int -> Int -> Int
forall a b. (Num a, Integral b) => a -> b -> a
^ Int
initialKeySize :: Int
    endOfInfo :: Int
endOfInfo = Int
clearCode Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1
    firstFreeIndex :: Int
firstFreeIndex = Int
endOfInfo Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1
    
    lookuper :: Int -> Int -> Trie -> (Int, Int, Trie)
lookuper = Vector Word8 -> Int -> Int -> Trie -> (Int, Int, Trie)
lookupUpdate Vector Word8
vec