{-# LANGUAGE TypeFamilies #-}
module Data.Array.Comfort.Storable.Private where

import qualified Data.Array.Comfort.Storable.Mutable.Private as MutArray
import qualified Data.Array.Comfort.Shape as Shape

import qualified Foreign.Marshal.Array.Guarded as Alloc
import Foreign.Storable (Storable, )
import Foreign.ForeignPtr (ForeignPtr, )

import Control.DeepSeq (NFData, rnf)
import Control.Monad.Primitive (PrimMonad, unsafeIOToPrim)
import Control.Monad.ST (runST)
import Control.Monad (liftM)

import Data.Foldable (forM_)


data Array sh a =
   Array {
      Array sh a -> sh
shape :: sh,
      Array sh a -> ForeignPtr a
buffer :: ForeignPtr a
   }

instance (Shape.C sh, Show sh, Storable a, Show a) => Show (Array sh a) where
   showsPrec :: Int -> Array sh a -> ShowS
showsPrec Int
p Array sh a
arr =
      Bool -> ShowS -> ShowS
showParen (Int
pInt -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>Int
10) (ShowS -> ShowS) -> ShowS -> ShowS
forall a b. (a -> b) -> a -> b
$ String -> ShowS
showString (String -> ShowS) -> String -> ShowS
forall a b. (a -> b) -> a -> b
$ (forall s. ST s String) -> String
forall a. (forall s. ST s a) -> a
runST (Array (ST s) sh a -> ST s String
forall (m :: * -> *) sh a.
(PrimMonad m, C sh, Show sh, Storable a, Show a) =>
Array m sh a -> m String
MutArray.show (Array (ST s) sh a -> ST s String)
-> ST s (Array (ST s) sh a) -> ST s String
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Array sh a -> ST s (Array (ST s) sh a)
forall (m :: * -> *) sh a.
(PrimMonad m, C sh, Storable a) =>
Array sh a -> m (Array m sh a)
unsafeThaw Array sh a
arr)

instance (NFData sh) => NFData (Array sh a) where
   rnf :: Array sh a -> ()
rnf (Array sh
sh ForeignPtr a
fptr) = ForeignPtr a -> () -> ()
seq ForeignPtr a
fptr (sh -> ()
forall a. NFData a => a -> ()
rnf sh
sh)

instance (Shape.C sh, Eq sh, Storable a, Eq a) => Eq (Array sh a) where
   a :: Array sh a
a@(Array sh
sha ForeignPtr a
_) == :: Array sh a -> Array sh a -> Bool
== b :: Array sh a
b@(Array sh
shb ForeignPtr a
_)  =  sh
shash -> sh -> Bool
forall a. Eq a => a -> a -> Bool
==sh
shb Bool -> Bool -> Bool
&& Array sh a -> [a]
forall sh a. (C sh, Storable a) => Array sh a -> [a]
toList Array sh a
a [a] -> [a] -> Bool
forall a. Eq a => a -> a -> Bool
== Array sh a -> [a]
forall sh a. (C sh, Storable a) => Array sh a -> [a]
toList Array sh a
b

reshape :: sh1 -> Array sh0 a -> Array sh1 a
reshape :: sh1 -> Array sh0 a -> Array sh1 a
reshape sh1
sh (Array sh0
_ ForeignPtr a
fptr) = sh1 -> ForeignPtr a -> Array sh1 a
forall sh a. sh -> ForeignPtr a -> Array sh a
Array sh1
sh ForeignPtr a
fptr

mapShape :: (sh0 -> sh1) -> Array sh0 a -> Array sh1 a
mapShape :: (sh0 -> sh1) -> Array sh0 a -> Array sh1 a
mapShape sh0 -> sh1
f Array sh0 a
arr = sh1 -> Array sh0 a -> Array sh1 a
forall sh1 sh0 a. sh1 -> Array sh0 a -> Array sh1 a
reshape (sh0 -> sh1
f (sh0 -> sh1) -> sh0 -> sh1
forall a b. (a -> b) -> a -> b
$ Array sh0 a -> sh0
forall sh a. Array sh a -> sh
shape Array sh0 a
arr) Array sh0 a
arr


infixl 9 !

(!) :: (Shape.Indexed sh, Storable a) => Array sh a -> Shape.Index sh -> a
(!) Array sh a
arr Index sh
ix = (forall s. ST s a) -> a
forall a. (forall s. ST s a) -> a
runST ((Array (ST s) sh a -> Index sh -> ST s a)
-> Index sh -> Array (ST s) sh a -> ST s a
forall a b c. (a -> b -> c) -> b -> a -> c
flip Array (ST s) sh a -> Index sh -> ST s a
forall (m :: * -> *) sh a.
(PrimMonad m, Indexed sh, Storable a) =>
Array m sh a -> Index sh -> m a
MutArray.read Index sh
ix (Array (ST s) sh a -> ST s a) -> ST s (Array (ST s) sh a) -> ST s a
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Array sh a -> ST s (Array (ST s) sh a)
forall (m :: * -> *) sh a.
(PrimMonad m, C sh, Storable a) =>
Array sh a -> m (Array m sh a)
unsafeThaw Array sh a
arr)

toList :: (Shape.C sh, Storable a) => Array sh a -> [a]
toList :: Array sh a -> [a]
toList Array sh a
arr = (forall s. ST s [a]) -> [a]
forall a. (forall s. ST s a) -> a
runST (Array (ST s) sh a -> ST s [a]
forall (m :: * -> *) sh a.
(PrimMonad m, C sh, Storable a) =>
Array m sh a -> m [a]
MutArray.toList (Array (ST s) sh a -> ST s [a])
-> ST s (Array (ST s) sh a) -> ST s [a]
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Array sh a -> ST s (Array (ST s) sh a)
forall (m :: * -> *) sh a.
(PrimMonad m, C sh, Storable a) =>
Array sh a -> m (Array m sh a)
unsafeThaw Array sh a
arr)

fromList :: (Shape.C sh, Storable a) => sh -> [a] -> Array sh a
fromList :: sh -> [a] -> Array sh a
fromList sh
sh [a]
arr = (forall s. ST s (Array sh a)) -> Array sh a
forall a. (forall s. ST s a) -> a
runST (Array (ST s) sh a -> ST s (Array sh a)
forall (m :: * -> *) sh a.
(PrimMonad m, C sh, Storable a) =>
Array m sh a -> m (Array sh a)
unsafeFreeze (Array (ST s) sh a -> ST s (Array sh a))
-> ST s (Array (ST s) sh a) -> ST s (Array sh a)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< sh -> [a] -> ST s (Array (ST s) sh a)
forall (m :: * -> *) sh a.
(PrimMonad m, C sh, Storable a) =>
sh -> [a] -> m (Array m sh a)
MutArray.fromList sh
sh [a]
arr)

vectorFromList :: (Storable a) => [a] -> Array (Shape.ZeroBased Int) a
vectorFromList :: [a] -> Array (ZeroBased Int) a
vectorFromList [a]
arr = (forall s. ST s (Array (ZeroBased Int) a))
-> Array (ZeroBased Int) a
forall a. (forall s. ST s a) -> a
runST (Array (ST s) (ZeroBased Int) a -> ST s (Array (ZeroBased Int) a)
forall (m :: * -> *) sh a.
(PrimMonad m, C sh, Storable a) =>
Array m sh a -> m (Array sh a)
unsafeFreeze (Array (ST s) (ZeroBased Int) a -> ST s (Array (ZeroBased Int) a))
-> ST s (Array (ST s) (ZeroBased Int) a)
-> ST s (Array (ZeroBased Int) a)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< [a] -> ST s (Array (ST s) (ZeroBased Int) a)
forall (m :: * -> *) a.
(PrimMonad m, Storable a) =>
[a] -> m (Array m (ZeroBased Int) a)
MutArray.vectorFromList [a]
arr)


(//) ::
   (Shape.Indexed sh, Storable a) =>
   Array sh a -> [(Shape.Index sh, a)] -> Array sh a
// :: Array sh a -> [(Index sh, a)] -> Array sh a
(//) Array sh a
arr [(Index sh, a)]
xs = (forall s. ST s (Array sh a)) -> Array sh a
forall a. (forall s. ST s a) -> a
runST (do
   Array (ST s) sh a
marr <- Array sh a -> ST s (Array (ST s) sh a)
forall (m :: * -> *) sh a.
(PrimMonad m, C sh, Storable a) =>
Array sh a -> m (Array m sh a)
thaw Array sh a
arr
   [(Index sh, a)] -> ((Index sh, a) -> ST s ()) -> ST s ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [(Index sh, a)]
xs (((Index sh, a) -> ST s ()) -> ST s ())
-> ((Index sh, a) -> ST s ()) -> ST s ()
forall a b. (a -> b) -> a -> b
$ (Index sh -> a -> ST s ()) -> (Index sh, a) -> ST s ()
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry ((Index sh -> a -> ST s ()) -> (Index sh, a) -> ST s ())
-> (Index sh -> a -> ST s ()) -> (Index sh, a) -> ST s ()
forall a b. (a -> b) -> a -> b
$ Array (ST s) sh a -> Index sh -> a -> ST s ()
forall (m :: * -> *) sh a.
(PrimMonad m, Indexed sh, Storable a) =>
Array m sh a -> Index sh -> a -> m ()
MutArray.write Array (ST s) sh a
marr
   Array (ST s) sh a -> ST s (Array sh a)
forall (m :: * -> *) sh a.
(PrimMonad m, C sh, Storable a) =>
Array m sh a -> m (Array sh a)
unsafeFreeze Array (ST s) sh a
marr)

accumulate ::
   (Shape.Indexed sh, Storable a) =>
   (a -> b -> a) -> Array sh a -> [(Shape.Index sh, b)] -> Array sh a
accumulate :: (a -> b -> a) -> Array sh a -> [(Index sh, b)] -> Array sh a
accumulate a -> b -> a
f Array sh a
arr [(Index sh, b)]
xs = (forall s. ST s (Array sh a)) -> Array sh a
forall a. (forall s. ST s a) -> a
runST (do
   Array (ST s) sh a
marr <- Array sh a -> ST s (Array (ST s) sh a)
forall (m :: * -> *) sh a.
(PrimMonad m, C sh, Storable a) =>
Array sh a -> m (Array m sh a)
thaw Array sh a
arr
   [(Index sh, b)] -> ((Index sh, b) -> ST s ()) -> ST s ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [(Index sh, b)]
xs (((Index sh, b) -> ST s ()) -> ST s ())
-> ((Index sh, b) -> ST s ()) -> ST s ()
forall a b. (a -> b) -> a -> b
$ \(Index sh
ix,b
b) -> Array (ST s) sh a -> Index sh -> (a -> a) -> ST s ()
forall (m :: * -> *) sh a.
(PrimMonad m, Indexed sh, Storable a) =>
Array m sh a -> Index sh -> (a -> a) -> m ()
MutArray.update Array (ST s) sh a
marr Index sh
ix ((a -> a) -> ST s ()) -> (a -> a) -> ST s ()
forall a b. (a -> b) -> a -> b
$ (a -> b -> a) -> b -> a -> a
forall a b c. (a -> b -> c) -> b -> a -> c
flip a -> b -> a
f b
b
   Array (ST s) sh a -> ST s (Array sh a)
forall (m :: * -> *) sh a.
(PrimMonad m, C sh, Storable a) =>
Array m sh a -> m (Array sh a)
unsafeFreeze Array (ST s) sh a
marr)

fromAssociations ::
   (Shape.Indexed sh, Storable a) =>
   a -> sh -> [(Shape.Index sh, a)] -> Array sh a
fromAssociations :: a -> sh -> [(Index sh, a)] -> Array sh a
fromAssociations a
a sh
sh [(Index sh, a)]
xs = (forall s. ST s (Array sh a)) -> Array sh a
forall a. (forall s. ST s a) -> a
runST (do
   Array (ST s) sh a
marr <- sh -> a -> ST s (Array (ST s) sh a)
forall (m :: * -> *) sh a.
(PrimMonad m, C sh, Storable a) =>
sh -> a -> m (Array m sh a)
MutArray.new sh
sh a
a
   [(Index sh, a)] -> ((Index sh, a) -> ST s ()) -> ST s ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [(Index sh, a)]
xs (((Index sh, a) -> ST s ()) -> ST s ())
-> ((Index sh, a) -> ST s ()) -> ST s ()
forall a b. (a -> b) -> a -> b
$ (Index sh -> a -> ST s ()) -> (Index sh, a) -> ST s ()
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry ((Index sh -> a -> ST s ()) -> (Index sh, a) -> ST s ())
-> (Index sh -> a -> ST s ()) -> (Index sh, a) -> ST s ()
forall a b. (a -> b) -> a -> b
$ Array (ST s) sh a -> Index sh -> a -> ST s ()
forall (m :: * -> *) sh a.
(PrimMonad m, Indexed sh, Storable a) =>
Array m sh a -> Index sh -> a -> m ()
MutArray.write Array (ST s) sh a
marr
   Array (ST s) sh a -> ST s (Array sh a)
forall (m :: * -> *) sh a.
(PrimMonad m, C sh, Storable a) =>
Array m sh a -> m (Array sh a)
unsafeFreeze Array (ST s) sh a
marr)


freeze ::
   (PrimMonad m, Shape.C sh, Storable a) =>
   MutArray.Array m sh a -> m (Array sh a)
freeze :: Array m sh a -> m (Array sh a)
freeze (MutArray.Array sh
sh MutablePtr a
fptr) =
   IO (Array sh a) -> m (Array sh a)
forall (m :: * -> *) a. PrimMonad m => IO a -> m a
unsafeIOToPrim (IO (Array sh a) -> m (Array sh a))
-> IO (Array sh a) -> m (Array sh a)
forall a b. (a -> b) -> a -> b
$
   (ForeignPtr a -> Array sh a)
-> IO (ForeignPtr a) -> IO (Array sh a)
forall (m :: * -> *) a1 r. Monad m => (a1 -> r) -> m a1 -> m r
liftM (sh -> ForeignPtr a -> Array sh a
forall sh a. sh -> ForeignPtr a -> Array sh a
Array sh
sh) (IO (ForeignPtr a) -> IO (Array sh a))
-> IO (ForeignPtr a) -> IO (Array sh a)
forall a b. (a -> b) -> a -> b
$ Int -> MutablePtr a -> IO (ForeignPtr a)
forall a. Storable a => Int -> MutablePtr a -> IO (ForeignPtr a)
Alloc.freeze (sh -> Int
forall sh. C sh => sh -> Int
Shape.size sh
sh) MutablePtr a
fptr

thaw ::
   (PrimMonad m, Shape.C sh, Storable a) =>
   Array sh a -> m (MutArray.Array m sh a)
thaw :: Array sh a -> m (Array m sh a)
thaw (Array sh
sh ForeignPtr a
fptr) =
   IO (Array m sh a) -> m (Array m sh a)
forall (m :: * -> *) a. PrimMonad m => IO a -> m a
unsafeIOToPrim (IO (Array m sh a) -> m (Array m sh a))
-> IO (Array m sh a) -> m (Array m sh a)
forall a b. (a -> b) -> a -> b
$
   (MutablePtr a -> Array m sh a)
-> IO (MutablePtr a) -> IO (Array m sh a)
forall (m :: * -> *) a1 r. Monad m => (a1 -> r) -> m a1 -> m r
liftM (sh -> MutablePtr a -> Array m sh a
forall (m :: * -> *) sh a. sh -> MutablePtr a -> Array m sh a
MutArray.Array sh
sh) (IO (MutablePtr a) -> IO (Array m sh a))
-> IO (MutablePtr a) -> IO (Array m sh a)
forall a b. (a -> b) -> a -> b
$ Int -> ForeignPtr a -> IO (MutablePtr a)
forall a. Storable a => Int -> ForeignPtr a -> IO (MutablePtr a)
Alloc.thaw (sh -> Int
forall sh. C sh => sh -> Int
Shape.size sh
sh) ForeignPtr a
fptr

unsafeFreeze ::
   (PrimMonad m, Shape.C sh, Storable a) =>
   MutArray.Array m sh a -> m (Array sh a)
unsafeFreeze :: Array m sh a -> m (Array sh a)
unsafeFreeze (MutArray.Array sh
sh MutablePtr a
fptr) =
   IO (Array sh a) -> m (Array sh a)
forall (m :: * -> *) a. PrimMonad m => IO a -> m a
unsafeIOToPrim (IO (Array sh a) -> m (Array sh a))
-> IO (Array sh a) -> m (Array sh a)
forall a b. (a -> b) -> a -> b
$
   (ForeignPtr a -> Array sh a)
-> IO (ForeignPtr a) -> IO (Array sh a)
forall (m :: * -> *) a1 r. Monad m => (a1 -> r) -> m a1 -> m r
liftM (sh -> ForeignPtr a -> Array sh a
forall sh a. sh -> ForeignPtr a -> Array sh a
Array sh
sh) (IO (ForeignPtr a) -> IO (Array sh a))
-> IO (ForeignPtr a) -> IO (Array sh a)
forall a b. (a -> b) -> a -> b
$ Int -> MutablePtr a -> IO (ForeignPtr a)
forall a. Storable a => Int -> MutablePtr a -> IO (ForeignPtr a)
Alloc.freezeInplace (sh -> Int
forall sh. C sh => sh -> Int
Shape.size sh
sh) MutablePtr a
fptr

unsafeThaw ::
   (PrimMonad m, Shape.C sh, Storable a) =>
   Array sh a -> m (MutArray.Array m sh a)
unsafeThaw :: Array sh a -> m (Array m sh a)
unsafeThaw (Array sh
sh ForeignPtr a
fptr) =
   IO (Array m sh a) -> m (Array m sh a)
forall (m :: * -> *) a. PrimMonad m => IO a -> m a
unsafeIOToPrim (IO (Array m sh a) -> m (Array m sh a))
-> IO (Array m sh a) -> m (Array m sh a)
forall a b. (a -> b) -> a -> b
$
   (MutablePtr a -> Array m sh a)
-> IO (MutablePtr a) -> IO (Array m sh a)
forall (m :: * -> *) a1 r. Monad m => (a1 -> r) -> m a1 -> m r
liftM (sh -> MutablePtr a -> Array m sh a
forall (m :: * -> *) sh a. sh -> MutablePtr a -> Array m sh a
MutArray.Array sh
sh) (IO (MutablePtr a) -> IO (Array m sh a))
-> IO (MutablePtr a) -> IO (Array m sh a)
forall a b. (a -> b) -> a -> b
$ Int -> ForeignPtr a -> IO (MutablePtr a)
forall a. Storable a => Int -> ForeignPtr a -> IO (MutablePtr a)
Alloc.thawInplace (sh -> Int
forall sh. C sh => sh -> Int
Shape.size sh
sh) ForeignPtr a
fptr