-- | Vastly extended primitive arrays. Some basic ideas are now modeled after the vector package,
-- especially the monadic mutable / pure immutable array system.
--
-- Note that in general only bulk operations should error out, error handling for index/read/write
-- is too costly. General usage should be to create data structures and run the DP code within an
-- error monad, but keep error handling to high-level operations.

module Data.PrimitiveArray.Class where

import           Control.Applicative (Applicative, pure, (<$>), (<*>))
import           Control.Exception (assert)
import           Control.Monad.Except
import           Control.Monad (forM_)
import           Control.Monad.Primitive (PrimMonad, liftPrim)
import           Control.Monad.ST (runST)
import           Data.Proxy
import           Data.Vector.Fusion.Util
import           Debug.Trace
import           GHC.Generics (Generic)
import           Prelude as P
import qualified Data.Vector.Fusion.Stream.Monadic as SM
import           GHC.Stack
import           Data.Kind (Constraint)

import           Data.PrimitiveArray.Index.Class



-- | Mutable version of an array.

data family MutArr (m :: * -> *) (arr :: *) :: *

-- | Associate a fill structure with each type of array (dense, sparse, ...).
--
-- Example: @type instance FillStruc (Sparse w v sh e) = (w sh)@ associates the type @(w sh)@, which
-- is of the same type as the underlying @w@ structure holding index information for a sparse array.

type family FillStruc arr :: *



-- | The core set of operations for pure and monadic arrays.

class (Index sh) => PrimArrayOps arr sh elm where

  -- ** Pure operations

  -- | Returns the bounds of an immutable array, again inclusive bounds: @ [lb..ub] @.
  upperBound :: arr sh elm -> LimitType sh

  -- | Extract a single element from the array. Generally unsafe as not bounds-checking is
  -- performed.
  unsafeIndex :: arr sh elm -> sh -> elm

  -- | Index into immutable array, but safe in case @sh@ is not part of the array.
  safeIndex :: arr sh elm -> sh -> Maybe elm

  -- | Savely transform the shape space of a table.
  transformShape :: Index sh' => (LimitType sh -> LimitType sh') -> arr sh elm -> arr sh' elm

  -- ** Monadic operations

  -- | Return the bounds of the array. All bounds are inclusive, as in @[lb..ub]@. Technically not
  -- monadic, but rather working on a monadic array.
  upperBoundM :: MutArr m (arr sh elm) -> LimitType sh

  -- | Given lower and upper bounds and a list of /all/ elements, produce a mutable array.
  fromListM :: PrimMonad m => LimitType sh -> [elm] -> m (MutArr m (arr sh elm))

  -- | Creates a new array with the given bounds with each element within the array being in an
  -- undefined state.
  newM :: PrimMonad m => LimitType sh -> m (MutArr m (arr sh elm))

  -- | Variant of 'newM' that requires a fill structure. Mostly for special / sparse structures
  -- (hence the @S@, also to be interpreted as "safe", since these functions won't fail with sparse
  -- structures).
  newSM :: (Monad m, PrimMonad m) => LimitType sh -> FillStruc (arr sh elm) -> m (MutArr m (arr sh elm))

  -- | Creates a new array with all elements being equal to 'elm'.
  newWithM :: PrimMonad m => LimitType sh -> elm -> m (MutArr m (arr sh elm))

  -- | Variant of 'newWithM'
  newWithSM :: (Monad m, PrimMonad m) => LimitType sh -> FillStruc (arr sh elm) -> elm -> m (MutArr m (arr sh elm))

  -- | Reads a single element in the array.
  readM :: PrimMonad m => MutArr m (arr sh elm) -> sh -> m elm

  -- | Read from the mutable array, but return @Nothing@ in case @sh@ does not exist. This will
  -- allow streaming DP combinators to "jump" over missing elements.
  --
  -- Should be used with @Stream.Monadic.mapMaybe@ to get efficient code.
  safeReadM :: (Monad m, PrimMonad m) => MutArr m (arr sh elm) -> sh -> m (Maybe elm)

  -- | Writes a single element in the array.
  writeM :: PrimMonad m => MutArr m (arr sh elm) -> sh -> elm -> m ()

  -- | Write into the mutable array, but if the index @sh@ does not exist, silently continue.
  safeWriteM :: (Monad m, PrimMonad m) => MutArr m (arr sh elm) -> sh -> elm -> m ()

  -- | Freezes a mutable array an returns its immutable version. This operation is /O(1)/ and both
  -- arrays share the same memory. Do not use the mutable array afterwards.
  unsafeFreezeM :: PrimMonad m => MutArr m (arr sh elm) -> m (arr sh elm)

  -- | Thaw an immutable array into a mutable one. Both versions share memory.
  unsafeThawM :: PrimMonad m => arr sh elm -> m (MutArr m (arr sh elm))


class PrimArrayMap arr sh e e' where
  -- -- | Map a function of type @elm -> e@ over the primitive array, returning another primitive array
  -- -- of same type and shape but different element.
  mapArray :: (e -> e') -> arr sh e -> arr sh e'


-- | Sum type of errors that can happen when using primitive arrays.

data PAErrors
  = PAEUpperBound
  deriving stock (PAErrors -> PAErrors -> Bool
(PAErrors -> PAErrors -> Bool)
-> (PAErrors -> PAErrors -> Bool) -> Eq PAErrors
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: PAErrors -> PAErrors -> Bool
$c/= :: PAErrors -> PAErrors -> Bool
== :: PAErrors -> PAErrors -> Bool
$c== :: PAErrors -> PAErrors -> Bool
Eq,(forall x. PAErrors -> Rep PAErrors x)
-> (forall x. Rep PAErrors x -> PAErrors) -> Generic PAErrors
forall x. Rep PAErrors x -> PAErrors
forall x. PAErrors -> Rep PAErrors x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
$cto :: forall x. Rep PAErrors x -> PAErrors
$cfrom :: forall x. PAErrors -> Rep PAErrors x
Generic)

instance Show PAErrors where
  show :: PAErrors -> String
show (PAErrors
PAEUpperBound) = String
"Upper bound is too large for @Int@ size!"



-- | Infix index operator. Performs minimal bounds-checking using assert in non-optimized code.
--
-- @(!)@ is rewritten from phase @[1]@ onwards into an optimized form. Before, it uses a very slow
-- form, that does bounds checking.

--(!) :: (HasCallStack, PrimArrayOps arr sh elm) => arr sh elm -> sh -> elm
(!) :: (PrimArrayOps arr sh elm) => arr sh elm -> sh -> elm
{-# Inline [1] (!) #-}
{-# Rules "unsafeIndex" [2] (!) = unsafeIndex #-}
(!) = \arr sh elm
arr sh
idx -> case arr sh elm -> sh -> Maybe elm
forall (arr :: * -> * -> *) sh elm.
PrimArrayOps arr sh elm =>
arr sh elm -> sh -> Maybe elm
safeIndex arr sh elm
arr sh
idx of
          Maybe elm
Nothing -> String -> elm
forall a. HasCallStack => String -> a
error (String -> elm) -> String -> elm
forall a b. (a -> b) -> a -> b
$ ([String], [String]) -> String
forall a. Show a => a -> String
show (LimitType sh -> [String]
forall i. Index i => LimitType i -> [String]
showBound (arr sh elm -> LimitType sh
forall (arr :: * -> * -> *) sh elm.
PrimArrayOps arr sh elm =>
arr sh elm -> LimitType sh
upperBound arr sh elm
arr), sh -> [String]
forall i. Index i => i -> [String]
showIndex sh
idx)
          Just elm
v  -> elm
v



-- | Return value at an index that might not exist.

(!?) :: PrimArrayOps arr sh elm => arr sh elm -> sh -> Maybe elm
{-# Inline (!?) #-}
!? :: arr sh elm -> sh -> Maybe elm
(!?) = arr sh elm -> sh -> Maybe elm
forall (arr :: * -> * -> *) sh elm.
PrimArrayOps arr sh elm =>
arr sh elm -> sh -> Maybe elm
safeIndex

-- | Returns true if the index is valid for the array.

inBoundsM :: (Monad m, PrimArrayOps arr sh elm) => MutArr m (arr sh elm) -> sh -> Bool
inBoundsM :: MutArr m (arr sh elm) -> sh -> Bool
inBoundsM MutArr m (arr sh elm)
marr sh
idx = LimitType sh -> sh -> Bool
forall i. Index i => LimitType i -> i -> Bool
inBounds (MutArr m (arr sh elm) -> LimitType sh
forall (arr :: * -> * -> *) sh elm (m :: * -> *).
PrimArrayOps arr sh elm =>
MutArr m (arr sh elm) -> LimitType sh
upperBoundM MutArr m (arr sh elm)
marr) sh
idx
{-# INLINE inBoundsM #-}

-- -- | Given two arrays with the same dimensionality, their respective starting
-- -- index, and how many steps to go in each dimension (in terms of a dimension
-- -- again), determine if the multidimensional slices have the same value at
-- -- all positions
-- --
-- -- TODO specialize for DIM1 (and maybe higher dim's) to use memcmp
-- 
-- sliceEq :: (Eq elm, PrimArrayOps arr sh elm) => arr sh elm -> sh -> arr sh elm -> sh -> sh -> Bool
-- sliceEq arr1 k1 arr2 k2 xtnd = assert ((inBounds arr1 k1) && (inBounds arr2 k2) && (inBounds arr1 $ k1 `addDim` xtnd) && (inBounds arr2 $ k2 `addDim` xtnd)) $ and res where
--   res = zipWith (==) xs ys
--   xs = P.map (unsafeIndex arr1) $ rangeList k1 xtnd
--   ys = P.map (unsafeIndex arr2) $ rangeList k2 xtnd
-- {-# INLINE sliceEq #-}

-- | Construct a mutable primitive array from a lower and an upper bound, a
-- default element, and a list of associations.

fromAssocsM
  :: (PrimMonad m, PrimArrayOps arr sh elm)
  => LimitType sh -> elm -> [(sh,elm)] -> m (MutArr m (arr sh elm))
fromAssocsM :: LimitType sh -> elm -> [(sh, elm)] -> m (MutArr m (arr sh elm))
fromAssocsM LimitType sh
ub elm
def [(sh, elm)]
xs = do
  MutArr m (arr sh elm)
ma <- LimitType sh -> elm -> m (MutArr m (arr sh elm))
forall (arr :: * -> * -> *) sh elm (m :: * -> *).
(PrimArrayOps arr sh elm, PrimMonad m) =>
LimitType sh -> elm -> m (MutArr m (arr sh elm))
newWithM LimitType sh
ub elm
def
--  let s = size ub
--  traceShow (s,length xs) $ when (s < length xs) $ error "bang"
  [(sh, elm)] -> ((sh, elm) -> m ()) -> m ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [(sh, elm)]
xs (((sh, elm) -> m ()) -> m ()) -> ((sh, elm) -> m ()) -> m ()
forall a b. (a -> b) -> a -> b
$ \(sh
k,elm
v) -> MutArr m (arr sh elm) -> sh -> elm -> m ()
forall (arr :: * -> * -> *) sh elm (m :: * -> *).
(PrimArrayOps arr sh elm, PrimMonad m) =>
MutArr m (arr sh elm) -> sh -> elm -> m ()
writeM MutArr m (arr sh elm)
ma sh
k elm
v
  MutArr m (arr sh elm) -> m (MutArr m (arr sh elm))
forall (m :: * -> *) a. Monad m => a -> m a
return MutArr m (arr sh elm)
ma
{-# INLINE fromAssocsM #-}

-- | Initialize an immutable array but stay within the primitive monad @m@.

newWithPA
  :: (PrimMonad m, PrimArrayOps arr sh elm)
  => LimitType sh
  -> elm
  -> m (arr sh elm)
newWithPA :: LimitType sh -> elm -> m (arr sh elm)
newWithPA LimitType sh
ub elm
def = do
  MutArr m (arr sh elm)
ma  LimitType sh -> elm -> m (MutArr m (arr sh elm))
forall (arr :: * -> * -> *) sh elm (m :: * -> *).
(PrimArrayOps arr sh elm, PrimMonad m) =>
LimitType sh -> elm -> m (MutArr m (arr sh elm))
newWithM LimitType sh
ub elm
def
  MutArr m (arr sh elm) -> m (arr sh elm)
forall (arr :: * -> * -> *) sh elm (m :: * -> *).
(PrimArrayOps arr sh elm, PrimMonad m) =>
MutArr m (arr sh elm) -> m (arr sh elm)
unsafeFreezeM MutArr m (arr sh elm)
ma
{-# Inlinable newWithPA #-}

-- | Initialize an immutable array with a fill structure.

newWithSPA
   (PrimMonad m, PrimArrayOps arr sh elm)
   LimitType sh
  -> FillStruc (arr sh elm)
   elm
   m (arr sh elm)
{-# Inlinable newWithSPA #-}
newWithSPA :: LimitType sh -> FillStruc (arr sh elm) -> elm -> m (arr sh elm)
newWithSPA LimitType sh
ub FillStruc (arr sh elm)
xs elm
def = do
  MutArr m (arr sh elm)
ma  LimitType sh
-> FillStruc (arr sh elm) -> elm -> m (MutArr m (arr sh elm))
forall (arr :: * -> * -> *) sh elm (m :: * -> *).
(PrimArrayOps arr sh elm, Monad m, PrimMonad m) =>
LimitType sh
-> FillStruc (arr sh elm) -> elm -> m (MutArr m (arr sh elm))
newWithSM LimitType sh
ub FillStruc (arr sh elm)
xs elm
def
  MutArr m (arr sh elm) -> m (arr sh elm)
forall (arr :: * -> * -> *) sh elm (m :: * -> *).
(PrimArrayOps arr sh elm, PrimMonad m) =>
MutArr m (arr sh elm) -> m (arr sh elm)
unsafeFreezeM MutArr m (arr sh elm)
ma

-- | Safely prepare a primitive array.
--
-- TODO Check if having a 'MonadError' instance degrades performance. (We
-- should see this once the test with NeedlemanWunsch is under way).

safeNewWithPA
  :: forall m arr sh elm
  . (PrimMonad m, MonadError PAErrors m, PrimArrayOps arr sh elm)
  => LimitType sh
  -> elm
  -> m (arr sh elm)
safeNewWithPA :: LimitType sh -> elm -> m (arr sh elm)
safeNewWithPA LimitType sh
ub elm
def = do
  case Except SizeError CellSize -> Either SizeError CellSize
forall e a. Except e a -> Either e a
runExcept (Except SizeError CellSize -> Either SizeError CellSize)
-> Except SizeError CellSize -> Either SizeError CellSize
forall a b. (a -> b) -> a -> b
$ Word -> [[Integer]] -> Except SizeError CellSize
forall (m :: * -> *).
Monad m =>
Word -> [[Integer]] -> ExceptT SizeError m CellSize
sizeIsValid Word
forall a. Bounded a => a
maxBound [LimitType sh -> [Integer]
forall i. Index i => LimitType i -> [Integer]
totalSize LimitType sh
ub] of
    Left  (SizeError String
_) -> PAErrors -> m (arr sh elm)
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError PAErrors
PAEUpperBound
    Right (CellSize  Word
_) -> LimitType sh -> elm -> m (arr sh elm)
forall (m :: * -> *) (arr :: * -> * -> *) sh elm.
(PrimMonad m, PrimArrayOps arr sh elm) =>
LimitType sh -> elm -> m (arr sh elm)
newWithPA LimitType sh
ub elm
def
{-# Inlinable safeNewWithPA #-}


-- | Return all associations from an array.

assocs :: forall arr sh elm . (IndexStream sh, PrimArrayOps arr sh elm) => arr sh elm -> [(sh,elm)]
assocs :: arr sh elm -> [(sh, elm)]
assocs arr sh elm
arr = Id [(sh, elm)] -> [(sh, elm)]
forall a. Id a -> a
unId (Id [(sh, elm)] -> [(sh, elm)])
-> (Stream Id (sh, elm) -> Id [(sh, elm)])
-> Stream Id (sh, elm)
-> [(sh, elm)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stream Id (sh, elm) -> Id [(sh, elm)]
forall (m :: * -> *) a. Monad m => Stream m a -> m [a]
SM.toList (Stream Id (sh, elm) -> [(sh, elm)])
-> Stream Id (sh, elm) -> [(sh, elm)]
forall a b. (a -> b) -> a -> b
$ arr sh elm -> Stream Id (sh, elm)
forall (m :: * -> *) (arr :: * -> * -> *) sh elm.
(Monad m, IndexStream sh, PrimArrayOps arr sh elm) =>
arr sh elm -> Stream m (sh, elm)
assocsS arr sh elm
arr
{-# INLINE assocs #-}

-- | Return all associations from an array.

assocsS :: forall m arr sh elm . (Monad m, IndexStream sh, PrimArrayOps arr sh elm) => arr sh elm -> SM.Stream m (sh,elm)
assocsS :: arr sh elm -> Stream m (sh, elm)
assocsS arr sh elm
arr = (sh -> (sh, elm)) -> Stream m sh -> Stream m (sh, elm)
forall (m :: * -> *) a b.
Monad m =>
(a -> b) -> Stream m a -> Stream m b
SM.map (\sh
k -> (sh
k,arr sh elm -> sh -> elm
forall (arr :: * -> * -> *) sh elm.
PrimArrayOps arr sh elm =>
arr sh elm -> sh -> elm
unsafeIndex arr sh elm
arr sh
k)) (Stream m sh -> Stream m (sh, elm))
-> Stream m sh -> Stream m (sh, elm)
forall a b. (a -> b) -> a -> b
$ LimitType sh -> LimitType sh -> Stream m sh
forall i (m :: * -> *).
(IndexStream i, Monad m) =>
LimitType i -> LimitType i -> Stream m i
streamUp LimitType sh
forall i. Index i => LimitType i
zeroBound' (arr sh elm -> LimitType sh
forall (arr :: * -> * -> *) sh elm.
PrimArrayOps arr sh elm =>
arr sh elm -> LimitType sh
upperBound arr sh elm
arr)
{-# INLINE assocsS #-}

-- | Creates an immutable array from lower and upper bounds and a complete list
-- of elements.

fromList :: (PrimArrayOps arr sh elm) => LimitType sh -> [elm] -> arr sh elm
fromList :: LimitType sh -> [elm] -> arr sh elm
fromList LimitType sh
ub [elm]
xs = (forall s. ST s (arr sh elm)) -> arr sh elm
forall a. (forall s. ST s a) -> a
runST ((forall s. ST s (arr sh elm)) -> arr sh elm)
-> (forall s. ST s (arr sh elm)) -> arr sh elm
forall a b. (a -> b) -> a -> b
$ LimitType sh -> [elm] -> ST s (MutArr (ST s) (arr sh elm))
forall (arr :: * -> * -> *) sh elm (m :: * -> *).
(PrimArrayOps arr sh elm, PrimMonad m) =>
LimitType sh -> [elm] -> m (MutArr m (arr sh elm))
fromListM LimitType sh
ub [elm]
xs ST s (MutArr (ST s) (arr sh elm))
-> (MutArr (ST s) (arr sh elm) -> ST s (arr sh elm))
-> ST s (arr sh elm)
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= MutArr (ST s) (arr sh elm) -> ST s (arr sh elm)
forall (arr :: * -> * -> *) sh elm (m :: * -> *).
(PrimArrayOps arr sh elm, PrimMonad m) =>
MutArr m (arr sh elm) -> m (arr sh elm)
unsafeFreezeM
{-# INLINE fromList #-}

-- | Creates an immutable array from lower and upper bounds, a default element,
-- and a list of associations.

fromAssocs :: (PrimArrayOps arr sh elm) => LimitType sh -> elm -> [(sh,elm)] -> arr sh elm
fromAssocs :: LimitType sh -> elm -> [(sh, elm)] -> arr sh elm
fromAssocs LimitType sh
ub elm
def [(sh, elm)]
xs = (forall s. ST s (arr sh elm)) -> arr sh elm
forall a. (forall s. ST s a) -> a
runST ((forall s. ST s (arr sh elm)) -> arr sh elm)
-> (forall s. ST s (arr sh elm)) -> arr sh elm
forall a b. (a -> b) -> a -> b
$ LimitType sh
-> elm -> [(sh, elm)] -> ST s (MutArr (ST s) (arr sh elm))
forall (m :: * -> *) (arr :: * -> * -> *) sh elm.
(PrimMonad m, PrimArrayOps arr sh elm) =>
LimitType sh -> elm -> [(sh, elm)] -> m (MutArr m (arr sh elm))
fromAssocsM LimitType sh
ub elm
def [(sh, elm)]
xs ST s (MutArr (ST s) (arr sh elm))
-> (MutArr (ST s) (arr sh elm) -> ST s (arr sh elm))
-> ST s (arr sh elm)
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= MutArr (ST s) (arr sh elm) -> ST s (arr sh elm)
forall (arr :: * -> * -> *) sh elm (m :: * -> *).
(PrimArrayOps arr sh elm, PrimMonad m) =>
MutArr m (arr sh elm) -> m (arr sh elm)
unsafeFreezeM
{-# INLINE fromAssocs #-}

-- -- | Determines if an index is valid for a given immutable array.
-- 
-- inBounds :: PrimArrayOps arr sh elm => arr sh elm -> sh -> Bool
-- inBounds arr idx = let (lb,ub) = bounds arr in inShapeRange lb (ub `addDim` unitDim) idx
-- {-# INLINE inBounds #-}

-- | Returns all elements of an immutable array as a list.

toList :: forall arr sh elm . (IndexStream sh, PrimArrayOps arr sh elm) => arr sh elm -> [elm]
toList :: arr sh elm -> [elm]
toList arr sh elm
arr = let ub :: LimitType sh
ub = arr sh elm -> LimitType sh
forall (arr :: * -> * -> *) sh elm.
PrimArrayOps arr sh elm =>
arr sh elm -> LimitType sh
upperBound arr sh elm
arr in (sh -> elm) -> [sh] -> [elm]
forall a b. (a -> b) -> [a] -> [b]
P.map ((!) arr sh elm
arr) ([sh] -> [elm]) -> (Stream Id sh -> [sh]) -> Stream Id sh -> [elm]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Id [sh] -> [sh]
forall a. Id a -> a
unId (Id [sh] -> [sh])
-> (Stream Id sh -> Id [sh]) -> Stream Id sh -> [sh]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stream Id sh -> Id [sh]
forall (m :: * -> *) a. Monad m => Stream m a -> m [a]
SM.toList (Stream Id sh -> [elm]) -> Stream Id sh -> [elm]
forall a b. (a -> b) -> a -> b
$ LimitType sh -> LimitType sh -> Stream Id sh
forall i (m :: * -> *).
(IndexStream i, Monad m) =>
LimitType i -> LimitType i -> Stream m i
streamUp LimitType sh
forall i. Index i => LimitType i
zeroBound' LimitType sh
ub
{-# INLINE toList #-}



{-

-- * Freeze an inductive stack of tables with a 'Z' at the bottom.

-- | 'freezeTables' freezes a stack of tables.

class FreezeTables m t where
    type Frozen t :: *
    freezeTables :: t -> m (Frozen t)

instance Applicative m => FreezeTables m Z where
    type Frozen Z = Z
    freezeTables Z = pure Z
    {-# INLINE freezeTables #-}

instance (Functor m, Applicative m, Monad m, PrimMonad m, FreezeTables m ts, PrimArrayOps arr sh elm) => FreezeTables m (ts:.MutArr m (arr sh elm)) where
    type Frozen (ts:.MutArr m (arr sh elm)) = Frozen ts :. arr sh elm
    freezeTables (ts:.t) = (:.) <$> freezeTables ts <*> unsafeFreezeM t
    {-# INLINE freezeTables #-}

-}