Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 45 additions & 24 deletions tensorflow_io/arrow/kernels/arrow_dataset_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,12 @@ limitations under the License.
#include "arrow/adapters/tensorflow/convert.h"
#include "arrow/ipc/api.h"
#include "arrow/util/io-util.h"
#include "tensorflow_io/arrow/kernels/arrow_stream_client.h"
#include "tensorflow_io/arrow/kernels/arrow_util.h"
#include "tensorflow/core/framework/dataset.h"
#include "tensorflow/core/graph/graph.h"
#include "tensorflow_io/core/kernels/stream.h"
#include "tensorflow_io/arrow/kernels/arrow_kernels.h"
#include "tensorflow_io/arrow/kernels/arrow_stream_client.h"
#include "tensorflow_io/arrow/kernels/arrow_util.h"

#define CHECK_ARROW(arrow_status) \
do { \
Expand All @@ -31,6 +33,7 @@ limitations under the License.
} while (false)

namespace tensorflow {
namespace data {

enum ArrowBatchMode {
BATCH_KEEP_REMAINDER,
Expand Down Expand Up @@ -294,7 +297,7 @@ class ArrowDatasetBase : public DatasetBase {

// If in initial state, setup and read first batch
if (current_batch_ == nullptr && current_row_idx_ == 0) {
TF_RETURN_IF_ERROR(SetupStreamsLocked());
TF_RETURN_IF_ERROR(SetupStreamsLocked(ctx->env()));
}

std::vector<Tensor>* result_tensors = out_tensors;
Expand All @@ -309,7 +312,7 @@ class ArrowDatasetBase : public DatasetBase {
// Try to go to next batch if consumed all rows in current batch
if (current_batch_ != nullptr &&
current_row_idx_ >= current_batch_->num_rows()) {
TF_RETURN_IF_ERROR(NextStreamLocked());
TF_RETURN_IF_ERROR(NextStreamLocked(ctx->env()));
}

// Check if reached end of stream
Expand Down Expand Up @@ -465,11 +468,12 @@ class ArrowDatasetBase : public DatasetBase {
}

// Setup Arrow record batch consumer and initialze current_batch_
virtual Status SetupStreamsLocked() EXCLUSIVE_LOCKS_REQUIRED(mu_) = 0;
virtual Status SetupStreamsLocked(Env* env)
EXCLUSIVE_LOCKS_REQUIRED(mu_) = 0;

// Get the next Arrow record batch, if available. If not then
// current_batch_ will be set to nullptr to indicate no further batches.
virtual Status NextStreamLocked() EXCLUSIVE_LOCKS_REQUIRED(mu_) {
virtual Status NextStreamLocked(Env* env) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
current_batch_ = nullptr;
current_row_idx_ = 0;
return Status::OK();
Expand Down Expand Up @@ -678,7 +682,8 @@ class ArrowZeroCopyDatasetOp : public ArrowOpKernelBase {
: ArrowBaseIterator<Dataset>(params) {}

private:
Status SetupStreamsLocked() EXCLUSIVE_LOCKS_REQUIRED(mu_) override {
Status SetupStreamsLocked(Env* env)
EXCLUSIVE_LOCKS_REQUIRED(mu_) override {
buffer_ = std::make_shared<arrow::Buffer>(
dataset()->buffer_ptr_,
dataset()->buffer_size_);
Expand All @@ -697,8 +702,9 @@ class ArrowZeroCopyDatasetOp : public ArrowOpKernelBase {
return Status::OK();
}

Status NextStreamLocked() EXCLUSIVE_LOCKS_REQUIRED(mu_) override {
ArrowBaseIterator<Dataset>::NextStreamLocked();
Status NextStreamLocked(Env* env)
EXCLUSIVE_LOCKS_REQUIRED(mu_) override {
ArrowBaseIterator<Dataset>::NextStreamLocked(env);
if (++current_batch_idx_ < num_batches_) {
CHECK_ARROW(
reader_->ReadRecordBatch(current_batch_idx_, &current_batch_));
Expand Down Expand Up @@ -818,7 +824,8 @@ class ArrowSerializedDatasetOp : public ArrowOpKernelBase {
: ArrowBaseIterator<Dataset>(params) {}

private:
Status SetupStreamsLocked() EXCLUSIVE_LOCKS_REQUIRED(mu_) override {
Status SetupStreamsLocked(Env* env)
EXCLUSIVE_LOCKS_REQUIRED(mu_) override {
const string& batches = dataset()->batches_.scalar<string>()();
auto buffer = std::make_shared<arrow::Buffer>(batches);
auto buffer_reader = std::make_shared<arrow::io::BufferReader>(buffer);
Expand All @@ -833,8 +840,9 @@ class ArrowSerializedDatasetOp : public ArrowOpKernelBase {
return Status::OK();
}

Status NextStreamLocked() EXCLUSIVE_LOCKS_REQUIRED(mu_) override {
ArrowBaseIterator<Dataset>::NextStreamLocked();
Status NextStreamLocked(Env* env)
EXCLUSIVE_LOCKS_REQUIRED(mu_) override {
ArrowBaseIterator<Dataset>::NextStreamLocked(env);
if (++current_batch_idx_ < num_batches_) {
CHECK_ARROW(
reader_->ReadRecordBatch(current_batch_idx_, &current_batch_));
Expand Down Expand Up @@ -864,8 +872,6 @@ class ArrowSerializedDatasetOp : public ArrowOpKernelBase {
// ideal for simple writing of Pandas DataFrames.
class ArrowFeatherDatasetOp : public ArrowOpKernelBase {
public:
//using DatasetOpKernel::DatasetOpKernel;

explicit ArrowFeatherDatasetOp(OpKernelConstruction* ctx)
: ArrowOpKernelBase(ctx) {}

Expand Down Expand Up @@ -951,10 +957,22 @@ class ArrowFeatherDatasetOp : public ArrowOpKernelBase {
: ArrowBaseIterator<Dataset>(params) {}

private:
Status SetupStreamsLocked() EXCLUSIVE_LOCKS_REQUIRED(mu_) override {
Status SetupStreamsLocked(Env* env)
EXCLUSIVE_LOCKS_REQUIRED(mu_) override {
const string& filename = dataset()->filenames_[current_file_idx_];
std::shared_ptr<arrow::io::ReadableFile> in_file;
CHECK_ARROW(arrow::io::ReadableFile::Open(filename, &in_file));

// Init a TF file from the filename and determine size
// TODO: set optional memory to nullptr until input arg is added
std::shared_ptr<SizedRandomAccessFile> tf_file(
new SizedRandomAccessFile(env, filename, nullptr, 0));
uint64 size;
TF_RETURN_IF_ERROR(tf_file->GetFileSize(&size));

// Wrap the TF file in Arrow interface to be used in Feather reader
std::shared_ptr<ArrowRandomAccessFile> in_file(
new ArrowRandomAccessFile(tf_file.get(), size));

// Create the Feather reader
std::unique_ptr<arrow::ipc::feather::TableReader> reader;
CHECK_ARROW(arrow::ipc::feather::TableReader::Open(in_file, &reader));

Expand Down Expand Up @@ -982,14 +1000,15 @@ class ArrowFeatherDatasetOp : public ArrowOpKernelBase {
return Status::OK();
}

Status NextStreamLocked() EXCLUSIVE_LOCKS_REQUIRED(mu_) override {
ArrowBaseIterator<Dataset>::NextStreamLocked();
Status NextStreamLocked(Env* env)
EXCLUSIVE_LOCKS_REQUIRED(mu_) override {
ArrowBaseIterator<Dataset>::NextStreamLocked(env);
if (++current_batch_idx_ < record_batches_.size()) {
current_batch_ = record_batches_[current_batch_idx_];
} else if (++current_file_idx_ < dataset()->filenames_.size()) {
current_batch_idx_ = 0;
record_batches_.clear();
SetupStreamsLocked();
return SetupStreamsLocked(env);
}
return Status::OK();
}
Expand Down Expand Up @@ -1102,7 +1121,7 @@ class ArrowStreamDatasetOp : public ArrowOpKernelBase {
: ArrowBaseIterator<Dataset>(params) {}

private:
Status SetupStreamsLocked()
Status SetupStreamsLocked(Env* env)
EXCLUSIVE_LOCKS_REQUIRED(mu_) override {
const string& endpoint = dataset()->endpoints_[current_endpoint_idx_];
string endpoint_type;
Expand All @@ -1128,13 +1147,14 @@ class ArrowStreamDatasetOp : public ArrowOpKernelBase {
return Status::OK();
}

Status NextStreamLocked() EXCLUSIVE_LOCKS_REQUIRED(mu_) override {
ArrowBaseIterator<Dataset>::NextStreamLocked();
Status NextStreamLocked(Env* env)
EXCLUSIVE_LOCKS_REQUIRED(mu_) override {
ArrowBaseIterator<Dataset>::NextStreamLocked(env);
CHECK_ARROW(reader_->ReadNext(&current_batch_));
if (current_batch_ == nullptr &&
++current_endpoint_idx_ < dataset()->endpoints_.size()) {
reader_.reset();
SetupStreamsLocked();
SetupStreamsLocked(env);
}
return Status::OK();
}
Expand Down Expand Up @@ -1167,4 +1187,5 @@ REGISTER_KERNEL_BUILDER(Name("ArrowFeatherDataset").Device(DEVICE_CPU),
REGISTER_KERNEL_BUILDER(Name("ArrowStreamDataset").Device(DEVICE_CPU),
ArrowStreamDatasetOp);

} // namespace data
} // namespace tensorflow
5 changes: 5 additions & 0 deletions tensorflow_io/arrow/kernels/arrow_kernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#ifndef TENSORFLOW_IO_ARROW_KERNELS_H_
#define TENSORFLOW_IO_ARROW_KERNELS_H_

#include "kernels/stream.h"
#include "arrow/io/api.h"
#include "arrow/buffer.h"
Expand Down Expand Up @@ -78,3 +81,5 @@ class ArrowRandomAccessFile : public ::arrow::io::RandomAccessFile {
};
} // namespace data
} // namespace tensorflow

#endif // TENSORFLOW_IO_ARROW_KERNELS_H_
5 changes: 5 additions & 0 deletions tensorflow_io/core/kernels/stream.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#ifndef TENSORFLOW_IO_CORE_KERNELS_STREAM_H_
#define TENSORFLOW_IO_CORE_KERNELS_STREAM_H_

#include "tensorflow/core/lib/io/inputstream_interface.h"
#include "tensorflow/core/lib/io/random_inputstream.h"

Expand Down Expand Up @@ -69,3 +72,5 @@ class SizedRandomAccessFile : public tensorflow::RandomAccessFile {

} // namespace data
} // namespace tensorflow

#endif // TENSORFLOW_IO_CORE_KERNELS_STREAM_H_
8 changes: 8 additions & 0 deletions tests/test_arrow_eager.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,14 @@ def test_arrow_feather_dataset(self):
truth_data.output_shapes)
self.run_test_case(dataset, truth_data)

# test single file with 'file://' prefix
dataset = arrow_io.ArrowFeatherDataset(
"file://{}".format(f.name),
list(range(len(truth_data.output_types))),
truth_data.output_types,
truth_data.output_shapes)
self.run_test_case(dataset, truth_data)

# test multiple files
dataset = arrow_io.ArrowFeatherDataset(
[f.name, f.name],
Expand Down