Skip to content

Commit e2c6abf

Browse files
committed
Add tfio.IOTensor.from_prometheus support
In the past couple of months I have been trying to come up with a good example of using Prometheus with TensorFlow for infrastructure/compute usage prediction and alerting in case of anomaly. My plan was to use LSTM to combine with prometheus for that. Though there is a PrometheusDataset support, the format of dataset is not very intuitive and hard to do it in a smooth way. The biggest challenge is that, normally for time series data you have a look back window to train, and you normalize the data (with total). Both are not easily available. Further more, to process data in normal TF operations such as tf.roll, it requires reading the whole data into one Tenor. But this is not very straightforward as well with iterable dataset. This PR adds tfio.IOTensor.from_prometheus which allows easily convertion the prometheus observation data into a tuple of Tensors, and do additional feature engineering. Signed-off-by: Yong Tang <[email protected]>
1 parent 525d55f commit e2c6abf

File tree

7 files changed

+281
-18
lines changed

7 files changed

+281
-18
lines changed

tensorflow_io/core/python/ops/io_tensor.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from tensorflow_io.core.python.ops import audio_io_tensor_ops
2323
from tensorflow_io.core.python.ops import json_io_tensor_ops
2424
from tensorflow_io.core.python.ops import kafka_io_tensor_ops
25+
from tensorflow_io.core.python.ops import prometheus_io_tensor_ops
2526

2627
class IOTensor(io_tensor_ops._BaseIOTensor): # pylint: disable=protected-access
2728
"""IOTensor
@@ -264,3 +265,27 @@ def from_kafka(cls,
264265
servers=kwargs.get("servers", None),
265266
configuration=kwargs.get("configuration", None),
266267
internal=True)
268+
269+
@classmethod
270+
def from_prometheus(cls,
271+
query,
272+
**kwargs):
273+
"""Creates an `IOTensor` from a prometheus query.
274+
275+
Args:
276+
query: A string, the query string for prometheus.
277+
endpoint: A string, the server address of prometheus, by default
278+
`http://localhost:9090`.
279+
name: A name prefix for the IOTensor (optional).
280+
281+
Returns:
282+
A (`IOTensor`, `IOTensor`) tuple that represents `timestamp`
283+
and `value`.
284+
285+
"""
286+
with tf.name_scope(kwargs.get("name", "IOFromPrometheus")):
287+
return prometheus_io_tensor_ops.PrometheusTimestampIOTensor(
288+
query, endpoint=kwargs.get("endpoint", None), internal=True,
289+
), prometheus_io_tensor_ops.PrometheusValueIOTensor(
290+
query, endpoint=kwargs.get("endpoint", None), internal=True,
291+
)
Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
"""PrometheusTimestampIOTensor"""
16+
from __future__ import absolute_import
17+
from __future__ import division
18+
from __future__ import print_function
19+
20+
import uuid
21+
22+
import tensorflow as tf
23+
from tensorflow_io.core.python.ops import io_tensor_ops
24+
from tensorflow_io.core.python.ops import core_ops
25+
26+
class PrometheusTimestampIOTensor(io_tensor_ops._ColumnIOTensor): # pylint: disable=protected-access
27+
"""PrometheusTimestampIOTensor"""
28+
29+
#=============================================================================
30+
# Constructor (private)
31+
#=============================================================================
32+
def __init__(self,
33+
query,
34+
endpoint=None,
35+
internal=False):
36+
with tf.name_scope("PrometheusTimestampIOTensor") as scope:
37+
metadata = ["column: timestamp"]
38+
if endpoint is not None:
39+
metadata.append(["endpoint: %s" % endpoint])
40+
resource, dtypes, shapes, _ = core_ops.prometheus_indexable_init(
41+
query, metadata=metadata,
42+
container=scope, shared_name="%s/%s" % (query, uuid.uuid4().hex))
43+
super(PrometheusTimestampIOTensor, self).__init__(
44+
shapes, dtypes, resource, core_ops.prometheus_indexable_get_item,
45+
internal=internal)
46+
47+
class PrometheusValueIOTensor(io_tensor_ops._ColumnIOTensor): # pylint: disable=protected-access
48+
"""PrometheusValueIOTensor"""
49+
50+
#=============================================================================
51+
# Constructor (private)
52+
#=============================================================================
53+
def __init__(self,
54+
query,
55+
endpoint=None,
56+
internal=False):
57+
with tf.name_scope("PrometheusTimestampIOTensor") as scope:
58+
metadata = ["column: value"]
59+
if endpoint is not None:
60+
metadata.append(["endpoint: %s" % endpoint])
61+
resource, dtypes, shapes, _ = core_ops.prometheus_indexable_init(
62+
query, metadata=metadata,
63+
container=scope, shared_name="%s/%s" % (query, uuid.uuid4().hex))
64+
super(PrometheusValueIOTensor, self).__init__(
65+
shapes, dtypes, resource, core_ops.prometheus_indexable_get_item,
66+
internal=internal)

tensorflow_io/json/kernels/json_kernels.cc

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -347,7 +347,6 @@ class JSONIndexable : public IOIndexableInterface {
347347
}
348348

349349
Status GetItem(const int64 start, const int64 stop, const int64 step, std::vector<Tensor>& tensors) override {
350-
Tensor& output_tensor = tensors[0];
351350
if (step != 1) {
352351
return errors::InvalidArgument("step ", step, " is not supported");
353352
}

tensorflow_io/prometheus/kernels/prometheus_kernels.cc

Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ limitations under the License.
1414
==============================================================================*/
1515

1616
#include "tensorflow/core/framework/op_kernel.h"
17+
#include "tensorflow_io/core/kernels/io_interface.h"
1718
#include "tensorflow_io/core/prometheus_go.h"
1819

1920
namespace tensorflow {
@@ -74,5 +75,138 @@ REGISTER_KERNEL_BUILDER(Name("ReadPrometheus").Device(DEVICE_CPU),
7475

7576

7677
} // namespace
78+
79+
80+
class PrometheusIndexable : public IOIndexableInterface {
81+
public:
82+
PrometheusIndexable(Env* env)
83+
: env_(env) {}
84+
85+
~PrometheusIndexable() {}
86+
Status Init(const std::vector<string>& input, const std::vector<string>& metadata, const void* memory_data, const int64 memory_size) override {
87+
if (input.size() > 1) {
88+
return errors::InvalidArgument("more than 1 query is not supported");
89+
}
90+
const string& query = input[0];
91+
92+
string endpoint = "http://localhost:9090";
93+
for (size_t i = 0; i < metadata.size(); i++) {
94+
if (metadata[i].find_first_of("endpoint: ") == 0) {
95+
endpoint = metadata[i].substr(8);
96+
}
97+
}
98+
99+
int64 ts = time(NULL);
100+
101+
GoString endpoint_go = {endpoint.c_str(), static_cast<int64>(endpoint.size())};
102+
GoString query_go = {query.c_str(), static_cast<int64>(query.size())};
103+
104+
GoSlice timestamp_go = {0, 0, 0};
105+
GoSlice value_go = {0, 0, 0};
106+
107+
GoInt returned = Query(endpoint_go, query_go, ts, timestamp_go, value_go);
108+
if (returned < 0) {
109+
return errors::InvalidArgument("unable to query prometheus");
110+
}
111+
112+
timestamp_.resize(returned);
113+
value_.resize(returned);
114+
115+
if (returned > 0) {
116+
timestamp_go.data = &timestamp_[0];
117+
timestamp_go.len = returned;
118+
timestamp_go.cap = returned;
119+
value_go.data = &value_[0];
120+
value_go.len = returned;
121+
value_go.cap = returned;
122+
123+
returned = Query(endpoint_go, query_go, ts, timestamp_go, value_go);
124+
if (returned < 0) {
125+
return errors::InvalidArgument("unable to query prometheus to get the value");
126+
}
127+
}
128+
129+
for (size_t i = 0; i < metadata.size(); i++) {
130+
if (metadata[i].find_first_of("column: ") == 0) {
131+
columns_.emplace_back(metadata[i].substr(8));
132+
}
133+
}
134+
if (columns_.size() == 0) {
135+
columns_.emplace_back("timestamp");
136+
columns_.emplace_back("value");
137+
}
138+
139+
for (size_t i = 0; i < columns_.size(); i++) {
140+
if (columns_[i] == "timestamp") {
141+
dtypes_.emplace_back(DT_INT64);
142+
shapes_.emplace_back(TensorShape({static_cast<int64>(returned)}));
143+
} else if (columns_[i] == "value") {
144+
dtypes_.emplace_back(DT_DOUBLE);
145+
shapes_.emplace_back(TensorShape({static_cast<int64>(returned)}));
146+
} else {
147+
return errors::InvalidArgument("column name other than `timestamp` or `value` is not supported: ", columns_[i]);
148+
}
149+
}
150+
151+
return Status::OK();
152+
}
153+
Status Spec(std::vector<DataType>& dtypes, std::vector<PartialTensorShape>& shapes) override {
154+
dtypes.clear();
155+
for (size_t i = 0; i < dtypes_.size(); i++) {
156+
dtypes.push_back(dtypes_[i]);
157+
}
158+
shapes.clear();
159+
for (size_t i = 0; i < shapes_.size(); i++) {
160+
shapes.push_back(shapes_[i]);
161+
}
162+
return Status::OK();
163+
}
164+
165+
Status Extra(std::vector<Tensor>* extra) override {
166+
// Expose columns
167+
Tensor columns(DT_STRING, TensorShape({static_cast<int64>(columns_.size())}));
168+
for (size_t i = 0; i < columns_.size(); i++) {
169+
columns.flat<string>()(i) = columns_[i];
170+
}
171+
extra->push_back(columns);
172+
return Status::OK();
173+
}
174+
175+
Status GetItem(const int64 start, const int64 stop, const int64 step, std::vector<Tensor>& tensors) override {
176+
if (step != 1) {
177+
return errors::InvalidArgument("step ", step, " is not supported");
178+
}
179+
for (size_t i = 0; i < columns_.size(); i++) {
180+
if (columns_[i] == "timestamp") {
181+
memcpy(tensors[i].flat<int64>().data(), &timestamp_[0], sizeof(int64) * (stop - start));
182+
} else {
183+
memcpy(tensors[i].flat<double>().data(), &value_[0], sizeof(double) * (stop - start));
184+
}
185+
}
186+
187+
return Status::OK();
188+
}
189+
190+
string DebugString() const override {
191+
mutex_lock l(mu_);
192+
return strings::StrCat("PrometheusIndexable");
193+
}
194+
private:
195+
mutable mutex mu_;
196+
Env* env_ GUARDED_BY(mu_);
197+
198+
std::vector<DataType> dtypes_;
199+
std::vector<TensorShape> shapes_;
200+
std::vector<string> columns_;
201+
202+
std::vector<int64> timestamp_;
203+
std::vector<double> value_;
204+
};
205+
206+
REGISTER_KERNEL_BUILDER(Name("PrometheusIndexableInit").Device(DEVICE_CPU),
207+
IOInterfaceInitOp<PrometheusIndexable>);
208+
REGISTER_KERNEL_BUILDER(Name("PrometheusIndexableGetItem").Device(DEVICE_CPU),
209+
IOIndexableGetItemOp<PrometheusIndexable>);
210+
77211
} // namespace data
78212
} // namespace tensorflow

tensorflow_io/prometheus/ops/prometheus_ops.cc

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,47 @@ limitations under the License.
1919

2020
namespace tensorflow {
2121

22+
REGISTER_OP("PrometheusIndexableInit")
23+
.Input("input: string")
24+
.Input("metadata: string")
25+
.Output("output: resource")
26+
.Output("dtypes: int64")
27+
.Output("shapes: int64")
28+
.Output("columns: string")
29+
.Attr("container: string = ''")
30+
.Attr("shared_name: string = ''")
31+
.SetIsStateful()
32+
.SetShapeFn([](shape_inference::InferenceContext* c) {
33+
c->set_output(0, c->Scalar());
34+
c->set_output(1, c->MakeShape({c->UnknownDim()}));
35+
c->set_output(2, c->MakeShape({c->UnknownDim(), c->UnknownDim()}));
36+
c->set_output(3, c->MakeShape({c->UnknownDim()}));
37+
return Status::OK();
38+
});
39+
40+
REGISTER_OP("PrometheusIndexableGetItem")
41+
.Input("input: resource")
42+
.Input("start: int64")
43+
.Input("stop: int64")
44+
.Input("step: int64")
45+
.Output("output: dtype")
46+
.Attr("dtype: list(type) >= 1")
47+
.Attr("shape: list(shape) >= 1")
48+
.SetShapeFn([](shape_inference::InferenceContext* c) {
49+
std::vector<PartialTensorShape> shape;
50+
TF_RETURN_IF_ERROR(c->GetAttr("shape", &shape));
51+
if (shape.size() != c->num_outputs()) {
52+
return errors::InvalidArgument("`shape` must be the same length as `types` (", shape.size(), " vs. ", c->num_outputs());
53+
}
54+
for (size_t i = 0; i < shape.size(); ++i) {
55+
shape_inference::ShapeHandle entry;
56+
TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape(shape[i], &entry));
57+
c->set_output(static_cast<int64>(i), entry);
58+
}
59+
return Status::OK();
60+
});
61+
62+
2263
REGISTER_OP("ReadPrometheus")
2364
.Input("endpoint: string")
2465
.Input("query: string")

tensorflow_io/prometheus/python/ops/prometheus_ops.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,19 @@
1717
from __future__ import division
1818
from __future__ import print_function
1919

20+
import warnings
21+
2022
import tensorflow as tf
2123
from tensorflow_io.core.python.ops import data_ops
2224
from tensorflow_io.core.python.ops import core_ops
2325

26+
warnings.warn(
27+
"The tensorflow_io.prometheus.PrometheusDataset is "
28+
"deprecated. Please look for tfio.IOTensor.from_prometheus "
29+
"for reading prometheus observations into tensorflow.",
30+
DeprecationWarning)
31+
32+
2433
def read_prometheus(endpoint, query):
2534
"""read_prometheus"""
2635
return core_ops.read_prometheus(endpoint, query)

tests/test_prometheus_eager.py

Lines changed: 6 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
import tensorflow as tf
2727
if not (hasattr(tf, "version") and tf.version.VERSION.startswith("2.")):
2828
tf.compat.v1.enable_eager_execution()
29-
import tensorflow_io.prometheus as prometheus_io # pylint: disable=wrong-import-position
29+
import tensorflow_io as tfio # pylint: disable=wrong-import-position
3030

3131
if sys.platform == "darwin":
3232
pytest.skip(
@@ -38,25 +38,14 @@ def test_prometheus():
3838
subprocess.call(["dig", "@localhost", "-p", "1053", "www.google.com"])
3939
time.sleep(1)
4040
time.sleep(2)
41-
prometheus_dataset = prometheus_io.PrometheusDataset(
42-
"http://localhost:9090",
43-
"coredns_dns_request_count_total[5s]").apply(
44-
tf.data.experimental.unbatch()).batch(2)
45-
46-
i = 0
47-
for k, v in prometheus_dataset:
48-
print("K, V: ", k.numpy(), v.numpy())
49-
if i == 4:
50-
# Last entry guaranteed 6.0
51-
assert v.numpy() == 6.0
52-
i += 2
53-
assert i == 6
54-
55-
timestamp, value = prometheus_io.read_prometheus(
56-
"http://localhost:9090",
41+
timestamp, value = tfio.IOTensor.from_prometheus(
5742
"coredns_dns_request_count_total[5s]")
5843
assert timestamp.shape == [5]
44+
assert timestamp.dtype == tf.int64
5945
assert value.shape == [5]
46+
assert value.dtype == tf.float64
47+
# last value should be 6.0
48+
assert value.to_tensor().numpy()[4] == 6.0
6049

6150
if __name__ == "__main__":
6251
test.main()

0 commit comments

Comments
 (0)