{-# LANGUAGE TypeFamilies #-}
module Data.Array.Comfort.Storable (
   Array,
   shape,
   reshape,
   mapShape,

   (!),
   Array.toList,
   Array.vectorFromList,
   toAssociations,
   fromList,
   fromMap,
   fromContainer,
   toContainer,
   sample,
   fromBoxed,
   toBoxed,

   Array.map,
   Array.mapWithIndex,
   zipWith,
   (//),
   accumulate,
   fromAssociations,

   Array.singleton,
   Array.append,
   Array.take, Array.drop,
   Array.takeLeft, Array.takeRight, Array.split,

   Array.sum, Array.product,
   minimum, argMinimum,
   maximum, argMaximum,
   limits,
   Array.foldl,
   foldl1,
   foldMap,
   ) where

import qualified Data.Array.Comfort.Storable.Mutable.Unchecked as MutArrayNC
import qualified Data.Array.Comfort.Storable.Mutable as MutArray
import qualified Data.Array.Comfort.Storable.Unchecked as Array
import qualified Data.Array.Comfort.Storable.Memory as Memory
import qualified Data.Array.Comfort.Container as Container
import qualified Data.Array.Comfort.Boxed as BoxedArray
import qualified Data.Array.Comfort.Check as Check
import qualified Data.Array.Comfort.Shape as Shape
import Data.Array.Comfort.Storable.Unchecked (Array(Array))

import System.IO.Unsafe (unsafePerformIO)
import Foreign.Storable (Storable)
import Foreign.ForeignPtr (withForeignPtr)

import Control.Monad.ST (runST)

import qualified Data.Map as Map
import qualified Data.Foldable as Fold
import qualified Data.List as List
import qualified Data.Tuple.Strict as StrictTuple
import Data.Map (Map)
import Data.Set (Set)
import Data.Foldable (forM_)
import Data.Semigroup
         (Semigroup, (<>), Min(Min,getMin), Max(Max,getMax), Arg(Arg))

import Prelude2010 hiding (map, zipWith, foldl1, minimum, maximum)
import Prelude ()


shape :: Array sh a -> sh
shape = Array.shape

reshape :: (Shape.C sh0, Shape.C sh1) => sh1 -> Array sh0 a -> Array sh1 a
reshape = Check.reshape "Storable" shape Array.reshape

mapShape ::
   (Shape.C sh0, Shape.C sh1) => (sh0 -> sh1) -> Array sh0 a -> Array sh1 a
mapShape f arr = reshape (f $ shape arr) arr


fromList :: (Shape.C sh, Storable a) => sh -> [a] -> Array sh a
fromList sh arr = runST (MutArrayNC.unsafeFreeze =<< MutArray.fromList sh arr)

fromMap :: (Ord k, Storable a) => Map k a -> Array (Set k) a
fromMap m = fromList (Map.keysSet m) (Map.elems m)

fromContainer ::
   (Container.C f, Storable a) => f a -> Array (Container.Shape f) a
fromContainer xs = fromList (Container.toShape xs) (Fold.toList xs)

toContainer ::
   (Container.C f, Storable a) => Array (Container.Shape f) a -> f a
toContainer arr = Container.fromList (Array.shape arr) (Array.toList arr)

sample ::
   (Shape.Indexed sh, Storable a) => sh -> (Shape.Index sh -> a) -> Array sh a
sample sh f = Array.fromList sh $ List.map f $ Shape.indices sh


fromBoxed :: (Shape.C sh, Storable a) => BoxedArray.Array sh a -> Array sh a
fromBoxed arr = Array.fromList (BoxedArray.shape arr) $ BoxedArray.toList arr

toBoxed :: (Shape.C sh, Storable a) => Array sh a -> BoxedArray.Array sh a
toBoxed arr = BoxedArray.fromList (Array.shape arr) $ Array.toList arr

toAssociations ::
   (Shape.Indexed sh, Storable a) => Array sh a -> [(Shape.Index sh, a)]
toAssociations arr = zip (Shape.indices $ shape arr) (Array.toList arr)


infixl 9 !

(!) :: (Shape.Indexed sh, Storable a) => Array sh a -> Shape.Index sh -> a
(!) arr ix = runST (do
   marr <- MutArrayNC.unsafeThaw arr
   MutArray.read marr ix)


zipWith ::
   (Shape.C sh, Eq sh, Storable a, Storable b, Storable c) =>
   (a -> b -> c) -> Array sh a -> Array sh b -> Array sh c
zipWith f a b =
   if shape a == shape b
      then Array.zipWith f a b
      else error "zipWith: shapes mismatch"

(//) ::
   (Shape.Indexed sh, Storable a) =>
   Array sh a -> [(Shape.Index sh, a)] -> Array sh a
(//) arr xs = runST (do
   marr <- MutArray.thaw arr
   forM_ xs $ uncurry $ MutArray.write marr
   MutArrayNC.unsafeFreeze marr)

accumulate ::
   (Shape.Indexed sh, Storable a) =>
   (a -> b -> a) -> Array sh a -> [(Shape.Index sh, b)] -> Array sh a
accumulate f arr xs = runST (do
   marr <- MutArray.thaw arr
   forM_ xs $ \(ix,b) -> MutArray.update marr ix $ flip f b
   MutArrayNC.unsafeFreeze marr)

fromAssociations ::
   (Shape.Indexed sh, Storable a) =>
   a -> sh -> [(Shape.Index sh, a)] -> Array sh a
fromAssociations a sh xs = runST (do
   marr <- MutArray.new sh a
   forM_ xs $ uncurry $ MutArray.write marr
   MutArrayNC.unsafeFreeze marr)



{- |
It is a checked error if the vector is empty.
-}
minimum, maximum :: (Shape.C sh, Storable a, Ord a) => Array sh a -> a
minimum = foldl1 min
maximum = foldl1 max

{-# INLINE foldl1 #-}
foldl1 :: (Shape.C sh, Storable a) => (a -> a -> a) -> Array sh a -> a
foldl1 op (Array sh x) = unsafePerformIO $
   withForeignPtr x $ \xPtr ->
      Memory.foldl1 (const id) op (Shape.size sh) xPtr 1

{- |
> limits x = (minimum x, maximum x)
-}
limits :: (Shape.C sh, Storable a, Ord a) => Array sh a -> (a,a)
limits = StrictTuple.mapPair (getMin, getMax) . foldMap (\x -> (Min x, Max x))

{-# INLINE foldMap #-}
foldMap ::
   (Shape.C sh, Storable a, Ord a, Semigroup m) => (a -> m) -> Array sh a -> m
foldMap f (Array sh x) = unsafePerformIO $
   withForeignPtr x $ \xPtr ->
      Memory.foldl1 (const f) (<>) (Shape.size sh) xPtr 1


argMinimum, argMaximum ::
   (Shape.InvIndexed sh, Shape.Index sh ~ ix, Storable a, Ord a) =>
   Array sh a -> (ix,a)
argMinimum xs = unArg xs $ getMin $ foldMapWithIndex (\k x -> Min (Arg x k)) xs
argMaximum xs = unArg xs $ getMax $ foldMapWithIndex (\k x -> Max (Arg x k)) xs

unArg ::
   (Shape.InvIndexed sh) => Array sh a -> Arg a Int -> (Shape.Index sh, a)
unArg xs (Arg x k) = (Shape.indexFromOffset (Array.shape xs) k, x)

{-# INLINE foldMapWithIndex #-}
foldMapWithIndex ::
   (Shape.C sh, Storable a, Semigroup m) => (Int -> a -> m) -> Array sh a -> m
foldMapWithIndex f (Array sh x) = unsafePerformIO $
   withForeignPtr x $ \xPtr -> Memory.foldl1 f (<>) (Shape.size sh) xPtr 1