Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .kokorun/io_cpu.sh
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@ python --version
python -m pip --version
docker --version

PYTHON_VERSION=$(python -c 'import sys; print(sys.version_info[0])')

## Set test services
bash -x -e tests/test_ignite/start_ignite.sh
bash -x -e tests/test_kafka/kafka_test.sh start kafka
Expand Down
10 changes: 10 additions & 0 deletions WORKSPACE
Original file line number Diff line number Diff line change
Expand Up @@ -542,3 +542,13 @@ http_archive(
"https://github.com/google/double-conversion/archive/3992066a95b823efc8ccc1baf82a1cfc73f6e9b8.zip",
],
)

http_archive(
name = "rapidjson",
build_file = "//third_party:rapidjson.BUILD",
sha256 = "bf7ced29704a1e696fbccf2a2b4ea068e7774fa37f6d7dd4039d0787f8bed98e",
strip_prefix = "rapidjson-1.1.0",
urls = [
"https://github.com/miloyip/rapidjson/archive/v1.1.0.tar.gz",
],
)
2 changes: 1 addition & 1 deletion tensorflow_io/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,4 @@
from __future__ import division
from __future__ import print_function

import tensorflow as tf
from tensorflow_io.core.python.ops.io_tensor import IOTensor
13 changes: 5 additions & 8 deletions tensorflow_io/arrow/kernels/arrow_dataset_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ limitations under the License.
==============================================================================*/

#include "arrow/api.h"
#include "arrow/adapters/tensorflow/convert.h"
#include "arrow/ipc/api.h"
#include "arrow/util/io-util.h"
#include "tensorflow/core/framework/dataset.h"
Expand Down Expand Up @@ -99,8 +98,10 @@ class ArrowColumnTypeChecker : public arrow::TypeVisitor {
// Check scalar types with arrow::adapters::tensorflow
arrow::Status CheckScalarType(std::shared_ptr<arrow::DataType> scalar_type) {
DataType converted_type;
ARROW_RETURN_NOT_OK(arrow::adapters::tensorflow::GetTensorFlowType(
scalar_type, &converted_type));
::tensorflow::Status status = GetTensorFlowType(scalar_type, &converted_type);
if (!status.ok()) {
return ::arrow::Status::Invalid(status);
}
if (converted_type != expected_type_) {
return arrow::Status::TypeError(
"Arrow type mismatch: expected dtype=" +
Expand Down Expand Up @@ -523,11 +524,7 @@ class ArrowOpKernelBase : public DatasetOpKernel {
OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_));
for (const DataType& dt : output_types_) {
std::shared_ptr<arrow::DataType> arrow_type;
auto status = arrow::adapters::tensorflow::GetArrowType(dt, &arrow_type);
OP_REQUIRES(ctx, status.ok(),
errors::InvalidArgument(
"Arrow type is unsupported for output_type dtype=" +
std::to_string(dt)));
OP_REQUIRES_OK(ctx, GetArrowType(dt, &arrow_type));
}
for (const PartialTensorShape& pts : output_shapes_) {
OP_REQUIRES(ctx, -1 <= pts.dims() && pts.dims() <= 2,
Expand Down
18 changes: 18 additions & 0 deletions tensorflow_io/arrow/kernels/arrow_kernels.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,27 @@ limitations under the License.
#include "arrow/ipc/feather.h"
#include "arrow/ipc/feather_generated.h"
#include "arrow/buffer.h"
#include "arrow/adapters/tensorflow/convert.h"

namespace tensorflow {
namespace data {

Status GetTensorFlowType(std::shared_ptr<::arrow::DataType> dtype, ::tensorflow::DataType* out) {
::arrow::Status status = ::arrow::adapters::tensorflow::GetTensorFlowType(dtype, out);
if (!status.ok()) {
return errors::InvalidArgument("arrow data type ", dtype, " is not supported: ", status);
}
return Status::OK();
}

Status GetArrowType(::tensorflow::DataType dtype, std::shared_ptr<::arrow::DataType>* out) {
::arrow::Status status = ::arrow::adapters::tensorflow::GetArrowType(dtype, out);
if (!status.ok()) {
return errors::InvalidArgument("tensorflow data type ", dtype, " is not supported: ", status);
}
return Status::OK();
}

namespace {

class ListFeatherColumnsOp : public OpKernel {
Expand Down
22 changes: 19 additions & 3 deletions tensorflow_io/arrow/kernels/arrow_kernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,18 +19,23 @@ limitations under the License.
#include "kernels/stream.h"
#include "arrow/io/api.h"
#include "arrow/buffer.h"
#include "arrow/type.h"

namespace tensorflow {
namespace data {

Status GetTensorFlowType(std::shared_ptr<::arrow::DataType> dtype, ::tensorflow::DataType* out);
Status GetArrowType(::tensorflow::DataType dtype, std::shared_ptr<::arrow::DataType>* out);

// NOTE: Both SizedRandomAccessFile and ArrowRandomAccessFile overlap
// with another PR. Will remove duplicate once PR merged

class ArrowRandomAccessFile : public ::arrow::io::RandomAccessFile {
public:
explicit ArrowRandomAccessFile(tensorflow::RandomAccessFile *file, int64 size)
: file_(file)
, size_(size) { }
, size_(size)
, position_(0) { }

~ArrowRandomAccessFile() {}
arrow::Status Close() override {
Expand All @@ -40,13 +45,21 @@ class ArrowRandomAccessFile : public ::arrow::io::RandomAccessFile {
return false;
}
arrow::Status Tell(int64_t* position) const override {
return arrow::Status::NotImplemented("Tell");
*position = position_;
return arrow::Status::OK();
}
arrow::Status Seek(int64_t position) override {
return arrow::Status::NotImplemented("Seek");
}
arrow::Status Read(int64_t nbytes, int64_t* bytes_read, void* out) override {
return arrow::Status::NotImplemented("Read (void*)");
StringPiece result;
Status status = file_->Read(position_, nbytes, &result, (char*)out);
if (!(status.ok() || errors::IsOutOfRange(status))) {
return arrow::Status::IOError(status.error_message());
}
*bytes_read = result.size();
position_ += (*bytes_read);
return arrow::Status::OK();
}
arrow::Status Read(int64_t nbytes, std::shared_ptr<arrow::Buffer>* out) override {
return arrow::Status::NotImplemented("Read (Buffer*)");
Expand Down Expand Up @@ -81,7 +94,10 @@ class ArrowRandomAccessFile : public ::arrow::io::RandomAccessFile {
private:
tensorflow::RandomAccessFile* file_;
int64 size_;
int64 position_;
};


} // namespace data
} // namespace tensorflow

Expand Down
6 changes: 0 additions & 6 deletions tensorflow_io/audio/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,24 +15,18 @@
"""Audio Dataset.

@@WAVDataset
@@list_wav_info
@@read_wav
"""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from tensorflow_io.audio.python.ops.audio_ops import WAVDataset
from tensorflow_io.audio.python.ops.audio_ops import list_wav_info
from tensorflow_io.audio.python.ops.audio_ops import read_wav

from tensorflow.python.util.all_util import remove_undocumented

_allowed_symbols = [
"WAVDataset",
"list_wav_info",
"read_wav",
]

remove_undocumented(__name__)
Loading