{-# LANGUAGE ExistentialQuantification #-}
{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE CPP #-}
module Codec.Picture.ColorQuant
( palettize
, defaultPaletteOptions
, PaletteCreationMethod(..)
, PaletteOptions( .. )
) where
#if !MIN_VERSION_base(4,8,0)
import Control.Applicative (Applicative (..), (<$>))
#endif
import Data.Bits (unsafeShiftL, unsafeShiftR, (.&.), (.|.))
import Data.List (elemIndex)
import Data.Maybe (fromMaybe)
import Data.Set (Set)
import qualified Data.Set as Set
import Data.Word (Word32)
import Data.Vector (Vector, (!))
import qualified Data.Vector as V
import qualified Data.Vector.Unboxed as VU
import qualified Data.Vector.Storable as VS
import Codec.Picture.Types
data PaletteCreationMethod =
MedianMeanCut
| Uniform
data PaletteOptions = PaletteOptions
{
paletteCreationMethod :: PaletteCreationMethod
, enableImageDithering :: Bool
, paletteColorCount :: Int
}
defaultPaletteOptions :: PaletteOptions
defaultPaletteOptions = PaletteOptions
{ paletteCreationMethod = MedianMeanCut
, enableImageDithering = True
, paletteColorCount = 256
}
palettize :: PaletteOptions -> Image PixelRGB8 -> (Image Pixel8, Palette)
palettize opts@PaletteOptions { paletteCreationMethod = method } =
case method of
MedianMeanCut -> medianMeanCutQuantization opts
Uniform -> uniformQuantization opts
medianMeanCutQuantization :: PaletteOptions -> Image PixelRGB8
-> (Image Pixel8, Palette)
medianMeanCutQuantization opts img
| isBelow =
(pixelMap okPaletteIndex img, vecToPalette okPaletteVec)
| enableImageDithering opts = (pixelMap paletteIndex dImg, palette)
| otherwise = (pixelMap paletteIndex img, palette)
where
maxColorCount = paletteColorCount opts
(okPalette, isBelow) = isColorCountBelow maxColorCount img
okPaletteVec = V.fromList $ Set.toList okPalette
okPaletteIndex p = nearestColorIdx p okPaletteVec
palette = vecToPalette paletteVec
paletteIndex p = nearestColorIdx p paletteVec
paletteVec = mkPaletteVec cs
cs = Set.toList . clusters maxColorCount $ img
dImg = pixelMapXY dither img
uniformQuantization :: PaletteOptions -> Image PixelRGB8 -> (Image Pixel8, Palette)
uniformQuantization opts img
| enableImageDithering opts =
(pixelMap paletteIndex (pixelMapXY dither img), palette)
| otherwise = (pixelMap paletteIndex img, palette)
where
maxCols = paletteColorCount opts
palette = listToPalette paletteList
paletteList = [PixelRGB8 r g b | r <- [0,dr..255]
, g <- [0,dg..255]
, b <- [0,db..255]]
(bg, br, bb) = bitDiv3 maxCols
(dr, dg, db) = (2^(8-br), 2^(8-bg), 2^(8-bb))
paletteIndex (PixelRGB8 r g b) = fromIntegral $ fromMaybe 0 (elemIndex
(PixelRGB8 (r .&. negate dr) (g .&. negate dg) (b .&. negate db))
paletteList)
isColorCountBelow :: Int -> Image PixelRGB8 -> (Set.Set PixelRGB8, Bool)
isColorCountBelow maxColorCount img = go 0 Set.empty
where rawData = imageData img
maxIndex = VS.length rawData
go !idx !allColors
| Set.size allColors > maxColorCount = (Set.empty, False)
| idx >= maxIndex - 2 = (allColors, True)
| otherwise = go (idx + 3) $ Set.insert px allColors
where px = unsafePixelAt rawData idx
vecToPalette :: Vector PixelRGB8 -> Palette
vecToPalette ps = generateImage (\x _ -> ps ! x) (V.length ps) 1
listToPalette :: [PixelRGB8] -> Palette
listToPalette ps = generateImage (\x _ -> ps !! x) (length ps) 1
bitDiv3 :: Int -> (Int, Int, Int)
bitDiv3 n = case r of
0 -> (q, q, q)
1 -> (q+1, q, q)
_ -> (q+1, q+1, q)
where
r = m `mod` 3
q = m `div` 3
m = floor . logBase (2 :: Double) $ fromIntegral n
dither :: Int -> Int -> PixelRGB8 -> PixelRGB8
dither x y (PixelRGB8 r g b) = PixelRGB8 (fromIntegral r')
(fromIntegral g')
(fromIntegral b')
where
r' = min 255 (fromIntegral r + (x' + y') .&. 16)
g' = min 255 (fromIntegral g + (x' + y' + 7973) .&. 16)
b' = min 255 (fromIntegral b + (x' + y' + 15946) .&. 16)
x' = 119 * x
y' = 28084 * y
data Fold a b = forall x . Fold (x -> a -> x) x (x -> b)
fold :: Fold PackedRGB b -> VU.Vector PackedRGB -> b
fold (Fold step begin done) = done . VU.foldl' step begin
{-# INLINE fold #-}
data Pair a b = Pair !a !b
instance Functor (Fold a) where
fmap f (Fold step begin done) = Fold step begin (f . done)
{-# INLINABLE fmap #-}
instance Applicative (Fold a) where
pure b = Fold (\() _ -> ()) () (\() -> b)
{-# INLINABLE pure #-}
(Fold stepL beginL doneL) <*> (Fold stepR beginR doneR) =
let step (Pair xL xR) a = Pair (stepL xL a) (stepR xR a)
begin = Pair beginL beginR
done (Pair xL xR) = doneL xL $ doneR xR
in Fold step begin done
{-# INLINABLE (<*>) #-}
intLength :: Fold a Int
intLength = Fold (\n _ -> n + 1) 0 id
mkPaletteVec :: [Cluster] -> Vector PixelRGB8
mkPaletteVec = V.fromList . map (toRGB8 . meanColor)
type PackedRGB = Word32
data Cluster = Cluster
{ value :: {-# UNPACK #-} !Float
, meanColor :: !PixelRGBF
, dims :: !PixelRGBF
, colors :: VU.Vector PackedRGB
}
instance Eq Cluster where
a == b =
(value a, meanColor a, dims a) == (value b, meanColor b, dims b)
instance Ord Cluster where
compare a b =
compare (value a, meanColor a, dims a) (value b, meanColor b, dims b)
data Axis = RAxis | GAxis | BAxis
inf :: Float
inf = read "Infinity"
fromRGB8 :: PixelRGB8 -> PixelRGBF
fromRGB8 (PixelRGB8 r g b) =
PixelRGBF (fromIntegral r) (fromIntegral g) (fromIntegral b)
toRGB8 :: PixelRGBF -> PixelRGB8
toRGB8 (PixelRGBF r g b) =
PixelRGB8 (round r) (round g) (round b)
meanRGB :: Fold PixelRGBF PixelRGBF
meanRGB = mean <$> intLength <*> pixelSum
where
pixelSum = Fold (mixWith $ const (+)) (PixelRGBF 0 0 0) id
mean n = colorMap (/ nf)
where nf = fromIntegral n
minimal :: Fold PixelRGBF PixelRGBF
minimal = Fold mini (PixelRGBF inf inf inf) id
where mini = mixWith $ const min
maximal :: Fold PixelRGBF PixelRGBF
maximal = Fold maxi (PixelRGBF (-inf) (-inf) (-inf)) id
where maxi = mixWith $ const max
extrems :: Fold PixelRGBF (PixelRGBF, PixelRGBF)
extrems = (,) <$> minimal <*> maximal
volAndDims :: Fold PixelRGBF (Float, PixelRGBF)
volAndDims = deltify <$> extrems
where deltify (mini, maxi) = (dr * dg * db, delta)
where delta@(PixelRGBF dr dg db) =
mixWith (const (-)) maxi mini
unpackFold :: Fold PixelRGBF a -> Fold PackedRGB a
unpackFold (Fold step start done) = Fold (\acc -> step acc . transform) start done
where transform = fromRGB8 . rgbIntUnpack
mkCluster :: VU.Vector PackedRGB -> Cluster
mkCluster ps = Cluster
{ value = v * fromIntegral l
, meanColor = m
, dims = ds
, colors = ps
}
where
worker = (,,) <$> volAndDims <*> meanRGB <*> intLength
((v, ds), m, l) = fold (unpackFold worker) ps
maxAxis :: PixelRGBF -> Axis
maxAxis (PixelRGBF r g b) =
case (r `compare` g, r `compare` b, g `compare` b) of
(GT, GT, _) -> RAxis
(LT, GT, _) -> GAxis
(GT, LT, _) -> BAxis
(LT, LT, GT) -> GAxis
(EQ, GT, _) -> RAxis
(_, _, _) -> BAxis
subdivide :: Cluster -> (Cluster, Cluster)
subdivide cluster = (mkCluster px1, mkCluster px2)
where
(PixelRGBF mr mg mb) = meanColor cluster
(px1, px2) = VU.partition (cond . rgbIntUnpack) $ colors cluster
cond = case maxAxis $ dims cluster of
RAxis -> \(PixelRGB8 r _ _) -> fromIntegral r < mr
GAxis -> \(PixelRGB8 _ g _) -> fromIntegral g < mg
BAxis -> \(PixelRGB8 _ _ b) -> fromIntegral b < mb
rgbIntPack :: PixelRGB8 -> PackedRGB
rgbIntPack (PixelRGB8 r g b) =
wr `unsafeShiftL` (2 * 8) .|. wg `unsafeShiftL` 8 .|. wb
where wr = fromIntegral r
wg = fromIntegral g
wb = fromIntegral b
rgbIntUnpack :: PackedRGB -> PixelRGB8
rgbIntUnpack v = PixelRGB8 r g b
where
r = fromIntegral $ v `unsafeShiftR` (2 * 8)
g = fromIntegral $ v `unsafeShiftR` 8
b = fromIntegral v
initCluster :: Image PixelRGB8 -> Cluster
initCluster img = mkCluster $ VU.generate ((w * h) `div` subSampling) packer
where samplingFactor = 3
subSampling = samplingFactor * samplingFactor
compCount = componentCount (undefined :: PixelRGB8)
w = imageWidth img
h = imageHeight img
rawData = imageData img
packer ix =
rgbIntPack . unsafePixelAt rawData $ ix * subSampling * compCount
split :: Set Cluster -> Set Cluster
split cs = Set.insert c1 . Set.insert c2 $ cs'
where
(c, cs') = Set.deleteFindMax cs
(c1, c2) = subdivide c
clusters :: Int -> Image PixelRGB8 -> Set Cluster
clusters maxCols img = clusters' (maxCols - 1)
where
clusters' :: Int -> Set Cluster
clusters' 0 = Set.singleton c
clusters' n = split (clusters' (n-1))
c = initCluster img
dist2Px :: PixelRGB8 -> PixelRGB8 -> Int
dist2Px (PixelRGB8 r1 g1 b1) (PixelRGB8 r2 g2 b2) = dr*dr + dg*dg + db*db
where
(dr, dg, db) =
( fromIntegral r1 - fromIntegral r2
, fromIntegral g1 - fromIntegral g2
, fromIntegral b1 - fromIntegral b2 )
nearestColorIdx :: PixelRGB8 -> Vector PixelRGB8 -> Pixel8
nearestColorIdx p ps = fromIntegral $ V.minIndex (V.map (`dist2Px` p) ps)