{-# LANGUAGE TypeFamilies #-}
module Data.Array.Comfort.Storable.Mutable.Private where

import qualified Data.Array.Comfort.Shape as Shape

import qualified Foreign.Marshal.Array.Guarded as Alloc
import Foreign.Marshal.Array (copyArray, pokeArray, peekArray)
import Foreign.Storable (Storable, pokeElemOff, peekElemOff)
import Foreign.ForeignPtr (ForeignPtr, withForeignPtr)
import Foreign.Ptr (Ptr)

import Control.Monad.Primitive (PrimMonad, unsafeIOToPrim)
import Control.Monad.ST (ST)
import Control.Monad (liftM)
import Control.Applicative ((<$>))

import Data.Tuple.HT (mapFst)

import qualified Prelude as P
import Prelude hiding (read, show)


data Array (m :: * -> *) sh a =
   Array {
      shape :: sh,
      buffer :: ForeignPtr a
   }

type STArray s = Array (ST s)
type IOArray = Array IO


copy ::
   (PrimMonad m, Shape.C sh, Storable a) =>
   Array m sh a -> m (Array m sh a)
copy (Array sh srcFPtr) =
   unsafeCreateWithSize sh $ \n dstPtr ->
   withForeignPtr srcFPtr $ \srcPtr ->
      copyArray dstPtr srcPtr n


create ::
   (Shape.C sh, Storable a) =>
   sh -> (Ptr a -> IO ()) -> IO (IOArray sh a)
create sh f = createWithSize sh $ const f

createWithSize ::
   (Shape.C sh, Storable a) =>
   sh -> (Int -> Ptr a -> IO ()) -> IO (IOArray sh a)
createWithSize sh f =
   fst <$> createWithSizeAndResult sh f

createWithSizeAndResult ::
   (Shape.C sh, Storable a) =>
   sh -> (Int -> Ptr a -> IO b) -> IO (IOArray sh a, b)
createWithSizeAndResult sh f =
   let size = Shape.size sh
   in fmap (mapFst (Array sh)) $ Alloc.create size $ f size


unsafeCreate ::
   (PrimMonad m, Shape.C sh, Storable a) =>
   sh -> (Ptr a -> IO ()) -> m (Array m sh a)
unsafeCreate sh f = unsafeCreateWithSize sh $ const f

unsafeCreateWithSize ::
   (PrimMonad m, Shape.C sh, Storable a) =>
   sh -> (Int -> Ptr a -> IO ()) -> m (Array m sh a)
unsafeCreateWithSize sh f =
   liftM fst $ unsafeCreateWithSizeAndResult sh f

unsafeCreateWithSizeAndResult ::
   (PrimMonad m, Shape.C sh, Storable a) =>
   sh -> (Int -> Ptr a -> IO b) -> m (Array m sh a, b)
unsafeCreateWithSizeAndResult sh f =
   unsafeIOToPrim $
   fmap (mapFst unsafeArrayIOToPrim) $ createWithSizeAndResult sh f

unsafeArrayIOToPrim :: (PrimMonad m) => IOArray sh a -> Array m sh a
unsafeArrayIOToPrim (Array sh fptr) = Array sh fptr


show ::
   (PrimMonad m, Shape.C sh, Show sh, Storable a, Show a) =>
   Array m sh a -> m String
show arr = do
   xs <- toList arr
   return $
      "StorableArray.fromList " ++ showsPrec 11 (shape arr) (' ' : P.show xs)

read ::
   (PrimMonad m, Shape.Indexed sh, Storable a) =>
   Array m sh a -> Shape.Index sh -> m a
read (Array sh fptr) ix =
   unsafeIOToPrim $ withForeignPtr fptr $ flip peekElemOff (Shape.offset sh ix)

write ::
   (PrimMonad m, Shape.Indexed sh, Storable a) =>
   Array m sh a -> Shape.Index sh -> a -> m ()
write (Array sh fptr) ix a =
   unsafeIOToPrim $
   withForeignPtr fptr $ \ptr -> pokeElemOff ptr (Shape.offset sh ix) a

update ::
   (PrimMonad m, Shape.Indexed sh, Storable a) =>
   Array m sh a -> Shape.Index sh -> (a -> a) -> m ()
update (Array sh fptr) ix f =
   unsafeIOToPrim $
   let k = Shape.offset sh ix
   in withForeignPtr fptr $ \ptr -> pokeElemOff ptr k . f =<< peekElemOff ptr k

new :: (PrimMonad m, Shape.C sh, Storable a) => sh -> a -> m (Array m sh a)
new sh x =
   unsafeCreateWithSize sh $ \size ptr -> pokeArray ptr $ replicate size x

toList :: (PrimMonad m, Shape.C sh, Storable a) => Array m sh a -> m [a]
toList (Array sh fptr) =
   unsafeIOToPrim $ withForeignPtr fptr $ peekArray (Shape.size sh)

fromList ::
   (PrimMonad m, Shape.C sh, Storable a) => sh -> [a] -> m (Array m sh a)
fromList sh xs =
   unsafeCreateWithSize sh $ \size ptr ->
      pokeArray ptr $ take size $
      xs ++
      repeat (error "Array.Comfort.Storable.fromList: list too short for shape")

vectorFromList ::
   (PrimMonad m, Storable a) => [a] -> m (Array m (Shape.ZeroBased Int) a)
vectorFromList xs =
   unsafeCreate (Shape.ZeroBased $ length xs) $ flip pokeArray xs