Skip to content

Commit 732579c

Browse files
committed
Add KafkaIOTensor which stores data in memory (so that it is indexable)
This is build around the same code base as KafkaDataset C++. Signed-off-by: Yong Tang <[email protected]>
1 parent 2458835 commit 732579c

File tree

7 files changed

+256
-5
lines changed

7 files changed

+256
-5
lines changed

tensorflow_io/audio/kernels/audio_kernels.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ class WAVIndexable : public IOIndexableInterface {
7373
: env_(env) {}
7474

7575
~WAVIndexable() {}
76-
Status Init(const std::vector<string>& input, const std::vector<string>& metadata, const void* memory_data, const int64 memory_size) override {
76+
Status Init(const std::vector<string>& input, const std::vector<string>& metadata, const void* memory_data, const int64 memory_size) override {
7777
if (input.size() > 1) {
7878
return errors::InvalidArgument("more than 1 filename is not supported");
7979
}

tensorflow_io/core/kernels/io_interface.h

Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,122 @@ class IOIndexableInterface : public IOInterface {
4141
virtual Status GetItem(const int64 start, const int64 stop, const int64 step, std::vector<Tensor>& tensors) = 0;
4242
};
4343

44+
template<typename Type>
45+
class IOIndexableImplementation : public IOIndexableInterface {
46+
public:
47+
IOIndexableImplementation<Type>(Env* env)
48+
: env_(env)
49+
, iterable_(new Type(env)) {}
50+
51+
~IOIndexableImplementation<Type>() {}
52+
Status Init(const std::vector<string>& input, const std::vector<string>& metadata, const void* memory_data, const int64 memory_size) override {
53+
54+
TF_RETURN_IF_ERROR(iterable_->Init(input, metadata, memory_data, memory_size));
55+
TF_RETURN_IF_ERROR(iterable_->Spec(dtypes_, shapes_));
56+
57+
const int64 capacity = 4096;
58+
std::vector<TensorShape> chunk_shapes;
59+
for (size_t component = 0; component < shapes_.size(); component++) {
60+
gtl::InlinedVector<int64, 4> dims = shapes_[component].dim_sizes();
61+
dims[0] = capacity;
62+
chunk_shapes.push_back(TensorShape(dims));
63+
}
64+
65+
int64 total = 0;
66+
67+
int64 record_read = 0;
68+
do {
69+
tensors_.push_back(std::vector<Tensor>());
70+
for (size_t component = 0; component < shapes_.size(); component++) {
71+
tensors_.back().push_back(Tensor(dtypes_[component], chunk_shapes[component]));
72+
}
73+
TF_RETURN_IF_ERROR(iterable_->Next(capacity, tensors_.back(), &record_read));
74+
if (record_read == 0) {
75+
tensors_.pop_back();
76+
break;
77+
}
78+
if (record_read < capacity) {
79+
for (size_t component = 0; component < shapes_.size(); component++) {
80+
tensors_.back()[component] = tensors_.back()[component].Slice(0, record_read);
81+
}
82+
}
83+
total += record_read;
84+
} while (record_read != 0);
85+
for (size_t component = 0; component < shapes_.size(); component++) {
86+
shapes_[component].set_dim(0, total);
87+
}
88+
return Status::OK();
89+
}
90+
virtual Status Spec(std::vector<DataType>& dtypes, std::vector<PartialTensorShape>& shapes) override {
91+
for (size_t component = 0; component < dtypes_.size(); component++) {
92+
dtypes.push_back(dtypes_[component]);
93+
}
94+
for (size_t component = 0; component < shapes_.size(); component++) {
95+
shapes.push_back(shapes_[component]);
96+
}
97+
return Status::OK();
98+
}
99+
100+
Status Extra(std::vector<Tensor>* extra) override {
101+
return iterable_->Extra(extra);
102+
}
103+
string DebugString() const override {
104+
mutex_lock l(mu_);
105+
return strings::StrCat("IOIndexableImplementation<", iterable_->DebugString(), ">[]");
106+
}
107+
108+
Status GetItem(const int64 start, const int64 stop, const int64 step, std::vector<Tensor>& tensors) override {
109+
if (step != 1) {
110+
return errors::InvalidArgument("step != 1 is not supported: ", step);
111+
}
112+
// Find first chunk
113+
114+
int64 chunk_index = 0;
115+
int64 chunk_element = -1;
116+
int64 current_element = 0;
117+
while (chunk_index < tensors_.size()) {
118+
if (current_element <= start && start < current_element + tensors_[chunk_index][0].shape().dim_size(0)) {
119+
chunk_element = start - current_element;
120+
current_element = start;
121+
break;
122+
}
123+
current_element += tensors_[chunk_index][0].shape().dim_size(0);
124+
chunk_index++;
125+
}
126+
if (chunk_element < 0) {
127+
return errors::InvalidArgument("start is out of range: ", start);
128+
}
129+
std::vector<Tensor> elements;
130+
for (size_t component = 0; component < shapes_.size(); component++) {
131+
TensorShape shape(shapes_[component].dim_sizes());
132+
shape.RemoveDim(0);
133+
elements.push_back(Tensor(dtypes_[component], shape));
134+
}
135+
136+
while (current_element < stop) {
137+
for (size_t component = 0; component < shapes_.size(); component++) {
138+
batch_util::CopySliceToElement(tensors_[chunk_index][component], &elements[component], chunk_element);
139+
batch_util::CopyElementToSlice(elements[component], &tensors[component], (current_element - start));
140+
}
141+
chunk_element++;
142+
if (chunk_element == tensors_[chunk_index][0].shape().dim_size(0)) {
143+
chunk_index++;
144+
chunk_element = 0;
145+
}
146+
current_element++;
147+
}
148+
return Status::OK();
149+
}
150+
private:
151+
mutable mutex mu_;
152+
Env* env_ GUARDED_BY(mu_);
153+
std::unique_ptr<Type> iterable_ GUARDED_BY(mu_);
154+
std::vector<DataType> dtypes_ GUARDED_BY(mu_);
155+
std::vector<PartialTensorShape> shapes_ GUARDED_BY(mu_);
156+
std::vector<std::vector<Tensor>> tensors_;
157+
};
158+
159+
44160
template<typename Type>
45161
class IOInterfaceInitOp : public ResourceOpKernel<Type> {
46162
public:

tensorflow_io/core/python/ops/io_tensor.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from tensorflow_io.core.python.ops import io_tensor_ops
2222
from tensorflow_io.core.python.ops import audio_io_tensor_ops
2323
from tensorflow_io.core.python.ops import json_io_tensor_ops
24+
from tensorflow_io.core.python.ops import kafka_io_tensor_ops
2425

2526
class IOTensor(io_tensor_ops._BaseIOTensor): # pylint: disable=protected-access
2627
"""IOTensor
@@ -225,3 +226,41 @@ def from_json(cls,
225226
"""
226227
with tf.name_scope(kwargs.get("name", "IOFromJSON")):
227228
return json_io_tensor_ops.JSONIOTensor(filename, internal=True)
229+
230+
@classmethod
231+
def from_kafka(cls,
232+
subscription,
233+
**kwargs):
234+
"""Creates an `IOTensor` from a Kafka stream.
235+
236+
Args:
237+
subscription: A `tf.string` tensor containing subscription,
238+
in the format of [topic:partition:offset:length],
239+
by default length is -1 for unlimited.
240+
servers: An optional list of bootstrap servers, by default
241+
`localhost:9092`.
242+
configuration: An optional `tf.string` tensor containing
243+
configurations in [Key=Value] format. There are three
244+
types of configurations:
245+
Global configuration: please refer to 'Global configuration properties'
246+
in librdkafka doc. Examples include
247+
["enable.auto.commit=false", "heartbeat.interval.ms=2000"]
248+
Topic configuration: please refer to 'Topic configuration properties'
249+
in librdkafka doc. Note all topic configurations should be
250+
prefixed with `configuration.topic.`. Examples include
251+
["conf.topic.auto.offset.reset=earliest"]
252+
Dataset configuration: there are two configurations available,
253+
`conf.eof=0|1`: if True, the KafkaDaset will stop on EOF (default).
254+
`conf.timeout=milliseconds`: timeout value for Kafka Consumer to wait.
255+
name: A name prefix for the IOTensor (optional).
256+
257+
Returns:
258+
A `IOTensor`.
259+
260+
"""
261+
with tf.name_scope(kwargs.get("name", "IOFromKafka")):
262+
return kafka_io_tensor_ops.KafkaIOTensor(
263+
subscription,
264+
servers=kwargs.get("servers", None),
265+
configuration=kwargs.get("configuration", None),
266+
internal=True)
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
"""KafkaIOTensor"""
16+
from __future__ import absolute_import
17+
from __future__ import division
18+
from __future__ import print_function
19+
20+
import uuid
21+
22+
import tensorflow as tf
23+
from tensorflow_io.core.python.ops import io_tensor_ops
24+
from tensorflow_io.core.python.ops import core_ops
25+
26+
class KafkaIOTensor(io_tensor_ops._ColumnIOTensor): # pylint: disable=protected-access
27+
"""KafkaIOTensor"""
28+
29+
#=============================================================================
30+
# Constructor (private)
31+
#=============================================================================
32+
def __init__(self,
33+
subscription,
34+
servers=None,
35+
configuration=None,
36+
internal=False):
37+
with tf.name_scope("KafkaIOTensor") as scope:
38+
metadata = [e for e in configuration or []]
39+
if servers is not None:
40+
metadata.append("bootstrap.servers=%s" % servers)
41+
resource, dtypes, shapes = core_ops.kafka_indexable_init(
42+
subscription, metadata=metadata,
43+
container=scope,
44+
shared_name="%s/%s" % (subscription, uuid.uuid4().hex))
45+
print("VVV: ", dtypes, shapes)
46+
super(KafkaIOTensor, self).__init__(
47+
shapes, dtypes, resource, core_ops.kafka_indexable_get_item,
48+
internal=internal)

tensorflow_io/kafka/kernels/kafka_kernels.cc

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -244,7 +244,10 @@ REGISTER_KERNEL_BUILDER(Name("KafkaIterableInit").Device(DEVICE_CPU),
244244
IOInterfaceInitOp<KafkaIterable>);
245245
REGISTER_KERNEL_BUILDER(Name("KafkaIterableNext").Device(DEVICE_CPU),
246246
IOIterableNextOp<KafkaIterable>);
247-
247+
REGISTER_KERNEL_BUILDER(Name("KafkaIndexableInit").Device(DEVICE_CPU),
248+
IOInterfaceInitOp<IOIndexableImplementation<KafkaIterable>>);
249+
REGISTER_KERNEL_BUILDER(Name("KafkaIndexableGetItem").Device(DEVICE_CPU),
250+
IOIndexableGetItemOp<IOIndexableImplementation<KafkaIterable>>);
248251

249252
} // namespace data
250253
} // namespace tensorflow

tensorflow_io/kafka/ops/kafka_ops.cc

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,44 @@ limitations under the License.
1919

2020
namespace tensorflow {
2121

22+
REGISTER_OP("KafkaIndexableInit")
23+
.Input("input: string")
24+
.Input("metadata: string")
25+
.Output("output: resource")
26+
.Output("dtypes: int64")
27+
.Output("shapes: int64")
28+
.Attr("container: string = ''")
29+
.Attr("shared_name: string = ''")
30+
.SetIsStateful()
31+
.SetShapeFn([](shape_inference::InferenceContext* c) {
32+
c->set_output(0, c->Scalar());
33+
c->set_output(1, c->MakeShape({c->UnknownDim()}));
34+
c->set_output(2, c->MakeShape({c->UnknownDim(), c->UnknownDim()}));
35+
return Status::OK();
36+
});
37+
38+
REGISTER_OP("KafkaIndexableGetItem")
39+
.Input("input: resource")
40+
.Input("start: int64")
41+
.Input("stop: int64")
42+
.Input("step: int64")
43+
.Output("output: dtype")
44+
.Attr("dtype: list(type) >= 1")
45+
.Attr("shape: list(shape) >= 1")
46+
.SetShapeFn([](shape_inference::InferenceContext* c) {
47+
std::vector<PartialTensorShape> shape;
48+
TF_RETURN_IF_ERROR(c->GetAttr("shape", &shape));
49+
if (shape.size() != c->num_outputs()) {
50+
return errors::InvalidArgument("`shape` must be the same length as `types` (", shape.size(), " vs. ", c->num_outputs());
51+
}
52+
for (size_t i = 0; i < shape.size(); ++i) {
53+
shape_inference::ShapeHandle entry;
54+
TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape(shape[i], &entry));
55+
c->set_output(static_cast<int64>(i), entry);
56+
}
57+
return Status::OK();
58+
});
59+
2260
REGISTER_OP("KafkaIterableInit")
2361
.Input("input: string")
2462
.Input("metadata: string")

tests/test_kafka_eager.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
import tensorflow as tf
2626
if not (hasattr(tf, "version") and tf.version.VERSION.startswith("2.")):
2727
tf.compat.v1.enable_eager_execution()
28-
import tensorflow_io.kafka as kafka_io # pylint: disable=wrong-import-position
28+
import tensorflow_io as tfio # pylint: disable=wrong-import-position
2929
from tensorflow_io.core.python.ops import kafka_dataset_ops # pylint: disable=wrong-import-position
3030

3131
def test_kafka_dataset():
@@ -34,6 +34,13 @@ def test_kafka_dataset():
3434
e.numpy().tolist() for e in dataset] == np.asarray([
3535
("D" + str(i)).encode() for i in range(10)]).reshape((5, 2)))
3636

37+
def test_kafka_io_tensor():
38+
kafka = tfio.IOTensor.from_kafka("test")
39+
assert kafka.dtype == tf.string
40+
assert kafka.shape == [10]
41+
assert np.all(kafka.to_tensor().numpy() == [
42+
("D" + str(i)).encode() for i in range(10)])
43+
3744
@pytest.mark.skipif(
3845
not (hasattr(tf, "version") and
3946
tf.version.VERSION.startswith("2.0.")), reason=None)
@@ -64,7 +71,7 @@ def test_kafka_output_sequence():
6471
class OutputCallback(tf.keras.callbacks.Callback):
6572
"""KafkaOutputCallback"""
6673
def __init__(self, batch_size, topic, servers):
67-
self._sequence = kafka_io.KafkaOutputSequence(
74+
self._sequence = tfio.kafka.KafkaOutputSequence(
6875
topic=topic, servers=servers)
6976
self._batch_size = batch_size
7077
def on_predict_batch_end(self, batch, logs=None):
@@ -87,6 +94,6 @@ def flush(self):
8794
predictions = [class_names[v] for v in np.argmax(predictions, axis=1)]
8895

8996
# Reading from `test_e(time)e` we should get the same result
90-
dataset = kafka_io.KafkaDataset(topics=[topic], group="test", eof=True)
97+
dataset = tfio.kafka.KafkaDataset(topics=[topic], group="test", eof=True)
9198
for entry, prediction in zip(dataset, predictions):
9299
assert entry.numpy() == prediction.encode()

0 commit comments

Comments
 (0)