-- | Expose Java iterators as streams from the
-- <http://hackage.haskell.org/package/streaming streaming> package.

{-# LANGUAGE DataKinds #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE ExistentialQuantification #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE QuasiQuotes #-}
{-# LANGUAGE RecursiveDo #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StaticPointers #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE UndecidableInstances #-}

{-# OPTIONS_GHC -fno-warn-orphans #-}
{-# OPTIONS_GHC -fno-warn-unused-top-binds #-}

module Language.Java.Streaming
  ( reifyStreamWithBatching
  , reflectStreamWithBatching
  ) where

import Control.Distributed.Closure.TH
import Control.Monad.IO.Class (liftIO)
import qualified Data.Coerce as Coerce
import Data.IORef (newIORef, readIORef, writeIORef)
import Data.Int (Int32, Int64)
import Data.Proxy
import qualified Data.Vector as V
import Data.Singletons (SomeSing(..))
import Foreign.Ptr (FunPtr, Ptr, intPtrToPtr, ptrToIntPtr)
import Foreign.ForeignPtr.Unsafe (unsafeForeignPtrToPtr)
import qualified Foreign.JNI as JNI
import qualified Foreign.JNI.Types as JNI
import GHC.Stable
  ( castPtrToStablePtr
  , castStablePtrToPtr
  , deRefStablePtr
  , freeStablePtr
  , newStablePtr
  )
import Language.Java
import Language.Java.Batching
import Language.Java.Inline
import Streaming (Bifunctor(first), Stream, Of)
import qualified Streaming as Streaming
import qualified Streaming.Prelude as Streaming
import System.IO.Unsafe (unsafePerformIO)

imports "io.tweag.jvm.batching.*"
imports "java.util.Iterator"

type JNIFun a = JNIEnv -> Ptr JObject -> Int64 -> IO a

foreign import ccall "wrapper" wrapObjectFun
  :: JNIFun (Ptr (J ty)) -> IO (FunPtr (JNIFun (Ptr (J ty))))

-- Export only to get a FunPtr.
foreign export ccall "jvm_streaming_freeIterator" freeIterator
  :: JNIEnv -> Ptr JObject -> Int64 -> IO ()
foreign import ccall "&jvm_streaming_freeIterator" freeIteratorPtr
  :: FunPtr (JNIEnv -> Ptr JObject -> Int64 -> IO ())

data FunPtrTable = FunPtrTable
  { FunPtrTable -> Int64
refPtr :: Int64
  }

freeIterator :: JNIEnv -> Ptr JObject -> Int64 -> IO ()
freeIterator :: JNIEnv -> Ptr JObject -> Int64 -> IO ()
freeIterator JNIEnv
_ Ptr JObject
_ Int64
ptr = do
    let sptr :: StablePtr FunPtrTable
sptr = Ptr () -> StablePtr FunPtrTable
forall a. Ptr () -> StablePtr a
castPtrToStablePtr (Ptr () -> StablePtr FunPtrTable)
-> Ptr () -> StablePtr FunPtrTable
forall a b. (a -> b) -> a -> b
$ IntPtr -> Ptr ()
forall a. IntPtr -> Ptr a
intPtrToPtr (IntPtr -> Ptr ()) -> IntPtr -> Ptr ()
forall a b. (a -> b) -> a -> b
$ Int64 -> IntPtr
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int64
ptr
    FunPtrTable{Int64
refPtr :: Int64
refPtr :: FunPtrTable -> Int64
..} <- StablePtr FunPtrTable -> IO FunPtrTable
forall a. StablePtr a -> IO a
deRefStablePtr StablePtr FunPtrTable
sptr
    StablePtr Any -> IO ()
forall a. StablePtr a -> IO ()
freeStablePtr (StablePtr Any -> IO ()) -> StablePtr Any -> IO ()
forall a b. (a -> b) -> a -> b
$ Ptr () -> StablePtr Any
forall a. Ptr () -> StablePtr a
castPtrToStablePtr (Ptr () -> StablePtr Any) -> Ptr () -> StablePtr Any
forall a b. (a -> b) -> a -> b
$ IntPtr -> Ptr ()
forall a. IntPtr -> Ptr a
intPtrToPtr (IntPtr -> Ptr ()) -> IntPtr -> Ptr ()
forall a b. (a -> b) -> a -> b
$ Int64 -> IntPtr
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int64
refPtr
    StablePtr FunPtrTable -> IO ()
forall a. StablePtr a -> IO ()
freeStablePtr StablePtr FunPtrTable
sptr

-- | Reflects a stream with no batching.
newIterator
  :: forall ty. Stream (Of (J ty)) IO ()
  -> IO (J ('Iface "java.util.Iterator" <> '[ty]))
newIterator :: Stream (Of (J ty)) IO ()
-> IO (J ('Iface "java.util.Iterator" <> '[ty]))
newIterator Stream (Of (J ty)) IO ()
stream0 = do
    IORef (Stream (Of (J ty)) IO ())
ioStreamRef <- Stream (Of (J ty)) IO () -> IO (IORef (Stream (Of (J ty)) IO ()))
forall a. a -> IO (IORef a)
newIORef Stream (Of (J ty)) IO ()
stream0
    Int64
refPtr :: Int64 <- IntPtr -> Int64
forall a b. (Integral a, Num b) => a -> b
fromIntegral (IntPtr -> Int64)
-> (StablePtr (IORef (Stream (Of (J ty)) IO ())) -> IntPtr)
-> StablePtr (IORef (Stream (Of (J ty)) IO ()))
-> Int64
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Ptr () -> IntPtr
forall a. Ptr a -> IntPtr
ptrToIntPtr (Ptr () -> IntPtr)
-> (StablePtr (IORef (Stream (Of (J ty)) IO ())) -> Ptr ())
-> StablePtr (IORef (Stream (Of (J ty)) IO ()))
-> IntPtr
forall b c a. (b -> c) -> (a -> b) -> a -> c
. StablePtr (IORef (Stream (Of (J ty)) IO ())) -> Ptr ()
forall a. StablePtr a -> Ptr ()
castStablePtrToPtr (StablePtr (IORef (Stream (Of (J ty)) IO ())) -> Int64)
-> IO (StablePtr (IORef (Stream (Of (J ty)) IO ()))) -> IO Int64
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$>
      IORef (Stream (Of (J ty)) IO ())
-> IO (StablePtr (IORef (Stream (Of (J ty)) IO ())))
forall a. a -> IO (StablePtr a)
newStablePtr IORef (Stream (Of (J ty)) IO ())
ioStreamRef
    -- Keep FunPtr's in a table that can be referenced from the Java side, so
    -- that they can be freed.
    Int64
tblPtr :: Int64 <- IntPtr -> Int64
forall a b. (Integral a, Num b) => a -> b
fromIntegral (IntPtr -> Int64)
-> (StablePtr FunPtrTable -> IntPtr)
-> StablePtr FunPtrTable
-> Int64
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Ptr () -> IntPtr
forall a. Ptr a -> IntPtr
ptrToIntPtr (Ptr () -> IntPtr)
-> (StablePtr FunPtrTable -> Ptr ())
-> StablePtr FunPtrTable
-> IntPtr
forall b c a. (b -> c) -> (a -> b) -> a -> c
. StablePtr FunPtrTable -> Ptr ()
forall a. StablePtr a -> Ptr ()
castStablePtrToPtr (StablePtr FunPtrTable -> Int64)
-> IO (StablePtr FunPtrTable) -> IO Int64
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> FunPtrTable -> IO (StablePtr FunPtrTable)
forall a. a -> IO (StablePtr a)
newStablePtr FunPtrTable :: Int64 -> FunPtrTable
FunPtrTable{Int64
refPtr :: Int64
refPtr :: Int64
..}
    J ('Iface "java.util.Iterator")
iterator <-
      [java| new Iterator() {

          /// A field that the Haskell side sets to true when it reaches the end.
          private boolean end = false;

          /// Lookahead element - it always points to a valid element unless
          /// end is true. There is no constructor, so in order to initialize
          // it, next() must be invoked once.
          private Object lookahead;

          @Override
          public boolean hasNext() { return !end; }

          @Override
          public Object next() {
            if (hasNext()) {
              final Object temp = lookahead;
              lookahead = hsNext($refPtr);
              return temp;
            } else
              throw new java.util.NoSuchElementException();
          }

          @Override
          public void remove() { throw new UnsupportedOperationException(); }

          private native void hsFinalize(long tblPtr);

          private native Object hsNext(long refPtr);

          @Override
          public void finalize() { hsFinalize($tblPtr); }
        } |]
    IO () -> IO ()
forall a. IO a -> IO a
runOnce (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
      JClass
klass <- J ('Iface "java.util.Iterator") -> IO JClass
forall o (ty :: JType). Coercible o (J ty) => o -> IO JClass
JNI.getObjectClass J ('Iface "java.util.Iterator")
iterator
      JClass -> IO ()
registerNativesForIterator JClass
klass
       IO () -> IO () -> IO ()
forall (f :: * -> *) a b. Applicative f => f a -> f b -> f a
<* JClass -> IO ()
forall o (ty :: JType). Coercible o (J ty) => o -> IO ()
JNI.deleteLocalRef JClass
klass
    -- Call next once to initialize the iterator.
    () <- [java| { $iterator.next(); } |]
    J ('Iface "java.util.Iterator" <> '[ty])
-> IO (J ('Iface "java.util.Iterator" <> '[ty]))
forall (m :: * -> *) a. Monad m => a -> m a
return (J ('Iface "java.util.Iterator" <> '[ty])
 -> IO (J ('Iface "java.util.Iterator" <> '[ty])))
-> J ('Iface "java.util.Iterator" <> '[ty])
-> IO (J ('Iface "java.util.Iterator" <> '[ty]))
forall a b. (a -> b) -> a -> b
$ J ('Iface "java.util.Iterator")
-> J ('Iface "java.util.Iterator" <> '[ty])
forall (a :: JType) (g :: [JType]). J a -> J (a <> g)
generic J ('Iface "java.util.Iterator")
iterator
  where
    -- Given that we always register natives on the same class,
    -- there is no point in registering natives more than once.
    runOnce :: IO a -> IO a
    runOnce :: IO a -> IO a
runOnce IO a
action = do
      let {-# NOINLINE ref #-}
          ref :: IORef (Maybe a)
ref = IO (IORef (Maybe a)) -> IORef (Maybe a)
forall a. IO a -> a
unsafePerformIO (IO (IORef (Maybe a)) -> IORef (Maybe a))
-> IO (IORef (Maybe a)) -> IORef (Maybe a)
forall a b. (a -> b) -> a -> b
$ Maybe a -> IO (IORef (Maybe a))
forall a. a -> IO (IORef a)
newIORef Maybe a
forall a. Maybe a
Nothing
      IORef (Maybe a) -> IO (Maybe a)
forall a. IORef a -> IO a
readIORef IORef (Maybe a)
forall a. IORef (Maybe a)
ref IO (Maybe a) -> (Maybe a -> IO a) -> IO a
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
        Maybe a
Nothing -> do
          a
a <- IO a
action
          IORef (Maybe a) -> Maybe a -> IO ()
forall a. IORef a -> a -> IO ()
writeIORef IORef (Maybe a)
forall a. IORef (Maybe a)
ref (a -> Maybe a
forall a. a -> Maybe a
Just a
a)
          a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return a
a
        Just a
a ->
          a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return a
a

-- | Registers functions for the native methods of the inner class created in
-- 'newIterator'.
--
-- We keep this helper as a top-level function to ensure that no state tied
-- to a particular iterator leaks in the registered functions. The methods
-- registered here affect all the instances of the inner class.
registerNativesForIterator :: JClass -> IO ()
registerNativesForIterator :: JClass -> IO ()
registerNativesForIterator JClass
klass = do
    JFieldID
fieldEndId <- JClass -> String -> Signature -> IO JFieldID
JNI.getFieldID JClass
klass String
"end"
                    (Sing ('Prim "boolean") -> Signature
forall (ty :: JType). Sing ty -> Signature
JNI.signature (Sing ('Prim "boolean")
forall k (a :: k). SingI a => Sing a
sing :: Sing ('Prim "boolean")))
    FunPtr (JNIFun (Ptr (J Any)))
nextPtr <- JNIFun (Ptr (J Any)) -> IO (FunPtr (JNIFun (Ptr (J Any))))
forall (ty :: JType).
JNIFun (Ptr (J ty)) -> IO (FunPtr (JNIFun (Ptr (J ty))))
wrapObjectFun (JNIFun (Ptr (J Any)) -> IO (FunPtr (JNIFun (Ptr (J Any)))))
-> JNIFun (Ptr (J Any)) -> IO (FunPtr (JNIFun (Ptr (J Any))))
forall a b. (a -> b) -> a -> b
$ \JNIEnv
_ Ptr JObject
jthis Int64
streamRef ->
      -- Conversion is safe, because result is always a reflected object.
      ForeignPtr (J Any) -> Ptr (J Any)
forall a. ForeignPtr a -> Ptr a
unsafeForeignPtrToPtr (ForeignPtr (J Any) -> Ptr (J Any))
-> (J Any -> ForeignPtr (J Any)) -> J Any -> Ptr (J Any)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. J Any -> ForeignPtr (J Any)
Coerce.coerce (J Any -> Ptr (J Any)) -> IO (J Any) -> IO (Ptr (J Any))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$>
        JFieldID -> Ptr JObject -> Int64 -> IO (J Any)
forall (ty :: JType). JFieldID -> Ptr JObject -> Int64 -> IO (J ty)
popStream JFieldID
fieldEndId Ptr JObject
jthis Int64
streamRef
    JClass -> [JNINativeMethod] -> IO ()
JNI.registerNatives JClass
klass
      [ String
-> MethodSignature
-> FunPtr (JNIFun (Ptr (J Any)))
-> JNINativeMethod
forall a. String -> MethodSignature -> FunPtr a -> JNINativeMethod
JNI.JNINativeMethod
          String
"hsNext"
          ([SomeSing JType]
-> Sing ('Class "java.lang.Object") -> MethodSignature
forall (ty :: JType).
[SomeSing JType] -> Sing ty -> MethodSignature
methodSignature
            [Sing ('Prim "long") -> SomeSing JType
forall k (a :: k). Sing a -> SomeSing k
SomeSing (Sing ('Prim "long")
forall k (a :: k). SingI a => Sing a
sing :: Sing ('Prim "long"))]
            (Sing ('Class "java.lang.Object")
forall k (a :: k). SingI a => Sing a
sing :: Sing ('Class "java.lang.Object"))
          )
          FunPtr (JNIFun (Ptr (J Any)))
nextPtr
      , String
-> MethodSignature
-> FunPtr (JNIEnv -> Ptr JObject -> Int64 -> IO ())
-> JNINativeMethod
forall a. String -> MethodSignature -> FunPtr a -> JNINativeMethod
JNI.JNINativeMethod
          String
"hsFinalize"
          ([SomeSing JType] -> Sing 'Void -> MethodSignature
forall (ty :: JType).
[SomeSing JType] -> Sing ty -> MethodSignature
methodSignature [Sing ('Prim "long") -> SomeSing JType
forall k (a :: k). Sing a -> SomeSing k
SomeSing (Sing ('Prim "long")
forall k (a :: k). SingI a => Sing a
sing :: Sing ('Prim "long"))] (Sing 'Void
forall k (a :: k). SingI a => Sing a
sing :: Sing 'Void))
          FunPtr (JNIEnv -> Ptr JObject -> Int64 -> IO ())
freeIteratorPtr
      ]
  where
    popStream :: JFieldID -> Ptr JObject -> Int64 -> IO (J ty)
    popStream :: JFieldID -> Ptr JObject -> Int64 -> IO (J ty)
popStream JFieldID
fieldEndId Ptr JObject
ptrThis Int64
streamRef = do
      let stableRef :: StablePtr (IORef (Stream (Of (J ty)) IO Any))
stableRef = Ptr () -> StablePtr (IORef (Stream (Of (J ty)) IO Any))
forall a. Ptr () -> StablePtr a
castPtrToStablePtr (Ptr () -> StablePtr (IORef (Stream (Of (J ty)) IO Any)))
-> Ptr () -> StablePtr (IORef (Stream (Of (J ty)) IO Any))
forall a b. (a -> b) -> a -> b
$ IntPtr -> Ptr ()
forall a. IntPtr -> Ptr a
intPtrToPtr (IntPtr -> Ptr ()) -> IntPtr -> Ptr ()
forall a b. (a -> b) -> a -> b
$ Int64 -> IntPtr
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int64
streamRef
      IORef (Stream (Of (J ty)) IO Any)
ref <- StablePtr (IORef (Stream (Of (J ty)) IO Any))
-> IO (IORef (Stream (Of (J ty)) IO Any))
forall a. StablePtr a -> IO a
deRefStablePtr StablePtr (IORef (Stream (Of (J ty)) IO Any))
stableRef
      Stream (Of (J ty)) IO Any
stream <- IORef (Stream (Of (J ty)) IO Any) -> IO (Stream (Of (J ty)) IO Any)
forall a. IORef a -> IO a
readIORef IORef (Stream (Of (J ty)) IO Any)
ref
      Stream (Of (J ty)) IO Any
-> IO (Maybe (J ty, Stream (Of (J ty)) IO Any))
forall (m :: * -> *) a r.
Monad m =>
Stream (Of a) m r -> m (Maybe (a, Stream (Of a) m r))
Streaming.uncons Stream (Of (J ty)) IO Any
stream IO (Maybe (J ty, Stream (Of (J ty)) IO Any))
-> (Maybe (J ty, Stream (Of (J ty)) IO Any) -> IO (J ty))
-> IO (J ty)
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
        Maybe (J ty, Stream (Of (J ty)) IO Any)
Nothing -> do
          JObject
jthis <- Ptr JObject -> IO JObject
forall (a :: JType). Ptr (J a) -> IO (J a)
JNI.objectFromPtr Ptr JObject
ptrThis
          -- When the stream ends, set the end field to True
          -- so the Iterator knows not to call hsNext again.
          JObject -> JFieldID -> Word8 -> IO ()
forall o (a :: JType).
Coercible o (J a) =>
o -> JFieldID -> Word8 -> IO ()
JNI.setBooleanField JObject
jthis JFieldID
fieldEndId Word8
1
          J ty -> IO (J ty)
forall (m :: * -> *) a. Monad m => a -> m a
return J ty
forall (a :: JType). J a
jnull
        Just (J ty
x, Stream (Of (J ty)) IO Any
stream') -> do
          IORef (Stream (Of (J ty)) IO Any)
-> Stream (Of (J ty)) IO Any -> IO ()
forall a. IORef a -> a -> IO ()
writeIORef IORef (Stream (Of (J ty)) IO Any)
ref Stream (Of (J ty)) IO Any
stream'
          J ty -> IO (J ty)
forall (m :: * -> *) a. Monad m => a -> m a
return J ty
x

-- | Reifies streams from iterators in batches of the given size.
reifyStreamWithBatching
  :: forall a. BatchReify a
  => Int32  -- ^ The batch size
  -> J ('Iface "java.util.Iterator" <> '[Interp a])
  -> IO (Stream (Of a) IO ())
reifyStreamWithBatching :: Int32
-> J ('Iface "java.util.Iterator" <> '[Interp a])
-> IO (Stream (Of a) IO ())
reifyStreamWithBatching Int32
batchSize J ('Iface "java.util.Iterator" <> '[Interp a])
jiterator0 = do
    let jiterator1 :: J ('Iface "java.util.Iterator")
jiterator1 = J ('Iface "java.util.Iterator" <> '[Interp a])
-> J ('Iface "java.util.Iterator")
forall (a :: JType) (g :: [JType]). J (a <> g) -> J a
unsafeUngeneric J ('Iface "java.util.Iterator" <> '[Interp a])
jiterator0
    J ('Iface "io.tweag.jvm.batching.BatchWriter")
jbatcher <- J ('Iface "io.tweag.jvm.batching.BatchWriter"
   <> '[Interp a, Batch a])
-> J ('Iface "io.tweag.jvm.batching.BatchWriter")
forall (a :: JType) (g :: [JType]). J (a <> g) -> J a
unsafeUngeneric (J ('Iface "io.tweag.jvm.batching.BatchWriter"
    <> '[Interp a, Batch a])
 -> J ('Iface "io.tweag.jvm.batching.BatchWriter"))
-> IO
     (J ('Iface "io.tweag.jvm.batching.BatchWriter"
         <> '[Interp a, Batch a]))
-> IO (J ('Iface "io.tweag.jvm.batching.BatchWriter"))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Proxy a
-> IO
     (J ('Iface "io.tweag.jvm.batching.BatchWriter"
         <> '[Interp a, Batch a]))
forall a (proxy :: * -> *).
BatchReify a =>
proxy a
-> IO
     (J ('Iface "io.tweag.jvm.batching.BatchWriter"
         <> '[Interp a, Batch a]))
newBatchWriter (Proxy a
forall k (t :: k). Proxy t
Proxy :: Proxy a)
    J ('Iface "java.util.Iterator")
jiterator <- [java| new Iterator() {
        private final int batchSize = $batchSize;
        private final Iterator it = $jiterator1;
        private final BatchWriter batcher = $jbatcher;
        public int count = 0;

        @Override
        public boolean hasNext() { return it.hasNext(); }

        @Override
        public Object next() {
          int i = 0;
          batcher.start(batchSize);
          while (it.hasNext() && i < batchSize) {
            batcher.set(i, it.next());
            i++;
          }
          count = i;
          return batcher.getBatch();
        }

        @Override
        public void remove() {
          throw new UnsupportedOperationException();
        }
      } |]
      IO (J ('Iface "java.util.Iterator"))
-> (J ('Iface "java.util.Iterator")
    -> IO (J ('Iface "java.util.Iterator")))
-> IO (J ('Iface "java.util.Iterator"))
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= J ('Iface "java.util.Iterator")
-> IO (J ('Iface "java.util.Iterator"))
forall o (ty :: JType). Coercible o (J ty) => o -> IO o
JNI.newGlobalRef
      :: IO (J ('Iface "java.util.Iterator"))
    JClass
cls <- J ('Iface "java.util.Iterator") -> IO JClass
forall o (ty :: JType). Coercible o (J ty) => o -> IO JClass
JNI.getObjectClass J ('Iface "java.util.Iterator")
jiterator IO JClass -> (JClass -> IO JClass) -> IO JClass
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= JClass -> IO JClass
forall o (ty :: JType). Coercible o (J ty) => o -> IO o
JNI.newGlobalRef
    JFieldID
fieldId <- JClass -> String -> Signature -> IO JFieldID
JNI.getFieldID JClass
cls String
"count"
                 (Sing ('Prim "int") -> Signature
forall (ty :: JType). Sing ty -> Signature
JNI.signature (Sing ('Prim "int")
forall k (a :: k). SingI a => Sing a
sing :: Sing ('Prim "int")))

    let go :: Int        -- next element to return from the batch
           -> V.Vector a -- current batch of elements
           -> Stream (Of a) IO ()
        go :: Int -> Vector a -> Stream (Of a) IO ()
go Int
i Vector a
v =
          if Vector a -> Int
forall a. Vector a -> Int
V.length Vector a
v Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
i then do
            Bool
hasNext <- IO Bool -> Stream (Of a) IO Bool
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO [java| $jiterator.hasNext() |]
            if Bool
hasNext then do
              Vector a
v' <- IO (Vector a) -> Stream (Of a) IO (Vector a)
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO (Vector a) -> Stream (Of a) IO (Vector a))
-> IO (Vector a) -> Stream (Of a) IO (Vector a)
forall a b. (a -> b) -> a -> b
$
                [java| $jiterator.next() |] IO JObject -> (JObject -> IO (Vector a)) -> IO (Vector a)
forall (m :: * -> *) o (ty :: JType) a.
(MonadMask m, MonadIO m, Coercible o (J ty)) =>
m o -> (o -> m a) -> m a
`withLocalRef` \JObject
jbatch ->
                  J ('Iface "java.util.Iterator") -> JFieldID -> IO Int32
forall o (a :: JType).
Coercible o (J a) =>
o -> JFieldID -> IO Int32
JNI.getIntField J ('Iface "java.util.Iterator")
jiterator JFieldID
fieldId
                  IO Int32 -> (Int32 -> IO (Vector a)) -> IO (Vector a)
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= J (Batch a) -> Int32 -> IO (Vector a)
forall a. BatchReify a => J (Batch a) -> Int32 -> IO (Vector a)
reifyBatch (JObject -> J (Batch a)
forall (a :: JType) (b :: JType). J a -> J b
unsafeCast (JObject
jbatch :: JObject) :: J (Batch a))
              Int -> Vector a -> Stream (Of a) IO ()
go Int
0 Vector a
v'
            else
              IO () -> Stream (Of a) IO ()
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> Stream (Of a) IO ()) -> IO () -> Stream (Of a) IO ()
forall a b. (a -> b) -> a -> b
$ do
                J ('Iface "java.util.Iterator") -> IO ()
forall o (ty :: JType). Coercible o (J ty) => o -> IO ()
JNI.deleteGlobalRef J ('Iface "java.util.Iterator")
jiterator
                JClass -> IO ()
forall o (ty :: JType). Coercible o (J ty) => o -> IO ()
JNI.deleteGlobalRef JClass
cls
          else do
            a -> Stream (Of a) IO ()
forall (m :: * -> *) a. Monad m => a -> Stream (Of a) m ()
Streaming.yield (a -> Stream (Of a) IO ()) -> a -> Stream (Of a) IO ()
forall a b. (a -> b) -> a -> b
$ Vector a
v Vector a -> Int -> a
forall a. Vector a -> Int -> a
V.! Int
i
            Int -> Vector a -> Stream (Of a) IO ()
go (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) Vector a
v

    Stream (Of a) IO () -> IO (Stream (Of a) IO ())
forall (m :: * -> *) a. Monad m => a -> m a
return (Stream (Of a) IO () -> IO (Stream (Of a) IO ()))
-> Stream (Of a) IO () -> IO (Stream (Of a) IO ())
forall a b. (a -> b) -> a -> b
$ Int -> Vector a -> Stream (Of a) IO ()
go Int
0 Vector a
forall a. Vector a
V.empty

-- | Reflects streams to iterators in batches of the given size.
reflectStreamWithBatching
  :: forall a. BatchReflect a
  => Int  -- ^ The batch size
  -> Stream (Of a) IO ()
  -> IO (J ('Iface "java.util.Iterator" <> '[Interp a]))
reflectStreamWithBatching :: Int
-> Stream (Of a) IO ()
-> IO (J ('Iface "java.util.Iterator" <> '[Interp a]))
reflectStreamWithBatching Int
batchSize Stream (Of a) IO ()
s0 = do
    J ('Iface "java.util.Iterator")
jiterator <- J ('Iface "java.util.Iterator" <> '[Batch a])
-> J ('Iface "java.util.Iterator")
forall (a :: JType) (g :: [JType]). J (a <> g) -> J a
unsafeUngeneric (J ('Iface "java.util.Iterator" <> '[Batch a])
 -> J ('Iface "java.util.Iterator"))
-> IO (J ('Iface "java.util.Iterator" <> '[Batch a]))
-> IO (J ('Iface "java.util.Iterator"))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$>
      (Stream (Of (Vector a)) IO ()
-> IO (J ('Iface "java.util.Iterator" <> '[Batch a]))
reflectStream (Stream (Of (Vector a)) IO ()
 -> IO (J ('Iface "java.util.Iterator" <> '[Batch a])))
-> Stream (Of (Vector a)) IO ()
-> IO (J ('Iface "java.util.Iterator" <> '[Batch a]))
forall a b. (a -> b) -> a -> b
$ (forall x. Stream (Of a) IO x -> IO (Of (Vector a) x))
-> Stream (Stream (Of a) IO) IO () -> Stream (Of (Vector a)) IO ()
forall (m :: * -> *) (f :: * -> *) (g :: * -> *) r.
(Monad m, Functor f) =>
(forall x. f x -> m (g x)) -> Stream f m r -> Stream g m r
Streaming.mapsM
                        (\Stream (Of a) IO x
s -> ([a] -> Vector a) -> Of [a] x -> Of (Vector a) x
forall (p :: * -> * -> *) a b c.
Bifunctor p =>
(a -> b) -> p a c -> p b c
first [a] -> Vector a
forall a. [a] -> Vector a
V.fromList (Of [a] x -> Of (Vector a) x)
-> IO (Of [a] x) -> IO (Of (Vector a) x)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Stream (Of a) IO x -> IO (Of [a] x)
forall (m :: * -> *) a r.
Monad m =>
Stream (Of a) m r -> m (Of [a] r)
Streaming.toList Stream (Of a) IO x
s)
                     (Stream (Stream (Of a) IO) IO () -> Stream (Of (Vector a)) IO ())
-> Stream (Stream (Of a) IO) IO () -> Stream (Of (Vector a)) IO ()
forall a b. (a -> b) -> a -> b
$ Int -> Stream (Of a) IO () -> Stream (Stream (Of a) IO) IO ()
forall (m :: * -> *) (f :: * -> *) r.
(Monad m, Functor f) =>
Int -> Stream f m r -> Stream (Stream f m) m r
Streaming.chunksOf Int
batchSize Stream (Of a) IO ()
s0
      )
    J ('Iface "io.tweag.jvm.batching.BatchReader")
jbatchReader <- J ('Iface "io.tweag.jvm.batching.BatchReader"
   <> '[Batch a, Interp a])
-> J ('Iface "io.tweag.jvm.batching.BatchReader")
forall (a :: JType) (g :: [JType]). J (a <> g) -> J a
unsafeUngeneric (J ('Iface "io.tweag.jvm.batching.BatchReader"
    <> '[Batch a, Interp a])
 -> J ('Iface "io.tweag.jvm.batching.BatchReader"))
-> IO
     (J ('Iface "io.tweag.jvm.batching.BatchReader"
         <> '[Batch a, Interp a]))
-> IO (J ('Iface "io.tweag.jvm.batching.BatchReader"))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Proxy a
-> IO
     (J ('Iface "io.tweag.jvm.batching.BatchReader"
         <> '[Batch a, Interp a]))
forall a (proxy :: * -> *).
BatchReflect a =>
proxy a
-> IO
     (J ('Iface "io.tweag.jvm.batching.BatchReader"
         <> '[Batch a, Interp a]))
newBatchReader (Proxy a
forall k (t :: k). Proxy t
Proxy :: Proxy a)
    J ('Iface "java.util.Iterator")
-> J ('Iface "java.util.Iterator" <> '[Interp a])
forall (a :: JType) (g :: [JType]). J a -> J (a <> g)
generic (J ('Iface "java.util.Iterator")
 -> J ('Iface "java.util.Iterator" <> '[Interp a]))
-> IO (J ('Iface "java.util.Iterator"))
-> IO (J ('Iface "java.util.Iterator" <> '[Interp a]))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [java| new Iterator() {
        private final Iterator it = $jiterator;
        private final BatchReader batchReader = $jbatchReader;
        private int count = 0;

        @Override
        public boolean hasNext() {
          return count < batchReader.getSize() || it.hasNext();
        }
        @Override
        public Object next() {
          if (count == batchReader.getSize()) {
            batchReader.setBatch(it.next());
            count = 0;
          }
          Object o = batchReader.get(count);
          count++;
          return o;
        }
        @Override
        public void remove() {
          throw new UnsupportedOperationException();
        }
      } |]
  where
    reflectStream :: Stream (Of (V.Vector a)) IO ()
                  -> IO (J ('Iface "java.util.Iterator" <> '[Batch a]))
    reflectStream :: Stream (Of (Vector a)) IO ()
-> IO (J ('Iface "java.util.Iterator" <> '[Batch a]))
reflectStream = Stream (Of (J (Batch a))) IO ()
-> IO (J ('Iface "java.util.Iterator" <> '[Batch a]))
forall (ty :: JType).
Stream (Of (J ty)) IO ()
-> IO (J ('Iface "java.util.Iterator" <> '[ty]))
newIterator (Stream (Of (J (Batch a))) IO ()
 -> IO (J ('Iface "java.util.Iterator" <> '[Batch a])))
-> (Stream (Of (Vector a)) IO ()
    -> Stream (Of (J (Batch a))) IO ())
-> Stream (Of (Vector a)) IO ()
-> IO (J ('Iface "java.util.Iterator" <> '[Batch a]))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Vector a -> IO (J (Batch a)))
-> Stream (Of (Vector a)) IO () -> Stream (Of (J (Batch a))) IO ()
forall (m :: * -> *) a b r.
Monad m =>
(a -> m b) -> Stream (Of a) m r -> Stream (Of b) m r
Streaming.mapM Vector a -> IO (J (Batch a))
forall a. BatchReflect a => Vector a -> IO (J (Batch a))
reflectBatch

withStatic [d|
  instance Interpretation (Stream (Of a) m r) where
    type Interp (Stream (Of a) m r) = 'Iface "java.util.Iterator"

  instance BatchReify a => Reify (Stream (Of a) IO ()) where
    reify = reifyStreamWithBatching 1024 . generic

  instance BatchReflect a => Reflect (Stream (Of a) IO ()) where
    reflect = fmap unsafeUngeneric . reflectStreamWithBatching 1024
  |]