Skip to content

Commit 5bbb94c

Browse files
authored
Wrap the IOTensor function as a callable class (#491)
* Remove step arg as it could not be supported in C++ anyway Signed-off-by: Yong Tang <[email protected]> * Fix audio and lmdb Signed-off-by: Yong Tang <[email protected]> * Fix prometheus Signed-off-by: Yong Tang <[email protected]> * Fix Kafka Signed-off-by: Yong Tang <[email protected]> * Fix CSV Signed-off-by: Yong Tang <[email protected]> * Fix HDF5 Signed-off-by: Yong Tang <[email protected]> * Fix JSON Signed-off-by: Yong Tang <[email protected]> * Pylint fix Signed-off-by: Yong Tang <[email protected]> * Fix feather Signed-off-by: Yong Tang <[email protected]> * LMDB naming change Signed-off-by: Yong Tang <[email protected]> * Remove the need to specify label Signed-off-by: Yong Tang <[email protected]> * Fix Kafka Signed-off-by: Yong Tang <[email protected]> * Make output name consistent Signed-off-by: Yong Tang <[email protected]>
1 parent 55d9750 commit 5bbb94c

26 files changed

+521
-498
lines changed

tensorflow_io/arrow/kernels/arrow_kernels.cc

Lines changed: 10 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -280,14 +280,14 @@ class FeatherIndexable : public IOIndexableInterface {
280280

281281
return Status::OK();
282282
}
283-
Status Component(Tensor* component) override {
284-
*component = Tensor(DT_STRING, TensorShape({static_cast<int64>(columns_.size())}));
283+
Status Components(Tensor* components) override {
284+
*components = Tensor(DT_STRING, TensorShape({static_cast<int64>(columns_.size())}));
285285
for (size_t i = 0; i < columns_.size(); i++) {
286-
component->flat<string>()(i) = columns_[i];
286+
components->flat<string>()(i) = columns_[i];
287287
}
288288
return Status::OK();
289289
}
290-
Status Spec(const Tensor& component, PartialTensorShape* shape, DataType* dtype) override {
290+
Status Spec(const Tensor& component, PartialTensorShape* shape, DataType* dtype, bool label) override {
291291
if (columns_index_.find(component.scalar<string>()()) == columns_index_.end()) {
292292
return errors::InvalidArgument("component ", component.scalar<string>()(), " is invalid");
293293
}
@@ -297,10 +297,7 @@ class FeatherIndexable : public IOIndexableInterface {
297297
return Status::OK();
298298
}
299299

300-
Status GetItem(const int64 start, const int64 stop, const int64 step, const Tensor& component, Tensor* tensor) override {
301-
if (step != 1) {
302-
return errors::InvalidArgument("step ", step, " is not supported");
303-
}
300+
Status Read(const int64 start, const int64 stop, const Tensor& component, Tensor* value, Tensor* label) override {
304301
if (columns_index_.find(component.scalar<string>()()) == columns_index_.end()) {
305302
return errors::InvalidArgument("component ", component.scalar<string>()(), " is invalid");
306303
}
@@ -326,12 +323,12 @@ class FeatherIndexable : public IOIndexableInterface {
326323
int64 curr_index = 0; \
327324
for (auto chunk : slice->data()->chunks()) { \
328325
for (int64_t item = 0; item < chunk->length(); item++) { \
329-
tensor->flat<TTYPE>()(curr_index) = (dynamic_cast<ATYPE *>(chunk.get()))->Value(item); \
326+
value->flat<TTYPE>()(curr_index) = (dynamic_cast<ATYPE *>(chunk.get()))->Value(item); \
330327
curr_index++; \
331328
} \
332329
} \
333330
}
334-
switch (tensor->dtype()) {
331+
switch (value->dtype()) {
335332
case DT_BOOL:
336333
FEATHER_PROCESS_TYPE(bool, ::arrow::BooleanArray);
337334
break;
@@ -366,7 +363,7 @@ class FeatherIndexable : public IOIndexableInterface {
366363
FEATHER_PROCESS_TYPE(double, ::arrow::NumericArray<::arrow::DoubleType>);
367364
break;
368365
default:
369-
return errors::InvalidArgument("data type is not supported: ", DataTypeString(tensor->dtype()));
366+
return errors::InvalidArgument("data type is not supported: ", DataTypeString(value->dtype()));
370367
}
371368

372369
return Status::OK();
@@ -394,7 +391,7 @@ REGISTER_KERNEL_BUILDER(Name("FeatherIndexableInit").Device(DEVICE_CPU),
394391
IOInterfaceInitOp<FeatherIndexable>);
395392
REGISTER_KERNEL_BUILDER(Name("FeatherIndexableSpec").Device(DEVICE_CPU),
396393
IOInterfaceSpecOp<FeatherIndexable>);
397-
REGISTER_KERNEL_BUILDER(Name("FeatherIndexableGetItem").Device(DEVICE_CPU),
398-
IOIndexableGetItemOp<FeatherIndexable>);
394+
REGISTER_KERNEL_BUILDER(Name("FeatherIndexableRead").Device(DEVICE_CPU),
395+
IOIndexableReadOp<FeatherIndexable>);
399396
} // namespace data
400397
} // namespace tensorflow

tensorflow_io/arrow/ops/dataset_ops.cc

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -102,8 +102,8 @@ REGISTER_OP("ListFeatherColumns")
102102

103103
REGISTER_OP("FeatherIndexableInit")
104104
.Input("input: string")
105-
.Output("output: resource")
106-
.Output("component: string")
105+
.Output("resource: resource")
106+
.Output("components: string")
107107
.Attr("container: string = ''")
108108
.Attr("shared_name: string = ''")
109109
.SetShapeFn([](shape_inference::InferenceContext* c) {
@@ -123,13 +123,12 @@ REGISTER_OP("FeatherIndexableSpec")
123123
return Status::OK();
124124
});
125125

126-
REGISTER_OP("FeatherIndexableGetItem")
126+
REGISTER_OP("FeatherIndexableRead")
127127
.Input("input: resource")
128128
.Input("start: int64")
129129
.Input("stop: int64")
130-
.Input("step: int64")
131130
.Input("component: string")
132-
.Output("output: dtype")
131+
.Output("value: dtype")
133132
.Attr("shape: shape")
134133
.Attr("dtype: type")
135134
.SetShapeFn([](shape_inference::InferenceContext* c) {

tensorflow_io/audio/kernels/audio_kernels.cc

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@ class WAVIndexable : public IOIndexableInterface {
130130

131131
return Status::OK();
132132
}
133-
Status Spec(const Tensor& component, PartialTensorShape* shape, DataType* dtype) override {
133+
Status Spec(const Tensor& component, PartialTensorShape* shape, DataType* dtype, bool label) override {
134134
*shape = shape_;
135135
*dtype = dtype_;
136136
return Status::OK();
@@ -144,11 +144,7 @@ class WAVIndexable : public IOIndexableInterface {
144144
return Status::OK();
145145
}
146146

147-
Status GetItem(const int64 start, const int64 stop, const int64 step, const Tensor& component, Tensor* tensor) override {
148-
if (step != 1) {
149-
return errors::InvalidArgument("step ", step, " is not supported");
150-
}
151-
147+
Status Read(const int64 start, const int64 stop, const Tensor& component, Tensor* value, Tensor* label) override {
152148
const int64 sample_start = start;
153149
const int64 sample_stop = stop;
154150

@@ -181,13 +177,13 @@ class WAVIndexable : public IOIndexableInterface {
181177
if (header_.wBitsPerSample * header_.nChannels != header_.nBlockAlign * 8) {
182178
return errors::InvalidArgument("unsupported wBitsPerSample and header.nBlockAlign: ", header_.wBitsPerSample, ", ", header_.nBlockAlign);
183179
}
184-
memcpy((char *)(tensor->flat<int8>().data()) + ((read_sample_start - sample_start) * header_.nBlockAlign), &buffer[0], (read_bytes_stop - read_bytes_start));
180+
memcpy((char *)(value->flat<int8>().data()) + ((read_sample_start - sample_start) * header_.nBlockAlign), &buffer[0], (read_bytes_stop - read_bytes_start));
185181
break;
186182
case 16:
187183
if (header_.wBitsPerSample * header_.nChannels != header_.nBlockAlign * 8) {
188184
return errors::InvalidArgument("unsupported wBitsPerSample and header.nBlockAlign: ", header_.wBitsPerSample, ", ", header_.nBlockAlign);
189185
}
190-
memcpy((char *)(tensor->flat<int16>().data()) + ((read_sample_start - sample_start) * header_.nBlockAlign), &buffer[0], (read_bytes_stop - read_bytes_start));
186+
memcpy((char *)(value->flat<int16>().data()) + ((read_sample_start - sample_start) * header_.nBlockAlign), &buffer[0], (read_bytes_stop - read_bytes_start));
191187
break;
192188
case 24:
193189
// NOTE: The conversion is from signed integer 24 to signed integer 32 (left shift 8 bits)
@@ -196,7 +192,7 @@ class WAVIndexable : public IOIndexableInterface {
196192
}
197193
for (int64 i = read_sample_start; i < read_sample_stop; i++) {
198194
for (int64 j = 0; j < header_.nChannels; j++) {
199-
char *data_p = (char *)(tensor->flat<int32>().data() + ((i - sample_start) * header_.nChannels + j));
195+
char *data_p = (char *)(value->flat<int32>().data() + ((i - sample_start) * header_.nChannels + j));
200196
char *read_p = (char *)(&buffer[((i - read_sample_start) * header_.nBlockAlign)]) + 3 * j;
201197
data_p[3] = read_p[2];
202198
data_p[2] = read_p[1];
@@ -235,8 +231,8 @@ REGISTER_KERNEL_BUILDER(Name("WAVIndexableInit").Device(DEVICE_CPU),
235231
IOInterfaceInitOp<WAVIndexable>);
236232
REGISTER_KERNEL_BUILDER(Name("WAVIndexableSpec").Device(DEVICE_CPU),
237233
IOInterfaceSpecOp<WAVIndexable>);
238-
REGISTER_KERNEL_BUILDER(Name("WAVIndexableGetItem").Device(DEVICE_CPU),
239-
IOIndexableGetItemOp<WAVIndexable>);
234+
REGISTER_KERNEL_BUILDER(Name("WAVIndexableRead").Device(DEVICE_CPU),
235+
IOIndexableReadOp<WAVIndexable>);
240236

241237

242238
} // namespace data

tensorflow_io/audio/ops/audio_ops.cc

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ namespace tensorflow {
2121

2222
REGISTER_OP("WAVIndexableInit")
2323
.Input("input: string")
24-
.Output("output: resource")
24+
.Output("resource: resource")
2525
.Attr("container: string = ''")
2626
.Attr("shared_name: string = ''")
2727
.SetShapeFn([](shape_inference::InferenceContext* c) {
@@ -42,13 +42,12 @@ REGISTER_OP("WAVIndexableSpec")
4242
return Status::OK();
4343
});
4444

45-
REGISTER_OP("WAVIndexableGetItem")
45+
REGISTER_OP("WAVIndexableRead")
4646
.Input("input: resource")
4747
.Input("start: int64")
4848
.Input("stop: int64")
49-
.Input("step: int64")
5049
.Input("component: int64")
51-
.Output("output: dtype")
50+
.Output("value: dtype")
5251
.Attr("shape: shape")
5352
.Attr("dtype: type")
5453
.SetShapeFn([](shape_inference::InferenceContext* c) {

0 commit comments

Comments
 (0)