{-# LANGUAGE OverloadedStrings #-}

module Codec.Compression.ShannonFano.Internal
  ( Input,
    Table,
    split,
    chunksOf,
    decode,
    compressChunk,
    compressWithLeftover,
    decompressWithLeftover,
  )
where

import Control.Arrow ((&&&))
import Data.Bits
import Data.Bool (bool)
import Data.ByteString.Lazy (ByteString)
import qualified Data.ByteString.Lazy as BSL
import Data.Word

-- | Simple table used as an intermediate structure
type Table a = [(Word8, a)]

-- | Simple type alias to distinguish when something is an input or is
-- a coded or compressed 'Bytestring'.
type Input = ByteString

-- | Auxiliary split function.
--
-- This function splits a probabilities table in half where the sum of the
-- two halfs are as close as possible.
split :: Table Float -> (Table Float, Table Float)
split :: Table Float -> (Table Float, Table Float)
split Table Float
t = Table Float -> Table Float -> (Table Float, Table Float)
forall a a.
(Ord a, Num a) =>
[(a, a)] -> [(a, a)] -> ([(a, a)], [(a, a)])
aux Table Float
t []
  where
    aux :: [(a, a)] -> [(a, a)] -> ([(a, a)], [(a, a)])
aux [] [(a, a)]
l = ([(a, a)]
l, [])
    aux ((a, a)
x : [(a, a)]
xs) [(a, a)]
l
      | (a, a) -> a
forall a b. (a, b) -> b
snd (a, a)
x a -> a -> a
forall a. Num a => a -> a -> a
+ [a] -> a
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum (((a, a) -> a) -> [(a, a)] -> [a]
forall a b. (a -> b) -> [a] -> [b]
map (a, a) -> a
forall a b. (a, b) -> b
snd [(a, a)]
l) a -> a -> Bool
forall a. Ord a => a -> a -> Bool
< [a] -> a
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum (((a, a) -> a) -> [(a, a)] -> [a]
forall a b. (a -> b) -> [a] -> [b]
map (a, a) -> a
forall a b. (a, b) -> b
snd [(a, a)]
xs) = [(a, a)] -> [(a, a)] -> ([(a, a)], [(a, a)])
aux [(a, a)]
xs ([(a, a)]
l [(a, a)] -> [(a, a)] -> [(a, a)]
forall a. [a] -> [a] -> [a]
++ [(a, a)
x])
      | Bool
otherwise = ([(a, a)]
l [(a, a)] -> [(a, a)] -> [(a, a)]
forall a. [a] -> [a] -> [a]
++ [(a, a)
x], [(a, a)]
xs)

-- | Takes a 'Bytestring' of 0s and 1s with length 8 and converts it to
-- a single 'Word8'.
--
-- Example:
-- @
-- compressChunk "00000001" == 1
-- @
compressChunk :: ByteString -> Word8
compressChunk :: ByteString -> Word8
compressChunk ByteString
s = ByteString -> Word8 -> Word8
aux ByteString
s Word8
forall a. Bits a => a
zeroBits
  where
    aux :: ByteString -> Word8 -> Word8
    aux :: ByteString -> Word8 -> Word8
aux ByteString
s Word8
w
      | ByteString -> Bool
BSL.null ByteString
s = Word8
w
      | Bool
otherwise =
        let (Word8
h, (ByteString
t, Int64
n)) = (ByteString -> Word8
BSL.head (ByteString -> Word8)
-> (ByteString -> (ByteString, Int64))
-> ByteString
-> (Word8, (ByteString, Int64))
forall (a :: * -> * -> *) b c c'.
Arrow a =>
a b c -> a b c' -> a b (c, c')
&&& ByteString -> ByteString
BSL.tail (ByteString -> ByteString)
-> (ByteString -> Int64) -> ByteString -> (ByteString, Int64)
forall (a :: * -> * -> *) b c c'.
Arrow a =>
a b c -> a b c' -> a b (c, c')
&&& ByteString -> Int64
BSL.length) ByteString
s
         in case Word8
h of
              Word8
49 -> ByteString -> Word8 -> Word8
aux ByteString
t (Word8 -> Int -> Word8
forall a. Bits a => a -> Int -> a
setBit Word8
w (Int64 -> Int
forall a. Enum a => a -> Int
fromEnum Int64
n Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1)) -- "1"
              Word8
48 -> ByteString -> Word8 -> Word8
aux ByteString
t Word8
w                 -- "0"

-- | Creates a list of 'Bytestring' chunks.
chunksOf :: Int -> ByteString -> [ByteString]
chunksOf :: Int -> ByteString -> [ByteString]
chunksOf Int
n = ByteString -> [ByteString]
go
  where
    go :: ByteString -> [ByteString]
go ByteString
t = case Int64 -> ByteString -> (ByteString, ByteString)
BSL.splitAt (Int -> Int64
forall a. Enum a => Int -> a
toEnum Int
n) ByteString
t of
      (ByteString
a, ByteString
b)
        | ByteString -> Bool
BSL.null ByteString
a -> []
        | Bool
otherwise -> ByteString
a ByteString -> [ByteString] -> [ByteString]
forall a. a -> [a] -> [a]
: ByteString -> [ByteString]
go ByteString
b

-- | Takes a full 'Bytestring' of 0s and 1s and compresses it in binary
-- form. This way, a 'ByteString' that represents a binary string is
-- converted in its compressed form, occupying only the necessary bits.
compress :: ByteString -> ByteString
compress :: ByteString -> ByteString
compress = [Word8] -> ByteString
BSL.pack ([Word8] -> ByteString)
-> (ByteString -> [Word8]) -> ByteString -> ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (ByteString -> Word8) -> [ByteString] -> [Word8]
forall a b. (a -> b) -> [a] -> [b]
map ByteString -> Word8
compressChunk ([ByteString] -> [Word8])
-> (ByteString -> [ByteString]) -> ByteString -> [Word8]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> ByteString -> [ByteString]
chunksOf Int
8

-- | The same as 'compress' but adds the size of the last byte to the
-- beginning of the compressed 'ByteString', since it might occupy less
-- than 8 bits.
compressWithLeftover :: ByteString -> ByteString
compressWithLeftover :: ByteString -> ByteString
compressWithLeftover ByteString
s = ByteString -> ByteString -> ByteString
BSL.append (Int -> ByteString
int2compressedBS (Int64 -> Int
forall a. Enum a => a -> Int
fromEnum (ByteString -> Int64
BSL.length ByteString
s) Int -> Int -> Int
forall a. Integral a => a -> a -> a
`mod` Int
8)) (ByteString -> ByteString
compress ByteString
s)
  where
    int2compressedBS :: Int -> ByteString
    int2compressedBS :: Int -> ByteString
int2compressedBS Int
n
      -- If n > 128 signifies that it does not fit in a 'Word8'
      | Int
n Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
128 = [Char] -> ByteString
forall a. HasCallStack => [Char] -> a
error [Char]
"excess length greater than 8"
      | Bool
otherwise = ByteString -> ByteString
compress (ByteString -> ByteString)
-> (Int -> ByteString) -> Int -> ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Bool] -> ByteString
bool2BS ([Bool] -> ByteString) -> (Int -> [Bool]) -> Int -> ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> [Bool]
forall a. Bits a => a -> [Bool]
bitList (Int -> ByteString) -> Int -> ByteString
forall a b. (a -> b) -> a -> b
$ Int
n

-- | Takes a compressed 'Bytestring' and converts it into a 'Bytestring' of
-- only 0s and 1s.
decode :: ByteString -> ByteString
decode :: ByteString -> ByteString
decode = (Word8 -> ByteString) -> ByteString -> ByteString
BSL.concatMap ([Bool] -> ByteString
bool2BS ([Bool] -> ByteString) -> (Word8 -> [Bool]) -> Word8 -> ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Word8 -> [Bool]
forall a. Bits a => a -> [Bool]
bitList)

-- | Takes a compressed 'Bytestring' and converts it into a 'Bytestring' of
-- only 0s and 1s, truncating the last byte accordingly, using the
-- information of the last byte's size.
decompressWithLeftover :: ByteString -> ByteString
decompressWithLeftover :: ByteString -> ByteString
decompressWithLeftover = [ByteString] -> ByteString
BSL.concat ([ByteString] -> ByteString)
-> (ByteString -> [ByteString]) -> ByteString -> ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ([ByteString], Int) -> [ByteString]
aux (([ByteString], Int) -> [ByteString])
-> (ByteString -> ([ByteString], Int))
-> ByteString
-> [ByteString]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ([ByteString] -> [ByteString]
forall a. [a] -> [a]
tail ([ByteString] -> [ByteString])
-> ([ByteString] -> Int) -> [ByteString] -> ([ByteString], Int)
forall (a :: * -> * -> *) b c c'.
Arrow a =>
a b c -> a b c' -> a b (c, c')
&&& (Word8 -> Int
forall a. Enum a => a -> Int
fromEnum (Word8 -> Int) -> ([ByteString] -> Word8) -> [ByteString] -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> Word8
compressChunk (ByteString -> Word8)
-> ([ByteString] -> ByteString) -> [ByteString] -> Word8
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [ByteString] -> ByteString
forall a. [a] -> a
head)) ([ByteString] -> ([ByteString], Int))
-> (ByteString -> [ByteString])
-> ByteString
-> ([ByteString], Int)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> ByteString -> [ByteString]
chunksOf Int
8 (ByteString -> [ByteString])
-> (ByteString -> ByteString) -> ByteString -> [ByteString]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> ByteString
decode
  where
    aux :: ([ByteString], Int) -> [ByteString]
    aux :: ([ByteString], Int) -> [ByteString]
aux ([], Int
_) = []
    aux ([ByteString
x], Int
0) = [ByteString
x] -- If 0 then there is no need to truncate
    aux ([ByteString
x], Int
i) = [Int64 -> ByteString -> ByteString
BSL.drop (Int -> Int64
forall a. Enum a => Int -> a
toEnum (Int
8 Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
i)) ByteString
x]
    aux (ByteString
x : [ByteString]
xs, Int
i) = ByteString
x ByteString -> [ByteString] -> [ByteString]
forall a. a -> [a] -> [a]
: ([ByteString], Int) -> [ByteString]
aux ([ByteString]
xs, Int
i)

-- | Creates a list of 'Bool's out of a bit.
bitList :: Bits a => a -> [Bool]
bitList :: a -> [Bool]
bitList a
x = (Int -> Bool) -> [Int] -> [Bool]
forall a b. (a -> b) -> [a] -> [b]
map (a -> Int -> Bool
forall a. Bits a => a -> Int -> Bool
testBit a
x) [Int
7, Int
6 .. Int
0]

-- | Creates a bit 'ByteString' (a 'ByteString' full of 0s and 1s).
bool2BS :: [Bool] -> ByteString
bool2BS :: [Bool] -> ByteString
bool2BS = [ByteString] -> ByteString
BSL.concat ([ByteString] -> ByteString)
-> ([Bool] -> [ByteString]) -> [Bool] -> ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Bool -> ByteString) -> [Bool] -> [ByteString]
forall a b. (a -> b) -> [a] -> [b]
map (ByteString -> ByteString -> Bool -> ByteString
forall a. a -> a -> Bool -> a
bool ByteString
"0" ByteString
"1")