{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE CPP #-}
{-# LANGUAGE DefaultSignatures #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE MagicHash #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE UnboxedTuples #-}
{-# LANGUAGE UndecidableInstances #-}
module Data.Massiv.Core.Common
( Array
, Elt
, EltRepr
, Construct(..)
, Source(..)
, Load(..)
, Size(..)
, Slice(..)
, OuterSlice(..)
, InnerSlice(..)
, Manifest(..)
, Mutable(..)
, State(..)
, WorldState
, Ragged(..)
, Nested(..)
, NestedStruct
, makeArray
, singleton
, (!?)
, index
, indexWith
, (!)
, index'
, (??)
, defaultIndex
, borderIndex
, evaluateAt
, module Data.Massiv.Core.Index
, imapM_
, module Data.Massiv.Core.Computation
) where
import Control.Monad.Primitive
import Data.Massiv.Core.Computation
import Data.Massiv.Core.Index
import Data.Massiv.Core.Scheduler
import Data.Typeable
import GHC.Prim
#include "massiv.h"
data family Array r ix e :: *
type family EltRepr r ix :: *
type family Elt r ix e :: * where
Elt r Ix1 e = e
Elt r ix e = Array (EltRepr r ix) (Lower ix) e
type family NestedStruct r ix e :: *
class (Typeable r, Index ix) => Construct r ix e where
getComp :: Array r ix e -> Comp
setComp :: Comp -> Array r ix e -> Array r ix e
unsafeMakeArray :: Comp -> ix -> (ix -> e) -> Array r ix e
class Construct r ix e => Size r ix e where
size :: Array r ix e -> ix
unsafeResize :: Index ix' => ix' -> Array r ix e -> Array r ix' e
unsafeExtract :: ix -> ix -> Array r ix e -> Array (EltRepr r ix) ix e
class Size r ix e => Source r ix e where
unsafeIndex :: Array r ix e -> ix -> e
unsafeIndex =
INDEX_CHECK("(Source r ix e).unsafeIndex",
size, \ !arr -> unsafeLinearIndex arr . toLinearIndex (size arr))
{-# INLINE unsafeIndex #-}
unsafeLinearIndex :: Array r ix e -> Int -> e
unsafeLinearIndex !arr = unsafeIndex arr . fromLinearIndex (size arr)
{-# INLINE unsafeLinearIndex #-}
class Size r ix e => Load r ix e where
loadS
:: Monad m =>
Array r ix e
-> (Int -> m e)
-> (Int -> e -> m ())
-> m ()
loadS = loadArray 1 id
{-# INLINE loadS #-}
loadP
:: [Int]
-> Array r ix e
-> (Int -> IO e)
-> (Int -> e -> IO ())
-> IO ()
loadP wIds arr unsafeRead unsafeWrite =
withScheduler_ wIds $ \scheduler ->
loadArray (numWorkers scheduler) (scheduleWork scheduler) arr unsafeRead unsafeWrite
{-# INLINE loadP #-}
loadArrayWithStride
:: Monad m =>
Int
-> (m () -> m ())
-> Stride ix
-> ix
-> Array r ix e
-> (Int -> m e)
-> (Int -> e -> m ())
-> m ()
default loadArrayWithStride
:: (Source r ix e, Monad m) =>
Int
-> (m () -> m ())
-> Stride ix
-> ix
-> Array r ix e
-> (Int -> m e)
-> (Int -> e -> m ())
-> m ()
loadArrayWithStride numWorkers' scheduleWork' stride resultSize arr _ =
splitLinearlyWith_ numWorkers' scheduleWork' (totalElem resultSize) unsafeLinearWriteWithStride
where
strideIx = unStride stride
unsafeLinearWriteWithStride =
unsafeIndex arr . liftIndex2 (*) strideIx . fromLinearIndex resultSize
{-# INLINE unsafeLinearWriteWithStride #-}
{-# INLINE loadArrayWithStride #-}
loadArray
:: Monad m =>
Int
-> (m () -> m ())
-> Array r ix e
-> (Int -> m e)
-> (Int -> e -> m ())
-> m ()
default loadArray
:: (Source r ix e, Monad m) =>
Int
-> (m () -> m ())
-> Array r ix e
-> (Int -> m e)
-> (Int -> e -> m ())
-> m ()
loadArray numWorkers' scheduleWork' arr _ =
splitLinearlyWith_ numWorkers' scheduleWork' (totalElem (size arr)) (unsafeLinearIndex arr)
{-# INLINE loadArray #-}
class OuterSlice r ix e where
unsafeOuterSlice :: Array r ix e -> Int -> Elt r ix e
outerLength :: Array r ix e -> Int
default outerLength :: Size r ix e => Array r ix e -> Int
outerLength = headDim . size
class Size r ix e => InnerSlice r ix e where
unsafeInnerSlice :: Array r ix e -> (Lower ix, Int) -> Int -> Elt r ix e
class Size r ix e => Slice r ix e where
unsafeSlice :: Array r ix e -> ix -> ix -> Dim -> Maybe (Elt r ix e)
class Source r ix e => Manifest r ix e where
unsafeLinearIndexM :: Array r ix e -> Int -> e
data State s = State (State# s)
type WorldState = State RealWorld
class Manifest r ix e => Mutable r ix e where
data MArray s r ix e :: *
msize :: MArray s r ix e -> ix
unsafeThaw :: PrimMonad m =>
Array r ix e -> m (MArray (PrimState m) r ix e)
unsafeFreeze :: PrimMonad m =>
Comp -> MArray (PrimState m) r ix e -> m (Array r ix e)
unsafeNew :: PrimMonad m =>
ix -> m (MArray (PrimState m) r ix e)
unsafeNewZero :: PrimMonad m =>
ix -> m (MArray (PrimState m) r ix e)
unsafeLinearRead :: PrimMonad m =>
MArray (PrimState m) r ix e -> Int -> m e
unsafeLinearWrite :: PrimMonad m =>
MArray (PrimState m) r ix e -> Int -> e -> m ()
unsafeNewA :: Applicative f => ix -> WorldState -> f (WorldState, MArray RealWorld r ix e)
unsafeNewA sz (State s#) =
case internal (unsafeNew sz :: IO (MArray RealWorld r ix e)) s# of
(# s'#, ma #) -> pure (State s'#, ma)
{-# INLINE unsafeNewA #-}
unsafeThawA :: Applicative m =>
Array r ix e -> WorldState -> m (WorldState, MArray RealWorld r ix e)
unsafeThawA arr (State s#) =
case internal (unsafeThaw arr :: IO (MArray RealWorld r ix e)) s# of
(# s'#, ma #) -> pure (State s'#, ma)
{-# INLINE unsafeThawA #-}
unsafeFreezeA :: Applicative m =>
Comp -> MArray RealWorld r ix e -> WorldState -> m (WorldState, Array r ix e)
unsafeFreezeA comp marr (State s#) =
case internal (unsafeFreeze comp marr :: IO (Array r ix e)) s# of
(# s'#, a #) -> pure (State s'#, a)
{-# INLINE unsafeFreezeA #-}
unsafeLinearWriteA :: Applicative m =>
MArray RealWorld r ix e -> Int -> e -> WorldState -> m WorldState
unsafeLinearWriteA marr i val (State s#) =
case internal (unsafeLinearWrite marr i val :: IO ()) s# of
(# s'#, _ #) -> pure (State s'#)
{-# INLINE unsafeLinearWriteA #-}
class Nested r ix e where
fromNested :: NestedStruct r ix e -> Array r ix e
toNested :: Array r ix e -> NestedStruct r ix e
class Construct r ix e => Ragged r ix e where
empty :: Comp -> Array r ix e
isNull :: Array r ix e -> Bool
cons :: Elt r ix e -> Array r ix e -> Array r ix e
uncons :: Array r ix e -> Maybe (Elt r ix e, Array r ix e)
unsafeGenerateM :: Monad m => Comp -> ix -> (ix -> m e) -> m (Array r ix e)
edgeSize :: Array r ix e -> ix
flatten :: Array r ix e -> Array r Ix1 e
loadRagged ::
(IO () -> IO ()) -> (Int -> e -> IO a) -> Int -> Int -> Lower ix -> Array r ix e -> IO ()
raggedFormat :: (e -> String) -> String -> Array r ix e -> String
makeArray :: Construct r ix e =>
Comp
-> ix
-> (ix -> e)
-> Array r ix e
makeArray !c = unsafeMakeArray c . liftIndex (max 0)
{-# INLINE makeArray #-}
singleton :: Construct r ix e =>
Comp
-> e
-> Array r ix e
singleton !c = unsafeMakeArray c (pureIndex 1) . const
{-# INLINE singleton #-}
infixl 4 !, !?, ??
(!) :: Manifest r ix e => Array r ix e -> ix -> e
(!) = index'
{-# INLINE (!) #-}
(!?) :: Manifest r ix e => Array r ix e -> ix -> Maybe e
(!?) = index
{-# INLINE (!?) #-}
(??) :: Manifest r ix e => Maybe (Array r ix e) -> ix -> Maybe e
(??) Nothing = const Nothing
(??) (Just arr) = (arr !?)
{-# INLINE (??) #-}
index :: Manifest r ix e => Array r ix e -> ix -> Maybe e
index arr = handleBorderIndex (Fill Nothing) (size arr) (Just . unsafeIndex arr)
{-# INLINE index #-}
defaultIndex :: Manifest r ix e => e -> Array r ix e -> ix -> e
defaultIndex defVal = borderIndex (Fill defVal)
{-# INLINE defaultIndex #-}
borderIndex :: Manifest r ix e => Border e -> Array r ix e -> ix -> e
borderIndex border arr = handleBorderIndex border (size arr) (unsafeIndex arr)
{-# INLINE borderIndex #-}
index' :: Manifest r ix e => Array r ix e -> ix -> e
index' arr ix =
borderIndex (Fill (errorIx "Data.Massiv.Array.index" (size arr) ix)) arr ix
{-# INLINE index' #-}
evaluateAt :: Source r ix e => Array r ix e -> ix -> e
evaluateAt !arr !ix =
handleBorderIndex
(Fill (errorIx "Data.Massiv.Array.evaluateAt" (size arr) ix))
(size arr)
(unsafeIndex arr)
ix
{-# INLINE evaluateAt #-}
indexWith ::
Index ix
=> String
-> Int
-> String
-> (arr -> ix)
-> (arr -> ix -> e)
-> arr
-> ix
-> e
indexWith fileName lineNo funName getSize f arr ix
| isSafeIndex (getSize arr) ix = f arr ix
| otherwise = errorIx ("<" ++ fileName ++ ":" ++ show lineNo ++ "> " ++ funName) (getSize arr) ix
{-# INLINE indexWith #-}
imapM_ :: (Source r ix a, Monad m) => (ix -> a -> m b) -> Array r ix a -> m ()
imapM_ f !arr =
iterM_ zeroIndex (size arr) (pureIndex 1) (<) $ \ !ix -> f ix (unsafeIndex arr ix)
{-# INLINE imapM_ #-}