diff --git a/.kokorun/io_cpu.sh b/.kokorun/io_cpu.sh index cc0215e5b..b301687c1 100755 --- a/.kokorun/io_cpu.sh +++ b/.kokorun/io_cpu.sh @@ -45,6 +45,8 @@ python --version python -m pip --version docker --version +PYTHON_VERSION=$(python -c 'import sys; print(sys.version_info[0])') + ## Set test services bash -x -e tests/test_ignite/start_ignite.sh bash -x -e tests/test_kafka/kafka_test.sh start kafka diff --git a/WORKSPACE b/WORKSPACE index a8427116a..677273d61 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -542,3 +542,13 @@ http_archive( "https://github.com/google/double-conversion/archive/3992066a95b823efc8ccc1baf82a1cfc73f6e9b8.zip", ], ) + +http_archive( + name = "rapidjson", + build_file = "//third_party:rapidjson.BUILD", + sha256 = "bf7ced29704a1e696fbccf2a2b4ea068e7774fa37f6d7dd4039d0787f8bed98e", + strip_prefix = "rapidjson-1.1.0", + urls = [ + "https://github.com/miloyip/rapidjson/archive/v1.1.0.tar.gz", + ], +) diff --git a/tensorflow_io/__init__.py b/tensorflow_io/__init__.py index 5c9b94a5d..5d66c595c 100644 --- a/tensorflow_io/__init__.py +++ b/tensorflow_io/__init__.py @@ -17,4 +17,4 @@ from __future__ import division from __future__ import print_function -import tensorflow as tf +from tensorflow_io.core.python.ops.io_tensor import IOTensor diff --git a/tensorflow_io/arrow/kernels/arrow_dataset_ops.cc b/tensorflow_io/arrow/kernels/arrow_dataset_ops.cc index 63f5b3fe3..cbe0c9a4e 100644 --- a/tensorflow_io/arrow/kernels/arrow_dataset_ops.cc +++ b/tensorflow_io/arrow/kernels/arrow_dataset_ops.cc @@ -14,7 +14,6 @@ limitations under the License. ==============================================================================*/ #include "arrow/api.h" -#include "arrow/adapters/tensorflow/convert.h" #include "arrow/ipc/api.h" #include "arrow/util/io-util.h" #include "tensorflow/core/framework/dataset.h" @@ -99,8 +98,10 @@ class ArrowColumnTypeChecker : public arrow::TypeVisitor { // Check scalar types with arrow::adapters::tensorflow arrow::Status CheckScalarType(std::shared_ptr scalar_type) { DataType converted_type; - ARROW_RETURN_NOT_OK(arrow::adapters::tensorflow::GetTensorFlowType( - scalar_type, &converted_type)); + ::tensorflow::Status status = GetTensorFlowType(scalar_type, &converted_type); + if (!status.ok()) { + return ::arrow::Status::Invalid(status); + } if (converted_type != expected_type_) { return arrow::Status::TypeError( "Arrow type mismatch: expected dtype=" + @@ -523,11 +524,7 @@ class ArrowOpKernelBase : public DatasetOpKernel { OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_)); for (const DataType& dt : output_types_) { std::shared_ptr arrow_type; - auto status = arrow::adapters::tensorflow::GetArrowType(dt, &arrow_type); - OP_REQUIRES(ctx, status.ok(), - errors::InvalidArgument( - "Arrow type is unsupported for output_type dtype=" + - std::to_string(dt))); + OP_REQUIRES_OK(ctx, GetArrowType(dt, &arrow_type)); } for (const PartialTensorShape& pts : output_shapes_) { OP_REQUIRES(ctx, -1 <= pts.dims() && pts.dims() <= 2, diff --git a/tensorflow_io/arrow/kernels/arrow_kernels.cc b/tensorflow_io/arrow/kernels/arrow_kernels.cc index 46b3bbcc5..90d337ed0 100644 --- a/tensorflow_io/arrow/kernels/arrow_kernels.cc +++ b/tensorflow_io/arrow/kernels/arrow_kernels.cc @@ -19,9 +19,27 @@ limitations under the License. #include "arrow/ipc/feather.h" #include "arrow/ipc/feather_generated.h" #include "arrow/buffer.h" +#include "arrow/adapters/tensorflow/convert.h" namespace tensorflow { namespace data { + +Status GetTensorFlowType(std::shared_ptr<::arrow::DataType> dtype, ::tensorflow::DataType* out) { + ::arrow::Status status = ::arrow::adapters::tensorflow::GetTensorFlowType(dtype, out); + if (!status.ok()) { + return errors::InvalidArgument("arrow data type ", dtype, " is not supported: ", status); + } + return Status::OK(); +} + +Status GetArrowType(::tensorflow::DataType dtype, std::shared_ptr<::arrow::DataType>* out) { + ::arrow::Status status = ::arrow::adapters::tensorflow::GetArrowType(dtype, out); + if (!status.ok()) { + return errors::InvalidArgument("tensorflow data type ", dtype, " is not supported: ", status); + } + return Status::OK(); +} + namespace { class ListFeatherColumnsOp : public OpKernel { diff --git a/tensorflow_io/arrow/kernels/arrow_kernels.h b/tensorflow_io/arrow/kernels/arrow_kernels.h index b0523ead1..39abef014 100644 --- a/tensorflow_io/arrow/kernels/arrow_kernels.h +++ b/tensorflow_io/arrow/kernels/arrow_kernels.h @@ -19,10 +19,14 @@ limitations under the License. #include "kernels/stream.h" #include "arrow/io/api.h" #include "arrow/buffer.h" +#include "arrow/type.h" namespace tensorflow { namespace data { +Status GetTensorFlowType(std::shared_ptr<::arrow::DataType> dtype, ::tensorflow::DataType* out); +Status GetArrowType(::tensorflow::DataType dtype, std::shared_ptr<::arrow::DataType>* out); + // NOTE: Both SizedRandomAccessFile and ArrowRandomAccessFile overlap // with another PR. Will remove duplicate once PR merged @@ -30,7 +34,8 @@ class ArrowRandomAccessFile : public ::arrow::io::RandomAccessFile { public: explicit ArrowRandomAccessFile(tensorflow::RandomAccessFile *file, int64 size) : file_(file) - , size_(size) { } + , size_(size) + , position_(0) { } ~ArrowRandomAccessFile() {} arrow::Status Close() override { @@ -40,13 +45,21 @@ class ArrowRandomAccessFile : public ::arrow::io::RandomAccessFile { return false; } arrow::Status Tell(int64_t* position) const override { - return arrow::Status::NotImplemented("Tell"); + *position = position_; + return arrow::Status::OK(); } arrow::Status Seek(int64_t position) override { return arrow::Status::NotImplemented("Seek"); } arrow::Status Read(int64_t nbytes, int64_t* bytes_read, void* out) override { - return arrow::Status::NotImplemented("Read (void*)"); + StringPiece result; + Status status = file_->Read(position_, nbytes, &result, (char*)out); + if (!(status.ok() || errors::IsOutOfRange(status))) { + return arrow::Status::IOError(status.error_message()); + } + *bytes_read = result.size(); + position_ += (*bytes_read); + return arrow::Status::OK(); } arrow::Status Read(int64_t nbytes, std::shared_ptr* out) override { return arrow::Status::NotImplemented("Read (Buffer*)"); @@ -81,7 +94,10 @@ class ArrowRandomAccessFile : public ::arrow::io::RandomAccessFile { private: tensorflow::RandomAccessFile* file_; int64 size_; + int64 position_; }; + + } // namespace data } // namespace tensorflow diff --git a/tensorflow_io/audio/__init__.py b/tensorflow_io/audio/__init__.py index b95edeedd..5792023c0 100644 --- a/tensorflow_io/audio/__init__.py +++ b/tensorflow_io/audio/__init__.py @@ -15,8 +15,6 @@ """Audio Dataset. @@WAVDataset -@@list_wav_info -@@read_wav """ from __future__ import absolute_import @@ -24,15 +22,11 @@ from __future__ import print_function from tensorflow_io.audio.python.ops.audio_ops import WAVDataset -from tensorflow_io.audio.python.ops.audio_ops import list_wav_info -from tensorflow_io.audio.python.ops.audio_ops import read_wav from tensorflow.python.util.all_util import remove_undocumented _allowed_symbols = [ "WAVDataset", - "list_wav_info", - "read_wav", ] remove_undocumented(__name__) diff --git a/tensorflow_io/audio/kernels/audio_kernels.cc b/tensorflow_io/audio/kernels/audio_kernels.cc index aa65ebed4..ae4d43f21 100644 --- a/tensorflow_io/audio/kernels/audio_kernels.cc +++ b/tensorflow_io/audio/kernels/audio_kernels.cc @@ -13,12 +13,11 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow_io/core/kernels/io_interface.h" #include "tensorflow_io/core/kernels/stream.h" namespace tensorflow { namespace data { -namespace { // See http://www-mmsp.ece.mcgill.ca/Documents/AudioFormats/WAVE/WAVE.html struct WAVHeader { @@ -68,189 +67,140 @@ Status ValidateWAVHeader(struct WAVHeader *header) { return Status::OK(); } -class ListWAVInfoOp : public OpKernel { +class WAVIndexable : public IOIndexableInterface { public: - explicit ListWAVInfoOp(OpKernelConstruction* context) : OpKernel(context) { - env_ = context->env(); - } - - void Compute(OpKernelContext* context) override { - const Tensor& filename_tensor = context->input(0); - const string filename = filename_tensor.scalar()(); + WAVIndexable(Env* env) + : env_(env) {} - const Tensor& memory_tensor = context->input(1); - const string& memory = memory_tensor.scalar()(); - std::unique_ptr file(new SizedRandomAccessFile(env_, filename, memory.data(), memory.size())); - uint64 size; - OP_REQUIRES_OK(context, file->GetFileSize(&size)); + ~WAVIndexable() {} + Status Init(const std::vector& input, const std::vector& metadata, const void* memory_data, const int64 memory_size) override { + if (input.size() > 1) { + return errors::InvalidArgument("more than 1 filename is not supported"); + } + const string& filename = input[0]; + file_.reset(new SizedRandomAccessFile(env_, filename, memory_data, memory_size)); + TF_RETURN_IF_ERROR(file_->GetFileSize(&file_size_)); StringPiece result; - struct WAVHeader header; - OP_REQUIRES_OK(context, file->Read(0, sizeof(header), &result, (char *)(&header))); + TF_RETURN_IF_ERROR(file_->Read(0, sizeof(header_), &result, (char *)(&header_))); - OP_REQUIRES_OK(context, ValidateWAVHeader(&header)); - if (header.riff_size + 8 != size) { + TF_RETURN_IF_ERROR(ValidateWAVHeader(&header_)); + if (header_.riff_size + 8 != file_size_) { // corrupted file? } - int64 filesize = header.riff_size + 8; + int64 filesize = header_.riff_size + 8; int64 position = result.size(); - if (header.fmt_size != 16) { - position += header.fmt_size - 16; + if (header_.fmt_size != 16) { + position += header_.fmt_size - 16; } int64 nSamples = 0; do { struct DataHeader head; - OP_REQUIRES_OK(context, file->Read(position, sizeof(head), &result, (char *)(&head))); + TF_RETURN_IF_ERROR(file_->Read(position, sizeof(head), &result, (char *)(&head))); position += result.size(); if (memcmp(head.mark, "data", 4) == 0) { // Data should be block aligned // bytes = nSamples * nBlockAlign - OP_REQUIRES(context, (head.size % header.nBlockAlign == 0), errors::InvalidArgument("data chunk should be block aligned (", header.nBlockAlign, "), received: ", head.size)); - nSamples += head.size / header.nBlockAlign; + if (head.size % header_.nBlockAlign != 0) { + return errors::InvalidArgument("data chunk should be block aligned (", header_.nBlockAlign, "), received: ", head.size); + } + nSamples += head.size / header_.nBlockAlign; } position += head.size; } while (position < filesize); - string dtype; - switch (header.wBitsPerSample) { + switch (header_.wBitsPerSample) { case 8: - dtype = "int8"; + dtype_ = DT_INT8; break; case 16: - dtype = "int16"; + dtype_ = DT_INT16; break; case 24: - dtype = "int32"; + dtype_ = DT_INT32; break; default: - OP_REQUIRES(context, false, errors::InvalidArgument("unsupported wBitsPerSample: ", header.wBitsPerSample)); + return errors::InvalidArgument("unsupported wBitsPerSample: ", header_.wBitsPerSample); } - Tensor* dtype_tensor; - OP_REQUIRES_OK(context, context->allocate_output(0, TensorShape({}), &dtype_tensor)); - Tensor* shape_tensor; - OP_REQUIRES_OK(context, context->allocate_output(1, TensorShape({2}), &shape_tensor)); - Tensor* rate_tensor; - OP_REQUIRES_OK(context, context->allocate_output(2, TensorShape({}), &rate_tensor)); + shape_ = TensorShape({nSamples, header_.nChannels}); - dtype_tensor->scalar()() = std::move(dtype); - shape_tensor->flat()(0) = nSamples; - shape_tensor->flat()(1) = header.nChannels; - rate_tensor->scalar()() = header.nSamplesPerSec; + return Status::OK(); } - private: - mutex mu_; - Env* env_ GUARDED_BY(mu_); -}; - -class ReadWAVOp : public OpKernel { - public: - explicit ReadWAVOp(OpKernelConstruction* context) : OpKernel(context) { - env_ = context->env(); + Status Spec(std::vector& dtypes, std::vector& shapes) override { + dtypes.clear(); + dtypes.push_back(dtype_); + shapes.clear(); + shapes.push_back(shape_); + return Status::OK(); } - void Compute(OpKernelContext* context) override { - const Tensor& filename_tensor = context->input(0); - const string& filename = filename_tensor.scalar()(); - - const Tensor& memory_tensor = context->input(1); - const string& memory = memory_tensor.scalar()(); - - const Tensor& start_tensor = context->input(2); - const int64 start = start_tensor.scalar()(); - - const Tensor& stop_tensor = context->input(3); - const int64 stop = stop_tensor.scalar()(); - - std::unique_ptr file(new SizedRandomAccessFile(env_, filename, memory.data(), memory.size())); - uint64 size; - OP_REQUIRES_OK(context, file->GetFileSize(&size)); - - StringPiece result; - struct WAVHeader header; - OP_REQUIRES_OK(context, file->Read(0, sizeof(header), &result, (char *)(&header))); - - OP_REQUIRES_OK(context, ValidateWAVHeader(&header)); - if (header.riff_size + 8 != size) { - // corrupted file? - } - int64 filesize = header.riff_size + 8; - - int64 position = result.size(); - if (header.fmt_size != 16) { - position += header.fmt_size - 16; - } - - int64 nSamples = 0; - do { - struct DataHeader head; - OP_REQUIRES_OK(context, file->Read(position, sizeof(head), &result, (char *)(&head))); - position += result.size(); - if (memcmp(head.mark, "data", 4) == 0) { - // Data should be block aligned - // bytes = nSamples * nBlockAlign - OP_REQUIRES(context, (head.size % header.nBlockAlign == 0), errors::InvalidArgument("data chunk should be block aligned (", header.nBlockAlign, "), received: ", head.size)); - nSamples += head.size / header.nBlockAlign; - } - position += head.size; - } while (position < filesize); - + Status Extra(std::vector* extra) override { + // Expose a sample `rate` + Tensor rate(DT_INT32, TensorShape({})); + rate.scalar()() = header_.nSamplesPerSec; + extra->push_back(rate); + return Status::OK(); + } - int64 sample_start = start; - int64 sample_stop = stop; - if (sample_start > nSamples) { - sample_start = nSamples; + Status GetItem(const int64 start, const int64 stop, const int64 step, std::vector& tensors) override { + Tensor& output_tensor = tensors[0]; + if (step != 1) { + return errors::InvalidArgument("step ", step, " is not supported"); } - if (sample_stop < 0) { - sample_stop = nSamples; - } - if (sample_stop < sample_start) { - sample_stop = sample_start; - } - - Tensor* output_tensor; - OP_REQUIRES_OK(context, context->allocate_output(0, TensorShape({sample_stop - sample_start, header.nChannels}), &output_tensor)); + const int64 sample_start = start; + const int64 sample_stop = stop; int64 sample_offset = 0; - - position = sizeof(header) + header.fmt_size - 16; + if (header_.riff_size + 8 != file_size_) { + // corrupted file? + } + int64 filesize = header_.riff_size + 8; + int64 position = sizeof(header_) + header_.fmt_size - 16; do { + StringPiece result; struct DataHeader head; - OP_REQUIRES_OK(context, file->Read(position, sizeof(head), &result, (char *)(&head))); + TF_RETURN_IF_ERROR(file_->Read(position, sizeof(head), &result, (char *)(&head))); position += result.size(); if (memcmp(head.mark, "data", 4) == 0) { // Already checked the alignment int64 block_sample_start = sample_offset; - int64 block_sample_stop = sample_offset + head.size / header.nBlockAlign; + int64 block_sample_stop = sample_offset + head.size / header_.nBlockAlign; // only read if block_sample_start and block_sample_stop within range if (sample_start < block_sample_stop && sample_stop > block_sample_start) { int64 read_sample_start = (block_sample_start > sample_start ? block_sample_start : sample_start); int64 read_sample_stop = (block_sample_stop < sample_stop ? block_sample_stop : sample_stop); - int64 read_bytes_start = position + (read_sample_start - block_sample_start) * header.nBlockAlign; - int64 read_bytes_stop = position + (read_sample_stop - block_sample_start) * header.nBlockAlign; + int64 read_bytes_start = position + (read_sample_start - block_sample_start) * header_.nBlockAlign; + int64 read_bytes_stop = position + (read_sample_stop - block_sample_start) * header_.nBlockAlign; string buffer; buffer.resize(read_bytes_stop - read_bytes_start); - OP_REQUIRES_OK(context, file->Read(read_bytes_start, read_bytes_stop - read_bytes_start, &result, &buffer[0])); - switch (header.wBitsPerSample) { + TF_RETURN_IF_ERROR(file_->Read(read_bytes_start, read_bytes_stop - read_bytes_start, &result, &buffer[0])); + switch (header_.wBitsPerSample) { case 8: - OP_REQUIRES(context, (header.wBitsPerSample * header.nChannels == header.nBlockAlign * 8), errors::InvalidArgument("unsupported wBitsPerSample and header.nBlockAlign: ", header.wBitsPerSample, ", ", header.nBlockAlign)); - memcpy((char *)(output_tensor->flat().data()) + ((read_sample_start - sample_start) * header.nBlockAlign), &buffer[0], (read_bytes_stop - read_bytes_start)); + if (header_.wBitsPerSample * header_.nChannels != header_.nBlockAlign * 8) { + return errors::InvalidArgument("unsupported wBitsPerSample and header.nBlockAlign: ", header_.wBitsPerSample, ", ", header_.nBlockAlign); + } + memcpy((char *)(output_tensor.flat().data()) + ((read_sample_start - sample_start) * header_.nBlockAlign), &buffer[0], (read_bytes_stop - read_bytes_start)); break; case 16: - OP_REQUIRES(context, (header.wBitsPerSample * header.nChannels == header.nBlockAlign * 8), errors::InvalidArgument("unsupported wBitsPerSample and header.nBlockAlign: ", header.wBitsPerSample, ", ", header.nBlockAlign)); - memcpy((char *)(output_tensor->flat().data()) + ((read_sample_start - sample_start) * header.nBlockAlign), &buffer[0], (read_bytes_stop - read_bytes_start)); + if (header_.wBitsPerSample * header_.nChannels != header_.nBlockAlign * 8) { + return errors::InvalidArgument("unsupported wBitsPerSample and header.nBlockAlign: ", header_.wBitsPerSample, ", ", header_.nBlockAlign); + } + memcpy((char *)(output_tensor.flat().data()) + ((read_sample_start - sample_start) * header_.nBlockAlign), &buffer[0], (read_bytes_stop - read_bytes_start)); break; case 24: // NOTE: The conversion is from signed integer 24 to signed integer 32 (left shift 8 bits) - OP_REQUIRES(context, (header.wBitsPerSample * header.nChannels == header.nBlockAlign * 8), errors::InvalidArgument("unsupported wBitsPerSample and header.nBlockAlign: ", header.wBitsPerSample, ", ", header.nBlockAlign)); + if (header_.wBitsPerSample * header_.nChannels != header_.nBlockAlign * 8) { + return errors::InvalidArgument("unsupported wBitsPerSample and header.nBlockAlign: ", header_.wBitsPerSample, ", ", header_.nBlockAlign); + } for (int64 i = read_sample_start; i < read_sample_stop; i++) { - for (int64 j = 0; j < header.nChannels; j++) { - char *data_p = (char *)(output_tensor->flat().data() + ((i - sample_start) * header.nChannels + j)); - char *read_p = (char *)(&buffer[((i - read_sample_start) * header.nBlockAlign)]) + 3 * j; + for (int64 j = 0; j < header_.nChannels; j++) { + char *data_p = (char *)(output_tensor.flat().data() + ((i - sample_start) * header_.nChannels + j)); + char *read_p = (char *)(&buffer[((i - read_sample_start) * header_.nBlockAlign)]) + 3 * j; data_p[3] = read_p[2]; data_p[2] = read_p[1]; data_p[1] = read_p[0]; @@ -259,25 +209,36 @@ class ReadWAVOp : public OpKernel { } break; default: - OP_REQUIRES(context, false, errors::InvalidArgument("unsupported wBitsPerSample and header.nBlockAlign: ", header.wBitsPerSample, ", ", header.nBlockAlign)); + return errors::InvalidArgument("unsupported wBitsPerSample and header.nBlockAlign: ", header_.wBitsPerSample, ", ", header_.nBlockAlign); } } sample_offset = block_sample_stop; } position += head.size; } while (position < filesize); + + return Status::OK(); + } + + string DebugString() const override { + mutex_lock l(mu_); + return strings::StrCat("WAVIndexable"); } private: - mutex mu_; + mutable mutex mu_; Env* env_ GUARDED_BY(mu_); + std::unique_ptr file_ GUARDED_BY(mu_); + uint64 file_size_ GUARDED_BY(mu_); + DataType dtype_; + TensorShape shape_; + struct WAVHeader header_; }; -REGISTER_KERNEL_BUILDER(Name("ListWAVInfo").Device(DEVICE_CPU), - ListWAVInfoOp); -REGISTER_KERNEL_BUILDER(Name("ReadWAV").Device(DEVICE_CPU), - ReadWAVOp); +REGISTER_KERNEL_BUILDER(Name("WAVIndexableInit").Device(DEVICE_CPU), + IOInterfaceInitOp); +REGISTER_KERNEL_BUILDER(Name("WAVIndexableGetItem").Device(DEVICE_CPU), + IOIndexableGetItemOp); -} // namespace } // namespace data } // namespace tensorflow diff --git a/tensorflow_io/audio/ops/audio_ops.cc b/tensorflow_io/audio/ops/audio_ops.cc index 2ca5c66a4..0d6c18fe4 100644 --- a/tensorflow_io/audio/ops/audio_ops.cc +++ b/tensorflow_io/audio/ops/audio_ops.cc @@ -19,29 +19,43 @@ limitations under the License. namespace tensorflow { -REGISTER_OP("ListWAVInfo") - .Input("filename: string") - .Input("memory: string") - .Output("dtype: string") - .Output("shape: int64") - .Output("rate: int32") - .SetShapeFn([](shape_inference::InferenceContext* c) { - c->set_output(0, c->MakeShape({})); - c->set_output(1, c->MakeShape({2})); - c->set_output(2, c->MakeShape({})); - return Status::OK(); - }); - -REGISTER_OP("ReadWAV") - .Input("filename: string") - .Input("memory: string") - .Input("start: int64") - .Input("stop: int64") - .Attr("dtype: type") - .Output("output: dtype") - .SetShapeFn([](shape_inference::InferenceContext* c) { - c->set_output(0, c->MakeShape({c->UnknownDim(), c->UnknownDim()})); - return Status::OK(); - }); +REGISTER_OP("WAVIndexableInit") + .Input("input: string") + .Output("output: resource") + .Output("dtypes: int64") + .Output("shapes: int64") + .Output("rate: int32") + .Attr("container: string = ''") + .Attr("shared_name: string = ''") + .SetIsStateful() + .SetShapeFn([](shape_inference::InferenceContext* c) { + c->set_output(0, c->Scalar()); + c->set_output(1, c->MakeShape({c->UnknownDim()})); + c->set_output(2, c->MakeShape({c->UnknownDim(), c->UnknownDim()})); + c->set_output(3, c->Scalar()); + return Status::OK(); + }); + +REGISTER_OP("WAVIndexableGetItem") + .Input("input: resource") + .Input("start: int64") + .Input("stop: int64") + .Input("step: int64") + .Output("output: dtype") + .Attr("dtype: list(type) >= 1") + .Attr("shape: list(shape) >= 1") + .SetShapeFn([](shape_inference::InferenceContext* c) { + std::vector shape; + TF_RETURN_IF_ERROR(c->GetAttr("shape", &shape)); + if (shape.size() != c->num_outputs()) { + return errors::InvalidArgument("`shape` must be the same length as `types` (", shape.size(), " vs. ", c->num_outputs()); + } + for (size_t i = 0; i < shape.size(); ++i) { + shape_inference::ShapeHandle entry; + TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape(shape[i], &entry)); + c->set_output(static_cast(i), entry); + } + return Status::OK(); + }); } // namespace tensorflow diff --git a/tensorflow_io/audio/python/ops/audio_ops.py b/tensorflow_io/audio/python/ops/audio_ops.py index 374ef0400..a3379b1d2 100644 --- a/tensorflow_io/audio/python/ops/audio_ops.py +++ b/tensorflow_io/audio/python/ops/audio_ops.py @@ -17,67 +17,17 @@ from __future__ import division from __future__ import print_function +import warnings + import tensorflow as tf from tensorflow_io.core.python.ops import data_ops -from tensorflow_io.core.python.ops import core_ops - -def list_wav_info(filename, **kwargs): - """list_wav_info""" - if not tf.executing_eagerly(): - raise NotImplementedError("list_wav_info only support eager mode") - memory = kwargs.get("memory", "") - dtype, shape, rate = core_ops.list_wav_info( - filename, memory=memory) - return tf.TensorSpec(shape.numpy(), dtype.numpy().decode()), rate - -def read_wav(filename, spec, **kwargs): - """read_wav""" - memory = kwargs.get("memory", "") - start = kwargs.get("start", 0) - stop = kwargs.get("stop", None) - if stop is None and spec.shape[0] is not None: - stop = spec.shape[0] - start - if stop is None: - stop = -1 - return core_ops.read_wav( - filename, memory=memory, - start=start, stop=stop, dtype=spec.dtype) - -class WAVDataset(data_ops.BaseDataset): - """A WAV Dataset""" - - def __init__(self, filename, **kwargs): - """Create a WAVDataset. - - Args: - filename: A string containing filename. - """ - if not tf.executing_eagerly(): - start = kwargs.get("start") - stop = kwargs.get("stop") - dtype = kwargs.get("dtype") - shape = kwargs.get("shape") - else: - spec, _ = list_wav_info(filename) - start = 0 - stop = spec.shape[0] - dtype = spec.dtype - shape = tf.TensorShape( - [dim if i != 0 else None for i, dim in enumerate( - spec.shape.as_list())]) - # capacity is the rough count for each chunk in dataset - capacity = kwargs.get("capacity", 65536) - entry_start = list(range(start, stop, capacity)) - entry_stop = entry_start[1:] + [stop] - dataset = data_ops.BaseDataset.from_tensor_slices( - (tf.constant(entry_start, tf.int64), tf.constant(entry_stop, tf.int64)) - ).map(lambda start, stop: core_ops.read_wav( - filename, memory="", start=start, stop=stop, dtype=dtype)) - self._dataset = dataset +warnings.warn( + "The tensorflow_io.audio.WAVDataset is " + "deprecated. Please look for tfio.IOTensor.from_audio " + "for reading WAV files into tensorflow.", + DeprecationWarning) - super(WAVDataset, self).__init__( - self._dataset._variant_tensor, [dtype], [shape]) # pylint: disable=protected-access class AudioDataset(data_ops.Dataset): """A Audio File Dataset that reads the audio file.""" diff --git a/tensorflow_io/core/BUILD b/tensorflow_io/core/BUILD index 192f0c9b9..0005ae2f8 100644 --- a/tensorflow_io/core/BUILD +++ b/tensorflow_io/core/BUILD @@ -29,6 +29,7 @@ cc_library( name = "dataset_ops", srcs = [ "kernels/dataset_ops.h", + "kernels/io_interface.h", "kernels/stream.h", ], copts = tf_io_copts(), diff --git a/tensorflow_io/core/kernels/io_interface.h b/tensorflow_io/core/kernels/io_interface.h new file mode 100644 index 000000000..247804d46 --- /dev/null +++ b/tensorflow_io/core/kernels/io_interface.h @@ -0,0 +1,340 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/resource_mgr.h" +#include "tensorflow/core/framework/resource_op_kernel.h" +#include "tensorflow/core/util/batch_util.h" + +namespace tensorflow { +namespace data { + +class IOInterface : public ResourceBase { + public: + virtual Status Init(const std::vector& input, const std::vector& metadata, const void* memory_data, const int64 memory_size) = 0; + virtual Status Spec(std::vector& dtypes, std::vector& shapes) = 0; + + virtual Status Extra(std::vector* extra) { + // This is the chance to provide additional extra information which should be appended to extra. + return Status::OK(); + } +}; + +class IOIterableInterface : public IOInterface { + public: + virtual Status Next(const int64 capacity, std::vector& tensors, int64* record_read) = 0; +}; + +class IOIndexableInterface : public IOInterface { + public: + virtual Status GetItem(const int64 start, const int64 stop, const int64 step, std::vector& tensors) = 0; +}; + +template +class IOIndexableImplementation : public IOIndexableInterface { + public: + IOIndexableImplementation(Env* env) + : env_(env) + , iterable_(new Type(env)) {} + + ~IOIndexableImplementation() {} + Status Init(const std::vector& input, const std::vector& metadata, const void* memory_data, const int64 memory_size) override { + + TF_RETURN_IF_ERROR(iterable_->Init(input, metadata, memory_data, memory_size)); + TF_RETURN_IF_ERROR(iterable_->Spec(dtypes_, shapes_)); + + const int64 capacity = 4096; + std::vector chunk_shapes; + for (size_t component = 0; component < shapes_.size(); component++) { + gtl::InlinedVector dims = shapes_[component].dim_sizes(); + dims[0] = capacity; + chunk_shapes.push_back(TensorShape(dims)); + } + + int64 total = 0; + + int64 record_read = 0; + do { + tensors_.push_back(std::vector()); + for (size_t component = 0; component < shapes_.size(); component++) { + tensors_.back().push_back(Tensor(dtypes_[component], chunk_shapes[component])); + } + TF_RETURN_IF_ERROR(iterable_->Next(capacity, tensors_.back(), &record_read)); + if (record_read == 0) { + tensors_.pop_back(); + break; + } + if (record_read < capacity) { + for (size_t component = 0; component < shapes_.size(); component++) { + tensors_.back()[component] = tensors_.back()[component].Slice(0, record_read); + } + } + total += record_read; + } while (record_read != 0); + for (size_t component = 0; component < shapes_.size(); component++) { + shapes_[component].set_dim(0, total); + } + return Status::OK(); + } + virtual Status Spec(std::vector& dtypes, std::vector& shapes) override { + for (size_t component = 0; component < dtypes_.size(); component++) { + dtypes.push_back(dtypes_[component]); + } + for (size_t component = 0; component < shapes_.size(); component++) { + shapes.push_back(shapes_[component]); + } + return Status::OK(); + } + + Status Extra(std::vector* extra) override { + return iterable_->Extra(extra); + } + string DebugString() const override { + mutex_lock l(mu_); + return strings::StrCat("IOIndexableImplementation<", iterable_->DebugString(), ">[]"); + } + + Status GetItem(const int64 start, const int64 stop, const int64 step, std::vector& tensors) override { + if (step != 1) { + return errors::InvalidArgument("step != 1 is not supported: ", step); + } + // Find first chunk + + int64 chunk_index = 0; + int64 chunk_element = -1; + int64 current_element = 0; + while (chunk_index < tensors_.size()) { + if (current_element <= start && start < current_element + tensors_[chunk_index][0].shape().dim_size(0)) { + chunk_element = start - current_element; + current_element = start; + break; + } + current_element += tensors_[chunk_index][0].shape().dim_size(0); + chunk_index++; + } + if (chunk_element < 0) { + return errors::InvalidArgument("start is out of range: ", start); + } + std::vector elements; + for (size_t component = 0; component < shapes_.size(); component++) { + TensorShape shape(shapes_[component].dim_sizes()); + shape.RemoveDim(0); + elements.push_back(Tensor(dtypes_[component], shape)); + } + + while (current_element < stop) { + for (size_t component = 0; component < shapes_.size(); component++) { + batch_util::CopySliceToElement(tensors_[chunk_index][component], &elements[component], chunk_element); + batch_util::CopyElementToSlice(elements[component], &tensors[component], (current_element - start)); + } + chunk_element++; + if (chunk_element == tensors_[chunk_index][0].shape().dim_size(0)) { + chunk_index++; + chunk_element = 0; + } + current_element++; + } + return Status::OK(); + } + private: + mutable mutex mu_; + Env* env_ GUARDED_BY(mu_); + std::unique_ptr iterable_ GUARDED_BY(mu_); + std::vector dtypes_ GUARDED_BY(mu_); + std::vector shapes_ GUARDED_BY(mu_); + std::vector> tensors_; +}; + + +template +class IOInterfaceInitOp : public ResourceOpKernel { + public: + explicit IOInterfaceInitOp(OpKernelConstruction* context) + : ResourceOpKernel(context) { + env_ = context->env(); + } + private: + void Compute(OpKernelContext* context) override { + ResourceOpKernel::Compute(context); + + std::vector input; + const Tensor* input_tensor; + OP_REQUIRES_OK(context, context->input("input", &input_tensor)); + for (int64 i = 0; i < input_tensor->NumElements(); i++) { + input.push_back(input_tensor->flat()(i)); + } + + Status status; + + std::vector metadata; + const Tensor* metadata_tensor; + status = context->input("metadata", &metadata_tensor); + if (status.ok()) { + for (int64 i = 0; i < metadata_tensor->NumElements(); i++) { + metadata.push_back(metadata_tensor->flat()(i)); + } + } + + const void *memory_data = nullptr; + size_t memory_size = 0; + + const Tensor* memory_tensor; + status = context->input("memory", &memory_tensor); + if (status.ok()) { + memory_data = memory_tensor->scalar()().data(); + memory_size = memory_tensor->scalar()().size(); + } + + OP_REQUIRES_OK(context, this->resource_->Init(input, metadata, memory_data, memory_size)); + + std::vector dtypes; + std::vector shapes; + OP_REQUIRES_OK(context, this->resource_->Spec(dtypes, shapes)); + int64 maxrank = 0; + for (size_t component = 0; component < shapes.size(); component++) { + if (dynamic_cast(this->resource_) != nullptr) { + int64 i = 0; + OP_REQUIRES(context, (shapes[component].dim_size(i) > 0), errors::InvalidArgument("component (", component, ")'s shape[", i, "] should not be None, received: ", shapes[component])); + } + for (int64 i = 1; i < shapes[component].dims(); i++) { + OP_REQUIRES(context, (shapes[component].dim_size(i) > 0), errors::InvalidArgument("component (", component, ")'s shape[", i, "] should not be None, received: ", shapes[component])); + } + maxrank = maxrank > shapes[component].dims() ? maxrank : shapes[component].dims(); + } + Tensor dtypes_tensor(DT_INT64, TensorShape({static_cast(dtypes.size())})); + for (size_t i = 0; i < dtypes.size(); i++) { + dtypes_tensor.flat()(i) = dtypes[i]; + } + Tensor shapes_tensor(DT_INT64, TensorShape({static_cast(dtypes.size()), maxrank})); + for (size_t component = 0; component < shapes.size(); component++) { + for (int64 i = 0; i < shapes[component].dims(); i++) { + shapes_tensor.tensor()(component, i) = shapes[component].dim_size(i); + } + for (int64 i = shapes[component].dims(); i < maxrank; i++) { + shapes_tensor.tensor()(component, i) = 0; + } + } + context->set_output(1, dtypes_tensor); + context->set_output(2, shapes_tensor); + + std::vector extra; + OP_REQUIRES_OK(context, this->resource_->Extra(&extra)); + for (size_t i = 0; i < extra.size(); i++) { + context->set_output(3 + i, extra[i]); + } + } + Status CreateResource(Type** resource) + EXCLUSIVE_LOCKS_REQUIRED(mu_) override { + *resource = new Type(env_); + return Status::OK(); + } + mutex mu_; + Env* env_; +}; + +template +class IOIterableNextOp : public OpKernel { + public: + explicit IOIterableNextOp(OpKernelConstruction* ctx) + : OpKernel(ctx) { + } + + void Compute(OpKernelContext* context) override { + Type* resource; + OP_REQUIRES_OK(context, GetResourceFromContext(context, "input", &resource)); + core::ScopedUnref unref(resource); + + const Tensor* capacity_tensor; + OP_REQUIRES_OK(context, context->input("capacity", &capacity_tensor)); + const int64 capacity = capacity_tensor->scalar()(); + + OP_REQUIRES(context, (capacity > 0), errors::InvalidArgument("capacity <= 0 is not supported: ", capacity)); + + std::vector dtypes; + std::vector shapes; + OP_REQUIRES_OK(context, resource->Spec(dtypes, shapes)); + + std::vector tensors; + for (size_t i = 0; i < dtypes.size(); i++) { + gtl::InlinedVector dims = shapes[i].dim_sizes(); + dims[0] = capacity; + tensors.emplace_back(Tensor(dtypes[i], TensorShape(dims))); + } + + int64 record_read; + OP_REQUIRES_OK(context, resource->Next(capacity, tensors, &record_read)); + for (size_t i = 0; i < tensors.size(); i++) { + if (record_read < capacity) { + context->set_output(i, tensors[i].Slice(0, record_read)); + } else { + context->set_output(i, tensors[i]); + } + } + } +}; +template +class IOIndexableGetItemOp : public OpKernel { + public: + explicit IOIndexableGetItemOp(OpKernelConstruction* ctx) + : OpKernel(ctx) { + } + + void Compute(OpKernelContext* context) override { + Type* resource; + OP_REQUIRES_OK(context, GetResourceFromContext(context, "input", &resource)); + core::ScopedUnref unref(resource); + + const Tensor* start_tensor; + OP_REQUIRES_OK(context, context->input("start", &start_tensor)); + int64 start = start_tensor->scalar()(); + + const Tensor* stop_tensor; + OP_REQUIRES_OK(context, context->input("stop", &stop_tensor)); + int64 stop = stop_tensor->scalar()(); + + const Tensor* step_tensor; + OP_REQUIRES_OK(context, context->input("step", &step_tensor)); + int64 step = step_tensor->scalar()(); + + OP_REQUIRES(context, (step == 1), errors::InvalidArgument("step != 1 is not supported: ", step)); + + std::vector dtypes; + std::vector shapes; + OP_REQUIRES_OK(context, resource->Spec(dtypes, shapes)); + + int64 count = shapes[0].dim_size(0); + if (start > count) { + start = count; + } + if (stop < 0) { + stop = count; + } + if (stop < start) { + stop = start; + } + + std::vector tensors; + for (size_t i = 0; i < dtypes.size(); i++) { + gtl::InlinedVector dims = shapes[i].dim_sizes(); + dims[0] = stop - start; + tensors.emplace_back(Tensor(dtypes[i], TensorShape(dims))); + } + OP_REQUIRES_OK(context, resource->GetItem(start, stop, step, tensors)); + for (size_t i = 0; i < tensors.size(); i++) { + context->set_output(i, tensors[i]); + } + } +}; +} // namespace data +} // namespace tensorflow diff --git a/tensorflow_io/core/ops/core_ops.cc b/tensorflow_io/core/ops/core_ops.cc index 61760e457..ee416d014 100644 --- a/tensorflow_io/core/ops/core_ops.cc +++ b/tensorflow_io/core/ops/core_ops.cc @@ -19,21 +19,6 @@ limitations under the License. namespace tensorflow { -REGISTER_OP("AdjustBatchDataset") - .Input("input_dataset: variant") - .Input("batch_size: int64") - .Input("batch_mode: string") - .Output("handle: variant") - .Attr("output_types: list(type) >= 1") - .Attr("output_shapes: list(shape) >= 1") - .SetShapeFn([](shape_inference::InferenceContext* c) { - shape_inference::ShapeHandle unused; - // batch_size should be a scalar. - TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); - // batch_mode should be a scalar. - TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused)); - return shape_inference::ScalarShape(c); - }); REGISTER_OP("ListArchiveEntries") .Input("filename: string") .Input("memory: string") diff --git a/tensorflow_io/core/python/ops/audio_io_tensor_ops.py b/tensorflow_io/core/python/ops/audio_io_tensor_ops.py new file mode 100644 index 000000000..78742292e --- /dev/null +++ b/tensorflow_io/core/python/ops/audio_io_tensor_ops.py @@ -0,0 +1,52 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""AudioIOTensor""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import uuid + +import tensorflow as tf +from tensorflow_io.core.python.ops import io_tensor_ops +from tensorflow_io.core.python.ops import core_ops + +class AudioIOTensor(io_tensor_ops._ColumnIOTensor): # pylint: disable=protected-access + """AudioIOTensor""" + + #============================================================================= + # Constructor (private) + #============================================================================= + def __init__(self, + filename, + internal=False): + with tf.name_scope("AudioIOTensor") as scope: + resource, dtypes, shapes, rate = core_ops.wav_indexable_init( + filename, + container=scope, + shared_name="%s/%s" % (filename, uuid.uuid4().hex)) + self._rate = rate.numpy() + super(AudioIOTensor, self).__init__( + shapes, dtypes, resource, core_ops.wav_indexable_get_item, + internal=internal) + + #============================================================================= + # Accessors + #============================================================================= + + @io_tensor_ops._BaseIOTensorMeta # pylint: disable=protected-access + def rate(self): + """The sample `rate` of the audio stream""" + return self._rate diff --git a/tensorflow_io/core/python/ops/io_tensor.py b/tensorflow_io/core/python/ops/io_tensor.py new file mode 100644 index 000000000..5a237da41 --- /dev/null +++ b/tensorflow_io/core/python/ops/io_tensor.py @@ -0,0 +1,266 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""IOTensor""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import tensorflow as tf +from tensorflow_io.core.python.ops import io_tensor_ops +from tensorflow_io.core.python.ops import audio_io_tensor_ops +from tensorflow_io.core.python.ops import json_io_tensor_ops +from tensorflow_io.core.python.ops import kafka_io_tensor_ops + +class IOTensor(io_tensor_ops._BaseIOTensor): # pylint: disable=protected-access + """IOTensor + + An `IOTensor` is a tensor with data backed by IO operations. For example, + an `AudioIOTensor` is a tensor with data from an audio file, a + `KafkaIOTensor` is a tensor with data from reading the messages of a Kafka + stream server. + + The `IOTensor` is indexable, supporting `__getitem__()` and + `__len__()` methods in Python. In other words, it is a subclass of + `collections.abc.Sequence`. + + Example: + + ```python + >>> import tensorflow_io as tfio + >>> + >>> samples = tfio.IOTensor.from_audio("sample.wav") + >>> print(samples[1000:1005]) + ... tf.Tensor( + ... [[-3] + ... [-7] + ... [-6] + ... [-6] + ... [-5]], shape=(5, 1), dtype=int16) + ``` + + ### Indexable vs. Iterable + + While many IO formats are natually considered as iterable only, in most + of the situations they could still be accessed by indexing through certain + workaround. For example, a Kafka stream is not directly indexable yet the + stream could be saved in memory or disk to allow indexing. Another example + is the packet capture (PCAP) file in networking area. The packets inside + a PCAP file is concatenated sequentially. Since each packets could have + a variable length, the only way to access each packet is to read one + packet at a time. If the PCAP file is huge (e.g., hundreds of GBs or even + TBs), it may not be realistic (or necessarily) to save the index of every + packet in memory. We could consider PCAP format as iterable only. + + As we could see the, availability of memory size could be a factor to decide + if a format is indexable or not. However, this factor could also be blurred + as well in distributed computing. One common case is the file format that + might be splittable where a file could be split into multiple chunks + (without read the whole file) with no data overlapping in between those + chunks. For example, a text file could be reliably split into multiple + chunks with line feed (LF) as the boundary. Processing of chunks could then + be distributed across a group of compute node to speed up (by reading small + chunks into memory). From that standpoint, we could still consider splittable + formats as indexable. + + For that reason our focus is `IOTensor` with convinience indexing and slicing + through `__getitem__()` method. + + ### Lazy Read + + One useful feature of `IOTensor` is the lazy read. Data inside a file is not + read into memory until needed. This could be convenient where only a small + segment of the data is needed. For example, a WAV file could be as big as + GBs but in many cases only several seconds of samples are used for training + or inference purposes. + + While CPU memory is cheap nowadays, GPU memory is still considered as an + expensive resource. It is also imperative to fit data in GPU memory for + speed up purposes. From that perspective lazy read could be very helpful. + + ### Association of Meta Data + + While a file format could consist of mostly numeric data, in may situations + the meta data is important as well. For example, in audio file format the + sample rate is a number that is necessary for almost everything. Association + of the sample rate with the sample of int16 Tensor is more helpful, + especially in eager mode. + + Example: + + ```python + >>> import tensorflow_io as tfio + >>> + >>> samples = tfio.IOTensor.from_audio("sample.wav") + >>> print(samples.rate) + ... 44100 + ``` + + ### Nested Element Structure + + The concept of `IOTensor` is not limited to a Tensor of single data type. + It supports nested element structure which could consists of many + components and complex structures. The exposed API such as `shape()` or + `dtype()` will display the shape and data type of an individual Tensor, + or a nested structure of shape and data types for components of a + composite Tensor. + + Example: + + ```python + >>> import tensorflow_io as tfio + >>> + >>> samples = tfio.IOTensor.from_audio("sample.wav") + >>> print(samples.shape) + ... (22050, 2) + >>> print(samples.dtype) + ... + >>> + >>> features = tfio.IOTensor.from_json("feature.json") + >>> print(features.shape) + ... (TensorShape([Dimension(2)]), TensorShape([Dimension(2)])) + >>> print(features.dtype) + ... (tf.float64, tf.int64) + ``` + + ### Access Columns of Tabular Data Formats + + May file formats such as Parquet or Json are considered as Tabular because + they consists of columns in a table. With `IOTensor` it is possible to + access individual columns through `__call__()`. + + Example: + + ```python + >>> import tensorflow_io as tfio + >>> + >>> features = tfio.IOTensor.from_json("feature.json") + >>> print(features.shape("floatfeature")) + ... (2,) + >>> print(features.dtype("floatfeature")) + ... + >>> + >>> print(features("floatfeature").shape) + ... (2,) + >>> print(features("floatfeature").dtype) + ... + ``` + + ### Conversion from and to Tensor and Dataset + + When needed, `IOTensor` could be converted into a `Tensor` (through + `to_tensor()`, or a `tf.data.Dataset` (through `to_dataset()`, to + suppor operations that is only available through `Tensor` or + `tf.data.Dataset`. + + Example: + + ```python + >>> import tensorflow as tf + >>> import tensorflow_io as tfio + >>> + >>> features = tfio.IOTensor.from_json("feature.json") + >>> + >>> features_tensor = features.to_tensor() + >>> print(features_tensor()) + ... (, + ... ) + >>> + >>> features_dataset = features.to_dataset() + >>> print(features_dataset) + ... <_IOTensorDataset shapes: ((), ()), types: (tf.float64, tf.int64)> + >>> + >>> dataset = tf.data.Dataset.zip((features_dataset, labels_dataset)) + ``` + + """ + + #============================================================================= + # Factory Methods + #============================================================================= + + @classmethod + def from_audio(cls, + filename, + **kwargs): + """Creates an `IOTensor` from an audio file. + + The following audio file formats are supported: + - WAV + + Args: + filename: A string, the filename of an audio file. + name: A name prefix for the IOTensor (optional). + + Returns: + A `IOTensor`. + + """ + with tf.name_scope(kwargs.get("name", "IOFromAudio")): + return audio_io_tensor_ops.AudioIOTensor(filename, internal=True) + + @classmethod + def from_json(cls, + filename, + **kwargs): + """Creates an `IOTensor` from an json file. + + Args: + filename: A string, the filename of an json file. + name: A name prefix for the IOTensor (optional). + + Returns: + A `IOTensor`. + + """ + with tf.name_scope(kwargs.get("name", "IOFromJSON")): + return json_io_tensor_ops.JSONIOTensor(filename, internal=True) + + @classmethod + def from_kafka(cls, + subscription, + **kwargs): + """Creates an `IOTensor` from a Kafka stream. + + Args: + subscription: A `tf.string` tensor containing subscription, + in the format of [topic:partition:offset:length], + by default length is -1 for unlimited. + servers: An optional list of bootstrap servers, by default + `localhost:9092`. + configuration: An optional `tf.string` tensor containing + configurations in [Key=Value] format. There are three + types of configurations: + Global configuration: please refer to 'Global configuration properties' + in librdkafka doc. Examples include + ["enable.auto.commit=false", "heartbeat.interval.ms=2000"] + Topic configuration: please refer to 'Topic configuration properties' + in librdkafka doc. Note all topic configurations should be + prefixed with `configuration.topic.`. Examples include + ["conf.topic.auto.offset.reset=earliest"] + Dataset configuration: there are two configurations available, + `conf.eof=0|1`: if True, the KafkaDaset will stop on EOF (default). + `conf.timeout=milliseconds`: timeout value for Kafka Consumer to wait. + name: A name prefix for the IOTensor (optional). + + Returns: + A `IOTensor`. + + """ + with tf.name_scope(kwargs.get("name", "IOFromKafka")): + return kafka_io_tensor_ops.KafkaIOTensor( + subscription, + servers=kwargs.get("servers", None), + configuration=kwargs.get("configuration", None), + internal=True) diff --git a/tensorflow_io/core/python/ops/io_tensor_ops.py b/tensorflow_io/core/python/ops/io_tensor_ops.py new file mode 100644 index 000000000..c5fb22e14 --- /dev/null +++ b/tensorflow_io/core/python/ops/io_tensor_ops.py @@ -0,0 +1,288 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""_BaseIOTensor""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import tensorflow as tf + +class _BaseIOTensorMeta(property): + """_BaseIOTensorMeta is a decorator that is viewable to __repr__""" + pass + +class _BaseIOTensorDataset(tf.compat.v2.data.Dataset): + """_IOTensorDataset""" + + def __init__(self, spec, resource, function): + start = 0 + stop = tf.nest.flatten( + tf.nest.map_structure(lambda e: e.shape, spec))[0][0] + capacity = 4096 + entry_start = list(range(start, stop, capacity)) + entry_stop = entry_start[1:] + [stop] + + dtype = tf.nest.flatten( + tf.nest.map_structure(lambda e: e.dtype, spec)) + shape = tf.nest.flatten( + tf.nest.map_structure( + lambda e: tf.TensorShape( + [None]).concatenate(e.shape[1:]), spec)) + + dataset = tf.compat.v2.data.Dataset.from_tensor_slices(( + tf.constant(entry_start, tf.int64), + tf.constant(entry_stop, tf.int64))) + dataset = dataset.map( + lambda start, stop: function( + resource, start, stop, 1, dtype=dtype, shape=shape)) + # Note: tf.data.Dataset consider tuple `(e, )` as one element + # instead of a sequence. So next `unbatch()` will not work. + # The tf.stack() below is necessary. + if len(dtype) == 1: + dataset = dataset.map(tf.stack) + dataset = dataset.apply(tf.data.experimental.unbatch()) + self._dataset = dataset + self._resource = resource + self._function = function + super(_BaseIOTensorDataset, self).__init__( + self._dataset._variant_tensor) # pylint: disable=protected-access + + def _inputs(self): + return [] + + @property + def element_spec(self): + return self._dataset.element_spec + +class _BaseIOTensor(object): + """_BaseIOTensor""" + + def __init__(self, + spec, + resource, + function, + internal=False): + if not internal: + raise ValueError("IOTensor constructor is private; please use one " + "of the factory methods instead (e.g., " + "IOTensor.from_tensor())") + self._spec = spec + self._resource = resource + self._function = function + super(_BaseIOTensor, self).__init__() + + #============================================================================= + # Accessors + #============================================================================= + + @property + def spec(self): + """The `TensorSpec` of values in this tensor.""" + return self._spec + + #============================================================================= + # String Encoding + #============================================================================= + def __repr__(self): + meta = "".join([", %s=%s" % ( + k, repr(v.__get__(self))) for k, v in self.__class__.__dict__.items( + ) if isinstance(v, _BaseIOTensorMeta)]) + return "<%s: spec=%s%s>" % ( + self.__class__.__name__, self.spec, meta) + + + #============================================================================= + # Indexing & Slicing + #============================================================================= + def __getitem__(self, key): + """Returns the specified piece of this IOTensor.""" + if isinstance(key, slice): + start = key.start + stop = key.stop + step = key.step + if start is None: + start = 0 + if stop is None: + stop = -1 + if step is None: + step = 1 + else: + start = key + stop = key + 1 + step = 1 + dtype = tf.nest.flatten( + tf.nest.map_structure(lambda e: e.dtype, self.spec)) + shape = tf.nest.flatten( + tf.nest.map_structure(lambda e: e.shape, self.spec)) + return tf.nest.pack_sequence_as(self.spec, self._function( + self._resource, + start, stop, step, + dtype=dtype, + shape=shape)) + + def __len__(self): + """Returns the total number of items of this IOTensor.""" + return tf.nest.flatten( + tf.nest.map_structure(lambda e: e.shape, self.spec))[0][0] + + #============================================================================= + # Tensor Type Conversions + #============================================================================= + + @classmethod + def from_tensor(cls, + tensor, + **kwargs): + """Converts a `tf.Tensor` into a `IOTensor`. + + Examples: + + ```python + ``` + + Args: + tensor: The `Tensor` to convert. + + Returns: + A `IOTensor`. + + Raises: + ValueError: If tensor is not a `Tensor`. + """ + with tf.name_scope(kwargs.get("name", "IOFromTensor")): + _ = tensor + raise NotImplementedError() + + def to_tensor(self, **kwargs): + """Converts this `IOTensor` into a `tf.Tensor`. + + Example: + + ```python + ``` + + Args: + name: A name prefix for the returned tensors (optional). + + Returns: + A `Tensor` with value obtained from this `IOTensor`. + """ + with tf.name_scope(kwargs.get("name", "IOToTensor")): + return self.__getitem__(slice(None, None)) + + #============================================================================= + # Dataset Conversions + #============================================================================= + + def to_dataset(self): + """Converts this `IOTensor` into a `tf.data.Dataset`. + + Example: + + ```python + ``` + + Args: + + Returns: + A `tf.data.Dataset` with value obtained from this `IOTensor`. + """ + return _BaseIOTensorDataset( + self.spec, self._resource, self._function) + +class _ColumnIOTensor(_BaseIOTensor): + """_ColumnIOTensor""" + + def __init__(self, + shapes, + dtypes, + resource, + function, + internal=False): + shapes = [ + tf.TensorShape( + [None if dim < 0 else dim for dim in e.numpy() if dim != 0] + ) for e in tf.unstack(shapes)] + dtypes = [tf.as_dtype(e.numpy()) for e in tf.unstack(dtypes)] + spec = [tf.TensorSpec(shape, dtype) for ( + shape, dtype) in zip(shapes, dtypes)] + assert len(spec) == 1 + spec = spec[0] + + self._shape = spec.shape + self._dtype = spec.dtype + super(_ColumnIOTensor, self).__init__( + spec, resource, function, internal=internal) + + #============================================================================= + # Accessors + #============================================================================= + + @property + def shape(self): + """Returns the `TensorShape` that represents the shape of the tensor.""" + return self._shape + + @property + def dtype(self): + """Returns the `dtype` of elements in the tensor.""" + return self._dtype + +class _TableIOTensor(_BaseIOTensor): + """_TableIOTensor""" + + def __init__(self, + shapes, + dtypes, + columns, + filename, + resource, + function, + internal=False): + shapes = [ + tf.TensorShape( + [None if dim < 0 else dim for dim in e.numpy() if dim != 0] + ) for e in tf.unstack(shapes)] + dtypes = [tf.as_dtype(e.numpy()) for e in tf.unstack(dtypes)] + columns = [e.numpy().decode() for e in tf.unstack(columns)] + spec = [tf.TensorSpec(shape, dtype, column) for ( + shape, dtype, column) in zip(shapes, dtypes, columns)] + if len(spec) == 1: + spec = spec[0] + else: + spec = tuple(spec) + self._filename = filename + super(_TableIOTensor, self).__init__( + spec, resource, function, internal=internal) + + #============================================================================= + # Accessors + #============================================================================= + + def columns(self): + """The `TensorSpec` of column named `name`""" + return [e.name for e in tf.nest.flatten(self.spec)] + + def shape(self, column): + """Returns the `TensorShape` shape of `column` in the tensor.""" + return next(e.shape for e in tf.nest.flatten(self.spec) if e.name == column) + + def dtype(self, column): + """Returns the `dtype` of `column` in the tensor.""" + return next(e.dtype for e in tf.nest.flatten(self.spec) if e.name == column) + + def __call__(self, column): + """Return a new IOTensor with column named `column`""" + return self.__class__(self._filename, columns=[column], internal=True) # pylint: disable=no-value-for-parameter diff --git a/tensorflow_io/core/python/ops/json_io_tensor_ops.py b/tensorflow_io/core/python/ops/json_io_tensor_ops.py new file mode 100644 index 000000000..f6bab6acc --- /dev/null +++ b/tensorflow_io/core/python/ops/json_io_tensor_ops.py @@ -0,0 +1,48 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""JSONIOTensor""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import uuid + +import tensorflow as tf +from tensorflow_io.core.python.ops import io_tensor_ops +from tensorflow_io.core.python.ops import core_ops + +class JSONIOTensor(io_tensor_ops._TableIOTensor): # pylint: disable=protected-access + """JSONIOTensor""" + + #============================================================================= + # Constructor (private) + #============================================================================= + def __init__(self, + filename, + columns=None, + internal=False): + with tf.name_scope("JSONIOTensor") as scope: + metadata = [] + if columns is not None: + metadata.extend(["column: "+column for column in columns]) + resource, dtypes, shapes, columns = core_ops.json_indexable_init( + filename, metadata=metadata, + container=scope, + shared_name="%s/%s" % (filename, uuid.uuid4().hex)) + self._filename = filename + super(JSONIOTensor, self).__init__( + shapes, dtypes, columns, filename, + resource, core_ops.json_indexable_get_item, + internal=internal) diff --git a/tensorflow_io/core/python/ops/kafka_dataset_ops.py b/tensorflow_io/core/python/ops/kafka_dataset_ops.py new file mode 100644 index 000000000..2c286bf13 --- /dev/null +++ b/tensorflow_io/core/python/ops/kafka_dataset_ops.py @@ -0,0 +1,84 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""KafkaDataset""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import sys +import uuid + +import tensorflow as tf +from tensorflow_io.core.python.ops import core_ops + +class KafkaDataset(tf.compat.v2.data.Dataset): + """KafkaDataset""" + + def __init__(self, + subscription, + servers=None, + configuration=None): + """Create a KafkaDataset. + Args: + subscription: A `tf.string` tensor containing subscription, + in the format of [topic:partition:offset:length], + by default length is -1 for unlimited. + servers: A list of bootstrap servers, by default `localhost:9092`. + configuration: A `tf.string` tensor containing configurations + in [Key=Value] format. There are three types of configurations, + Global configuration: please refer to 'Global configuration properties' + in librdkafka doc. Examples include + ["enable.auto.commit=false", "heartbeat.interval.ms=2000"] + Topic configuration: please refer to 'Topic configuration properties' + in librdkafka doc. Note all topic configurations should be + prefixed with `configuration.topic.`. Examples include + ["conf.topic.auto.offset.reset=earliest"] + Dataset configuration: there are two configurations available, + `conf.eof=0|1`: if True, the KafkaDaset will stop on EOF (default). + `conf.timeout=milliseconds`: timeout value for Kafka Consumer to wait. + """ + with tf.name_scope("KafkaDataset") as scope: + metadata = [e for e in configuration or []] + if servers is not None: + metadata.append("bootstrap.servers=%s" % servers) + resource, _, _ = core_ops.kafka_iterable_init( + subscription, metadata=metadata, + container=scope, + shared_name="%s/%s" % (subscription, uuid.uuid4().hex)) + + capacity = 4096 + dataset = tf.compat.v2.data.Dataset.range(0, sys.maxsize, capacity) + dataset = dataset.map( + lambda i: core_ops.kafka_iterable_next(resource, capacity, dtype=[tf.string], shape=[tf.TensorShape([None])])) + dataset = dataset.apply( + tf.data.experimental.take_while( + lambda v: tf.greater(tf.shape(v)[0], 0))) + # Note: tf.data.Dataset consider tuple `(e, )` as one element + # instead of a sequence. So next `unbatch()` will not work. + # The tf.stack() below is necessary. + dataset = dataset.map(tf.stack) + dataset = dataset.unbatch() + + self._resource = resource + self._dataset = dataset + super(KafkaDataset, self).__init__( + self._dataset._variant_tensor) # pylint: disable=protected-access + + def _inputs(self): + return [] + + @property + def element_spec(self): + return self._dataset.element_spec diff --git a/tensorflow_io/core/python/ops/kafka_io_tensor_ops.py b/tensorflow_io/core/python/ops/kafka_io_tensor_ops.py new file mode 100644 index 000000000..ee88e4b90 --- /dev/null +++ b/tensorflow_io/core/python/ops/kafka_io_tensor_ops.py @@ -0,0 +1,47 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""KafkaIOTensor""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import uuid + +import tensorflow as tf +from tensorflow_io.core.python.ops import io_tensor_ops +from tensorflow_io.core.python.ops import core_ops + +class KafkaIOTensor(io_tensor_ops._ColumnIOTensor): # pylint: disable=protected-access + """KafkaIOTensor""" + + #============================================================================= + # Constructor (private) + #============================================================================= + def __init__(self, + subscription, + servers=None, + configuration=None, + internal=False): + with tf.name_scope("KafkaIOTensor") as scope: + metadata = [e for e in configuration or []] + if servers is not None: + metadata.append("bootstrap.servers=%s" % servers) + resource, dtypes, shapes = core_ops.kafka_indexable_init( + subscription, metadata=metadata, + container=scope, + shared_name="%s/%s" % (subscription, uuid.uuid4().hex)) + super(KafkaIOTensor, self).__init__( + shapes, dtypes, resource, core_ops.kafka_indexable_get_item, + internal=internal) diff --git a/tensorflow_io/json/BUILD b/tensorflow_io/json/BUILD index ae75eb6df..b3d8b63cf 100644 --- a/tensorflow_io/json/BUILD +++ b/tensorflow_io/json/BUILD @@ -19,7 +19,9 @@ cc_library( ], linkstatic = True, deps = [ + "//tensorflow_io/arrow:arrow_ops", "//tensorflow_io/core:dataset_ops", + "@arrow", "@jsoncpp_git//:jsoncpp", ], ) diff --git a/tensorflow_io/json/kernels/json_kernels.cc b/tensorflow_io/json/kernels/json_kernels.cc index 2fcae9863..a2dff6146 100644 --- a/tensorflow_io/json/kernels/json_kernels.cc +++ b/tensorflow_io/json/kernels/json_kernels.cc @@ -21,6 +21,12 @@ limitations under the License. #include "tensorflow/core/lib/io/buffered_inputstream.h" #include "tensorflow/core/platform/env.h" #include "include/json/json.h" +#include "tensorflow_io/core/kernels/io_interface.h" +#include "tensorflow_io/core/kernels/stream.h" +#include "arrow/memory_pool.h" +#include "arrow/json/reader.h" +#include "arrow/table.h" +#include "tensorflow_io/arrow/kernels/arrow_kernels.h" namespace tensorflow { namespace data { @@ -199,5 +205,174 @@ REGISTER_KERNEL_BUILDER(Name("ReadJSON").Device(DEVICE_CPU), ReadJSONOp); } // namespace + + +class JSONIndexable : public IOIndexableInterface { + public: + JSONIndexable(Env* env) + : env_(env) {} + + ~JSONIndexable() {} + Status Init(const std::vector& input, const std::vector& metadata, const void* memory_data, const int64 memory_size) override { + if (input.size() > 1) { + return errors::InvalidArgument("more than 1 filename is not supported"); + } + const string& filename = input[0]; + file_.reset(new SizedRandomAccessFile(env_, filename, memory_data, memory_size)); + TF_RETURN_IF_ERROR(file_->GetFileSize(&file_size_)); + + json_file_.reset(new ArrowRandomAccessFile(file_.get(), file_size_)); + + ::arrow::Status status; + + status = ::arrow::json::TableReader::Make(::arrow::default_memory_pool(), json_file_, ::arrow::json::ReadOptions::Defaults(), ::arrow::json::ParseOptions::Defaults(), &reader_); + if (!status.ok()) { + return errors::InvalidArgument("unable to make a TableReader: ", status); + } + status = reader_->Read(&table_); + if (!status.ok()) { + return errors::InvalidArgument("unable to read table: ", status); + } + + std::vector columns; + for (size_t i = 0; i < metadata.size(); i++) { + if (metadata[i].find_first_of("column: ") == 0) { + columns.emplace_back(metadata[i].substr(8)); + } + } + + columns_index_.clear(); + if (columns.size() == 0) { + for (int i = 0; i < table_->num_columns(); i++) { + columns_index_.push_back(i); + } + } else { + std::unordered_map columns_map; + for (int i = 0; i < table_->num_columns(); i++) { + columns_map[table_->column(i)->name()] = i; + } + for (size_t i = 0; i < columns.size(); i++) { + columns_index_.push_back(columns_map[columns[i]]); + } + } + + dtypes_.clear(); + shapes_.clear(); + columns_.clear(); + for (size_t i = 0; i < columns_index_.size(); i++) { + int column_index = columns_index_[i]; + ::tensorflow::DataType dtype; + TF_RETURN_IF_ERROR(GetTensorFlowType(table_->column(column_index)->type(), &dtype)); + dtypes_.push_back(dtype); + shapes_.push_back(TensorShape({static_cast(table_->num_rows())})); + columns_.push_back(table_->column(column_index)->name()); + } + + return Status::OK(); + } + Status Spec(std::vector& dtypes, std::vector& shapes) override { + dtypes.clear(); + for (size_t i = 0; i < dtypes_.size(); i++) { + dtypes.push_back(dtypes_[i]); + } + shapes.clear(); + for (size_t i = 0; i < shapes_.size(); i++) { + shapes.push_back(shapes_[i]); + } + return Status::OK(); + } + + Status Extra(std::vector* extra) override { + // Expose columns + Tensor columns(DT_STRING, TensorShape({static_cast(columns_.size())})); + for (size_t i = 0; i < columns_.size(); i++) { + columns.flat()(i) = columns_[i]; + } + extra->push_back(columns); + return Status::OK(); + } + + Status GetItem(const int64 start, const int64 stop, const int64 step, std::vector& tensors) override { + if (step != 1) { + return errors::InvalidArgument("step ", step, " is not supported"); + } + for (size_t i = 0; i < tensors.size(); i++) { + int column_index = columns_index_[i]; + std::shared_ptr<::arrow::Column> slice = table_->column(column_index)->Slice(start, stop); + + #define PROCESS_TYPE(TTYPE,ATYPE) { \ + int64 curr_index = 0; \ + for (auto chunk : slice->data()->chunks()) { \ + for (int64_t item = 0; item < chunk->length(); item++) { \ + tensors[i].flat()(curr_index) = (dynamic_cast(chunk.get()))->Value(item); \ + curr_index++; \ + } \ + } \ + } + switch (tensors[i].dtype()) { + case DT_BOOL: + PROCESS_TYPE(bool, ::arrow::BooleanArray); + break; + case DT_INT8: + PROCESS_TYPE(int8, ::arrow::NumericArray<::arrow::Int8Type>); + break; + case DT_UINT8: + PROCESS_TYPE(uint8, ::arrow::NumericArray<::arrow::UInt8Type>); + break; + case DT_INT16: + PROCESS_TYPE(int16, ::arrow::NumericArray<::arrow::Int16Type>); + break; + case DT_UINT16: + PROCESS_TYPE(uint16, ::arrow::NumericArray<::arrow::UInt16Type>); + break; + case DT_INT32: + PROCESS_TYPE(int32, ::arrow::NumericArray<::arrow::Int32Type>); + break; + case DT_UINT32: + PROCESS_TYPE(uint32, ::arrow::NumericArray<::arrow::UInt32Type>); + break; + case DT_INT64: + PROCESS_TYPE(int64, ::arrow::NumericArray<::arrow::Int64Type>); + break; + case DT_UINT64: + PROCESS_TYPE(uint64, ::arrow::NumericArray<::arrow::UInt64Type>); + break; + case DT_FLOAT: + PROCESS_TYPE(float, ::arrow::NumericArray<::arrow::FloatType>); + break; + case DT_DOUBLE: + PROCESS_TYPE(double, ::arrow::NumericArray<::arrow::DoubleType>); + break; + default: + return errors::InvalidArgument("data type is not supported: ", DataTypeString(tensors[i].dtype())); + } + } + + return Status::OK(); + } + + string DebugString() const override { + mutex_lock l(mu_); + return strings::StrCat("JSONIndexable"); + } + private: + mutable mutex mu_; + Env* env_ GUARDED_BY(mu_); + std::unique_ptr file_ GUARDED_BY(mu_); + uint64 file_size_ GUARDED_BY(mu_); + std::shared_ptr json_file_; + std::shared_ptr<::arrow::json::TableReader> reader_; + std::shared_ptr<::arrow::Table> table_; + + std::vector dtypes_; + std::vector shapes_; + std::vector columns_; + std::vector columns_index_; +}; + +REGISTER_KERNEL_BUILDER(Name("JSONIndexableInit").Device(DEVICE_CPU), + IOInterfaceInitOp); +REGISTER_KERNEL_BUILDER(Name("JSONIndexableGetItem").Device(DEVICE_CPU), + IOIndexableGetItemOp); } // namespace data } // namespace tensorflow diff --git a/tensorflow_io/json/ops/json_ops.cc b/tensorflow_io/json/ops/json_ops.cc index a2a6afc7d..c9206409a 100644 --- a/tensorflow_io/json/ops/json_ops.cc +++ b/tensorflow_io/json/ops/json_ops.cc @@ -19,6 +19,47 @@ limitations under the License. namespace tensorflow { +REGISTER_OP("JSONIndexableInit") + .Input("input: string") + .Input("metadata: string") + .Output("output: resource") + .Output("dtypes: int64") + .Output("shapes: int64") + .Output("columns: string") + .Attr("container: string = ''") + .Attr("shared_name: string = ''") + .SetIsStateful() + .SetShapeFn([](shape_inference::InferenceContext* c) { + c->set_output(0, c->Scalar()); + c->set_output(1, c->MakeShape({c->UnknownDim()})); + c->set_output(2, c->MakeShape({c->UnknownDim(), c->UnknownDim()})); + c->set_output(3, c->MakeShape({c->UnknownDim()})); + return Status::OK(); + }); + +REGISTER_OP("JSONIndexableGetItem") + .Input("input: resource") + .Input("start: int64") + .Input("stop: int64") + .Input("step: int64") + .Output("output: dtype") + .Attr("dtype: list(type) >= 1") + .Attr("shape: list(shape) >= 1") + .SetShapeFn([](shape_inference::InferenceContext* c) { + std::vector shape; + TF_RETURN_IF_ERROR(c->GetAttr("shape", &shape)); + if (shape.size() != c->num_outputs()) { + return errors::InvalidArgument("`shape` must be the same length as `types` (", shape.size(), " vs. ", c->num_outputs()); + } + for (size_t i = 0; i < shape.size(); ++i) { + shape_inference::ShapeHandle entry; + TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape(shape[i], &entry)); + c->set_output(static_cast(i), entry); + } + return Status::OK(); + }); + + REGISTER_OP("ListJSONColumns") .Input("filename: string") .Output("columns: string") diff --git a/tensorflow_io/kafka/BUILD b/tensorflow_io/kafka/BUILD index 21ce10215..0b3976702 100644 --- a/tensorflow_io/kafka/BUILD +++ b/tensorflow_io/kafka/BUILD @@ -11,6 +11,7 @@ cc_library( name = "kafka_ops", srcs = [ "kernels/kafka_dataset_ops.cc", + "kernels/kafka_kernels.cc", "kernels/kafka_sequence.cc", "ops/dataset_ops.cc", "ops/kafka_ops.cc", diff --git a/tensorflow_io/kafka/kernels/kafka_kernels.cc b/tensorflow_io/kafka/kernels/kafka_kernels.cc new file mode 100644 index 000000000..2a3a25f20 --- /dev/null +++ b/tensorflow_io/kafka/kernels/kafka_kernels.cc @@ -0,0 +1,253 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow_io/core/kernels/io_interface.h" +//#include "tensorflow/core/platform/logging.h" + +#include "rdkafkacpp.h" + +#include + +namespace tensorflow { +namespace data { + +class KafkaEventCb : public RdKafka::EventCb { +public: + KafkaEventCb() + : run_(true) {} + + bool run() { + return run_; + } + + void event_cb (RdKafka::Event &event) { + switch (event.type()) { + case RdKafka::Event::EVENT_ERROR: + LOG(ERROR) << "EVENT_ERROR: " << "(" << RdKafka::err2str(event.err()) << "): " << event.str(); + { + run_ = (event.err() != RdKafka::ERR__ALL_BROKERS_DOWN); + } + break; + case RdKafka::Event::EVENT_STATS: + LOG(ERROR) << "EVENT_STATS: " << event.str(); + break; + case RdKafka::Event::EVENT_LOG: + LOG(ERROR) << "EVENT_LOG: " << event.severity() << "-" << event.fac().c_str() << "-" << event.str().c_str(); + break; + case RdKafka::Event::EVENT_THROTTLE: + LOG(ERROR) << "EVENT_THROTTLE: " << event.throttle_time() << "ms by " << event.broker_name() << " id " << (int)event.broker_id(); + break; + default: + LOG(ERROR) << "EVENT: " << event.type() << " (" << RdKafka::err2str(event.err()) << "): " << event.str(); + break; + } + } +private: + mutable mutex mu_; + bool run_ GUARDED_BY(mu_) = true; +}; + +class KafkaIterable : public IOIterableInterface { + public: + KafkaIterable(Env* env) + : env_(env) {} + + ~KafkaIterable() {} + Status Init(const std::vector& input, const std::vector& metadata, const void* memory_data, const int64 memory_size) override { + std::unique_ptr conf( + RdKafka::Conf::create(RdKafka::Conf::CONF_GLOBAL)); + std::unique_ptr conf_topic( + RdKafka::Conf::create(RdKafka::Conf::CONF_TOPIC)); + + string errstr; + RdKafka::Conf::ConfResult result = RdKafka::Conf::CONF_UNKNOWN; + + eof_ = true; + timeout_ = 1000; + for (size_t i = 0; i < metadata.size(); i++) { + if (metadata[i].find_first_of("conf.eof") == 0) { + std::vector parts = str_util::Split(metadata[i], "="); + if (parts.size() != 2) { + return errors::InvalidArgument("invalid timeout configuration: ", metadata[i]); + } + eof_ = (parts[1] != "0"); + } else if (metadata[i].find_first_of("conf.timeout") == 0) { + std::vector parts = str_util::Split(metadata[i], "="); + if (parts.size() != 2 || !strings::safe_strto64(parts[1], &timeout_)) { + return errors::InvalidArgument("invalid timeout configuration: ", metadata[i]); + } + } else if (metadata[i].find_first_of("conf.topic.") == 0) { + std::vector parts = str_util::Split(metadata[i], "="); + if (parts.size() != 2) { + return errors::InvalidArgument("invalid topic configuration: ", metadata[i]); + } + result = conf_topic->set(parts[0].substr(11), parts[1], errstr); + if (result != RdKafka::Conf::CONF_OK) { + return errors::Internal("failed to do topic configuration:", metadata[i], "error:", errstr); + } + } else if (metadata[i] != "" && metadata[i].find_first_of("conf.") == string::npos) { + std::vector parts = str_util::Split(metadata[i], "="); + if (parts.size() != 2) { + return errors::InvalidArgument("invalid topic configuration: ", metadata[i]); + } + if ((result = conf->set(parts[0], parts[1], errstr)) != RdKafka::Conf::CONF_OK) { + return errors::Internal("failed to do global configuration: ", metadata[i], "error:", errstr); + } + } + } + if ((result = conf->set("default_topic_conf", conf_topic.get(), errstr)) != RdKafka::Conf::CONF_OK) { + return errors::Internal("failed to set default_topic_conf:", errstr); + } + + // consumer.properties: + // bootstrap.servers=localhost:9092 + // group.id=test-consumer-group + string bootstrap_servers; + if ((result = conf->get("bootstrap.servers", bootstrap_servers)) != RdKafka::Conf::CONF_OK) { + bootstrap_servers = "localhost:9092"; + if ((result = conf->set("bootstrap.servers", bootstrap_servers, errstr)) != RdKafka::Conf::CONF_OK) { + return errors::Internal("failed to set bootstrap.servers [", bootstrap_servers, "]:", errstr); + } + } + string group_id; + if ((result = conf->get("group.id", group_id)) != RdKafka::Conf::CONF_OK) { + group_id = "test-consumer-group"; + if ((result = conf->set("group.id", group_id, errstr)) != RdKafka::Conf::CONF_OK) { + return errors::Internal("failed to set group.id [", group_id, "]:", errstr); + } + } + if ((result = conf->set("event_cb", &kafka_event_cb_, errstr)) != RdKafka::Conf::CONF_OK) { + return errors::Internal("failed to set event_cb:", errstr); + } + + // TODO: multiple topic and partitions + + const string& entry = input[0]; + std::vector parts = str_util::Split(entry, ":"); + string topic = parts[0]; + int32 partition = 0; + if (parts.size() > 1) { + if (!strings::safe_strto32(parts[1], &partition)) { + return errors::InvalidArgument("invalid parameters: ", entry); + } + } + + int64 start = 0; + if (parts.size() > 2) { + if (!strings::safe_strto64(parts[2], &start)) { + return errors::InvalidArgument("invalid parameters: ", entry); + } + } + subscription_.reset(RdKafka::TopicPartition::create(topic, partition, start)); + start = subscription_->offset(); + + offset_ = start; + + int64 stop = -1; + if (parts.size() > 3) { + if (!strings::safe_strto64(parts[3], &stop)) { + return errors::InvalidArgument("invalid parameters: ", entry); + } + } + range_ = std::pair(start, stop); + + consumer_.reset(RdKafka::KafkaConsumer::create(conf.get(), errstr)); + if (!consumer_.get()) { + return errors::Internal("failed to create consumer:", errstr); + } + + std::vector partitions; + partitions.emplace_back(subscription_.get()); + RdKafka::ErrorCode err = consumer_->assign(partitions); + if (err != RdKafka::ERR_NO_ERROR) { + return errors::Internal("failed to assign partition: ", RdKafka::err2str(err)); + } + + return Status::OK(); + } + Status Next(const int64 capacity, std::vector& tensors, int64* record_read) override { + *record_read = 0; + while (consumer_.get() != nullptr && (*record_read) < capacity) { + if (!kafka_event_cb_.run()) { + return errors::Internal("failed to consume due to all brokers down"); + } + if (range_.second >= 0 && (subscription_->offset() >= range_.second || offset_ >= range_.second)) { + // EOF of topic + consumer_.reset(nullptr); + return Status::OK(); + } + + std::unique_ptr message(consumer_->consume(timeout_)); + if (message->err() == RdKafka::ERR_NO_ERROR) { + // Produce the line as output. + tensors[0].flat()((*record_read)) = std::string(static_cast(message->payload()), message->len()); + // Sync offset + offset_ = message->offset(); + (*record_read)++; + continue; + } + if (message->err() == RdKafka::ERR__PARTITION_EOF) { + LOG(INFO) << "Partition reach EOF, current offset: " << offset_; + if (eof_) { + consumer_.reset(nullptr); + return Status::OK(); + } + } + else if (message->err() == RdKafka::ERR__TRANSPORT) { + // Not return error here because consumer will try re-connect. + LOG(ERROR) << "Broker transport failure: " << message->errstr(); + } + else if (message->err() != RdKafka::ERR__TIMED_OUT) { + LOG(ERROR) << "Failed to consume: " << message->errstr(); + return errors::Internal("Failed to consume: ", message->errstr()); + } + } + return Status::OK(); + } + Status Spec(std::vector& dtypes, std::vector& shapes) override { + dtypes.clear(); + dtypes.push_back(DT_STRING); + shapes.clear(); + shapes.push_back(PartialTensorShape({-1})); + return Status::OK(); + } + + string DebugString() const override { + mutex_lock l(mu_); + return strings::StrCat("KafkaIterable[]"); + } + private: + mutable mutex mu_; + Env* env_ GUARDED_BY(mu_); + std::pair range_ GUARDED_BY(mu_); + std::unique_ptr subscription_ GUARDED_BY(mu_); + std::unique_ptr consumer_ GUARDED_BY(mu_); + KafkaEventCb kafka_event_cb_ = KafkaEventCb(); + int64 timeout_ GUARDED_BY(mu_); + bool eof_ GUARDED_BY(mu_); + int64 offset_ GUARDED_BY(mu_); +}; + +REGISTER_KERNEL_BUILDER(Name("KafkaIterableInit").Device(DEVICE_CPU), + IOInterfaceInitOp); +REGISTER_KERNEL_BUILDER(Name("KafkaIterableNext").Device(DEVICE_CPU), + IOIterableNextOp); +REGISTER_KERNEL_BUILDER(Name("KafkaIndexableInit").Device(DEVICE_CPU), + IOInterfaceInitOp>); +REGISTER_KERNEL_BUILDER(Name("KafkaIndexableGetItem").Device(DEVICE_CPU), + IOIndexableGetItemOp>); + +} // namespace data +} // namespace tensorflow diff --git a/tensorflow_io/kafka/ops/kafka_ops.cc b/tensorflow_io/kafka/ops/kafka_ops.cc index 7142ae719..1ca333eac 100644 --- a/tensorflow_io/kafka/ops/kafka_ops.cc +++ b/tensorflow_io/kafka/ops/kafka_ops.cc @@ -19,6 +19,81 @@ limitations under the License. namespace tensorflow { +REGISTER_OP("KafkaIndexableInit") + .Input("input: string") + .Input("metadata: string") + .Output("output: resource") + .Output("dtypes: int64") + .Output("shapes: int64") + .Attr("container: string = ''") + .Attr("shared_name: string = ''") + .SetIsStateful() + .SetShapeFn([](shape_inference::InferenceContext* c) { + c->set_output(0, c->Scalar()); + c->set_output(1, c->MakeShape({c->UnknownDim()})); + c->set_output(2, c->MakeShape({c->UnknownDim(), c->UnknownDim()})); + return Status::OK(); + }); + +REGISTER_OP("KafkaIndexableGetItem") + .Input("input: resource") + .Input("start: int64") + .Input("stop: int64") + .Input("step: int64") + .Output("output: dtype") + .Attr("dtype: list(type) >= 1") + .Attr("shape: list(shape) >= 1") + .SetShapeFn([](shape_inference::InferenceContext* c) { + std::vector shape; + TF_RETURN_IF_ERROR(c->GetAttr("shape", &shape)); + if (shape.size() != c->num_outputs()) { + return errors::InvalidArgument("`shape` must be the same length as `types` (", shape.size(), " vs. ", c->num_outputs()); + } + for (size_t i = 0; i < shape.size(); ++i) { + shape_inference::ShapeHandle entry; + TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape(shape[i], &entry)); + c->set_output(static_cast(i), entry); + } + return Status::OK(); + }); + +REGISTER_OP("KafkaIterableInit") + .Input("input: string") + .Input("metadata: string") + .Output("output: resource") + .Output("dtypes: int64") + .Output("shapes: int64") + .Attr("container: string = ''") + .Attr("shared_name: string = ''") + .SetIsStateful() + .SetShapeFn([](shape_inference::InferenceContext* c) { + c->set_output(0, c->Scalar()); + c->set_output(1, c->MakeShape({c->UnknownDim()})); + c->set_output(2, c->MakeShape({c->UnknownDim(), c->UnknownDim()})); + return Status::OK(); + }); + +REGISTER_OP("KafkaIterableNext") + .Input("input: resource") + .Input("capacity: int64") + .Output("output: dtype") + .Attr("dtype: list(type) >= 1") + .Attr("shape: list(shape) >= 1") + .SetShapeFn([](shape_inference::InferenceContext* c) { + std::vector shape; + TF_RETURN_IF_ERROR(c->GetAttr("shape", &shape)); + if (shape.size() != c->num_outputs()) { + return errors::InvalidArgument("`shape` must be the same length as `types` (", shape.size(), " vs. ", c->num_outputs()); + } + for (size_t i = 0; i < shape.size(); ++i) { + shape_inference::ShapeHandle entry; + TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape(shape[i], &entry)); + c->set_output(static_cast(i), entry); + } + return Status::OK(); + }); + + REGISTER_OP("KafkaOutputSequence") .Input("topic: string") .Input("servers: string") diff --git a/tensorflow_io/kafka/python/ops/kafka_dataset_ops.py b/tensorflow_io/kafka/python/ops/kafka_dataset_ops.py index 7b6171d85..4b18ec2df 100644 --- a/tensorflow_io/kafka/python/ops/kafka_dataset_ops.py +++ b/tensorflow_io/kafka/python/ops/kafka_dataset_ops.py @@ -17,11 +17,21 @@ from __future__ import division from __future__ import print_function +import warnings + import tensorflow as tf from tensorflow import dtypes from tensorflow.compat.v1 import data from tensorflow_io.core.python.ops import core_ops +warnings.warn( + "implementation of existing tensorflow_io.kafka.KafkaDataset is " + "deprecated and will be replaced with the implementation in " + "tensorflow_io.core.python.ops.kafka_dataset_ops.KafkaDataset, " + "please check the doc of new implementation for API changes", + DeprecationWarning) + + class KafkaDataset(data.Dataset): """A Kafka Dataset that consumes the message. """ diff --git a/tests/test_audio_eager.py b/tests/test_audio_eager.py index c848b157b..5ff8e5bb0 100644 --- a/tests/test_audio_eager.py +++ b/tests/test_audio_eager.py @@ -23,7 +23,7 @@ import tensorflow as tf if not (hasattr(tf, "version") and tf.version.VERSION.startswith("2.")): tf.compat.v1.enable_eager_execution() -import tensorflow_io.audio as audio_io # pylint: disable=wrong-import-position +import tensorflow_io as tfio # pylint: disable=wrong-import-position audio_path = os.path.join( os.path.dirname(os.path.abspath(__file__)), @@ -37,33 +37,25 @@ def test_audio_dataset(): f = lambda x: float(x) / (1 << 15) - for capacity in [10, 100, 500]: - audio_dataset = audio_io.WAVDataset(audio_path, capacity=capacity).apply( - tf.data.experimental.unbatch()).map(tf.squeeze) - i = 0 - for v in audio_dataset: - assert audio_v.audio[i].numpy() == f(v.numpy()) - i += 1 - assert i == 5760 + audio_dataset = tfio.IOTensor.from_audio(audio_path).to_dataset() + i = 0 + for v in audio_dataset: + assert audio_v.audio[i].numpy() == f(v.numpy()) + i += 1 + assert i == 5760 - for capacity in [10, 100, 500]: - audio_dataset = audio_io.WAVDataset(audio_path, capacity=capacity).apply( - tf.data.experimental.unbatch()).batch(2).map(tf.squeeze) - i = 0 - for v in audio_dataset: - assert audio_v.audio[i].numpy() == f(v[0].numpy()) - assert audio_v.audio[i + 1].numpy() == f(v[1].numpy()) - i += 2 - assert i == 5760 + audio_dataset = tfio.IOTensor.from_audio(audio_path).to_dataset().batch(2) + i = 0 + for v in audio_dataset: + assert audio_v.audio[i].numpy() == f(v[0].numpy()) + assert audio_v.audio[i + 1].numpy() == f(v[1].numpy()) + i += 2 + assert i == 5760 - spec, rate = audio_io.list_wav_info(audio_path) - assert spec.dtype == tf.int16 - assert spec.shape == [5760, 1] - assert rate.numpy() == audio_v.sample_rate.numpy() - - samples = audio_io.read_wav(audio_path, spec) + samples = tfio.IOTensor.from_audio(audio_path) assert samples.dtype == tf.int16 assert samples.shape == [5760, 1] + assert samples.rate == audio_v.sample_rate.numpy() audio_24bit_path = os.path.join( os.path.dirname(os.path.abspath(__file__)), @@ -76,12 +68,8 @@ def test_audio_dataset(): expected = np.fromfile(audio_24bit_raw_path, np.int32) expected = np.reshape(expected, [22050, 2]) - spec, rate = audio_io.list_wav_info(audio_24bit_path) - assert spec.dtype == tf.int32 - assert spec.shape == [22050, 2] - assert rate.numpy() == 44100 - - samples = audio_io.read_wav(audio_24bit_path, spec) + samples = tfio.IOTensor.from_audio(audio_24bit_path) assert samples.dtype == tf.int32 assert samples.shape == [22050, 2] - assert np.all(samples.numpy() == expected) + assert samples.rate == 44100 + assert np.all(samples.to_tensor().numpy() == expected) diff --git a/tests/test_json/feature.ndjson b/tests/test_json/feature.ndjson new file mode 100644 index 000000000..aaa67a60d --- /dev/null +++ b/tests/test_json/feature.ndjson @@ -0,0 +1,2 @@ +{ "floatfeature": 1.1, "integerfeature": 2 } +{ "floatfeature": 2.1, "integerfeature": 3 } diff --git a/tests/test_json/label.ndjson b/tests/test_json/label.ndjson new file mode 100644 index 000000000..5c4aa77d0 --- /dev/null +++ b/tests/test_json/label.ndjson @@ -0,0 +1,2 @@ +{ "floatlabel": 2.2, "integerlabel": 3 } +{ "floatlabel": 1.2, "integerlabel": 3 } diff --git a/tests/test_json_eager.py b/tests/test_json_eager.py index 333f24718..55393a9e6 100644 --- a/tests/test_json_eager.py +++ b/tests/test_json_eager.py @@ -23,8 +23,65 @@ import tensorflow as tf if not (hasattr(tf, "version") and tf.version.VERSION.startswith("2.")): tf.compat.v1.enable_eager_execution() +import tensorflow_io as tfio # pylint: disable=wrong-import-position import tensorflow_io.json as json_io # pylint: disable=wrong-import-position +def test_io_tensor_json(): + """Test case for tfio.IOTensor.from_json.""" + x_test = [[1.1, 2], [2.1, 3]] + y_test = [[2.2, 3], [1.2, 3]] + feature_filename = os.path.join( + os.path.dirname(os.path.abspath(__file__)), + "test_json", + "feature.ndjson") + feature_filename = "file://" + feature_filename + label_filename = os.path.join( + os.path.dirname(os.path.abspath(__file__)), + "test_json", + "label.ndjson") + label_filename = "file://" + label_filename + + features = tfio.IOTensor.from_json(feature_filename) + assert features.dtype("floatfeature") == tf.float64 + assert features.dtype("integerfeature") == tf.int64 + + labels = tfio.IOTensor.from_json(label_filename) + assert labels.dtype("floatlabel") == tf.float64 + assert labels.dtype("integerlabel") == tf.int64 + + float_feature = features("floatfeature") + integer_feature = features("integerfeature") + float_label = labels("floatlabel") + integer_label = labels("integerlabel") + + for i in range(2): + v_x = x_test[i] + v_y = y_test[i] + assert v_x[0] == float_feature[i].numpy() + assert v_x[1] == integer_feature[i].numpy() + assert v_y[0] == float_label[i].numpy() + assert v_y[1] == integer_label[i].numpy() + + feature_dataset = features.to_dataset() + + label_dataset = labels.to_dataset() + + dataset = tf.data.Dataset.zip(( + feature_dataset, + label_dataset + )) + + i = 0 + for (j_x, j_y) in dataset: + v_x = x_test[i] + v_y = y_test[i] + for index, x in enumerate(j_x): + assert v_x[index] == x.numpy() + for index, y in enumerate(j_y): + assert v_y[index] == y.numpy() + i += 1 + assert i == len(y_test) + def test_json_dataset(): """Test case for JSON Dataset. """ diff --git a/tests/test_kafka_eager.py b/tests/test_kafka_eager.py index 81f0bbab5..0021aaa84 100644 --- a/tests/test_kafka_eager.py +++ b/tests/test_kafka_eager.py @@ -23,7 +23,23 @@ import numpy as np import tensorflow as tf -import tensorflow_io.kafka as kafka_io +if not (hasattr(tf, "version") and tf.version.VERSION.startswith("2.")): + tf.compat.v1.enable_eager_execution() +import tensorflow_io as tfio # pylint: disable=wrong-import-position +from tensorflow_io.core.python.ops import kafka_dataset_ops # pylint: disable=wrong-import-position + +def test_kafka_dataset(): + dataset = kafka_dataset_ops.KafkaDataset("test").batch(2) + assert np.all([ + e.numpy().tolist() for e in dataset] == np.asarray([ + ("D" + str(i)).encode() for i in range(10)]).reshape((5, 2))) + +def test_kafka_io_tensor(): + kafka = tfio.IOTensor.from_kafka("test") + assert kafka.dtype == tf.string + assert kafka.shape == [10] + assert np.all(kafka.to_tensor().numpy() == [ + ("D" + str(i)).encode() for i in range(10)]) @pytest.mark.skipif( not (hasattr(tf, "version") and @@ -55,7 +71,7 @@ def test_kafka_output_sequence(): class OutputCallback(tf.keras.callbacks.Callback): """KafkaOutputCallback""" def __init__(self, batch_size, topic, servers): - self._sequence = kafka_io.KafkaOutputSequence( + self._sequence = tfio.kafka.KafkaOutputSequence( topic=topic, servers=servers) self._batch_size = batch_size def on_predict_batch_end(self, batch, logs=None): @@ -78,6 +94,6 @@ def flush(self): predictions = [class_names[v] for v in np.argmax(predictions, axis=1)] # Reading from `test_e(time)e` we should get the same result - dataset = kafka_io.KafkaDataset(topics=[topic], group="test", eof=True) + dataset = tfio.kafka.KafkaDataset(topics=[topic], group="test", eof=True) for entry, prediction in zip(dataset, predictions): assert entry.numpy() == prediction.encode() diff --git a/third_party/arrow.BUILD b/third_party/arrow.BUILD index c1f10307c..d52d8242e 100644 --- a/third_party/arrow.BUILD +++ b/third_party/arrow.BUILD @@ -47,6 +47,8 @@ cc_library( "cpp/src/arrow/io/*.h", "cpp/src/arrow/ipc/*.cc", "cpp/src/arrow/ipc/*.h", + "cpp/src/arrow/json/*.cc", + "cpp/src/arrow/json/*.h", "cpp/src/arrow/util/*.cc", "cpp/src/arrow/util/*.h", "cpp/src/arrow/vendored/**/*.cpp", @@ -121,6 +123,7 @@ cc_library( ":arrow_format", "@boost", "@double_conversion//:double-conversion", + "@rapidjson", "@snappy", "@thrift", "@zlib", diff --git a/third_party/rapidjson.BUILD b/third_party/rapidjson.BUILD new file mode 100644 index 000000000..af5f2d26e --- /dev/null +++ b/third_party/rapidjson.BUILD @@ -0,0 +1,23 @@ +package(default_visibility = ["//visibility:public"]) + +licenses(["notice"]) # MIT/JSON license + +cc_library( + name = "rapidjson", + srcs = glob( + [ + ], + ) + [ + ], + hdrs = glob( + [ + "include/**/*.h", + ], + ) + [ + ], + copts = [ + ], + includes = [ + "include", + ], +)