{-# LANGUAGE BangPatterns
           , FlexibleContexts
           , FlexibleInstances
           , ParallelListComp
           , TypeFamilies
           , TypeOperators #-}
{-# OPTIONS_GHC -fno-warn-orphans #-}

-- | Contains functions to compute and manipulate histograms as well as some
-- images transformations which are histogram-based.
--
-- Every polymorphic function is specialised for histograms of 'Int32', 'Double'
-- and 'Float'. Other types can be specialized as every polymorphic function is
-- declared @INLINABLE@.
module Vision.Histogram (
    -- * Types & helpers
      Histogram (..), HistogramShape (..), ToHistogram (..)
    , index, (!), linearIndex, map, assocs, pixToBin
    -- * Histogram computations
    , histogram,  histogram2D, reduce, resize, cumulative, normalize
    -- * Images processing
    , equalizeImage
    -- * Histogram comparisons
    , compareCorrel, compareChi, compareIntersect, compareEMD
    ) where

import Data.Int
import Data.Vector.Storable (Vector)
import Foreign.Storable (Storable)
import Prelude hiding (map)

import qualified Data.Vector.Storable as V

import Vision.Image.Class (Pixel, MaskedImage, Image, ImagePixel, FunctorImage)
import Vision.Image.Grey.Type (GreyPixel (..))
import Vision.Image.HSV.Type  (HSVPixel (..))
import Vision.Image.RGBA.Type (RGBAPixel (..))
import Vision.Image.RGB.Type  (RGBPixel (..))
import Vision.Primitive (
      Z (..), (:.) (..), Shape (..), DIM1, DIM3, DIM4, DIM5, ix1, ix3, ix4
    )

import qualified Vision.Image.Class as I

-- There is no rule to simplify the conversion from Int32 to Double and Float
-- when using realToFrac. Both conversions are using a temporary yet useless
-- Rational value.

{-# RULES
"realToFrac/Int32->Double" realToFrac = fromIntegral :: Int32 -> Double
"realToFrac/Int32->Float"  realToFrac = fromIntegral :: Int32 -> Float
  #-}

-- Types -----------------------------------------------------------------------

data Histogram sh a = Histogram {
      shape  :: !sh
    , vector :: !(Vector a) -- ^ Values of the histogram in row-major order.
    } deriving (Eq, Ord, Show)

-- | Subclass of 'Shape' which defines how to resize a shape so it will fit
-- inside a resized histogram.
class Shape sh => HistogramShape sh where
    -- | Given a number of bins of an histogram, reduces an index so it will be
    -- mapped to a bin.
    toBin :: sh -- ^ The number of bins we are mapping to.
          -> sh -- ^ The number of possible values of the original index.
          -> sh -- ^ The original index.
          -> sh -- ^ The index of the bin in the histogram.

instance HistogramShape Z where
    toBin _ _ _ = Z
    {-# INLINE toBin #-}

instance HistogramShape sh => HistogramShape (sh :. Int) where
    toBin !(shBins :. bins) !(shMaxBins :. maxBins) !(shIx :. ix)
        | bins == maxBins = inner :. ix
        | otherwise       = inner :. (ix * bins `quot` maxBins)
      where
        inner = toBin shBins shMaxBins shIx
    {-# INLINE toBin #-}

-- | This class defines how many dimensions a histogram will have and what will
-- be the default number of bins.
class (Pixel p, Shape (PixelValueSpace p)) => ToHistogram p where
    -- | Gives the value space of a pixel. Single-channel pixels will be 'DIM1'
    -- whereas three-channels pixels will be 'DIM3'.
    -- This is used to determine the rank of the generated histogram.
    type PixelValueSpace p

    -- | Converts a pixel to an index.
    pixToIndex :: p -> PixelValueSpace p

    -- | Returns the maximum number of different values an index can take for
    -- each dimension of the histogram (aka. the maximum index returned by
    -- 'pixToIndex' plus one).
    domainSize :: p -> PixelValueSpace p

instance ToHistogram GreyPixel where
    type PixelValueSpace GreyPixel = DIM1

    pixToIndex !(GreyPixel val) = ix1 $ int val
    {-# INLINE pixToIndex #-}

    domainSize _ = ix1 256

instance ToHistogram RGBAPixel where
    type PixelValueSpace RGBAPixel = DIM4

    pixToIndex !(RGBAPixel r g b a) = ix4 (int r) (int g) (int b) (int a)
    {-# INLINE pixToIndex #-}

    domainSize _ = ix4 256 256 256 256

instance ToHistogram RGBPixel where
    type PixelValueSpace RGBPixel = DIM3

    pixToIndex !(RGBPixel r g b) = ix3 (int r) (int g) (int b)
    {-# INLINE pixToIndex #-}

    domainSize _ = ix3 256 256 256

instance ToHistogram HSVPixel where
    type PixelValueSpace HSVPixel = DIM3

    pixToIndex !(HSVPixel h s v) = ix3 (int h) (int s) (int v)
    {-# INLINE pixToIndex #-}

    domainSize _ = ix3 180 256 256

-- Functions -------------------------------------------------------------------

index :: (Shape sh, Storable a) => Histogram sh a -> sh -> a
index !hist = linearIndex hist . toLinearIndex (shape hist)
{-# INLINE index #-}

-- | Alias of 'index'.
(!) :: (Shape sh, Storable a) => Histogram sh a -> sh -> a
(!) = index
{-# INLINE (!) #-}

-- | Returns the value at the index as if the histogram was a single dimension
-- vector (row-major representation).
linearIndex :: (Shape sh, Storable a) => Histogram sh a -> Int -> a
linearIndex !hist = (V.!) (vector hist)
{-# INLINE linearIndex #-}

map :: (Storable a, Storable b) => (a -> b) -> Histogram sh a -> Histogram sh b
map f !(Histogram sh vec) = Histogram sh (V.map f vec)
{-# INLINE map #-}

-- | Returns all index/value pairs from the histogram.
assocs :: (Shape sh, Storable a) => Histogram sh a -> [(sh, a)]
assocs !(Histogram sh vec) = [ (ix, v) | ix <- shapeList sh
                                       | v <- V.toList vec ]
{-# INLINE assocs #-}

-- | Given the number of bins of an histogram and a given pixel, returns the
-- corresponding bin.
pixToBin :: (HistogramShape (PixelValueSpace p), ToHistogram p)
         => PixelValueSpace p -> p -> PixelValueSpace p
pixToBin size p =
    let !domain = domainSize p
    in toBin size domain $! pixToIndex p
{-# INLINE pixToBin #-}

-- | Computes an histogram from a (possibly) multi-channel image.
--
-- If the size of the histogram is not given, there will be as many bins as the
-- range of values of pixels of the original image (see 'domainSize').
--
-- If the size of the histogram is specified, every bin of a given dimension
-- will be of the same size (uniform histogram).
histogram :: ( MaskedImage i, ToHistogram (ImagePixel i), Storable a, Num a
             , HistogramShape (PixelValueSpace (ImagePixel i)))
         => Maybe (PixelValueSpace (ImagePixel i)) -> i
         -> Histogram (PixelValueSpace (ImagePixel i)) a
histogram mSize img =
    let initial = V.replicate nBins 0
        ones    = V.replicate nPixs 1
        ixs     = V.map toIndex (I.values img)
    in Histogram size (V.accumulate_ (+) initial ixs ones)
  where
    !size = case mSize of Just s  -> s
                          Nothing -> domainSize (I.pixel img)
    !nChans = I.nChannels img
    !nPixs = shapeLength (I.shape img) * nChans
    !nBins = shapeLength size

    toIndex !p = toLinearIndex size $!
        case mSize of Just _  -> pixToBin   size p
                      Nothing -> pixToIndex p
    {-# INLINE toIndex #-}
{-# INLINABLE histogram #-}

-- | Similar to 'histogram' but adds two dimensions for the y and x-coordinates
-- of the sampled points. This way, the histogram will map different regions of
-- the original image.
--
-- For example, an 'RGB' image will be mapped as
-- @'Z' ':.' red channel ':.' green channel ':.' blue channel ':.' y region
-- ':.' x region@.
--
-- As there is no reason to create an histogram as large as the number of pixels
-- of the image, a size is always needed.
histogram2D :: ( Image i, ToHistogram (ImagePixel i), Storable a, Num a
               , HistogramShape (PixelValueSpace (ImagePixel i)))
            => (PixelValueSpace (ImagePixel i)) :. Int :. Int -> i
            -> Histogram ((PixelValueSpace (ImagePixel i)) :. Int :. Int) a
histogram2D size img =
    let initial = V.replicate nBins 0
        ones    = V.replicate nPixs 1
        imgIxs  = V.iterateN nPixs (shapeSucc imgSize) shapeZero
        ixs     = V.zipWith toIndex imgIxs (I.vector img)
    in Histogram size (V.accumulate_ (+) initial ixs ones)
  where
    !imgSize@(Z :. h :. w) = I.shape img
    !maxSize = domainSize (I.pixel img) :. h :. w
    !nChans = I.nChannels img
    !nPixs = shapeLength (I.shape img) * nChans
    !nBins = shapeLength size

    toIndex !(Z :. y :. x) !p =
        let !ix = (pixToIndex p) :. y :. x
        in toLinearIndex size $! toBin size maxSize ix
    {-# INLINE toIndex #-}
{-# INLINABLE histogram2D #-}

-- Reshaping -------------------------------------------------------------------

-- | Reduces a 2D histogram to its linear representation. See 'resize' for a
-- reduction of the number of bins of an histogram.
--
-- @'histogram' == 'reduce' . 'histogram2D'@
reduce :: (HistogramShape sh, Storable a, Num a)
       => Histogram (sh :. Int :. Int) a -> Histogram sh a
reduce !(Histogram sh vec) =
    let !(sh' :. h :. w) = sh
        !len2D = h * w
        !vec' = V.unfoldrN (shapeLength sh') step vec
        step !rest = let (!channels, !rest') = V.splitAt len2D rest
                     in Just (V.sum channels, rest')
    in Histogram sh' vec'
{-# SPECIALIZE reduce :: Histogram DIM5 Int32  -> Histogram DIM3 Int32
                      ,  Histogram DIM5 Double -> Histogram DIM3 Double
                      ,  Histogram DIM5 Float  -> Histogram DIM3 Float
                      ,  Histogram DIM3 Int32  -> Histogram DIM1 Int32
                      ,  Histogram DIM3 Double -> Histogram DIM1 Double
                      ,  Histogram DIM3 Float  -> Histogram DIM1 Float #-}
{-# INLINABLE reduce #-}

-- | Resizes an histogram to another index shape. See 'reduce' for a reduction
-- of the number of dimensions of an histogram.
resize :: (HistogramShape sh, Storable a, Num a)
       => sh -> Histogram sh a -> Histogram sh a
resize !sh' (Histogram sh vec) =
    let initial = V.replicate (shapeLength sh') 0
        -- TODO: In this scheme, indexes are computed for each bin of the
        -- original histogram. It's sub-optimal as some parts of those indexes
        -- (lower dimensions) don't change at each bin.
        reIndex = toLinearIndex sh' . toBin sh' sh . fromLinearIndex sh
        ixs = V.map reIndex $ V.enumFromN 0 (shapeLength sh)
    in Histogram sh' (V.accumulate_ (+) initial ixs vec)

-- Normalisation ---------------------------------------------------------------

-- | Computes the cumulative histogram of another single dimension histogram.
--
-- @C(i) = SUM H(j)@ for each @j@ in @[0..i]@ where @C@ is the cumulative
-- histogram, and @H@ the original histogram.
cumulative :: (Storable a, Num a) => Histogram DIM1 a -> Histogram DIM1 a
cumulative (Histogram sh vec) = Histogram sh (V.scanl1' (+) vec)
{-# SPECIALIZE cumulative :: Histogram DIM1 Int32  -> Histogram DIM1 Int32
                          ,  Histogram DIM1 Double -> Histogram DIM1 Double
                          ,  Histogram DIM1 Float  -> Histogram DIM1 Float #-}
{-# INLINABLE cumulative #-}

-- | Normalizes the histogram so that the sum of the histogram bins is equal to
-- the given value (normalisation by the @L1@ norm).
--
-- This is useful to compare two histograms which have been computed from images
-- with a different number of pixels.
normalize :: (Storable a, Real a, Storable b, Fractional b)
          => b -> Histogram sh a -> Histogram sh b
normalize norm !hist@(Histogram _ vec) =
    let !ratio = norm / realToFrac (V.sum vec)
        equalizeVal !val = realToFrac val * ratio
        {-# INLINE equalizeVal #-}
    in map equalizeVal hist
{-# SPECIALIZE normalize :: Double -> Histogram sh Int32  -> Histogram sh Double
                         ,  Float  -> Histogram sh Int32  -> Histogram sh Float
                         ,  Double -> Histogram sh Double -> Histogram sh Double
                         ,  Float  -> Histogram sh Double -> Histogram sh Float
                         ,  Double -> Histogram sh Float  -> Histogram sh Double
                         ,  Float  -> Histogram sh Float  -> Histogram sh Float
                         #-}
{-# INLINABLE normalize #-}

-- | Equalizes a single channel image by equalising its histogram.
--
-- The algorithm equalizes the brightness and increases the contrast of the
-- image by mapping each pixel values to the value at the index of the
-- cumulative @L1@-normalized histogram :
--
-- @N(x, y) = H(I(x, y))@ where @N@ is the equalized image, @I@ is the image and
-- @H@ the cumulative of the histogram normalized over an @L1@ norm.
--
-- See <https://en.wikipedia.org/wiki/Histogram_equalization>.
equalizeImage :: ( FunctorImage i i, Integral (ImagePixel i)
                 , ToHistogram (ImagePixel i)
                 , PixelValueSpace (ImagePixel i) ~ DIM1)
              => i -> i
equalizeImage img =
    I.map equalizePixel img
  where
    hist            = histogram Nothing img             :: Histogram DIM1 Int32
    Z :. nBins      = shape hist
    cumNormalized   = cumulative $ normalize (double nBins) hist
    !cumNormalized' = map round cumNormalized           :: Histogram DIM1 Int32
    equalizePixel !val = fromIntegral $ cumNormalized' ! ix1 (int val)
    {-# INLINE equalizePixel #-}
{-# INLINABLE equalizeImage #-}

-- Comparisons -----------------------------------------------------------------

-- | Computes the /Pearson\'s correlation coefficient/ between each
-- corresponding bins of the two histograms.
--
-- A value of 1 implies a perfect correlation, a value of -1 a perfect
-- opposition and a value of 0 no correlation at all.
--
-- @'compareCorrel' = SUM  [ (H1(i) - µ(H1)) (H1(2) - µ(H2)) ]
--                  / (   SQRT [ SUM [ (H1(i) - µ(H1))^2 ] ]
--                      * SQRT [ SUM [ (H2(i) - µ(H2))^2 ] ] )@
--
-- Where @µ(H)@ is the average value of the histogram @H@.
--
-- See <http://en.wikipedia.org/wiki/Pearson_correlation_coefficient>.
compareCorrel :: (Shape sh, Storable a, Real a, Storable b, Eq b, Floating b)
              => Histogram sh a -> Histogram sh a -> b
compareCorrel (Histogram sh1 vec1) (Histogram sh2 vec2)
    | sh1 /= sh2     = error "Histograms are not of equal size."
    | denominat == 0 = 1
    | otherwise      = numerat / denominat
  where
    numerat   = V.sum $ V.zipWith (*) diff1 diff2
    denominat =   sqrt (V.sum (V.map square diff1))
                * sqrt (V.sum (V.map square diff2))

    diff1 = V.map (\v1 -> realToFrac v1 - avg1) vec1
    diff2 = V.map (\v2 -> realToFrac v2 - avg2) vec2

    (avg1, avg2) = (avg vec1, avg vec2)
    avg !vec = realToFrac (V.sum vec) / realToFrac (V.length vec)
{-# SPECIALIZE compareCorrel
    :: Shape sh => Histogram sh Int32  -> Histogram sh Int32  -> Double
    ,  Shape sh => Histogram sh Int32  -> Histogram sh Int32  -> Float
    ,  Shape sh => Histogram sh Double -> Histogram sh Double -> Double
    ,  Shape sh => Histogram sh Double -> Histogram sh Double -> Float
    ,  Shape sh => Histogram sh Float  -> Histogram sh Float  -> Double
    ,  Shape sh => Histogram sh Float  -> Histogram sh Float  -> Float  #-}
{-# INLINABLE compareCorrel #-}

-- | Computes the Chi-squared distance between two histograms.
--
-- A value of 0 indicates a perfect match.
--
-- @'compareChi' = SUM (d(i))@ for each indice @i@ of the histograms where
-- @d(i) = 2 * ((H1(i) - H2(i))^2 / (H1(i) + H2(i)))@.
compareChi :: (Shape sh, Storable a, Real a, Storable b, Fractional b)
           => Histogram sh a -> Histogram sh a -> b
compareChi (Histogram sh1 vec1) (Histogram sh2 vec2)
    | sh1 /= sh2 = error "Histograms are not of equal size."
    | otherwise  = (V.sum $ V.zipWith step vec1 vec2) * 2
  where
    step !v1 !v2 = let !denom = v1 + v2
                   in if denom == 0
                        then 0
                        else realToFrac (square (v1 - v2)) / realToFrac denom
    {-# INLINE step #-}
{-# SPECIALIZE compareChi
    :: Shape sh => Histogram sh Int32  -> Histogram sh Int32  -> Double
    ,  Shape sh => Histogram sh Int32  -> Histogram sh Int32  -> Float
    ,  Shape sh => Histogram sh Double -> Histogram sh Double -> Double
    ,  Shape sh => Histogram sh Double -> Histogram sh Double -> Float
    ,  Shape sh => Histogram sh Float  -> Histogram sh Float  -> Double
    ,  Shape sh => Histogram sh Float  -> Histogram sh Float  -> Float  #-}
{-# INLINABLE compareChi #-}

-- | Computes the intersection of the two histograms.
--
-- The higher the score is, the best the match is.
--
-- @'compareIntersect' = SUM (min(H1(i), H2(i))@ for each indice @i@ of the
-- histograms.
compareIntersect :: (Shape sh, Storable a, Num a, Ord a)
                 => Histogram sh a -> Histogram sh a -> a
compareIntersect (Histogram sh1 vec1) (Histogram sh2 vec2)
    | sh1 /= sh2 = error "Histograms are not of equal size."
    | otherwise  = V.sum $ V.zipWith min vec1 vec2
{-# SPECIALIZE compareIntersect
    :: Shape sh => Histogram sh Int32  -> Histogram sh Int32  -> Int32
    ,  Shape sh => Histogram sh Double -> Histogram sh Double -> Double
    ,  Shape sh => Histogram sh Float  -> Histogram sh Float  -> Float #-}
{-# INLINABLE compareIntersect #-}

-- | Computed the /Earth mover's distance/ between two histograms.
--
-- Current algorithm only supports histograms of one dimension.
--
-- See <https://en.wikipedia.org/wiki/Earth_mover's_distance>.
compareEMD :: (Num a, Storable a)
           => Histogram DIM1 a -> Histogram DIM1 a -> a
compareEMD hist1@(Histogram sh1 _) hist2@(Histogram sh2 _)
    | sh1 /= sh2 = error "Histograms are not of equal size."
    | otherwise  = let Histogram _ vec1 = cumulative hist1
                       Histogram _ vec2 = cumulative hist2
                   in V.sum $ V.zipWith (\v1 v2 -> abs (v1 - v2)) vec1 vec2
{-# SPECIALIZE compareEMD
    :: Histogram DIM1 Int32  -> Histogram DIM1 Int32  -> Int32
    ,  Histogram DIM1 Double -> Histogram DIM1 Double -> Double
    ,  Histogram DIM1 Float  -> Histogram DIM1 Float  -> Float #-}
{-# INLINABLE compareEMD #-}

square :: Num a => a -> a
square a = a * a

double :: Integral a => a -> Double
double= fromIntegral

int :: Integral a => a -> Int
int = fromIntegral