diff --git a/tensorflow_io/audio/kernels/audio_kernels.cc b/tensorflow_io/audio/kernels/audio_kernels.cc index ae4d43f21..a9b699b48 100644 --- a/tensorflow_io/audio/kernels/audio_kernels.cc +++ b/tensorflow_io/audio/kernels/audio_kernels.cc @@ -130,11 +130,11 @@ class WAVIndexable : public IOIndexableInterface { return Status::OK(); } - Status Spec(std::vector& dtypes, std::vector& shapes) override { - dtypes.clear(); - dtypes.push_back(dtype_); + Status Spec(std::vector& shapes, std::vector& dtypes) override { shapes.clear(); shapes.push_back(shape_); + dtypes.clear(); + dtypes.push_back(dtype_); return Status::OK(); } @@ -146,8 +146,7 @@ class WAVIndexable : public IOIndexableInterface { return Status::OK(); } - Status GetItem(const int64 start, const int64 stop, const int64 step, std::vector& tensors) override { - Tensor& output_tensor = tensors[0]; + Status GetItem(const int64 start, const int64 stop, const int64 step, const int64 component, Tensor* tensor) override { if (step != 1) { return errors::InvalidArgument("step ", step, " is not supported"); } @@ -184,13 +183,13 @@ class WAVIndexable : public IOIndexableInterface { if (header_.wBitsPerSample * header_.nChannels != header_.nBlockAlign * 8) { return errors::InvalidArgument("unsupported wBitsPerSample and header.nBlockAlign: ", header_.wBitsPerSample, ", ", header_.nBlockAlign); } - memcpy((char *)(output_tensor.flat().data()) + ((read_sample_start - sample_start) * header_.nBlockAlign), &buffer[0], (read_bytes_stop - read_bytes_start)); + memcpy((char *)(tensor->flat().data()) + ((read_sample_start - sample_start) * header_.nBlockAlign), &buffer[0], (read_bytes_stop - read_bytes_start)); break; case 16: if (header_.wBitsPerSample * header_.nChannels != header_.nBlockAlign * 8) { return errors::InvalidArgument("unsupported wBitsPerSample and header.nBlockAlign: ", header_.wBitsPerSample, ", ", header_.nBlockAlign); } - memcpy((char *)(output_tensor.flat().data()) + ((read_sample_start - sample_start) * header_.nBlockAlign), &buffer[0], (read_bytes_stop - read_bytes_start)); + memcpy((char *)(tensor->flat().data()) + ((read_sample_start - sample_start) * header_.nBlockAlign), &buffer[0], (read_bytes_stop - read_bytes_start)); break; case 24: // NOTE: The conversion is from signed integer 24 to signed integer 32 (left shift 8 bits) @@ -199,7 +198,7 @@ class WAVIndexable : public IOIndexableInterface { } for (int64 i = read_sample_start; i < read_sample_stop; i++) { for (int64 j = 0; j < header_.nChannels; j++) { - char *data_p = (char *)(output_tensor.flat().data() + ((i - sample_start) * header_.nChannels + j)); + char *data_p = (char *)(tensor->flat().data() + ((i - sample_start) * header_.nChannels + j)); char *read_p = (char *)(&buffer[((i - read_sample_start) * header_.nBlockAlign)]) + 3 * j; data_p[3] = read_p[2]; data_p[2] = read_p[1]; diff --git a/tensorflow_io/audio/ops/audio_ops.cc b/tensorflow_io/audio/ops/audio_ops.cc index 0d6c18fe4..f7e795b96 100644 --- a/tensorflow_io/audio/ops/audio_ops.cc +++ b/tensorflow_io/audio/ops/audio_ops.cc @@ -41,20 +41,16 @@ REGISTER_OP("WAVIndexableGetItem") .Input("start: int64") .Input("stop: int64") .Input("step: int64") + .Input("component: int64") .Output("output: dtype") - .Attr("dtype: list(type) >= 1") - .Attr("shape: list(shape) >= 1") + .Attr("shape: shape") + .Attr("dtype: type") .SetShapeFn([](shape_inference::InferenceContext* c) { - std::vector shape; + PartialTensorShape shape; TF_RETURN_IF_ERROR(c->GetAttr("shape", &shape)); - if (shape.size() != c->num_outputs()) { - return errors::InvalidArgument("`shape` must be the same length as `types` (", shape.size(), " vs. ", c->num_outputs()); - } - for (size_t i = 0; i < shape.size(); ++i) { - shape_inference::ShapeHandle entry; - TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape(shape[i], &entry)); - c->set_output(static_cast(i), entry); - } + shape_inference::ShapeHandle entry; + TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape(shape, &entry)); + c->set_output(0, entry); return Status::OK(); }); diff --git a/tensorflow_io/core/BUILD b/tensorflow_io/core/BUILD index 0005ae2f8..00c77f49c 100644 --- a/tensorflow_io/core/BUILD +++ b/tensorflow_io/core/BUILD @@ -122,9 +122,9 @@ cc_library( ) go_binary( - name = "prometheus_go", + name = "golang_ops", srcs = ["go/prometheus.go"], - out = "python/ops/libtensorflow_io_prometheus.so", + out = "python/ops/libtensorflow_io_golang.so", cgo = True, linkmode = "c-shared", visibility = ["//visibility:public"], @@ -138,11 +138,11 @@ go_binary( cc_library( name = "prometheus_go_ops", srcs = [ - "prometheus_go.h", + "golang_ops.h", ], copts = tf_io_copts(), data = [ - "//tensorflow_io/core:prometheus_go.h", + "//tensorflow_io/core:golang_ops.h", ], linkstatic = True, ) diff --git a/tensorflow_io/core/kernels/io_interface.h b/tensorflow_io/core/kernels/io_interface.h index 247804d46..e3e937c47 100644 --- a/tensorflow_io/core/kernels/io_interface.h +++ b/tensorflow_io/core/kernels/io_interface.h @@ -23,7 +23,7 @@ namespace data { class IOInterface : public ResourceBase { public: virtual Status Init(const std::vector& input, const std::vector& metadata, const void* memory_data, const int64 memory_size) = 0; - virtual Status Spec(std::vector& dtypes, std::vector& shapes) = 0; + virtual Status Spec(std::vector& shapes, std::vector& dtypes) = 0; virtual Status Extra(std::vector* extra) { // This is the chance to provide additional extra information which should be appended to extra. @@ -33,12 +33,12 @@ class IOInterface : public ResourceBase { class IOIterableInterface : public IOInterface { public: - virtual Status Next(const int64 capacity, std::vector& tensors, int64* record_read) = 0; + virtual Status Next(const int64 capacity, const int64 component, Tensor* tensor, int64* record_read) = 0; }; class IOIndexableInterface : public IOInterface { public: - virtual Status GetItem(const int64 start, const int64 stop, const int64 step, std::vector& tensors) = 0; + virtual Status GetItem(const int64 start, const int64 stop, const int64 step, const int64 component, Tensor* tensor) = 0; }; template @@ -50,9 +50,8 @@ class IOIndexableImplementation : public IOIndexableInterface { ~IOIndexableImplementation() {} Status Init(const std::vector& input, const std::vector& metadata, const void* memory_data, const int64 memory_size) override { - TF_RETURN_IF_ERROR(iterable_->Init(input, metadata, memory_data, memory_size)); - TF_RETURN_IF_ERROR(iterable_->Spec(dtypes_, shapes_)); + TF_RETURN_IF_ERROR(iterable_->Spec(shapes_, dtypes_)); const int64 capacity = 4096; std::vector chunk_shapes; @@ -66,18 +65,23 @@ class IOIndexableImplementation : public IOIndexableInterface { int64 record_read = 0; do { - tensors_.push_back(std::vector()); + chunk_tensors_.push_back(std::vector()); for (size_t component = 0; component < shapes_.size(); component++) { - tensors_.back().push_back(Tensor(dtypes_[component], chunk_shapes[component])); + chunk_tensors_.back().push_back(Tensor(dtypes_[component], chunk_shapes[component])); + int64 chunk_record_read = 0; + TF_RETURN_IF_ERROR(iterable_->Next(capacity, component, &chunk_tensors_.back()[component], &chunk_record_read)); + if (component != 0 && record_read != chunk_record_read) { + return errors::Internal("component ", component, " has differtnt chunk size: ", chunk_record_read, " vs. ", record_read); + } + record_read = chunk_record_read; } - TF_RETURN_IF_ERROR(iterable_->Next(capacity, tensors_.back(), &record_read)); if (record_read == 0) { - tensors_.pop_back(); + chunk_tensors_.pop_back(); break; } if (record_read < capacity) { for (size_t component = 0; component < shapes_.size(); component++) { - tensors_.back()[component] = tensors_.back()[component].Slice(0, record_read); + chunk_tensors_.back()[component] = chunk_tensors_.back()[component].Slice(0, record_read); } } total += record_read; @@ -87,13 +91,13 @@ class IOIndexableImplementation : public IOIndexableInterface { } return Status::OK(); } - virtual Status Spec(std::vector& dtypes, std::vector& shapes) override { - for (size_t component = 0; component < dtypes_.size(); component++) { - dtypes.push_back(dtypes_[component]); - } + virtual Status Spec(std::vector& shapes, std::vector& dtypes) override { for (size_t component = 0; component < shapes_.size(); component++) { shapes.push_back(shapes_[component]); } + for (size_t component = 0; component < dtypes_.size(); component++) { + dtypes.push_back(dtypes_[component]); + } return Status::OK(); } @@ -105,41 +109,35 @@ class IOIndexableImplementation : public IOIndexableInterface { return strings::StrCat("IOIndexableImplementation<", iterable_->DebugString(), ">[]"); } - Status GetItem(const int64 start, const int64 stop, const int64 step, std::vector& tensors) override { + Status GetItem(const int64 start, const int64 stop, const int64 step, const int64 component, Tensor* tensor) override { if (step != 1) { return errors::InvalidArgument("step != 1 is not supported: ", step); } // Find first chunk - int64 chunk_index = 0; int64 chunk_element = -1; int64 current_element = 0; - while (chunk_index < tensors_.size()) { - if (current_element <= start && start < current_element + tensors_[chunk_index][0].shape().dim_size(0)) { + while (chunk_index < chunk_tensors_.size()) { + if (current_element <= start && start < current_element + chunk_tensors_[chunk_index][component].shape().dim_size(0)) { chunk_element = start - current_element; current_element = start; break; } - current_element += tensors_[chunk_index][0].shape().dim_size(0); + current_element += chunk_tensors_[chunk_index][component].shape().dim_size(0); chunk_index++; } if (chunk_element < 0) { return errors::InvalidArgument("start is out of range: ", start); } - std::vector elements; - for (size_t component = 0; component < shapes_.size(); component++) { - TensorShape shape(shapes_[component].dim_sizes()); - shape.RemoveDim(0); - elements.push_back(Tensor(dtypes_[component], shape)); - } + TensorShape shape(shapes_[component].dim_sizes()); + shape.RemoveDim(0); + Tensor element(dtypes_[component], shape); while (current_element < stop) { - for (size_t component = 0; component < shapes_.size(); component++) { - batch_util::CopySliceToElement(tensors_[chunk_index][component], &elements[component], chunk_element); - batch_util::CopyElementToSlice(elements[component], &tensors[component], (current_element - start)); - } + batch_util::CopySliceToElement(chunk_tensors_[chunk_index][component], &element, chunk_element); + batch_util::CopyElementToSlice(element, tensor, (current_element - start)); chunk_element++; - if (chunk_element == tensors_[chunk_index][0].shape().dim_size(0)) { + if (chunk_element == chunk_tensors_[chunk_index][component].shape().dim_size(0)) { chunk_index++; chunk_element = 0; } @@ -151,9 +149,9 @@ class IOIndexableImplementation : public IOIndexableInterface { mutable mutex mu_; Env* env_ GUARDED_BY(mu_); std::unique_ptr iterable_ GUARDED_BY(mu_); - std::vector dtypes_ GUARDED_BY(mu_); std::vector shapes_ GUARDED_BY(mu_); - std::vector> tensors_; + std::vector dtypes_ GUARDED_BY(mu_); + std::vector> chunk_tensors_; }; @@ -198,9 +196,9 @@ class IOInterfaceInitOp : public ResourceOpKernel { OP_REQUIRES_OK(context, this->resource_->Init(input, metadata, memory_data, memory_size)); - std::vector dtypes; std::vector shapes; - OP_REQUIRES_OK(context, this->resource_->Spec(dtypes, shapes)); + std::vector dtypes; + OP_REQUIRES_OK(context, this->resource_->Spec(shapes, dtypes)); int64 maxrank = 0; for (size_t component = 0; component < shapes.size(); component++) { if (dynamic_cast(this->resource_) != nullptr) { @@ -212,10 +210,6 @@ class IOInterfaceInitOp : public ResourceOpKernel { } maxrank = maxrank > shapes[component].dims() ? maxrank : shapes[component].dims(); } - Tensor dtypes_tensor(DT_INT64, TensorShape({static_cast(dtypes.size())})); - for (size_t i = 0; i < dtypes.size(); i++) { - dtypes_tensor.flat()(i) = dtypes[i]; - } Tensor shapes_tensor(DT_INT64, TensorShape({static_cast(dtypes.size()), maxrank})); for (size_t component = 0; component < shapes.size(); component++) { for (int64 i = 0; i < shapes[component].dims(); i++) { @@ -225,8 +219,12 @@ class IOInterfaceInitOp : public ResourceOpKernel { shapes_tensor.tensor()(component, i) = 0; } } - context->set_output(1, dtypes_tensor); - context->set_output(2, shapes_tensor); + Tensor dtypes_tensor(DT_INT64, TensorShape({static_cast(dtypes.size())})); + for (size_t i = 0; i < dtypes.size(); i++) { + dtypes_tensor.flat()(i) = dtypes[i]; + } + context->set_output(1, shapes_tensor); + context->set_output(2, dtypes_tensor); std::vector extra; OP_REQUIRES_OK(context, this->resource_->Extra(&extra)); @@ -259,27 +257,26 @@ class IOIterableNextOp : public OpKernel { OP_REQUIRES_OK(context, context->input("capacity", &capacity_tensor)); const int64 capacity = capacity_tensor->scalar()(); + const Tensor* component_tensor; + OP_REQUIRES_OK(context, context->input("component", &component_tensor)); + const int64 component = component_tensor->scalar()(); + OP_REQUIRES(context, (capacity > 0), errors::InvalidArgument("capacity <= 0 is not supported: ", capacity)); - std::vector dtypes; std::vector shapes; - OP_REQUIRES_OK(context, resource->Spec(dtypes, shapes)); + std::vector dtypes; + OP_REQUIRES_OK(context, resource->Spec(shapes, dtypes)); - std::vector tensors; - for (size_t i = 0; i < dtypes.size(); i++) { - gtl::InlinedVector dims = shapes[i].dim_sizes(); - dims[0] = capacity; - tensors.emplace_back(Tensor(dtypes[i], TensorShape(dims))); - } + gtl::InlinedVector dims = shapes[component].dim_sizes(); + dims[0] = capacity; + Tensor tensor(dtypes[component], TensorShape(dims)); int64 record_read; - OP_REQUIRES_OK(context, resource->Next(capacity, tensors, &record_read)); - for (size_t i = 0; i < tensors.size(); i++) { - if (record_read < capacity) { - context->set_output(i, tensors[i].Slice(0, record_read)); - } else { - context->set_output(i, tensors[i]); - } + OP_REQUIRES_OK(context, resource->Next(capacity, component, &tensor, &record_read)); + if (record_read < capacity) { + context->set_output(0, tensor.Slice(0, record_read)); + } else { + context->set_output(0, tensor); } } }; @@ -307,13 +304,17 @@ class IOIndexableGetItemOp : public OpKernel { OP_REQUIRES_OK(context, context->input("step", &step_tensor)); int64 step = step_tensor->scalar()(); + const Tensor* component_tensor; + OP_REQUIRES_OK(context, context->input("component", &component_tensor)); + const int64 component = component_tensor->scalar()(); + OP_REQUIRES(context, (step == 1), errors::InvalidArgument("step != 1 is not supported: ", step)); - std::vector dtypes; std::vector shapes; - OP_REQUIRES_OK(context, resource->Spec(dtypes, shapes)); + std::vector dtypes; + OP_REQUIRES_OK(context, resource->Spec(shapes, dtypes)); - int64 count = shapes[0].dim_size(0); + int64 count = shapes[component].dim_size(0); if (start > count) { start = count; } @@ -324,16 +325,11 @@ class IOIndexableGetItemOp : public OpKernel { stop = start; } - std::vector tensors; - for (size_t i = 0; i < dtypes.size(); i++) { - gtl::InlinedVector dims = shapes[i].dim_sizes(); - dims[0] = stop - start; - tensors.emplace_back(Tensor(dtypes[i], TensorShape(dims))); - } - OP_REQUIRES_OK(context, resource->GetItem(start, stop, step, tensors)); - for (size_t i = 0; i < tensors.size(); i++) { - context->set_output(i, tensors[i]); - } + gtl::InlinedVector dims = shapes[component].dim_sizes(); + dims[0] = stop - start; + Tensor tensor(dtypes[component], TensorShape(dims)); + OP_REQUIRES_OK(context, resource->GetItem(start, stop, step, component, &tensor)); + context->set_output(0, tensor); } }; } // namespace data diff --git a/tensorflow_io/core/python/ops/__init__.py b/tensorflow_io/core/python/ops/__init__.py index 4e60b8a1e..8a447089f 100644 --- a/tensorflow_io/core/python/ops/__init__.py +++ b/tensorflow_io/core/python/ops/__init__.py @@ -66,5 +66,5 @@ def _load_library(filename, lib="op"): "unable to open file: " + "{}, from paths: {}\ncaused by: {}".format(filename, filenames, errs)) -_load_library("libtensorflow_io_prometheus.so", "dependency") +_load_library("libtensorflow_io_golang.so", "dependency") core_ops = _load_library('libtensorflow_io.so') diff --git a/tensorflow_io/core/python/ops/audio_io_tensor_ops.py b/tensorflow_io/core/python/ops/audio_io_tensor_ops.py index 78742292e..419f9d70d 100644 --- a/tensorflow_io/core/python/ops/audio_io_tensor_ops.py +++ b/tensorflow_io/core/python/ops/audio_io_tensor_ops.py @@ -23,8 +23,15 @@ from tensorflow_io.core.python.ops import io_tensor_ops from tensorflow_io.core.python.ops import core_ops -class AudioIOTensor(io_tensor_ops._ColumnIOTensor): # pylint: disable=protected-access - """AudioIOTensor""" +class AudioIOTensor(io_tensor_ops.BaseIOTensor): # pylint: disable=protected-access + """AudioIOTensor + + An `AudioIOTensor` is an `IOTensor` backed by audio files such as WAV + format. It consists of only one `Tensor` with `shape` defined as + `[n_samples, n_channels]`. It is a subclass of `BaseIOTensor` + with additional `rate` property exposed, indicating the sample rate + of the audio. + """ #============================================================================= # Constructor (private) @@ -33,20 +40,31 @@ def __init__(self, filename, internal=False): with tf.name_scope("AudioIOTensor") as scope: - resource, dtypes, shapes, rate = core_ops.wav_indexable_init( + resource, shapes, dtypes, rate = core_ops.wav_indexable_init( filename, container=scope, shared_name="%s/%s" % (filename, uuid.uuid4().hex)) + shapes = [ + tf.TensorShape( + [None if dim < 0 else dim for dim in e.numpy() if dim != 0] + ) for e in tf.unstack(shapes)] + dtypes = [tf.as_dtype(e.numpy()) for e in tf.unstack(dtypes)] + assert len(shapes) == 1 + assert len(dtypes) == 1 + shape = shapes[0] + dtype = dtypes[0] + spec = tf.TensorSpec(shape, dtype) + self._rate = rate.numpy() super(AudioIOTensor, self).__init__( - shapes, dtypes, resource, core_ops.wav_indexable_get_item, + spec, resource, core_ops.wav_indexable_get_item, internal=internal) #============================================================================= # Accessors #============================================================================= - @io_tensor_ops._BaseIOTensorMeta # pylint: disable=protected-access + @io_tensor_ops._IOTensorMeta # pylint: disable=protected-access def rate(self): """The sample `rate` of the audio stream""" return self._rate diff --git a/tensorflow_io/core/python/ops/io_tensor.py b/tensorflow_io/core/python/ops/io_tensor.py index 5a237da41..24358ec39 100644 --- a/tensorflow_io/core/python/ops/io_tensor.py +++ b/tensorflow_io/core/python/ops/io_tensor.py @@ -22,8 +22,9 @@ from tensorflow_io.core.python.ops import audio_io_tensor_ops from tensorflow_io.core.python.ops import json_io_tensor_ops from tensorflow_io.core.python.ops import kafka_io_tensor_ops +from tensorflow_io.core.python.ops import prometheus_io_tensor_ops -class IOTensor(io_tensor_ops._BaseIOTensor): # pylint: disable=protected-access +class IOTensor(io_tensor_ops._IOTensor): # pylint: disable=protected-access """IOTensor An `IOTensor` is a tensor with data backed by IO operations. For example, @@ -225,7 +226,8 @@ def from_json(cls, """ with tf.name_scope(kwargs.get("name", "IOFromJSON")): - return json_io_tensor_ops.JSONIOTensor(filename, internal=True) + return json_io_tensor_ops.JSONIOTensor( + filename, mode=kwargs.get('mode', None), internal=True) @classmethod def from_kafka(cls, @@ -264,3 +266,24 @@ def from_kafka(cls, servers=kwargs.get("servers", None), configuration=kwargs.get("configuration", None), internal=True) + + @classmethod + def from_prometheus(cls, + query, + **kwargs): + """Creates an `IOTensor` from a prometheus query. + + Args: + query: A string, the query string for prometheus. + endpoint: A string, the server address of prometheus, by default + `http://localhost:9090`. + name: A name prefix for the IOTensor (optional). + + Returns: + A (`IOTensor`, `IOTensor`) tuple that represents `timestamp` + and `value`. + + """ + with tf.name_scope(kwargs.get("name", "IOFromPrometheus")): + return prometheus_io_tensor_ops.PrometheusIOTensor( + query, endpoint=kwargs.get("endpoint", None), internal=True) diff --git a/tensorflow_io/core/python/ops/io_tensor_ops.py b/tensorflow_io/core/python/ops/io_tensor_ops.py index c5fb22e14..75c49f14e 100644 --- a/tensorflow_io/core/python/ops/io_tensor_ops.py +++ b/tensorflow_io/core/python/ops/io_tensor_ops.py @@ -12,51 +12,50 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""_BaseIOTensor""" +"""_IOTensor""" from __future__ import absolute_import from __future__ import division from __future__ import print_function import tensorflow as tf -class _BaseIOTensorMeta(property): - """_BaseIOTensorMeta is a decorator that is viewable to __repr__""" +class _IOTensorMeta(property): + """_IOTensorMeta is a decorator that is viewable to __repr__""" pass -class _BaseIOTensorDataset(tf.compat.v2.data.Dataset): +class _IOTensorDataset(tf.compat.v2.data.Dataset): """_IOTensorDataset""" def __init__(self, spec, resource, function): + components = tf.nest.flatten(spec) + start = 0 - stop = tf.nest.flatten( - tf.nest.map_structure(lambda e: e.shape, spec))[0][0] + stop = components[0].shape[0] capacity = 4096 entry_start = list(range(start, stop, capacity)) entry_stop = entry_start[1:] + [stop] - dtype = tf.nest.flatten( - tf.nest.map_structure(lambda e: e.dtype, spec)) - shape = tf.nest.flatten( - tf.nest.map_structure( - lambda e: tf.TensorShape( - [None]).concatenate(e.shape[1:]), spec)) - dataset = tf.compat.v2.data.Dataset.from_tensor_slices(( tf.constant(entry_start, tf.int64), tf.constant(entry_stop, tf.int64))) - dataset = dataset.map( - lambda start, stop: function( - resource, start, stop, 1, dtype=dtype, shape=shape)) - # Note: tf.data.Dataset consider tuple `(e, )` as one element - # instead of a sequence. So next `unbatch()` will not work. - # The tf.stack() below is necessary. - if len(dtype) == 1: - dataset = dataset.map(tf.stack) - dataset = dataset.apply(tf.data.experimental.unbatch()) + + components = [(component, e) for component, e in enumerate(components)] + components = [ + dataset.map( + lambda start, stop: function( + resource, + start, stop, 1, + component=component, + shape=e.shape, + dtype=e.dtype)) for (component, e) in components] + dataset = tf.compat.v2.data.Dataset.zip( + tf.nest.pack_sequence_as(spec, components)) + dataset = dataset.unbatch() + self._dataset = dataset self._resource = resource self._function = function - super(_BaseIOTensorDataset, self).__init__( + super(_IOTensorDataset, self).__init__( self._dataset._variant_tensor) # pylint: disable=protected-access def _inputs(self): @@ -66,22 +65,18 @@ def _inputs(self): def element_spec(self): return self._dataset.element_spec -class _BaseIOTensor(object): - """_BaseIOTensor""" +class _IOTensor(object): + """_IOTensor""" def __init__(self, spec, - resource, - function, internal=False): if not internal: raise ValueError("IOTensor constructor is private; please use one " "of the factory methods instead (e.g., " "IOTensor.from_tensor())") self._spec = spec - self._resource = resource - self._function = function - super(_BaseIOTensor, self).__init__() + super(_IOTensor, self).__init__() #============================================================================= # Accessors @@ -98,10 +93,71 @@ def spec(self): def __repr__(self): meta = "".join([", %s=%s" % ( k, repr(v.__get__(self))) for k, v in self.__class__.__dict__.items( - ) if isinstance(v, _BaseIOTensorMeta)]) + ) if isinstance(v, _IOTensorMeta)]) return "<%s: spec=%s%s>" % ( self.__class__.__name__, self.spec, meta) + #============================================================================= + # Dataset Conversions + #============================================================================= + + def to_dataset(self): + """Converts this `IOTensor` into a `tf.data.Dataset`. + + Example: + + ```python + ``` + + Args: + + Returns: + A `tf.data.Dataset` with value obtained from this `IOTensor`. + """ + return _IOTensorDataset( + self.spec, self._resource, self._function) + +class BaseIOTensor(_IOTensor): + """BaseIOTensor + + A `BaseIOTensor` is a basic `IOTensor` with only one component. + It is associated with a `Tensor` of `shape` and `dtype`, with + data backed by IO. It is the building block for `IOTensor`. + For example, a `CSVIOTensor` consists of multiple `BaseIOTensor` + where each one is a column of the CSV. + + All `IOTensor` types are either a subclass of `BaseIOTensor`, + or are a composite of a collection of `BaseIOTensor`. + + The additional properties exposed by `BaseIOTensor` are `shape` + and `dtype` associated with counterparts in `Tensor`. + """ + + def __init__(self, + spec, + resource, + function, + component=0, + internal=False): + self._resource = resource + self._function = function + self._component = component + super(BaseIOTensor, self).__init__( + spec, internal=internal) + + #============================================================================= + # Accessors + #============================================================================= + + @property + def shape(self): + """Returns the `TensorShape` that represents the shape of the tensor.""" + return self.spec.shape + + @property + def dtype(self): + """Returns the `dtype` of elements in the tensor.""" + return self.spec.dtype #============================================================================= # Indexing & Slicing @@ -122,20 +178,15 @@ def __getitem__(self, key): start = key stop = key + 1 step = 1 - dtype = tf.nest.flatten( - tf.nest.map_structure(lambda e: e.dtype, self.spec)) - shape = tf.nest.flatten( - tf.nest.map_structure(lambda e: e.shape, self.spec)) - return tf.nest.pack_sequence_as(self.spec, self._function( + return self._function( self._resource, start, stop, step, - dtype=dtype, - shape=shape)) + component=self._component, + shape=self.spec.shape, dtype=self.spec.dtype) def __len__(self): """Returns the total number of items of this IOTensor.""" - return tf.nest.flatten( - tf.nest.map_structure(lambda e: e.shape, self.spec))[0][0] + return self.shape[0] #============================================================================= # Tensor Type Conversions @@ -182,107 +233,66 @@ def to_tensor(self, **kwargs): with tf.name_scope(kwargs.get("name", "IOToTensor")): return self.__getitem__(slice(None, None)) - #============================================================================= - # Dataset Conversions - #============================================================================= - - def to_dataset(self): - """Converts this `IOTensor` into a `tf.data.Dataset`. - - Example: - - ```python - ``` - - Args: - - Returns: - A `tf.data.Dataset` with value obtained from this `IOTensor`. - """ - return _BaseIOTensorDataset( - self.spec, self._resource, self._function) - -class _ColumnIOTensor(_BaseIOTensor): - """_ColumnIOTensor""" +class _TableIOTensor(_IOTensor): + """_TableIOTensor""" def __init__(self, - shapes, - dtypes, + spec, + columns, resource, function, internal=False): - shapes = [ - tf.TensorShape( - [None if dim < 0 else dim for dim in e.numpy() if dim != 0] - ) for e in tf.unstack(shapes)] - dtypes = [tf.as_dtype(e.numpy()) for e in tf.unstack(dtypes)] - spec = [tf.TensorSpec(shape, dtype) for ( - shape, dtype) in zip(shapes, dtypes)] - assert len(spec) == 1 - spec = spec[0] - - self._shape = spec.shape - self._dtype = spec.dtype - super(_ColumnIOTensor, self).__init__( - spec, resource, function, internal=internal) + self._columns = columns + self._resource = resource + self._function = function + super(_TableIOTensor, self).__init__( + spec, internal=internal) #============================================================================= # Accessors #============================================================================= @property - def shape(self): - """Returns the `TensorShape` that represents the shape of the tensor.""" - return self._shape + def columns(self): + """The names of columns""" + return self._columns - @property - def dtype(self): - """Returns the `dtype` of elements in the tensor.""" - return self._dtype + def __call__(self, column): + """Return a BaseIOTensor with column named `column`""" + component = self.columns.index( + next(e for e in self.columns if e == column)) + spec = tf.nest.flatten(self.spec)[component] + return BaseIOTensor( + spec, self._resource, self._function, + component=component, internal=True) -class _TableIOTensor(_BaseIOTensor): - """_TableIOTensor""" +class _SeriesIOTensor(_IOTensor): + """_SeriesIOTensor""" def __init__(self, - shapes, - dtypes, - columns, - filename, + spec, resource, function, internal=False): - shapes = [ - tf.TensorShape( - [None if dim < 0 else dim for dim in e.numpy() if dim != 0] - ) for e in tf.unstack(shapes)] - dtypes = [tf.as_dtype(e.numpy()) for e in tf.unstack(dtypes)] - columns = [e.numpy().decode() for e in tf.unstack(columns)] - spec = [tf.TensorSpec(shape, dtype, column) for ( - shape, dtype, column) in zip(shapes, dtypes, columns)] - if len(spec) == 1: - spec = spec[0] - else: - spec = tuple(spec) - self._filename = filename - super(_TableIOTensor, self).__init__( - spec, resource, function, internal=internal) + self._resource = resource + self._function = function + super(_SeriesIOTensor, self).__init__( + spec, internal=internal) #============================================================================= # Accessors #============================================================================= - def columns(self): - """The `TensorSpec` of column named `name`""" - return [e.name for e in tf.nest.flatten(self.spec)] - - def shape(self, column): - """Returns the `TensorShape` shape of `column` in the tensor.""" - return next(e.shape for e in tf.nest.flatten(self.spec) if e.name == column) - - def dtype(self, column): - """Returns the `dtype` of `column` in the tensor.""" - return next(e.dtype for e in tf.nest.flatten(self.spec) if e.name == column) + @property + def index(self): + """The index column of the series""" + return BaseIOTensor( + self.spec[0], self._resource, self._function, + component=0, internal=True) - def __call__(self, column): - """Return a new IOTensor with column named `column`""" - return self.__class__(self._filename, columns=[column], internal=True) # pylint: disable=no-value-for-parameter + @property + def value(self): + """The value column of the series""" + return BaseIOTensor( + self.spec[1], self._resource, self._function, + component=1, internal=True) diff --git a/tensorflow_io/core/python/ops/json_io_tensor_ops.py b/tensorflow_io/core/python/ops/json_io_tensor_ops.py index f6bab6acc..3eb7ee57a 100644 --- a/tensorflow_io/core/python/ops/json_io_tensor_ops.py +++ b/tensorflow_io/core/python/ops/json_io_tensor_ops.py @@ -31,18 +31,24 @@ class JSONIOTensor(io_tensor_ops._TableIOTensor): # pylint: disable=protected-ac #============================================================================= def __init__(self, filename, - columns=None, + mode=None, internal=False): with tf.name_scope("JSONIOTensor") as scope: - metadata = [] - if columns is not None: - metadata.extend(["column: "+column for column in columns]) - resource, dtypes, shapes, columns = core_ops.json_indexable_init( - filename, metadata=metadata, + metadata = [] if mode is None else ["mode: %s" % mode] + resource, shapes, dtypes, columns = core_ops.json_indexable_init( + filename, + metadata=metadata, container=scope, shared_name="%s/%s" % (filename, uuid.uuid4().hex)) - self._filename = filename + shapes = [ + tf.TensorShape( + [None if dim < 0 else dim for dim in e.numpy() if dim != 0] + ) for e in tf.unstack(shapes)] + dtypes = [tf.as_dtype(e.numpy()) for e in tf.unstack(dtypes)] + columns = [e.numpy().decode() for e in tf.unstack(columns)] + spec = tuple([tf.TensorSpec(shape, dtype, column) for ( + shape, dtype, column) in zip(shapes, dtypes, columns)]) super(JSONIOTensor, self).__init__( - shapes, dtypes, columns, filename, + spec, columns, resource, core_ops.json_indexable_get_item, internal=internal) diff --git a/tensorflow_io/core/python/ops/kafka_dataset_ops.py b/tensorflow_io/core/python/ops/kafka_dataset_ops.py index 2c286bf13..690c0edbd 100644 --- a/tensorflow_io/core/python/ops/kafka_dataset_ops.py +++ b/tensorflow_io/core/python/ops/kafka_dataset_ops.py @@ -61,14 +61,12 @@ def __init__(self, capacity = 4096 dataset = tf.compat.v2.data.Dataset.range(0, sys.maxsize, capacity) dataset = dataset.map( - lambda i: core_ops.kafka_iterable_next(resource, capacity, dtype=[tf.string], shape=[tf.TensorShape([None])])) + lambda i: core_ops.kafka_iterable_next( + resource, capacity, component=0, + dtype=tf.string, shape=tf.TensorShape([None]))) dataset = dataset.apply( tf.data.experimental.take_while( lambda v: tf.greater(tf.shape(v)[0], 0))) - # Note: tf.data.Dataset consider tuple `(e, )` as one element - # instead of a sequence. So next `unbatch()` will not work. - # The tf.stack() below is necessary. - dataset = dataset.map(tf.stack) dataset = dataset.unbatch() self._resource = resource diff --git a/tensorflow_io/core/python/ops/kafka_io_tensor_ops.py b/tensorflow_io/core/python/ops/kafka_io_tensor_ops.py index ee88e4b90..0d64f6f87 100644 --- a/tensorflow_io/core/python/ops/kafka_io_tensor_ops.py +++ b/tensorflow_io/core/python/ops/kafka_io_tensor_ops.py @@ -23,7 +23,7 @@ from tensorflow_io.core.python.ops import io_tensor_ops from tensorflow_io.core.python.ops import core_ops -class KafkaIOTensor(io_tensor_ops._ColumnIOTensor): # pylint: disable=protected-access +class KafkaIOTensor(io_tensor_ops.BaseIOTensor): # pylint: disable=protected-access """KafkaIOTensor""" #============================================================================= @@ -38,10 +38,21 @@ def __init__(self, metadata = [e for e in configuration or []] if servers is not None: metadata.append("bootstrap.servers=%s" % servers) - resource, dtypes, shapes = core_ops.kafka_indexable_init( + resource, shapes, dtypes = core_ops.kafka_indexable_init( subscription, metadata=metadata, container=scope, shared_name="%s/%s" % (subscription, uuid.uuid4().hex)) + shapes = [ + tf.TensorShape( + [None if dim < 0 else dim for dim in e.numpy() if dim != 0] + ) for e in tf.unstack(shapes)] + dtypes = [tf.as_dtype(e.numpy()) for e in tf.unstack(dtypes)] + assert len(shapes) == 1 + assert len(dtypes) == 1 + shape = shapes[0] + dtype = dtypes[0] + spec = tf.TensorSpec(shape, dtype) + super(KafkaIOTensor, self).__init__( - shapes, dtypes, resource, core_ops.kafka_indexable_get_item, + spec, resource, core_ops.kafka_indexable_get_item, internal=internal) diff --git a/tensorflow_io/core/python/ops/prometheus_io_tensor_ops.py b/tensorflow_io/core/python/ops/prometheus_io_tensor_ops.py new file mode 100644 index 000000000..6debc0423 --- /dev/null +++ b/tensorflow_io/core/python/ops/prometheus_io_tensor_ops.py @@ -0,0 +1,53 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""PrometheusIOTensor""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import uuid + +import tensorflow as tf +from tensorflow_io.core.python.ops import io_tensor_ops +from tensorflow_io.core.python.ops import core_ops + +class PrometheusIOTensor(io_tensor_ops._SeriesIOTensor): # pylint: disable=protected-access + """PrometheusIOTensor""" + + #============================================================================= + # Constructor (private) + #============================================================================= + def __init__(self, + query, + endpoint=None, + internal=False): + with tf.name_scope("PrometheusIOTensor") as scope: + metadata = [] if endpoint is None else ["endpoint: %s" % endpoint] + resource, shapes, dtypes = core_ops.prometheus_indexable_init( + query, metadata=metadata, + container=scope, shared_name="%s/%s" % (query, uuid.uuid4().hex)) + shapes = [ + tf.TensorShape( + [None if dim < 0 else dim for dim in e.numpy() if dim != 0] + ) for e in tf.unstack(shapes)] + dtypes = [tf.as_dtype(e.numpy()) for e in tf.unstack(dtypes)] + assert len(shapes) == 2 + assert len(dtypes) == 2 + assert shapes[0] == shapes[1] + spec = (tf.TensorSpec(shapes[0], dtypes[0]), + tf.TensorSpec(shapes[1], dtypes[1])) + super(PrometheusIOTensor, self).__init__( + spec, resource, core_ops.prometheus_indexable_get_item, + internal=internal) diff --git a/tensorflow_io/json/kernels/json_kernels.cc b/tensorflow_io/json/kernels/json_kernels.cc index a2dff6146..18932fa33 100644 --- a/tensorflow_io/json/kernels/json_kernels.cc +++ b/tensorflow_io/json/kernels/json_kernels.cc @@ -27,6 +27,7 @@ limitations under the License. #include "arrow/json/reader.h" #include "arrow/table.h" #include "tensorflow_io/arrow/kernels/arrow_kernels.h" +#include "rapidjson/document.h" namespace tensorflow { namespace data { @@ -221,6 +222,55 @@ class JSONIndexable : public IOIndexableInterface { file_.reset(new SizedRandomAccessFile(env_, filename, memory_data, memory_size)); TF_RETURN_IF_ERROR(file_->GetFileSize(&file_size_)); + mode_ = "ndjson"; + for (size_t i = 0; i < metadata.size(); i++) { + if (metadata[i].find_first_of("mode: ") == 0) { + mode_ = metadata[i].substr(6); + } + } + + if (mode_ == "records") { + string buffer; + buffer.resize(file_size_); + StringPiece result; + TF_RETURN_IF_ERROR(file_->Read(0, file_size_, &result, &buffer[0])); + + rapidjson::Document d; + d.Parse(buffer.c_str()); + // Check the first element only + const rapidjson::Value& a = d.GetArray(); + const rapidjson::Value& o = a[0]; + for (rapidjson::Value::ConstMemberIterator oi = o.MemberBegin(); oi != o.MemberEnd(); ++oi) { + DataType dtype; + if (oi->value.IsInt64()) { + dtype = DT_INT64; + } else if (oi->value.IsDouble()) { + dtype = DT_DOUBLE; + } else { + return errors::InvalidArgument("invalid data type: ", oi->name.GetString()); + } + shapes_.push_back(TensorShape({static_cast(a.MemberCount())})); + dtypes_.push_back(dtype); + columns_.push_back(oi->name.GetString()); + tensors_.emplace_back(Tensor(dtype, TensorShape({static_cast(a.MemberCount())}))); + } + // Fill in the values + for (size_t i = 0; i < a.MemberCount(); i++) { + const rapidjson::Value& o = a[i]; + for (size_t column_index = 0; column_index < columns_.size(); column_index++) { + const rapidjson::Value& v = o[columns_[column_index].c_str()]; + if (dtypes_[column_index] == DT_INT64) { + tensors_[column_index].flat()(i) = v.GetInt64(); + } else if (dtypes_[column_index] == DT_DOUBLE) { + tensors_[column_index].flat()(i) = v.GetDouble(); + } + } + } + + return Status::OK(); + } + + json_file_.reset(new ArrowRandomAccessFile(file_.get(), file_size_)); ::arrow::Status status; @@ -234,51 +284,28 @@ class JSONIndexable : public IOIndexableInterface { return errors::InvalidArgument("unable to read table: ", status); } - std::vector columns; - for (size_t i = 0; i < metadata.size(); i++) { - if (metadata[i].find_first_of("column: ") == 0) { - columns.emplace_back(metadata[i].substr(8)); - } - } - - columns_index_.clear(); - if (columns.size() == 0) { - for (int i = 0; i < table_->num_columns(); i++) { - columns_index_.push_back(i); - } - } else { - std::unordered_map columns_map; - for (int i = 0; i < table_->num_columns(); i++) { - columns_map[table_->column(i)->name()] = i; - } - for (size_t i = 0; i < columns.size(); i++) { - columns_index_.push_back(columns_map[columns[i]]); - } - } - - dtypes_.clear(); shapes_.clear(); + dtypes_.clear(); columns_.clear(); - for (size_t i = 0; i < columns_index_.size(); i++) { - int column_index = columns_index_[i]; + for (int i = 0; i < table_->num_columns(); i++) { + shapes_.push_back(TensorShape({static_cast(table_->num_rows())})); ::tensorflow::DataType dtype; - TF_RETURN_IF_ERROR(GetTensorFlowType(table_->column(column_index)->type(), &dtype)); + TF_RETURN_IF_ERROR(GetTensorFlowType(table_->column(i)->type(), &dtype)); dtypes_.push_back(dtype); - shapes_.push_back(TensorShape({static_cast(table_->num_rows())})); - columns_.push_back(table_->column(column_index)->name()); + columns_.push_back(table_->column(i)->name()); } return Status::OK(); } - Status Spec(std::vector& dtypes, std::vector& shapes) override { - dtypes.clear(); - for (size_t i = 0; i < dtypes_.size(); i++) { - dtypes.push_back(dtypes_[i]); - } + Status Spec(std::vector& shapes, std::vector& dtypes) override { shapes.clear(); for (size_t i = 0; i < shapes_.size(); i++) { shapes.push_back(shapes_[i]); } + dtypes.clear(); + for (size_t i = 0; i < dtypes_.size(); i++) { + dtypes.push_back(dtypes_[i]); + } return Status::OK(); } @@ -292,60 +319,69 @@ class JSONIndexable : public IOIndexableInterface { return Status::OK(); } - Status GetItem(const int64 start, const int64 stop, const int64 step, std::vector& tensors) override { + Status GetItem(const int64 start, const int64 stop, const int64 step, const int64 component, Tensor* tensor) override { if (step != 1) { return errors::InvalidArgument("step ", step, " is not supported"); } - for (size_t i = 0; i < tensors.size(); i++) { - int column_index = columns_index_[i]; - std::shared_ptr<::arrow::Column> slice = table_->column(column_index)->Slice(start, stop); - - #define PROCESS_TYPE(TTYPE,ATYPE) { \ - int64 curr_index = 0; \ - for (auto chunk : slice->data()->chunks()) { \ - for (int64_t item = 0; item < chunk->length(); item++) { \ - tensors[i].flat()(curr_index) = (dynamic_cast(chunk.get()))->Value(item); \ - curr_index++; \ - } \ + if (mode_ == "records") { + if (dtypes_[component] == DT_INT64) { + memcpy(&tensor->flat().data()[0], &tensors_[component].flat().data()[start], sizeof(int64) * (stop - start)); + } else if (dtypes_[component] == DT_DOUBLE) { + memcpy(&tensor->flat().data()[0], &tensors_[component].flat().data()[start], sizeof(double) * (stop - start)); + } else { + return errors::InvalidArgument("invalid data type: ", dtypes_[component]); + } + + return Status::OK(); + } + + std::shared_ptr<::arrow::Column> slice = table_->column(component)->Slice(start, stop); + + #define PROCESS_TYPE(TTYPE,ATYPE) { \ + int64 curr_index = 0; \ + for (auto chunk : slice->data()->chunks()) { \ + for (int64_t item = 0; item < chunk->length(); item++) { \ + tensor->flat()(curr_index) = (dynamic_cast(chunk.get()))->Value(item); \ + curr_index++; \ } \ - } - switch (tensors[i].dtype()) { - case DT_BOOL: - PROCESS_TYPE(bool, ::arrow::BooleanArray); - break; - case DT_INT8: - PROCESS_TYPE(int8, ::arrow::NumericArray<::arrow::Int8Type>); - break; - case DT_UINT8: - PROCESS_TYPE(uint8, ::arrow::NumericArray<::arrow::UInt8Type>); - break; - case DT_INT16: - PROCESS_TYPE(int16, ::arrow::NumericArray<::arrow::Int16Type>); - break; - case DT_UINT16: - PROCESS_TYPE(uint16, ::arrow::NumericArray<::arrow::UInt16Type>); - break; - case DT_INT32: - PROCESS_TYPE(int32, ::arrow::NumericArray<::arrow::Int32Type>); - break; - case DT_UINT32: - PROCESS_TYPE(uint32, ::arrow::NumericArray<::arrow::UInt32Type>); - break; - case DT_INT64: - PROCESS_TYPE(int64, ::arrow::NumericArray<::arrow::Int64Type>); - break; - case DT_UINT64: - PROCESS_TYPE(uint64, ::arrow::NumericArray<::arrow::UInt64Type>); - break; - case DT_FLOAT: - PROCESS_TYPE(float, ::arrow::NumericArray<::arrow::FloatType>); - break; - case DT_DOUBLE: - PROCESS_TYPE(double, ::arrow::NumericArray<::arrow::DoubleType>); - break; - default: - return errors::InvalidArgument("data type is not supported: ", DataTypeString(tensors[i].dtype())); + } \ } + switch (tensor->dtype()) { + case DT_BOOL: + PROCESS_TYPE(bool, ::arrow::BooleanArray); + break; + case DT_INT8: + PROCESS_TYPE(int8, ::arrow::NumericArray<::arrow::Int8Type>); + break; + case DT_UINT8: + PROCESS_TYPE(uint8, ::arrow::NumericArray<::arrow::UInt8Type>); + break; + case DT_INT16: + PROCESS_TYPE(int16, ::arrow::NumericArray<::arrow::Int16Type>); + break; + case DT_UINT16: + PROCESS_TYPE(uint16, ::arrow::NumericArray<::arrow::UInt16Type>); + break; + case DT_INT32: + PROCESS_TYPE(int32, ::arrow::NumericArray<::arrow::Int32Type>); + break; + case DT_UINT32: + PROCESS_TYPE(uint32, ::arrow::NumericArray<::arrow::UInt32Type>); + break; + case DT_INT64: + PROCESS_TYPE(int64, ::arrow::NumericArray<::arrow::Int64Type>); + break; + case DT_UINT64: + PROCESS_TYPE(uint64, ::arrow::NumericArray<::arrow::UInt64Type>); + break; + case DT_FLOAT: + PROCESS_TYPE(float, ::arrow::NumericArray<::arrow::FloatType>); + break; + case DT_DOUBLE: + PROCESS_TYPE(double, ::arrow::NumericArray<::arrow::DoubleType>); + break; + default: + return errors::InvalidArgument("data type is not supported: ", DataTypeString(tensor->dtype())); } return Status::OK(); @@ -364,10 +400,12 @@ class JSONIndexable : public IOIndexableInterface { std::shared_ptr<::arrow::json::TableReader> reader_; std::shared_ptr<::arrow::Table> table_; + std::vector tensors_; + string mode_; + std::vector dtypes_; std::vector shapes_; std::vector columns_; - std::vector columns_index_; }; REGISTER_KERNEL_BUILDER(Name("JSONIndexableInit").Device(DEVICE_CPU), diff --git a/tensorflow_io/json/ops/json_ops.cc b/tensorflow_io/json/ops/json_ops.cc index c9206409a..5d2227d28 100644 --- a/tensorflow_io/json/ops/json_ops.cc +++ b/tensorflow_io/json/ops/json_ops.cc @@ -23,8 +23,8 @@ REGISTER_OP("JSONIndexableInit") .Input("input: string") .Input("metadata: string") .Output("output: resource") - .Output("dtypes: int64") .Output("shapes: int64") + .Output("dtypes: int64") .Output("columns: string") .Attr("container: string = ''") .Attr("shared_name: string = ''") @@ -42,24 +42,19 @@ REGISTER_OP("JSONIndexableGetItem") .Input("start: int64") .Input("stop: int64") .Input("step: int64") + .Input("component: int64") .Output("output: dtype") - .Attr("dtype: list(type) >= 1") - .Attr("shape: list(shape) >= 1") + .Attr("shape: shape") + .Attr("dtype: type") .SetShapeFn([](shape_inference::InferenceContext* c) { - std::vector shape; + PartialTensorShape shape; TF_RETURN_IF_ERROR(c->GetAttr("shape", &shape)); - if (shape.size() != c->num_outputs()) { - return errors::InvalidArgument("`shape` must be the same length as `types` (", shape.size(), " vs. ", c->num_outputs()); - } - for (size_t i = 0; i < shape.size(); ++i) { - shape_inference::ShapeHandle entry; - TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape(shape[i], &entry)); - c->set_output(static_cast(i), entry); - } + shape_inference::ShapeHandle entry; + TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape(shape, &entry)); + c->set_output(0, entry); return Status::OK(); }); - REGISTER_OP("ListJSONColumns") .Input("filename: string") .Output("columns: string") diff --git a/tensorflow_io/kafka/kernels/kafka_kernels.cc b/tensorflow_io/kafka/kernels/kafka_kernels.cc index 2a3a25f20..25b686fe4 100644 --- a/tensorflow_io/kafka/kernels/kafka_kernels.cc +++ b/tensorflow_io/kafka/kernels/kafka_kernels.cc @@ -177,7 +177,7 @@ class KafkaIterable : public IOIterableInterface { return Status::OK(); } - Status Next(const int64 capacity, std::vector& tensors, int64* record_read) override { + Status Next(const int64 capacity, const int64 component, Tensor* tensor, int64* record_read) override { *record_read = 0; while (consumer_.get() != nullptr && (*record_read) < capacity) { if (!kafka_event_cb_.run()) { @@ -192,7 +192,7 @@ class KafkaIterable : public IOIterableInterface { std::unique_ptr message(consumer_->consume(timeout_)); if (message->err() == RdKafka::ERR_NO_ERROR) { // Produce the line as output. - tensors[0].flat()((*record_read)) = std::string(static_cast(message->payload()), message->len()); + tensor->flat()((*record_read)) = std::string(static_cast(message->payload()), message->len()); // Sync offset offset_ = message->offset(); (*record_read)++; @@ -216,11 +216,11 @@ class KafkaIterable : public IOIterableInterface { } return Status::OK(); } - Status Spec(std::vector& dtypes, std::vector& shapes) override { - dtypes.clear(); - dtypes.push_back(DT_STRING); + Status Spec(std::vector& shapes, std::vector& dtypes) override { shapes.clear(); shapes.push_back(PartialTensorShape({-1})); + dtypes.clear(); + dtypes.push_back(DT_STRING); return Status::OK(); } diff --git a/tensorflow_io/kafka/ops/kafka_ops.cc b/tensorflow_io/kafka/ops/kafka_ops.cc index 1ca333eac..0a8a3c75d 100644 --- a/tensorflow_io/kafka/ops/kafka_ops.cc +++ b/tensorflow_io/kafka/ops/kafka_ops.cc @@ -40,20 +40,16 @@ REGISTER_OP("KafkaIndexableGetItem") .Input("start: int64") .Input("stop: int64") .Input("step: int64") + .Input("component: int64") .Output("output: dtype") - .Attr("dtype: list(type) >= 1") - .Attr("shape: list(shape) >= 1") + .Attr("shape: shape") + .Attr("dtype: type") .SetShapeFn([](shape_inference::InferenceContext* c) { - std::vector shape; + PartialTensorShape shape; TF_RETURN_IF_ERROR(c->GetAttr("shape", &shape)); - if (shape.size() != c->num_outputs()) { - return errors::InvalidArgument("`shape` must be the same length as `types` (", shape.size(), " vs. ", c->num_outputs()); - } - for (size_t i = 0; i < shape.size(); ++i) { - shape_inference::ShapeHandle entry; - TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape(shape[i], &entry)); - c->set_output(static_cast(i), entry); - } + shape_inference::ShapeHandle entry; + TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape(shape, &entry)); + c->set_output(0, entry); return Status::OK(); }); @@ -76,24 +72,19 @@ REGISTER_OP("KafkaIterableInit") REGISTER_OP("KafkaIterableNext") .Input("input: resource") .Input("capacity: int64") + .Input("component: int64") .Output("output: dtype") - .Attr("dtype: list(type) >= 1") - .Attr("shape: list(shape) >= 1") + .Attr("shape: shape") + .Attr("dtype: type") .SetShapeFn([](shape_inference::InferenceContext* c) { - std::vector shape; + PartialTensorShape shape; TF_RETURN_IF_ERROR(c->GetAttr("shape", &shape)); - if (shape.size() != c->num_outputs()) { - return errors::InvalidArgument("`shape` must be the same length as `types` (", shape.size(), " vs. ", c->num_outputs()); - } - for (size_t i = 0; i < shape.size(); ++i) { - shape_inference::ShapeHandle entry; - TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape(shape[i], &entry)); - c->set_output(static_cast(i), entry); - } + shape_inference::ShapeHandle entry; + TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape(shape, &entry)); + c->set_output(0, entry); return Status::OK(); }); - REGISTER_OP("KafkaOutputSequence") .Input("topic: string") .Input("servers: string") diff --git a/tensorflow_io/prometheus/BUILD b/tensorflow_io/prometheus/BUILD index 9348aeae6..7ec41776a 100644 --- a/tensorflow_io/prometheus/BUILD +++ b/tensorflow_io/prometheus/BUILD @@ -19,7 +19,4 @@ cc_library( "//tensorflow_io/core:dataset_ops", "//tensorflow_io/core:prometheus_go_ops", ], - #data = [ - # "//tensorflow_io/core:prometheus_go.h", - #] ) diff --git a/tensorflow_io/prometheus/kernels/prometheus_kernels.cc b/tensorflow_io/prometheus/kernels/prometheus_kernels.cc index 31c94f71b..d65b79897 100644 --- a/tensorflow_io/prometheus/kernels/prometheus_kernels.cc +++ b/tensorflow_io/prometheus/kernels/prometheus_kernels.cc @@ -14,7 +14,8 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/core/framework/op_kernel.h" -#include "tensorflow_io/core/prometheus_go.h" +#include "tensorflow_io/core/kernels/io_interface.h" +#include "tensorflow_io/core/golang_ops.h" namespace tensorflow { namespace data { @@ -74,5 +75,109 @@ REGISTER_KERNEL_BUILDER(Name("ReadPrometheus").Device(DEVICE_CPU), } // namespace + + +class PrometheusIndexable : public IOIndexableInterface { + public: + PrometheusIndexable(Env* env) + : env_(env) {} + + ~PrometheusIndexable() {} + Status Init(const std::vector& input, const std::vector& metadata, const void* memory_data, const int64 memory_size) override { + if (input.size() > 1) { + return errors::InvalidArgument("more than 1 query is not supported"); + } + const string& query = input[0]; + + string endpoint = "http://localhost:9090"; + for (size_t i = 0; i < metadata.size(); i++) { + if (metadata[i].find_first_of("endpoint: ") == 0) { + endpoint = metadata[i].substr(8); + } + } + + int64 ts = time(NULL); + + GoString endpoint_go = {endpoint.c_str(), static_cast(endpoint.size())}; + GoString query_go = {query.c_str(), static_cast(query.size())}; + + GoSlice timestamp_go = {0, 0, 0}; + GoSlice value_go = {0, 0, 0}; + + GoInt returned = Query(endpoint_go, query_go, ts, timestamp_go, value_go); + if (returned < 0) { + return errors::InvalidArgument("unable to query prometheus"); + } + + timestamp_.resize(returned); + value_.resize(returned); + + if (returned > 0) { + timestamp_go.data = ×tamp_[0]; + timestamp_go.len = returned; + timestamp_go.cap = returned; + value_go.data = &value_[0]; + value_go.len = returned; + value_go.cap = returned; + + returned = Query(endpoint_go, query_go, ts, timestamp_go, value_go); + if (returned < 0) { + return errors::InvalidArgument("unable to query prometheus to get the value"); + } + } + + // timestamp, value + dtypes_.emplace_back(DT_INT64); + shapes_.emplace_back(TensorShape({static_cast(returned)})); + dtypes_.emplace_back(DT_DOUBLE); + shapes_.emplace_back(TensorShape({static_cast(returned)})); + + return Status::OK(); + } + Status Spec(std::vector& shapes, std::vector& dtypes) override { + shapes.clear(); + for (size_t i = 0; i < shapes_.size(); i++) { + shapes.push_back(shapes_[i]); + } + dtypes.clear(); + for (size_t i = 0; i < dtypes_.size(); i++) { + dtypes.push_back(dtypes_[i]); + } + return Status::OK(); + } + + Status GetItem(const int64 start, const int64 stop, const int64 step, const int64 component, Tensor* tensor) override { + if (step != 1) { + return errors::InvalidArgument("step ", step, " is not supported"); + } + if (component == 0) { + memcpy(&tensor->flat().data()[start], ×tamp_[0], sizeof(int64) * (stop - start)); + } else { + memcpy(&tensor->flat().data()[start], &value_[0], sizeof(double) * (stop - start)); + } + + return Status::OK(); + } + + string DebugString() const override { + mutex_lock l(mu_); + return strings::StrCat("PrometheusIndexable"); + } + private: + mutable mutex mu_; + Env* env_ GUARDED_BY(mu_); + + std::vector dtypes_; + std::vector shapes_; + + std::vector timestamp_; + std::vector value_; +}; + +REGISTER_KERNEL_BUILDER(Name("PrometheusIndexableInit").Device(DEVICE_CPU), + IOInterfaceInitOp); +REGISTER_KERNEL_BUILDER(Name("PrometheusIndexableGetItem").Device(DEVICE_CPU), + IOIndexableGetItemOp); + } // namespace data } // namespace tensorflow diff --git a/tensorflow_io/prometheus/ops/prometheus_ops.cc b/tensorflow_io/prometheus/ops/prometheus_ops.cc index bbf995346..d0c8a51da 100644 --- a/tensorflow_io/prometheus/ops/prometheus_ops.cc +++ b/tensorflow_io/prometheus/ops/prometheus_ops.cc @@ -19,6 +19,40 @@ limitations under the License. namespace tensorflow { +REGISTER_OP("PrometheusIndexableInit") + .Input("input: string") + .Input("metadata: string") + .Output("output: resource") + .Output("shapes: int64") + .Output("dtypes: int64") + .Attr("container: string = ''") + .Attr("shared_name: string = ''") + .SetIsStateful() + .SetShapeFn([](shape_inference::InferenceContext* c) { + c->set_output(0, c->Scalar()); + c->set_output(1, c->MakeShape({c->UnknownDim()})); + c->set_output(2, c->MakeShape({c->UnknownDim(), c->UnknownDim()})); + return Status::OK(); + }); + +REGISTER_OP("PrometheusIndexableGetItem") + .Input("input: resource") + .Input("start: int64") + .Input("stop: int64") + .Input("step: int64") + .Input("component: int64") + .Output("output: dtype") + .Attr("shape: shape") + .Attr("dtype: type") + .SetShapeFn([](shape_inference::InferenceContext* c) { + PartialTensorShape shape; + TF_RETURN_IF_ERROR(c->GetAttr("shape", &shape)); + shape_inference::ShapeHandle entry; + TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape(shape, &entry)); + c->set_output(0, entry); + return Status::OK(); + }); + REGISTER_OP("ReadPrometheus") .Input("endpoint: string") .Input("query: string") diff --git a/tensorflow_io/prometheus/python/ops/prometheus_ops.py b/tensorflow_io/prometheus/python/ops/prometheus_ops.py index 89b63fa23..a7e5e6fee 100644 --- a/tensorflow_io/prometheus/python/ops/prometheus_ops.py +++ b/tensorflow_io/prometheus/python/ops/prometheus_ops.py @@ -17,10 +17,19 @@ from __future__ import division from __future__ import print_function +import warnings + import tensorflow as tf from tensorflow_io.core.python.ops import data_ops from tensorflow_io.core.python.ops import core_ops +warnings.warn( + "The tensorflow_io.prometheus.PrometheusDataset is " + "deprecated. Please look for tfio.IOTensor.from_prometheus " + "for reading prometheus observations into tensorflow.", + DeprecationWarning) + + def read_prometheus(endpoint, query): """read_prometheus""" return core_ops.read_prometheus(endpoint, query) diff --git a/tests/test_json_eager.py b/tests/test_json_eager.py index 55393a9e6..e254e3215 100644 --- a/tests/test_json_eager.py +++ b/tests/test_json_eager.py @@ -26,6 +26,62 @@ import tensorflow_io as tfio # pylint: disable=wrong-import-position import tensorflow_io.json as json_io # pylint: disable=wrong-import-position +def test_io_tensor_json_recods_mode(): + """Test case for tfio.IOTensor.from_json.""" + x_test = [[1.1, 2], [2.1, 3]] + y_test = [[2.2, 3], [1.2, 3]] + feature_filename = os.path.join( + os.path.dirname(os.path.abspath(__file__)), + "test_json", + "feature.json") + feature_filename = "file://" + feature_filename + label_filename = os.path.join( + os.path.dirname(os.path.abspath(__file__)), + "test_json", + "label.json") + label_filename = "file://" + label_filename + + features = tfio.IOTensor.from_json(feature_filename, mode='records') + assert features("floatfeature").dtype == tf.float64 + assert features("integerfeature").dtype == tf.int64 + + labels = tfio.IOTensor.from_json(label_filename, mode='records') + assert labels("floatlabel").dtype == tf.float64 + assert labels("integerlabel").dtype == tf.int64 + + float_feature = features("floatfeature") + integer_feature = features("integerfeature") + float_label = labels("floatlabel") + integer_label = labels("integerlabel") + + for i in range(2): + v_x = x_test[i] + v_y = y_test[i] + assert v_x[0] == float_feature[i].numpy() + assert v_x[1] == integer_feature[i].numpy() + assert v_y[0] == float_label[i].numpy() + assert v_y[1] == integer_label[i].numpy() + + feature_dataset = features.to_dataset() + + label_dataset = labels.to_dataset() + + dataset = tf.data.Dataset.zip(( + feature_dataset, + label_dataset + )) + + i = 0 + for (j_x, j_y) in dataset: + v_x = x_test[i] + v_y = y_test[i] + for index, x in enumerate(j_x): + assert v_x[index] == x.numpy() + for index, y in enumerate(j_y): + assert v_y[index] == y.numpy() + i += 1 + assert i == len(y_test) + def test_io_tensor_json(): """Test case for tfio.IOTensor.from_json.""" x_test = [[1.1, 2], [2.1, 3]] @@ -42,12 +98,12 @@ def test_io_tensor_json(): label_filename = "file://" + label_filename features = tfio.IOTensor.from_json(feature_filename) - assert features.dtype("floatfeature") == tf.float64 - assert features.dtype("integerfeature") == tf.int64 + assert features("floatfeature").dtype == tf.float64 + assert features("integerfeature").dtype == tf.int64 labels = tfio.IOTensor.from_json(label_filename) - assert labels.dtype("floatlabel") == tf.float64 - assert labels.dtype("integerlabel") == tf.int64 + assert labels("floatlabel").dtype == tf.float64 + assert labels("integerlabel").dtype == tf.int64 float_feature = features("floatfeature") integer_feature = features("integerfeature") diff --git a/tests/test_prometheus_eager.py b/tests/test_prometheus_eager.py index e99f83af3..34744e67c 100644 --- a/tests/test_prometheus_eager.py +++ b/tests/test_prometheus_eager.py @@ -26,7 +26,7 @@ import tensorflow as tf if not (hasattr(tf, "version") and tf.version.VERSION.startswith("2.")): tf.compat.v1.enable_eager_execution() -import tensorflow_io.prometheus as prometheus_io # pylint: disable=wrong-import-position +import tensorflow_io as tfio # pylint: disable=wrong-import-position if sys.platform == "darwin": pytest.skip( @@ -38,25 +38,14 @@ def test_prometheus(): subprocess.call(["dig", "@localhost", "-p", "1053", "www.google.com"]) time.sleep(1) time.sleep(2) - prometheus_dataset = prometheus_io.PrometheusDataset( - "http://localhost:9090", - "coredns_dns_request_count_total[5s]").apply( - tf.data.experimental.unbatch()).batch(2) - - i = 0 - for k, v in prometheus_dataset: - print("K, V: ", k.numpy(), v.numpy()) - if i == 4: - # Last entry guaranteed 6.0 - assert v.numpy() == 6.0 - i += 2 - assert i == 6 - - timestamp, value = prometheus_io.read_prometheus( - "http://localhost:9090", + prometheus = tfio.IOTensor.from_prometheus( "coredns_dns_request_count_total[5s]") - assert timestamp.shape == [5] - assert value.shape == [5] + assert prometheus.index.shape == [5] + assert prometheus.index.dtype == tf.int64 + assert prometheus.value.shape == [5] + assert prometheus.value.dtype == tf.float64 + # last value should be 6.0 + assert prometheus.value.to_tensor().numpy()[4] == 6.0 if __name__ == "__main__": test.main()