{-# LANGUAGE CPP #-}
{-# LANGUAGE TypeFamilies #-}
module Codec.Picture.Jpg.Internal.FastDct( referenceDct, fastDctLibJpeg ) where
#if !MIN_VERSION_base(4,8,0)
import Control.Applicative( (<$>) )
#endif
import Data.Int( Int16, Int32 )
import Data.Bits( unsafeShiftR, unsafeShiftL )
import Control.Monad.ST( ST )
import qualified Data.Vector.Storable.Mutable as M
import Codec.Picture.Jpg.Internal.Types
import Control.Monad( forM, forM_ )
referenceDct :: MutableMacroBlock s Int32
-> MutableMacroBlock s Int16
-> ST s (MutableMacroBlock s Int32)
referenceDct workData block = do
forM_ [(u, v) | u <- [0 :: Int .. dctBlockSize - 1], v <- [0..dctBlockSize - 1]] $ \(u,v) -> do
val <- at (u,v)
(workData `M.unsafeWrite` (v * dctBlockSize + u)) . truncate $ (1 / 4) * c u * c v * val
return workData
where
at (u,v) = do
toSum <-
forM [(x,y) | x <- [0..dctBlockSize - 1], y <- [0..dctBlockSize - 1 :: Int]] $ \(x,y) -> do
sample <- fromIntegral <$> (block `M.unsafeRead` (y * dctBlockSize + x))
return $ sample * cos ((2 * fromIntegral x + 1) * fromIntegral u * (pi :: Float)/ 16)
* cos ((2 * fromIntegral y + 1) * fromIntegral v * pi / 16)
return $ sum toSum
c 0 = 1 / sqrt 2
c _ = 1
pASS1_BITS, cONST_BITS :: Int
cONST_BITS = 13
pASS1_BITS = 2
fIX_0_298631336, fIX_0_390180644, fIX_0_541196100,
fIX_0_765366865, fIX_0_899976223, fIX_1_175875602,
fIX_1_501321110, fIX_1_847759065, fIX_1_961570560,
fIX_2_053119869, fIX_2_562915447, fIX_3_072711026 :: Int32
fIX_0_298631336 = 2446
fIX_0_390180644 = 3196
fIX_0_541196100 = 4433
fIX_0_765366865 = 6270
fIX_0_899976223 = 7373
fIX_1_175875602 = 9633
fIX_1_501321110 = 12299
fIX_1_847759065 = 15137
fIX_1_961570560 = 16069
fIX_2_053119869 = 16819
fIX_2_562915447 = 20995
fIX_3_072711026 = 25172
cENTERJSAMPLE :: Int32
cENTERJSAMPLE = 128
fastDctLibJpeg :: MutableMacroBlock s Int32
-> MutableMacroBlock s Int16
-> ST s (MutableMacroBlock s Int32)
fastDctLibJpeg workData sample_block = do
firstPass workData 0
secondPass workData 7
return workData
where
firstPass _ i | i == dctBlockSize = return ()
firstPass dataBlock i = do
let baseIdx = i * dctBlockSize
readAt idx = fromIntegral <$> sample_block `M.unsafeRead` (baseIdx + idx)
mult = (*)
writeAt idx = dataBlock `M.unsafeWrite` (baseIdx + idx)
writeAtPos idx n = (dataBlock `M.unsafeWrite` (baseIdx + idx))
(n `unsafeShiftR` (cONST_BITS - pASS1_BITS))
blk0 <- readAt 0
blk1 <- readAt 1
blk2 <- readAt 2
blk3 <- readAt 3
blk4 <- readAt 4
blk5 <- readAt 5
blk6 <- readAt 6
blk7 <- readAt 7
let tmp0 = blk0 + blk7
tmp1 = blk1 + blk6
tmp2 = blk2 + blk5
tmp3 = blk3 + blk4
tmp10 = tmp0 + tmp3
tmp12 = tmp0 - tmp3
tmp11 = tmp1 + tmp2
tmp13 = tmp1 - tmp2
tmp0' = blk0 - blk7
tmp1' = blk1 - blk6
tmp2' = blk2 - blk5
tmp3' = blk3 - blk4
writeAt 0 $ (tmp10 + tmp11 - dctBlockSize * cENTERJSAMPLE) `unsafeShiftL` pASS1_BITS
writeAt 4 $ (tmp10 - tmp11) `unsafeShiftL` pASS1_BITS
let z1 = mult (tmp12 + tmp13) fIX_0_541196100
+ (1 `unsafeShiftL` (cONST_BITS - pASS1_BITS - 1))
writeAtPos 2 $ z1 + mult tmp12 fIX_0_765366865
writeAtPos 6 $ z1 - mult tmp13 fIX_1_847759065
let tmp10' = tmp0' + tmp3'
tmp11' = tmp1' + tmp2'
tmp12' = tmp0' + tmp2'
tmp13' = tmp1' + tmp3'
z1' = mult (tmp12' + tmp13') fIX_1_175875602
+ (1 `unsafeShiftL` (cONST_BITS - pASS1_BITS-1))
tmp0'' = mult tmp0' fIX_1_501321110
tmp1'' = mult tmp1' fIX_3_072711026
tmp2'' = mult tmp2' fIX_2_053119869
tmp3'' = mult tmp3' fIX_0_298631336
tmp10'' = mult tmp10' (- fIX_0_899976223)
tmp11'' = mult tmp11' (- fIX_2_562915447)
tmp12'' = mult tmp12' (- fIX_0_390180644) + z1'
tmp13'' = mult tmp13' (- fIX_1_961570560) + z1'
writeAtPos 1 $ tmp0'' + tmp10'' + tmp12''
writeAtPos 3 $ tmp1'' + tmp11'' + tmp13''
writeAtPos 5 $ tmp2'' + tmp11'' + tmp12''
writeAtPos 7 $ tmp3'' + tmp10'' + tmp13''
firstPass dataBlock $ i + 1
secondPass :: M.STVector s Int32 -> Int -> ST s ()
secondPass _ (-1) = return ()
secondPass block i = do
let readAt idx = block `M.unsafeRead` ((7 - i) + idx * dctBlockSize)
mult = (*)
writeAt idx = block `M.unsafeWrite` (dctBlockSize * idx + (7 - i))
writeAtPos idx n = (block `M.unsafeWrite` (dctBlockSize * idx + (7 - i))) $ n `unsafeShiftR` (cONST_BITS + pASS1_BITS + 3)
blk0 <- readAt 0
blk1 <- readAt 1
blk2 <- readAt 2
blk3 <- readAt 3
blk4 <- readAt 4
blk5 <- readAt 5
blk6 <- readAt 6
blk7 <- readAt 7
let tmp0 = blk0 + blk7
tmp1 = blk1 + blk6
tmp2 = blk2 + blk5
tmp3 = blk3 + blk4
tmp10 = tmp0 + tmp3 + (1 `unsafeShiftL` (pASS1_BITS-1))
tmp12 = tmp0 - tmp3
tmp11 = tmp1 + tmp2
tmp13 = tmp1 - tmp2
tmp0' = blk0 - blk7
tmp1' = blk1 - blk6
tmp2' = blk2 - blk5
tmp3' = blk3 - blk4
writeAt 0 $ (tmp10 + tmp11) `unsafeShiftR` (pASS1_BITS + 3)
writeAt 4 $ (tmp10 - tmp11) `unsafeShiftR` (pASS1_BITS + 3)
let z1 = mult (tmp12 + tmp13) fIX_0_541196100
+ (1 `unsafeShiftL` (cONST_BITS + pASS1_BITS - 1))
writeAtPos 2 $ z1 + mult tmp12 fIX_0_765366865
writeAtPos 6 $ z1 - mult tmp13 fIX_1_847759065
let tmp10' = tmp0' + tmp3'
tmp11' = tmp1' + tmp2'
tmp12' = tmp0' + tmp2'
tmp13' = tmp1' + tmp3'
z1' = mult (tmp12' + tmp13') fIX_1_175875602
+ 1 `unsafeShiftL` (cONST_BITS+pASS1_BITS-1);
tmp0'' = mult tmp0' fIX_1_501321110
tmp1'' = mult tmp1' fIX_3_072711026
tmp2'' = mult tmp2' fIX_2_053119869
tmp3'' = mult tmp3' fIX_0_298631336
tmp10'' = mult tmp10' (- fIX_0_899976223)
tmp11'' = mult tmp11' (- fIX_2_562915447)
tmp12'' = mult tmp12' (- fIX_0_390180644)
+ z1'
tmp13'' = mult tmp13' (- fIX_1_961570560)
+ z1'
writeAtPos 1 $ tmp0'' + tmp10'' + tmp12''
writeAtPos 3 $ tmp1'' + tmp11'' + tmp13''
writeAtPos 5 $ tmp2'' + tmp11'' + tmp12''
writeAtPos 7 $ tmp3'' + tmp10'' + tmp13''
secondPass block (i - 1)
{-# ANN module "HLint: ignore Use camelCase" #-}
{-# ANN module "HLint: ignore Reduce duplication" #-}