Skip to content

Commit 566fa69

Browse files
committed
Add to_dataset to IOIterableTensor (e.g., Kafka) as well
Signed-off-by: Yong Tang <[email protected]>
1 parent 302db61 commit 566fa69

File tree

4 files changed

+82
-29
lines changed

4 files changed

+82
-29
lines changed

tensorflow_io/core/kernels/io_interface.h

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -160,11 +160,7 @@ class IOIterableNextOp : public OpKernel {
160160
OP_REQUIRES_OK(context, resource->Next(capacity, tensors, &record_read));
161161
for (size_t i = 0; i < tensors.size(); i++) {
162162
if (record_read < capacity) {
163-
gtl::InlinedVector<int64, 4> dims = shapes[i].dim_sizes();
164-
dims[0] = record_read;
165-
Tensor value;
166-
value.CopyFrom(tensors[i], TensorShape(dims));
167-
context->set_output(i, value);
163+
context->set_output(i, tensors[i].Slice(0, record_read));
168164
} else {
169165
context->set_output(i, tensors[i]);
170166
}

tensorflow_io/core/python/ops/io_tensor.py

Lines changed: 73 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from __future__ import division
1818
from __future__ import print_function
1919

20+
import sys
2021
import collections
2122

2223
import tensorflow as tf
@@ -325,7 +326,55 @@ def __init__(self,
325326
# Iterator
326327
#=============================================================================
327328
def __iter__(self):
328-
return self._function()
329+
resource = self._function["init"](self._function["data"])
330+
capacity = 1
331+
while True:
332+
value = self._function["next"](resource, capacity=capacity)
333+
if tf.shape(value)[0].numpy() < capacity:
334+
return
335+
yield value
336+
337+
#=============================================================================
338+
# Dataset Conversions
339+
#=============================================================================
340+
341+
def to_dataset(self):
342+
"""Converts this `IOIterableTensor` into a `tf.data.Dataset`.
343+
344+
Example:
345+
346+
```python
347+
```
348+
349+
Args:
350+
351+
Returns:
352+
A `tf.data.Dataset` with value obtained from this `IOIterableTensor`.
353+
"""
354+
class _IOIterableTensorDataset(data_ops.BaseDataset):
355+
"""_IOIterableTensorDataset"""
356+
357+
def __init__(self, dtype, shape, function):
358+
func_init = function["init"]
359+
func_next = function["next"]
360+
func_data = function["data"]
361+
resource = func_init(func_data)
362+
capacity = 4096
363+
dataset = data_ops.BaseDataset.range(
364+
0, sys.maxsize, capacity).map(
365+
lambda i: func_next(resource, capacity)).apply(
366+
tf.data.experimental.take_while(
367+
lambda v: tf.greater(tf.shape(v)[0], 0))).apply(
368+
tf.data.experimental.unbatch())
369+
self._dataset = dataset
370+
self._resource = resource
371+
self._function = function
372+
shape = shape[1:]
373+
super(_IOIterableTensorDataset, self).__init__(
374+
self._dataset._variant_tensor, [dtype], [shape]) # pylint: disable=protected-access
375+
376+
return _IOIterableTensorDataset(
377+
self._dtype, self._shape, self._function)
329378

330379
class AudioIOTensor(IOTensor):
331380
"""AudioIOTensor"""
@@ -369,30 +418,31 @@ def __init__(self,
369418
subscription, servers, timeout, eof, conf,
370419
internal=False):
371420
with tf.name_scope("KafkaIOTensor") as scope:
372-
dtype = tf.string
373-
shape = tf.TensorShape([None])
374421

375-
def function():
376-
"""_iter_func"""
377-
metadata = []
378-
if servers is not None:
379-
metadata.append("bootstrap.servers=%s" % servers)
380-
if timeout is not None:
381-
metadata.append("conf.timeout=%d" % timeout)
382-
if eof is not None:
383-
metadata.append("conf.eof=%d" % (1 if eof else 0))
384-
if conf is not None:
385-
for e in conf:
386-
metadata.append(e)
422+
metadata = []
423+
if servers is not None:
424+
metadata.append("bootstrap.servers=%s" % servers)
425+
if timeout is not None:
426+
metadata.append("conf.timeout=%d" % timeout)
427+
if eof is not None:
428+
metadata.append("conf.eof=%d" % (1 if eof else 0))
429+
if conf is not None:
430+
for e in conf:
431+
metadata.append(e)
432+
433+
func_data = {"subscription": subscription, "metadata": metadata}
434+
def func_init(data):
435+
"""func_init"""
387436
resource, _, _ = kafka_ops.kafka_iterable_init(
388-
subscription, metadata=metadata,
437+
data["subscription"], metadata=data["metadata"],
389438
container=scope, shared_name=subscription)
390-
capacity = 1
391-
while True:
392-
value = kafka_ops.kafka_iterable_next(resource, capacity=capacity)
393-
if tf.shape(value)[0].numpy() < capacity:
394-
return
395-
yield value
439+
return resource
440+
func_next = kafka_ops.kafka_iterable_next
441+
442+
dtype = tf.string
443+
shape = tf.TensorShape([None])
396444

397445
super(KafkaIOTensor, self).__init__(
398-
dtype, shape, function, properties=None, internal=internal)
446+
dtype, shape,
447+
{"init": func_init, "next": func_next, "data": func_data},
448+
properties=None, internal=internal)

tensorflow_io/kafka/kernels/kafka_kernels.cc

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -179,12 +179,13 @@ class KafkaIterable : public IOIterableInterface {
179179
}
180180
Status Next(const int64 capacity, std::vector<Tensor>& tensors, int64* record_read) override {
181181
*record_read = 0;
182-
while ((*record_read) < capacity) {
182+
while (consumer_.get() != nullptr && (*record_read) < capacity) {
183183
if (!kafka_event_cb_.run()) {
184184
return errors::Internal("failed to consume due to all brokers down");
185185
}
186186
if (range_.second >= 0 && (subscription_->offset() >= range_.second || offset_ >= range_.second)) {
187187
// EOF of topic
188+
consumer_.reset(nullptr);
188189
return Status::OK();
189190
}
190191

@@ -200,6 +201,7 @@ class KafkaIterable : public IOIterableInterface {
200201
if (message->err() == RdKafka::ERR__PARTITION_EOF) {
201202
LOG(INFO) << "Partition reach EOF, current offset: " << offset_;
202203
if (eof_) {
204+
consumer_.reset(nullptr);
203205
return Status::OK();
204206
}
205207
}

tests/test_kafka_eager.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,11 @@ def test_kafka_io_tensor():
3434
e.numpy() for e in kafka] == [
3535
("D" + str(i)).encode() for i in range(10)])
3636

37+
dataset = kafka.to_dataset().batch(2)
38+
assert np.all([
39+
e.numpy().tolist() for e in dataset] == np.asarray([
40+
("D" + str(i)).encode() for i in range(10)]).reshape((5, 2)))
41+
3742
@pytest.mark.skipif(
3843
not (hasattr(tf, "version") and
3944
tf.version.VERSION.startswith("2.0.")), reason=None)

0 commit comments

Comments
 (0)