{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE NamedFieldPuns #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}
module Data.Massiv.Array.Delayed.Windowed
( DW(..)
, Array(..)
, Window(..)
, insertWindow
, getWindow
, dropWindow
, makeWindowedArray
) where
import Control.Exception (Exception(..))
import Control.Monad (when)
import Data.Massiv.Array.Delayed.Pull
import Data.Massiv.Array.Manifest.Boxed
import Data.Massiv.Array.Manifest.Internal
import Data.Massiv.Core
import Data.Massiv.Core.Common
import Data.Massiv.Core.Index.Internal (Sz(..))
import Data.Massiv.Core.List (L, showArrayList, showsArrayPrec)
import Data.Maybe (fromMaybe)
import GHC.TypeLits
data DW = DW
data Window ix e = Window { windowStart :: !ix
, windowSize :: !(Sz ix)
, windowIndex :: ix -> e
, windowUnrollIx2 :: !(Maybe Int)
}
instance Functor (Window ix) where
fmap f arr@Window{windowIndex} = arr { windowIndex = f . windowIndex }
data instance Array DW ix e = DWArray { dwArray :: !(Array D ix e)
, dwWindow :: !(Maybe (Window ix e))
}
instance (Ragged L ix e, Load DW ix e, Show e) => Show (Array DW ix e) where
showsPrec = showsArrayPrec (computeAs B)
showList = showArrayList
instance Index ix => Construct DW ix e where
setComp c arr = arr { dwArray = (dwArray arr) { dComp = c } }
{-# INLINE setComp #-}
makeArray c sz f = DWArray (makeArray c sz f) Nothing
{-# INLINE makeArray #-}
instance Functor (Array DW ix) where
fmap f arr@DWArray{dwArray, dwWindow} =
arr
{ dwArray = fmap f dwArray
, dwWindow = fmap f <$> dwWindow
}
{-# INLINE fmap #-}
makeWindowedArray
:: Source r ix e
=> Array r ix e
-> ix
-> Sz ix
-> (ix -> e)
-> Array DW ix e
makeWindowedArray !arr wStart wSize wIndex =
insertWindow (delay arr) $
Window {windowStart = wStart, windowSize = wSize, windowIndex = wIndex, windowUnrollIx2 = Nothing}
{-# INLINE makeWindowedArray #-}
insertWindow
:: Source D ix e
=> Array D ix e
-> Window ix e
-> Array DW ix e
insertWindow !arr !window =
DWArray
{ dwArray = delay arr
, dwWindow =
Just $!
Window
{ windowStart = wStart'
, windowSize = Sz (liftIndex2 min wSize (liftIndex2 (-) sz wStart'))
, windowIndex = wIndex
, windowUnrollIx2 = wUnrollIx2
}
}
where
wStart' = unSz (Sz (liftIndex2 min wStart (liftIndex (subtract 1) sz)))
Sz sz = size arr
Window { windowStart = wStart
, windowSize = Sz wSize
, windowIndex = wIndex
, windowUnrollIx2 = wUnrollIx2
} = window
{-# INLINE insertWindow #-}
getWindow :: Array DW ix e -> Maybe (Window ix e)
getWindow = dwWindow
{-# INLINE getWindow #-}
dropWindow :: Array DW ix e -> Array D ix e
dropWindow = dwArray
{-# INLINE dropWindow #-}
zeroWindow :: Index ix => Window ix e
zeroWindow = Window zeroIndex zeroSz windowError Nothing
{-# INLINE zeroWindow #-}
data EmptyWindowException = EmptyWindowException deriving (Eq, Show)
instance Exception EmptyWindowException where
displayException _ = "Index of zero size Window"
windowError :: a
windowError = throwImpossible EmptyWindowException
{-# NOINLINE windowError #-}
loadWithIx1 ::
(Monad m)
=> (m () -> m ())
-> Array DW Ix1 e
-> (Ix1 -> e -> m a)
-> m (Ix1 -> Ix1 -> m (), Ix1, Ix1)
loadWithIx1 with (DWArray (DArray _ sz indexB) mWindow) uWrite = do
let Window it wk indexW _ = fromMaybe zeroWindow mWindow
wEnd = it + unSz wk
with $ iterM_ 0 it 1 (<) $ \ !i -> uWrite i (indexB i)
with $ iterM_ wEnd (unSz sz) 1 (<) $ \ !i -> uWrite i (indexB i)
return (\from to -> with $ iterM_ from to 1 (<) $ \ !i -> uWrite i (indexW i), it, wEnd)
{-# INLINE loadWithIx1 #-}
instance Load DW Ix1 e where
size = dSize . dwArray
{-# INLINE size #-}
getComp = dComp . dwArray
{-# INLINE getComp #-}
loadArrayM scheduler arr uWrite = do
(loadWindow, wStart, wEnd) <- loadWithIx1 (scheduleWork scheduler) arr uWrite
let (chunkWidth, slackWidth) = (wEnd - wStart) `quotRem` numWorkers scheduler
loopM_ 0 (< numWorkers scheduler) (+ 1) $ \ !wid ->
let !it' = wid * chunkWidth + wStart
in loadWindow it' (it' + chunkWidth)
when (slackWidth > 0) $
let !itSlack = numWorkers scheduler * chunkWidth + wStart
in loadWindow itSlack (itSlack + slackWidth)
{-# INLINE loadArrayM #-}
instance StrideLoad DW Ix1 e where
loadArrayWithStrideM scheduler stride sz arr uWrite = do
(loadWindow, (wStart, wEnd)) <- loadArrayWithIx1 (scheduleWork scheduler) arr stride sz uWrite
let (chunkWidth, slackWidth) = (wEnd - wStart) `quotRem` numWorkers scheduler
loopM_ 0 (< numWorkers scheduler) (+ 1) $ \ !wid ->
let !it' = wid * chunkWidth + wStart
in loadWindow (it', it' + chunkWidth)
when (slackWidth > 0) $
let !itSlack = numWorkers scheduler * chunkWidth + wStart
in loadWindow (itSlack, itSlack + slackWidth)
{-# INLINE loadArrayWithStrideM #-}
loadArrayWithIx1 ::
(Monad m)
=> (m () -> m ())
-> Array DW Ix1 e
-> Stride Ix1
-> Sz1
-> (Ix1 -> e -> m a)
-> m ((Ix1, Ix1) -> m (), (Ix1, Ix1))
loadArrayWithIx1 with (DWArray (DArray _ arrSz indexB) mWindow) stride _ uWrite = do
let Window it wk indexW _ = fromMaybe zeroWindow mWindow
wEnd = it + unSz wk
strideIx = unStride stride
with $ iterM_ 0 it strideIx (<) $ \ !i -> uWrite (i `div` strideIx) (indexB i)
with $
iterM_ (strideStart stride wEnd) (unSz arrSz) strideIx (<) $ \ !i ->
uWrite (i `div` strideIx) (indexB i)
return
( \(from, to) ->
with $
iterM_ (strideStart stride from) to strideIx (<) $ \ !i ->
uWrite (i `div` strideIx) (indexW i)
, (it, wEnd))
{-# INLINE loadArrayWithIx1 #-}
loadWithIx2 ::
Monad m
=> (m () -> m ())
-> Array DW Ix2 t1
-> (Int -> t1 -> m ())
-> m (Ix2 -> m (), Ix2)
loadWithIx2 with arr uWrite = do
let DWArray (DArray _ (Sz (m :. n)) indexB) window = arr
let Window (it :. jt) (Sz (wm :. wn)) indexW mUnrollHeight = fromMaybe zeroWindow window
let ib :. jb = (wm + it) :. (wn + jt)
!blockHeight = maybe 1 (min 7 . max 1) mUnrollHeight
stride = oneStride
!sz = strideSize stride $ size arr
writeB !ix = uWrite (toLinearIndex sz ix) (indexB ix)
{-# INLINE writeB #-}
writeW !ix = uWrite (toLinearIndex sz ix) (indexW ix)
{-# INLINE writeW #-}
with $ iterM_ (0 :. 0) (it :. n) (1 :. 1) (<) writeB
with $ iterM_ (ib :. 0) (m :. n) (1 :. 1) (<) writeB
with $ iterM_ (it :. 0) (ib :. jt) (1 :. 1) (<) writeB
with $ iterM_ (it :. jb) (ib :. n) (1 :. 1) (<) writeB
let f (it' :. ib') = with $ unrollAndJam blockHeight (it' :. jt) (ib' :. jb) 1 writeW
{-# INLINE f #-}
return (f, it :. ib)
{-# INLINE loadWithIx2 #-}
loadArrayWithIx2 ::
Monad m
=> (m () -> m ())
-> Array DW Ix2 e
-> Stride Ix2
-> Sz2
-> (Int -> e -> m ())
-> m (Ix2 -> m (), Ix2)
loadArrayWithIx2 with arr stride sz uWrite = do
let DWArray (DArray _ (Sz (m :. n)) indexB) window = arr
let Window (it :. jt) (Sz (wm :. wn)) indexW mUnrollHeight = fromMaybe zeroWindow window
let ib :. jb = (wm + it) :. (wn + jt)
!blockHeight = maybe 1 (min 7 . max 1) mUnrollHeight
strideIx@(is :. js) = unStride stride
writeB !ix = uWrite (toLinearIndexStride stride sz ix) (indexB ix)
{-# INLINE writeB #-}
writeW !ix = uWrite (toLinearIndexStride stride sz ix) (indexW ix)
{-# INLINE writeW #-}
with $ iterM_ (0 :. 0) (it :. n) strideIx (<) writeB
with $ iterM_ (strideStart stride (ib :. 0)) (m :. n) strideIx (<) writeB
with $ iterM_ (strideStart stride (it :. 0)) (ib :. jt) strideIx (<) writeB
with $ iterM_ (strideStart stride (it :. jb)) (ib :. n) strideIx (<) writeB
f <-
if is > 1 || blockHeight <= 1
then return $ \(it' :. ib') ->
iterM_ (strideStart stride (it' :. jt)) (ib' :. jb) strideIx (<) writeW
else return $ \(it' :. ib') ->
unrollAndJam blockHeight (strideStart stride (it' :. jt)) (ib' :. jb) js writeW
return (f, it :. ib)
{-# INLINE loadArrayWithIx2 #-}
loadWindowIx2 :: Monad m => Int -> (Ix2 -> m ()) -> Ix2 -> m ()
loadWindowIx2 nWorkers loadWindow (it :. ib) = do
let !(chunkHeight, slackHeight) = (ib - it) `quotRem` nWorkers
loopM_ 0 (< nWorkers) (+ 1) $ \ !wid ->
let !it' = wid * chunkHeight + it
in loadWindow (it' :. (it' + chunkHeight))
when (slackHeight > 0) $
let !itSlack = nWorkers * chunkHeight + it
in loadWindow (itSlack :. (itSlack + slackHeight))
{-# INLINE loadWindowIx2 #-}
instance Load DW Ix2 e where
size = dSize . dwArray
{-# INLINE size #-}
getComp = dComp . dwArray
{-# INLINE getComp #-}
loadArrayM scheduler arr uWrite =
loadWithIx2 (scheduleWork scheduler) arr uWrite >>=
uncurry (loadWindowIx2 (numWorkers scheduler))
{-# INLINE loadArrayM #-}
instance StrideLoad DW Ix2 e where
loadArrayWithStrideM scheduler stride sz arr uWrite =
loadArrayWithIx2 (scheduleWork scheduler) arr stride sz uWrite >>=
uncurry (loadWindowIx2 (numWorkers scheduler))
{-# INLINE loadArrayWithStrideM #-}
instance (Index (IxN n), Load DW (Ix (n - 1)) e) => Load DW (IxN n) e where
size = dSize . dwArray
{-# INLINE size #-}
getComp = dComp . dwArray
{-# INLINE getComp #-}
loadArrayM = loadWithIxN
{-# INLINE loadArrayM #-}
instance (Index (IxN n), StrideLoad DW (Ix (n - 1)) e) => StrideLoad DW (IxN n) e where
loadArrayWithStrideM = loadArrayWithIxN
{-# INLINE loadArrayWithStrideM #-}
loadArrayWithIxN ::
(Index ix, Monad m, StrideLoad DW (Lower ix) e)
=> Scheduler m ()
-> Stride ix
-> Sz ix
-> Array DW ix e
-> (Int -> e -> m ())
-> m ()
loadArrayWithIxN scheduler stride szResult arr uWrite = do
let DWArray darr window = arr
DArray {dSize = szSource, dIndex = indexBorder} = darr
Window {windowStart, windowSize, windowIndex, windowUnrollIx2} = fromMaybe zeroWindow window
!(headSourceSize, lowerSourceSize) = unconsSz szSource
!lowerSize = snd $ unconsSz szResult
!(s, lowerStrideIx) = unconsDim $ unStride stride
!(curWindowStart, lowerWindowStart) = unconsDim windowStart
!(headWindowSz, tailWindowSz) = unconsSz windowSize
!curWindowEnd = curWindowStart + unSz headWindowSz
!pageElements = totalElem lowerSize
mkLowerWindow i =
Window
{ windowStart = lowerWindowStart
, windowSize = tailWindowSz
, windowIndex = windowIndex . consDim i
, windowUnrollIx2 = windowUnrollIx2
}
mkLowerArray mw i =
DWArray
{dwArray = DArray Seq lowerSourceSize (indexBorder . consDim i), dwWindow = ($ i) <$> mw}
loadLower mw !i =
loadArrayWithStrideM
scheduler
(Stride lowerStrideIx)
lowerSize
(mkLowerArray mw i)
(\k -> uWrite (k + pageElements * (i `div` s)))
{-# NOINLINE loadLower #-}
loopM_ 0 (< headDim windowStart) (+ s) (loadLower Nothing)
loopM_
(strideStart (Stride s) curWindowStart)
(< curWindowEnd)
(+ s)
(loadLower (Just mkLowerWindow))
loopM_ (strideStart (Stride s) curWindowEnd) (< unSz headSourceSize) (+ s) (loadLower Nothing)
{-# INLINE loadArrayWithIxN #-}
loadWithIxN ::
(Index ix, Monad m, Load DW (Lower ix) e)
=> Scheduler m ()
-> Array DW ix e
-> (Int -> e -> m ())
-> m ()
loadWithIxN scheduler arr uWrite = do
let DWArray darr window = arr
DArray {dSize = sz, dIndex = indexBorder} = darr
Window {windowStart, windowSize, windowIndex, windowUnrollIx2} = fromMaybe zeroWindow window
!(si, szL) = unconsSz sz
!windowEnd = liftIndex2 (+) windowStart (unSz windowSize)
!(t, windowStartL) = unconsDim windowStart
!pageElements = totalElem szL
mkLowerWindow i =
Window
{ windowStart = windowStartL
, windowSize = snd $ unconsSz windowSize
, windowIndex = windowIndex . consDim i
, windowUnrollIx2 = windowUnrollIx2
}
mkLowerArray mw i =
DWArray {dwArray = DArray Seq szL (indexBorder . consDim i), dwWindow = ($ i) <$> mw}
loadLower mw !i =
scheduleWork_ scheduler $
loadArrayM scheduler (mkLowerArray mw i) (\k -> uWrite (k + pageElements * i))
{-# NOINLINE loadLower #-}
loopM_ 0 (< headDim windowStart) (+ 1) (loadLower Nothing)
loopM_ t (< headDim windowEnd) (+ 1) (loadLower (Just mkLowerWindow))
loopM_ (headDim windowEnd) (< unSz si) (+ 1) (loadLower Nothing)
{-# INLINE loadWithIxN #-}
unrollAndJam :: Monad m =>
Int
-> Ix2
-> Ix2
-> Int
-> (Ix2 -> m ())
-> m ()
unrollAndJam !bH (it :. jt) (ib :. jb) js f = do
let f2 (i :. j) = f (i :. j) >> f ((i + 1) :. j)
let f3 (i :. j) = f (i :. j) >> f2 ((i + 1) :. j)
let f4 (i :. j) = f (i :. j) >> f3 ((i + 1) :. j)
let f5 (i :. j) = f (i :. j) >> f4 ((i + 1) :. j)
let f6 (i :. j) = f (i :. j) >> f5 ((i + 1) :. j)
let f7 (i :. j) = f (i :. j) >> f6 ((i + 1) :. j)
let f' = case bH of
1 -> f
2 -> f2
3 -> f3
4 -> f4
5 -> f5
6 -> f6
_ -> f7
let !ibS = ib - ((ib - it) `mod` bH)
loopM_ it (< ibS) (+ bH) $ \ !i ->
loopM_ jt (< jb) (+ js) $ \ !j ->
f' (i :. j)
loopM_ ibS (< ib) (+ 1) $ \ !i ->
loopM_ jt (< jb) (+ js) $ \ !j ->
f (i :. j)
{-# INLINE unrollAndJam #-}
toIx2Window :: Window Ix2T e -> Window Ix2 e
toIx2Window Window {..} =
Window
{ windowStart = toIx2 windowStart
, windowSize = SafeSz (toIx2 $ unSz windowSize)
, windowIndex = windowIndex . fromIx2
, windowUnrollIx2 = windowUnrollIx2
}
{-# INLINE toIx2Window #-}
toIx2ArrayDW :: Array DW Ix2T e -> Array DW Ix2 e
toIx2ArrayDW DWArray {dwArray, dwWindow} =
DWArray
{ dwArray =
dwArray {dIndex = dIndex dwArray . fromIx2, dSize = SafeSz (toIx2 (unSz (dSize dwArray)))}
, dwWindow = fmap toIx2Window dwWindow
}
{-# INLINE toIx2ArrayDW #-}
instance Load DW Ix2T e where
size = dSize . dwArray
{-# INLINE size #-}
getComp = dComp . dwArray
{-# INLINE getComp #-}
loadArrayM scheduler arr =
loadArrayWithStrideM scheduler oneStride (size arr) arr
{-# INLINE loadArrayM #-}
instance StrideLoad DW Ix2T e where
loadArrayWithStrideM scheduler stride sz arr =
loadArrayWithStrideM
scheduler
(Stride $ toIx2 $ unStride stride)
(SafeSz (toIx2 (unSz sz)))
(toIx2ArrayDW arr)
{-# INLINE loadArrayWithStrideM #-}
instance Load DW Ix3T e where
size = dSize . dwArray
{-# INLINE size #-}
getComp = dComp . dwArray
{-# INLINE getComp #-}
loadArrayM scheduler arr =
loadArrayWithStrideM scheduler oneStride (size arr) arr
{-# INLINE loadArrayM #-}
instance StrideLoad DW Ix3T e where
loadArrayWithStrideM = loadArrayWithIxN
{-# INLINE loadArrayWithStrideM #-}
instance Load DW Ix4T e where
size = dSize . dwArray
{-# INLINE size #-}
getComp = dComp . dwArray
{-# INLINE getComp #-}
loadArrayM scheduler arr = loadArrayWithStrideM scheduler oneStride (size arr) arr
{-# INLINE loadArrayM #-}
instance StrideLoad DW Ix4T e where
loadArrayWithStrideM = loadArrayWithIxN
{-# INLINE loadArrayWithStrideM #-}
instance Load DW Ix5T e where
size = dSize . dwArray
{-# INLINE size #-}
getComp = dComp . dwArray
{-# INLINE getComp #-}
loadArrayM scheduler arr = loadArrayWithStrideM scheduler oneStride (size arr) arr
{-# INLINE loadArrayM #-}
instance StrideLoad DW Ix5T e where
loadArrayWithStrideM = loadArrayWithIxN
{-# INLINE loadArrayWithStrideM #-}