diff --git a/tensorflow_io/core/BUILD b/tensorflow_io/core/BUILD index 25c62ec34..07bdc0b36 100644 --- a/tensorflow_io/core/BUILD +++ b/tensorflow_io/core/BUILD @@ -28,6 +28,7 @@ cc_library( name = "dataset_ops", srcs = [ "kernels/dataset_ops.h", + "kernels/stream.h", ], copts = tf_io_copts(), includes = [ diff --git a/tensorflow_io/core/kernels/stream.h b/tensorflow_io/core/kernels/stream.h new file mode 100644 index 000000000..e812babf4 --- /dev/null +++ b/tensorflow_io/core/kernels/stream.h @@ -0,0 +1,71 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/lib/io/inputstream_interface.h" +#include "tensorflow/core/lib/io/random_inputstream.h" + +namespace tensorflow { +namespace data { + +// Note: This SizedRandomAccessFile should only lives within Compute() +// of the kernel as buffer could be released by outside. +class SizedRandomAccessFile : public tensorflow::RandomAccessFile { + public: + SizedRandomAccessFile(Env* env, const string& filename, const void* optional_memory_buff, const size_t optional_memory_size) + : file_(nullptr) + , size_(optional_memory_size) + , buff_((const char *)(optional_memory_buff)) + , size_status_(Status::OK()) { + if (size_ == 0) { + size_status_ = env->GetFileSize(filename, &size_); + if (size_status_.ok()) { + size_status_ = env->NewRandomAccessFile(filename, &file_); + } + } + } + + virtual ~SizedRandomAccessFile() {} + Status Read(uint64 offset, size_t n, StringPiece* result, char* scratch) const override { + if (file_.get() != nullptr) { + return file_.get()->Read(offset, n, result, scratch); + } + size_t bytes_to_read = 0; + if (offset < size_) { + bytes_to_read = (offset + n < size_) ? n : (size_ - offset); + } + if (bytes_to_read > 0) { + memcpy(scratch, &buff_[offset], bytes_to_read); + } + *result = StringPiece(scratch, bytes_to_read); + if (bytes_to_read < n) { + return errors::OutOfRange("EOF reached"); + } + return Status::OK(); + } + Status GetFileSize(uint64* size) { + if (size_status_.ok()) { + *size = size_; + } + return size_status_; + } + private: + std::unique_ptr file_; + uint64 size_; + const char *buff_; + Status size_status_; +}; + +} // namespace data +} // namespace tensorflow diff --git a/tensorflow_io/core/python/ops/data_ops.py b/tensorflow_io/core/python/ops/data_ops.py index bbd3b5d8b..40107037d 100644 --- a/tensorflow_io/core/python/ops/data_ops.py +++ b/tensorflow_io/core/python/ops/data_ops.py @@ -58,9 +58,8 @@ def _apply_fn(dataset): class BaseDataset(tf.compat.v2.data.Dataset): """A Base Dataset""" - def __init__(self, variant, batch, dtypes, shapes): + def __init__(self, variant, dtypes, shapes): """Create a Base Dataset.""" - self._batch = 0 if batch is None else batch self._dtypes = dtypes self._shapes = shapes super(BaseDataset, self).__init__(variant) @@ -93,4 +92,4 @@ def __init__(self, fn, data_input, batch, dtypes, shapes): self._data_input, self._batch, output_types=self._dtypes, - output_shapes=self._shapes), self._batch, self._dtypes, self._shapes) + output_shapes=self._shapes), self._dtypes, self._shapes) diff --git a/tensorflow_io/text/BUILD b/tensorflow_io/text/BUILD index fc12f92b4..12c169dec 100644 --- a/tensorflow_io/text/BUILD +++ b/tensorflow_io/text/BUILD @@ -11,7 +11,7 @@ cc_library( name = "text_ops", srcs = [ "kernels/csv_output.cc", - "kernels/text_input.cc", + "kernels/text_kernels.cc", "kernels/text_output.cc", "kernels/text_re2.cc", "kernels/text_sequence.cc", diff --git a/tensorflow_io/text/__init__.py b/tensorflow_io/text/__init__.py index 5ce061bd2..70173477e 100644 --- a/tensorflow_io/text/__init__.py +++ b/tensorflow_io/text/__init__.py @@ -20,6 +20,7 @@ @@save_csv @@from_csv @@re2_full_match +@@read_text """ from __future__ import absolute_import @@ -32,6 +33,7 @@ from tensorflow_io.text.python.ops.text_ops import save_csv from tensorflow_io.text.python.ops.text_ops import from_csv from tensorflow_io.text.python.ops.text_ops import re2_full_match +from tensorflow_io.text.python.ops.text_ops import read_text from tensorflow.python.util.all_util import remove_undocumented @@ -42,6 +44,7 @@ "save_csv", "from_csv", "re2_full_match", + "read_text", ] remove_undocumented(__name__, allowed_exception_list=_allowed_symbols) diff --git a/tensorflow_io/text/kernels/text_input.cc b/tensorflow_io/text/kernels/text_input.cc deleted file mode 100644 index 8261bcbbc..000000000 --- a/tensorflow_io/text/kernels/text_input.cc +++ /dev/null @@ -1,78 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "kernels/dataset_ops.h" -#include "tensorflow/core/lib/io/buffered_inputstream.h" - -namespace tensorflow { -namespace data { - -class TextInput: public FileStreamInput { - public: - Status ReadRecord(io::InputStreamInterface* s, bool owns_input_stream, IteratorContext* ctx, std::unique_ptr& state, int64 record_to_read, int64* record_read, std::vector* out_tensors) const override { - if (state.get() == nullptr) { - state.reset(new io::BufferedInputStream(s, 4096, owns_input_stream)); - } - Tensor value_tensor(ctx->allocator({}), DT_STRING, {record_to_read}); - while ((*record_read) < record_to_read) { - string buffer; - buffer.clear(); - Status status = state.get()->ReadLine(&buffer); - if (!(status.ok() || errors::IsOutOfRange(status))) { - return status; - } - if (!status.ok()) { - break; - } - value_tensor.flat()((*record_read)) = std::move(buffer); - (*record_read)++; - } - if (*record_read > 0) { - out_tensors->emplace_back(std::move(value_tensor)); - } - return Status::OK(); - } - Status FromStream(io::InputStreamInterface* s) override { - // TODO: Read 4K buffer to detect BOM. - //string header; - //TF_RETURN_IF_ERROR(s.ReadNBytes(4096, &header)); - //for (size i = 0; i < header.size(); i++) { - // if (!isprint(header[i])) { - // return errors::InvalidArgument("text file contains character that is non printable at ", i); - // } - //} - return Status::OK(); - } - void EncodeAttributes(VariantTensorData* data) const override { - } - bool DecodeAttributes(const VariantTensorData& data) override { - return true; - } - protected: -}; - -REGISTER_UNARY_VARIANT_DECODE_FUNCTION(TextInput, "tensorflow::data::TextInput"); - -REGISTER_KERNEL_BUILDER(Name("TextInput").Device(DEVICE_CPU), - FileInputOp); -REGISTER_KERNEL_BUILDER(Name("TextDataset").Device(DEVICE_CPU), - FileInputDatasetOp); - -REGISTER_KERNEL_BUILDER(Name("TextStreamInput").Device(DEVICE_CPU), - StreamInputOp); -REGISTER_KERNEL_BUILDER(Name("TextStreamDataset").Device(DEVICE_CPU), - StreamInputDatasetOp); -} // namespace data -} // namespace tensorflow diff --git a/tensorflow_io/text/kernels/text_kernels.cc b/tensorflow_io/text/kernels/text_kernels.cc new file mode 100644 index 000000000..0481237a9 --- /dev/null +++ b/tensorflow_io/text/kernels/text_kernels.cc @@ -0,0 +1,170 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/lib/io/buffered_inputstream.h" +#include "tensorflow_io/core/kernels/stream.h" + +namespace tensorflow { +namespace data { +namespace { + +class FilenoInputStream : public io::InputStreamInterface { + public: + FilenoInputStream(int fileno) : fileno_(fileno) {} + virtual ~FilenoInputStream() {} + + virtual Status ReadNBytes(int64 bytes_to_read, string* result) override { + if (bytes_to_read < 0) { + return errors::InvalidArgument("Can't read a negative number of bytes: ", bytes_to_read); + } + + result->clear(); + if (final_) { + return errors::OutOfRange("EOF reached"); + } + + + string buffer; + result->resize(bytes_to_read); + int64 bytes_read = 0; + while (bytes_read < bytes_to_read) { + size_t chunk = bytes_to_read - bytes_read; + ssize_t returned = read(fileno_, &(*result)[bytes_read], chunk); + if (returned < 0) { + result->resize(bytes_read); + return errors::Internal("read fileno ", fileno_, " error: ", returned); + } + if (returned == 0) { + break; + } + bytes_read += returned; + } + offset_ += bytes_read; + result->resize(bytes_read); + if (bytes_read < bytes_to_read) { + return errors::OutOfRange("EOF reached"); + } + return Status::OK(); + } + + virtual int64 Tell() const override { + return offset_; + } + + virtual Status Reset() override { + return errors::Unimplemented("Reset fileno stream is not implemented"); + } + private: + int fileno_ = -1; + int64 offset_ = 0; + bool final_ = false; +}; + +class ReadTextOp : public OpKernel { + public: + explicit ReadTextOp(OpKernelConstruction* context) : OpKernel(context) { + env_ = context->env(); + } + + void Compute(OpKernelContext* context) override { + const Tensor& filename_tensor = context->input(0); + const string& filename = filename_tensor.scalar()(); + + const Tensor& memory_tensor = context->input(1); + const string& memory = memory_tensor.scalar()(); + + const Tensor& offset_tensor = context->input(2); + const int64 offset = offset_tensor.scalar()(); + + const Tensor& length_tensor = context->input(3); + int64 length = length_tensor.scalar()(); + + std::vector lines; + + if (filename == "file://-" || filename == "file://0") { + // If we read from stdin then let's read until EOF is reached + // Note: It is possible to read data in large slices. However, + // BufferedInputStream takes a cached buffer which complicates + // the data read from stream. Will need to implement a no-cache + // version of ReadLine() in order to read chunks. + std::unique_ptr input_stream(new FilenoInputStream(STDIN_FILENO)); + std::unique_ptr stream(new tensorflow::io::BufferedInputStream(input_stream.get(), 4096)); + + Status status = Status::OK(); + while (status.ok()) { + string line; + status = stream->ReadLine(&line); + OP_REQUIRES(context, (status.ok() || errors::IsOutOfRange(status)), status); + if (!status.ok()) { + break; + } + lines.emplace_back(line); + } + } else { + std::unique_ptr file(new SizedRandomAccessFile(env_, filename, memory.data(), memory.size())); + uint64 size; + OP_REQUIRES_OK(context, file->GetFileSize(&size)); + if (length < 0) { + length = size; + } + + // This ReadText is a splittable version so that it is possible to read Text from a chunk of a file, + // much like Hadoop. We use the following method to decide if a line belongs to the chunk or not: + // 1) offset = 0: read lines and stop after length is reached. + // 2) offset > 0: back off 1 and skip one line to start with the next line, stop after length is reached. + // + // Note: We use BufferedInputStream which is only able to process separator of "\n", though it could + // be expanded to more than "\n" in the future. + + std::unique_ptr stream(new tensorflow::io::BufferedInputStream(file.get(), 65536)); + if (offset > 0) { + OP_REQUIRES_OK(context, stream->SkipNBytes(offset - 1)); + string line; + OP_REQUIRES_OK(context, stream->ReadLine(&line)); + } + + while (stream->Tell() < offset + length) { + string line; + Status status = stream->ReadLine(&line); + OP_REQUIRES(context, (status.ok() || errors::IsOutOfRange(status)), status); + if (!status.ok()) { + break; + } + lines.emplace_back(line); + } + } + + TensorShape output_shape({static_cast(lines.size())}); + + Tensor* output_tensor; + OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &output_tensor)); + + for (size_t i = 0; i < lines.size(); i++) { + output_tensor->flat()(i) = std::move(lines[i]); + } + } + private: + mutex mu_; + Env* env_ GUARDED_BY(mu_); +}; + +REGISTER_KERNEL_BUILDER(Name("ReadText").Device(DEVICE_CPU), + ReadTextOp); + + +} // namespace +} // namespace data +} // namespace tensorflow diff --git a/tensorflow_io/text/ops/text_ops.cc b/tensorflow_io/text/ops/text_ops.cc index ac5c6da26..e8544ca22 100644 --- a/tensorflow_io/text/ops/text_ops.cc +++ b/tensorflow_io/text/ops/text_ops.cc @@ -46,6 +46,17 @@ REGISTER_OP("RE2FullMatch") return Status::OK(); }); +REGISTER_OP("ReadText") + .Input("filename: string") + .Input("memory: string") + .Input("offset: int64") + .Input("length: int64") + .Output("output: string") + .SetShapeFn([](shape_inference::InferenceContext* c) { + c->set_output(0, c->MakeShape({c->UnknownDim()})); + return Status::OK(); + }); + REGISTER_OP("TextStreamInput") .Input("source: string") .Output("handle: variant") diff --git a/tensorflow_io/text/python/ops/text_ops.py b/tensorflow_io/text/python/ops/text_ops.py index 0ff77fbbf..4447a05e2 100644 --- a/tensorflow_io/text/python/ops/text_ops.py +++ b/tensorflow_io/text/python/ops/text_ops.py @@ -22,8 +22,16 @@ import numpy as np import tensorflow as tf -from tensorflow_io.core.python.ops import data_ops as data_ops -from tensorflow_io.core.python.ops import core_ops as text_ops +from tensorflow_io.core.python.ops import data_ops +from tensorflow_io.core.python.ops import core_ops + +def read_text(filename, **kwargs): + """read_text""" + memory = kwargs.get("memory", "") + offset = kwargs.get("offset", 0) + length = kwargs.get("length", -1) + return core_ops.read_text( + filename, offset=offset, length=length, memory=memory) def save_text(dataset, filename): """Save Dataset to disk. @@ -32,7 +40,7 @@ def save_text(dataset, filename): dataset: A TextDataset to be saved. filename: A `tf.string` tensor containing filename. """ - return text_ops.text_dataset_output(dataset._variant_tensor, filename) # pylint: disable=protected-access + return core_ops.text_dataset_output(dataset._variant_tensor, filename) # pylint: disable=protected-access def save_csv(dataset, filename): @@ -42,7 +50,7 @@ def save_csv(dataset, filename): dataset: A Dataset to be saved. filename: A `tf.string` tensor containing filename. """ - return text_ops.csv_dataset_output(dataset._variant_tensor, filename) # pylint: disable=protected-access + return core_ops.csv_dataset_output(dataset._variant_tensor, filename) # pylint: disable=protected-access def re2_full_match(input, pattern): # pylint: disable=redefined-builtin @@ -52,33 +60,45 @@ def re2_full_match(input, pattern): # pylint: disable=redefined-builtin dataset: A `tf.string` tensor pattern: A pattern string. """ - return text_ops.re2_full_match(input, pattern) + return core_ops.re2_full_match(input, pattern) -class TextDataset(data_ops.Dataset): +class TextDataset(data_ops.BaseDataset): """A Text Dataset""" - def __init__(self, filename, batch=None): + def __init__(self, filename, **kwargs): """Create a Text Reader. Args: - filename: A `tf.string` tensor containing one or more filenames. + filename: A string containing filename to read. """ - batch = 0 if batch is None else batch - dtypes = [tf.string] - shapes = [ - tf.TensorShape([])] if batch == 0 else [ - tf.TensorShape([None])] - fn = text_ops.text_stream_dataset if ( - filename == 'file://-') else text_ops.text_dataset - data_input = text_ops.text_stream_input(filename) if ( - filename == 'file://-') else text_ops.text_input( - filename, ["none", "gz"]) - super(TextDataset, self).__init__( - fn, - data_input, - batch, dtypes, shapes) + dtype = tf.string + shape = tf.TensorShape([None]) + + capacity = kwargs.get("capacity", 65536) + + if filename.startswith("file://-") or filename.startswith("file://0"): + dataset = data_ops.BaseDataset.range(1).map( + lambda length: core_ops.read_text(filename, memory="", offset=0, length=length) + ) + else: + filesize = tf.io.gfile.GFile(filename).size() + # capacity is the rough length for each split + entry_offset = list(range(0, filesize, capacity)) + entry_length = [ + min(capacity, filesize - offset) for offset in entry_offset] + dataset = data_ops.BaseDataset.from_tensor_slices( + ( + tf.constant(entry_offset, tf.int64), + tf.constant(entry_length, tf.int64) + ) + ).map(lambda offset, length: core_ops.read_text( + filename, memory="", + offset=offset, length=length)) + self._dataset = dataset + super(TextDataset, self).__init__( + self._dataset._variant_tensor, [dtype], [shape]) # pylint: disable=protected-access class TextOutputSequence(object): """TextOutputSequence""" @@ -87,10 +107,10 @@ def __init__(self, filenames): """Create a `TextOutputSequence`. """ self._filenames = filenames - self._resource = text_ops.text_output_sequence(destination=filenames) + self._resource = core_ops.text_output_sequence(destination=filenames) def setitem(self, index, item): - text_ops.text_output_sequence_set_item(self._resource, index, item) + core_ops.text_output_sequence_set_item(self._resource, index, item) def _infer_dtype(val): @@ -122,7 +142,7 @@ def from_csv(filename, header=0): """ if not tf.executing_eagerly(): raise NotImplementedError("from_csv only supports eager mode") - dataset = TextDataset(filename) + dataset = TextDataset(filename).apply(tf.data.experimental.unbatch()) columns = None if header is not None: if header != 0: diff --git a/tests/test_text/stdin_test.py b/tests/test_text/stdin_test.py index 472647360..e3ef48dbd 100644 --- a/tests/test_text/stdin_test.py +++ b/tests/test_text/stdin_test.py @@ -26,7 +26,7 @@ # tshark -T fields -e frame.number -e ip.dst -e ip.proto -r attack-trace.pcap | python stdin_test.py def f(v): - frame_number, ip_dst, ip_proto = tf.decode_csv( + frame_number, ip_dst, ip_proto = tf.io.decode_csv( v, [[0], [''], [0]], field_delim='\t') return frame_number, ip_dst, ip_proto diff --git a/tests/test_text_eager.py b/tests/test_text_eager.py index 838ce24e7..bd8731834 100644 --- a/tests/test_text_eager.py +++ b/tests/test_text_eager.py @@ -29,6 +29,36 @@ import tensorflow_io.text as text_io # pylint: disable=wrong-import-position import tensorflow_io.core.python.ops.data_ops as core_io # pylint: disable=wrong-import-position +def test_read_text(): + """test_read_text""" + filename = os.path.join( + os.path.dirname(os.path.abspath(__file__)), "test_text", "lorem.txt") + with open(filename, 'rb') as f: + lines = [line for line in f] + filename = "file://" + filename + + filesize = tf.io.gfile.GFile(filename).size() + + offset = 0 + offsets = [] + for line in lines: + offsets.append(offset) + offset += len(line) + + lines = list(zip(offsets, lines)) + + for offset, length in [ + (0, -1), (1, -1), (1000, -1), (100, 1000), (1000, 10000)]: + entries = text_io.read_text(filename, offset=offset, length=length) + if length < 0: + length = filesize - offset + expected = [ + line for (k, line) in lines if k >= offset and k < offset + length] + assert entries.shape == len(expected) + for k, v in enumerate(expected): + assert entries[k].numpy().decode() + "\n" == v.decode() + + def test_text_input(): """test_text_input """ @@ -38,13 +68,8 @@ def test_text_input(): lines = [line.strip() for line in f] text_filename = "file://" + text_filename - gzip_text_filename = os.path.join( - os.path.dirname(os.path.abspath(__file__)), "test_text", "lorem.txt.gz") - gzip_text_filename = "file://" + gzip_text_filename - - lines = lines * 3 - filenames = [text_filename, gzip_text_filename, text_filename] - text_dataset = text_io.TextDataset(filenames, batch=2) + text_dataset = text_io.TextDataset(text_filename).apply( + tf.data.experimental.unbatch()).batch(2) i = 0 for v in text_dataset: assert lines[i] == v.numpy()[0] @@ -69,7 +94,7 @@ def test_text_input(): for vv in v.numpy(): assert lines[i] == vv i += 1 - assert i == 145 + assert i == 45 rebatch_dataset = text_dataset.apply(core_io.rebatch(5, "pad")) i = 0 @@ -80,7 +105,7 @@ def test_text_input(): else: assert vv.decode() == "" i += 1 - assert i == 150 + assert i == 50 def test_text_output_sequence(): """Test case based on fashion mnist tutorial""" @@ -142,7 +167,8 @@ def test_text_output(): f, filename = tempfile.mkstemp() os.close(f) - df = text_io.TextDataset(text_filename) + df = text_io.TextDataset(text_filename).apply( + tf.data.experimental.unbatch()) df = df.take(5) text_io.save_text(df, filename) @@ -231,16 +257,15 @@ def test_from_csv(): def test_re2_extract(): """test_text_input """ - text_filename = os.path.join( + filename = os.path.join( os.path.dirname(os.path.abspath(__file__)), "test_text", "lorem.txt") - with open(text_filename, 'rb') as f: + with open(filename, 'rb') as f: lines = [line.strip() for line in f] - - filename = os.path.join( - os.path.dirname(os.path.abspath(__file__)), "test_text", "lorem.txt.gz") filename = "file://" + filename - dataset = text_io.TextDataset(filename).map(lambda x: text_io.re2_full_match(x, ".+(ipsum).+(dolor).+")) + dataset = text_io.TextDataset(filename).map( + lambda x: text_io.re2_full_match(x, ".+(ipsum).+(dolor).+")).apply( + tf.data.experimental.unbatch()) i = 0 for v in dataset: r, g = v