Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions tensorflow_io/core/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,22 @@ cc_library(
],
)

cc_library(
name = "file_ops",
srcs = [
"kernels/file_input.cc",
"ops/file_ops.cc",
],
copts = tf_io_copts(),
includes = [
".",
],
linkstatic = True,
deps = [
":dataset_ops",
],
)

cc_library(
name = "ffmpeg_3.4",
srcs = [
Expand Down Expand Up @@ -107,6 +123,7 @@ cc_binary(
copts = tf_io_copts(),
linkshared = 1,
deps = [
":file_ops",
"//tensorflow_io/audio:audio_ops",
"//tensorflow_io/avro:avro_ops",
"//tensorflow_io/azure:azfs_ops",
Expand Down
89 changes: 89 additions & 0 deletions tensorflow_io/core/kernels/file_input.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include "kernels/dataset_ops.h"
#include "tensorflow/core/lib/io/buffered_inputstream.h"

namespace tensorflow {
namespace data {

class FileContentInput: public FileInput<bool> {
public:
Status ReadRecord(io::InputStreamInterface* s, IteratorContext* ctx, std::unique_ptr<bool>& state, int64 record_to_read, int64* record_read, std::vector<Tensor>* out_tensors) const override {
if (record_to_read != 1) {
return errors::InvalidArgument("FileDataset only accept reading one record at a time");
}
if (state.get() == nullptr) {
state.reset(new bool(true));
} else {
// We only read file once.
return Status::OK();
}
int64 chunk_size = 4096;
SizedRandomAccessInputStreamInterface* sized_stream = dynamic_cast<SizedRandomAccessInputStreamInterface*>(s);
if (sized_stream != nullptr) {
// First try to find out the size of the file, depending on if size is available, we will set the chunk to read.
uint64 file_size = 0;
if (sized_stream->GetFileSize(&file_size) == Status::OK()) {
chunk_size = file_size;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it be helpful if we add a warning log here when file_size is big?

}
}
std::vector<string> entries;
Status status = Status::OK();
int64 total_size = 0;
while (status.ok()) {
string buffer;
status = s->ReadNBytes(chunk_size, &buffer);
if (status.ok() || errors::IsOutOfRange(status)) {
total_size += buffer.size();
entries.emplace_back(std::move(buffer));
}
}
if (!errors::IsOutOfRange(status)) {
return status;
}
Tensor value_tensor(ctx->allocator({}), DT_STRING, {1});
if (entries.size() == 1) {
value_tensor.flat<string>()((*record_read)) = std::move(entries[0]);
} else {
string buffer;
buffer.reserve(total_size);
for (size_t i = 0; i < entries.size(); ++i) {
buffer.append(entries[i]);
}
value_tensor.flat<string>()((*record_read)) = std::move(buffer);
}
(*record_read)++;
out_tensors->emplace_back(std::move(value_tensor));
return Status::OK();
}
Status FromStream(io::InputStreamInterface* s) override {
return Status::OK();
}
void EncodeAttributes(VariantTensorData* data) const override {
}
bool DecodeAttributes(const VariantTensorData& data) override {
return true;
}
};

REGISTER_UNARY_VARIANT_DECODE_FUNCTION(FileContentInput, "tensorflow::data::FileContentInput");

REGISTER_KERNEL_BUILDER(Name("FileInput").Device(DEVICE_CPU),
FileInputOp<FileContentInput>);
REGISTER_KERNEL_BUILDER(Name("FileDataset").Device(DEVICE_CPU),
FileInputDatasetOp<FileContentInput, bool>);
} // namespace data
} // namespace tensorflow
46 changes: 46 additions & 0 deletions tensorflow_io/core/ops/file_ops.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include "tensorflow/core/framework/common_shape_fns.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/shape_inference.h"

namespace tensorflow {

REGISTER_OP("FileInput")
.Input("source: string")
.Output("handle: variant")
.Attr("filters: list(string) = []")
.Attr("columns: list(string) = []")
.Attr("schema: string = ''")
.SetShapeFn([](shape_inference::InferenceContext* c) {
c->set_output(0, c->MakeShape({c->UnknownDim()}));
return Status::OK();
});

REGISTER_OP("FileDataset")
.Input("input: T")
.Input("batch: int64")
.Output("handle: variant")
.Attr("output_types: list(type) >= 1")
.Attr("output_shapes: list(shape) >= 1")
.Attr("T: {string, variant} = DT_VARIANT")
.SetIsStateful()
.SetShapeFn([](shape_inference::InferenceContext* c) {
c->set_output(0, c->MakeShape({}));
return Status::OK();
});

} // namespace tensorflow
16 changes: 16 additions & 0 deletions tensorflow_io/core/python/ops/data_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from __future__ import print_function

import tensorflow as tf
from tensorflow_io.core.python.ops import core_ops

# Note: BaseDataset could be used by Dataset implementations
# that does not utilize DataInput implementation.
Expand Down Expand Up @@ -60,3 +61,18 @@ def __init__(self, fn, data_input, batch, dtypes, shapes):
self._batch,
output_types=self._dtypes,
output_shapes=self._shapes), self._batch, self._dtypes, self._shapes)

class FileDataset(BaseDataset):
"""A FileDataset that reads file content as string"""

def __init__(self, filename):
"""Create a FileDataset."""
self._data_input = core_ops.file_input(filename)
self._batch = 0
self._dtypes = [tf.string]
self._shapes = [tf.TensorShape([])]
super(FileDataset, self).__init__(core_ops.file_dataset(
self._data_input,
self._batch,
output_types=self._dtypes,
output_shapes=self._shapes), self._batch, self._dtypes, self._shapes)
145 changes: 0 additions & 145 deletions tensorflow_io/image/kernels/webp_dataset_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,149 +23,6 @@ limitations under the License.
namespace tensorflow {
namespace data {
namespace {
class WebPDatasetOp : public DatasetOpKernel {
public:
using DatasetOpKernel::DatasetOpKernel;
explicit WebPDatasetOp(OpKernelConstruction* ctx)
: DatasetOpKernel(ctx) {
}
void MakeDataset(OpKernelContext* ctx, DatasetBase** output) override {
const Tensor* filenames_tensor;
OP_REQUIRES_OK(ctx, ctx->input("filenames", &filenames_tensor));
OP_REQUIRES(
ctx, filenames_tensor->dims() <= 1,
errors::InvalidArgument("`filenames` must be a scalar or a vector."));

std::vector<string> filenames;
filenames.reserve(filenames_tensor->NumElements());
for (int i = 0; i < filenames_tensor->NumElements(); ++i) {
filenames.push_back(filenames_tensor->flat<string>()(i));
}

*output = new Dataset(ctx, filenames);
}

private:
class Dataset : public DatasetBase {
public:
Dataset(OpKernelContext* ctx, const std::vector<string>& filenames)
: DatasetBase(DatasetContext(ctx)),
filenames_(filenames) {}

std::unique_ptr<IteratorBase> MakeIteratorInternal(
const string& prefix) const override {
return std::unique_ptr<IteratorBase>(
new Iterator({this, strings::StrCat(prefix, "::WebP")}));
}

const DataTypeVector& output_dtypes() const override {
static DataTypeVector* dtypes = new DataTypeVector({DT_UINT8});
return *dtypes;
}

const std::vector<PartialTensorShape>& output_shapes() const override {
static std::vector<PartialTensorShape>* shapes =
new std::vector<PartialTensorShape>({{-1, -1, -1}});
return *shapes;
}

string DebugString() const override {
return "WebPDatasetOp::Dataset";
}

protected:
Status AsGraphDefInternal(SerializationContext* ctx,
DatasetGraphDefBuilder* b,
Node** output) const override {
Node* filenames = nullptr;
TF_RETURN_IF_ERROR(b->AddVector(filenames_, &filenames));
TF_RETURN_IF_ERROR(b->AddDataset(this, {filenames}, output));
return Status::OK();
}

private:
class Iterator : public DatasetIterator<Dataset> {
public:
explicit Iterator(const Params& params)
: DatasetIterator<Dataset>(params) {}

Status GetNextInternal(IteratorContext* ctx,
std::vector<Tensor>* out_tensors,
bool* end_of_sequence) override {
mutex_lock l(mu_);
if (current_file_index_ == 0) {
TF_RETURN_IF_ERROR(SetupStreamsLocked(ctx->env()));
}
if (current_file_index_ < dataset()->filenames_.size()) {
const string& filename = dataset()->filenames_[current_file_index_];
uint64 size = 0;
TF_RETURN_IF_ERROR(ctx->env()->GetFileSize(filename, &size));
std::unique_ptr<RandomAccessFile> file;
TF_RETURN_IF_ERROR(ctx->env()->NewRandomAccessFile(filename, &file));
std::unique_ptr<io::RandomAccessInputStream> stream(new io::RandomAccessInputStream(file.get()));
string data;
TF_RETURN_IF_ERROR(stream->ReadNBytes(size, &data));
WebPDecoderConfig config;
WebPInitDecoderConfig(&config);
int returned = WebPGetFeatures(reinterpret_cast<const uint8_t *>(data.c_str()),
size, &config.input);
if (returned != VP8_STATUS_OK) {
return errors::InvalidArgument("File could not be decoded as WebP: ", filename);
}
// TODO (yongtang): Set channel = 4 for now.
static const int channel = 4;
Tensor value_tensor(ctx->allocator({}), DT_UINT8, {config.input.height, config.input.width, channel});

config.output.colorspace = MODE_RGBA;
config.output.u.RGBA.rgba = value_tensor.flat<uint8_t>().data();
config.output.u.RGBA.stride = config.input.width * channel;
config.output.u.RGBA.size = config.input.height * config.input.width * channel;
config.output.is_external_memory = 1;
returned = WebPDecode(reinterpret_cast<const uint8_t *>(data.c_str()), size, &config);
if (returned != VP8_STATUS_OK) {
return errors::InvalidArgument("File could not be decoded as WebP: ", filename);
}
out_tensors->emplace_back(std::move(value_tensor));
*end_of_sequence = false;
++current_file_index_;
return Status::OK();
}
*end_of_sequence = true;
return Status::OK();
}

protected:
Status SaveInternal(IteratorStateWriter* writer) override {
return errors::Unimplemented("SaveInternal is currently not supported");
}

Status RestoreInternal(IteratorContext* ctx,
IteratorStateReader* reader) override {
return errors::Unimplemented(
"RestoreInternal is currently not supported");
}

private:
// Sets up WebP streams to read from the topic at
// `current_file_index_`.
Status SetupStreamsLocked(Env* env) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
if (current_file_index_ >= dataset()->filenames_.size()) {
return errors::InvalidArgument(
"current_file_index_:", current_file_index_,
" >= filenames_.size():", dataset()->filenames_.size());
}
return Status::OK();
}

mutex mu_;
size_t current_file_index_ GUARDED_BY(mu_) = 0;
};

const std::vector<string> filenames_;
const DataTypeVector output_types_;
};
DataTypeVector output_types_;
};

class DecodeWebPOp : public OpKernel {
public:
Expand Down Expand Up @@ -205,8 +62,6 @@ class DecodeWebPOp : public OpKernel {
// TODO (yongtang): Set channels_ = 4 for now.
static const int channels_ = 4;
};
REGISTER_KERNEL_BUILDER(Name("WebPDataset").Device(DEVICE_CPU),
WebPDatasetOp);

REGISTER_KERNEL_BUILDER(Name("DecodeWebP").Device(DEVICE_CPU),
DecodeWebPOp);
Expand Down
9 changes: 0 additions & 9 deletions tensorflow_io/image/ops/dataset_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,6 @@ limitations under the License.

namespace tensorflow {

REGISTER_OP("WebPDataset")
.Input("filenames: string")
.Output("handle: variant")
.SetIsStateful()
.SetShapeFn([](shape_inference::InferenceContext* c) {
c->set_output(0, c->MakeShape({c->UnknownDim(), c->UnknownDim(), c->UnknownDim()}));
return Status::OK();
});

REGISTER_OP("TIFFDataset")
.Input("filenames: string")
.Output("handle: variant")
Expand Down
Loading