{-# LANGUAGE TypeFamilies #-}
module Data.Array.Comfort.Storable.Private where

import qualified Data.Array.Comfort.Storable.Mutable.Private as MutArray
import qualified Data.Array.Comfort.Shape as Shape
import Data.Semigroup (Semigroup((<>)))
import Data.Monoid (Monoid(mempty, mappend))

import qualified Foreign.Marshal.Array.Guarded as Alloc
import Foreign.Marshal.Array (copyArray, advancePtr)
import Foreign.ForeignPtr (ForeignPtr, withForeignPtr)
import Foreign.Storable (Storable)

import System.IO.Unsafe (unsafePerformIO)

import Control.DeepSeq (NFData, rnf)
import Control.Monad.Primitive (PrimMonad, unsafeIOToPrim)
import Control.Monad.ST (runST)
import Control.Monad (liftM)

import Data.Foldable (forM_)


data Array sh a =
   Array {
      forall sh a. Array sh a -> sh
shape :: sh,
      forall sh a. Array sh a -> ForeignPtr a
buffer :: ForeignPtr a
   }

instance (Shape.C sh, Show sh, Storable a, Show a) => Show (Array sh a) where
   showsPrec :: Int -> Array sh a -> ShowS
showsPrec Int
p Array sh a
arr =
      Bool -> ShowS -> ShowS
showParen (Int
pforall a. Ord a => a -> a -> Bool
>Int
10) forall a b. (a -> b) -> a -> b
$ String -> ShowS
showString forall a b. (a -> b) -> a -> b
$ forall a. (forall s. ST s a) -> a
runST (forall (m :: * -> *) sh a.
(PrimMonad m, C sh, Show sh, Storable a, Show a) =>
Array m sh a -> m String
MutArray.show forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (m :: * -> *) sh a.
(PrimMonad m, C sh, Storable a) =>
Array sh a -> m (Array m sh a)
unsafeThaw Array sh a
arr)

instance (NFData sh) => NFData (Array sh a) where
   rnf :: Array sh a -> ()
rnf (Array sh
sh ForeignPtr a
fptr) = seq :: forall a b. a -> b -> b
seq ForeignPtr a
fptr (forall a. NFData a => a -> ()
rnf sh
sh)

instance (Shape.C sh, Eq sh, Storable a, Eq a) => Eq (Array sh a) where
   a :: Array sh a
a@(Array sh
sha ForeignPtr a
_) == :: Array sh a -> Array sh a -> Bool
== b :: Array sh a
b@(Array sh
shb ForeignPtr a
_)  =  sh
shaforall a. Eq a => a -> a -> Bool
==sh
shb Bool -> Bool -> Bool
&& forall sh a. (C sh, Storable a) => Array sh a -> [a]
toList Array sh a
a forall a. Eq a => a -> a -> Bool
== forall sh a. (C sh, Storable a) => Array sh a -> [a]
toList Array sh a
b

reshape :: sh1 -> Array sh0 a -> Array sh1 a
reshape :: forall sh1 sh0 a. sh1 -> Array sh0 a -> Array sh1 a
reshape sh1
sh (Array sh0
_ ForeignPtr a
fptr) = forall sh a. sh -> ForeignPtr a -> Array sh a
Array sh1
sh ForeignPtr a
fptr

mapShape :: (sh0 -> sh1) -> Array sh0 a -> Array sh1 a
mapShape :: forall sh0 sh1 a. (sh0 -> sh1) -> Array sh0 a -> Array sh1 a
mapShape sh0 -> sh1
f Array sh0 a
arr = forall sh1 sh0 a. sh1 -> Array sh0 a -> Array sh1 a
reshape (sh0 -> sh1
f forall a b. (a -> b) -> a -> b
$ forall sh a. Array sh a -> sh
shape Array sh0 a
arr) Array sh0 a
arr


infixl 9 !

(!) :: (Shape.Indexed sh, Storable a) => Array sh a -> Shape.Index sh -> a
! :: forall sh a.
(Indexed sh, Storable a) =>
Array sh a -> Index sh -> a
(!) Array sh a
arr Index sh
ix = forall a. (forall s. ST s a) -> a
runST (forall a b c. (a -> b -> c) -> b -> a -> c
flip forall (m :: * -> *) sh a.
(PrimMonad m, Indexed sh, Storable a) =>
Array m sh a -> Index sh -> m a
MutArray.read Index sh
ix forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (m :: * -> *) sh a.
(PrimMonad m, C sh, Storable a) =>
Array sh a -> m (Array m sh a)
unsafeThaw Array sh a
arr)

toList :: (Shape.C sh, Storable a) => Array sh a -> [a]
toList :: forall sh a. (C sh, Storable a) => Array sh a -> [a]
toList Array sh a
arr = forall a. (forall s. ST s a) -> a
runST (forall (m :: * -> *) sh a.
(PrimMonad m, C sh, Storable a) =>
Array m sh a -> m [a]
MutArray.toList forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (m :: * -> *) sh a.
(PrimMonad m, C sh, Storable a) =>
Array sh a -> m (Array m sh a)
unsafeThaw Array sh a
arr)

fromList :: (Shape.C sh, Storable a) => sh -> [a] -> Array sh a
fromList :: forall sh a. (C sh, Storable a) => sh -> [a] -> Array sh a
fromList sh
sh [a]
arr = forall a. (forall s. ST s a) -> a
runST (forall (m :: * -> *) sh a.
(PrimMonad m, C sh, Storable a) =>
Array m sh a -> m (Array sh a)
unsafeFreeze forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (m :: * -> *) sh a.
(PrimMonad m, C sh, Storable a) =>
sh -> [a] -> m (Array m sh a)
MutArray.fromList sh
sh [a]
arr)

vectorFromList :: (Storable a) => [a] -> Array (Shape.ZeroBased Int) a
vectorFromList :: forall a. Storable a => [a] -> Array (ZeroBased Int) a
vectorFromList [a]
arr = forall a. (forall s. ST s a) -> a
runST (forall (m :: * -> *) sh a.
(PrimMonad m, C sh, Storable a) =>
Array m sh a -> m (Array sh a)
unsafeFreeze forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (m :: * -> *) a.
(PrimMonad m, Storable a) =>
[a] -> m (Array m (ZeroBased Int) a)
MutArray.vectorFromList [a]
arr)


(//) ::
   (Shape.Indexed sh, Storable a) =>
   Array sh a -> [(Shape.Index sh, a)] -> Array sh a
// :: forall sh a.
(Indexed sh, Storable a) =>
Array sh a -> [(Index sh, a)] -> Array sh a
(//) Array sh a
arr [(Index sh, a)]
xs = forall a. (forall s. ST s a) -> a
runST (do
   Array (ST s) sh a
marr <- forall (m :: * -> *) sh a.
(PrimMonad m, C sh, Storable a) =>
Array sh a -> m (Array m sh a)
thaw Array sh a
arr
   forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [(Index sh, a)]
xs forall a b. (a -> b) -> a -> b
$ forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) sh a.
(PrimMonad m, Indexed sh, Storable a) =>
Array m sh a -> Index sh -> a -> m ()
MutArray.write Array (ST s) sh a
marr
   forall (m :: * -> *) sh a.
(PrimMonad m, C sh, Storable a) =>
Array m sh a -> m (Array sh a)
unsafeFreeze Array (ST s) sh a
marr)

accumulate ::
   (Shape.Indexed sh, Storable a) =>
   (a -> b -> a) -> Array sh a -> [(Shape.Index sh, b)] -> Array sh a
accumulate :: forall sh a b.
(Indexed sh, Storable a) =>
(a -> b -> a) -> Array sh a -> [(Index sh, b)] -> Array sh a
accumulate a -> b -> a
f Array sh a
arr [(Index sh, b)]
xs = forall a. (forall s. ST s a) -> a
runST (do
   Array (ST s) sh a
marr <- forall (m :: * -> *) sh a.
(PrimMonad m, C sh, Storable a) =>
Array sh a -> m (Array m sh a)
thaw Array sh a
arr
   forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [(Index sh, b)]
xs forall a b. (a -> b) -> a -> b
$ \(Index sh
ix,b
b) -> forall (m :: * -> *) sh a.
(PrimMonad m, Indexed sh, Storable a) =>
Array m sh a -> Index sh -> (a -> a) -> m ()
MutArray.update Array (ST s) sh a
marr Index sh
ix forall a b. (a -> b) -> a -> b
$ forall a b c. (a -> b -> c) -> b -> a -> c
flip a -> b -> a
f b
b
   forall (m :: * -> *) sh a.
(PrimMonad m, C sh, Storable a) =>
Array m sh a -> m (Array sh a)
unsafeFreeze Array (ST s) sh a
marr)

fromAssociations ::
   (Shape.Indexed sh, Storable a) =>
   a -> sh -> [(Shape.Index sh, a)] -> Array sh a
fromAssociations :: forall sh a.
(Indexed sh, Storable a) =>
a -> sh -> [(Index sh, a)] -> Array sh a
fromAssociations a
a sh
sh [(Index sh, a)]
xs = forall a. (forall s. ST s a) -> a
runST (do
   Array (ST s) sh a
marr <- forall (m :: * -> *) sh a.
(PrimMonad m, C sh, Storable a) =>
sh -> a -> m (Array m sh a)
MutArray.new sh
sh a
a
   forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [(Index sh, a)]
xs forall a b. (a -> b) -> a -> b
$ forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) sh a.
(PrimMonad m, Indexed sh, Storable a) =>
Array m sh a -> Index sh -> a -> m ()
MutArray.write Array (ST s) sh a
marr
   forall (m :: * -> *) sh a.
(PrimMonad m, C sh, Storable a) =>
Array m sh a -> m (Array sh a)
unsafeFreeze Array (ST s) sh a
marr)


freeze ::
   (PrimMonad m, Shape.C sh, Storable a) =>
   MutArray.Array m sh a -> m (Array sh a)
freeze :: forall (m :: * -> *) sh a.
(PrimMonad m, C sh, Storable a) =>
Array m sh a -> m (Array sh a)
freeze (MutArray.Array sh
sh MutablePtr a
fptr) =
   forall (m :: * -> *) a. PrimMonad m => IO a -> m a
unsafeIOToPrim forall a b. (a -> b) -> a -> b
$
   forall (m :: * -> *) a1 r. Monad m => (a1 -> r) -> m a1 -> m r
liftM (forall sh a. sh -> ForeignPtr a -> Array sh a
Array sh
sh) forall a b. (a -> b) -> a -> b
$ forall a. Storable a => Int -> MutablePtr a -> IO (ForeignPtr a)
Alloc.freeze (forall sh. C sh => sh -> Int
Shape.size sh
sh) MutablePtr a
fptr

thaw ::
   (PrimMonad m, Shape.C sh, Storable a) =>
   Array sh a -> m (MutArray.Array m sh a)
thaw :: forall (m :: * -> *) sh a.
(PrimMonad m, C sh, Storable a) =>
Array sh a -> m (Array m sh a)
thaw (Array sh
sh ForeignPtr a
fptr) =
   forall (m :: * -> *) a. PrimMonad m => IO a -> m a
unsafeIOToPrim forall a b. (a -> b) -> a -> b
$
   forall (m :: * -> *) a1 r. Monad m => (a1 -> r) -> m a1 -> m r
liftM (forall (m :: * -> *) sh a. sh -> MutablePtr a -> Array m sh a
MutArray.Array sh
sh) forall a b. (a -> b) -> a -> b
$ forall a. Storable a => Int -> ForeignPtr a -> IO (MutablePtr a)
Alloc.thaw (forall sh. C sh => sh -> Int
Shape.size sh
sh) ForeignPtr a
fptr

unsafeFreeze ::
   (PrimMonad m, Shape.C sh, Storable a) =>
   MutArray.Array m sh a -> m (Array sh a)
unsafeFreeze :: forall (m :: * -> *) sh a.
(PrimMonad m, C sh, Storable a) =>
Array m sh a -> m (Array sh a)
unsafeFreeze (MutArray.Array sh
sh MutablePtr a
fptr) =
   forall (m :: * -> *) a. PrimMonad m => IO a -> m a
unsafeIOToPrim forall a b. (a -> b) -> a -> b
$
   forall (m :: * -> *) a1 r. Monad m => (a1 -> r) -> m a1 -> m r
liftM (forall sh a. sh -> ForeignPtr a -> Array sh a
Array sh
sh) forall a b. (a -> b) -> a -> b
$ forall a. Storable a => Int -> MutablePtr a -> IO (ForeignPtr a)
Alloc.freezeInplace (forall sh. C sh => sh -> Int
Shape.size sh
sh) MutablePtr a
fptr

unsafeThaw ::
   (PrimMonad m, Shape.C sh, Storable a) =>
   Array sh a -> m (MutArray.Array m sh a)
unsafeThaw :: forall (m :: * -> *) sh a.
(PrimMonad m, C sh, Storable a) =>
Array sh a -> m (Array m sh a)
unsafeThaw (Array sh
sh ForeignPtr a
fptr) =
   forall (m :: * -> *) a. PrimMonad m => IO a -> m a
unsafeIOToPrim forall a b. (a -> b) -> a -> b
$
   forall (m :: * -> *) a1 r. Monad m => (a1 -> r) -> m a1 -> m r
liftM (forall (m :: * -> *) sh a. sh -> MutablePtr a -> Array m sh a
MutArray.Array sh
sh) forall a b. (a -> b) -> a -> b
$ forall a. Storable a => Int -> ForeignPtr a -> IO (MutablePtr a)
Alloc.thawInplace (forall sh. C sh => sh -> Int
Shape.size sh
sh) ForeignPtr a
fptr


instance (Shape.AppendSemigroup sh, Storable a) => Semigroup (Array sh a) where
   <> :: Array sh a -> Array sh a -> Array sh a
(<>) = forall shx shy a shz.
(C shx, C shy, Storable a) =>
(shx -> shy -> shz) -> Array shx a -> Array shy a -> Array shz a
append forall sh. AppendSemigroup sh => sh -> sh -> sh
Shape.append

instance (Shape.AppendMonoid sh, Storable a) => Monoid (Array sh a) where
   mappend :: Array sh a -> Array sh a -> Array sh a
mappend = forall a. Semigroup a => a -> a -> a
(<>)
   mempty :: Array sh a
mempty = forall sh a. (C sh, Storable a) => sh -> [a] -> Array sh a
fromList forall sh. AppendMonoid sh => sh
Shape.empty []

append ::
   (Shape.C shx, Shape.C shy, Storable a) =>
   (shx -> shy -> shz) ->
   Array shx a -> Array shy a -> Array shz a
append :: forall shx shy a shz.
(C shx, C shy, Storable a) =>
(shx -> shy -> shz) -> Array shx a -> Array shy a -> Array shz a
append shx -> shy -> shz
appendShape (Array shx
shX ForeignPtr a
x) (Array shy
shY ForeignPtr a
y) =
   forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$
   let sizeX :: Int
sizeX = forall sh. C sh => sh -> Int
Shape.size shx
shX in
   let sizeY :: Int
sizeY = forall sh. C sh => sh -> Int
Shape.size shy
shY in
   forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (forall sh a. sh -> ForeignPtr a -> Array sh a
Array (shx -> shy -> shz
appendShape shx
shX shy
shY) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a, b) -> a
fst) forall a b. (a -> b) -> a -> b
$
      forall a b.
Storable a =>
Int -> (Ptr a -> IO b) -> IO (ForeignPtr a, b)
Alloc.create (Int
sizeXforall a. Num a => a -> a -> a
+Int
sizeY) forall a b. (a -> b) -> a -> b
$ \Ptr a
zPtr ->
      forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr a
x forall a b. (a -> b) -> a -> b
$ \Ptr a
xPtr ->
      forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr a
y forall a b. (a -> b) -> a -> b
$ \Ptr a
yPtr -> do
         forall a. Storable a => Ptr a -> Ptr a -> Int -> IO ()
copyArray Ptr a
zPtr Ptr a
xPtr Int
sizeX
         forall a. Storable a => Ptr a -> Ptr a -> Int -> IO ()
copyArray (forall a. Storable a => Ptr a -> Int -> Ptr a
advancePtr Ptr a
zPtr Int
sizeX) Ptr a
yPtr Int
sizeY