Skip to content

Commit b109ec4

Browse files
committed
Update JSONIOTensor with __call__ so that it is possible to support column selection
Signed-off-by: Yong Tang <[email protected]>
1 parent d90ef55 commit b109ec4

File tree

6 files changed

+117
-10
lines changed

6 files changed

+117
-10
lines changed

tensorflow_io/core/python/ops/io_tensor.py

Lines changed: 26 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
import sys
2121
import collections
22+
import uuid
2223

2324
import tensorflow as tf
2425
from tensorflow_io.core.python.ops import core_ops
@@ -457,7 +458,8 @@ def __init__(self,
457458
internal=False):
458459
with tf.name_scope("AudioIOTensor") as scope:
459460
resource, dtypes, shapes, rate = core_ops.wav_indexable_init(
460-
filename, container=scope, shared_name=filename)
461+
filename,
462+
container=scope, shared_name="%s/%s" % (filename, uuid.uuid4().hex))
461463
shapes = [
462464
tf.TensorShape(
463465
[None if dim < 0 else dim for dim in e.numpy() if dim != 0]
@@ -493,27 +495,45 @@ class JSONIOTensor(IOTensor):
493495
#=============================================================================
494496
def __init__(self,
495497
filename,
498+
columns=None,
496499
internal=False):
497500
with tf.name_scope("JSONIOTensor") as scope:
501+
metadata = []
502+
if columns is not None:
503+
metadata.extend(["column: "+column for column in columns])
498504
resource, dtypes, shapes, columns = core_ops.json_indexable_init(
499-
filename, container=scope, shared_name=filename)
505+
filename, metadata=metadata,
506+
container=scope, shared_name="%s/%s" % (filename, uuid.uuid4().hex))
500507
shapes = [
501508
tf.TensorShape(
502509
[None if dim < 0 else dim for dim in e.numpy() if dim != 0]
503510
) for e in tf.unstack(shapes)]
504511
dtypes = [tf.as_dtype(e.numpy()) for e in tf.unstack(dtypes)]
505-
spec = [tf.TensorSpec(shape, dtype) for (
506-
shape, dtype) in zip(shapes, dtypes)]
512+
columns = [e.numpy() for e in tf.unstack(columns)]
513+
spec = [tf.TensorSpec(shape, dtype, column) for (
514+
shape, dtype, column) in zip(shapes, dtypes, columns)]
507515
if len(spec) == 1:
508516
spec = spec[0]
509517
else:
510518
spec = tuple(spec)
511-
columns = [e.numpy() for e in tf.unstack(columns)]
519+
self._filename = filename
512520
super(JSONIOTensor, self).__init__(
513521
spec, resource, core_ops.json_indexable_get_item,
514522
None,
515523
internal=internal)
516524

525+
#=============================================================================
526+
# Accessors
527+
#=============================================================================
528+
529+
def column(self, name):
530+
"""The `TensorSpec` of column named `name`"""
531+
return next(e for e in tf.nest.flatten(self.spec) if e.name == name)
532+
533+
def __call__(self, column):
534+
"""Return a new JSONIOTensor with column named `column`"""
535+
return JSONIOTensor(self._filename, columns=[column], internal=True)
536+
517537
class KafkaIOTensor(IOIterableTensor):
518538
"""KafkaIOTensor"""
519539

@@ -541,7 +561,7 @@ def func_init(data):
541561
"""func_init"""
542562
resource, _, _ = core_ops.kafka_iterable_init(
543563
data["subscription"], metadata=data["metadata"],
544-
container=scope, shared_name=subscription)
564+
container=scope, shared_name="%s/%s" % (filename, uuid.uuid4().hex))
545565
return resource
546566
func_next = core_ops.kafka_iterable_next
547567

tensorflow_io/json/kernels/json_kernels.cc

Lines changed: 29 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -234,12 +234,35 @@ class JSONIndexable : public IOIndexableInterface {
234234
return errors::InvalidArgument("unable to read table: ", status);
235235
}
236236

237+
std::vector<string> columns;
238+
for (size_t i = 0; i < metadata.size(); i++) {
239+
if (metadata[i].find_first_of("column: ") == 0) {
240+
columns.emplace_back(metadata[i].substr(8));
241+
}
242+
}
243+
244+
columns_index_.clear();
245+
if (columns.size() == 0) {
246+
for (int i = 0; i < table_->num_columns(); i++) {
247+
columns_index_.push_back(i);
248+
}
249+
} else {
250+
std::unordered_map<string, int> columns_map;
251+
for (int i = 0; i < table_->num_columns(); i++) {
252+
columns_map[table_->column(i)->name()] = i;
253+
}
254+
for (size_t i = 0; i < columns.size(); i++) {
255+
columns_index_.push_back(columns_map[columns[i]]);
256+
}
257+
}
258+
237259
dtypes_.clear();
238260
shapes_.clear();
239261
columns_.clear();
240-
for (int i = 0; i < table_->num_columns(); i++) {
262+
for (size_t i = 0; i < columns_index_.size(); i++) {
263+
int column_index = columns_index_[i];
241264
::tensorflow::DataType dtype;
242-
switch (table_->column(i)->type()->id()) {
265+
switch (table_->column(column_index)->type()->id()) {
243266
case ::arrow::Type::BOOL:
244267
dtype = ::tensorflow::DT_BOOL;
245268
break;
@@ -296,7 +319,7 @@ class JSONIndexable : public IOIndexableInterface {
296319
}
297320
dtypes_.push_back(dtype);
298321
shapes_.push_back(TensorShape({static_cast<int64>(table_->num_rows())}));
299-
columns_.push_back(table_->column(i)->name());
322+
columns_.push_back(table_->column(column_index)->name());
300323
}
301324

302325
return Status::OK();
@@ -329,7 +352,8 @@ class JSONIndexable : public IOIndexableInterface {
329352
return errors::InvalidArgument("step ", step, " is not supported");
330353
}
331354
for (size_t i = 0; i < tensors.size(); i++) {
332-
std::shared_ptr<::arrow::Column> slice = table_->column(i)->Slice(start, stop);
355+
int column_index = columns_index_[i];
356+
std::shared_ptr<::arrow::Column> slice = table_->column(column_index)->Slice(start, stop);
333357

334358
#define PROCESS_TYPE(TTYPE,ATYPE) { \
335359
int64 curr_index = 0; \
@@ -398,6 +422,7 @@ class JSONIndexable : public IOIndexableInterface {
398422
std::vector<DataType> dtypes_;
399423
std::vector<TensorShape> shapes_;
400424
std::vector<string> columns_;
425+
std::vector<int> columns_index_;
401426
};
402427

403428
REGISTER_KERNEL_BUILDER(Name("JSONIndexableInit").Device(DEVICE_CPU),

tensorflow_io/json/ops/json_ops.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ namespace tensorflow {
2121

2222
REGISTER_OP("JSONIndexableInit")
2323
.Input("input: string")
24+
.Input("metadata: string")
2425
.Output("output: resource")
2526
.Output("dtypes: int64")
2627
.Output("shapes: int64")

tests/test_json/feature.ndjson

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
{ "floatfeature": 1.1, "integerfeature": 2 }
2+
{ "floatfeature": 2.1, "integerfeature": 3 }

tests/test_json/label.ndjson

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
{ "floatlabel": 2.2, "integerlabel": 3 }
2+
{ "floatlabel": 1.2, "integerlabel": 3 }

tests/test_json_eager.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,65 @@
2323
import tensorflow as tf
2424
if not (hasattr(tf, "version") and tf.version.VERSION.startswith("2.")):
2525
tf.compat.v1.enable_eager_execution()
26+
import tensorflow_io as tfio # pylint: disable=wrong-import-position
2627
import tensorflow_io.json as json_io # pylint: disable=wrong-import-position
2728

29+
def test_io_tensor_json():
30+
"""Test case for tfio.IOTensor.from_json."""
31+
x_test = [[1.1, 2], [2.1, 3]]
32+
y_test = [[2.2, 3], [1.2, 3]]
33+
feature_filename = os.path.join(
34+
os.path.dirname(os.path.abspath(__file__)),
35+
"test_json",
36+
"feature.ndjson")
37+
feature_filename = "file://" + feature_filename
38+
label_filename = os.path.join(
39+
os.path.dirname(os.path.abspath(__file__)),
40+
"test_json",
41+
"label.ndjson")
42+
label_filename = "file://" + label_filename
43+
44+
features = tfio.IOTensor.from_json(feature_filename)
45+
assert features.column("floatfeature").dtype == tf.float64
46+
assert features.column("integerfeature").dtype == tf.int64
47+
48+
labels = tfio.IOTensor.from_json(label_filename)
49+
assert labels.column("floatlabel").dtype == tf.float64
50+
assert labels.column("integerlabel").dtype == tf.int64
51+
52+
float_feature = features("floatfeature")
53+
integer_feature = features("integerfeature")
54+
float_label = labels("floatlabel")
55+
integer_label = labels("integerlabel")
56+
57+
for i in range(2):
58+
v_x = x_test[i]
59+
v_y = y_test[i]
60+
assert v_x[0] == float_feature[i].numpy()
61+
assert v_x[1] == integer_feature[i].numpy()
62+
assert v_y[0] == float_label[i].numpy()
63+
assert v_y[1] == integer_label[i].numpy()
64+
65+
feature_dataset = features.to_dataset()
66+
67+
label_dataset = labels.to_dataset()
68+
69+
dataset = tf.data.Dataset.zip((
70+
feature_dataset,
71+
label_dataset
72+
))
73+
74+
i = 0
75+
for (j_x, j_y) in dataset:
76+
v_x = x_test[i]
77+
v_y = y_test[i]
78+
for index, x in enumerate(j_x):
79+
assert v_x[index] == x.numpy()
80+
for index, y in enumerate(j_y):
81+
assert v_y[index] == y.numpy()
82+
i += 1
83+
assert i == len(y_test)
84+
2885
def test_json_dataset():
2986
"""Test case for JSON Dataset.
3087
"""

0 commit comments

Comments
 (0)