From 2e5aec3860a97a04bfb07da89e8b2a9df390bb4b Mon Sep 17 00:00:00 2001 From: Yong Tang Date: Thu, 1 Aug 2019 02:59:57 +0000 Subject: [PATCH] Rework on PrometheusDataset This is part of the effort to reduce the dedicated C++ implementation of Dataset and replace with primitive ops that could be used both with tf.data, and with Tensor. There are some room for enhancement, for example, a timestamp could be passed to read_prometheus and each call will only read a small slice of the data to tensor. Will have follow up PRs later to implement that. Signed-off-by: Yong Tang --- tensorflow_io/core/python/ops/data_ops.py | 5 +- tensorflow_io/prometheus/BUILD | 4 +- tensorflow_io/prometheus/__init__.py | 3 + tensorflow_io/prometheus/go/prometheus.go | 24 +++--- .../prometheus/kernels/prometheus_input.cc | 83 ------------------- .../prometheus/kernels/prometheus_kernels.cc | 77 +++++++++++++++++ .../prometheus/ops/prometheus_ops.cc | 25 ++---- .../prometheus/python/ops/prometheus_ops.py | 40 +++++---- tests/test_prometheus_eager.py | 16 ++-- 9 files changed, 137 insertions(+), 140 deletions(-) delete mode 100644 tensorflow_io/prometheus/kernels/prometheus_input.cc create mode 100644 tensorflow_io/prometheus/kernels/prometheus_kernels.cc diff --git a/tensorflow_io/core/python/ops/data_ops.py b/tensorflow_io/core/python/ops/data_ops.py index bbd3b5d8b..40107037d 100644 --- a/tensorflow_io/core/python/ops/data_ops.py +++ b/tensorflow_io/core/python/ops/data_ops.py @@ -58,9 +58,8 @@ def _apply_fn(dataset): class BaseDataset(tf.compat.v2.data.Dataset): """A Base Dataset""" - def __init__(self, variant, batch, dtypes, shapes): + def __init__(self, variant, dtypes, shapes): """Create a Base Dataset.""" - self._batch = 0 if batch is None else batch self._dtypes = dtypes self._shapes = shapes super(BaseDataset, self).__init__(variant) @@ -93,4 +92,4 @@ def __init__(self, fn, data_input, batch, dtypes, shapes): self._data_input, self._batch, output_types=self._dtypes, - output_shapes=self._shapes), self._batch, self._dtypes, self._shapes) + output_shapes=self._shapes), self._dtypes, self._shapes) diff --git a/tensorflow_io/prometheus/BUILD b/tensorflow_io/prometheus/BUILD index 32858cf5e..6f3cff021 100644 --- a/tensorflow_io/prometheus/BUILD +++ b/tensorflow_io/prometheus/BUILD @@ -10,9 +10,7 @@ load( cc_library( name = "prometheus_ops", srcs = [ - #"//tensorflow_io/prometheus/go:prometheus.a", - #"//tensorflow_io/prometheus/go:prometheus.h", - "kernels/prometheus_input.cc", + "kernels/prometheus_kernels.cc", "ops/prometheus_ops.cc", ], copts = tf_io_copts(), diff --git a/tensorflow_io/prometheus/__init__.py b/tensorflow_io/prometheus/__init__.py index 7742c1eba..55d4fc376 100644 --- a/tensorflow_io/prometheus/__init__.py +++ b/tensorflow_io/prometheus/__init__.py @@ -15,6 +15,7 @@ """PrometheusInput @@PrometheusDataset +@@read_prometheus """ from __future__ import absolute_import @@ -22,11 +23,13 @@ from __future__ import print_function from tensorflow_io.prometheus.python.ops.prometheus_ops import PrometheusDataset +from tensorflow_io.prometheus.python.ops.prometheus_ops import read_prometheus from tensorflow.python.util.all_util import remove_undocumented _allowed_symbols = [ "PrometheusDataset", + "read_prometheus", ] remove_undocumented(__name__, allowed_exception_list=_allowed_symbols) diff --git a/tensorflow_io/prometheus/go/prometheus.go b/tensorflow_io/prometheus/go/prometheus.go index 36d9b8774..bb7221cbd 100644 --- a/tensorflow_io/prometheus/go/prometheus.go +++ b/tensorflow_io/prometheus/go/prometheus.go @@ -13,26 +13,28 @@ import ( ) //export Query -func Query(endpoint string, query string, sec int64, offset int64, key []int64, val []float64) int64 { +func Query(endpoint string, query string, ts int64, timestamp []int64, value []float64) int { client, err := api.NewClient(api.Config{ Address: endpoint, }) if err != nil { return -1 } - value, err := v1.NewAPI(client).Query(context.Background(), query, time.Unix(sec, 0)) + v, err := v1.NewAPI(client).Query(context.Background(), query, time.Unix(ts, 0)) if err != nil { return -1 } - if m, ok := value.(model.Matrix); ok && m.Len() > 0 { - index := int64(0) - for index < int64(len(key)) && offset+index < int64(len(m[0].Values)) { - v := m[0].Values[offset+index] - key[index] = v.Timestamp.Unix() - val[index] = float64(v.Value) - index++ + if m, ok := v.(model.Matrix); ok && m.Len() > 0 { + if len(timestamp) >= len(m[0].Values) && len(value) == len(m[0].Values) { + + for i := 0; i < len(m[0].Values); i++ { + v := m[0].Values[i] + timestamp[i] = int64(v.Timestamp) + value[i] = float64(v.Value) + } } - return index + + return len(m[0].Values) } return 0 } @@ -42,7 +44,7 @@ func main() { val := make([]float64, 20, 20) sec := time.Now().Unix() fmt.Println(sec) - returned := Query("http://localhost:9090", "coredns_dns_request_count_total[5m]", sec, 0, key, val) + returned := Query("http://localhost:9090", "coredns_dns_request_count_total[5m]", sec, key, val) fmt.Println(returned) for i := range key { fmt.Printf("%d, %q, %v\n", i, model.TimeFromUnix(key[i]).Time(), val[i]) diff --git a/tensorflow_io/prometheus/kernels/prometheus_input.cc b/tensorflow_io/prometheus/kernels/prometheus_input.cc deleted file mode 100644 index 9738ac4b3..000000000 --- a/tensorflow_io/prometheus/kernels/prometheus_input.cc +++ /dev/null @@ -1,83 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "kernels/dataset_ops.h" -#include "tensorflow/core/lib/io/buffered_inputstream.h" - -#include "go/prometheus.h" - -namespace tensorflow { -namespace data { - -class PrometheusState { -public: - PrometheusState() : time_(0), offset_(0) {} - - int64 time_; - int64 offset_; -}; - -class PrometheusInput: public StreamInput { - public: - Status ReadRecord(IteratorContext* ctx, std::unique_ptr& state, int64 record_to_read, int64* record_read, std::vector* out_tensors) const override { - if (state.get() == nullptr) { - state.reset(new PrometheusState()); - state.get()->time_ = time(NULL); - } - Tensor key_tensor(ctx->allocator({}), DT_INT64, {record_to_read}); - Tensor val_tensor(ctx->allocator({}), DT_DOUBLE, {record_to_read}); - GoSlice key_go = {key_tensor.flat().data(), record_to_read, record_to_read}; - GoSlice val_go = {val_tensor.flat().data(), record_to_read, record_to_read}; - GoString endpoint_go = {endpoint().c_str(), static_cast(endpoint().size())}; - GoString query_go = {schema().c_str(), static_cast(schema().size())}; - - GoInt returned = Query(endpoint_go, query_go, state.get()->time_, state.get()->offset_, key_go, val_go); - if (returned < 0) { - return errors::InvalidArgument("prometheus server error: ", returned); - } - if (returned > 0) { - state.get()->offset_ += returned; - *record_read = returned; - if (*record_read < record_to_read) { - Tensor key_tensor_final = key_tensor.Slice(0, *record_read); - Tensor val_tensor_final = val_tensor.Slice(0, *record_read); - out_tensors->emplace_back(std::move(key_tensor_final)); - out_tensors->emplace_back(std::move(val_tensor_final)); - } else { - out_tensors->emplace_back(std::move(key_tensor)); - out_tensors->emplace_back(std::move(val_tensor)); - } - } - return Status::OK(); - } - Status FromEndpoint(const string& endpoint) override { - return Status::OK(); - } - void EncodeAttributes(VariantTensorData* data) const override { - } - bool DecodeAttributes(const VariantTensorData& data) override { - return true; - } - protected: -}; - -REGISTER_UNARY_VARIANT_DECODE_FUNCTION(PrometheusInput, "tensorflow::data::PrometheusInput"); - -REGISTER_KERNEL_BUILDER(Name("PrometheusInput").Device(DEVICE_CPU), - StreamInputOp); -REGISTER_KERNEL_BUILDER(Name("PrometheusDataset").Device(DEVICE_CPU), - StreamInputDatasetOp); -} // namespace data -} // namespace tensorflow diff --git a/tensorflow_io/prometheus/kernels/prometheus_kernels.cc b/tensorflow_io/prometheus/kernels/prometheus_kernels.cc new file mode 100644 index 000000000..7ea56f5d4 --- /dev/null +++ b/tensorflow_io/prometheus/kernels/prometheus_kernels.cc @@ -0,0 +1,77 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/op_kernel.h" +#include "go/prometheus.h" + +namespace tensorflow { +namespace data { +namespace { + +class ReadPrometheusOp : public OpKernel { + public: + explicit ReadPrometheusOp(OpKernelConstruction* context) : OpKernel(context) { + env_ = context->env(); + } + + void Compute(OpKernelContext* context) override { + const Tensor& endpoint_tensor = context->input(0); + const string& endpoint = endpoint_tensor.scalar()(); + + const Tensor& query_tensor = context->input(1); + const string& query = query_tensor.scalar()(); + + int64 ts = time(NULL); + + GoString endpoint_go = {endpoint.c_str(), static_cast(endpoint.size())}; + GoString query_go = {query.c_str(), static_cast(query.size())}; + + GoSlice timestamp_go = {0, 0, 0}; + GoSlice value_go = {0, 0, 0}; + + GoInt returned = Query(endpoint_go, query_go, ts, timestamp_go, value_go); + OP_REQUIRES(context, returned >= 0, errors::InvalidArgument("unable to query prometheus")); + + TensorShape output_shape({static_cast(returned)}); + + Tensor* timestamp_tensor; + OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, ×tamp_tensor)); + Tensor* value_tensor; + OP_REQUIRES_OK(context, context->allocate_output(1, output_shape, &value_tensor)); + + if (returned > 0) { + timestamp_go.data = timestamp_tensor->flat().data(); + timestamp_go.len = returned; + timestamp_go.cap = returned; + value_go.data = value_tensor->flat().data(); + value_go.len = returned; + value_go.cap = returned; + + returned = Query(endpoint_go, query_go, ts, timestamp_go, value_go); + OP_REQUIRES(context, returned >= 0, errors::InvalidArgument("unable to query prometheus to get the value")); + } + } + private: + mutex mu_; + Env* env_ GUARDED_BY(mu_); +}; + +REGISTER_KERNEL_BUILDER(Name("ReadPrometheus").Device(DEVICE_CPU), + ReadPrometheusOp); + + +} // namespace +} // namespace data +} // namespace tensorflow diff --git a/tensorflow_io/prometheus/ops/prometheus_ops.cc b/tensorflow_io/prometheus/ops/prometheus_ops.cc index e47d5ac4a..bbf995346 100644 --- a/tensorflow_io/prometheus/ops/prometheus_ops.cc +++ b/tensorflow_io/prometheus/ops/prometheus_ops.cc @@ -19,27 +19,14 @@ limitations under the License. namespace tensorflow { -REGISTER_OP("PrometheusInput") - .Input("source: string") - .Output("handle: variant") - .Attr("filters: list(string) = []") - .Attr("columns: list(string) = []") - .Attr("schema: string = ''") +REGISTER_OP("ReadPrometheus") + .Input("endpoint: string") + .Input("query: string") + .Output("timestamp: int64") + .Output("value: float64") .SetShapeFn([](shape_inference::InferenceContext* c) { c->set_output(0, c->MakeShape({c->UnknownDim()})); - return Status::OK(); - }); - -REGISTER_OP("PrometheusDataset") - .Input("input: T") - .Input("batch: int64") - .Output("handle: variant") - .Attr("output_types: list(type) >= 1") - .Attr("output_shapes: list(shape) >= 1") - .Attr("T: {string, variant} = DT_VARIANT") - .SetIsStateful() - .SetShapeFn([](shape_inference::InferenceContext* c) { - c->set_output(0, c->MakeShape({})); + c->set_output(1, c->MakeShape({c->UnknownDim()})); return Status::OK(); }); diff --git a/tensorflow_io/prometheus/python/ops/prometheus_ops.py b/tensorflow_io/prometheus/python/ops/prometheus_ops.py index 24792d1e9..89b63fa23 100644 --- a/tensorflow_io/prometheus/python/ops/prometheus_ops.py +++ b/tensorflow_io/prometheus/python/ops/prometheus_ops.py @@ -18,29 +18,37 @@ from __future__ import print_function import tensorflow as tf -from tensorflow_io.core.python.ops import data_ops as data_ops -from tensorflow_io.core.python.ops import core_ops as prometheus_ops +from tensorflow_io.core.python.ops import data_ops +from tensorflow_io.core.python.ops import core_ops -class PrometheusDataset(data_ops.Dataset): - """A Prometheus Dataset - """ +def read_prometheus(endpoint, query): + """read_prometheus""" + return core_ops.read_prometheus(endpoint, query) - def __init__(self, endpoint, schema=None, batch=None): - """Create a Prometheus Reader. +class PrometheusDataset(data_ops.BaseDataset): + """A Prometheus Dataset""" + + def __init__(self, endpoint, query): + """Create a Prometheus Dataset Args: endpoint: A `tf.string` tensor containing address of the prometheus server. - schema: A `tf.string` tensor containing the query + query: A `tf.string` tensor containing the query string. - batch: Size of the batch. """ - batch = 0 if batch is None else batch dtypes = [tf.int64, tf.float64] - shapes = [ - tf.TensorShape([]), tensorflow.TensorShape([])] if batch == 0 else [ - tf.TensorShape([None]), tf.TensorShape([None])] + shapes = [tf.TensorShape([None]), tf.TensorShape([None])] + # TODO: It could be possible to improve the performance + # by reading a small chunk of the data while at the same + # time allowing reuse of read_prometheus. Essentially + # read_prometheus could take a timestamp and read small chunk + # at a time until running out of data. + timestamp, value = read_prometheus(endpoint, query) + timestamp_dataset = data_ops.BaseDataset.from_tensors(timestamp) + value_dataset = data_ops.BaseDataset.from_tensors(value) + dataset = data_ops.BaseDataset.zip((timestamp_dataset, value_dataset)) + + self._dataset = dataset super(PrometheusDataset, self).__init__( - prometheus_ops.prometheus_dataset, - prometheus_ops.prometheus_input(endpoint, schema=schema), - batch, dtypes, shapes) + self._dataset._variant_tensor, dtypes, shapes) # pylint: disable=protected-access diff --git a/tests/test_prometheus_eager.py b/tests/test_prometheus_eager.py index 9d5610c47..e99f83af3 100644 --- a/tests/test_prometheus_eager.py +++ b/tests/test_prometheus_eager.py @@ -32,17 +32,17 @@ pytest.skip( "prometheus is not supported on macOS yet", allow_module_level=True) -def test_prometheus_input(): - """test_prometheus_input - """ +def test_prometheus(): + """test_prometheus""" for _ in range(6): subprocess.call(["dig", "@localhost", "-p", "1053", "www.google.com"]) time.sleep(1) time.sleep(2) prometheus_dataset = prometheus_io.PrometheusDataset( "http://localhost:9090", - schema="coredns_dns_request_count_total[5s]", - batch=2) + "coredns_dns_request_count_total[5s]").apply( + tf.data.experimental.unbatch()).batch(2) + i = 0 for k, v in prometheus_dataset: print("K, V: ", k.numpy(), v.numpy()) @@ -52,5 +52,11 @@ def test_prometheus_input(): i += 2 assert i == 6 + timestamp, value = prometheus_io.read_prometheus( + "http://localhost:9090", + "coredns_dns_request_count_total[5s]") + assert timestamp.shape == [5] + assert value.shape == [5] + if __name__ == "__main__": test.main()