From cfe8bfe76028f989381113189cc5dd62dce6edb9 Mon Sep 17 00:00:00 2001 From: Jiacheng Xu Date: Wed, 7 Aug 2019 17:03:04 +0200 Subject: [PATCH] Add read_json and list_json_columns for JSON support. --- tensorflow_io/json/BUILD | 2 +- tensorflow_io/json/__init__.py | 6 + tensorflow_io/json/kernels/json_input.cc | 166 ----------------- tensorflow_io/json/kernels/json_kernels.cc | 203 +++++++++++++++++++++ tensorflow_io/json/ops/json_ops.cc | 26 ++- tensorflow_io/json/python/ops/json_ops.py | 57 ++++-- tests/test_json.py | 70 ++++--- tests/test_json_eager.py | 114 ++++++------ 8 files changed, 352 insertions(+), 292 deletions(-) delete mode 100644 tensorflow_io/json/kernels/json_input.cc create mode 100644 tensorflow_io/json/kernels/json_kernels.cc diff --git a/tensorflow_io/json/BUILD b/tensorflow_io/json/BUILD index 99b4700a3..ae75eb6df 100644 --- a/tensorflow_io/json/BUILD +++ b/tensorflow_io/json/BUILD @@ -10,7 +10,7 @@ load( cc_library( name = "json_ops", srcs = [ - "kernels/json_input.cc", + "kernels/json_kernels.cc", "ops/json_ops.cc", ], copts = tf_io_copts(), diff --git a/tensorflow_io/json/__init__.py b/tensorflow_io/json/__init__.py index 6350756ff..56db1f049 100644 --- a/tensorflow_io/json/__init__.py +++ b/tensorflow_io/json/__init__.py @@ -15,6 +15,8 @@ """JSONDataset. @@JSONDataset +@@list_json_columns +@@read_json """ from __future__ import absolute_import @@ -24,9 +26,13 @@ from tensorflow.python.util.all_util import remove_undocumented from tensorflow_io.json.python.ops.json_ops import JSONDataset +from tensorflow_io.json.python.ops.json_ops import list_json_columns +from tensorflow_io.json.python.ops.json_ops import read_json _allowed_symbols = [ "JSONDataset", + "list_json_columns", + "read_json", ] remove_undocumented(__name__, allowed_exception_list=_allowed_symbols) diff --git a/tensorflow_io/json/kernels/json_input.cc b/tensorflow_io/json/kernels/json_input.cc deleted file mode 100644 index 769d77628..000000000 --- a/tensorflow_io/json/kernels/json_input.cc +++ /dev/null @@ -1,166 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include -#include -#include "kernels/dataset_ops.h" -#include "tensorflow/core/lib/io/buffered_inputstream.h" -#include "tensorflow/core/platform/env.h" -#include "include/json/json.h" - -namespace tensorflow { -namespace data { - -class JSONInputStream { -public: - explicit JSONInputStream(const string& filename) { - Env* env = Env::Default(); - uint64 size = 0; - Status status = env->GetFileSize(filename, &size); - if (status.ok()) { - std::unique_ptr file; - status = env->NewRandomAccessFile(filename, &file); - if (status.ok()) { - StringPiece result; - buffer_memory_.resize(size); - status = file->Read(0, size, &result, &buffer_memory_[0]); - } - } - } - - ~JSONInputStream() {} - Status Open(){ - if (reader_.parse(buffer_memory_, records_)) { - return Status::OK(); - } - return errors::InvalidArgument("JSON parsing error: ", reader_.getFormattedErrorMessages()); - } - -bool ReadRecord(Json::Value& record){ - if (index_ < records_.size()){ - record = records_[index_]; - index_++; - return true; - } - return false; -} - -private: - string buffer_memory_; - string filename_; - Json::Reader reader_; - Json::Value records_; - Json::ArrayIndex index_ = 0; -}; - -class JSONInput: public FileInput { - public: - Status ReadRecord(io::InputStreamInterface* s, IteratorContext* ctx, std::unique_ptr& state, int64 record_to_read, int64* record_read, std::vector* out_tensors) const override { - if (state.get() == nullptr) { - state.reset(new JSONInputStream(filename())); - TF_RETURN_IF_ERROR(state.get()->Open()); - } - - Json::Value record; - while ((*record_read) < record_to_read && state.get()->ReadRecord(record)) { - if(*record_read == 0){ - out_tensors->clear(); - //allocate enough space for Tensor - for (size_t i = 0; i < columns().size(); i++){ - const string& column = columns()[i]; - const Json::Value& val = record[column]; - DataType dtype; - switch (val.type()){ - case Json::intValue: - dtype = DT_INT64; - break; - case Json::uintValue: - dtype = DT_UINT64; - break; - case Json::realValue: - dtype = DT_DOUBLE; - break; - case Json::stringValue: - dtype = DT_STRING; - break; - case Json::booleanValue: - dtype = DT_BOOL; - break; - //Currently Json::nullValue, Json::arrayValue, and Json::objectValue are not supported. - default: - return errors::InvalidArgument("Unsupported data type: ", val.type()); - } - Tensor tensor(ctx->allocator({}), dtype, {record_to_read}); - out_tensors->emplace_back(std::move(tensor)); - } - } - for (size_t i = 0; i < columns().size(); i++) { - const string& column = columns()[i]; - const Json::Value& val = record[column]; - switch (val.type()){ - case Json::intValue: - ((*out_tensors)[i]).flat()(*record_read) = val.asInt64(); - break; - case Json::uintValue: - ((*out_tensors)[i]).flat()(*record_read) = val.asUInt64(); - break; - case Json::realValue: - ((*out_tensors)[i]).flat()(*record_read) = val.asDouble(); - break; - case Json::stringValue: - ((*out_tensors)[i]).flat()(*record_read) = val.asString(); - break; - case Json::booleanValue: - ((*out_tensors)[i]).flat()(*record_read) = val.asBool(); - break; - //Currently Json::nullValue, Json::arrayValue, and Json::objectValue are not supported. - default: - return errors::InvalidArgument("Unsupported data type: ", val.type()); - } - } - (*record_read)++; - } - - //Slice if needed - if (*record_read < record_to_read) { - if (*record_read == 0) { - out_tensors->clear(); - } - for (size_t i = 0; i < out_tensors->size(); i++) { - Tensor tensor = (*out_tensors)[i].Slice(0, *record_read); - (*out_tensors)[i] = std::move(tensor); - } - } - return Status::OK(); - } - Status FromStream(io::InputStreamInterface* s) override { - return Status::OK(); - } - void EncodeAttributes(VariantTensorData* data) const override { - } - bool DecodeAttributes(const VariantTensorData& data) override { - return true; - } - protected: -}; - -REGISTER_UNARY_VARIANT_DECODE_FUNCTION(JSONInput, "tensorflow::data::JSONInput"); - -REGISTER_KERNEL_BUILDER(Name("JSONInput").Device(DEVICE_CPU), - FileInputOp); -REGISTER_KERNEL_BUILDER(Name("JSONDataset").Device(DEVICE_CPU), - FileInputDatasetOp); -} // namespace data -} // namespace tensorflow diff --git a/tensorflow_io/json/kernels/json_kernels.cc b/tensorflow_io/json/kernels/json_kernels.cc new file mode 100644 index 000000000..2fcae9863 --- /dev/null +++ b/tensorflow_io/json/kernels/json_kernels.cc @@ -0,0 +1,203 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include "kernels/dataset_ops.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow_io/core/kernels/stream.h" +#include "tensorflow/core/lib/io/buffered_inputstream.h" +#include "tensorflow/core/platform/env.h" +#include "include/json/json.h" + +namespace tensorflow { +namespace data { +namespace{ + +class ListJSONColumnsOp : public OpKernel { +public: + explicit ListJSONColumnsOp(OpKernelConstruction* context) : OpKernel(context) { + env_ = context->env(); + } + + void Compute(OpKernelContext* context) override { + const Tensor& filename_tensor = context->input(0); + const string filename = filename_tensor.scalar()(); + + std::vector columns; + std::vector dtypes; + string error; + + // Read the whole JSON file to the memory. + uint64 size = 0; + OP_REQUIRES_OK(context, env_->GetFileSize(filename, &size)); + std::unique_ptr file; + OP_REQUIRES_OK(context, env_->NewRandomAccessFile(filename, &file)); + string buffer_memory; + StringPiece result; + buffer_memory.resize(size); + OP_REQUIRES_OK(context, file->Read(0, size, &result, &buffer_memory[0])); + + // Parse JSON records from the buffer string. + Json::Reader reader; + Json::Value records; + OP_REQUIRES(context, reader.parse(buffer_memory, records), + errors::InvalidArgument("JSON parsing error: ", reader.getFormattedErrorMessages())); + + + // Read one JSON record to get the list of the columns + OP_REQUIRES(context, records.type() == Json::arrayValue, + errors::InvalidArgument("JSON is not in record format!")); + const Json::Value& record = records[0]; + columns = record.getMemberNames(); + for(size_t i=0; iallocate_output(0, output_shape, &columns_tensor)); + Tensor* dtypes_tensor; + OP_REQUIRES_OK(context, context->allocate_output(1, output_shape, &dtypes_tensor)); + + output_shape.AddDim(1); + + for (size_t i = 0; i < columns.size(); i++) { + columns_tensor->flat()(i) = columns[i]; + dtypes_tensor->flat()(i) = dtypes[i]; + } + } + private: + mutex mu_; + Env* env_ GUARDED_BY(mu_); +}; + +class ReadJSONOp : public OpKernel { +public: + explicit ReadJSONOp(OpKernelConstruction* context) : OpKernel(context) { + env_ = context->env(); + } + + void Compute(OpKernelContext* context) override { + const Tensor& filename_tensor = context->input(0); + const string& filename = filename_tensor.scalar()(); + + const Tensor& column_tensor = context->input(1); + const string& column = column_tensor.scalar()(); + + // Read the whole JSON file to the memory. + uint64 size = 0; + OP_REQUIRES_OK(context, env_->GetFileSize(filename, &size)); + std::unique_ptr file; + OP_REQUIRES_OK(context, env_->NewRandomAccessFile(filename, &file)); + string buffer_memory; + StringPiece result; + buffer_memory.resize(size); + OP_REQUIRES_OK(context, file->Read(0, size, &result, &buffer_memory[0])); + + + // Parse JSON records from the buffer string. + Json::Reader reader; + Json::Value json_records; + OP_REQUIRES(context, reader.parse(buffer_memory, json_records), + errors::InvalidArgument("JSON parsing error: ", reader.getFormattedErrorMessages())); + + OP_REQUIRES(context, json_records.type() == Json::arrayValue, + errors::InvalidArgument("JSON is not in record format!")); + + #define BOOL_VALUE records.push_back(val.asBool()) + #define INT64_VALUE records.emplace_back(val.asInt64()) + #define UINT64_VALUE records.emplace_back(val.asUInt64()) + #define DOUBLE_VALUE records.emplace_back(val.asDouble()) + #define STRING_VALUE records.emplace_back(val.asString()) + + #define PROCESS_RECORD(TYPE, VALUE) { \ + std::vector records; \ + for(Json::ArrayIndex i = 0; i < json_records.size(); i++) { \ + const Json::Value& record = json_records[i]; \ + const Json::Value& val = record[column]; \ + VALUE; \ + } \ + Tensor* output_tensor; \ + OP_REQUIRES_OK(context, context->allocate_output(0, TensorShape({static_cast(records.size())}), &output_tensor)); \ + for (size_t i = 0; i < records.size(); i++) { \ + output_tensor->flat()(i) = std::move(records[i]); \ + } \ + } + + const Json::Value& first_record = json_records[0]; + const Json::Value& value = first_record[column]; + switch(value.type()) { + case Json::intValue: + PROCESS_RECORD(int64, INT64_VALUE); + break; + case Json::uintValue: + PROCESS_RECORD(uint64, UINT64_VALUE); + break; + case Json::realValue: + PROCESS_RECORD(double, DOUBLE_VALUE); + break; + case Json::stringValue: + PROCESS_RECORD(string, STRING_VALUE); + break; + case Json::booleanValue: + PROCESS_RECORD(bool, BOOL_VALUE); + break; + //Currently Json::nullValue, Json::arrayValue, and Json::objectValue are not supported. + default: + OP_REQUIRES(context, false, errors::InvalidArgument("unsupported data type: ", value.type())); + } + + } +private: + mutex mu_; + Env* env_ GUARDED_BY(mu_); +}; + +REGISTER_KERNEL_BUILDER(Name("ListJSONColumns").Device(DEVICE_CPU), + ListJSONColumnsOp); +REGISTER_KERNEL_BUILDER(Name("ReadJSON").Device(DEVICE_CPU), + ReadJSONOp); + +} // namespace +} // namespace data +} // namespace tensorflow diff --git a/tensorflow_io/json/ops/json_ops.cc b/tensorflow_io/json/ops/json_ops.cc index 1a5b8f870..a2a6afc7d 100644 --- a/tensorflow_io/json/ops/json_ops.cc +++ b/tensorflow_io/json/ops/json_ops.cc @@ -19,27 +19,23 @@ limitations under the License. namespace tensorflow { -REGISTER_OP("JSONInput") - .Input("source: string") - .Output("handle: variant") - .Attr("filters: list(string) = []") - .Attr("columns: list(string) = []") - .Attr("schema: string = ''") +REGISTER_OP("ListJSONColumns") + .Input("filename: string") + .Output("columns: string") + .Output("dtypes: string") .SetShapeFn([](shape_inference::InferenceContext* c) { c->set_output(0, c->MakeShape({c->UnknownDim()})); + c->set_output(1, c->MakeShape({c->UnknownDim()})); return Status::OK(); }); -REGISTER_OP("JSONDataset") - .Input("input: T") - .Input("batch: int64") - .Output("handle: variant") - .Attr("output_types: list(type) >= 1") - .Attr("output_shapes: list(shape) >= 1") - .Attr("T: {string, variant} = DT_VARIANT") - .SetIsStateful() +REGISTER_OP("ReadJSON") + .Input("filename: string") + .Input("column: string") + .Attr("dtype: type") + .Output("output: dtype") .SetShapeFn([](shape_inference::InferenceContext* c) { - c->set_output(0, c->MakeShape({})); + c->set_output(0, c->MakeShape({c->UnknownDim()})); return Status::OK(); }); diff --git a/tensorflow_io/json/python/ops/json_ops.py b/tensorflow_io/json/python/ops/json_ops.py index 5edd6e101..774a6f8d7 100644 --- a/tensorflow_io/json/python/ops/json_ops.py +++ b/tensorflow_io/json/python/ops/json_ops.py @@ -13,33 +13,52 @@ # limitations under the License. # ============================================================================== """JSONDataset""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + import tensorflow as tf from tensorflow_io.core.python.ops import data_ops -from tensorflow_io.core.python.ops import core_ops as json_ops +from tensorflow_io.core.python.ops import core_ops + +def list_json_columns(filename): + """list_json_columns""" + if not tf.executing_eagerly(): + raise NotImplementedError("list_json_columns only support eager mode") + columns, dtypes = core_ops.list_json_columns(filename) + entries = zip(tf.unstack(columns), tf.unstack(dtypes)) + return dict([(column.numpy().decode(), tf.TensorSpec( + tf.TensorShape([None]), + dtype.numpy().decode(), + column.numpy().decode())) for ( + column, dtype) in entries]) + +def read_json(filename, column): + """read_json""" + return core_ops.read_json( + filename, column.name, dtype=column.dtype) -class JSONDataset(data_ops.Dataset): +class JSONDataset(data_ops.BaseDataset): """A JSONLabelDataset. JSON (JavaScript Object Notation) is a lightweight data-interchange format. """ - def __init__(self, filenames, columns, dtypes, batch=None): + def __init__(self, filename, column, **kwargs): """Create a JSONLabelDataset. Args: - filenames: A 0-D or 1-D `tf.string` tensor containing one or more - filenames. - columns: A 0-D or 1-D `tf.int32` tensor containing the columns to extract. - dtypes: A tuple of `tf.DType` objects representing the types of the - columns returned. + filename: A string containing one or more filenames. + column: A string containing the column to extract. """ - data_input = json_ops.json_input( - filenames, ["none", "gz"], columns=columns) - dtypes = dtypes - batch = 0 if batch is None else batch - shapes = [ - tf.TensorShape([]) for _ in columns] if batch == 0 else [ - tf.TensorShape([None]) for _ in columns] + if not tf.executing_eagerly(): + dtype = kwargs.get("dtype") + else: + columns = list_json_columns(filename) + dtype = columns[column].dtype + shape = tf.TensorShape([None]) + + dataset = data_ops.BaseDataset.from_tensors( + core_ops.read_json(filename, column, dtype=dtype)) + self._dataset = dataset + super(JSONDataset, self).__init__( - json_ops.json_dataset, - data_input, - batch, dtypes, shapes - ) + self._dataset._variant_tensor, [dtype], [shape]) # pylint: disable=protected-access diff --git a/tests/test_json.py b/tests/test_json.py index 81cd0d849..9240a5d95 100644 --- a/tests/test_json.py +++ b/tests/test_json.py @@ -19,49 +19,43 @@ from __future__ import print_function import os +import numpy as np +import pytest + import tensorflow as tf tf.compat.v1.disable_eager_execution() - -from tensorflow import dtypes # pylint: disable=wrong-import-position -from tensorflow import errors # pylint: disable=wrong-import-position -from tensorflow import test # pylint: disable=wrong-import-position -from tensorflow.compat.v1 import data # pylint: disable=wrong-import-position - import tensorflow_io.json as json_io # pylint: disable=wrong-import-position - -class JSONDatasetTest(test.TestCase): - """JSONDatasetTest""" - - def test_json_dataset(self): - """Test case for JSONDataset.""" - filename = os.path.join( - os.path.dirname(os.path.abspath(__file__)), - "test_json", - "feature.json") - filename = "file://" + filename - - columns = ['floatfeature', 'integerfeature'] - output_types = (dtypes.float64, dtypes.int64) - num_repeats = 2 - - dataset = json_io.JSONDataset( - filename, columns=columns, dtypes=output_types).repeat(num_repeats) - iterator = data.make_initializable_iterator(dataset) - init_op = iterator.initializer - get_next = iterator.get_next() - - test_json = [(1.1, 2), (2.1, 3)] - with self.test_session() as sess: - sess.run(init_op) - for _ in range(num_repeats): - for i in range(2): - (floatf, intf) = test_json[i] - vv = sess.run(get_next) - self.assertAllClose((floatf, intf), vv) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) +def test_json_dataset(): + """Test case for JSONDataset.""" + filename = os.path.join( + os.path.dirname(os.path.abspath(__file__)), + "test_json", + "feature.json") + filename = "file://" + filename + + num_repeats = 2 + dataset = tf.compat.v2.data.Dataset.zip( + ( + json_io.JSONDataset(filename, "floatfeature", dtype=tf.float64), + json_io.JSONDataset(filename, "integerfeature", dtype=tf.int64) + )).repeat(num_repeats).apply(tf.data.experimental.unbatch()) + + iterator = tf.compat.v1.data.make_initializable_iterator(dataset) + init_op = iterator.initializer + get_next = iterator.get_next() + + test_json = [(1.1, 2), (2.1, 3)] + with tf.compat.v1.Session() as sess: + sess.run(init_op) + for _ in range(num_repeats): + for i in range(2): + (floatf, intf) = test_json[i] + vv = sess.run(get_next) + np.allclose((floatf, intf), vv) + with pytest.raises(tf.errors.OutOfRangeError): + sess.run(get_next) if __name__ == "__main__": test.main() diff --git a/tests/test_json_eager.py b/tests/test_json_eager.py index 088a02192..333f24718 100644 --- a/tests/test_json_eager.py +++ b/tests/test_json_eager.py @@ -19,12 +19,10 @@ from __future__ import print_function import os -import numpy as np import tensorflow as tf if not (hasattr(tf, "version") and tf.version.VERSION.startswith("2.")): tf.compat.v1.enable_eager_execution() - import tensorflow_io.json as json_io # pylint: disable=wrong-import-position def test_json_dataset(): @@ -43,47 +41,49 @@ def test_json_dataset(): "label.json") label_filename = "file://" + label_filename - feature_list = ["floatfeature", "integerfeature"] - label_list = ["floatlabel", "integerlabel"] - feature_dataset = json_io.JSONDataset( - feature_filename, - feature_list, - [tf.float64, tf.int64]) - label_dataset = json_io.JSONDataset( - label_filename, - label_list, - [tf.float64, tf.int64]) + feature_cols = json_io.list_json_columns(feature_filename) + assert feature_cols["floatfeature"].dtype == tf.float64 + assert feature_cols["integerfeature"].dtype == tf.int64 - i = 0 - for record in feature_dataset: - v_x = x_test[i] - for index, val in enumerate(record): - assert v_x[index] == val.numpy() - i += 1 - assert i == len(y_test) + label_cols = json_io.list_json_columns(label_filename) + assert label_cols["floatlabel"].dtype == tf.float64 + assert label_cols["integerlabel"].dtype == tf.int64 - ## Test of the reverse order of the columns - feature_list = ["integerfeature", "floatfeature"] - feature_dataset = json_io.JSONDataset( + float_feature = json_io.read_json( feature_filename, - feature_list, - [tf.int64, tf.float64]) - - i = 0 - for record in feature_dataset: - v_x = np.flip(x_test[i]) - for index, val in enumerate(record): - assert v_x[index] == val.numpy() - i += 1 - assert i == len(y_test) + feature_cols["floatfeature"]) + integer_feature = json_io.read_json( + feature_filename, + feature_cols["integerfeature"]) + float_label = json_io.read_json( + label_filename, + label_cols["floatlabel"]) + integer_label = json_io.read_json( + label_filename, + label_cols["integerlabel"]) - i = 0 - for record in label_dataset: + for i in range(2): + v_x = x_test[i] v_y = y_test[i] - for index, val in enumerate(record): - assert v_y[index] == val.numpy() - i += 1 - assert i == len(y_test) + assert v_x[0] == float_feature[i].numpy() + assert v_x[1] == integer_feature[i].numpy() + assert v_y[0] == float_label[i].numpy() + assert v_y[1] == integer_label[i].numpy() + + feature_dataset = tf.compat.v2.data.Dataset.zip( + ( + json_io.JSONDataset(feature_filename, "floatfeature"), + json_io.JSONDataset(feature_filename, "integerfeature") + ) + ).apply(tf.data.experimental.unbatch()) + + label_dataset = tf.compat.v2.data.Dataset.zip( + ( + json_io.JSONDataset(label_filename, "floatlabel"), + json_io.JSONDataset(label_filename, "integerlabel") + ) + ).apply(tf.data.experimental.unbatch()) + dataset = tf.data.Dataset.zip(( feature_dataset, @@ -92,7 +92,7 @@ def test_json_dataset(): i = 0 for (j_x, j_y) in dataset: - v_x = np.flip(x_test[i]) + v_x = x_test[i] v_y = y_test[i] for index, x in enumerate(j_x): assert v_x[index] == x.numpy() @@ -114,20 +114,28 @@ def test_json_keras(): "species.json") label_filename = "file://" + label_filename - feature_list = ['sepalLength', 'sepalWidth', 'petalLength', 'petalWidth'] - label_list = ["species"] - feature_types = [tf.float64, tf.float64, tf.float64, tf.float64] - label_types = [tf.int64] - feature_dataset = json_io.JSONDataset( - feature_filename, - feature_list, - feature_types, - batch=32) - label_dataset = json_io.JSONDataset( - label_filename, - label_list, - label_types, - batch=32) + feature_cols = json_io.list_json_columns(feature_filename) + label_cols = json_io.list_json_columns(label_filename) + + feature_tensors = [] + for feature in feature_cols: + dataset = json_io.JSONDataset(feature_filename, feature) + feature_tensors.append(dataset) + + label_tensors = [] + for label in label_cols: + dataset = json_io.JSONDataset(label_filename, label) + label_tensors.append(dataset) + + + feature_dataset = tf.compat.v2.data.Dataset.zip( + tuple(feature_tensors) + ) + + label_dataset = tf.compat.v2.data.Dataset.zip( + tuple(label_tensors) + ) + dataset = tf.data.Dataset.zip(( feature_dataset, label_dataset