module Data.Array.Comfort.Storable (
   Array,
   shape,
   reshape,
   mapShape,

   (!),
   Array.toList,
   fromList,
   fromMap,
   Array.vectorFromList,
   sample,
   fromBoxed,
   toBoxed,

   Array.map,
   Array.mapWithIndex,
   (//),
   accumulate,
   fromAssociations,
   ) where

import qualified Data.Array.Comfort.Storable.Mutable.Unchecked as MutArrayNC
import qualified Data.Array.Comfort.Storable.Mutable as MutArray
import qualified Data.Array.Comfort.Storable.Unchecked as Array
import qualified Data.Array.Comfort.Boxed as BoxedArray
import qualified Data.Array.Comfort.Shape as Shape
import Data.Array.Comfort.Storable.Unchecked (Array)

import Foreign.Storable (Storable)

import Control.Monad.ST (runST)

import qualified Data.Map as Map
import qualified Data.List as List
import Data.Map (Map)
import Data.Set (Set)
import Data.Foldable (forM_)

import Text.Printf (printf)

import Prelude hiding (map)


shape :: Array sh a -> sh
shape = Array.shape

reshape :: (Shape.C sh0, Shape.C sh1) => sh1 -> Array sh0 a -> Array sh1 a
reshape sh1 arr =
   let n0 = Shape.size $ shape arr
       n1 = Shape.size sh1
   in if n0 == n1
         then Array.reshape sh1 arr
         else error $
              printf
                 ("Array.Comfort.Storable.reshape: " ++
                  "different sizes of old (%d) and new (%d) shape")
                 n0 n1

mapShape ::
   (Shape.C sh0, Shape.C sh1) => (sh0 -> sh1) -> Array sh0 a -> Array sh1 a
mapShape f arr = reshape (f $ shape arr) arr


fromList :: (Shape.C sh, Storable a) => sh -> [a] -> Array sh a
fromList sh arr = runST (MutArrayNC.unsafeFreeze =<< MutArray.fromList sh arr)

fromMap :: (Ord k, Storable a) => Map k a -> Array (Set k) a
fromMap m = fromList (Map.keysSet m) (Map.elems m)

sample ::
   (Shape.Indexed sh, Storable a) => sh -> (Shape.Index sh -> a) -> Array sh a
sample sh f = Array.fromList sh $ List.map f $ Shape.indices sh


fromBoxed :: (Shape.C sh, Storable a) => BoxedArray.Array sh a -> Array sh a
fromBoxed arr = Array.fromList (BoxedArray.shape arr) $ BoxedArray.toList arr

toBoxed :: (Shape.C sh, Storable a) => Array sh a -> BoxedArray.Array sh a
toBoxed arr = BoxedArray.fromList (Array.shape arr) $ Array.toList arr


infixl 9 !

(!) :: (Shape.Indexed sh, Storable a) => Array sh a -> Shape.Index sh -> a
(!) arr ix = runST (do
   marr <- MutArrayNC.unsafeThaw arr
   MutArray.read marr ix)


(//) ::
   (Shape.Indexed sh, Storable a) =>
   Array sh a -> [(Shape.Index sh, a)] -> Array sh a
(//) arr xs = runST (do
   marr <- MutArray.thaw arr
   forM_ xs $ uncurry $ MutArray.write marr
   MutArrayNC.unsafeFreeze marr)

accumulate ::
   (Shape.Indexed sh, Storable a) =>
   (a -> b -> a) -> Array sh a -> [(Shape.Index sh, b)] -> Array sh a
accumulate f arr xs = runST (do
   marr <- MutArray.thaw arr
   forM_ xs $ \(ix,b) -> MutArray.update marr ix $ flip f b
   MutArrayNC.unsafeFreeze marr)

fromAssociations ::
   (Shape.Indexed sh, Storable a) =>
   sh -> a -> [(Shape.Index sh, a)] -> Array sh a
fromAssociations sh a xs = runST (do
   marr <- MutArray.new sh a
   forM_ xs $ uncurry $ MutArray.write marr
   MutArrayNC.unsafeFreeze marr)