{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{- |
The functions in this module miss any bound checking.
-}
module Data.Array.Comfort.Storable.Unchecked (
   Priv.Array(Array, shape, buffer),
   Priv.reshape,
   mapShape,

   (Priv.!),
   unsafeCreate,
   unsafeCreateWithSize,
   unsafeCreateWithSizeAndResult,
   Priv.toList,
   Priv.fromList,
   Priv.vectorFromList,

   map,
   mapWithIndex,
   zipWith,
   (Priv.//),
   Priv.accumulate,
   Priv.fromAssociations,

   singleton,
   append,
   take, drop,
   takeLeft, takeRight, split,

   sum, product,
   foldl,
   ) where

import qualified Data.Array.Comfort.Storable.Unchecked.Monadic as Monadic
import qualified Data.Array.Comfort.Storable.Private as Priv
import qualified Data.Array.Comfort.Storable.Memory as Memory
import qualified Data.Array.Comfort.Shape as Shape
import Data.Array.Comfort.Storable.Private (Array(Array), mapShape)
import Data.Array.Comfort.Shape ((:+:)((:+:)))

import System.IO.Unsafe (unsafePerformIO)
import Foreign.Marshal.Array (copyArray, advancePtr)
import Foreign.Storable (Storable, poke, peek)
import Foreign.ForeignPtr (withForeignPtr)
import Foreign.Ptr (Ptr)

import Control.Monad.ST (runST)
import Control.Applicative (liftA2)

import qualified Data.List as List

import Prelude hiding (map, zipWith, foldl, take, drop, sum, product)


unsafeCreate ::
   (Shape.C sh, Storable a) =>
   sh -> (Ptr a -> IO ()) -> Array sh a
unsafeCreate sh arr = runST (Monadic.unsafeCreate sh arr)

unsafeCreateWithSize ::
   (Shape.C sh, Storable a) =>
   sh -> (Int -> Ptr a -> IO ()) -> Array sh a
unsafeCreateWithSize sh arr = runST (Monadic.unsafeCreateWithSize sh arr)

unsafeCreateWithSizeAndResult ::
   (Shape.C sh, Storable a) =>
   sh -> (Int -> Ptr a -> IO b) -> (Array sh a, b)
unsafeCreateWithSizeAndResult sh arr =
   runST (Monadic.unsafeCreateWithSizeAndResult sh arr)


map ::
   (Shape.C sh, Storable a, Storable b) =>
   (a -> b) -> Array sh a -> Array sh b
map f (Array sh a) =
   unsafeCreate sh $ \dstPtr ->
   withForeignPtr a $ \srcPtr ->
   sequence_ $ List.take (Shape.size sh) $
      List.zipWith
         (\src dst -> poke dst . f =<< peek src)
         (iterate (flip advancePtr 1) srcPtr)
         (iterate (flip advancePtr 1) dstPtr)

mapWithIndex ::
   (Shape.Indexed sh, Shape.Index sh ~ ix, Storable a, Storable b) =>
   (ix -> a -> b) -> Array sh a -> Array sh b
mapWithIndex f (Array sh a) =
   unsafeCreate sh $ \dstPtr ->
   withForeignPtr a $ \srcPtr ->
   sequence_ $
      List.zipWith3
         (\ix src dst -> poke dst . f ix =<< peek src)
         (Shape.indices sh)
         (iterate (flip advancePtr 1) srcPtr)
         (iterate (flip advancePtr 1) dstPtr)

zipWith ::
   (Shape.C sh, Storable a, Storable b, Storable c) =>
   (a -> b -> c) -> Array sh a -> Array sh b -> Array sh c
zipWith f (Array _sh a) (Array sh b) =
   unsafeCreate sh $ \dstPtr ->
   withForeignPtr a $ \srcAPtr ->
   withForeignPtr b $ \srcBPtr ->
   sequence_ $ List.take (Shape.size sh) $
      zipWith3
         (\srcA srcB dst -> poke dst =<< liftA2 f (peek srcA) (peek srcB))
         (iterate (flip advancePtr 1) srcAPtr)
         (iterate (flip advancePtr 1) srcBPtr)
         (iterate (flip advancePtr 1) dstPtr)


singleton :: (Storable a) => a -> Array () a
singleton a = unsafeCreate () $ flip poke a

append ::
   (Shape.C shx, Shape.C shy, Storable a) =>
   Array shx a -> Array shy a -> Array (shx:+:shy) a
append (Array shX x) (Array shY y) =
   unsafeCreate (shX:+:shY) $ \zPtr ->
   withForeignPtr x $ \xPtr ->
   withForeignPtr y $ \yPtr -> do
      let sizeX = Shape.size shX
      let sizeY = Shape.size shY
      copyArray zPtr xPtr sizeX
      copyArray (advancePtr zPtr sizeX) yPtr sizeY

take, drop ::
   (Integral n, Storable a) =>
   n -> Array (Shape.ZeroBased n) a -> Array (Shape.ZeroBased n) a
take n = takeLeft . splitN n
drop n = takeRight . splitN n

splitN ::
   (Integral n, Storable a) =>
   n -> Array (Shape.ZeroBased n) a ->
   Array (Shape.ZeroBased n :+: Shape.ZeroBased n) a
splitN n = mapShape (Shape.zeroBasedSplit n)

takeLeft ::
   (Shape.C sh0, Shape.C sh1, Storable a) =>
   Array (sh0:+:sh1) a -> Array sh0 a
takeLeft (Array (sh0 :+: _sh1) x) =
   unsafeCreateWithSize sh0 $ \k yPtr ->
   withForeignPtr x $ \xPtr -> copyArray yPtr xPtr k

takeRight ::
   (Shape.C sh0, Shape.C sh1, Storable a) =>
   Array (sh0:+:sh1) a -> Array sh1 a
takeRight (Array (sh0:+:sh1) x) =
   unsafeCreateWithSize sh1 $ \k yPtr ->
   withForeignPtr x $ \xPtr ->
      copyArray yPtr (advancePtr xPtr (Shape.size sh0)) k

split ::
   (Shape.C sh0, Shape.C sh1, Storable a) =>
   Array (sh0:+:sh1) a -> (Array sh0 a, Array sh1 a)
split x = (takeLeft x, takeRight x)



sum :: (Shape.C sh, Storable a, Num a) => Array sh a -> a
sum = foldl (+) 0

product :: (Shape.C sh, Storable a, Num a) => Array sh a -> a
product = foldl (*) 1

{-# INLINE foldl #-}
foldl :: (Shape.C sh, Storable a) => (b -> a -> b) -> b -> Array sh a -> b
foldl op a (Array sh x) = unsafePerformIO $
   withForeignPtr x $ \xPtr ->
      Memory.foldl (const op) a (Shape.size sh) xPtr 1