{-# LANGUAGE BangPatterns
           , FlexibleContexts
           , RankNTypes
           , TypeFamilies #-}

-- | Contains an stateful image type which can be modified inside a 'ST' monad.
module Vision.Image.Mutable (
      MutableImage (..), create
    , MutableManifest (..)
    ) where

import Control.Monad.Primitive (PrimMonad (..))
import Control.Monad.ST (ST, runST)
import Data.Vector.Storable (MVector)
import Foreign.Storable (Storable)
import Prelude hiding (read)

import qualified Data.Vector.Storable as V
import qualified Data.Vector.Storable.Mutable as MV

import Vision.Image.Class (Image, ImagePixel)
import Vision.Image.Type (Manifest (..))
import Vision.Primitive (
      Point, Size, fromLinearIndex, toLinearIndex, shapeLength
    )

-- | Class for images which can be constructed from a mutable image.
class Image (Freezed i) => MutableImage i where
    -- | The type of the immutable version of the mutable image 'i'.
    type Freezed i

    -- | 'mShape' doesn't run in a monad as the size of a mutable image is
    -- constant.
    mShape :: i s -> Size

    -- | Creates a new mutable image of the given size. Pixels are initialized
    -- with an unknown value.
    new :: PrimMonad m => Size -> m (i (PrimState m))

    -- | Creates a new mutable image of the given size and fill it with the
    -- given value.
    new' :: PrimMonad m => Size -> ImagePixel (Freezed i) -> m (i (PrimState m))

    -- | Returns the pixel value at @Z :. y :. x@.
    read :: PrimMonad m => i (PrimState m) -> Point
         -> m (ImagePixel (Freezed i))
    read !i (PrimState m)
img !Point
ix = i (PrimState m)
img forall (i :: * -> *) (m :: * -> *).
(MutableImage i, PrimMonad m) =>
i (PrimState m) -> Int -> m (ImagePixel (Freezed i))
`linearRead` forall sh. Shape sh => sh -> sh -> Int
toLinearIndex (forall (i :: * -> *) s. MutableImage i => i s -> Point
mShape i (PrimState m)
img) Point
ix
    {-# INLINE read #-}

    -- | Returns the pixel value as if the image was a single dimension vector
    -- (row-major representation).
    linearRead :: PrimMonad m
               => i (PrimState m) -> Int -> m (ImagePixel (Freezed i))
    linearRead !i (PrimState m)
img !Int
ix = i (PrimState m)
img forall (i :: * -> *) (m :: * -> *).
(MutableImage i, PrimMonad m) =>
i (PrimState m) -> Point -> m (ImagePixel (Freezed i))
`read` forall sh. Shape sh => sh -> Int -> sh
fromLinearIndex (forall (i :: * -> *) s. MutableImage i => i s -> Point
mShape i (PrimState m)
img) Int
ix
    {-# INLINE linearRead #-}

    -- | Overrides the value of the pixel at @Z :. y :. x@.
    write :: PrimMonad m => i (PrimState m) -> Point -> ImagePixel (Freezed i)
          -> m ()
    write !i (PrimState m)
img !Point
ix !ImagePixel (Freezed i)
val = forall (i :: * -> *) (m :: * -> *).
(MutableImage i, PrimMonad m) =>
i (PrimState m) -> Int -> ImagePixel (Freezed i) -> m ()
linearWrite i (PrimState m)
img (forall sh. Shape sh => sh -> sh -> Int
toLinearIndex (forall (i :: * -> *) s. MutableImage i => i s -> Point
mShape i (PrimState m)
img) Point
ix) ImagePixel (Freezed i)
val
    {-# INLINE write #-}

    -- | Overrides the value of the pixel at the given index as if the image was
    -- a single dimension vector (row-major representation).
    linearWrite :: PrimMonad m => i (PrimState m) -> Int
                -> ImagePixel (Freezed i) -> m ()
    linearWrite !i (PrimState m)
img !Int
ix !ImagePixel (Freezed i)
val = forall (i :: * -> *) (m :: * -> *).
(MutableImage i, PrimMonad m) =>
i (PrimState m) -> Point -> ImagePixel (Freezed i) -> m ()
write i (PrimState m)
img (forall sh. Shape sh => sh -> Int -> sh
fromLinearIndex (forall (i :: * -> *) s. MutableImage i => i s -> Point
mShape i (PrimState m)
img) Int
ix) ImagePixel (Freezed i)
val

    -- | Returns an immutable copy of the mutable image.
    freeze :: PrimMonad m => i (PrimState m) -> m (Freezed i)

    -- | Returns the immutable version of the mutable image. The mutable image
    -- should not be modified thereafter.
    unsafeFreeze :: PrimMonad m => i (PrimState m) -> m (Freezed i)
    unsafeFreeze = forall (i :: * -> *) (m :: * -> *).
(MutableImage i, PrimMonad m) =>
i (PrimState m) -> m (Freezed i)
freeze

    -- | Returns a mutable copy of the immutable image.
    thaw :: PrimMonad m => Freezed i -> m (i (PrimState m))

    {-# MINIMAL mShape, new, new', (read | linearRead)
              , (write | linearWrite), freeze, thaw #-}

-- | Creates an immutable image from an 'ST' action creating a mutable image.
create :: (MutableImage i) => (forall s. ST s (i s)) -> Freezed i
create :: forall (i :: * -> *).
MutableImage i =>
(forall s. ST s (i s)) -> Freezed i
create forall s. ST s (i s)
action =
    forall a. (forall s. ST s a) -> a
runST forall a b. (a -> b) -> a -> b
$ do
        i s
img <- forall s. ST s (i s)
action
        forall (i :: * -> *) (m :: * -> *).
(MutableImage i, PrimMonad m) =>
i (PrimState m) -> m (Freezed i)
unsafeFreeze i s
img

-- Instances -------------------------------------------------------------------

data MutableManifest p s = MutableManifest {
      forall p s. MutableManifest p s -> Point
mmSize   :: !Size
    , forall p s. MutableManifest p s -> MVector s p
mmVector :: !(MVector s p)
    }

instance Storable p => MutableImage (MutableManifest p) where
    type Freezed (MutableManifest p) = Manifest p

    mShape :: forall s. MutableManifest p s -> Point
mShape = forall p s. MutableManifest p s -> Point
mmSize

    new :: forall (m :: * -> *).
PrimMonad m =>
Point -> m (MutableManifest p (PrimState m))
new  !Point
size = do
        MVector (PrimState m) p
mvec <- forall (m :: * -> *) a.
(PrimMonad m, Storable a) =>
Int -> m (MVector (PrimState m) a)
MV.new (forall sh. Shape sh => sh -> Int
shapeLength Point
size)
        forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$! forall p s. Point -> MVector s p -> MutableManifest p s
MutableManifest Point
size MVector (PrimState m) p
mvec

    new' :: forall (m :: * -> *).
PrimMonad m =>
Point
-> ImagePixel (Freezed (MutableManifest p))
-> m (MutableManifest p (PrimState m))
new' !Point
size !ImagePixel (Freezed (MutableManifest p))
val = do
        MVector (PrimState m) p
mvec <- forall (m :: * -> *) a.
(PrimMonad m, Storable a) =>
Int -> a -> m (MVector (PrimState m) a)
MV.replicate (forall sh. Shape sh => sh -> Int
shapeLength Point
size) ImagePixel (Freezed (MutableManifest p))
val
        forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$! forall p s. Point -> MVector s p -> MutableManifest p s
MutableManifest Point
size MVector (PrimState m) p
mvec

    linearRead :: forall (m :: * -> *).
PrimMonad m =>
MutableManifest p (PrimState m)
-> Int -> m (ImagePixel (Freezed (MutableManifest p)))
linearRead  !MutableManifest p (PrimState m)
img = forall (m :: * -> *) a.
(PrimMonad m, Storable a) =>
MVector (PrimState m) a -> Int -> m a
MV.read  (forall p s. MutableManifest p s -> MVector s p
mmVector MutableManifest p (PrimState m)
img)
    {-# INLINE linearRead #-}

    linearWrite :: forall (m :: * -> *).
PrimMonad m =>
MutableManifest p (PrimState m)
-> Int -> ImagePixel (Freezed (MutableManifest p)) -> m ()
linearWrite !MutableManifest p (PrimState m)
img = forall (m :: * -> *) a.
(PrimMonad m, Storable a) =>
MVector (PrimState m) a -> Int -> a -> m ()
MV.write (forall p s. MutableManifest p s -> MVector s p
mmVector MutableManifest p (PrimState m)
img)
    {-# INLINE linearWrite #-}

    freeze :: forall (m :: * -> *).
PrimMonad m =>
MutableManifest p (PrimState m) -> m (Freezed (MutableManifest p))
freeze       !(MutableManifest Point
size MVector (PrimState m) p
mvec) = do
        Vector p
vec <- forall a (m :: * -> *).
(Storable a, PrimMonad m) =>
MVector (PrimState m) a -> m (Vector a)
V.freeze MVector (PrimState m) p
mvec
        forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$! forall p. Point -> Vector p -> Manifest p
Manifest Point
size Vector p
vec

    unsafeFreeze :: forall (m :: * -> *).
PrimMonad m =>
MutableManifest p (PrimState m) -> m (Freezed (MutableManifest p))
unsafeFreeze !(MutableManifest Point
size MVector (PrimState m) p
mvec) = do
        Vector p
vec <- forall a (m :: * -> *).
(Storable a, PrimMonad m) =>
MVector (PrimState m) a -> m (Vector a)
V.unsafeFreeze MVector (PrimState m) p
mvec
        forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$! forall p. Point -> Vector p -> Manifest p
Manifest Point
size Vector p
vec

    thaw :: forall (m :: * -> *).
PrimMonad m =>
Freezed (MutableManifest p) -> m (MutableManifest p (PrimState m))
thaw !(Manifest Point
size Vector p
vec) = do
        MVector (PrimState m) p
mvec <- forall a (m :: * -> *).
(Storable a, PrimMonad m) =>
Vector a -> m (MVector (PrimState m) a)
V.thaw Vector p
vec
        forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$! forall p s. Point -> MVector s p -> MutableManifest p s
MutableManifest Point
size MVector (PrimState m) p
mvec