{-# LANGUAGE OverloadedStrings #-}

module Codec.Compression.ShannonFano where

import Codec.Compression.ShannonFano.Internal
import Control.Arrow
import Data.ByteString.Lazy (ByteString)
import qualified Data.ByteString as BS
import qualified Data.ByteString.Lazy as BSL
import Data.List (lookup, sortBy)
import Data.Word (Word8)
import System.IO

-- | Decode table error can happen when the wrong code table is provided.
data DecodeTableError = DecodeTableError
  deriving (DecodeTableError -> DecodeTableError -> Bool
(DecodeTableError -> DecodeTableError -> Bool)
-> (DecodeTableError -> DecodeTableError -> Bool)
-> Eq DecodeTableError
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: DecodeTableError -> DecodeTableError -> Bool
$c/= :: DecodeTableError -> DecodeTableError -> Bool
== :: DecodeTableError -> DecodeTableError -> Bool
$c== :: DecodeTableError -> DecodeTableError -> Bool
Eq, Int -> DecodeTableError -> ShowS
[DecodeTableError] -> ShowS
DecodeTableError -> String
(Int -> DecodeTableError -> ShowS)
-> (DecodeTableError -> String)
-> ([DecodeTableError] -> ShowS)
-> Show DecodeTableError
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [DecodeTableError] -> ShowS
$cshowList :: [DecodeTableError] -> ShowS
show :: DecodeTableError -> String
$cshow :: DecodeTableError -> String
showsPrec :: Int -> DecodeTableError -> ShowS
$cshowsPrec :: Int -> DecodeTableError -> ShowS
Show)

-- | Gives the frequency table of all characters in a string.
frequency ::
  -- | Input string
  Input ->
  -- | Resulting table
  Table Int
frequency :: Input -> Table Int
frequency = (Input -> (Word8, Int)) -> [Input] -> Table Int
forall a b. (a -> b) -> [a] -> [b]
map (Input -> Word8
BSL.head (Input -> Word8) -> (Input -> Int) -> Input -> (Word8, Int)
forall (a :: * -> * -> *) b c c'.
Arrow a =>
a b c -> a b c' -> a b (c, c')
&&& (Int64 -> Int
forall a. Enum a => a -> Int
fromEnum (Int64 -> Int) -> (Input -> Int64) -> Input -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Input -> Int64
BSL.length)) ([Input] -> Table Int) -> (Input -> [Input]) -> Input -> Table Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Input -> [Input]
BSL.group (Input -> [Input]) -> (Input -> Input) -> Input -> [Input]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> Input
BSL.fromStrict (ByteString -> Input) -> (Input -> ByteString) -> Input -> Input
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> ByteString
BS.sort (ByteString -> ByteString)
-> (Input -> ByteString) -> Input -> ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Input -> ByteString
BSL.toStrict

-- | Gives the probability table of all characters in a string.
probability ::
  -- | Input string
  Input ->
  -- | Resulting table
  Table Float
probability :: Input -> Table Float
probability Input
s =
  let table :: Table Int
table = Input -> Table Int
frequency Input
s
      total :: Float
total = Int -> Float
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int -> Float) -> (Input -> Int) -> Input -> Float
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int64 -> Int
forall a. Enum a => a -> Int
fromEnum (Int64 -> Int) -> (Input -> Int64) -> Input -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Input -> Int64
BSL.length (Input -> Float) -> Input -> Float
forall a b. (a -> b) -> a -> b
$ Input
s
   in ((Word8, Int) -> (Word8, Float)) -> Table Int -> Table Float
forall a b. (a -> b) -> [a] -> [b]
map ((Int -> Float) -> (Word8, Int) -> (Word8, Float)
forall (a :: * -> * -> *) b c d.
Arrow a =>
a b c -> a (d, b) (d, c)
second ((Float -> Float -> Float
forall a. Fractional a => a -> a -> a
/ Float
total) (Float -> Float) -> (Int -> Float) -> Int -> Float
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> Float
forall a b. (Integral a, Num b) => a -> b
fromIntegral)) Table Int
table

-- | Generates a 'DecodeTable'
genCodeTable ::
  -- | Input string
  Input ->
  -- | Resulting code table
  Table ByteString
genCodeTable :: Input -> Table Input
genCodeTable Input
s =
  let table :: Table Float
table = ((Word8, Float) -> (Word8, Float) -> Ordering)
-> Table Float -> Table Float
forall a. (a -> a -> Ordering) -> [a] -> [a]
sortBy (Word8, Float) -> (Word8, Float) -> Ordering
cmp (Table Float -> Table Float) -> Table Float -> Table Float
forall a b. (a -> b) -> a -> b
$ Input -> Table Float
probability Input
s
   in (Table Float, Table Float) -> Table Input
aux (Table Float -> (Table Float, Table Float)
split Table Float
table)
  where
    cmp :: (Word8, Float) -> (Word8, Float) -> Ordering
    cmp :: (Word8, Float) -> (Word8, Float) -> Ordering
cmp (Word8, Float)
x (Word8, Float)
y = if (Word8, Float) -> Float
forall a b. (a, b) -> b
snd (Word8, Float)
x Float -> Float -> Bool
forall a. Ord a => a -> a -> Bool
< (Word8, Float) -> Float
forall a b. (a, b) -> b
snd (Word8, Float)
y then Ordering
GT else Ordering
LT
    aux :: (Table Float, Table Float) -> Table ByteString
    aux :: (Table Float, Table Float) -> Table Input
aux ([], []) = []
    aux ([(Word8
x, Float
_)], [(Word8
y, Float
_)]) = [(Word8
x, Input
"0"), (Word8
y, Input
"1")]
    aux ([(Word8
x, Float
_)], Table Float
r) = (Word8
x, Input
"0") (Word8, Input) -> Table Input -> Table Input
forall a. a -> [a] -> [a]
: ((Word8, Input) -> (Word8, Input)) -> Table Input -> Table Input
forall a b. (a -> b) -> [a] -> [b]
map ((Input -> Input) -> (Word8, Input) -> (Word8, Input)
forall (a :: * -> * -> *) b c d.
Arrow a =>
a b c -> a (d, b) (d, c)
second (Input -> Input -> Input
BSL.append Input
"1")) ((Table Float, Table Float) -> Table Input
aux (Table Float -> (Table Float, Table Float)
split Table Float
r))
    aux (Table Float
l, [(Word8
y, Float
_)]) = ((Word8, Input) -> (Word8, Input)) -> Table Input -> Table Input
forall a b. (a -> b) -> [a] -> [b]
map ((Input -> Input) -> (Word8, Input) -> (Word8, Input)
forall (a :: * -> * -> *) b c d.
Arrow a =>
a b c -> a (d, b) (d, c)
second (Input -> Input -> Input
BSL.append Input
"0")) ((Table Float, Table Float) -> Table Input
aux (Table Float -> (Table Float, Table Float)
split Table Float
l)) Table Input -> Table Input -> Table Input
forall a. [a] -> [a] -> [a]
++ [(Word8
y, Input
"1")]
    aux (Table Float
l, Table Float
r) =
      let l2 :: Table Input
l2 = (Table Float, Table Float) -> Table Input
aux ((Table Float, Table Float) -> Table Input)
-> (Table Float, Table Float) -> Table Input
forall a b. (a -> b) -> a -> b
$ Table Float -> (Table Float, Table Float)
split Table Float
l
          r2 :: Table Input
r2 = (Table Float, Table Float) -> Table Input
aux ((Table Float, Table Float) -> Table Input)
-> (Table Float, Table Float) -> Table Input
forall a b. (a -> b) -> a -> b
$ Table Float -> (Table Float, Table Float)
split Table Float
r
       in ((Word8, Input) -> (Word8, Input)) -> Table Input -> Table Input
forall a b. (a -> b) -> [a] -> [b]
map ((Input -> Input) -> (Word8, Input) -> (Word8, Input)
forall (a :: * -> * -> *) b c d.
Arrow a =>
a b c -> a (d, b) (d, c)
second (Input -> Input -> Input
BSL.append Input
"0")) Table Input
l2 Table Input -> Table Input -> Table Input
forall a. [a] -> [a] -> [a]
++ ((Word8, Input) -> (Word8, Input)) -> Table Input -> Table Input
forall a b. (a -> b) -> [a] -> [b]
map ((Input -> Input) -> (Word8, Input) -> (Word8, Input)
forall (a :: * -> * -> *) b c d.
Arrow a =>
a b c -> a (d, b) (d, c)
second (Input -> Input -> Input
BSL.append Input
"1")) Table Input
r2

-- | Given a 'Table ByteString' compresses it by applying the Shannon-fano
--   algorithm.
compress ::
  -- | Input string
  Input ->
  -- | Result compressed
  ByteString
compress :: Input -> Input
compress Input
s = Input -> Input
compressWithLeftover (Input -> Input) -> Input -> Input
forall a b. (a -> b) -> a -> b
$ Input -> Table Input -> Input
aux Input
s (Input -> Table Input
genCodeTable Input
s)
  where
    aux :: ByteString -> Table ByteString -> ByteString
    aux :: Input -> Table Input -> Input
aux Input
s Table Input
t
      | Input -> Bool
BSL.null Input
s = Input
BSL.empty
      | Bool
otherwise =
        let (Word8
x, Input
xs) = (Input -> Word8
BSL.head (Input -> Word8) -> (Input -> Input) -> Input -> (Word8, Input)
forall (a :: * -> * -> *) b c c'.
Arrow a =>
a b c -> a b c' -> a b (c, c')
&&& Input -> Input
BSL.tail) Input
s
            (Just Input
r) = Word8 -> Table Input -> Maybe Input
forall a b. Eq a => a -> [(a, b)] -> Maybe b
lookup Word8
x Table Input
t
         in Input -> Input -> Input
BSL.append Input
r (Input -> Table Input -> Input
aux Input
xs Table Input
t)

-- | Decompresses a compressed 'ByteString', given a code table
--
--   This fails if the code table does not have an entry for a given
--   character.
decompress ::
  -- | Coded input to decompress
  ByteString ->
  -- | Code table associated with the input
  Table ByteString ->
  -- | Result decompressed
  Maybe Input
decompress :: Input -> Table Input -> Maybe Input
decompress Input
s Table Input
t
  | Input -> Bool
BSL.null Input
s = Input -> Maybe Input
forall a. a -> Maybe a
Just Input
BSL.empty
  | Input -> Bool
BSL.null (Input -> Input
decompressWithLeftover Input
s) = Input -> Maybe Input
forall a. a -> Maybe a
Just Input
BSL.empty
  | Bool
otherwise =
    let decomps :: Input
decomps = Input -> Input
decompressWithLeftover Input
s
        (Word8
x, Input
xs) = (Input -> Word8
BSL.head (Input -> Word8) -> (Input -> Input) -> Input -> (Word8, Input)
forall (a :: * -> * -> *) b c c'.
Arrow a =>
a b c -> a b c' -> a b (c, c')
&&& Input -> Input
BSL.tail) Input
decomps
     in [(Input, Word8)] -> Input -> Input -> Maybe Input
aux (((Word8, Input) -> (Input, Word8))
-> Table Input -> [(Input, Word8)]
forall a b. (a -> b) -> [a] -> [b]
map ((Word8, Input) -> Input
forall a b. (a, b) -> b
snd ((Word8, Input) -> Input)
-> ((Word8, Input) -> Word8) -> (Word8, Input) -> (Input, Word8)
forall (a :: * -> * -> *) b c c'.
Arrow a =>
a b c -> a b c' -> a b (c, c')
&&& (Word8, Input) -> Word8
forall a b. (a, b) -> a
fst) Table Input
t) Input
xs (Word8 -> Input
BSL.singleton Word8
x)
  where
    aux :: [(ByteString, Word8)] -> ByteString -> ByteString -> Maybe ByteString
    aux :: [(Input, Word8)] -> Input -> Input -> Maybe Input
aux [(Input, Word8)]
dt Input
ls Input
l =
      if Input -> Bool
BSL.null Input
ls
        then case Input -> [(Input, Word8)] -> Maybe Word8
forall a b. Eq a => a -> [(a, b)] -> Maybe b
lookup Input
l [(Input, Word8)]
dt of
          Maybe Word8
Nothing -> Input -> Maybe Input
forall a. a -> Maybe a
Just Input
""
          Just Word8
r -> Word8 -> Input -> Input
BSL.cons (Word8 -> Input -> Input) -> Maybe Word8 -> Maybe (Input -> Input)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Word8 -> Maybe Word8
forall a. a -> Maybe a
Just Word8
r Maybe (Input -> Input) -> Maybe Input -> Maybe Input
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Input -> Maybe Input
forall a. a -> Maybe a
Just Input
""
        else
          let (Word8
h, Input
t) = (Input -> Word8
BSL.head (Input -> Word8) -> (Input -> Input) -> Input -> (Word8, Input)
forall (a :: * -> * -> *) b c c'.
Arrow a =>
a b c -> a b c' -> a b (c, c')
&&& Input -> Input
BSL.tail) Input
ls
           in case Input -> [(Input, Word8)] -> Maybe Word8
forall a b. Eq a => a -> [(a, b)] -> Maybe b
lookup Input
l [(Input, Word8)]
dt of
                Maybe Word8
Nothing -> [(Input, Word8)] -> Input -> Input -> Maybe Input
aux [(Input, Word8)]
dt Input
t (Input -> Input -> Input
BSL.append Input
l (Word8 -> Input
BSL.singleton Word8
h))
                (Just Word8
r) -> Word8 -> Input -> Input
BSL.cons (Word8 -> Input -> Input) -> Maybe Word8 -> Maybe (Input -> Input)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Word8 -> Maybe Word8
forall a. a -> Maybe a
Just Word8
r Maybe (Input -> Input) -> Maybe Input -> Maybe Input
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> [(Input, Word8)] -> Input -> Input -> Maybe Input
aux [(Input, Word8)]
dt Input
t (Word8 -> Input
BSL.singleton Word8
h)

-- | Reads contents from a handle and compresses it to a file.
--
--   The resulting files are:
--    - '<filename>' <- binary compressed file
--    - '<filename>.dat' <- contains the decoding table
compressToFile ::
  -- | Handle from where data will be read
  Handle ->
  -- | Output file name
  String ->
  IO ()
compressToFile :: Handle -> String -> IO ()
compressToFile Handle
h String
filename = do
  Input
contents <- Handle -> IO Input
BSL.hGetContents Handle
h
  let compressed :: Input
compressed = Input -> Input
compress Input
contents
      decodeTable :: Table Input
decodeTable = Input -> Table Input
genCodeTable Input
contents
  String -> String -> IO ()
writeFile (String
filename String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
".tab") (Table Input -> String
forall a. Show a => a -> String
show Table Input
decodeTable)
  String -> Input -> IO ()
BSL.writeFile String
filename Input
compressed

-- | Decompresses a file given a decoding table file and a compressed
--   compressed file.
decompressFromFile ::
  -- | Handle from where compressed data will be read
  Handle ->
  -- | Decode table
  Table ByteString ->
  -- | Output file name
  String ->
  IO (Either DecodeTableError ())
decompressFromFile :: Handle -> Table Input -> String -> IO (Either DecodeTableError ())
decompressFromFile Handle
h Table Input
dt String
filename = do
  Input
contents <- Handle -> IO Input
BSL.hGetContents Handle
h
  let decoded :: Maybe Input
decoded = Input -> Table Input -> Maybe Input
decompress Input
contents Table Input
dt
  case Maybe Input
decoded of
    Maybe Input
Nothing -> Either DecodeTableError () -> IO (Either DecodeTableError ())
forall (m :: * -> *) a. Monad m => a -> m a
return (Either DecodeTableError () -> IO (Either DecodeTableError ()))
-> (DecodeTableError -> Either DecodeTableError ())
-> DecodeTableError
-> IO (Either DecodeTableError ())
forall b c a. (b -> c) -> (a -> b) -> a -> c
. DecodeTableError -> Either DecodeTableError ()
forall a b. a -> Either a b
Left (DecodeTableError -> IO (Either DecodeTableError ()))
-> DecodeTableError -> IO (Either DecodeTableError ())
forall a b. (a -> b) -> a -> b
$ DecodeTableError
DecodeTableError
    Just Input
r -> () -> Either DecodeTableError ()
forall a b. b -> Either a b
Right (() -> Either DecodeTableError ())
-> IO () -> IO (Either DecodeTableError ())
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> String -> Input -> IO ()
BSL.writeFile String
filename Input
r