module Data.Repa.Eval.Chain
( chainOfArray
, unchainToArray
, unchainToArrayIO)
where
import Data.Repa.Chain (Chain(..), Step(..))
import Data.Repa.Array.Generic.Index as A
import Data.Repa.Array.Internals.Bulk as A
import Data.Repa.Array.Internals.Target as A
import qualified Data.Vector.Fusion.Stream.Monadic as S
import qualified Data.Vector.Fusion.Stream.Size as S
import qualified Data.Vector.Fusion.Util as S
import System.IO.Unsafe
#include "repa-array.h"
chainOfArray
:: (Monad m, Bulk l a)
=> Array l a -> Chain m Int a
chainOfArray !arr
= Chain (S.Exact len) 0 step
where
!len = A.length arr
step !i
| i >= len = return $ Done i
| otherwise
= return $ Yield (A.index arr $ A.fromIndex (A.layout arr) i) (i + 1)
liftChain :: Monad m => Chain S.Id s a -> Chain m s a
liftChain (Chain sz s step)
= Chain sz s (return . S.unId . step)
unchainToArray
:: Target l a
=> Name l -> Chain S.Id s a -> (Array l a, s)
unchainToArray nDst c
= unsafePerformIO
$ unchainToArrayIO nDst
$ liftChain c
unchainToArrayIO
:: Target l a
=> Name l -> Chain IO s a -> IO (Array l a, s)
unchainToArrayIO nDst (Chain sz s0 step)
= case sz of
S.Exact i -> unchainToArrayIO_max i
S.Max i -> unchainToArrayIO_max i
S.Unknown -> unchainToArrayIO_unknown 32
where unchainToArrayIO_max !nMax
= do !vec0 <- unsafeNewBuffer (create nDst zeroDim)
!vec <- unsafeGrowBuffer vec0 nMax
let go_unchainIO_max !sPEC !i !s
= step s >>= \m
-> case m of
Yield e s'
-> do unsafeWriteBuffer vec i e
go_unchainIO_max sPEC (i + 1) s'
Skip s'
-> go_unchainIO_max sPEC i s'
Done s'
-> do buf' <- unsafeSliceBuffer 0 i vec
arr <- unsafeFreezeBuffer buf'
return (arr, s')
go_unchainIO_max S.SPEC 0 s0
unchainToArrayIO_unknown !nStart
= do !vec0 <- unsafeNewBuffer (create nDst zeroDim)
!vec1 <- unsafeGrowBuffer vec0 nStart
let go_unchainIO_unknown !uvec !i !n !s
= go_unchainIO_unknown1 uvec i n s
(\vec' i' n' s' -> go_unchainIO_unknown vec' i' n' s')
(\result -> return result)
go_unchainIO_unknown1 !vec !i !n !s cont done
= step s >>= \r
-> case r of
Yield e s'
-> do (vec', n')
<- if i >= n
then do vec' <- unsafeGrowBuffer vec n
return (vec', n + n)
else return (vec, n)
unsafeWriteBuffer vec' i e
cont vec' (i + 1) n' s'
Skip s'
-> cont vec i n s'
Done s'
-> do
vec' <- unsafeSliceBuffer 0 i vec
arr <- unsafeFreezeBuffer vec'
done (arr, s')
go_unchainIO_unknown vec1 0 nStart s0