module Data.Array.Comfort.Boxed (
   Array,
   shape,
   reshape,
   mapShape,
   (!),
   Array.toList,
   Array.fromList,
   Array.vectorFromList,
   toAssociations,
   fromMap,
   toMap,
   fromContainer,
   toContainer,
   indices,
   Array.replicate,

   Array.map,
   zipWith,
   (//),
   accumulate,
   fromAssociations,
   ) where

import qualified Data.Array.Comfort.Boxed.Unchecked as Array
import qualified Data.Array.Comfort.Container as Container
import qualified Data.Array.Comfort.Check as Check
import qualified Data.Array.Comfort.Shape as Shape
import Data.Array.Comfort.Boxed.Unchecked (Array(Array))

import qualified Data.Primitive.Array as Prim

import qualified Control.Monad.Primitive as PrimM
import Control.Monad.ST (runST)
import Control.Applicative ((<$>))

import qualified Data.Foldable as Fold
import qualified Data.Map as Map
import qualified Data.Set as Set
import Data.Map (Map)
import Data.Set (Set)
import Data.Foldable (forM_)

import Prelude hiding (zipWith, replicate)


shape :: Array.Array sh a -> sh
shape = Array.shape

reshape :: (Shape.C sh0, Shape.C sh1) => sh1 -> Array sh0 a -> Array sh1 a
reshape = Check.reshape "Boxed" shape Array.reshape

mapShape ::
   (Shape.C sh0, Shape.C sh1) => (sh0 -> sh1) -> Array sh0 a -> Array sh1 a
mapShape f arr = reshape (f $ shape arr) arr


indices :: (Shape.Indexed sh) => sh -> Array.Array sh (Shape.Index sh)
indices sh = Array.fromList sh $ Shape.indices sh

fromMap :: (Ord k) => Map k a -> Array (Set k) a
fromMap m = Array.fromList (Map.keysSet m) (Map.elems m)

toMap :: (Ord k) => Array (Set k) a -> Map k a
toMap arr = Map.fromAscList $ zip (Set.toAscList $ shape arr) (Array.toList arr)

fromContainer :: (Container.C f) => f a -> Array (Container.Shape f) a
fromContainer xs = Array.fromList (Container.toShape xs) (Fold.toList xs)

toContainer :: (Container.C f) => Array (Container.Shape f) a -> f a
toContainer arr = Container.fromList (Array.shape arr) (Array.toList arr)


infixl 9 !

(!) :: (Shape.Indexed sh) => Array sh a -> Shape.Index sh -> a
(!) (Array sh arr) ix =
   if Shape.inBounds sh ix
      then Prim.indexArray arr $ Shape.offset sh ix
      else error "Array.Comfort.Boxed.!: index out of bounds"


zipWith ::
   (Shape.C sh, Eq sh) =>
   (a -> b -> c) -> Array sh a -> Array sh b -> Array sh c
zipWith f a b =
   if shape a == shape b
      then Array.zipWith f a b
      else error "zipWith: shapes mismatch"


(//) ::
   (Shape.Indexed sh) => Array sh a -> [(Shape.Index sh, a)] -> Array sh a
(//) (Array sh arr) xs = runST (do
   marr <- Prim.thawArray arr 0 (Shape.size sh)
   forM_ xs $ \(ix,a) -> Prim.writeArray marr (Shape.offset sh ix) a
   Array sh <$> Prim.unsafeFreezeArray marr)

accumulate ::
   (Shape.Indexed sh) =>
   (a -> b -> a) -> Array sh a -> [(Shape.Index sh, b)] -> Array sh a
accumulate f (Array sh arr) xs = runST (do
   marr <- Prim.thawArray arr 0 (Shape.size sh)
   forM_ xs $ \(ix,b) -> updateArray marr (Shape.offset sh ix) $ flip f b
   Array sh <$> Prim.unsafeFreezeArray marr)

updateArray ::
   PrimM.PrimMonad m =>
   Prim.MutableArray (PrimM.PrimState m) a -> Int -> (a -> a) -> m ()
updateArray marr k f = Prim.writeArray marr k . f =<< Prim.readArray marr k

toAssociations :: (Shape.Indexed sh) => Array sh a -> [(Shape.Index sh, a)]
toAssociations arr = zip (Shape.indices $ shape arr) (Array.toList arr)

fromAssociations ::
   (Shape.Indexed sh) => a -> sh -> [(Shape.Index sh, a)] -> Array sh a
fromAssociations a sh xs = runST (do
   marr <- Prim.newArray (Shape.size sh) a
   forM_ xs $ \(ix,x) -> Prim.writeArray marr (Shape.offset sh ix) x
   Array sh <$> Prim.unsafeFreezeArray marr)