{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE CPP #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE MagicHash #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}
module Data.Massiv.Array.Manifest.Internal
( M
, Manifest(..)
, Array(..)
, toManifest
, compute
, computeS
, computeIO
, computePrimM
, computeAs
, computeProxy
, computeSource
, computeWithStride
, computeWithStrideAs
, clone
, convert
, convertAs
, convertProxy
, gcastArr
, fromRaggedArrayM
, fromRaggedArray'
, sizeofArray
, sizeofMutableArray
, iterateUntil
, iterateUntilM
) where
import Control.Exception (try)
import Control.Monad.ST
import Control.Scheduler
import qualified Data.Foldable as F (Foldable(..))
import Data.Massiv.Array.Delayed.Pull
import Data.Massiv.Array.Mutable
import Data.Massiv.Array.Ops.Fold.Internal
import Data.Massiv.Array.Mutable.Internal (unsafeCreateArray_)
import Data.Massiv.Vector.Stream as S (steps, isteps)
import Data.Massiv.Core.Common
import Data.Massiv.Core.List
import Data.Maybe (fromMaybe)
import Data.Typeable
import GHC.Base hiding (ord)
import System.IO.Unsafe (unsafePerformIO)
#if MIN_VERSION_primitive(0,6,2)
import Data.Primitive.Array (sizeofArray, sizeofMutableArray)
#else
import qualified Data.Primitive.Array as A (Array(..), MutableArray(..))
import GHC.Exts (sizeofArray#, sizeofMutableArray#)
sizeofArray :: A.Array a -> Int
sizeofArray (A.Array a) = I# (sizeofArray# a)
{-# INLINE sizeofArray #-}
sizeofMutableArray :: A.MutableArray s a -> Int
sizeofMutableArray (A.MutableArray ma) = I# (sizeofMutableArray# ma)
{-# INLINE sizeofMutableArray #-}
#endif
data M
data instance Array M ix e = MArray { mComp :: !Comp
, mSize :: !(Sz ix)
, mLinearIndex :: Int -> e }
instance (Ragged L ix e, Show e) => Show (Array M ix e) where
showsPrec = showsArrayPrec id
showList = showArrayList
instance (Eq e, Index ix) => Eq (Array M ix e) where
(==) = eq (==)
{-# INLINE (==) #-}
instance (Ord e, Index ix) => Ord (Array M ix e) where
compare = ord compare
{-# INLINE compare #-}
toManifest :: Manifest r ix e => Array r ix e -> Array M ix e
toManifest !arr = MArray (getComp arr) (size arr) (unsafeLinearIndexM arr)
{-# INLINE toManifest #-}
instance Index ix => Foldable (Array M ix) where
fold = fold
{-# INLINE fold #-}
foldMap = foldMono
{-# INLINE foldMap #-}
foldl = lazyFoldlS
{-# INLINE foldl #-}
foldl' = foldlS
{-# INLINE foldl' #-}
foldr = foldrFB
{-# INLINE foldr #-}
foldr' = foldrS
{-# INLINE foldr' #-}
null (MArray _ sz _) = totalElem sz == 0
{-# INLINE null #-}
length = totalElem . size
{-# INLINE length #-}
toList arr = build (\ c n -> foldrFB c n arr)
{-# INLINE toList #-}
instance Index ix => Source M ix e where
unsafeLinearIndex = mLinearIndex
{-# INLINE unsafeLinearIndex #-}
unsafeLinearSlice ix sz arr = unsafeExtract ix sz (unsafeResize sz arr)
{-# INLINE unsafeLinearSlice #-}
instance Index ix => Manifest M ix e where
unsafeLinearIndexM = mLinearIndex
{-# INLINE unsafeLinearIndexM #-}
instance Index ix => Resize M ix where
unsafeResize !sz !arr = arr { mSize = sz }
{-# INLINE unsafeResize #-}
instance Index ix => Extract M ix e where
unsafeExtract !sIx !newSz !arr =
MArray (getComp arr) newSz $ \ i ->
unsafeIndex arr (liftIndex2 (+) (fromLinearIndex newSz i) sIx)
{-# INLINE unsafeExtract #-}
instance {-# OVERLAPPING #-} Slice M Ix1 e where
unsafeSlice arr i _ _ = pure (unsafeLinearIndex arr i)
{-# INLINE unsafeSlice #-}
instance ( Index ix
, Index (Lower ix)
, Elt M ix e ~ Array M (Lower ix) e
) =>
Slice M ix e where
unsafeSlice arr start cutSz dim = do
(_, newSz) <- pullOutSzM cutSz dim
return $ unsafeResize newSz (unsafeExtract start cutSz arr)
{-# INLINE unsafeSlice #-}
instance {-# OVERLAPPING #-} OuterSlice M Ix1 e where
unsafeOuterSlice !arr = unsafeIndex arr
{-# INLINE unsafeOuterSlice #-}
instance (Elt M ix e ~ Array M (Lower ix) e, Index ix, Index (Lower ix)) => OuterSlice M ix e where
unsafeOuterSlice !arr !i =
MArray (getComp arr) (snd (unconsSz (size arr))) (unsafeLinearIndex arr . (+ kStart))
where
!kStart = toLinearIndex (size arr) (consDim i (zeroIndex :: Lower ix))
{-# INLINE unsafeOuterSlice #-}
instance {-# OVERLAPPING #-} InnerSlice M Ix1 e where
unsafeInnerSlice !arr _ = unsafeIndex arr
{-# INLINE unsafeInnerSlice #-}
instance (Elt M ix e ~ Array M (Lower ix) e, Index ix, Index (Lower ix)) => InnerSlice M ix e where
unsafeInnerSlice !arr (szL, m) !i =
MArray (getComp arr) szL (\k -> unsafeLinearIndex arr (k * unSz m + kStart))
where
!kStart = toLinearIndex (size arr) (snocDim (zeroIndex :: Lower ix) i)
{-# INLINE unsafeInnerSlice #-}
instance Index ix => Load M ix e where
size = mSize
{-# INLINE size #-}
getComp = mComp
{-# INLINE getComp #-}
loadArrayM scheduler (MArray _ sz f) = splitLinearlyWith_ scheduler (totalElem sz) f
{-# INLINE loadArrayM #-}
instance Index ix => StrideLoad M ix e
instance Index ix => Stream M ix e where
toStream = S.steps
{-# INLINE toStream #-}
toStreamIx = S.isteps
{-# INLINE toStreamIx #-}
compute :: forall r ix e r' . (Mutable r ix e, Load r' ix e) => Array r' ix e -> Array r ix e
compute !arr = unsafePerformIO $ computeIO arr
{-# INLINE compute #-}
computeS :: forall r ix e r' . (Mutable r ix e, Load r' ix e) => Array r' ix e -> Array r ix e
computeS !arr = runST $ computePrimM arr
{-# INLINE computeS #-}
computeIO ::
forall r ix e r' m. (Mutable r ix e, Load r' ix e, MonadIO m)
=> Array r' ix e
-> m (Array r ix e)
computeIO arr = liftIO (loadArray arr >>= unsafeFreeze (getComp arr))
{-# INLINE computeIO #-}
computePrimM ::
forall r ix e r' m. (Mutable r ix e, Load r' ix e, PrimMonad m)
=> Array r' ix e
-> m (Array r ix e)
computePrimM arr = loadArrayS arr >>= unsafeFreeze (getComp arr)
{-# INLINE computePrimM #-}
computeAs :: (Mutable r ix e, Load r' ix e) => r -> Array r' ix e -> Array r ix e
computeAs _ = compute
{-# INLINE computeAs #-}
computeProxy :: (Mutable r ix e, Load r' ix e) => proxy r -> Array r' ix e -> Array r ix e
computeProxy _ = compute
{-# INLINE computeProxy #-}
computeSource :: forall r ix e r' . (Mutable r ix e, Source r' ix e)
=> Array r' ix e -> Array r ix e
computeSource arr = maybe (compute arr) (\Refl -> arr) (eqT :: Maybe (r' :~: r))
{-# INLINE computeSource #-}
clone :: Mutable r ix e => Array r ix e -> Array r ix e
clone arr = unsafePerformIO $ thaw arr >>= unsafeFreeze (getComp arr)
{-# INLINE clone #-}
gcastArr :: forall r ix e r' . (Typeable r, Typeable r')
=> Array r' ix e -> Maybe (Array r ix e)
gcastArr arr = fmap (\Refl -> arr) (eqT :: Maybe (r :~: r'))
convert :: forall r ix e r' . (Mutable r ix e, Load r' ix e)
=> Array r' ix e -> Array r ix e
convert arr = fromMaybe (compute arr) (gcastArr arr)
{-# INLINE convert #-}
convertAs :: (Mutable r ix e, Load r' ix e)
=> r -> Array r' ix e -> Array r ix e
convertAs _ = convert
{-# INLINE convertAs #-}
convertProxy :: (Mutable r ix e, Load r' ix e)
=> proxy r -> Array r' ix e -> Array r ix e
convertProxy _ = convert
{-# INLINE convertProxy #-}
fromRaggedArrayM ::
forall r ix e r' m . (Mutable r ix e, Ragged r' ix e, Load r' ix e, MonadThrow m)
=> Array r' ix e
-> m (Array r ix e)
fromRaggedArrayM arr =
let sz = edgeSize arr
in either (\(e :: ShapeException) -> throwM e) pure $
unsafePerformIO $ do
marr <- unsafeNew sz
traverse (\_ -> unsafeFreeze (getComp arr) marr) =<<
try (withScheduler_ (getComp arr) $ \scheduler ->
loadRagged (scheduleWork scheduler) (unsafeLinearWrite marr) 0 (totalElem sz) sz arr)
{-# INLINE fromRaggedArrayM #-}
fromRaggedArray' ::
forall r ix e r'. (Mutable r ix e, Load r' ix e, Ragged r' ix e)
=> Array r' ix e
-> Array r ix e
fromRaggedArray' arr = either throw id $ fromRaggedArrayM arr
{-# INLINE fromRaggedArray' #-}
computeWithStride ::
forall r ix e r'. (Mutable r ix e, StrideLoad r' ix e)
=> Stride ix
-> Array r' ix e
-> Array r ix e
computeWithStride stride !arr =
unsafePerformIO $ do
let !sz = strideSize stride (size arr)
unsafeCreateArray_ (getComp arr) sz $ \scheduler marr ->
loadArrayWithStrideM scheduler stride sz arr (unsafeLinearWrite marr)
{-# INLINE computeWithStride #-}
computeWithStrideAs ::
(Mutable r ix e, StrideLoad r' ix e) => r -> Stride ix -> Array r' ix e -> Array r ix e
computeWithStrideAs _ = computeWithStride
{-# INLINE computeWithStrideAs #-}
iterateUntil ::
(Load r' ix e, Mutable r ix e)
=> (Int -> Array r ix e -> Array r ix e -> Bool)
-> (Int -> Array r ix e -> Array r' ix e)
-> Array r ix e
-> Array r ix e
iterateUntil convergence iteration initArr0
| convergence 0 initArr0 initArr1 = initArr1
| otherwise =
unsafePerformIO $ do
let loadArr = iteration 1 initArr1
marr <- unsafeNew (size loadArr)
iterateLoop
(\n a a' _ -> pure $ convergence n a a')
iteration
1
initArr1
loadArr
(asArr initArr0 marr)
where
!initArr1 = compute $ iteration 0 initArr0
asArr :: Array r ix e -> MArray s r ix e -> MArray s r ix e
asArr _ = id
{-# INLINE iterateUntil #-}
iterateUntilM ::
(Load r' ix e, Mutable r ix e, PrimMonad m, MonadIO m, PrimState m ~ RealWorld)
=> (Int -> Array r ix e -> MArray (PrimState m) r ix e -> m Bool)
-> (Int -> Array r ix e -> Array r' ix e)
-> Array r ix e
-> m (Array r ix e)
iterateUntilM convergence iteration initArr0 = do
let loadArr0 = iteration 0 initArr0
initMArr1 <- unsafeNew (size loadArr0)
computeInto initMArr1 loadArr0
shouldStop <- convergence 0 initArr0 initMArr1
initArr1 <- unsafeFreeze (getComp loadArr0) initMArr1
if shouldStop
then pure initArr1
else do
let loadArr1 = iteration 1 initArr1
marr <- unsafeNew (size loadArr1)
iterateLoop (\n a _ -> convergence n a) iteration 1 initArr1 loadArr1 marr
{-# INLINE iterateUntilM #-}
iterateLoop ::
(Load r' ix e, Mutable r ix e, PrimMonad m, MonadIO m, PrimState m ~ RealWorld)
=> (Int -> Array r ix e -> Array r ix e -> MArray (PrimState m) r ix e -> m Bool)
-> (Int -> Array r ix e -> Array r' ix e)
-> Int
-> Array r ix e
-> Array r' ix e
-> MArray (PrimState m) r ix e
-> m (Array r ix e)
iterateLoop convergence iteration = go
where
go !n !arr !loadArr !marr = do
let !sz = size loadArr
!k = totalElem sz
!mk = totalElem (msize marr)
marr' <-
if k == mk
then pure marr
else if k < mk
then unsafeLinearShrink marr sz
else unsafeLinearGrow marr sz
computeInto marr' loadArr
arr' <- unsafeFreeze (getComp loadArr) marr'
shouldStop <- convergence n arr arr' marr'
if shouldStop
then pure arr'
else do
nextMArr <- unsafeThaw arr
go (n + 1) arr' (iteration (n + 1) arr') nextMArr
{-# INLINE iterateLoop #-}