From 0b8d221b8784c45d2a2c1c7507dc394ec226b66d Mon Sep 17 00:00:00 2001 From: Yong Tang Date: Tue, 30 Jul 2019 17:57:34 +0000 Subject: [PATCH 1/3] Add read_text to read lines from splittable text file This PR is part of the effort to rework on Dataset with large files reading into Tensors first to speed up performance. See 382 and 366 for related discussions. Summary: 1) read_text is able to read a text file with in the range of [offset, offset+length] 2) that gives us the Splittable text file where we could read file in chunks (similar to hadoop) 3) the plan is to read a text file in big chunks and then wire up with tf.data.Dataset 4) read_text is a primitive C++ op so it could be used in tf.data, and it could be used in other places. Note once PR 393 is merged I will convert TextDataset to use this ops (and remove the native C++ implementation of TextDataset) Signed-off-by: Yong Tang --- tensorflow_io/text/BUILD | 1 + tensorflow_io/text/__init__.py | 3 + tensorflow_io/text/kernels/text_kernels.cc | 144 +++++++++++++++++++++ tensorflow_io/text/ops/text_ops.cc | 11 ++ tensorflow_io/text/python/ops/text_ops.py | 8 ++ tests/test_text_eager.py | 30 +++++ 6 files changed, 197 insertions(+) create mode 100644 tensorflow_io/text/kernels/text_kernels.cc diff --git a/tensorflow_io/text/BUILD b/tensorflow_io/text/BUILD index fc12f92b4..5ef8a0d24 100644 --- a/tensorflow_io/text/BUILD +++ b/tensorflow_io/text/BUILD @@ -12,6 +12,7 @@ cc_library( srcs = [ "kernels/csv_output.cc", "kernels/text_input.cc", + "kernels/text_kernels.cc", "kernels/text_output.cc", "kernels/text_re2.cc", "kernels/text_sequence.cc", diff --git a/tensorflow_io/text/__init__.py b/tensorflow_io/text/__init__.py index 5ce061bd2..70173477e 100644 --- a/tensorflow_io/text/__init__.py +++ b/tensorflow_io/text/__init__.py @@ -20,6 +20,7 @@ @@save_csv @@from_csv @@re2_full_match +@@read_text """ from __future__ import absolute_import @@ -32,6 +33,7 @@ from tensorflow_io.text.python.ops.text_ops import save_csv from tensorflow_io.text.python.ops.text_ops import from_csv from tensorflow_io.text.python.ops.text_ops import re2_full_match +from tensorflow_io.text.python.ops.text_ops import read_text from tensorflow.python.util.all_util import remove_undocumented @@ -42,6 +44,7 @@ "save_csv", "from_csv", "re2_full_match", + "read_text", ] remove_undocumented(__name__, allowed_exception_list=_allowed_symbols) diff --git a/tensorflow_io/text/kernels/text_kernels.cc b/tensorflow_io/text/kernels/text_kernels.cc new file mode 100644 index 000000000..7dbc1affb --- /dev/null +++ b/tensorflow_io/text/kernels/text_kernels.cc @@ -0,0 +1,144 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/lib/io/buffered_inputstream.h" + +namespace tensorflow { +namespace data { +namespace { + +// Note: This SizedRandomAccessFile should only lives within Compute() +// of the kernel as buffer could be released by outside. +class SizedRandomAccessFile : public tensorflow::RandomAccessFile { + public: + SizedRandomAccessFile(Env* env, const string& filename, const string& optional_memory) + : file_(nullptr) + , size_status_(Status::OK()) + , size_(optional_memory.size()) + , buffer_(optional_memory) { + if (size_ == 0) { + size_status_ = env->GetFileSize(filename, &size_); + if (size_status_.ok()) { + size_status_ = env->NewRandomAccessFile(filename, &file_); + } + } + } + + virtual ~SizedRandomAccessFile() {} + Status Read(uint64 offset, size_t n, StringPiece* result, char* scratch) const override { + if (file_.get() != nullptr) { + return file_.get()->Read(offset, n, result, scratch); + } + size_t bytes_to_read = 0; + if (offset < size_) { + bytes_to_read = (offset + n < size_) ? n : (size_ - offset); + } + if (bytes_to_read > 0) { + memcpy(scratch, buffer_.data(), bytes_to_read); + } + *result = StringPiece(scratch, bytes_to_read); + if (bytes_to_read < n) { + return errors::OutOfRange("EOF reached"); + } + return Status::OK(); + } + Status GetFileSize(uint64* size) { + if (size_status_.ok()) { + *size = size_; + } + return size_status_; + } + private: + std::unique_ptr file_; + Status size_status_; + uint64 size_; + const string& buffer_; +}; + +class ReadTextOp : public OpKernel { + public: + explicit ReadTextOp(OpKernelConstruction* context) : OpKernel(context) { + env_ = context->env(); + } + + void Compute(OpKernelContext* context) override { + const Tensor& filename_tensor = context->input(0); + const string& filename = filename_tensor.scalar()(); + + const Tensor& offset_tensor = context->input(1); + const int64 offset = offset_tensor.scalar()(); + + const Tensor& length_tensor = context->input(2); + int64 length = length_tensor.scalar()(); + + const Tensor& memory_tensor = context->input(3); + const string& memory = memory_tensor.scalar()(); + + std::unique_ptr file(new SizedRandomAccessFile(env_, filename, memory)); + uint64 size; + OP_REQUIRES_OK(context, file->GetFileSize(&size)); + + if (length < 0) { + length = size; + } + + // This ReadText is a splittable version so that it is possible to read Text from a chunk of a file, + // much like Hadoop. We use the following method to decide if a line belongs to the chunk or not: + // 1) offset = 0: read lines and stop after length is reached. + // 2) offset > 0: back off 1 and skip one line to start with the next line, stop after length is reached. + // + // Note: We use BufferedInputStream which is only able to process separator of "\n", though it could + // be expanded to more than "\n" in the future. + + std::unique_ptr stream(new tensorflow::io::BufferedInputStream(file.get(), 65536)); + if (offset > 0) { + OP_REQUIRES_OK(context, stream->SkipNBytes(offset - 1)); + string line; + OP_REQUIRES_OK(context, stream->ReadLine(&line)); + } + + std::vector lines; + while (stream->Tell() < offset + length) { + string line; + Status status = stream->ReadLine(&line); + OP_REQUIRES(context, (status.ok() || errors::IsOutOfRange(status)), status); + if (!status.ok()) { + break; + } + lines.emplace_back(line); + } + + TensorShape output_shape({static_cast(lines.size())}); + + Tensor* output_tensor; + OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &output_tensor)); + + for (size_t i = 0; i < lines.size(); i++) { + output_tensor->flat()(i) = std::move(lines[i]); + } + } + private: + mutex mu_; + Env* env_ GUARDED_BY(mu_); +}; + +REGISTER_KERNEL_BUILDER(Name("ReadText").Device(DEVICE_CPU), + ReadTextOp); + + +} // namespace +} // namespace data +} // namespace tensorflow diff --git a/tensorflow_io/text/ops/text_ops.cc b/tensorflow_io/text/ops/text_ops.cc index ac5c6da26..743a229dd 100644 --- a/tensorflow_io/text/ops/text_ops.cc +++ b/tensorflow_io/text/ops/text_ops.cc @@ -46,6 +46,17 @@ REGISTER_OP("RE2FullMatch") return Status::OK(); }); +REGISTER_OP("ReadText") + .Input("filename: string") + .Input("offset: int64") + .Input("length: int64") + .Input("memory: string") + .Output("output: string") + .SetShapeFn([](shape_inference::InferenceContext* c) { + c->set_output(0, c->MakeShape({c->UnknownDim()})); + return Status::OK(); + }); + REGISTER_OP("TextStreamInput") .Input("source: string") .Output("handle: variant") diff --git a/tensorflow_io/text/python/ops/text_ops.py b/tensorflow_io/text/python/ops/text_ops.py index 0ff77fbbf..34bd4b13a 100644 --- a/tensorflow_io/text/python/ops/text_ops.py +++ b/tensorflow_io/text/python/ops/text_ops.py @@ -25,6 +25,14 @@ from tensorflow_io.core.python.ops import data_ops as data_ops from tensorflow_io.core.python.ops import core_ops as text_ops +def read_text(filename, **kwargs): + """read_text""" + offset = kwargs.get("offset", 0) + length = kwargs.get("length", -1) + memory = kwargs.get("memory", "") + return text_ops.read_text( + filename, offset=offset, length=length, memory=memory) + def save_text(dataset, filename): """Save Dataset to disk. diff --git a/tests/test_text_eager.py b/tests/test_text_eager.py index 838ce24e7..68f5bd072 100644 --- a/tests/test_text_eager.py +++ b/tests/test_text_eager.py @@ -29,6 +29,36 @@ import tensorflow_io.text as text_io # pylint: disable=wrong-import-position import tensorflow_io.core.python.ops.data_ops as core_io # pylint: disable=wrong-import-position +def test_read_text(): + """test_read_text""" + filename = os.path.join( + os.path.dirname(os.path.abspath(__file__)), "test_text", "lorem.txt") + with open(filename, 'rb') as f: + lines = [line for line in f] + filename = "file://" + filename + + filesize = tf.io.gfile.GFile(filename).size() + + offset = 0 + offsets = [] + for line in lines: + offsets.append(offset) + offset += len(line) + + lines = zip(offsets, lines) + + for offset, length in [ + (0, -1), (1, -1), (1000, -1), (100, 1000), (1000, 10000)]: + entries = text_io.read_text(filename, offset=offset, length=length) + if length < 0: + length = filesize - offset + expected = [ + line for (k, line) in lines if k >= offset and k < offset + length] + assert entries.shape == len(expected) + for k, v in enumerate(expected): + assert entries[k].numpy().decode() + "\n" == v.decode() + + def test_text_input(): """test_text_input """ From 632330e32125fcc63e98b2fb78beca0863ebbc8a Mon Sep 17 00:00:00 2001 From: Yong Tang Date: Wed, 31 Jul 2019 20:59:42 +0000 Subject: [PATCH 2/3] Use read_text to implement TextDataset Signed-off-by: Yong Tang --- tensorflow_io/core/python/ops/data_ops.py | 5 +- tensorflow_io/text/BUILD | 1 - tensorflow_io/text/kernels/text_input.cc | 78 ------------- tensorflow_io/text/kernels/text_kernels.cc | 128 ++++++++++++++++----- tensorflow_io/text/python/ops/text_ops.py | 64 ++++++----- tests/test_text/stdin_test.py | 2 +- tests/test_text_eager.py | 23 ++-- 7 files changed, 148 insertions(+), 153 deletions(-) delete mode 100644 tensorflow_io/text/kernels/text_input.cc diff --git a/tensorflow_io/core/python/ops/data_ops.py b/tensorflow_io/core/python/ops/data_ops.py index bbd3b5d8b..40107037d 100644 --- a/tensorflow_io/core/python/ops/data_ops.py +++ b/tensorflow_io/core/python/ops/data_ops.py @@ -58,9 +58,8 @@ def _apply_fn(dataset): class BaseDataset(tf.compat.v2.data.Dataset): """A Base Dataset""" - def __init__(self, variant, batch, dtypes, shapes): + def __init__(self, variant, dtypes, shapes): """Create a Base Dataset.""" - self._batch = 0 if batch is None else batch self._dtypes = dtypes self._shapes = shapes super(BaseDataset, self).__init__(variant) @@ -93,4 +92,4 @@ def __init__(self, fn, data_input, batch, dtypes, shapes): self._data_input, self._batch, output_types=self._dtypes, - output_shapes=self._shapes), self._batch, self._dtypes, self._shapes) + output_shapes=self._shapes), self._dtypes, self._shapes) diff --git a/tensorflow_io/text/BUILD b/tensorflow_io/text/BUILD index 5ef8a0d24..12c169dec 100644 --- a/tensorflow_io/text/BUILD +++ b/tensorflow_io/text/BUILD @@ -11,7 +11,6 @@ cc_library( name = "text_ops", srcs = [ "kernels/csv_output.cc", - "kernels/text_input.cc", "kernels/text_kernels.cc", "kernels/text_output.cc", "kernels/text_re2.cc", diff --git a/tensorflow_io/text/kernels/text_input.cc b/tensorflow_io/text/kernels/text_input.cc deleted file mode 100644 index 8261bcbbc..000000000 --- a/tensorflow_io/text/kernels/text_input.cc +++ /dev/null @@ -1,78 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "kernels/dataset_ops.h" -#include "tensorflow/core/lib/io/buffered_inputstream.h" - -namespace tensorflow { -namespace data { - -class TextInput: public FileStreamInput { - public: - Status ReadRecord(io::InputStreamInterface* s, bool owns_input_stream, IteratorContext* ctx, std::unique_ptr& state, int64 record_to_read, int64* record_read, std::vector* out_tensors) const override { - if (state.get() == nullptr) { - state.reset(new io::BufferedInputStream(s, 4096, owns_input_stream)); - } - Tensor value_tensor(ctx->allocator({}), DT_STRING, {record_to_read}); - while ((*record_read) < record_to_read) { - string buffer; - buffer.clear(); - Status status = state.get()->ReadLine(&buffer); - if (!(status.ok() || errors::IsOutOfRange(status))) { - return status; - } - if (!status.ok()) { - break; - } - value_tensor.flat()((*record_read)) = std::move(buffer); - (*record_read)++; - } - if (*record_read > 0) { - out_tensors->emplace_back(std::move(value_tensor)); - } - return Status::OK(); - } - Status FromStream(io::InputStreamInterface* s) override { - // TODO: Read 4K buffer to detect BOM. - //string header; - //TF_RETURN_IF_ERROR(s.ReadNBytes(4096, &header)); - //for (size i = 0; i < header.size(); i++) { - // if (!isprint(header[i])) { - // return errors::InvalidArgument("text file contains character that is non printable at ", i); - // } - //} - return Status::OK(); - } - void EncodeAttributes(VariantTensorData* data) const override { - } - bool DecodeAttributes(const VariantTensorData& data) override { - return true; - } - protected: -}; - -REGISTER_UNARY_VARIANT_DECODE_FUNCTION(TextInput, "tensorflow::data::TextInput"); - -REGISTER_KERNEL_BUILDER(Name("TextInput").Device(DEVICE_CPU), - FileInputOp); -REGISTER_KERNEL_BUILDER(Name("TextDataset").Device(DEVICE_CPU), - FileInputDatasetOp); - -REGISTER_KERNEL_BUILDER(Name("TextStreamInput").Device(DEVICE_CPU), - StreamInputOp); -REGISTER_KERNEL_BUILDER(Name("TextStreamDataset").Device(DEVICE_CPU), - StreamInputDatasetOp); -} // namespace data -} // namespace tensorflow diff --git a/tensorflow_io/text/kernels/text_kernels.cc b/tensorflow_io/text/kernels/text_kernels.cc index 7dbc1affb..e2bd4ed22 100644 --- a/tensorflow_io/text/kernels/text_kernels.cc +++ b/tensorflow_io/text/kernels/text_kernels.cc @@ -67,6 +67,57 @@ class SizedRandomAccessFile : public tensorflow::RandomAccessFile { uint64 size_; const string& buffer_; }; +class FilenoInputStream : public io::InputStreamInterface { + public: + FilenoInputStream(int fileno) : fileno_(fileno) {} + virtual ~FilenoInputStream() {} + + virtual Status ReadNBytes(int64 bytes_to_read, string* result) override { + if (bytes_to_read < 0) { + return errors::InvalidArgument("Can't read a negative number of bytes: ", bytes_to_read); + } + + result->clear(); + if (final_) { + return errors::OutOfRange("EOF reached"); + } + + + string buffer; + result->resize(bytes_to_read); + int64 bytes_read = 0; + while (bytes_read < bytes_to_read) { + size_t chunk = bytes_to_read - bytes_read; + ssize_t returned = read(fileno_, &(*result)[bytes_read], chunk); + if (returned < 0) { + result->resize(bytes_read); + return errors::Internal("read fileno ", fileno_, " error: ", returned); + } + if (returned == 0) { + break; + } + bytes_read += returned; + } + offset_ += bytes_read; + result->resize(bytes_read); + if (bytes_read < bytes_to_read) { + return errors::OutOfRange("EOF reached"); + } + return Status::OK(); + } + + virtual int64 Tell() const override { + return offset_; + } + + virtual Status Reset() override { + return errors::Unimplemented("Reset fileno stream is not implemented"); + } + private: + int fileno_ = -1; + int64 offset_ = 0; + bool final_ = false; +}; class ReadTextOp : public OpKernel { public: @@ -87,38 +138,59 @@ class ReadTextOp : public OpKernel { const Tensor& memory_tensor = context->input(3); const string& memory = memory_tensor.scalar()(); - std::unique_ptr file(new SizedRandomAccessFile(env_, filename, memory)); - uint64 size; - OP_REQUIRES_OK(context, file->GetFileSize(&size)); + std::vector lines; - if (length < 0) { - length = size; - } + if (filename == "file://-" || filename == "file://0") { + // If we read from stdin then let's read until EOF is reached + // Note: It is possible to read data in large slices. However, + // BufferedInputStream takes a cached buffer which complicates + // the data read from stream. Will need to implement a no-cache + // version of ReadLine() in order to read chunks. + std::unique_ptr input_stream(new FilenoInputStream(STDIN_FILENO)); + std::unique_ptr stream(new tensorflow::io::BufferedInputStream(input_stream.get(), 4096)); + + Status status = Status::OK(); + while (status.ok()) { + string line; + status = stream->ReadLine(&line); + OP_REQUIRES(context, (status.ok() || errors::IsOutOfRange(status)), status); + if (!status.ok()) { + break; + } + lines.emplace_back(line); + } + } else { + std::unique_ptr file(new SizedRandomAccessFile(env_, filename, memory)); + uint64 size; + OP_REQUIRES_OK(context, file->GetFileSize(&size)); + if (length < 0) { + length = size; + } - // This ReadText is a splittable version so that it is possible to read Text from a chunk of a file, - // much like Hadoop. We use the following method to decide if a line belongs to the chunk or not: - // 1) offset = 0: read lines and stop after length is reached. - // 2) offset > 0: back off 1 and skip one line to start with the next line, stop after length is reached. - // - // Note: We use BufferedInputStream which is only able to process separator of "\n", though it could - // be expanded to more than "\n" in the future. - - std::unique_ptr stream(new tensorflow::io::BufferedInputStream(file.get(), 65536)); - if (offset > 0) { - OP_REQUIRES_OK(context, stream->SkipNBytes(offset - 1)); - string line; - OP_REQUIRES_OK(context, stream->ReadLine(&line)); - } + // This ReadText is a splittable version so that it is possible to read Text from a chunk of a file, + // much like Hadoop. We use the following method to decide if a line belongs to the chunk or not: + // 1) offset = 0: read lines and stop after length is reached. + // 2) offset > 0: back off 1 and skip one line to start with the next line, stop after length is reached. + // + // Note: We use BufferedInputStream which is only able to process separator of "\n", though it could + // be expanded to more than "\n" in the future. + + std::unique_ptr stream(new tensorflow::io::BufferedInputStream(file.get(), 65536)); + if (offset > 0) { + OP_REQUIRES_OK(context, stream->SkipNBytes(offset - 1)); + string line; + OP_REQUIRES_OK(context, stream->ReadLine(&line)); + } - std::vector lines; - while (stream->Tell() < offset + length) { - string line; - Status status = stream->ReadLine(&line); - OP_REQUIRES(context, (status.ok() || errors::IsOutOfRange(status)), status); - if (!status.ok()) { - break; + while (stream->Tell() < offset + length) { + string line; + Status status = stream->ReadLine(&line); + OP_REQUIRES(context, (status.ok() || errors::IsOutOfRange(status)), status); + if (!status.ok()) { + break; + } + lines.emplace_back(line); } - lines.emplace_back(line); } TensorShape output_shape({static_cast(lines.size())}); diff --git a/tensorflow_io/text/python/ops/text_ops.py b/tensorflow_io/text/python/ops/text_ops.py index 34bd4b13a..c03fa0344 100644 --- a/tensorflow_io/text/python/ops/text_ops.py +++ b/tensorflow_io/text/python/ops/text_ops.py @@ -22,15 +22,15 @@ import numpy as np import tensorflow as tf -from tensorflow_io.core.python.ops import data_ops as data_ops -from tensorflow_io.core.python.ops import core_ops as text_ops +from tensorflow_io.core.python.ops import data_ops +from tensorflow_io.core.python.ops import core_ops def read_text(filename, **kwargs): """read_text""" offset = kwargs.get("offset", 0) length = kwargs.get("length", -1) memory = kwargs.get("memory", "") - return text_ops.read_text( + return core_ops.read_text( filename, offset=offset, length=length, memory=memory) def save_text(dataset, filename): @@ -40,7 +40,7 @@ def save_text(dataset, filename): dataset: A TextDataset to be saved. filename: A `tf.string` tensor containing filename. """ - return text_ops.text_dataset_output(dataset._variant_tensor, filename) # pylint: disable=protected-access + return core_ops.text_dataset_output(dataset._variant_tensor, filename) # pylint: disable=protected-access def save_csv(dataset, filename): @@ -50,7 +50,7 @@ def save_csv(dataset, filename): dataset: A Dataset to be saved. filename: A `tf.string` tensor containing filename. """ - return text_ops.csv_dataset_output(dataset._variant_tensor, filename) # pylint: disable=protected-access + return core_ops.csv_dataset_output(dataset._variant_tensor, filename) # pylint: disable=protected-access def re2_full_match(input, pattern): # pylint: disable=redefined-builtin @@ -60,33 +60,45 @@ def re2_full_match(input, pattern): # pylint: disable=redefined-builtin dataset: A `tf.string` tensor pattern: A pattern string. """ - return text_ops.re2_full_match(input, pattern) + return core_ops.re2_full_match(input, pattern) -class TextDataset(data_ops.Dataset): +class TextDataset(data_ops.BaseDataset): """A Text Dataset""" - def __init__(self, filename, batch=None): + def __init__(self, filename, **kwargs): """Create a Text Reader. Args: - filename: A `tf.string` tensor containing one or more filenames. + filename: A string containing filename to read. """ - batch = 0 if batch is None else batch - dtypes = [tf.string] - shapes = [ - tf.TensorShape([])] if batch == 0 else [ - tf.TensorShape([None])] - fn = text_ops.text_stream_dataset if ( - filename == 'file://-') else text_ops.text_dataset - data_input = text_ops.text_stream_input(filename) if ( - filename == 'file://-') else text_ops.text_input( - filename, ["none", "gz"]) - super(TextDataset, self).__init__( - fn, - data_input, - batch, dtypes, shapes) + dtype = tf.string + shape = tf.TensorShape([None]) + + capacity = kwargs.get("capacity", 65536) + + if filename.startswith("file://-") or filename.startswith("file://0"): + dataset = data_ops.BaseDataset.range(1).map( + lambda length: core_ops.read_text(filename, offset=0, length=length, memory="") + ) + else: + filesize = tf.io.gfile.GFile(filename).size() + # capacity is the rough length for each split + entry_offset = list(range(0, filesize, capacity)) + entry_length = [ + min(capacity, filesize - offset) for offset in entry_offset] + dataset = data_ops.BaseDataset.from_tensor_slices( + ( + tf.constant(entry_offset, tf.int64), + tf.constant(entry_length, tf.int64) + ) + ).map(lambda offset, length: core_ops.read_text( + filename, + offset=offset, length=length, memory="")) + self._dataset = dataset + super(TextDataset, self).__init__( + self._dataset._variant_tensor, [dtype], [shape]) # pylint: disable=protected-access class TextOutputSequence(object): """TextOutputSequence""" @@ -95,10 +107,10 @@ def __init__(self, filenames): """Create a `TextOutputSequence`. """ self._filenames = filenames - self._resource = text_ops.text_output_sequence(destination=filenames) + self._resource = core_ops.text_output_sequence(destination=filenames) def setitem(self, index, item): - text_ops.text_output_sequence_set_item(self._resource, index, item) + core_ops.text_output_sequence_set_item(self._resource, index, item) def _infer_dtype(val): @@ -130,7 +142,7 @@ def from_csv(filename, header=0): """ if not tf.executing_eagerly(): raise NotImplementedError("from_csv only supports eager mode") - dataset = TextDataset(filename) + dataset = TextDataset(filename).unbatch() columns = None if header is not None: if header != 0: diff --git a/tests/test_text/stdin_test.py b/tests/test_text/stdin_test.py index 472647360..e3ef48dbd 100644 --- a/tests/test_text/stdin_test.py +++ b/tests/test_text/stdin_test.py @@ -26,7 +26,7 @@ # tshark -T fields -e frame.number -e ip.dst -e ip.proto -r attack-trace.pcap | python stdin_test.py def f(v): - frame_number, ip_dst, ip_proto = tf.decode_csv( + frame_number, ip_dst, ip_proto = tf.io.decode_csv( v, [[0], [''], [0]], field_delim='\t') return frame_number, ip_dst, ip_proto diff --git a/tests/test_text_eager.py b/tests/test_text_eager.py index 68f5bd072..fdc9e3b14 100644 --- a/tests/test_text_eager.py +++ b/tests/test_text_eager.py @@ -68,13 +68,7 @@ def test_text_input(): lines = [line.strip() for line in f] text_filename = "file://" + text_filename - gzip_text_filename = os.path.join( - os.path.dirname(os.path.abspath(__file__)), "test_text", "lorem.txt.gz") - gzip_text_filename = "file://" + gzip_text_filename - - lines = lines * 3 - filenames = [text_filename, gzip_text_filename, text_filename] - text_dataset = text_io.TextDataset(filenames, batch=2) + text_dataset = text_io.TextDataset(text_filename).unbatch().batch(2) i = 0 for v in text_dataset: assert lines[i] == v.numpy()[0] @@ -99,7 +93,7 @@ def test_text_input(): for vv in v.numpy(): assert lines[i] == vv i += 1 - assert i == 145 + assert i == 45 rebatch_dataset = text_dataset.apply(core_io.rebatch(5, "pad")) i = 0 @@ -110,7 +104,7 @@ def test_text_input(): else: assert vv.decode() == "" i += 1 - assert i == 150 + assert i == 50 def test_text_output_sequence(): """Test case based on fashion mnist tutorial""" @@ -172,7 +166,7 @@ def test_text_output(): f, filename = tempfile.mkstemp() os.close(f) - df = text_io.TextDataset(text_filename) + df = text_io.TextDataset(text_filename).unbatch() df = df.take(5) text_io.save_text(df, filename) @@ -261,16 +255,13 @@ def test_from_csv(): def test_re2_extract(): """test_text_input """ - text_filename = os.path.join( + filename = os.path.join( os.path.dirname(os.path.abspath(__file__)), "test_text", "lorem.txt") - with open(text_filename, 'rb') as f: + with open(filename, 'rb') as f: lines = [line.strip() for line in f] - - filename = os.path.join( - os.path.dirname(os.path.abspath(__file__)), "test_text", "lorem.txt.gz") filename = "file://" + filename - dataset = text_io.TextDataset(filename).map(lambda x: text_io.re2_full_match(x, ".+(ipsum).+(dolor).+")) + dataset = text_io.TextDataset(filename).map(lambda x: text_io.re2_full_match(x, ".+(ipsum).+(dolor).+")).unbatch() i = 0 for v in dataset: r, g = v From 33220b3fe840c93cc0e149d9a3da21478d4c5cbb Mon Sep 17 00:00:00 2001 From: Yong Tang Date: Wed, 31 Jul 2019 22:01:34 +0000 Subject: [PATCH 3/3] Fix python 3 failure Signed-off-by: Yong Tang --- tensorflow_io/core/BUILD | 1 + tensorflow_io/core/kernels/stream.h | 71 ++++++++++++++++++++++ tensorflow_io/text/kernels/text_kernels.cc | 60 +++--------------- tensorflow_io/text/ops/text_ops.cc | 2 +- tensorflow_io/text/python/ops/text_ops.py | 10 +-- tests/test_text_eager.py | 12 ++-- 6 files changed, 93 insertions(+), 63 deletions(-) create mode 100644 tensorflow_io/core/kernels/stream.h diff --git a/tensorflow_io/core/BUILD b/tensorflow_io/core/BUILD index 25c62ec34..07bdc0b36 100644 --- a/tensorflow_io/core/BUILD +++ b/tensorflow_io/core/BUILD @@ -28,6 +28,7 @@ cc_library( name = "dataset_ops", srcs = [ "kernels/dataset_ops.h", + "kernels/stream.h", ], copts = tf_io_copts(), includes = [ diff --git a/tensorflow_io/core/kernels/stream.h b/tensorflow_io/core/kernels/stream.h new file mode 100644 index 000000000..e812babf4 --- /dev/null +++ b/tensorflow_io/core/kernels/stream.h @@ -0,0 +1,71 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/lib/io/inputstream_interface.h" +#include "tensorflow/core/lib/io/random_inputstream.h" + +namespace tensorflow { +namespace data { + +// Note: This SizedRandomAccessFile should only lives within Compute() +// of the kernel as buffer could be released by outside. +class SizedRandomAccessFile : public tensorflow::RandomAccessFile { + public: + SizedRandomAccessFile(Env* env, const string& filename, const void* optional_memory_buff, const size_t optional_memory_size) + : file_(nullptr) + , size_(optional_memory_size) + , buff_((const char *)(optional_memory_buff)) + , size_status_(Status::OK()) { + if (size_ == 0) { + size_status_ = env->GetFileSize(filename, &size_); + if (size_status_.ok()) { + size_status_ = env->NewRandomAccessFile(filename, &file_); + } + } + } + + virtual ~SizedRandomAccessFile() {} + Status Read(uint64 offset, size_t n, StringPiece* result, char* scratch) const override { + if (file_.get() != nullptr) { + return file_.get()->Read(offset, n, result, scratch); + } + size_t bytes_to_read = 0; + if (offset < size_) { + bytes_to_read = (offset + n < size_) ? n : (size_ - offset); + } + if (bytes_to_read > 0) { + memcpy(scratch, &buff_[offset], bytes_to_read); + } + *result = StringPiece(scratch, bytes_to_read); + if (bytes_to_read < n) { + return errors::OutOfRange("EOF reached"); + } + return Status::OK(); + } + Status GetFileSize(uint64* size) { + if (size_status_.ok()) { + *size = size_; + } + return size_status_; + } + private: + std::unique_ptr file_; + uint64 size_; + const char *buff_; + Status size_status_; +}; + +} // namespace data +} // namespace tensorflow diff --git a/tensorflow_io/text/kernels/text_kernels.cc b/tensorflow_io/text/kernels/text_kernels.cc index e2bd4ed22..0481237a9 100644 --- a/tensorflow_io/text/kernels/text_kernels.cc +++ b/tensorflow_io/text/kernels/text_kernels.cc @@ -15,58 +15,12 @@ limitations under the License. #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/lib/io/buffered_inputstream.h" +#include "tensorflow_io/core/kernels/stream.h" namespace tensorflow { namespace data { namespace { -// Note: This SizedRandomAccessFile should only lives within Compute() -// of the kernel as buffer could be released by outside. -class SizedRandomAccessFile : public tensorflow::RandomAccessFile { - public: - SizedRandomAccessFile(Env* env, const string& filename, const string& optional_memory) - : file_(nullptr) - , size_status_(Status::OK()) - , size_(optional_memory.size()) - , buffer_(optional_memory) { - if (size_ == 0) { - size_status_ = env->GetFileSize(filename, &size_); - if (size_status_.ok()) { - size_status_ = env->NewRandomAccessFile(filename, &file_); - } - } - } - - virtual ~SizedRandomAccessFile() {} - Status Read(uint64 offset, size_t n, StringPiece* result, char* scratch) const override { - if (file_.get() != nullptr) { - return file_.get()->Read(offset, n, result, scratch); - } - size_t bytes_to_read = 0; - if (offset < size_) { - bytes_to_read = (offset + n < size_) ? n : (size_ - offset); - } - if (bytes_to_read > 0) { - memcpy(scratch, buffer_.data(), bytes_to_read); - } - *result = StringPiece(scratch, bytes_to_read); - if (bytes_to_read < n) { - return errors::OutOfRange("EOF reached"); - } - return Status::OK(); - } - Status GetFileSize(uint64* size) { - if (size_status_.ok()) { - *size = size_; - } - return size_status_; - } - private: - std::unique_ptr file_; - Status size_status_; - uint64 size_; - const string& buffer_; -}; class FilenoInputStream : public io::InputStreamInterface { public: FilenoInputStream(int fileno) : fileno_(fileno) {} @@ -129,15 +83,15 @@ class ReadTextOp : public OpKernel { const Tensor& filename_tensor = context->input(0); const string& filename = filename_tensor.scalar()(); - const Tensor& offset_tensor = context->input(1); + const Tensor& memory_tensor = context->input(1); + const string& memory = memory_tensor.scalar()(); + + const Tensor& offset_tensor = context->input(2); const int64 offset = offset_tensor.scalar()(); - const Tensor& length_tensor = context->input(2); + const Tensor& length_tensor = context->input(3); int64 length = length_tensor.scalar()(); - const Tensor& memory_tensor = context->input(3); - const string& memory = memory_tensor.scalar()(); - std::vector lines; if (filename == "file://-" || filename == "file://0") { @@ -160,7 +114,7 @@ class ReadTextOp : public OpKernel { lines.emplace_back(line); } } else { - std::unique_ptr file(new SizedRandomAccessFile(env_, filename, memory)); + std::unique_ptr file(new SizedRandomAccessFile(env_, filename, memory.data(), memory.size())); uint64 size; OP_REQUIRES_OK(context, file->GetFileSize(&size)); if (length < 0) { diff --git a/tensorflow_io/text/ops/text_ops.cc b/tensorflow_io/text/ops/text_ops.cc index 743a229dd..e8544ca22 100644 --- a/tensorflow_io/text/ops/text_ops.cc +++ b/tensorflow_io/text/ops/text_ops.cc @@ -48,9 +48,9 @@ REGISTER_OP("RE2FullMatch") REGISTER_OP("ReadText") .Input("filename: string") + .Input("memory: string") .Input("offset: int64") .Input("length: int64") - .Input("memory: string") .Output("output: string") .SetShapeFn([](shape_inference::InferenceContext* c) { c->set_output(0, c->MakeShape({c->UnknownDim()})); diff --git a/tensorflow_io/text/python/ops/text_ops.py b/tensorflow_io/text/python/ops/text_ops.py index c03fa0344..4447a05e2 100644 --- a/tensorflow_io/text/python/ops/text_ops.py +++ b/tensorflow_io/text/python/ops/text_ops.py @@ -27,9 +27,9 @@ def read_text(filename, **kwargs): """read_text""" + memory = kwargs.get("memory", "") offset = kwargs.get("offset", 0) length = kwargs.get("length", -1) - memory = kwargs.get("memory", "") return core_ops.read_text( filename, offset=offset, length=length, memory=memory) @@ -79,7 +79,7 @@ def __init__(self, filename, **kwargs): if filename.startswith("file://-") or filename.startswith("file://0"): dataset = data_ops.BaseDataset.range(1).map( - lambda length: core_ops.read_text(filename, offset=0, length=length, memory="") + lambda length: core_ops.read_text(filename, memory="", offset=0, length=length) ) else: filesize = tf.io.gfile.GFile(filename).size() @@ -93,8 +93,8 @@ def __init__(self, filename, **kwargs): tf.constant(entry_length, tf.int64) ) ).map(lambda offset, length: core_ops.read_text( - filename, - offset=offset, length=length, memory="")) + filename, memory="", + offset=offset, length=length)) self._dataset = dataset super(TextDataset, self).__init__( @@ -142,7 +142,7 @@ def from_csv(filename, header=0): """ if not tf.executing_eagerly(): raise NotImplementedError("from_csv only supports eager mode") - dataset = TextDataset(filename).unbatch() + dataset = TextDataset(filename).apply(tf.data.experimental.unbatch()) columns = None if header is not None: if header != 0: diff --git a/tests/test_text_eager.py b/tests/test_text_eager.py index fdc9e3b14..bd8731834 100644 --- a/tests/test_text_eager.py +++ b/tests/test_text_eager.py @@ -45,7 +45,7 @@ def test_read_text(): offsets.append(offset) offset += len(line) - lines = zip(offsets, lines) + lines = list(zip(offsets, lines)) for offset, length in [ (0, -1), (1, -1), (1000, -1), (100, 1000), (1000, 10000)]: @@ -68,7 +68,8 @@ def test_text_input(): lines = [line.strip() for line in f] text_filename = "file://" + text_filename - text_dataset = text_io.TextDataset(text_filename).unbatch().batch(2) + text_dataset = text_io.TextDataset(text_filename).apply( + tf.data.experimental.unbatch()).batch(2) i = 0 for v in text_dataset: assert lines[i] == v.numpy()[0] @@ -166,7 +167,8 @@ def test_text_output(): f, filename = tempfile.mkstemp() os.close(f) - df = text_io.TextDataset(text_filename).unbatch() + df = text_io.TextDataset(text_filename).apply( + tf.data.experimental.unbatch()) df = df.take(5) text_io.save_text(df, filename) @@ -261,7 +263,9 @@ def test_re2_extract(): lines = [line.strip() for line in f] filename = "file://" + filename - dataset = text_io.TextDataset(filename).map(lambda x: text_io.re2_full_match(x, ".+(ipsum).+(dolor).+")).unbatch() + dataset = text_io.TextDataset(filename).map( + lambda x: text_io.re2_full_match(x, ".+(ipsum).+(dolor).+")).apply( + tf.data.experimental.unbatch()) i = 0 for v in dataset: r, g = v