Skip to content

Commit d90ef55

Browse files
committed
Add from_json, and allows multiple columns
Signed-off-by: Yong Tang <[email protected]>
1 parent 6aa7ba1 commit d90ef55

File tree

11 files changed

+489
-83
lines changed

11 files changed

+489
-83
lines changed

WORKSPACE

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -542,3 +542,13 @@ http_archive(
542542
"https://github.com/google/double-conversion/archive/3992066a95b823efc8ccc1baf82a1cfc73f6e9b8.zip",
543543
],
544544
)
545+
546+
http_archive(
547+
name = "rapidjson",
548+
build_file = "//third_party:rapidjson.BUILD",
549+
sha256 = "bf7ced29704a1e696fbccf2a2b4ea068e7774fa37f6d7dd4039d0787f8bed98e",
550+
strip_prefix = "rapidjson-1.1.0",
551+
urls = [
552+
"https://github.com/miloyip/rapidjson/archive/v1.1.0.tar.gz",
553+
],
554+
)

tensorflow_io/arrow/kernels/arrow_kernels.h

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,8 @@ class ArrowRandomAccessFile : public ::arrow::io::RandomAccessFile {
3030
public:
3131
explicit ArrowRandomAccessFile(tensorflow::RandomAccessFile *file, int64 size)
3232
: file_(file)
33-
, size_(size) { }
33+
, size_(size)
34+
, position_(0) { }
3435

3536
~ArrowRandomAccessFile() {}
3637
arrow::Status Close() override {
@@ -40,13 +41,21 @@ class ArrowRandomAccessFile : public ::arrow::io::RandomAccessFile {
4041
return false;
4142
}
4243
arrow::Status Tell(int64_t* position) const override {
43-
return arrow::Status::NotImplemented("Tell");
44+
*position = position_;
45+
return arrow::Status::OK();
4446
}
4547
arrow::Status Seek(int64_t position) override {
4648
return arrow::Status::NotImplemented("Seek");
4749
}
4850
arrow::Status Read(int64_t nbytes, int64_t* bytes_read, void* out) override {
49-
return arrow::Status::NotImplemented("Read (void*)");
51+
StringPiece result;
52+
Status status = file_->Read(position_, nbytes, &result, (char*)out);
53+
if (!(status.ok() || errors::IsOutOfRange(status))) {
54+
return arrow::Status::IOError(status.error_message());
55+
}
56+
*bytes_read = result.size();
57+
position_ += (*bytes_read);
58+
return arrow::Status::OK();
5059
}
5160
arrow::Status Read(int64_t nbytes, std::shared_ptr<arrow::Buffer>* out) override {
5261
return arrow::Status::NotImplemented("Read (Buffer*)");
@@ -81,6 +90,7 @@ class ArrowRandomAccessFile : public ::arrow::io::RandomAccessFile {
8190
private:
8291
tensorflow::RandomAccessFile* file_;
8392
int64 size_;
93+
int64 position_;
8494
};
8595
} // namespace data
8696
} // namespace tensorflow

tensorflow_io/audio/ops/audio_ops.cc

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,9 +42,19 @@ REGISTER_OP("WAVIndexableGetItem")
4242
.Input("stop: int64")
4343
.Input("step: int64")
4444
.Output("output: dtype")
45-
.Attr("dtype: {int8, int16, int32}")
45+
.Attr("dtype: list(type) >= 1")
46+
.Attr("shape: list(shape) >= 1")
4647
.SetShapeFn([](shape_inference::InferenceContext* c) {
47-
c->set_output(0, c->UnknownShape());
48+
std::vector<PartialTensorShape> shape;
49+
TF_RETURN_IF_ERROR(c->GetAttr("shape", &shape));
50+
if (shape.size() != c->num_outputs()) {
51+
return errors::InvalidArgument("`shape` must be the same length as `types` (", shape.size(), " vs. ", c->num_outputs());
52+
}
53+
for (size_t i = 0; i < shape.size(); ++i) {
54+
shape_inference::ShapeHandle entry;
55+
TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape(shape[i], &entry));
56+
c->set_output(static_cast<int64>(i), entry);
57+
}
4858
return Status::OK();
4959
});
5060

tensorflow_io/audio/python/ops/audio_ops.py

Lines changed: 10 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,10 @@
2121
from tensorflow_io.core.python.ops import data_ops
2222
from tensorflow_io.core.python.ops import io_tensor
2323

24-
class WAVDataset(data_ops.BaseDataset):
24+
class WAVDataset(tf.compat.v2.data.Dataset):
2525
"""A WAV Dataset"""
2626

27-
def __init__(self, filename, batch=None, **kwargs):
27+
def __init__(self, filename, batch=None):
2828
"""Create a WAVDataset.
2929
3030
Args:
@@ -35,28 +35,16 @@ def __init__(self, filename, batch=None, **kwargs):
3535
raise NotImplementedError("WAVDataset only support eager mode")
3636

3737
self._wav = io_tensor.IOTensor.from_audio(filename)
38+
self._dataset = self._wav.to_dataset()
39+
super(WAVDataset, self).__init__(
40+
self._dataset._variant_tensor) # pylint: disable=protected-access
3841

39-
dtype = self._wav.dtype
40-
shape = self._wav.shape[1:]
41-
start = 0
42-
stop = self._wav.shape[0]
43-
44-
# capacity is the rough count for each chunk in dataset
45-
capacity = kwargs.get("capacity", 65536)
46-
entry_start = list(range(start, stop, capacity))
47-
entry_stop = entry_start[1:] + [stop]
48-
dataset = data_ops.BaseDataset.from_tensor_slices(
49-
(tf.constant(entry_start, tf.int64), tf.constant(entry_stop, tf.int64))
50-
).map(lambda start, stop: self._wav.__getitem__(slice(start, stop)))
51-
52-
dataset = dataset.apply(tf.data.experimental.unbatch())
53-
if batch != 0:
54-
dataset = dataset.batch(batch)
55-
shape = tf.TensorShape([None]).concatenate(shape)
42+
def _inputs(self):
43+
return []
5644

57-
self._dataset = dataset
58-
super(WAVDataset, self).__init__(
59-
self._dataset._variant_tensor, [dtype], [shape]) # pylint: disable=protected-access
45+
@property
46+
def _element_structure(self):
47+
return self._dataset._element_structure # pylint: disable=protected-access
6048

6149
class AudioDataset(data_ops.Dataset):
6250
"""A Audio File Dataset that reads the audio file."""

0 commit comments

Comments
 (0)