Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 10 additions & 13 deletions tensorflow_io/arrow/kernels/arrow_kernels.cc
Original file line number Diff line number Diff line change
Expand Up @@ -280,14 +280,14 @@ class FeatherIndexable : public IOIndexableInterface {

return Status::OK();
}
Status Component(Tensor* component) override {
*component = Tensor(DT_STRING, TensorShape({static_cast<int64>(columns_.size())}));
Status Components(Tensor* components) override {
*components = Tensor(DT_STRING, TensorShape({static_cast<int64>(columns_.size())}));
for (size_t i = 0; i < columns_.size(); i++) {
component->flat<string>()(i) = columns_[i];
components->flat<string>()(i) = columns_[i];
}
return Status::OK();
}
Status Spec(const Tensor& component, PartialTensorShape* shape, DataType* dtype) override {
Status Spec(const Tensor& component, PartialTensorShape* shape, DataType* dtype, bool label) override {
if (columns_index_.find(component.scalar<string>()()) == columns_index_.end()) {
return errors::InvalidArgument("component ", component.scalar<string>()(), " is invalid");
}
Expand All @@ -297,10 +297,7 @@ class FeatherIndexable : public IOIndexableInterface {
return Status::OK();
}

Status GetItem(const int64 start, const int64 stop, const int64 step, const Tensor& component, Tensor* tensor) override {
if (step != 1) {
return errors::InvalidArgument("step ", step, " is not supported");
}
Status Read(const int64 start, const int64 stop, const Tensor& component, Tensor* value, Tensor* label) override {
if (columns_index_.find(component.scalar<string>()()) == columns_index_.end()) {
return errors::InvalidArgument("component ", component.scalar<string>()(), " is invalid");
}
Expand All @@ -326,12 +323,12 @@ class FeatherIndexable : public IOIndexableInterface {
int64 curr_index = 0; \
for (auto chunk : slice->data()->chunks()) { \
for (int64_t item = 0; item < chunk->length(); item++) { \
tensor->flat<TTYPE>()(curr_index) = (dynamic_cast<ATYPE *>(chunk.get()))->Value(item); \
value->flat<TTYPE>()(curr_index) = (dynamic_cast<ATYPE *>(chunk.get()))->Value(item); \
curr_index++; \
} \
} \
}
switch (tensor->dtype()) {
switch (value->dtype()) {
case DT_BOOL:
FEATHER_PROCESS_TYPE(bool, ::arrow::BooleanArray);
break;
Expand Down Expand Up @@ -366,7 +363,7 @@ class FeatherIndexable : public IOIndexableInterface {
FEATHER_PROCESS_TYPE(double, ::arrow::NumericArray<::arrow::DoubleType>);
break;
default:
return errors::InvalidArgument("data type is not supported: ", DataTypeString(tensor->dtype()));
return errors::InvalidArgument("data type is not supported: ", DataTypeString(value->dtype()));
}

return Status::OK();
Expand Down Expand Up @@ -394,7 +391,7 @@ REGISTER_KERNEL_BUILDER(Name("FeatherIndexableInit").Device(DEVICE_CPU),
IOInterfaceInitOp<FeatherIndexable>);
REGISTER_KERNEL_BUILDER(Name("FeatherIndexableSpec").Device(DEVICE_CPU),
IOInterfaceSpecOp<FeatherIndexable>);
REGISTER_KERNEL_BUILDER(Name("FeatherIndexableGetItem").Device(DEVICE_CPU),
IOIndexableGetItemOp<FeatherIndexable>);
REGISTER_KERNEL_BUILDER(Name("FeatherIndexableRead").Device(DEVICE_CPU),
IOIndexableReadOp<FeatherIndexable>);
} // namespace data
} // namespace tensorflow
9 changes: 4 additions & 5 deletions tensorflow_io/arrow/ops/dataset_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -102,8 +102,8 @@ REGISTER_OP("ListFeatherColumns")

REGISTER_OP("FeatherIndexableInit")
.Input("input: string")
.Output("output: resource")
.Output("component: string")
.Output("resource: resource")
.Output("components: string")
.Attr("container: string = ''")
.Attr("shared_name: string = ''")
.SetShapeFn([](shape_inference::InferenceContext* c) {
Expand All @@ -123,13 +123,12 @@ REGISTER_OP("FeatherIndexableSpec")
return Status::OK();
});

REGISTER_OP("FeatherIndexableGetItem")
REGISTER_OP("FeatherIndexableRead")
.Input("input: resource")
.Input("start: int64")
.Input("stop: int64")
.Input("step: int64")
.Input("component: string")
.Output("output: dtype")
.Output("value: dtype")
.Attr("shape: shape")
.Attr("dtype: type")
.SetShapeFn([](shape_inference::InferenceContext* c) {
Expand Down
18 changes: 7 additions & 11 deletions tensorflow_io/audio/kernels/audio_kernels.cc
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ class WAVIndexable : public IOIndexableInterface {

return Status::OK();
}
Status Spec(const Tensor& component, PartialTensorShape* shape, DataType* dtype) override {
Status Spec(const Tensor& component, PartialTensorShape* shape, DataType* dtype, bool label) override {
*shape = shape_;
*dtype = dtype_;
return Status::OK();
Expand All @@ -144,11 +144,7 @@ class WAVIndexable : public IOIndexableInterface {
return Status::OK();
}

Status GetItem(const int64 start, const int64 stop, const int64 step, const Tensor& component, Tensor* tensor) override {
if (step != 1) {
return errors::InvalidArgument("step ", step, " is not supported");
}

Status Read(const int64 start, const int64 stop, const Tensor& component, Tensor* value, Tensor* label) override {
const int64 sample_start = start;
const int64 sample_stop = stop;

Expand Down Expand Up @@ -181,13 +177,13 @@ class WAVIndexable : public IOIndexableInterface {
if (header_.wBitsPerSample * header_.nChannels != header_.nBlockAlign * 8) {
return errors::InvalidArgument("unsupported wBitsPerSample and header.nBlockAlign: ", header_.wBitsPerSample, ", ", header_.nBlockAlign);
}
memcpy((char *)(tensor->flat<int8>().data()) + ((read_sample_start - sample_start) * header_.nBlockAlign), &buffer[0], (read_bytes_stop - read_bytes_start));
memcpy((char *)(value->flat<int8>().data()) + ((read_sample_start - sample_start) * header_.nBlockAlign), &buffer[0], (read_bytes_stop - read_bytes_start));
break;
case 16:
if (header_.wBitsPerSample * header_.nChannels != header_.nBlockAlign * 8) {
return errors::InvalidArgument("unsupported wBitsPerSample and header.nBlockAlign: ", header_.wBitsPerSample, ", ", header_.nBlockAlign);
}
memcpy((char *)(tensor->flat<int16>().data()) + ((read_sample_start - sample_start) * header_.nBlockAlign), &buffer[0], (read_bytes_stop - read_bytes_start));
memcpy((char *)(value->flat<int16>().data()) + ((read_sample_start - sample_start) * header_.nBlockAlign), &buffer[0], (read_bytes_stop - read_bytes_start));
break;
case 24:
// NOTE: The conversion is from signed integer 24 to signed integer 32 (left shift 8 bits)
Expand All @@ -196,7 +192,7 @@ class WAVIndexable : public IOIndexableInterface {
}
for (int64 i = read_sample_start; i < read_sample_stop; i++) {
for (int64 j = 0; j < header_.nChannels; j++) {
char *data_p = (char *)(tensor->flat<int32>().data() + ((i - sample_start) * header_.nChannels + j));
char *data_p = (char *)(value->flat<int32>().data() + ((i - sample_start) * header_.nChannels + j));
char *read_p = (char *)(&buffer[((i - read_sample_start) * header_.nBlockAlign)]) + 3 * j;
data_p[3] = read_p[2];
data_p[2] = read_p[1];
Expand Down Expand Up @@ -235,8 +231,8 @@ REGISTER_KERNEL_BUILDER(Name("WAVIndexableInit").Device(DEVICE_CPU),
IOInterfaceInitOp<WAVIndexable>);
REGISTER_KERNEL_BUILDER(Name("WAVIndexableSpec").Device(DEVICE_CPU),
IOInterfaceSpecOp<WAVIndexable>);
REGISTER_KERNEL_BUILDER(Name("WAVIndexableGetItem").Device(DEVICE_CPU),
IOIndexableGetItemOp<WAVIndexable>);
REGISTER_KERNEL_BUILDER(Name("WAVIndexableRead").Device(DEVICE_CPU),
IOIndexableReadOp<WAVIndexable>);


} // namespace data
Expand Down
7 changes: 3 additions & 4 deletions tensorflow_io/audio/ops/audio_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ namespace tensorflow {

REGISTER_OP("WAVIndexableInit")
.Input("input: string")
.Output("output: resource")
.Output("resource: resource")
.Attr("container: string = ''")
.Attr("shared_name: string = ''")
.SetShapeFn([](shape_inference::InferenceContext* c) {
Expand All @@ -42,13 +42,12 @@ REGISTER_OP("WAVIndexableSpec")
return Status::OK();
});

REGISTER_OP("WAVIndexableGetItem")
REGISTER_OP("WAVIndexableRead")
.Input("input: resource")
.Input("start: int64")
.Input("stop: int64")
.Input("step: int64")
.Input("component: int64")
.Output("output: dtype")
.Output("value: dtype")
.Attr("shape: shape")
.Attr("dtype: type")
.SetShapeFn([](shape_inference::InferenceContext* c) {
Expand Down
Loading