From 8f0d8dd9d7f51fdf02ae25c728d4667859c8f555 Mon Sep 17 00:00:00 2001 From: Yong Tang Date: Wed, 21 Aug 2019 23:16:51 +0000 Subject: [PATCH 01/11] Expose tfio.IOTensor class and from_audio and tfio.IOTensor.to_dataset() This PR tries to expose a tfio.IOTensor which could be applied to and io related data which are indexable (__getitem__ and __len__) The idea is to bind __getitem__ and __len__ to kernel ops in run time, so that is is not necessarily to read everything in memory. The first file format is the WAV file. With tfio.IOTensor dtype and shape are exposed with __getitem__ and __len__. Further, a rate property has been exposed specifically for Audio/WAV file which gives sample rate. This tfio.IOTensor only works in eager mode. In additional this PR also converts WavDataset to use IOTensor (instead of direct C++ implementation). This PR also carries 420. Note as was discussed, rebatch has been dropped. Instead, a PR to core tensorflow repo will be opened. Signed-off-by: Yong Tang --- WORKSPACE | 10 + tensorflow_io/__init__.py | 2 +- tensorflow_io/arrow/kernels/arrow_kernels.h | 16 +- tensorflow_io/audio/__init__.py | 6 - tensorflow_io/audio/kernels/audio_kernels.cc | 225 ++++++-------- tensorflow_io/audio/ops/audio_ops.cc | 62 ++-- tensorflow_io/audio/python/ops/audio_ops.py | 59 ---- tensorflow_io/core/BUILD | 1 + tensorflow_io/core/kernels/io_interface.h | 224 ++++++++++++++ tensorflow_io/core/ops/core_ops.cc | 15 - .../core/python/ops/audio_io_tensor_ops.py | 54 ++++ tensorflow_io/core/python/ops/io_tensor.py | 254 ++++++++++++++++ .../core/python/ops/io_tensor_ops.py | 286 ++++++++++++++++++ .../core/python/ops/json_io_tensor_ops.py | 50 +++ tensorflow_io/json/BUILD | 2 + tensorflow_io/json/kernels/json_kernels.cc | 230 ++++++++++++++ tensorflow_io/json/ops/json_ops.cc | 41 +++ tests/test_audio_eager.py | 50 ++- tests/test_json/feature.ndjson | 2 + tests/test_json/label.ndjson | 2 + tests/test_json_eager.py | 57 ++++ third_party/arrow.BUILD | 5 +- third_party/rapidjson.BUILD | 23 ++ 23 files changed, 1404 insertions(+), 272 deletions(-) create mode 100644 tensorflow_io/core/kernels/io_interface.h create mode 100644 tensorflow_io/core/python/ops/audio_io_tensor_ops.py create mode 100644 tensorflow_io/core/python/ops/io_tensor.py create mode 100644 tensorflow_io/core/python/ops/io_tensor_ops.py create mode 100644 tensorflow_io/core/python/ops/json_io_tensor_ops.py create mode 100644 tests/test_json/feature.ndjson create mode 100644 tests/test_json/label.ndjson create mode 100644 third_party/rapidjson.BUILD diff --git a/WORKSPACE b/WORKSPACE index a8427116a..677273d61 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -542,3 +542,13 @@ http_archive( "https://github.com/google/double-conversion/archive/3992066a95b823efc8ccc1baf82a1cfc73f6e9b8.zip", ], ) + +http_archive( + name = "rapidjson", + build_file = "//third_party:rapidjson.BUILD", + sha256 = "bf7ced29704a1e696fbccf2a2b4ea068e7774fa37f6d7dd4039d0787f8bed98e", + strip_prefix = "rapidjson-1.1.0", + urls = [ + "https://github.com/miloyip/rapidjson/archive/v1.1.0.tar.gz", + ], +) diff --git a/tensorflow_io/__init__.py b/tensorflow_io/__init__.py index 5c9b94a5d..5d66c595c 100644 --- a/tensorflow_io/__init__.py +++ b/tensorflow_io/__init__.py @@ -17,4 +17,4 @@ from __future__ import division from __future__ import print_function -import tensorflow as tf +from tensorflow_io.core.python.ops.io_tensor import IOTensor diff --git a/tensorflow_io/arrow/kernels/arrow_kernels.h b/tensorflow_io/arrow/kernels/arrow_kernels.h index b0523ead1..30da9745f 100644 --- a/tensorflow_io/arrow/kernels/arrow_kernels.h +++ b/tensorflow_io/arrow/kernels/arrow_kernels.h @@ -30,7 +30,8 @@ class ArrowRandomAccessFile : public ::arrow::io::RandomAccessFile { public: explicit ArrowRandomAccessFile(tensorflow::RandomAccessFile *file, int64 size) : file_(file) - , size_(size) { } + , size_(size) + , position_(0) { } ~ArrowRandomAccessFile() {} arrow::Status Close() override { @@ -40,13 +41,21 @@ class ArrowRandomAccessFile : public ::arrow::io::RandomAccessFile { return false; } arrow::Status Tell(int64_t* position) const override { - return arrow::Status::NotImplemented("Tell"); + *position = position_; + return arrow::Status::OK(); } arrow::Status Seek(int64_t position) override { return arrow::Status::NotImplemented("Seek"); } arrow::Status Read(int64_t nbytes, int64_t* bytes_read, void* out) override { - return arrow::Status::NotImplemented("Read (void*)"); + StringPiece result; + Status status = file_->Read(position_, nbytes, &result, (char*)out); + if (!(status.ok() || errors::IsOutOfRange(status))) { + return arrow::Status::IOError(status.error_message()); + } + *bytes_read = result.size(); + position_ += (*bytes_read); + return arrow::Status::OK(); } arrow::Status Read(int64_t nbytes, std::shared_ptr* out) override { return arrow::Status::NotImplemented("Read (Buffer*)"); @@ -81,6 +90,7 @@ class ArrowRandomAccessFile : public ::arrow::io::RandomAccessFile { private: tensorflow::RandomAccessFile* file_; int64 size_; + int64 position_; }; } // namespace data } // namespace tensorflow diff --git a/tensorflow_io/audio/__init__.py b/tensorflow_io/audio/__init__.py index b95edeedd..5792023c0 100644 --- a/tensorflow_io/audio/__init__.py +++ b/tensorflow_io/audio/__init__.py @@ -15,8 +15,6 @@ """Audio Dataset. @@WAVDataset -@@list_wav_info -@@read_wav """ from __future__ import absolute_import @@ -24,15 +22,11 @@ from __future__ import print_function from tensorflow_io.audio.python.ops.audio_ops import WAVDataset -from tensorflow_io.audio.python.ops.audio_ops import list_wav_info -from tensorflow_io.audio.python.ops.audio_ops import read_wav from tensorflow.python.util.all_util import remove_undocumented _allowed_symbols = [ "WAVDataset", - "list_wav_info", - "read_wav", ] remove_undocumented(__name__) diff --git a/tensorflow_io/audio/kernels/audio_kernels.cc b/tensorflow_io/audio/kernels/audio_kernels.cc index aa65ebed4..62f9e0737 100644 --- a/tensorflow_io/audio/kernels/audio_kernels.cc +++ b/tensorflow_io/audio/kernels/audio_kernels.cc @@ -13,12 +13,11 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow_io/core/kernels/io_interface.h" #include "tensorflow_io/core/kernels/stream.h" namespace tensorflow { namespace data { -namespace { // See http://www-mmsp.ece.mcgill.ca/Documents/AudioFormats/WAVE/WAVE.html struct WAVHeader { @@ -68,189 +67,140 @@ Status ValidateWAVHeader(struct WAVHeader *header) { return Status::OK(); } -class ListWAVInfoOp : public OpKernel { +class WAVIndexable : public IOIndexableInterface { public: - explicit ListWAVInfoOp(OpKernelConstruction* context) : OpKernel(context) { - env_ = context->env(); - } - - void Compute(OpKernelContext* context) override { - const Tensor& filename_tensor = context->input(0); - const string filename = filename_tensor.scalar()(); + WAVIndexable(Env* env) + : env_(env) {} - const Tensor& memory_tensor = context->input(1); - const string& memory = memory_tensor.scalar()(); - std::unique_ptr file(new SizedRandomAccessFile(env_, filename, memory.data(), memory.size())); - uint64 size; - OP_REQUIRES_OK(context, file->GetFileSize(&size)); + ~WAVIndexable() {} + Status Init(const std::vector& input, const std::vector& metadata, const void* memory_data, const int64 memory_size) override { + if (input.size() > 1) { + return errors::InvalidArgument("more than 1 filename is not supported"); + } + const string& filename = input[0]; + file_.reset(new SizedRandomAccessFile(env_, filename, memory_data, memory_size)); + TF_RETURN_IF_ERROR(file_->GetFileSize(&file_size_)); StringPiece result; - struct WAVHeader header; - OP_REQUIRES_OK(context, file->Read(0, sizeof(header), &result, (char *)(&header))); + TF_RETURN_IF_ERROR(file_->Read(0, sizeof(header_), &result, (char *)(&header_))); - OP_REQUIRES_OK(context, ValidateWAVHeader(&header)); - if (header.riff_size + 8 != size) { + TF_RETURN_IF_ERROR(ValidateWAVHeader(&header_)); + if (header_.riff_size + 8 != file_size_) { // corrupted file? } - int64 filesize = header.riff_size + 8; + int64 filesize = header_.riff_size + 8; int64 position = result.size(); - if (header.fmt_size != 16) { - position += header.fmt_size - 16; + if (header_.fmt_size != 16) { + position += header_.fmt_size - 16; } int64 nSamples = 0; do { struct DataHeader head; - OP_REQUIRES_OK(context, file->Read(position, sizeof(head), &result, (char *)(&head))); + TF_RETURN_IF_ERROR(file_->Read(position, sizeof(head), &result, (char *)(&head))); position += result.size(); if (memcmp(head.mark, "data", 4) == 0) { // Data should be block aligned // bytes = nSamples * nBlockAlign - OP_REQUIRES(context, (head.size % header.nBlockAlign == 0), errors::InvalidArgument("data chunk should be block aligned (", header.nBlockAlign, "), received: ", head.size)); - nSamples += head.size / header.nBlockAlign; + if (head.size % header_.nBlockAlign != 0) { + return errors::InvalidArgument("data chunk should be block aligned (", header_.nBlockAlign, "), received: ", head.size); + } + nSamples += head.size / header_.nBlockAlign; } position += head.size; } while (position < filesize); - string dtype; - switch (header.wBitsPerSample) { + switch (header_.wBitsPerSample) { case 8: - dtype = "int8"; + dtype_ = DT_INT8; break; case 16: - dtype = "int16"; + dtype_ = DT_INT16; break; case 24: - dtype = "int32"; + dtype_ = DT_INT32; break; default: - OP_REQUIRES(context, false, errors::InvalidArgument("unsupported wBitsPerSample: ", header.wBitsPerSample)); + return errors::InvalidArgument("unsupported wBitsPerSample: ", header_.wBitsPerSample); } - Tensor* dtype_tensor; - OP_REQUIRES_OK(context, context->allocate_output(0, TensorShape({}), &dtype_tensor)); - Tensor* shape_tensor; - OP_REQUIRES_OK(context, context->allocate_output(1, TensorShape({2}), &shape_tensor)); - Tensor* rate_tensor; - OP_REQUIRES_OK(context, context->allocate_output(2, TensorShape({}), &rate_tensor)); + shape_ = TensorShape({nSamples, header_.nChannels}); - dtype_tensor->scalar()() = std::move(dtype); - shape_tensor->flat()(0) = nSamples; - shape_tensor->flat()(1) = header.nChannels; - rate_tensor->scalar()() = header.nSamplesPerSec; + return Status::OK(); } - private: - mutex mu_; - Env* env_ GUARDED_BY(mu_); -}; - -class ReadWAVOp : public OpKernel { - public: - explicit ReadWAVOp(OpKernelConstruction* context) : OpKernel(context) { - env_ = context->env(); + Status Spec(std::vector& dtypes, std::vector& shapes) override { + dtypes.clear(); + dtypes.push_back(dtype_); + shapes.clear(); + shapes.push_back(shape_); + return Status::OK(); } - void Compute(OpKernelContext* context) override { - const Tensor& filename_tensor = context->input(0); - const string& filename = filename_tensor.scalar()(); - - const Tensor& memory_tensor = context->input(1); - const string& memory = memory_tensor.scalar()(); - - const Tensor& start_tensor = context->input(2); - const int64 start = start_tensor.scalar()(); - - const Tensor& stop_tensor = context->input(3); - const int64 stop = stop_tensor.scalar()(); - - std::unique_ptr file(new SizedRandomAccessFile(env_, filename, memory.data(), memory.size())); - uint64 size; - OP_REQUIRES_OK(context, file->GetFileSize(&size)); - - StringPiece result; - struct WAVHeader header; - OP_REQUIRES_OK(context, file->Read(0, sizeof(header), &result, (char *)(&header))); - - OP_REQUIRES_OK(context, ValidateWAVHeader(&header)); - if (header.riff_size + 8 != size) { - // corrupted file? - } - int64 filesize = header.riff_size + 8; - - int64 position = result.size(); - if (header.fmt_size != 16) { - position += header.fmt_size - 16; - } - - int64 nSamples = 0; - do { - struct DataHeader head; - OP_REQUIRES_OK(context, file->Read(position, sizeof(head), &result, (char *)(&head))); - position += result.size(); - if (memcmp(head.mark, "data", 4) == 0) { - // Data should be block aligned - // bytes = nSamples * nBlockAlign - OP_REQUIRES(context, (head.size % header.nBlockAlign == 0), errors::InvalidArgument("data chunk should be block aligned (", header.nBlockAlign, "), received: ", head.size)); - nSamples += head.size / header.nBlockAlign; - } - position += head.size; - } while (position < filesize); - + Status Extra(std::vector* extra) override { + // Expose a sample `rate` + Tensor rate(DT_INT32, TensorShape({})); + rate.scalar()() = header_.nSamplesPerSec; + extra->push_back(rate); + return Status::OK(); + } - int64 sample_start = start; - int64 sample_stop = stop; - if (sample_start > nSamples) { - sample_start = nSamples; + Status GetItem(const int64 start, const int64 stop, const int64 step, std::vector& tensors) override { + Tensor& output_tensor = tensors[0]; + if (step != 1) { + return errors::InvalidArgument("step ", step, " is not supported"); } - if (sample_stop < 0) { - sample_stop = nSamples; - } - if (sample_stop < sample_start) { - sample_stop = sample_start; - } - - Tensor* output_tensor; - OP_REQUIRES_OK(context, context->allocate_output(0, TensorShape({sample_stop - sample_start, header.nChannels}), &output_tensor)); + const int64 sample_start = start; + const int64 sample_stop = stop; int64 sample_offset = 0; - - position = sizeof(header) + header.fmt_size - 16; + if (header_.riff_size + 8 != file_size_) { + // corrupted file? + } + int64 filesize = header_.riff_size + 8; + int64 position = sizeof(header_) + header_.fmt_size - 16; do { + StringPiece result; struct DataHeader head; - OP_REQUIRES_OK(context, file->Read(position, sizeof(head), &result, (char *)(&head))); + TF_RETURN_IF_ERROR(file_->Read(position, sizeof(head), &result, (char *)(&head))); position += result.size(); if (memcmp(head.mark, "data", 4) == 0) { // Already checked the alignment int64 block_sample_start = sample_offset; - int64 block_sample_stop = sample_offset + head.size / header.nBlockAlign; + int64 block_sample_stop = sample_offset + head.size / header_.nBlockAlign; // only read if block_sample_start and block_sample_stop within range if (sample_start < block_sample_stop && sample_stop > block_sample_start) { int64 read_sample_start = (block_sample_start > sample_start ? block_sample_start : sample_start); int64 read_sample_stop = (block_sample_stop < sample_stop ? block_sample_stop : sample_stop); - int64 read_bytes_start = position + (read_sample_start - block_sample_start) * header.nBlockAlign; - int64 read_bytes_stop = position + (read_sample_stop - block_sample_start) * header.nBlockAlign; + int64 read_bytes_start = position + (read_sample_start - block_sample_start) * header_.nBlockAlign; + int64 read_bytes_stop = position + (read_sample_stop - block_sample_start) * header_.nBlockAlign; string buffer; buffer.resize(read_bytes_stop - read_bytes_start); - OP_REQUIRES_OK(context, file->Read(read_bytes_start, read_bytes_stop - read_bytes_start, &result, &buffer[0])); - switch (header.wBitsPerSample) { + TF_RETURN_IF_ERROR(file_->Read(read_bytes_start, read_bytes_stop - read_bytes_start, &result, &buffer[0])); + switch (header_.wBitsPerSample) { case 8: - OP_REQUIRES(context, (header.wBitsPerSample * header.nChannels == header.nBlockAlign * 8), errors::InvalidArgument("unsupported wBitsPerSample and header.nBlockAlign: ", header.wBitsPerSample, ", ", header.nBlockAlign)); - memcpy((char *)(output_tensor->flat().data()) + ((read_sample_start - sample_start) * header.nBlockAlign), &buffer[0], (read_bytes_stop - read_bytes_start)); + if (header_.wBitsPerSample * header_.nChannels != header_.nBlockAlign * 8) { + return errors::InvalidArgument("unsupported wBitsPerSample and header.nBlockAlign: ", header_.wBitsPerSample, ", ", header_.nBlockAlign); + } + memcpy((char *)(output_tensor.flat().data()) + ((read_sample_start - sample_start) * header_.nBlockAlign), &buffer[0], (read_bytes_stop - read_bytes_start)); break; case 16: - OP_REQUIRES(context, (header.wBitsPerSample * header.nChannels == header.nBlockAlign * 8), errors::InvalidArgument("unsupported wBitsPerSample and header.nBlockAlign: ", header.wBitsPerSample, ", ", header.nBlockAlign)); - memcpy((char *)(output_tensor->flat().data()) + ((read_sample_start - sample_start) * header.nBlockAlign), &buffer[0], (read_bytes_stop - read_bytes_start)); + if (header_.wBitsPerSample * header_.nChannels != header_.nBlockAlign * 8) { + return errors::InvalidArgument("unsupported wBitsPerSample and header.nBlockAlign: ", header_.wBitsPerSample, ", ", header_.nBlockAlign); + } + memcpy((char *)(output_tensor.flat().data()) + ((read_sample_start - sample_start) * header_.nBlockAlign), &buffer[0], (read_bytes_stop - read_bytes_start)); break; case 24: // NOTE: The conversion is from signed integer 24 to signed integer 32 (left shift 8 bits) - OP_REQUIRES(context, (header.wBitsPerSample * header.nChannels == header.nBlockAlign * 8), errors::InvalidArgument("unsupported wBitsPerSample and header.nBlockAlign: ", header.wBitsPerSample, ", ", header.nBlockAlign)); + if (header_.wBitsPerSample * header_.nChannels != header_.nBlockAlign * 8) { + return errors::InvalidArgument("unsupported wBitsPerSample and header.nBlockAlign: ", header_.wBitsPerSample, ", ", header_.nBlockAlign); + } for (int64 i = read_sample_start; i < read_sample_stop; i++) { - for (int64 j = 0; j < header.nChannels; j++) { - char *data_p = (char *)(output_tensor->flat().data() + ((i - sample_start) * header.nChannels + j)); - char *read_p = (char *)(&buffer[((i - read_sample_start) * header.nBlockAlign)]) + 3 * j; + for (int64 j = 0; j < header_.nChannels; j++) { + char *data_p = (char *)(output_tensor.flat().data() + ((i - sample_start) * header_.nChannels + j)); + char *read_p = (char *)(&buffer[((i - read_sample_start) * header_.nBlockAlign)]) + 3 * j; data_p[3] = read_p[2]; data_p[2] = read_p[1]; data_p[1] = read_p[0]; @@ -259,25 +209,36 @@ class ReadWAVOp : public OpKernel { } break; default: - OP_REQUIRES(context, false, errors::InvalidArgument("unsupported wBitsPerSample and header.nBlockAlign: ", header.wBitsPerSample, ", ", header.nBlockAlign)); + return errors::InvalidArgument("unsupported wBitsPerSample and header.nBlockAlign: ", header_.wBitsPerSample, ", ", header_.nBlockAlign); } } sample_offset = block_sample_stop; } position += head.size; } while (position < filesize); + + return Status::OK(); + } + + string DebugString() const override { + mutex_lock l(mu_); + return strings::StrCat("WAVIndexable"); } private: - mutex mu_; + mutable mutex mu_; Env* env_ GUARDED_BY(mu_); + std::unique_ptr file_ GUARDED_BY(mu_); + uint64 file_size_ GUARDED_BY(mu_); + DataType dtype_; + TensorShape shape_; + struct WAVHeader header_; }; -REGISTER_KERNEL_BUILDER(Name("ListWAVInfo").Device(DEVICE_CPU), - ListWAVInfoOp); -REGISTER_KERNEL_BUILDER(Name("ReadWAV").Device(DEVICE_CPU), - ReadWAVOp); +REGISTER_KERNEL_BUILDER(Name("WAVIndexableInit").Device(DEVICE_CPU), + IOInterfaceInitOp); +REGISTER_KERNEL_BUILDER(Name("WAVIndexableGetItem").Device(DEVICE_CPU), + IOIndexableGetItemOp); -} // namespace } // namespace data } // namespace tensorflow diff --git a/tensorflow_io/audio/ops/audio_ops.cc b/tensorflow_io/audio/ops/audio_ops.cc index 2ca5c66a4..0d6c18fe4 100644 --- a/tensorflow_io/audio/ops/audio_ops.cc +++ b/tensorflow_io/audio/ops/audio_ops.cc @@ -19,29 +19,43 @@ limitations under the License. namespace tensorflow { -REGISTER_OP("ListWAVInfo") - .Input("filename: string") - .Input("memory: string") - .Output("dtype: string") - .Output("shape: int64") - .Output("rate: int32") - .SetShapeFn([](shape_inference::InferenceContext* c) { - c->set_output(0, c->MakeShape({})); - c->set_output(1, c->MakeShape({2})); - c->set_output(2, c->MakeShape({})); - return Status::OK(); - }); - -REGISTER_OP("ReadWAV") - .Input("filename: string") - .Input("memory: string") - .Input("start: int64") - .Input("stop: int64") - .Attr("dtype: type") - .Output("output: dtype") - .SetShapeFn([](shape_inference::InferenceContext* c) { - c->set_output(0, c->MakeShape({c->UnknownDim(), c->UnknownDim()})); - return Status::OK(); - }); +REGISTER_OP("WAVIndexableInit") + .Input("input: string") + .Output("output: resource") + .Output("dtypes: int64") + .Output("shapes: int64") + .Output("rate: int32") + .Attr("container: string = ''") + .Attr("shared_name: string = ''") + .SetIsStateful() + .SetShapeFn([](shape_inference::InferenceContext* c) { + c->set_output(0, c->Scalar()); + c->set_output(1, c->MakeShape({c->UnknownDim()})); + c->set_output(2, c->MakeShape({c->UnknownDim(), c->UnknownDim()})); + c->set_output(3, c->Scalar()); + return Status::OK(); + }); + +REGISTER_OP("WAVIndexableGetItem") + .Input("input: resource") + .Input("start: int64") + .Input("stop: int64") + .Input("step: int64") + .Output("output: dtype") + .Attr("dtype: list(type) >= 1") + .Attr("shape: list(shape) >= 1") + .SetShapeFn([](shape_inference::InferenceContext* c) { + std::vector shape; + TF_RETURN_IF_ERROR(c->GetAttr("shape", &shape)); + if (shape.size() != c->num_outputs()) { + return errors::InvalidArgument("`shape` must be the same length as `types` (", shape.size(), " vs. ", c->num_outputs()); + } + for (size_t i = 0; i < shape.size(); ++i) { + shape_inference::ShapeHandle entry; + TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape(shape[i], &entry)); + c->set_output(static_cast(i), entry); + } + return Status::OK(); + }); } // namespace tensorflow diff --git a/tensorflow_io/audio/python/ops/audio_ops.py b/tensorflow_io/audio/python/ops/audio_ops.py index 374ef0400..8c01b594d 100644 --- a/tensorflow_io/audio/python/ops/audio_ops.py +++ b/tensorflow_io/audio/python/ops/audio_ops.py @@ -19,65 +19,6 @@ import tensorflow as tf from tensorflow_io.core.python.ops import data_ops -from tensorflow_io.core.python.ops import core_ops - -def list_wav_info(filename, **kwargs): - """list_wav_info""" - if not tf.executing_eagerly(): - raise NotImplementedError("list_wav_info only support eager mode") - memory = kwargs.get("memory", "") - dtype, shape, rate = core_ops.list_wav_info( - filename, memory=memory) - return tf.TensorSpec(shape.numpy(), dtype.numpy().decode()), rate - -def read_wav(filename, spec, **kwargs): - """read_wav""" - memory = kwargs.get("memory", "") - start = kwargs.get("start", 0) - stop = kwargs.get("stop", None) - if stop is None and spec.shape[0] is not None: - stop = spec.shape[0] - start - if stop is None: - stop = -1 - return core_ops.read_wav( - filename, memory=memory, - start=start, stop=stop, dtype=spec.dtype) - -class WAVDataset(data_ops.BaseDataset): - """A WAV Dataset""" - - def __init__(self, filename, **kwargs): - """Create a WAVDataset. - - Args: - filename: A string containing filename. - """ - if not tf.executing_eagerly(): - start = kwargs.get("start") - stop = kwargs.get("stop") - dtype = kwargs.get("dtype") - shape = kwargs.get("shape") - else: - spec, _ = list_wav_info(filename) - start = 0 - stop = spec.shape[0] - dtype = spec.dtype - shape = tf.TensorShape( - [dim if i != 0 else None for i, dim in enumerate( - spec.shape.as_list())]) - - # capacity is the rough count for each chunk in dataset - capacity = kwargs.get("capacity", 65536) - entry_start = list(range(start, stop, capacity)) - entry_stop = entry_start[1:] + [stop] - dataset = data_ops.BaseDataset.from_tensor_slices( - (tf.constant(entry_start, tf.int64), tf.constant(entry_stop, tf.int64)) - ).map(lambda start, stop: core_ops.read_wav( - filename, memory="", start=start, stop=stop, dtype=dtype)) - self._dataset = dataset - - super(WAVDataset, self).__init__( - self._dataset._variant_tensor, [dtype], [shape]) # pylint: disable=protected-access class AudioDataset(data_ops.Dataset): """A Audio File Dataset that reads the audio file.""" diff --git a/tensorflow_io/core/BUILD b/tensorflow_io/core/BUILD index 192f0c9b9..0005ae2f8 100644 --- a/tensorflow_io/core/BUILD +++ b/tensorflow_io/core/BUILD @@ -29,6 +29,7 @@ cc_library( name = "dataset_ops", srcs = [ "kernels/dataset_ops.h", + "kernels/io_interface.h", "kernels/stream.h", ], copts = tf_io_copts(), diff --git a/tensorflow_io/core/kernels/io_interface.h b/tensorflow_io/core/kernels/io_interface.h new file mode 100644 index 000000000..1d7e49269 --- /dev/null +++ b/tensorflow_io/core/kernels/io_interface.h @@ -0,0 +1,224 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/resource_mgr.h" +#include "tensorflow/core/framework/resource_op_kernel.h" +#include "tensorflow/core/util/batch_util.h" + +namespace tensorflow { +namespace data { + +class IOInterface : public ResourceBase { + public: + virtual Status Init(const std::vector& input, const std::vector& metadata, const void* memory_data, const int64 memory_size) = 0; + virtual Status Spec(std::vector& dtypes, std::vector& shapes) = 0; + + virtual Status Extra(std::vector* extra) { + // This is the chance to provide additional extra information which should be appended to extra. + return Status::OK(); + } +}; + +class IOIterableInterface : public IOInterface { + public: + virtual Status Next(const int64 capacity, std::vector& tensors, int64* record_read) = 0; +}; + +class IOIndexableInterface : public IOInterface { + public: + virtual Status GetItem(const int64 start, const int64 stop, const int64 step, std::vector& tensors) = 0; +}; + +template +class IOInterfaceInitOp : public ResourceOpKernel { + public: + explicit IOInterfaceInitOp(OpKernelConstruction* context) + : ResourceOpKernel(context) { + env_ = context->env(); + } + private: + void Compute(OpKernelContext* context) override { + ResourceOpKernel::Compute(context); + + std::vector input; + const Tensor* input_tensor; + OP_REQUIRES_OK(context, context->input("input", &input_tensor)); + for (int64 i = 0; i < input_tensor->NumElements(); i++) { + input.push_back(input_tensor->flat()(i)); + } + + Status status; + + std::vector metadata; + const Tensor* metadata_tensor; + status = context->input("metadata", &metadata_tensor); + if (status.ok()) { + for (int64 i = 0; i < metadata_tensor->NumElements(); i++) { + metadata.push_back(metadata_tensor->flat()(i)); + } + } + + const void *memory_data = nullptr; + size_t memory_size = 0; + + const Tensor* memory_tensor; + status = context->input("memory", &memory_tensor); + if (status.ok()) { + memory_data = memory_tensor->scalar()().data(); + memory_size = memory_tensor->scalar()().size(); + } + + OP_REQUIRES_OK(context, this->resource_->Init(input, metadata, memory_data, memory_size)); + + std::vector dtypes; + std::vector shapes; + OP_REQUIRES_OK(context, this->resource_->Spec(dtypes, shapes)); + int64 maxrank = 0; + for (size_t component = 0; component < shapes.size(); component++) { + if (dynamic_cast(this->resource_) != nullptr) { + int64 i = 0; + OP_REQUIRES(context, (shapes[component].dim_size(i) > 0), errors::InvalidArgument("component (", component, ")'s shape[", i, "] should not be None, received: ", shapes[component])); + } + for (int64 i = 1; i < shapes[component].dims(); i++) { + OP_REQUIRES(context, (shapes[component].dim_size(i) > 0), errors::InvalidArgument("component (", component, ")'s shape[", i, "] should not be None, received: ", shapes[component])); + } + maxrank = maxrank > shapes[component].dims() ? maxrank : shapes[component].dims(); + } + Tensor dtypes_tensor(DT_INT64, TensorShape({static_cast(dtypes.size())})); + for (size_t i = 0; i < dtypes.size(); i++) { + dtypes_tensor.flat()(i) = dtypes[i]; + } + Tensor shapes_tensor(DT_INT64, TensorShape({static_cast(dtypes.size()), maxrank})); + for (size_t component = 0; component < shapes.size(); component++) { + for (int64 i = 0; i < shapes[component].dims(); i++) { + shapes_tensor.tensor()(component, i) = shapes[component].dim_size(i); + } + for (int64 i = shapes[component].dims(); i < maxrank; i++) { + shapes_tensor.tensor()(component, i) = 0; + } + } + context->set_output(1, dtypes_tensor); + context->set_output(2, shapes_tensor); + + std::vector extra; + OP_REQUIRES_OK(context, this->resource_->Extra(&extra)); + for (size_t i = 0; i < extra.size(); i++) { + context->set_output(3 + i, extra[i]); + } + } + Status CreateResource(Type** resource) + EXCLUSIVE_LOCKS_REQUIRED(mu_) override { + *resource = new Type(env_); + return Status::OK(); + } + mutex mu_; + Env* env_; +}; + +template +class IOIterableNextOp : public OpKernel { + public: + explicit IOIterableNextOp(OpKernelConstruction* ctx) + : OpKernel(ctx) { + } + + void Compute(OpKernelContext* context) override { + Type* resource; + OP_REQUIRES_OK(context, GetResourceFromContext(context, "input", &resource)); + core::ScopedUnref unref(resource); + + const Tensor* capacity_tensor; + OP_REQUIRES_OK(context, context->input("capacity", &capacity_tensor)); + const int64 capacity = capacity_tensor->scalar()(); + + OP_REQUIRES(context, (capacity > 0), errors::InvalidArgument("capacity <= 0 is not supported: ", capacity)); + + std::vector dtypes; + std::vector shapes; + OP_REQUIRES_OK(context, resource->Spec(dtypes, shapes)); + + std::vector tensors; + for (size_t i = 0; i < dtypes.size(); i++) { + gtl::InlinedVector dims = shapes[i].dim_sizes(); + dims[0] = capacity; + tensors.emplace_back(Tensor(dtypes[i], TensorShape(dims))); + } + + int64 record_read; + OP_REQUIRES_OK(context, resource->Next(capacity, tensors, &record_read)); + for (size_t i = 0; i < tensors.size(); i++) { + if (record_read < capacity) { + context->set_output(i, tensors[i].Slice(0, record_read)); + } else { + context->set_output(i, tensors[i]); + } + } + } +}; +template +class IOIndexableGetItemOp : public OpKernel { + public: + explicit IOIndexableGetItemOp(OpKernelConstruction* ctx) + : OpKernel(ctx) { + } + + void Compute(OpKernelContext* context) override { + Type* resource; + OP_REQUIRES_OK(context, GetResourceFromContext(context, "input", &resource)); + core::ScopedUnref unref(resource); + + const Tensor* start_tensor; + OP_REQUIRES_OK(context, context->input("start", &start_tensor)); + int64 start = start_tensor->scalar()(); + + const Tensor* stop_tensor; + OP_REQUIRES_OK(context, context->input("stop", &stop_tensor)); + int64 stop = stop_tensor->scalar()(); + + const Tensor* step_tensor; + OP_REQUIRES_OK(context, context->input("step", &step_tensor)); + int64 step = step_tensor->scalar()(); + + OP_REQUIRES(context, (step == 1), errors::InvalidArgument("step != 1 is not supported: ", step)); + + std::vector dtypes; + std::vector shapes; + OP_REQUIRES_OK(context, resource->Spec(dtypes, shapes)); + + int64 count = shapes[0].dim_size(0); + if (start > count) { + start = count; + } + if (stop < 0) { + stop = count; + } + if (stop < start) { + stop = start; + } + + std::vector tensors; + for (size_t i = 0; i < dtypes.size(); i++) { + gtl::InlinedVector dims = shapes[i].dim_sizes(); + dims[0] = stop - start; + tensors.emplace_back(Tensor(dtypes[i], TensorShape(dims))); + } + OP_REQUIRES_OK(context, resource->GetItem(start, stop, step, tensors)); + for (size_t i = 0; i < tensors.size(); i++) { + context->set_output(i, tensors[i]); + } + } +}; +} // namespace data +} // namespace tensorflow diff --git a/tensorflow_io/core/ops/core_ops.cc b/tensorflow_io/core/ops/core_ops.cc index 61760e457..ee416d014 100644 --- a/tensorflow_io/core/ops/core_ops.cc +++ b/tensorflow_io/core/ops/core_ops.cc @@ -19,21 +19,6 @@ limitations under the License. namespace tensorflow { -REGISTER_OP("AdjustBatchDataset") - .Input("input_dataset: variant") - .Input("batch_size: int64") - .Input("batch_mode: string") - .Output("handle: variant") - .Attr("output_types: list(type) >= 1") - .Attr("output_shapes: list(shape) >= 1") - .SetShapeFn([](shape_inference::InferenceContext* c) { - shape_inference::ShapeHandle unused; - // batch_size should be a scalar. - TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); - // batch_mode should be a scalar. - TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused)); - return shape_inference::ScalarShape(c); - }); REGISTER_OP("ListArchiveEntries") .Input("filename: string") .Input("memory: string") diff --git a/tensorflow_io/core/python/ops/audio_io_tensor_ops.py b/tensorflow_io/core/python/ops/audio_io_tensor_ops.py new file mode 100644 index 000000000..81d666520 --- /dev/null +++ b/tensorflow_io/core/python/ops/audio_io_tensor_ops.py @@ -0,0 +1,54 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""AudioIOTensor""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import sys +import collections +import uuid + +import tensorflow as tf +from tensorflow_io.core.python.ops import io_tensor_ops +from tensorflow_io.core.python.ops import core_ops + +class AudioIOTensor(io_tensor_ops._ColumnIOTensor): + """AudioIOTensor""" + + #============================================================================= + # Constructor (private) + #============================================================================= + def __init__(self, + filename, + internal=False): + with tf.name_scope("AudioIOTensor") as scope: + resource, dtypes, shapes, rate = core_ops.wav_indexable_init( + filename, + container=scope, + shared_name="%s/%s" % (filename, uuid.uuid4().hex)) + self._rate = rate.numpy() + super(AudioIOTensor, self).__init__( + shapes, dtypes, resource, core_ops.wav_indexable_get_item, + internal=internal) + + #============================================================================= + # Accessors + #============================================================================= + + @property + def rate(self): + """The sampel `rate` of the audio stream""" + return self._rate diff --git a/tensorflow_io/core/python/ops/io_tensor.py b/tensorflow_io/core/python/ops/io_tensor.py new file mode 100644 index 000000000..3089e4e47 --- /dev/null +++ b/tensorflow_io/core/python/ops/io_tensor.py @@ -0,0 +1,254 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""IOTensor""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import sys +import collections +import uuid + +import tensorflow as tf +from tensorflow_io.core.python.ops import io_tensor_ops +from tensorflow_io.core.python.ops import audio_io_tensor_ops +from tensorflow_io.core.python.ops import json_io_tensor_ops + +class IOTensor(io_tensor_ops._BaseIOTensor): + """IOTensor + + An `IOTensor` is a tensor with data backed by IO operations. For example, + an `AudioIOTensor` is a tensor with data from an audio file, a + `KafkaIOTensor` is a tensor with data from reading the messages of a Kafka + stream server. + + There are two types of `IOTensor`, a normal `IOTensor` which itself is + indexable, or a degenerated `IOIterableTensor` which only supports + accessing the tensor iteratively. + + Since `IOTensor` is indexable, it support `__getitem__()` and + `__len__()` methods in Python. In other words, it is a subclass of + `collections.abc.Sequence`. + + Example: + + ```python + >>> import tensorflow_io as tfio + >>> + >>> samples = tfio.IOTensor.from_audio("sample.wav") + >>> print(samples[1000:1005]) + ... tf.Tensor( + ... [[-3] + ... [-7] + ... [-6] + ... [-6] + ... [-5]], shape=(5, 1), dtype=int16) + ``` + + A `IOIterableTensor` is really a subclass of `collections.abc.Iterable`. + It provides a `__iter__()` method that could be used (through `iter` + indirectly) to access data in an iterative fashion. + + Example: + + ```python + >>> import tensorflow_io as tfio + >>> + >>> kafka = tfio.IOTensor.from_kafka("test", eof=True) + >>> for message in kafka: + >>> print(message) + ... tf.Tensor(['D0'], shape=(1,), dtype=string) + ... tf.Tensor(['D1'], shape=(1,), dtype=string) + ... tf.Tensor(['D2'], shape=(1,), dtype=string) + ... tf.Tensor(['D3'], shape=(1,), dtype=string) + ... tf.Tensor(['D4'], shape=(1,), dtype=string) + ``` + + ### Indexable vs. Iterable + + While many IO formats are natually considered as iterable only, in most + of the situations they could still be accessed by indexing through certain + workaround. For example, a Kafka stream is not directly indexable yet the + stream could be saved in memory or disk to allow indexing. Another example + is the packet capture (PCAP) file in networking area. The packets inside + a PCAP file is concatenated sequentially. Since each packets could have + a variable length, the only way to access each packet is to read one + packet at a time. If the PCAP file is huge (e.g., hundreds of GBs or even + TBs), it may not be realistic (or necessarily) to save the index of every + packet in memory. We could consider PCAP format as iterable only. + + As we could see the availability of memory size could be a factor to decide + if a format is indexable or not. However, this factor could also be blurred + as well in distributed computing. One common case is the file format that + might be splittable where a file could be split into multiple chunks + (without read the whole file) with no data overlapping in between those + chunks. For example, a text file could be reliably split into multiple + chunks with line feed (LF) as the boundary. Processing of chunks could then + be distributed across a group of compute node to speed up (by reading small + chunks into memory). From that standpoint, we could still consider splittable + formats as indexable. + + For that reason our focus is `IOTensor` with convinience indexing and slicing + through `__getitem__()` method. + + ### Lazy Read + + One useful feature of `IOTensor` is the lazy read. Data inside a file is not + read into memory until needed. This could be convenient where only a small + segment of the data is needed. For example, a WAV file could be as big as + GBs but in many cases only several seconds of samples are used for training + or inference purposes. + + While CPU memory is cheap nowadays, GPU memory is still considered as an + expensive resource. It is also imperative to fit data in GPU memory for + speed up purposes. From that perspective lazy read could be very helpful. + + ### Association of Meta Data + + While a file format could consist of mostly numeric data, in may situations + the meta data is important as well. For example, in audio file format the + sample rate is a number that is necessary for almost everything. Association + of the sample rate with the sample of int16 Tensor is more helpful, + especially in eager mode. + + Example: + + ```python + >>> import tensorflow_io as tfio + >>> + >>> samples = tfio.IOTensor.from_audio("sample.wav") + >>> print(samples.rate) + ... 44100 + ``` + + ### Nested Element Structure + + The concept of `IOTensor` is not limited to a Tensor of single data type. + It supports nested element structure which could consists of many + components and complex structures. The exposed API such as `shape()` or + `dtype()` will display the shape and data type of an individual Tensor, + or a nested structure of shape and data types for components of a + composite Tensor. + + Example: + + ```python + >>> import tensorflow_io as tfio + >>> + >>> samples = tfio.IOTensor.from_audio("sample.wav") + >>> print(samples.shape) + ... (22050, 2) + >>> print(samples.dtype) + ... + >>> + >>> features = tfio.IOTensor.from_json("feature.json") + >>> print(features.shape) + ... (TensorShape([Dimension(2)]), TensorShape([Dimension(2)])) + >>> print(features.dtype) + ... (tf.float64, tf.int64) + ``` + + ### Access Columns of Tabular Data Formats + + May file formats such as Parquet or Json are considered as Tabular because + they consists of columns in a table. With `IOTensor` it is possible to + access individual columns through `__call__()`. + + Example: + + ```python + >>> import tensorflow_io as tfio + >>> + >>> features = tfio.IOTensor.from_json("feature.json") + >>> print(features.shape) + ... (TensorShape([Dimension(2)]), TensorShape([Dimension(2)])) + >>> print(features.dtype) + ... (tf.float64, tf.int64) + >>> + >>> print(features("floatfeature").shape) + ... (2,) + >>> print(features("floatfeature").dtype) + ... + ``` + + ### Conversion from and to Tensor and Dataset + + When needed, `IOTensor` could be converted into a `Tensor` (through + `to_tensor()`, or a `tf.data.Dataset` (through `to_dataset()`, to + suppor operations that is only available through `Tensor` or + `tf.data.Dataset`. + + Example: + + ```python + >>> import tensorflow as tf + >>> import tensorflow_io as tfio + >>> + >>> features = tfio.IOTensor.from_json("feature.json") + >>> + >>> features_tensor = features.to_tensor() + >>> print(features_tensor()) + ... (, + ... ) + >>> + >>> features_dataset = features.to_dataset() + >>> print(features_dataset) + ... <_IOTensorDataset shapes: ((), ()), types: (tf.float64, tf.int64)> + >>> + >>> dataset = tf.data.Dataset.zip((features_dataset, labels_dataset)) + ``` + + """ + + #============================================================================= + # Factory Methods + #============================================================================= + + @classmethod + def from_audio(cls, + filename, + **kwargs): + """Creates an `IOTensor` from an audio file. + + The following audio file formats are supported: + - WAV + + Args: + filename: A string, the filename of an audio file. + name: A name prefix for the IOTensor (optional). + + Returns: + A `IOTensor`. + + """ + with tf.name_scope(kwargs.get("name", "IOFromAudio")): + return audio_io_tensor_ops.AudioIOTensor(filename, internal=True) + + @classmethod + def from_json(cls, + filename, + **kwargs): + """Creates an `IOTensor` from an json file. + + Args: + filename: A string, the filename of an json file. + name: A name prefix for the IOTensor (optional). + + Returns: + A `IOTensor`. + + """ + with tf.name_scope(kwargs.get("name", "IOFromJSON")): + return json_io_tensor_ops.JSONIOTensor(filename, internal=True) diff --git a/tensorflow_io/core/python/ops/io_tensor_ops.py b/tensorflow_io/core/python/ops/io_tensor_ops.py new file mode 100644 index 000000000..989b717c7 --- /dev/null +++ b/tensorflow_io/core/python/ops/io_tensor_ops.py @@ -0,0 +1,286 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""_BaseIOTensor""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import sys +import collections +import uuid + +import tensorflow as tf +from tensorflow_io.core.python.ops import core_ops + +class _BaseIOTensorDataset(tf.compat.v2.data.Dataset): + """_IOTensorDataset""" + + def __init__(self, spec, resource, function): + start = 0 + stop = tf.nest.flatten( + tf.nest.map_structure(lambda e: e.shape, spec))[0][0] + capacity = 4096 + entry_start = list(range(start, stop, capacity)) + entry_stop = entry_start[1:] + [stop] + + dtype = tf.nest.flatten( + tf.nest.map_structure(lambda e: e.dtype, spec)) + shape = tf.nest.flatten( + tf.nest.map_structure( + lambda e: tf.TensorShape( + [None]).concatenate(e.shape[1:]), spec)) + + dataset = tf.compat.v2.data.Dataset.from_tensor_slices(( + tf.constant(entry_start, tf.int64), + tf.constant(entry_stop, tf.int64))) + dataset = dataset.map( + lambda start, stop: function( + resource, start, stop, 1, dtype=dtype, shape=shape)) + # Note: tf.data.Dataset consider tuple `(e, )` as one element + # instead of a sequence. So next `unbatch()` will not work. + # The tf.stack() below is necessary. + if len(dtype) == 1: + dataset = dataset.map(tf.stack) + dataset = dataset.apply(tf.data.experimental.unbatch()) + self._dataset = dataset + self._resource = resource + self._function = function + super(_BaseIOTensorDataset, self).__init__( + self._dataset._variant_tensor) # pylint: disable=protected-access + + def _inputs(self): + return [] + + @property + def element_spec(self): + return self._dataset.element_spec + +class _BaseIOTensor(object): + """_BaseIOTensor""" + + def __init__(self, + spec, + resource, + function, + internal=False): + if not internal: + raise ValueError("IOTensor constructor is private; please use one " + "of the factory methods instead (e.g., " + "IOTensor.from_tensor())") + self._spec = spec + self._resource = resource + self._function = function + super(_BaseIOTensor, self).__init__() + + #============================================================================= + # Accessors + #============================================================================= + + @property + def spec(self): + """The `TensorSpec` of values in this tensor.""" + return self._spec + + #============================================================================= + # String Encoding + #============================================================================= + def __repr__(self): + return "<%s: spec=%s>" % ( + self.__class__.__name__, self.spec) + + + #============================================================================= + # Indexing & Slicing + #============================================================================= + def __getitem__(self, key): + """Returns the specified piece of this IOTensor.""" + if isinstance(key, slice): + start = key.start + stop = key.stop + step = key.step + if start is None: + start = 0 + if stop is None: + stop = -1 + if step is None: + step = 1 + else: + start = key + stop = key + 1 + step = 1 + dtype = tf.nest.flatten( + tf.nest.map_structure(lambda e: e.dtype, self.spec)) + shape = tf.nest.flatten( + tf.nest.map_structure(lambda e: e.shape, self.spec)) + return tf.nest.pack_sequence_as(self.spec, self._function( + self._resource, + start, stop, step, + dtype=dtype, + shape=shape)) + + def __len__(self): + """Returns the total number of items of this IOTensor.""" + return tf.nest.flatten( + tf.nest.map_structure(lambda e: e.shape, self.spec))[0][0] + + #============================================================================= + # Tensor Type Conversions + #============================================================================= + + @classmethod + def from_tensor(cls, + tensor, + **kwargs): + """Converts a `tf.Tensor` into a `IOTensor`. + + Examples: + + ```python + ``` + + Args: + tensor: The `Tensor` to convert. + + Returns: + A `IOTensor`. + + Raises: + ValueError: If tensor is not a `Tensor`. + """ + with tf.name_scope(kwargs.get("name", "IOFromTensor")): + _ = tensor + raise NotImplementedError() + + def to_tensor(self, **kwargs): + """Converts this `IOTensor` into a `tf.Tensor`. + + Example: + + ```python + ``` + + Args: + name: A name prefix for the returned tensors (optional). + + Returns: + A `Tensor` with value obtained from this `IOTensor`. + """ + with tf.name_scope(kwargs.get("name", "IOToTensor")): + return self.__getitem__(slice(None, None)) + + #============================================================================= + # Dataset Conversions + #============================================================================= + + def to_dataset(self): + """Converts this `IOTensor` into a `tf.data.Dataset`. + + Example: + + ```python + ``` + + Args: + + Returns: + A `tf.data.Dataset` with value obtained from this `IOTensor`. + """ + return _BaseIOTensorDataset( + self.spec, self._resource, self._function) + +class _ColumnIOTensor(_BaseIOTensor): + """_ColumnIOTensor""" + + def __init__(self, + shapes, + dtypes, + resource, + function, + internal=False): + shapes = [ + tf.TensorShape( + [None if dim < 0 else dim for dim in e.numpy() if dim != 0] + ) for e in tf.unstack(shapes)] + dtypes = [tf.as_dtype(e.numpy()) for e in tf.unstack(dtypes)] + spec = [tf.TensorSpec(shape, dtype) for ( + shape, dtype) in zip(shapes, dtypes)] + assert len(spec) == 1 + spec = spec[0] + + self._shape = spec.shape + self._dtype = spec.dtype + super(_ColumnIOTensor, self).__init__( + spec, resource, function, internal=internal) + + #============================================================================= + # Accessors + #============================================================================= + + @property + def shape(self): + """Returns the `TensorShape` that represents the shape of the tensor.""" + return self._shape + + @property + def dtype(self): + """Returns the `dtype` of elements in the tensor.""" + return self._dtype + +class _TableIOTensor(_BaseIOTensor): + """_TableIOTensor""" + + def __init__(self, + shapes, + dtypes, + columns, + filename, + resource, + function, + internal=False): + shapes = [ + tf.TensorShape( + [None if dim < 0 else dim for dim in e.numpy() if dim != 0] + ) for e in tf.unstack(shapes)] + dtypes = [tf.as_dtype(e.numpy()) for e in tf.unstack(dtypes)] + columns = [e.numpy() for e in tf.unstack(columns)] + spec = [tf.TensorSpec(shape, dtype, column) for ( + shape, dtype, column) in zip(shapes, dtypes, columns)] + if len(spec) == 1: + spec = spec[0] + else: + spec = tuple(spec) + self._filename = filename + super(_TableIOTensor, self).__init__( + spec, resource, function, internal=internal) + + #============================================================================= + # Accessors + #============================================================================= + + def columns(self): + """The `TensorSpec` of column named `name`""" + return [e.name for e in tf.nest.flatten(self.spec)] + + def shape(self, column): + """Returns the `TensorShape` shape of `column` in the tensor.""" + return next(e.shape for e in tf.nest.flatten(self.spec) if e.name == column) + + def dtype(self, column): + """Returns the `dtype` of `column` in the tensor.""" + return next(e.dtype for e in tf.nest.flatten(self.spec) if e.name == column) + + def __call__(self, column): + """Return a new IOTensor with column named `column`""" + return self.__class__(self._filename, columns=[column], internal=True) diff --git a/tensorflow_io/core/python/ops/json_io_tensor_ops.py b/tensorflow_io/core/python/ops/json_io_tensor_ops.py new file mode 100644 index 000000000..73fe4a924 --- /dev/null +++ b/tensorflow_io/core/python/ops/json_io_tensor_ops.py @@ -0,0 +1,50 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""JSONIOTensor""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import sys +import collections +import uuid + +import tensorflow as tf +from tensorflow_io.core.python.ops import io_tensor_ops +from tensorflow_io.core.python.ops import core_ops + +class JSONIOTensor(io_tensor_ops._TableIOTensor): + """JSONIOTensor""" + + #============================================================================= + # Constructor (private) + #============================================================================= + def __init__(self, + filename, + columns=None, + internal=False): + with tf.name_scope("JSONIOTensor") as scope: + metadata = [] + if columns is not None: + metadata.extend(["column: "+column for column in columns]) + resource, dtypes, shapes, columns = core_ops.json_indexable_init( + filename, metadata=metadata, + container=scope, + shared_name="%s/%s" % (filename, uuid.uuid4().hex)) + self._filename = filename + super(JSONIOTensor, self).__init__( + shapes, dtypes, columns, filename, + resource, core_ops.json_indexable_get_item, + internal=internal) diff --git a/tensorflow_io/json/BUILD b/tensorflow_io/json/BUILD index ae75eb6df..b3d8b63cf 100644 --- a/tensorflow_io/json/BUILD +++ b/tensorflow_io/json/BUILD @@ -19,7 +19,9 @@ cc_library( ], linkstatic = True, deps = [ + "//tensorflow_io/arrow:arrow_ops", "//tensorflow_io/core:dataset_ops", + "@arrow", "@jsoncpp_git//:jsoncpp", ], ) diff --git a/tensorflow_io/json/kernels/json_kernels.cc b/tensorflow_io/json/kernels/json_kernels.cc index 2fcae9863..d7841ec21 100644 --- a/tensorflow_io/json/kernels/json_kernels.cc +++ b/tensorflow_io/json/kernels/json_kernels.cc @@ -21,6 +21,12 @@ limitations under the License. #include "tensorflow/core/lib/io/buffered_inputstream.h" #include "tensorflow/core/platform/env.h" #include "include/json/json.h" +#include "tensorflow_io/core/kernels/io_interface.h" +#include "tensorflow_io/core/kernels/stream.h" +#include "arrow/memory_pool.h" +#include "arrow/json/reader.h" +#include "arrow/table.h" +#include "tensorflow_io/arrow/kernels/arrow_kernels.h" namespace tensorflow { namespace data { @@ -199,5 +205,229 @@ REGISTER_KERNEL_BUILDER(Name("ReadJSON").Device(DEVICE_CPU), ReadJSONOp); } // namespace + + +class JSONIndexable : public IOIndexableInterface { + public: + JSONIndexable(Env* env) + : env_(env) {} + + ~JSONIndexable() {} + Status Init(const std::vector& input, const std::vector& metadata, const void* memory_data, const int64 memory_size) override { + if (input.size() > 1) { + return errors::InvalidArgument("more than 1 filename is not supported"); + } + const string& filename = input[0]; + file_.reset(new SizedRandomAccessFile(env_, filename, memory_data, memory_size)); + TF_RETURN_IF_ERROR(file_->GetFileSize(&file_size_)); + + json_file_.reset(new ArrowRandomAccessFile(file_.get(), file_size_)); + + ::arrow::Status status; + + status = ::arrow::json::TableReader::Make(::arrow::default_memory_pool(), json_file_, ::arrow::json::ReadOptions::Defaults(), ::arrow::json::ParseOptions::Defaults(), &reader_); + if (!status.ok()) { + return errors::InvalidArgument("unable to make a TableReader: ", status); + } + status = reader_->Read(&table_); + if (!status.ok()) { + return errors::InvalidArgument("unable to read table: ", status); + } + + std::vector columns; + for (size_t i = 0; i < metadata.size(); i++) { + if (metadata[i].find_first_of("column: ") == 0) { + columns.emplace_back(metadata[i].substr(8)); + } + } + + columns_index_.clear(); + if (columns.size() == 0) { + for (int i = 0; i < table_->num_columns(); i++) { + columns_index_.push_back(i); + } + } else { + std::unordered_map columns_map; + for (int i = 0; i < table_->num_columns(); i++) { + columns_map[table_->column(i)->name()] = i; + } + for (size_t i = 0; i < columns.size(); i++) { + columns_index_.push_back(columns_map[columns[i]]); + } + } + + dtypes_.clear(); + shapes_.clear(); + columns_.clear(); + for (size_t i = 0; i < columns_index_.size(); i++) { + int column_index = columns_index_[i]; + ::tensorflow::DataType dtype; + switch (table_->column(column_index)->type()->id()) { + case ::arrow::Type::BOOL: + dtype = ::tensorflow::DT_BOOL; + break; + case ::arrow::Type::UINT8: + dtype= ::tensorflow::DT_UINT8; + break; + case ::arrow::Type::INT8: + dtype= ::tensorflow::DT_INT8; + break; + case ::arrow::Type::UINT16: + dtype= ::tensorflow::DT_UINT16; + break; + case ::arrow::Type::INT16: + dtype= ::tensorflow::DT_INT16; + break; + case ::arrow::Type::UINT32: + dtype= ::tensorflow::DT_UINT32; + break; + case ::arrow::Type::INT32: + dtype= ::tensorflow::DT_INT32; + break; + case ::arrow::Type::UINT64: + dtype= ::tensorflow::DT_UINT64; + break; + case ::arrow::Type::INT64: + dtype= ::tensorflow::DT_INT64; + break; + case ::arrow::Type::HALF_FLOAT: + dtype= ::tensorflow::DT_HALF; + break; + case ::arrow::Type::FLOAT: + dtype= ::tensorflow::DT_FLOAT; + break; + case ::arrow::Type::DOUBLE: + dtype= ::tensorflow::DT_DOUBLE; + break; + case ::arrow::Type::STRING: + case ::arrow::Type::BINARY: + case ::arrow::Type::FIXED_SIZE_BINARY: + case ::arrow::Type::DATE32: + case ::arrow::Type::DATE64: + case ::arrow::Type::TIMESTAMP: + case ::arrow::Type::TIME32: + case ::arrow::Type::TIME64: + case ::arrow::Type::INTERVAL: + case ::arrow::Type::DECIMAL: + case ::arrow::Type::LIST: + case ::arrow::Type::STRUCT: + case ::arrow::Type::UNION: + case ::arrow::Type::DICTIONARY: + case ::arrow::Type::MAP: + default: + return errors::InvalidArgument("arrow data type is not supported: ", table_->column(i)->type()->ToString()); + } + dtypes_.push_back(dtype); + shapes_.push_back(TensorShape({static_cast(table_->num_rows())})); + columns_.push_back(table_->column(column_index)->name()); + } + + return Status::OK(); + } + Status Spec(std::vector& dtypes, std::vector& shapes) override { + dtypes.clear(); + for (size_t i = 0; i < dtypes_.size(); i++) { + dtypes.push_back(dtypes_[i]); + } + shapes.clear(); + for (size_t i = 0; i < shapes_.size(); i++) { + shapes.push_back(shapes_[i]); + } + return Status::OK(); + } + + Status Extra(std::vector* extra) override { + // Expose columns + Tensor columns(DT_STRING, TensorShape({static_cast(columns_.size())})); + for (size_t i = 0; i < columns_.size(); i++) { + columns.flat()(i) = columns_[i]; + } + extra->push_back(columns); + return Status::OK(); + } + + Status GetItem(const int64 start, const int64 stop, const int64 step, std::vector& tensors) override { + Tensor& output_tensor = tensors[0]; + if (step != 1) { + return errors::InvalidArgument("step ", step, " is not supported"); + } + for (size_t i = 0; i < tensors.size(); i++) { + int column_index = columns_index_[i]; + std::shared_ptr<::arrow::Column> slice = table_->column(column_index)->Slice(start, stop); + + #define PROCESS_TYPE(TTYPE,ATYPE) { \ + int64 curr_index = 0; \ + for (auto chunk : slice->data()->chunks()) { \ + for (int64_t item = 0; item < chunk->length(); item++) { \ + tensors[i].flat()(curr_index) = (dynamic_cast(chunk.get()))->Value(item); \ + curr_index++; \ + } \ + } \ + } + switch (tensors[i].dtype()) { + case DT_BOOL: + PROCESS_TYPE(bool, ::arrow::BooleanArray); + break; + case DT_INT8: + PROCESS_TYPE(int8, ::arrow::NumericArray<::arrow::Int8Type>); + break; + case DT_UINT8: + PROCESS_TYPE(uint8, ::arrow::NumericArray<::arrow::UInt8Type>); + break; + case DT_INT16: + PROCESS_TYPE(int16, ::arrow::NumericArray<::arrow::Int16Type>); + break; + case DT_UINT16: + PROCESS_TYPE(uint16, ::arrow::NumericArray<::arrow::UInt16Type>); + break; + case DT_INT32: + PROCESS_TYPE(int32, ::arrow::NumericArray<::arrow::Int32Type>); + break; + case DT_UINT32: + PROCESS_TYPE(uint32, ::arrow::NumericArray<::arrow::UInt32Type>); + break; + case DT_INT64: + PROCESS_TYPE(int64, ::arrow::NumericArray<::arrow::Int64Type>); + break; + case DT_UINT64: + PROCESS_TYPE(uint64, ::arrow::NumericArray<::arrow::UInt64Type>); + break; + case DT_FLOAT: + PROCESS_TYPE(float, ::arrow::NumericArray<::arrow::FloatType>); + break; + case DT_DOUBLE: + PROCESS_TYPE(double, ::arrow::NumericArray<::arrow::DoubleType>); + break; + default: + return errors::InvalidArgument("data type is not supported: ", DataTypeString(tensors[i].dtype())); + } + } + + return Status::OK(); + } + + string DebugString() const override { + mutex_lock l(mu_); + return strings::StrCat("JSONIndexable"); + } + private: + mutable mutex mu_; + Env* env_ GUARDED_BY(mu_); + std::unique_ptr file_ GUARDED_BY(mu_); + uint64 file_size_ GUARDED_BY(mu_); + std::shared_ptr json_file_; + std::shared_ptr<::arrow::json::TableReader> reader_; + std::shared_ptr<::arrow::Table> table_; + + std::vector dtypes_; + std::vector shapes_; + std::vector columns_; + std::vector columns_index_; +}; + +REGISTER_KERNEL_BUILDER(Name("JSONIndexableInit").Device(DEVICE_CPU), + IOInterfaceInitOp); +REGISTER_KERNEL_BUILDER(Name("JSONIndexableGetItem").Device(DEVICE_CPU), + IOIndexableGetItemOp); } // namespace data } // namespace tensorflow diff --git a/tensorflow_io/json/ops/json_ops.cc b/tensorflow_io/json/ops/json_ops.cc index a2a6afc7d..c9206409a 100644 --- a/tensorflow_io/json/ops/json_ops.cc +++ b/tensorflow_io/json/ops/json_ops.cc @@ -19,6 +19,47 @@ limitations under the License. namespace tensorflow { +REGISTER_OP("JSONIndexableInit") + .Input("input: string") + .Input("metadata: string") + .Output("output: resource") + .Output("dtypes: int64") + .Output("shapes: int64") + .Output("columns: string") + .Attr("container: string = ''") + .Attr("shared_name: string = ''") + .SetIsStateful() + .SetShapeFn([](shape_inference::InferenceContext* c) { + c->set_output(0, c->Scalar()); + c->set_output(1, c->MakeShape({c->UnknownDim()})); + c->set_output(2, c->MakeShape({c->UnknownDim(), c->UnknownDim()})); + c->set_output(3, c->MakeShape({c->UnknownDim()})); + return Status::OK(); + }); + +REGISTER_OP("JSONIndexableGetItem") + .Input("input: resource") + .Input("start: int64") + .Input("stop: int64") + .Input("step: int64") + .Output("output: dtype") + .Attr("dtype: list(type) >= 1") + .Attr("shape: list(shape) >= 1") + .SetShapeFn([](shape_inference::InferenceContext* c) { + std::vector shape; + TF_RETURN_IF_ERROR(c->GetAttr("shape", &shape)); + if (shape.size() != c->num_outputs()) { + return errors::InvalidArgument("`shape` must be the same length as `types` (", shape.size(), " vs. ", c->num_outputs()); + } + for (size_t i = 0; i < shape.size(); ++i) { + shape_inference::ShapeHandle entry; + TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape(shape[i], &entry)); + c->set_output(static_cast(i), entry); + } + return Status::OK(); + }); + + REGISTER_OP("ListJSONColumns") .Input("filename: string") .Output("columns: string") diff --git a/tests/test_audio_eager.py b/tests/test_audio_eager.py index c848b157b..5ff8e5bb0 100644 --- a/tests/test_audio_eager.py +++ b/tests/test_audio_eager.py @@ -23,7 +23,7 @@ import tensorflow as tf if not (hasattr(tf, "version") and tf.version.VERSION.startswith("2.")): tf.compat.v1.enable_eager_execution() -import tensorflow_io.audio as audio_io # pylint: disable=wrong-import-position +import tensorflow_io as tfio # pylint: disable=wrong-import-position audio_path = os.path.join( os.path.dirname(os.path.abspath(__file__)), @@ -37,33 +37,25 @@ def test_audio_dataset(): f = lambda x: float(x) / (1 << 15) - for capacity in [10, 100, 500]: - audio_dataset = audio_io.WAVDataset(audio_path, capacity=capacity).apply( - tf.data.experimental.unbatch()).map(tf.squeeze) - i = 0 - for v in audio_dataset: - assert audio_v.audio[i].numpy() == f(v.numpy()) - i += 1 - assert i == 5760 + audio_dataset = tfio.IOTensor.from_audio(audio_path).to_dataset() + i = 0 + for v in audio_dataset: + assert audio_v.audio[i].numpy() == f(v.numpy()) + i += 1 + assert i == 5760 - for capacity in [10, 100, 500]: - audio_dataset = audio_io.WAVDataset(audio_path, capacity=capacity).apply( - tf.data.experimental.unbatch()).batch(2).map(tf.squeeze) - i = 0 - for v in audio_dataset: - assert audio_v.audio[i].numpy() == f(v[0].numpy()) - assert audio_v.audio[i + 1].numpy() == f(v[1].numpy()) - i += 2 - assert i == 5760 + audio_dataset = tfio.IOTensor.from_audio(audio_path).to_dataset().batch(2) + i = 0 + for v in audio_dataset: + assert audio_v.audio[i].numpy() == f(v[0].numpy()) + assert audio_v.audio[i + 1].numpy() == f(v[1].numpy()) + i += 2 + assert i == 5760 - spec, rate = audio_io.list_wav_info(audio_path) - assert spec.dtype == tf.int16 - assert spec.shape == [5760, 1] - assert rate.numpy() == audio_v.sample_rate.numpy() - - samples = audio_io.read_wav(audio_path, spec) + samples = tfio.IOTensor.from_audio(audio_path) assert samples.dtype == tf.int16 assert samples.shape == [5760, 1] + assert samples.rate == audio_v.sample_rate.numpy() audio_24bit_path = os.path.join( os.path.dirname(os.path.abspath(__file__)), @@ -76,12 +68,8 @@ def test_audio_dataset(): expected = np.fromfile(audio_24bit_raw_path, np.int32) expected = np.reshape(expected, [22050, 2]) - spec, rate = audio_io.list_wav_info(audio_24bit_path) - assert spec.dtype == tf.int32 - assert spec.shape == [22050, 2] - assert rate.numpy() == 44100 - - samples = audio_io.read_wav(audio_24bit_path, spec) + samples = tfio.IOTensor.from_audio(audio_24bit_path) assert samples.dtype == tf.int32 assert samples.shape == [22050, 2] - assert np.all(samples.numpy() == expected) + assert samples.rate == 44100 + assert np.all(samples.to_tensor().numpy() == expected) diff --git a/tests/test_json/feature.ndjson b/tests/test_json/feature.ndjson new file mode 100644 index 000000000..aaa67a60d --- /dev/null +++ b/tests/test_json/feature.ndjson @@ -0,0 +1,2 @@ +{ "floatfeature": 1.1, "integerfeature": 2 } +{ "floatfeature": 2.1, "integerfeature": 3 } diff --git a/tests/test_json/label.ndjson b/tests/test_json/label.ndjson new file mode 100644 index 000000000..5c4aa77d0 --- /dev/null +++ b/tests/test_json/label.ndjson @@ -0,0 +1,2 @@ +{ "floatlabel": 2.2, "integerlabel": 3 } +{ "floatlabel": 1.2, "integerlabel": 3 } diff --git a/tests/test_json_eager.py b/tests/test_json_eager.py index 333f24718..55393a9e6 100644 --- a/tests/test_json_eager.py +++ b/tests/test_json_eager.py @@ -23,8 +23,65 @@ import tensorflow as tf if not (hasattr(tf, "version") and tf.version.VERSION.startswith("2.")): tf.compat.v1.enable_eager_execution() +import tensorflow_io as tfio # pylint: disable=wrong-import-position import tensorflow_io.json as json_io # pylint: disable=wrong-import-position +def test_io_tensor_json(): + """Test case for tfio.IOTensor.from_json.""" + x_test = [[1.1, 2], [2.1, 3]] + y_test = [[2.2, 3], [1.2, 3]] + feature_filename = os.path.join( + os.path.dirname(os.path.abspath(__file__)), + "test_json", + "feature.ndjson") + feature_filename = "file://" + feature_filename + label_filename = os.path.join( + os.path.dirname(os.path.abspath(__file__)), + "test_json", + "label.ndjson") + label_filename = "file://" + label_filename + + features = tfio.IOTensor.from_json(feature_filename) + assert features.dtype("floatfeature") == tf.float64 + assert features.dtype("integerfeature") == tf.int64 + + labels = tfio.IOTensor.from_json(label_filename) + assert labels.dtype("floatlabel") == tf.float64 + assert labels.dtype("integerlabel") == tf.int64 + + float_feature = features("floatfeature") + integer_feature = features("integerfeature") + float_label = labels("floatlabel") + integer_label = labels("integerlabel") + + for i in range(2): + v_x = x_test[i] + v_y = y_test[i] + assert v_x[0] == float_feature[i].numpy() + assert v_x[1] == integer_feature[i].numpy() + assert v_y[0] == float_label[i].numpy() + assert v_y[1] == integer_label[i].numpy() + + feature_dataset = features.to_dataset() + + label_dataset = labels.to_dataset() + + dataset = tf.data.Dataset.zip(( + feature_dataset, + label_dataset + )) + + i = 0 + for (j_x, j_y) in dataset: + v_x = x_test[i] + v_y = y_test[i] + for index, x in enumerate(j_x): + assert v_x[index] == x.numpy() + for index, y in enumerate(j_y): + assert v_y[index] == y.numpy() + i += 1 + assert i == len(y_test) + def test_json_dataset(): """Test case for JSON Dataset. """ diff --git a/third_party/arrow.BUILD b/third_party/arrow.BUILD index c1f10307c..8ea8145e4 100644 --- a/third_party/arrow.BUILD +++ b/third_party/arrow.BUILD @@ -47,6 +47,8 @@ cc_library( "cpp/src/arrow/io/*.h", "cpp/src/arrow/ipc/*.cc", "cpp/src/arrow/ipc/*.h", + "cpp/src/arrow/json/*.cc", + "cpp/src/arrow/json/*.h", "cpp/src/arrow/util/*.cc", "cpp/src/arrow/util/*.h", "cpp/src/arrow/vendored/**/*.cpp", @@ -120,9 +122,10 @@ cc_library( deps = [ ":arrow_format", "@boost", + "@zlib", "@double_conversion//:double-conversion", + "@rapidjson", "@snappy", "@thrift", - "@zlib", ], ) diff --git a/third_party/rapidjson.BUILD b/third_party/rapidjson.BUILD new file mode 100644 index 000000000..af5f2d26e --- /dev/null +++ b/third_party/rapidjson.BUILD @@ -0,0 +1,23 @@ +package(default_visibility = ["//visibility:public"]) + +licenses(["notice"]) # MIT/JSON license + +cc_library( + name = "rapidjson", + srcs = glob( + [ + ], + ) + [ + ], + hdrs = glob( + [ + "include/**/*.h", + ], + ) + [ + ], + copts = [ + ], + includes = [ + "include", + ], +) From 1527f0ec5d76c9739ed22b6e684c89ffe3fe6aea Mon Sep 17 00:00:00 2001 From: Yong Tang Date: Wed, 21 Aug 2019 23:22:51 +0000 Subject: [PATCH 02/11] Remove Iterable from reference Signed-off-by: Yong Tang --- tensorflow_io/core/python/ops/io_tensor.py | 35 ++++------------------ 1 file changed, 6 insertions(+), 29 deletions(-) diff --git a/tensorflow_io/core/python/ops/io_tensor.py b/tensorflow_io/core/python/ops/io_tensor.py index 3089e4e47..666865478 100644 --- a/tensorflow_io/core/python/ops/io_tensor.py +++ b/tensorflow_io/core/python/ops/io_tensor.py @@ -34,11 +34,7 @@ class IOTensor(io_tensor_ops._BaseIOTensor): `KafkaIOTensor` is a tensor with data from reading the messages of a Kafka stream server. - There are two types of `IOTensor`, a normal `IOTensor` which itself is - indexable, or a degenerated `IOIterableTensor` which only supports - accessing the tensor iteratively. - - Since `IOTensor` is indexable, it support `__getitem__()` and + The `IOTensor` is indexable, supporting `__getitem__()` and `__len__()` methods in Python. In other words, it is a subclass of `collections.abc.Sequence`. @@ -57,25 +53,6 @@ class IOTensor(io_tensor_ops._BaseIOTensor): ... [-5]], shape=(5, 1), dtype=int16) ``` - A `IOIterableTensor` is really a subclass of `collections.abc.Iterable`. - It provides a `__iter__()` method that could be used (through `iter` - indirectly) to access data in an iterative fashion. - - Example: - - ```python - >>> import tensorflow_io as tfio - >>> - >>> kafka = tfio.IOTensor.from_kafka("test", eof=True) - >>> for message in kafka: - >>> print(message) - ... tf.Tensor(['D0'], shape=(1,), dtype=string) - ... tf.Tensor(['D1'], shape=(1,), dtype=string) - ... tf.Tensor(['D2'], shape=(1,), dtype=string) - ... tf.Tensor(['D3'], shape=(1,), dtype=string) - ... tf.Tensor(['D4'], shape=(1,), dtype=string) - ``` - ### Indexable vs. Iterable While many IO formats are natually considered as iterable only, in most @@ -89,7 +66,7 @@ class IOTensor(io_tensor_ops._BaseIOTensor): TBs), it may not be realistic (or necessarily) to save the index of every packet in memory. We could consider PCAP format as iterable only. - As we could see the availability of memory size could be a factor to decide + As we could see the, availability of memory size could be a factor to decide if a format is indexable or not. However, this factor could also be blurred as well in distributed computing. One common case is the file format that might be splittable where a file could be split into multiple chunks @@ -172,10 +149,10 @@ class IOTensor(io_tensor_ops._BaseIOTensor): >>> import tensorflow_io as tfio >>> >>> features = tfio.IOTensor.from_json("feature.json") - >>> print(features.shape) - ... (TensorShape([Dimension(2)]), TensorShape([Dimension(2)])) - >>> print(features.dtype) - ... (tf.float64, tf.int64) + >>> print(features.shape("floatfeature")) + ... (2,) + >>> print(features.dtype("floatfeature")) + ... >>> >>> print(features("floatfeature").shape) ... (2,) From 83450acbef5309e39182d11a107c56dd36319b3c Mon Sep 17 00:00:00 2001 From: Yong Tang Date: Thu, 22 Aug 2019 00:12:39 +0000 Subject: [PATCH 03/11] Pylint fix Signed-off-by: Yong Tang --- tensorflow_io/core/python/ops/audio_io_tensor_ops.py | 4 +--- tensorflow_io/core/python/ops/io_tensor.py | 6 +----- tensorflow_io/core/python/ops/io_tensor_ops.py | 7 +------ tensorflow_io/core/python/ops/json_io_tensor_ops.py | 4 +--- third_party/arrow.BUILD | 2 +- 5 files changed, 5 insertions(+), 18 deletions(-) diff --git a/tensorflow_io/core/python/ops/audio_io_tensor_ops.py b/tensorflow_io/core/python/ops/audio_io_tensor_ops.py index 81d666520..969e1bfdf 100644 --- a/tensorflow_io/core/python/ops/audio_io_tensor_ops.py +++ b/tensorflow_io/core/python/ops/audio_io_tensor_ops.py @@ -17,15 +17,13 @@ from __future__ import division from __future__ import print_function -import sys -import collections import uuid import tensorflow as tf from tensorflow_io.core.python.ops import io_tensor_ops from tensorflow_io.core.python.ops import core_ops -class AudioIOTensor(io_tensor_ops._ColumnIOTensor): +class AudioIOTensor(io_tensor_ops._ColumnIOTensor): # pylint: disable=protected-access """AudioIOTensor""" #============================================================================= diff --git a/tensorflow_io/core/python/ops/io_tensor.py b/tensorflow_io/core/python/ops/io_tensor.py index 666865478..f31cf9437 100644 --- a/tensorflow_io/core/python/ops/io_tensor.py +++ b/tensorflow_io/core/python/ops/io_tensor.py @@ -17,16 +17,12 @@ from __future__ import division from __future__ import print_function -import sys -import collections -import uuid - import tensorflow as tf from tensorflow_io.core.python.ops import io_tensor_ops from tensorflow_io.core.python.ops import audio_io_tensor_ops from tensorflow_io.core.python.ops import json_io_tensor_ops -class IOTensor(io_tensor_ops._BaseIOTensor): +class IOTensor(io_tensor_ops._BaseIOTensor): # pylint: disable=protected-access """IOTensor An `IOTensor` is a tensor with data backed by IO operations. For example, diff --git a/tensorflow_io/core/python/ops/io_tensor_ops.py b/tensorflow_io/core/python/ops/io_tensor_ops.py index 989b717c7..6d0b9ca51 100644 --- a/tensorflow_io/core/python/ops/io_tensor_ops.py +++ b/tensorflow_io/core/python/ops/io_tensor_ops.py @@ -17,12 +17,7 @@ from __future__ import division from __future__ import print_function -import sys -import collections -import uuid - import tensorflow as tf -from tensorflow_io.core.python.ops import core_ops class _BaseIOTensorDataset(tf.compat.v2.data.Dataset): """_IOTensorDataset""" @@ -283,4 +278,4 @@ def dtype(self, column): def __call__(self, column): """Return a new IOTensor with column named `column`""" - return self.__class__(self._filename, columns=[column], internal=True) + return self.__class__(self._filename, columns=[column], internal=True) # pylint: disable=no-value-for-parameter diff --git a/tensorflow_io/core/python/ops/json_io_tensor_ops.py b/tensorflow_io/core/python/ops/json_io_tensor_ops.py index 73fe4a924..f6bab6acc 100644 --- a/tensorflow_io/core/python/ops/json_io_tensor_ops.py +++ b/tensorflow_io/core/python/ops/json_io_tensor_ops.py @@ -17,15 +17,13 @@ from __future__ import division from __future__ import print_function -import sys -import collections import uuid import tensorflow as tf from tensorflow_io.core.python.ops import io_tensor_ops from tensorflow_io.core.python.ops import core_ops -class JSONIOTensor(io_tensor_ops._TableIOTensor): +class JSONIOTensor(io_tensor_ops._TableIOTensor): # pylint: disable=protected-access """JSONIOTensor""" #============================================================================= diff --git a/third_party/arrow.BUILD b/third_party/arrow.BUILD index 8ea8145e4..d52d8242e 100644 --- a/third_party/arrow.BUILD +++ b/third_party/arrow.BUILD @@ -122,10 +122,10 @@ cc_library( deps = [ ":arrow_format", "@boost", - "@zlib", "@double_conversion//:double-conversion", "@rapidjson", "@snappy", "@thrift", + "@zlib", ], ) From eb19c0c83af2daf0e99c88ee3c10ea6488c37a0a Mon Sep 17 00:00:00 2001 From: Yong Tang Date: Thu, 22 Aug 2019 02:20:52 +0000 Subject: [PATCH 04/11] Add a decorator so that it could be picked up by __repr__ automatically Signed-off-by: Yong Tang --- tensorflow_io/core/python/ops/audio_io_tensor_ops.py | 4 ++-- tensorflow_io/core/python/ops/io_tensor_ops.py | 11 +++++++++-- 2 files changed, 11 insertions(+), 4 deletions(-) diff --git a/tensorflow_io/core/python/ops/audio_io_tensor_ops.py b/tensorflow_io/core/python/ops/audio_io_tensor_ops.py index 969e1bfdf..78742292e 100644 --- a/tensorflow_io/core/python/ops/audio_io_tensor_ops.py +++ b/tensorflow_io/core/python/ops/audio_io_tensor_ops.py @@ -46,7 +46,7 @@ def __init__(self, # Accessors #============================================================================= - @property + @io_tensor_ops._BaseIOTensorMeta # pylint: disable=protected-access def rate(self): - """The sampel `rate` of the audio stream""" + """The sample `rate` of the audio stream""" return self._rate diff --git a/tensorflow_io/core/python/ops/io_tensor_ops.py b/tensorflow_io/core/python/ops/io_tensor_ops.py index 6d0b9ca51..b6873c5ed 100644 --- a/tensorflow_io/core/python/ops/io_tensor_ops.py +++ b/tensorflow_io/core/python/ops/io_tensor_ops.py @@ -19,6 +19,10 @@ import tensorflow as tf +class _BaseIOTensorMeta(property): + """_BaseIOTensorMeta is a decorator that is viewable to __repr__""" + pass + class _BaseIOTensorDataset(tf.compat.v2.data.Dataset): """_IOTensorDataset""" @@ -92,8 +96,11 @@ def spec(self): # String Encoding #============================================================================= def __repr__(self): - return "<%s: spec=%s>" % ( - self.__class__.__name__, self.spec) + meta = "".join([", %s=%s" % ( + k, repr(v.__get__(self))) for k, v in self.__class__.__dict__.items( + ) if isinstance(v, _BaseIOTensorMeta)]) + return "<%s: spec=%s%s>" % ( + self.__class__.__name__, self.spec, meta) #============================================================================= From 3a4870af1c390b77be6fe45cbd4f5781ae8db1a1 Mon Sep 17 00:00:00 2001 From: Yong Tang Date: Thu, 22 Aug 2019 14:13:48 +0000 Subject: [PATCH 05/11] Fix python 3 issue Signed-off-by: Yong Tang --- tensorflow_io/core/python/ops/io_tensor_ops.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow_io/core/python/ops/io_tensor_ops.py b/tensorflow_io/core/python/ops/io_tensor_ops.py index b6873c5ed..c5fb22e14 100644 --- a/tensorflow_io/core/python/ops/io_tensor_ops.py +++ b/tensorflow_io/core/python/ops/io_tensor_ops.py @@ -256,7 +256,7 @@ def __init__(self, [None if dim < 0 else dim for dim in e.numpy() if dim != 0] ) for e in tf.unstack(shapes)] dtypes = [tf.as_dtype(e.numpy()) for e in tf.unstack(dtypes)] - columns = [e.numpy() for e in tf.unstack(columns)] + columns = [e.numpy().decode() for e in tf.unstack(columns)] spec = [tf.TensorSpec(shape, dtype, column) for ( shape, dtype, column) in zip(shapes, dtypes, columns)] if len(spec) == 1: From c1323edd6b8ab769f3081f6270829af142771a23 Mon Sep 17 00:00:00 2001 From: Yong Tang Date: Thu, 22 Aug 2019 15:17:33 +0000 Subject: [PATCH 06/11] Add KafkaDataset to tensorflow_io.core.python.ops.kafka_ops.KafkaDataset intend to deprecate old KafkaDataset soon. Signed-off-by: Yong Tang --- .../core/python/ops/kafka_dataset_ops.py | 84 ++++++ tensorflow_io/kafka/BUILD | 1 + tensorflow_io/kafka/kernels/kafka_kernels.cc | 250 ++++++++++++++++++ tensorflow_io/kafka/ops/kafka_ops.cc | 37 +++ .../kafka/python/ops/kafka_dataset_ops.py | 10 + tests/test_kafka_eager.py | 11 +- 6 files changed, 392 insertions(+), 1 deletion(-) create mode 100644 tensorflow_io/core/python/ops/kafka_dataset_ops.py create mode 100644 tensorflow_io/kafka/kernels/kafka_kernels.cc diff --git a/tensorflow_io/core/python/ops/kafka_dataset_ops.py b/tensorflow_io/core/python/ops/kafka_dataset_ops.py new file mode 100644 index 000000000..2c286bf13 --- /dev/null +++ b/tensorflow_io/core/python/ops/kafka_dataset_ops.py @@ -0,0 +1,84 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""KafkaDataset""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import sys +import uuid + +import tensorflow as tf +from tensorflow_io.core.python.ops import core_ops + +class KafkaDataset(tf.compat.v2.data.Dataset): + """KafkaDataset""" + + def __init__(self, + subscription, + servers=None, + configuration=None): + """Create a KafkaDataset. + Args: + subscription: A `tf.string` tensor containing subscription, + in the format of [topic:partition:offset:length], + by default length is -1 for unlimited. + servers: A list of bootstrap servers, by default `localhost:9092`. + configuration: A `tf.string` tensor containing configurations + in [Key=Value] format. There are three types of configurations, + Global configuration: please refer to 'Global configuration properties' + in librdkafka doc. Examples include + ["enable.auto.commit=false", "heartbeat.interval.ms=2000"] + Topic configuration: please refer to 'Topic configuration properties' + in librdkafka doc. Note all topic configurations should be + prefixed with `configuration.topic.`. Examples include + ["conf.topic.auto.offset.reset=earliest"] + Dataset configuration: there are two configurations available, + `conf.eof=0|1`: if True, the KafkaDaset will stop on EOF (default). + `conf.timeout=milliseconds`: timeout value for Kafka Consumer to wait. + """ + with tf.name_scope("KafkaDataset") as scope: + metadata = [e for e in configuration or []] + if servers is not None: + metadata.append("bootstrap.servers=%s" % servers) + resource, _, _ = core_ops.kafka_iterable_init( + subscription, metadata=metadata, + container=scope, + shared_name="%s/%s" % (subscription, uuid.uuid4().hex)) + + capacity = 4096 + dataset = tf.compat.v2.data.Dataset.range(0, sys.maxsize, capacity) + dataset = dataset.map( + lambda i: core_ops.kafka_iterable_next(resource, capacity, dtype=[tf.string], shape=[tf.TensorShape([None])])) + dataset = dataset.apply( + tf.data.experimental.take_while( + lambda v: tf.greater(tf.shape(v)[0], 0))) + # Note: tf.data.Dataset consider tuple `(e, )` as one element + # instead of a sequence. So next `unbatch()` will not work. + # The tf.stack() below is necessary. + dataset = dataset.map(tf.stack) + dataset = dataset.unbatch() + + self._resource = resource + self._dataset = dataset + super(KafkaDataset, self).__init__( + self._dataset._variant_tensor) # pylint: disable=protected-access + + def _inputs(self): + return [] + + @property + def element_spec(self): + return self._dataset.element_spec diff --git a/tensorflow_io/kafka/BUILD b/tensorflow_io/kafka/BUILD index 21ce10215..0b3976702 100644 --- a/tensorflow_io/kafka/BUILD +++ b/tensorflow_io/kafka/BUILD @@ -11,6 +11,7 @@ cc_library( name = "kafka_ops", srcs = [ "kernels/kafka_dataset_ops.cc", + "kernels/kafka_kernels.cc", "kernels/kafka_sequence.cc", "ops/dataset_ops.cc", "ops/kafka_ops.cc", diff --git a/tensorflow_io/kafka/kernels/kafka_kernels.cc b/tensorflow_io/kafka/kernels/kafka_kernels.cc new file mode 100644 index 000000000..30a595afb --- /dev/null +++ b/tensorflow_io/kafka/kernels/kafka_kernels.cc @@ -0,0 +1,250 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow_io/core/kernels/io_interface.h" +//#include "tensorflow/core/platform/logging.h" + +#include "rdkafkacpp.h" + +#include + +namespace tensorflow { +namespace data { + +class KafkaEventCb : public RdKafka::EventCb { +public: + KafkaEventCb() + : run_(true) {} + + bool run() { + return run_; + } + + void event_cb (RdKafka::Event &event) { + switch (event.type()) { + case RdKafka::Event::EVENT_ERROR: + LOG(ERROR) << "EVENT_ERROR: " << "(" << RdKafka::err2str(event.err()) << "): " << event.str(); + { + run_ = (event.err() != RdKafka::ERR__ALL_BROKERS_DOWN); + } + break; + case RdKafka::Event::EVENT_STATS: + LOG(ERROR) << "EVENT_STATS: " << event.str(); + break; + case RdKafka::Event::EVENT_LOG: + LOG(ERROR) << "EVENT_LOG: " << event.severity() << "-" << event.fac().c_str() << "-" << event.str().c_str(); + break; + case RdKafka::Event::EVENT_THROTTLE: + LOG(ERROR) << "EVENT_THROTTLE: " << event.throttle_time() << "ms by " << event.broker_name() << " id " << (int)event.broker_id(); + break; + default: + LOG(ERROR) << "EVENT: " << event.type() << " (" << RdKafka::err2str(event.err()) << "): " << event.str(); + break; + } + } +private: + mutable mutex mu_; + bool run_ GUARDED_BY(mu_) = true; +}; + +class KafkaIterable : public IOIterableInterface { + public: + KafkaIterable(Env* env) + : env_(env) {} + + ~KafkaIterable() {} + Status Init(const std::vector& input, const std::vector& metadata, const void* memory_data, const int64 memory_size) override { + std::unique_ptr conf( + RdKafka::Conf::create(RdKafka::Conf::CONF_GLOBAL)); + std::unique_ptr conf_topic( + RdKafka::Conf::create(RdKafka::Conf::CONF_TOPIC)); + + string errstr; + RdKafka::Conf::ConfResult result = RdKafka::Conf::CONF_UNKNOWN; + + eof_ = true; + timeout_ = 1000; + for (size_t i = 0; i < metadata.size(); i++) { + if (metadata[i].find_first_of("conf.eof") == 0) { + std::vector parts = str_util::Split(metadata[i], "="); + if (parts.size() != 2) { + return errors::InvalidArgument("invalid timeout configuration: ", metadata[i]); + } + eof_ = (parts[1] != "0"); + } else if (metadata[i].find_first_of("conf.timeout") == 0) { + std::vector parts = str_util::Split(metadata[i], "="); + if (parts.size() != 2 || !strings::safe_strto64(parts[1], &timeout_)) { + return errors::InvalidArgument("invalid timeout configuration: ", metadata[i]); + } + } else if (metadata[i].find_first_of("conf.topic.") == 0) { + std::vector parts = str_util::Split(metadata[i], "="); + if (parts.size() != 2) { + return errors::InvalidArgument("invalid topic configuration: ", metadata[i]); + } + result = conf_topic->set(parts[0].substr(11), parts[1], errstr); + if (result != RdKafka::Conf::CONF_OK) { + return errors::Internal("failed to do topic configuration:", metadata[i], "error:", errstr); + } + } else if (metadata[i] != "" && metadata[i].find_first_of("conf.") == string::npos) { + std::vector parts = str_util::Split(metadata[i], "="); + if (parts.size() != 2) { + return errors::InvalidArgument("invalid topic configuration: ", metadata[i]); + } + if ((result = conf->set(parts[0], parts[1], errstr)) != RdKafka::Conf::CONF_OK) { + return errors::Internal("failed to do global configuration: ", metadata[i], "error:", errstr); + } + } + } + if ((result = conf->set("default_topic_conf", conf_topic.get(), errstr)) != RdKafka::Conf::CONF_OK) { + return errors::Internal("failed to set default_topic_conf:", errstr); + } + + // consumer.properties: + // bootstrap.servers=localhost:9092 + // group.id=test-consumer-group + string bootstrap_servers; + if ((result = conf->get("bootstrap.servers", bootstrap_servers)) != RdKafka::Conf::CONF_OK) { + bootstrap_servers = "localhost:9092"; + if ((result = conf->set("bootstrap.servers", bootstrap_servers, errstr)) != RdKafka::Conf::CONF_OK) { + return errors::Internal("failed to set bootstrap.servers [", bootstrap_servers, "]:", errstr); + } + } + string group_id; + if ((result = conf->get("group.id", group_id)) != RdKafka::Conf::CONF_OK) { + group_id = "test-consumer-group"; + if ((result = conf->set("group.id", group_id, errstr)) != RdKafka::Conf::CONF_OK) { + return errors::Internal("failed to set group.id [", group_id, "]:", errstr); + } + } + if ((result = conf->set("event_cb", &kafka_event_cb_, errstr)) != RdKafka::Conf::CONF_OK) { + return errors::Internal("failed to set event_cb:", errstr); + } + + // TODO: multiple topic and partitions + + const string& entry = input[0]; + std::vector parts = str_util::Split(entry, ":"); + string topic = parts[0]; + int32 partition = 0; + if (parts.size() > 1) { + if (!strings::safe_strto32(parts[1], &partition)) { + return errors::InvalidArgument("invalid parameters: ", entry); + } + } + + int64 start = 0; + if (parts.size() > 2) { + if (!strings::safe_strto64(parts[2], &start)) { + return errors::InvalidArgument("invalid parameters: ", entry); + } + } + subscription_.reset(RdKafka::TopicPartition::create(topic, partition, start)); + start = subscription_->offset(); + + offset_ = start; + + int64 stop = -1; + if (parts.size() > 3) { + if (!strings::safe_strto64(parts[3], &stop)) { + return errors::InvalidArgument("invalid parameters: ", entry); + } + } + range_ = std::pair(start, stop); + + consumer_.reset(RdKafka::KafkaConsumer::create(conf.get(), errstr)); + if (!consumer_.get()) { + return errors::Internal("failed to create consumer:", errstr); + } + + std::vector partitions; + partitions.emplace_back(subscription_.get()); + RdKafka::ErrorCode err = consumer_->assign(partitions); + if (err != RdKafka::ERR_NO_ERROR) { + return errors::Internal("failed to assign partition: ", RdKafka::err2str(err)); + } + + return Status::OK(); + } + Status Next(const int64 capacity, std::vector& tensors, int64* record_read) override { + *record_read = 0; + while (consumer_.get() != nullptr && (*record_read) < capacity) { + if (!kafka_event_cb_.run()) { + return errors::Internal("failed to consume due to all brokers down"); + } + if (range_.second >= 0 && (subscription_->offset() >= range_.second || offset_ >= range_.second)) { + // EOF of topic + consumer_.reset(nullptr); + return Status::OK(); + } + + std::unique_ptr message(consumer_->consume(timeout_)); + if (message->err() == RdKafka::ERR_NO_ERROR) { + // Produce the line as output. + tensors[0].flat()((*record_read)) = std::string(static_cast(message->payload()), message->len()); + // Sync offset + offset_ = message->offset(); + (*record_read)++; + continue; + } + if (message->err() == RdKafka::ERR__PARTITION_EOF) { + LOG(INFO) << "Partition reach EOF, current offset: " << offset_; + if (eof_) { + consumer_.reset(nullptr); + return Status::OK(); + } + } + else if (message->err() == RdKafka::ERR__TRANSPORT) { + // Not return error here because consumer will try re-connect. + LOG(ERROR) << "Broker transport failure: " << message->errstr(); + } + else if (message->err() != RdKafka::ERR__TIMED_OUT) { + LOG(ERROR) << "Failed to consume: " << message->errstr(); + return errors::Internal("Failed to consume: ", message->errstr()); + } + } + return Status::OK(); + } + Status Spec(std::vector& dtypes, std::vector& shapes) override { + dtypes.clear(); + dtypes.push_back(DT_STRING); + shapes.clear(); + shapes.push_back(PartialTensorShape({-1})); + return Status::OK(); + } + + string DebugString() const override { + mutex_lock l(mu_); + return strings::StrCat("KafkaIterable[]"); + } + private: + mutable mutex mu_; + Env* env_ GUARDED_BY(mu_); + std::pair range_ GUARDED_BY(mu_); + std::unique_ptr subscription_ GUARDED_BY(mu_); + std::unique_ptr consumer_ GUARDED_BY(mu_); + KafkaEventCb kafka_event_cb_ = KafkaEventCb(); + int64 timeout_ GUARDED_BY(mu_); + bool eof_ GUARDED_BY(mu_); + int64 offset_ GUARDED_BY(mu_); +}; + +REGISTER_KERNEL_BUILDER(Name("KafkaIterableInit").Device(DEVICE_CPU), + IOInterfaceInitOp); +REGISTER_KERNEL_BUILDER(Name("KafkaIterableNext").Device(DEVICE_CPU), + IOIterableNextOp); + + +} // namespace data +} // namespace tensorflow diff --git a/tensorflow_io/kafka/ops/kafka_ops.cc b/tensorflow_io/kafka/ops/kafka_ops.cc index 7142ae719..a3fd183ee 100644 --- a/tensorflow_io/kafka/ops/kafka_ops.cc +++ b/tensorflow_io/kafka/ops/kafka_ops.cc @@ -19,6 +19,43 @@ limitations under the License. namespace tensorflow { +REGISTER_OP("KafkaIterableInit") + .Input("input: string") + .Input("metadata: string") + .Output("output: resource") + .Output("dtypes: int64") + .Output("shapes: int64") + .Attr("container: string = ''") + .Attr("shared_name: string = ''") + .SetIsStateful() + .SetShapeFn([](shape_inference::InferenceContext* c) { + c->set_output(0, c->Scalar()); + c->set_output(1, c->MakeShape({c->UnknownDim()})); + c->set_output(2, c->MakeShape({c->UnknownDim(), c->UnknownDim()})); + return Status::OK(); + }); + +REGISTER_OP("KafkaIterableNext") + .Input("input: resource") + .Input("capacity: int64") + .Output("output: dtype") + .Attr("dtype: list(type) >= 1") + .Attr("shape: list(shape) >= 1") + .SetShapeFn([](shape_inference::InferenceContext* c) { + std::vector shape; + TF_RETURN_IF_ERROR(c->GetAttr("shape", &shape)); + if (shape.size() != c->num_outputs()) { + return errors::InvalidArgument("`shape` must be the same length as `types` (", shape.size(), " vs. ", c->num_outputs()); + } + for (size_t i = 0; i < shape.size(); ++i) { + shape_inference::ShapeHandle entry; + TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape(shape[i], &entry)); + c->set_output(static_cast(i), entry); + } + return Status::OK(); + }); + + REGISTER_OP("KafkaOutputSequence") .Input("topic: string") .Input("servers: string") diff --git a/tensorflow_io/kafka/python/ops/kafka_dataset_ops.py b/tensorflow_io/kafka/python/ops/kafka_dataset_ops.py index 7b6171d85..4b18ec2df 100644 --- a/tensorflow_io/kafka/python/ops/kafka_dataset_ops.py +++ b/tensorflow_io/kafka/python/ops/kafka_dataset_ops.py @@ -17,11 +17,21 @@ from __future__ import division from __future__ import print_function +import warnings + import tensorflow as tf from tensorflow import dtypes from tensorflow.compat.v1 import data from tensorflow_io.core.python.ops import core_ops +warnings.warn( + "implementation of existing tensorflow_io.kafka.KafkaDataset is " + "deprecated and will be replaced with the implementation in " + "tensorflow_io.core.python.ops.kafka_dataset_ops.KafkaDataset, " + "please check the doc of new implementation for API changes", + DeprecationWarning) + + class KafkaDataset(data.Dataset): """A Kafka Dataset that consumes the message. """ diff --git a/tests/test_kafka_eager.py b/tests/test_kafka_eager.py index 81f0bbab5..c350b0cd7 100644 --- a/tests/test_kafka_eager.py +++ b/tests/test_kafka_eager.py @@ -23,7 +23,16 @@ import numpy as np import tensorflow as tf -import tensorflow_io.kafka as kafka_io +if not (hasattr(tf, "version") and tf.version.VERSION.startswith("2.")): + tf.compat.v1.enable_eager_execution() +import tensorflow_io.kafka as kafka_io # pylint: disable=wrong-import-position +from tensorflow_io.core.python.ops import kafka_dataset_ops # pylint: disable=wrong-import-position + +def test_kafka_dataset(): + dataset = kafka_dataset_ops.KafkaDataset("test").batch(2) + assert np.all([ + e.numpy().tolist() for e in dataset] == np.asarray([ + ("D" + str(i)).encode() for i in range(10)]).reshape((5, 2))) @pytest.mark.skipif( not (hasattr(tf, "version") and From 5df11e9be846c93cc3f9115c8d7d9d98eb1b8a1c Mon Sep 17 00:00:00 2001 From: Yong Tang Date: Thu, 22 Aug 2019 19:16:57 +0000 Subject: [PATCH 07/11] Add KafkaIOTensor which stores data in memory (so that it is indexable) This is build around the same code base as KafkaDataset C++. Signed-off-by: Yong Tang --- tensorflow_io/audio/kernels/audio_kernels.cc | 2 +- tensorflow_io/core/kernels/io_interface.h | 116 ++++++++++++++++++ tensorflow_io/core/python/ops/io_tensor.py | 39 ++++++ .../core/python/ops/kafka_io_tensor_ops.py | 48 ++++++++ tensorflow_io/kafka/kernels/kafka_kernels.cc | 5 +- tensorflow_io/kafka/ops/kafka_ops.cc | 38 ++++++ tests/test_kafka_eager.py | 13 +- 7 files changed, 256 insertions(+), 5 deletions(-) create mode 100644 tensorflow_io/core/python/ops/kafka_io_tensor_ops.py diff --git a/tensorflow_io/audio/kernels/audio_kernels.cc b/tensorflow_io/audio/kernels/audio_kernels.cc index 62f9e0737..ae4d43f21 100644 --- a/tensorflow_io/audio/kernels/audio_kernels.cc +++ b/tensorflow_io/audio/kernels/audio_kernels.cc @@ -73,7 +73,7 @@ class WAVIndexable : public IOIndexableInterface { : env_(env) {} ~WAVIndexable() {} - Status Init(const std::vector& input, const std::vector& metadata, const void* memory_data, const int64 memory_size) override { + Status Init(const std::vector& input, const std::vector& metadata, const void* memory_data, const int64 memory_size) override { if (input.size() > 1) { return errors::InvalidArgument("more than 1 filename is not supported"); } diff --git a/tensorflow_io/core/kernels/io_interface.h b/tensorflow_io/core/kernels/io_interface.h index 1d7e49269..247804d46 100644 --- a/tensorflow_io/core/kernels/io_interface.h +++ b/tensorflow_io/core/kernels/io_interface.h @@ -41,6 +41,122 @@ class IOIndexableInterface : public IOInterface { virtual Status GetItem(const int64 start, const int64 stop, const int64 step, std::vector& tensors) = 0; }; +template +class IOIndexableImplementation : public IOIndexableInterface { + public: + IOIndexableImplementation(Env* env) + : env_(env) + , iterable_(new Type(env)) {} + + ~IOIndexableImplementation() {} + Status Init(const std::vector& input, const std::vector& metadata, const void* memory_data, const int64 memory_size) override { + + TF_RETURN_IF_ERROR(iterable_->Init(input, metadata, memory_data, memory_size)); + TF_RETURN_IF_ERROR(iterable_->Spec(dtypes_, shapes_)); + + const int64 capacity = 4096; + std::vector chunk_shapes; + for (size_t component = 0; component < shapes_.size(); component++) { + gtl::InlinedVector dims = shapes_[component].dim_sizes(); + dims[0] = capacity; + chunk_shapes.push_back(TensorShape(dims)); + } + + int64 total = 0; + + int64 record_read = 0; + do { + tensors_.push_back(std::vector()); + for (size_t component = 0; component < shapes_.size(); component++) { + tensors_.back().push_back(Tensor(dtypes_[component], chunk_shapes[component])); + } + TF_RETURN_IF_ERROR(iterable_->Next(capacity, tensors_.back(), &record_read)); + if (record_read == 0) { + tensors_.pop_back(); + break; + } + if (record_read < capacity) { + for (size_t component = 0; component < shapes_.size(); component++) { + tensors_.back()[component] = tensors_.back()[component].Slice(0, record_read); + } + } + total += record_read; + } while (record_read != 0); + for (size_t component = 0; component < shapes_.size(); component++) { + shapes_[component].set_dim(0, total); + } + return Status::OK(); + } + virtual Status Spec(std::vector& dtypes, std::vector& shapes) override { + for (size_t component = 0; component < dtypes_.size(); component++) { + dtypes.push_back(dtypes_[component]); + } + for (size_t component = 0; component < shapes_.size(); component++) { + shapes.push_back(shapes_[component]); + } + return Status::OK(); + } + + Status Extra(std::vector* extra) override { + return iterable_->Extra(extra); + } + string DebugString() const override { + mutex_lock l(mu_); + return strings::StrCat("IOIndexableImplementation<", iterable_->DebugString(), ">[]"); + } + + Status GetItem(const int64 start, const int64 stop, const int64 step, std::vector& tensors) override { + if (step != 1) { + return errors::InvalidArgument("step != 1 is not supported: ", step); + } + // Find first chunk + + int64 chunk_index = 0; + int64 chunk_element = -1; + int64 current_element = 0; + while (chunk_index < tensors_.size()) { + if (current_element <= start && start < current_element + tensors_[chunk_index][0].shape().dim_size(0)) { + chunk_element = start - current_element; + current_element = start; + break; + } + current_element += tensors_[chunk_index][0].shape().dim_size(0); + chunk_index++; + } + if (chunk_element < 0) { + return errors::InvalidArgument("start is out of range: ", start); + } + std::vector elements; + for (size_t component = 0; component < shapes_.size(); component++) { + TensorShape shape(shapes_[component].dim_sizes()); + shape.RemoveDim(0); + elements.push_back(Tensor(dtypes_[component], shape)); + } + + while (current_element < stop) { + for (size_t component = 0; component < shapes_.size(); component++) { + batch_util::CopySliceToElement(tensors_[chunk_index][component], &elements[component], chunk_element); + batch_util::CopyElementToSlice(elements[component], &tensors[component], (current_element - start)); + } + chunk_element++; + if (chunk_element == tensors_[chunk_index][0].shape().dim_size(0)) { + chunk_index++; + chunk_element = 0; + } + current_element++; + } + return Status::OK(); + } + private: + mutable mutex mu_; + Env* env_ GUARDED_BY(mu_); + std::unique_ptr iterable_ GUARDED_BY(mu_); + std::vector dtypes_ GUARDED_BY(mu_); + std::vector shapes_ GUARDED_BY(mu_); + std::vector> tensors_; +}; + + template class IOInterfaceInitOp : public ResourceOpKernel { public: diff --git a/tensorflow_io/core/python/ops/io_tensor.py b/tensorflow_io/core/python/ops/io_tensor.py index f31cf9437..55e9410b6 100644 --- a/tensorflow_io/core/python/ops/io_tensor.py +++ b/tensorflow_io/core/python/ops/io_tensor.py @@ -21,6 +21,7 @@ from tensorflow_io.core.python.ops import io_tensor_ops from tensorflow_io.core.python.ops import audio_io_tensor_ops from tensorflow_io.core.python.ops import json_io_tensor_ops +from tensorflow_io.core.python.ops import kafka_io_tensor_ops class IOTensor(io_tensor_ops._BaseIOTensor): # pylint: disable=protected-access """IOTensor @@ -225,3 +226,41 @@ def from_json(cls, """ with tf.name_scope(kwargs.get("name", "IOFromJSON")): return json_io_tensor_ops.JSONIOTensor(filename, internal=True) + + @classmethod + def from_kafka(cls, + subscription, + **kwargs): + """Creates an `IOTensor` from a Kafka stream. + + Args: + subscription: A `tf.string` tensor containing subscription, + in the format of [topic:partition:offset:length], + by default length is -1 for unlimited. + servers: An optional list of bootstrap servers, by default + `localhost:9092`. + configuration: An optional `tf.string` tensor containing + configurations in [Key=Value] format. There are three + types of configurations: + Global configuration: please refer to 'Global configuration properties' + in librdkafka doc. Examples include + ["enable.auto.commit=false", "heartbeat.interval.ms=2000"] + Topic configuration: please refer to 'Topic configuration properties' + in librdkafka doc. Note all topic configurations should be + prefixed with `configuration.topic.`. Examples include + ["conf.topic.auto.offset.reset=earliest"] + Dataset configuration: there are two configurations available, + `conf.eof=0|1`: if True, the KafkaDaset will stop on EOF (default). + `conf.timeout=milliseconds`: timeout value for Kafka Consumer to wait. + name: A name prefix for the IOTensor (optional). + + Returns: + A `IOTensor`. + + """ + with tf.name_scope(kwargs.get("name", "IOFromKafka")): + return kafka_io_tensor_ops.KafkaIOTensor( + subscription, + servers=kwargs.get("servers", None), + configuration=kwargs.get("configuration", None), + internal=True) diff --git a/tensorflow_io/core/python/ops/kafka_io_tensor_ops.py b/tensorflow_io/core/python/ops/kafka_io_tensor_ops.py new file mode 100644 index 000000000..a09f6245e --- /dev/null +++ b/tensorflow_io/core/python/ops/kafka_io_tensor_ops.py @@ -0,0 +1,48 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""KafkaIOTensor""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import uuid + +import tensorflow as tf +from tensorflow_io.core.python.ops import io_tensor_ops +from tensorflow_io.core.python.ops import core_ops + +class KafkaIOTensor(io_tensor_ops._ColumnIOTensor): # pylint: disable=protected-access + """KafkaIOTensor""" + + #============================================================================= + # Constructor (private) + #============================================================================= + def __init__(self, + subscription, + servers=None, + configuration=None, + internal=False): + with tf.name_scope("KafkaIOTensor") as scope: + metadata = [e for e in configuration or []] + if servers is not None: + metadata.append("bootstrap.servers=%s" % servers) + resource, dtypes, shapes = core_ops.kafka_indexable_init( + subscription, metadata=metadata, + container=scope, + shared_name="%s/%s" % (subscription, uuid.uuid4().hex)) + print("VVV: ", dtypes, shapes) + super(KafkaIOTensor, self).__init__( + shapes, dtypes, resource, core_ops.kafka_indexable_get_item, + internal=internal) diff --git a/tensorflow_io/kafka/kernels/kafka_kernels.cc b/tensorflow_io/kafka/kernels/kafka_kernels.cc index 30a595afb..2a3a25f20 100644 --- a/tensorflow_io/kafka/kernels/kafka_kernels.cc +++ b/tensorflow_io/kafka/kernels/kafka_kernels.cc @@ -244,7 +244,10 @@ REGISTER_KERNEL_BUILDER(Name("KafkaIterableInit").Device(DEVICE_CPU), IOInterfaceInitOp); REGISTER_KERNEL_BUILDER(Name("KafkaIterableNext").Device(DEVICE_CPU), IOIterableNextOp); - +REGISTER_KERNEL_BUILDER(Name("KafkaIndexableInit").Device(DEVICE_CPU), + IOInterfaceInitOp>); +REGISTER_KERNEL_BUILDER(Name("KafkaIndexableGetItem").Device(DEVICE_CPU), + IOIndexableGetItemOp>); } // namespace data } // namespace tensorflow diff --git a/tensorflow_io/kafka/ops/kafka_ops.cc b/tensorflow_io/kafka/ops/kafka_ops.cc index a3fd183ee..1ca333eac 100644 --- a/tensorflow_io/kafka/ops/kafka_ops.cc +++ b/tensorflow_io/kafka/ops/kafka_ops.cc @@ -19,6 +19,44 @@ limitations under the License. namespace tensorflow { +REGISTER_OP("KafkaIndexableInit") + .Input("input: string") + .Input("metadata: string") + .Output("output: resource") + .Output("dtypes: int64") + .Output("shapes: int64") + .Attr("container: string = ''") + .Attr("shared_name: string = ''") + .SetIsStateful() + .SetShapeFn([](shape_inference::InferenceContext* c) { + c->set_output(0, c->Scalar()); + c->set_output(1, c->MakeShape({c->UnknownDim()})); + c->set_output(2, c->MakeShape({c->UnknownDim(), c->UnknownDim()})); + return Status::OK(); + }); + +REGISTER_OP("KafkaIndexableGetItem") + .Input("input: resource") + .Input("start: int64") + .Input("stop: int64") + .Input("step: int64") + .Output("output: dtype") + .Attr("dtype: list(type) >= 1") + .Attr("shape: list(shape) >= 1") + .SetShapeFn([](shape_inference::InferenceContext* c) { + std::vector shape; + TF_RETURN_IF_ERROR(c->GetAttr("shape", &shape)); + if (shape.size() != c->num_outputs()) { + return errors::InvalidArgument("`shape` must be the same length as `types` (", shape.size(), " vs. ", c->num_outputs()); + } + for (size_t i = 0; i < shape.size(); ++i) { + shape_inference::ShapeHandle entry; + TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape(shape[i], &entry)); + c->set_output(static_cast(i), entry); + } + return Status::OK(); + }); + REGISTER_OP("KafkaIterableInit") .Input("input: string") .Input("metadata: string") diff --git a/tests/test_kafka_eager.py b/tests/test_kafka_eager.py index c350b0cd7..2b18954ea 100644 --- a/tests/test_kafka_eager.py +++ b/tests/test_kafka_eager.py @@ -25,7 +25,7 @@ import tensorflow as tf if not (hasattr(tf, "version") and tf.version.VERSION.startswith("2.")): tf.compat.v1.enable_eager_execution() -import tensorflow_io.kafka as kafka_io # pylint: disable=wrong-import-position +import tensorflow_io as tfio # pylint: disable=wrong-import-position from tensorflow_io.core.python.ops import kafka_dataset_ops # pylint: disable=wrong-import-position def test_kafka_dataset(): @@ -34,6 +34,13 @@ def test_kafka_dataset(): e.numpy().tolist() for e in dataset] == np.asarray([ ("D" + str(i)).encode() for i in range(10)]).reshape((5, 2))) +def test_kafka_io_tensor(): + kafka = tfio.IOTensor.from_kafka("test") + assert kafka.dtype == tf.string + assert kafka.shape == [10] + assert np.all(kafka.to_tensor().numpy() == [ + ("D" + str(i)).encode() for i in range(10)]) + @pytest.mark.skipif( not (hasattr(tf, "version") and tf.version.VERSION.startswith("2.0.")), reason=None) @@ -64,7 +71,7 @@ def test_kafka_output_sequence(): class OutputCallback(tf.keras.callbacks.Callback): """KafkaOutputCallback""" def __init__(self, batch_size, topic, servers): - self._sequence = kafka_io.KafkaOutputSequence( + self._sequence = tfio.kafka.KafkaOutputSequence( topic=topic, servers=servers) self._batch_size = batch_size def on_predict_batch_end(self, batch, logs=None): @@ -87,6 +94,6 @@ def flush(self): predictions = [class_names[v] for v in np.argmax(predictions, axis=1)] # Reading from `test_e(time)e` we should get the same result - dataset = kafka_io.KafkaDataset(topics=[topic], group="test", eof=True) + dataset = tfio.kafka.KafkaDataset(topics=[topic], group="test", eof=True) for entry, prediction in zip(dataset, predictions): assert entry.numpy() == prediction.encode() From 405101082b3d7962cd4b9b73d3a9f2f2cfd6e7c4 Mon Sep 17 00:00:00 2001 From: Yong Tang Date: Thu, 22 Aug 2019 23:54:28 +0000 Subject: [PATCH 08/11] Deprecate WAVDataset, and pylint fix Signed-off-by: Yong Tang --- tensorflow_io/audio/python/ops/audio_ops.py | 9 +++++++++ tensorflow_io/core/python/ops/io_tensor.py | 8 ++++---- tests/test_kafka_eager.py | 2 +- 3 files changed, 14 insertions(+), 5 deletions(-) diff --git a/tensorflow_io/audio/python/ops/audio_ops.py b/tensorflow_io/audio/python/ops/audio_ops.py index 8c01b594d..a3379b1d2 100644 --- a/tensorflow_io/audio/python/ops/audio_ops.py +++ b/tensorflow_io/audio/python/ops/audio_ops.py @@ -17,9 +17,18 @@ from __future__ import division from __future__ import print_function +import warnings + import tensorflow as tf from tensorflow_io.core.python.ops import data_ops +warnings.warn( + "The tensorflow_io.audio.WAVDataset is " + "deprecated. Please look for tfio.IOTensor.from_audio " + "for reading WAV files into tensorflow.", + DeprecationWarning) + + class AudioDataset(data_ops.Dataset): """A Audio File Dataset that reads the audio file.""" diff --git a/tensorflow_io/core/python/ops/io_tensor.py b/tensorflow_io/core/python/ops/io_tensor.py index 55e9410b6..5a237da41 100644 --- a/tensorflow_io/core/python/ops/io_tensor.py +++ b/tensorflow_io/core/python/ops/io_tensor.py @@ -260,7 +260,7 @@ def from_kafka(cls, """ with tf.name_scope(kwargs.get("name", "IOFromKafka")): return kafka_io_tensor_ops.KafkaIOTensor( - subscription, - servers=kwargs.get("servers", None), - configuration=kwargs.get("configuration", None), - internal=True) + subscription, + servers=kwargs.get("servers", None), + configuration=kwargs.get("configuration", None), + internal=True) diff --git a/tests/test_kafka_eager.py b/tests/test_kafka_eager.py index 2b18954ea..0021aaa84 100644 --- a/tests/test_kafka_eager.py +++ b/tests/test_kafka_eager.py @@ -39,7 +39,7 @@ def test_kafka_io_tensor(): assert kafka.dtype == tf.string assert kafka.shape == [10] assert np.all(kafka.to_tensor().numpy() == [ - ("D" + str(i)).encode() for i in range(10)]) + ("D" + str(i)).encode() for i in range(10)]) @pytest.mark.skipif( not (hasattr(tf, "version") and From 5eb166d600fb201e610c695e434e3a2154033140 Mon Sep 17 00:00:00 2001 From: Yong Tang Date: Sat, 24 Aug 2019 15:00:55 +0000 Subject: [PATCH 09/11] Remove leftover print Signed-off-by: Yong Tang --- tensorflow_io/core/python/ops/kafka_io_tensor_ops.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tensorflow_io/core/python/ops/kafka_io_tensor_ops.py b/tensorflow_io/core/python/ops/kafka_io_tensor_ops.py index a09f6245e..ee88e4b90 100644 --- a/tensorflow_io/core/python/ops/kafka_io_tensor_ops.py +++ b/tensorflow_io/core/python/ops/kafka_io_tensor_ops.py @@ -42,7 +42,6 @@ def __init__(self, subscription, metadata=metadata, container=scope, shared_name="%s/%s" % (subscription, uuid.uuid4().hex)) - print("VVV: ", dtypes, shapes) super(KafkaIOTensor, self).__init__( shapes, dtypes, resource, core_ops.kafka_indexable_get_item, internal=internal) From a8af64a839f120520cde5f7cc6fb5c9ef17c101a Mon Sep 17 00:00:00 2001 From: Yong Tang Date: Sat, 24 Aug 2019 15:38:52 +0000 Subject: [PATCH 10/11] Import GetTensorFlowType and GetArrowType Signed-off-by: Yong Tang --- .../arrow/kernels/arrow_dataset_ops.cc | 13 ++--- tensorflow_io/arrow/kernels/arrow_kernels.cc | 18 ++++++ tensorflow_io/arrow/kernels/arrow_kernels.h | 6 ++ tensorflow_io/json/kernels/json_kernels.cc | 57 +------------------ 4 files changed, 30 insertions(+), 64 deletions(-) diff --git a/tensorflow_io/arrow/kernels/arrow_dataset_ops.cc b/tensorflow_io/arrow/kernels/arrow_dataset_ops.cc index 63f5b3fe3..cbe0c9a4e 100644 --- a/tensorflow_io/arrow/kernels/arrow_dataset_ops.cc +++ b/tensorflow_io/arrow/kernels/arrow_dataset_ops.cc @@ -14,7 +14,6 @@ limitations under the License. ==============================================================================*/ #include "arrow/api.h" -#include "arrow/adapters/tensorflow/convert.h" #include "arrow/ipc/api.h" #include "arrow/util/io-util.h" #include "tensorflow/core/framework/dataset.h" @@ -99,8 +98,10 @@ class ArrowColumnTypeChecker : public arrow::TypeVisitor { // Check scalar types with arrow::adapters::tensorflow arrow::Status CheckScalarType(std::shared_ptr scalar_type) { DataType converted_type; - ARROW_RETURN_NOT_OK(arrow::adapters::tensorflow::GetTensorFlowType( - scalar_type, &converted_type)); + ::tensorflow::Status status = GetTensorFlowType(scalar_type, &converted_type); + if (!status.ok()) { + return ::arrow::Status::Invalid(status); + } if (converted_type != expected_type_) { return arrow::Status::TypeError( "Arrow type mismatch: expected dtype=" + @@ -523,11 +524,7 @@ class ArrowOpKernelBase : public DatasetOpKernel { OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_)); for (const DataType& dt : output_types_) { std::shared_ptr arrow_type; - auto status = arrow::adapters::tensorflow::GetArrowType(dt, &arrow_type); - OP_REQUIRES(ctx, status.ok(), - errors::InvalidArgument( - "Arrow type is unsupported for output_type dtype=" + - std::to_string(dt))); + OP_REQUIRES_OK(ctx, GetArrowType(dt, &arrow_type)); } for (const PartialTensorShape& pts : output_shapes_) { OP_REQUIRES(ctx, -1 <= pts.dims() && pts.dims() <= 2, diff --git a/tensorflow_io/arrow/kernels/arrow_kernels.cc b/tensorflow_io/arrow/kernels/arrow_kernels.cc index 46b3bbcc5..90d337ed0 100644 --- a/tensorflow_io/arrow/kernels/arrow_kernels.cc +++ b/tensorflow_io/arrow/kernels/arrow_kernels.cc @@ -19,9 +19,27 @@ limitations under the License. #include "arrow/ipc/feather.h" #include "arrow/ipc/feather_generated.h" #include "arrow/buffer.h" +#include "arrow/adapters/tensorflow/convert.h" namespace tensorflow { namespace data { + +Status GetTensorFlowType(std::shared_ptr<::arrow::DataType> dtype, ::tensorflow::DataType* out) { + ::arrow::Status status = ::arrow::adapters::tensorflow::GetTensorFlowType(dtype, out); + if (!status.ok()) { + return errors::InvalidArgument("arrow data type ", dtype, " is not supported: ", status); + } + return Status::OK(); +} + +Status GetArrowType(::tensorflow::DataType dtype, std::shared_ptr<::arrow::DataType>* out) { + ::arrow::Status status = ::arrow::adapters::tensorflow::GetArrowType(dtype, out); + if (!status.ok()) { + return errors::InvalidArgument("tensorflow data type ", dtype, " is not supported: ", status); + } + return Status::OK(); +} + namespace { class ListFeatherColumnsOp : public OpKernel { diff --git a/tensorflow_io/arrow/kernels/arrow_kernels.h b/tensorflow_io/arrow/kernels/arrow_kernels.h index 30da9745f..39abef014 100644 --- a/tensorflow_io/arrow/kernels/arrow_kernels.h +++ b/tensorflow_io/arrow/kernels/arrow_kernels.h @@ -19,10 +19,14 @@ limitations under the License. #include "kernels/stream.h" #include "arrow/io/api.h" #include "arrow/buffer.h" +#include "arrow/type.h" namespace tensorflow { namespace data { +Status GetTensorFlowType(std::shared_ptr<::arrow::DataType> dtype, ::tensorflow::DataType* out); +Status GetArrowType(::tensorflow::DataType dtype, std::shared_ptr<::arrow::DataType>* out); + // NOTE: Both SizedRandomAccessFile and ArrowRandomAccessFile overlap // with another PR. Will remove duplicate once PR merged @@ -92,6 +96,8 @@ class ArrowRandomAccessFile : public ::arrow::io::RandomAccessFile { int64 size_; int64 position_; }; + + } // namespace data } // namespace tensorflow diff --git a/tensorflow_io/json/kernels/json_kernels.cc b/tensorflow_io/json/kernels/json_kernels.cc index d7841ec21..a2dff6146 100644 --- a/tensorflow_io/json/kernels/json_kernels.cc +++ b/tensorflow_io/json/kernels/json_kernels.cc @@ -262,61 +262,7 @@ class JSONIndexable : public IOIndexableInterface { for (size_t i = 0; i < columns_index_.size(); i++) { int column_index = columns_index_[i]; ::tensorflow::DataType dtype; - switch (table_->column(column_index)->type()->id()) { - case ::arrow::Type::BOOL: - dtype = ::tensorflow::DT_BOOL; - break; - case ::arrow::Type::UINT8: - dtype= ::tensorflow::DT_UINT8; - break; - case ::arrow::Type::INT8: - dtype= ::tensorflow::DT_INT8; - break; - case ::arrow::Type::UINT16: - dtype= ::tensorflow::DT_UINT16; - break; - case ::arrow::Type::INT16: - dtype= ::tensorflow::DT_INT16; - break; - case ::arrow::Type::UINT32: - dtype= ::tensorflow::DT_UINT32; - break; - case ::arrow::Type::INT32: - dtype= ::tensorflow::DT_INT32; - break; - case ::arrow::Type::UINT64: - dtype= ::tensorflow::DT_UINT64; - break; - case ::arrow::Type::INT64: - dtype= ::tensorflow::DT_INT64; - break; - case ::arrow::Type::HALF_FLOAT: - dtype= ::tensorflow::DT_HALF; - break; - case ::arrow::Type::FLOAT: - dtype= ::tensorflow::DT_FLOAT; - break; - case ::arrow::Type::DOUBLE: - dtype= ::tensorflow::DT_DOUBLE; - break; - case ::arrow::Type::STRING: - case ::arrow::Type::BINARY: - case ::arrow::Type::FIXED_SIZE_BINARY: - case ::arrow::Type::DATE32: - case ::arrow::Type::DATE64: - case ::arrow::Type::TIMESTAMP: - case ::arrow::Type::TIME32: - case ::arrow::Type::TIME64: - case ::arrow::Type::INTERVAL: - case ::arrow::Type::DECIMAL: - case ::arrow::Type::LIST: - case ::arrow::Type::STRUCT: - case ::arrow::Type::UNION: - case ::arrow::Type::DICTIONARY: - case ::arrow::Type::MAP: - default: - return errors::InvalidArgument("arrow data type is not supported: ", table_->column(i)->type()->ToString()); - } + TF_RETURN_IF_ERROR(GetTensorFlowType(table_->column(column_index)->type(), &dtype)); dtypes_.push_back(dtype); shapes_.push_back(TensorShape({static_cast(table_->num_rows())})); columns_.push_back(table_->column(column_index)->name()); @@ -347,7 +293,6 @@ class JSONIndexable : public IOIndexableInterface { } Status GetItem(const int64 start, const int64 stop, const int64 step, std::vector& tensors) override { - Tensor& output_tensor = tensors[0]; if (step != 1) { return errors::InvalidArgument("step ", step, " is not supported"); } From b1f5508e44410b24151ceb39799635318c7919fc Mon Sep 17 00:00:00 2001 From: Yong Tang Date: Sat, 24 Aug 2019 15:54:20 +0000 Subject: [PATCH 11/11] Fix kokoro version Signed-off-by: Yong Tang --- .kokorun/io_cpu.sh | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.kokorun/io_cpu.sh b/.kokorun/io_cpu.sh index cc0215e5b..b301687c1 100755 --- a/.kokorun/io_cpu.sh +++ b/.kokorun/io_cpu.sh @@ -45,6 +45,8 @@ python --version python -m pip --version docker --version +PYTHON_VERSION=$(python -c 'import sys; print(sys.version_info[0])') + ## Set test services bash -x -e tests/test_ignite/start_ignite.sh bash -x -e tests/test_kafka/kafka_test.sh start kafka