Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 7 additions & 8 deletions tensorflow_io/audio/kernels/audio_kernels.cc
Original file line number Diff line number Diff line change
Expand Up @@ -130,11 +130,11 @@ class WAVIndexable : public IOIndexableInterface {

return Status::OK();
}
Status Spec(std::vector<DataType>& dtypes, std::vector<PartialTensorShape>& shapes) override {
dtypes.clear();
dtypes.push_back(dtype_);
Status Spec(std::vector<PartialTensorShape>& shapes, std::vector<DataType>& dtypes) override {
shapes.clear();
shapes.push_back(shape_);
dtypes.clear();
dtypes.push_back(dtype_);
return Status::OK();
}

Expand All @@ -146,8 +146,7 @@ class WAVIndexable : public IOIndexableInterface {
return Status::OK();
}

Status GetItem(const int64 start, const int64 stop, const int64 step, std::vector<Tensor>& tensors) override {
Tensor& output_tensor = tensors[0];
Status GetItem(const int64 start, const int64 stop, const int64 step, const int64 component, Tensor* tensor) override {
if (step != 1) {
return errors::InvalidArgument("step ", step, " is not supported");
}
Expand Down Expand Up @@ -184,13 +183,13 @@ class WAVIndexable : public IOIndexableInterface {
if (header_.wBitsPerSample * header_.nChannels != header_.nBlockAlign * 8) {
return errors::InvalidArgument("unsupported wBitsPerSample and header.nBlockAlign: ", header_.wBitsPerSample, ", ", header_.nBlockAlign);
}
memcpy((char *)(output_tensor.flat<int8>().data()) + ((read_sample_start - sample_start) * header_.nBlockAlign), &buffer[0], (read_bytes_stop - read_bytes_start));
memcpy((char *)(tensor->flat<int8>().data()) + ((read_sample_start - sample_start) * header_.nBlockAlign), &buffer[0], (read_bytes_stop - read_bytes_start));
break;
case 16:
if (header_.wBitsPerSample * header_.nChannels != header_.nBlockAlign * 8) {
return errors::InvalidArgument("unsupported wBitsPerSample and header.nBlockAlign: ", header_.wBitsPerSample, ", ", header_.nBlockAlign);
}
memcpy((char *)(output_tensor.flat<int16>().data()) + ((read_sample_start - sample_start) * header_.nBlockAlign), &buffer[0], (read_bytes_stop - read_bytes_start));
memcpy((char *)(tensor->flat<int16>().data()) + ((read_sample_start - sample_start) * header_.nBlockAlign), &buffer[0], (read_bytes_stop - read_bytes_start));
break;
case 24:
// NOTE: The conversion is from signed integer 24 to signed integer 32 (left shift 8 bits)
Expand All @@ -199,7 +198,7 @@ class WAVIndexable : public IOIndexableInterface {
}
for (int64 i = read_sample_start; i < read_sample_stop; i++) {
for (int64 j = 0; j < header_.nChannels; j++) {
char *data_p = (char *)(output_tensor.flat<int32>().data() + ((i - sample_start) * header_.nChannels + j));
char *data_p = (char *)(tensor->flat<int32>().data() + ((i - sample_start) * header_.nChannels + j));
char *read_p = (char *)(&buffer[((i - read_sample_start) * header_.nBlockAlign)]) + 3 * j;
data_p[3] = read_p[2];
data_p[2] = read_p[1];
Expand Down
18 changes: 7 additions & 11 deletions tensorflow_io/audio/ops/audio_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -41,20 +41,16 @@ REGISTER_OP("WAVIndexableGetItem")
.Input("start: int64")
.Input("stop: int64")
.Input("step: int64")
.Input("component: int64")
.Output("output: dtype")
.Attr("dtype: list(type) >= 1")
.Attr("shape: list(shape) >= 1")
.Attr("shape: shape")
.Attr("dtype: type")
.SetShapeFn([](shape_inference::InferenceContext* c) {
std::vector<PartialTensorShape> shape;
PartialTensorShape shape;
TF_RETURN_IF_ERROR(c->GetAttr("shape", &shape));
if (shape.size() != c->num_outputs()) {
return errors::InvalidArgument("`shape` must be the same length as `types` (", shape.size(), " vs. ", c->num_outputs());
}
for (size_t i = 0; i < shape.size(); ++i) {
shape_inference::ShapeHandle entry;
TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape(shape[i], &entry));
c->set_output(static_cast<int64>(i), entry);
}
shape_inference::ShapeHandle entry;
TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape(shape, &entry));
c->set_output(0, entry);
return Status::OK();
});

Expand Down
8 changes: 4 additions & 4 deletions tensorflow_io/core/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -122,9 +122,9 @@ cc_library(
)

go_binary(
name = "prometheus_go",
name = "golang_ops",
srcs = ["go/prometheus.go"],
out = "python/ops/libtensorflow_io_prometheus.so",
out = "python/ops/libtensorflow_io_golang.so",
cgo = True,
linkmode = "c-shared",
visibility = ["//visibility:public"],
Expand All @@ -138,11 +138,11 @@ go_binary(
cc_library(
name = "prometheus_go_ops",
srcs = [
"prometheus_go.h",
"golang_ops.h",
],
copts = tf_io_copts(),
data = [
"//tensorflow_io/core:prometheus_go.h",
"//tensorflow_io/core:golang_ops.h",
],
linkstatic = True,
)
Expand Down
132 changes: 64 additions & 68 deletions tensorflow_io/core/kernels/io_interface.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ namespace data {
class IOInterface : public ResourceBase {
public:
virtual Status Init(const std::vector<string>& input, const std::vector<string>& metadata, const void* memory_data, const int64 memory_size) = 0;
virtual Status Spec(std::vector<DataType>& dtypes, std::vector<PartialTensorShape>& shapes) = 0;
virtual Status Spec(std::vector<PartialTensorShape>& shapes, std::vector<DataType>& dtypes) = 0;

virtual Status Extra(std::vector<Tensor>* extra) {
// This is the chance to provide additional extra information which should be appended to extra.
Expand All @@ -33,12 +33,12 @@ class IOInterface : public ResourceBase {

class IOIterableInterface : public IOInterface {
public:
virtual Status Next(const int64 capacity, std::vector<Tensor>& tensors, int64* record_read) = 0;
virtual Status Next(const int64 capacity, const int64 component, Tensor* tensor, int64* record_read) = 0;
};

class IOIndexableInterface : public IOInterface {
public:
virtual Status GetItem(const int64 start, const int64 stop, const int64 step, std::vector<Tensor>& tensors) = 0;
virtual Status GetItem(const int64 start, const int64 stop, const int64 step, const int64 component, Tensor* tensor) = 0;
};

template<typename Type>
Expand All @@ -50,9 +50,8 @@ class IOIndexableImplementation : public IOIndexableInterface {

~IOIndexableImplementation<Type>() {}
Status Init(const std::vector<string>& input, const std::vector<string>& metadata, const void* memory_data, const int64 memory_size) override {

TF_RETURN_IF_ERROR(iterable_->Init(input, metadata, memory_data, memory_size));
TF_RETURN_IF_ERROR(iterable_->Spec(dtypes_, shapes_));
TF_RETURN_IF_ERROR(iterable_->Spec(shapes_, dtypes_));

const int64 capacity = 4096;
std::vector<TensorShape> chunk_shapes;
Expand All @@ -66,18 +65,23 @@ class IOIndexableImplementation : public IOIndexableInterface {

int64 record_read = 0;
do {
tensors_.push_back(std::vector<Tensor>());
chunk_tensors_.push_back(std::vector<Tensor>());
for (size_t component = 0; component < shapes_.size(); component++) {
tensors_.back().push_back(Tensor(dtypes_[component], chunk_shapes[component]));
chunk_tensors_.back().push_back(Tensor(dtypes_[component], chunk_shapes[component]));
int64 chunk_record_read = 0;
TF_RETURN_IF_ERROR(iterable_->Next(capacity, component, &chunk_tensors_.back()[component], &chunk_record_read));
if (component != 0 && record_read != chunk_record_read) {
return errors::Internal("component ", component, " has differtnt chunk size: ", chunk_record_read, " vs. ", record_read);
}
record_read = chunk_record_read;
}
TF_RETURN_IF_ERROR(iterable_->Next(capacity, tensors_.back(), &record_read));
if (record_read == 0) {
tensors_.pop_back();
chunk_tensors_.pop_back();
break;
}
if (record_read < capacity) {
for (size_t component = 0; component < shapes_.size(); component++) {
tensors_.back()[component] = tensors_.back()[component].Slice(0, record_read);
chunk_tensors_.back()[component] = chunk_tensors_.back()[component].Slice(0, record_read);
}
}
total += record_read;
Expand All @@ -87,13 +91,13 @@ class IOIndexableImplementation : public IOIndexableInterface {
}
return Status::OK();
}
virtual Status Spec(std::vector<DataType>& dtypes, std::vector<PartialTensorShape>& shapes) override {
for (size_t component = 0; component < dtypes_.size(); component++) {
dtypes.push_back(dtypes_[component]);
}
virtual Status Spec(std::vector<PartialTensorShape>& shapes, std::vector<DataType>& dtypes) override {
for (size_t component = 0; component < shapes_.size(); component++) {
shapes.push_back(shapes_[component]);
}
for (size_t component = 0; component < dtypes_.size(); component++) {
dtypes.push_back(dtypes_[component]);
}
return Status::OK();
}

Expand All @@ -105,41 +109,35 @@ class IOIndexableImplementation : public IOIndexableInterface {
return strings::StrCat("IOIndexableImplementation<", iterable_->DebugString(), ">[]");
}

Status GetItem(const int64 start, const int64 stop, const int64 step, std::vector<Tensor>& tensors) override {
Status GetItem(const int64 start, const int64 stop, const int64 step, const int64 component, Tensor* tensor) override {
if (step != 1) {
return errors::InvalidArgument("step != 1 is not supported: ", step);
}
// Find first chunk

int64 chunk_index = 0;
int64 chunk_element = -1;
int64 current_element = 0;
while (chunk_index < tensors_.size()) {
if (current_element <= start && start < current_element + tensors_[chunk_index][0].shape().dim_size(0)) {
while (chunk_index < chunk_tensors_.size()) {
if (current_element <= start && start < current_element + chunk_tensors_[chunk_index][component].shape().dim_size(0)) {
chunk_element = start - current_element;
current_element = start;
break;
}
current_element += tensors_[chunk_index][0].shape().dim_size(0);
current_element += chunk_tensors_[chunk_index][component].shape().dim_size(0);
chunk_index++;
}
if (chunk_element < 0) {
return errors::InvalidArgument("start is out of range: ", start);
}
std::vector<Tensor> elements;
for (size_t component = 0; component < shapes_.size(); component++) {
TensorShape shape(shapes_[component].dim_sizes());
shape.RemoveDim(0);
elements.push_back(Tensor(dtypes_[component], shape));
}
TensorShape shape(shapes_[component].dim_sizes());
shape.RemoveDim(0);
Tensor element(dtypes_[component], shape);

while (current_element < stop) {
for (size_t component = 0; component < shapes_.size(); component++) {
batch_util::CopySliceToElement(tensors_[chunk_index][component], &elements[component], chunk_element);
batch_util::CopyElementToSlice(elements[component], &tensors[component], (current_element - start));
}
batch_util::CopySliceToElement(chunk_tensors_[chunk_index][component], &element, chunk_element);
batch_util::CopyElementToSlice(element, tensor, (current_element - start));
chunk_element++;
if (chunk_element == tensors_[chunk_index][0].shape().dim_size(0)) {
if (chunk_element == chunk_tensors_[chunk_index][component].shape().dim_size(0)) {
chunk_index++;
chunk_element = 0;
}
Expand All @@ -151,9 +149,9 @@ class IOIndexableImplementation : public IOIndexableInterface {
mutable mutex mu_;
Env* env_ GUARDED_BY(mu_);
std::unique_ptr<Type> iterable_ GUARDED_BY(mu_);
std::vector<DataType> dtypes_ GUARDED_BY(mu_);
std::vector<PartialTensorShape> shapes_ GUARDED_BY(mu_);
std::vector<std::vector<Tensor>> tensors_;
std::vector<DataType> dtypes_ GUARDED_BY(mu_);
std::vector<std::vector<Tensor>> chunk_tensors_;
};


Expand Down Expand Up @@ -198,9 +196,9 @@ class IOInterfaceInitOp : public ResourceOpKernel<Type> {

OP_REQUIRES_OK(context, this->resource_->Init(input, metadata, memory_data, memory_size));

std::vector<DataType> dtypes;
std::vector<PartialTensorShape> shapes;
OP_REQUIRES_OK(context, this->resource_->Spec(dtypes, shapes));
std::vector<DataType> dtypes;
OP_REQUIRES_OK(context, this->resource_->Spec(shapes, dtypes));
int64 maxrank = 0;
for (size_t component = 0; component < shapes.size(); component++) {
if (dynamic_cast<IOIndexableInterface *>(this->resource_) != nullptr) {
Expand All @@ -212,10 +210,6 @@ class IOInterfaceInitOp : public ResourceOpKernel<Type> {
}
maxrank = maxrank > shapes[component].dims() ? maxrank : shapes[component].dims();
}
Tensor dtypes_tensor(DT_INT64, TensorShape({static_cast<int64>(dtypes.size())}));
for (size_t i = 0; i < dtypes.size(); i++) {
dtypes_tensor.flat<int64>()(i) = dtypes[i];
}
Tensor shapes_tensor(DT_INT64, TensorShape({static_cast<int64>(dtypes.size()), maxrank}));
for (size_t component = 0; component < shapes.size(); component++) {
for (int64 i = 0; i < shapes[component].dims(); i++) {
Expand All @@ -225,8 +219,12 @@ class IOInterfaceInitOp : public ResourceOpKernel<Type> {
shapes_tensor.tensor<int64, 2>()(component, i) = 0;
}
}
context->set_output(1, dtypes_tensor);
context->set_output(2, shapes_tensor);
Tensor dtypes_tensor(DT_INT64, TensorShape({static_cast<int64>(dtypes.size())}));
for (size_t i = 0; i < dtypes.size(); i++) {
dtypes_tensor.flat<int64>()(i) = dtypes[i];
}
context->set_output(1, shapes_tensor);
context->set_output(2, dtypes_tensor);

std::vector<Tensor> extra;
OP_REQUIRES_OK(context, this->resource_->Extra(&extra));
Expand Down Expand Up @@ -259,27 +257,26 @@ class IOIterableNextOp : public OpKernel {
OP_REQUIRES_OK(context, context->input("capacity", &capacity_tensor));
const int64 capacity = capacity_tensor->scalar<int64>()();

const Tensor* component_tensor;
OP_REQUIRES_OK(context, context->input("component", &component_tensor));
const int64 component = component_tensor->scalar<int64>()();

OP_REQUIRES(context, (capacity > 0), errors::InvalidArgument("capacity <= 0 is not supported: ", capacity));

std::vector<DataType> dtypes;
std::vector<PartialTensorShape> shapes;
OP_REQUIRES_OK(context, resource->Spec(dtypes, shapes));
std::vector<DataType> dtypes;
OP_REQUIRES_OK(context, resource->Spec(shapes, dtypes));

std::vector<Tensor> tensors;
for (size_t i = 0; i < dtypes.size(); i++) {
gtl::InlinedVector<int64, 4> dims = shapes[i].dim_sizes();
dims[0] = capacity;
tensors.emplace_back(Tensor(dtypes[i], TensorShape(dims)));
}
gtl::InlinedVector<int64, 4> dims = shapes[component].dim_sizes();
dims[0] = capacity;
Tensor tensor(dtypes[component], TensorShape(dims));

int64 record_read;
OP_REQUIRES_OK(context, resource->Next(capacity, tensors, &record_read));
for (size_t i = 0; i < tensors.size(); i++) {
if (record_read < capacity) {
context->set_output(i, tensors[i].Slice(0, record_read));
} else {
context->set_output(i, tensors[i]);
}
OP_REQUIRES_OK(context, resource->Next(capacity, component, &tensor, &record_read));
if (record_read < capacity) {
context->set_output(0, tensor.Slice(0, record_read));
} else {
context->set_output(0, tensor);
}
}
};
Expand Down Expand Up @@ -307,13 +304,17 @@ class IOIndexableGetItemOp : public OpKernel {
OP_REQUIRES_OK(context, context->input("step", &step_tensor));
int64 step = step_tensor->scalar<int64>()();

const Tensor* component_tensor;
OP_REQUIRES_OK(context, context->input("component", &component_tensor));
const int64 component = component_tensor->scalar<int64>()();

OP_REQUIRES(context, (step == 1), errors::InvalidArgument("step != 1 is not supported: ", step));

std::vector<DataType> dtypes;
std::vector<PartialTensorShape> shapes;
OP_REQUIRES_OK(context, resource->Spec(dtypes, shapes));
std::vector<DataType> dtypes;
OP_REQUIRES_OK(context, resource->Spec(shapes, dtypes));

int64 count = shapes[0].dim_size(0);
int64 count = shapes[component].dim_size(0);
if (start > count) {
start = count;
}
Expand All @@ -324,16 +325,11 @@ class IOIndexableGetItemOp : public OpKernel {
stop = start;
}

std::vector<Tensor> tensors;
for (size_t i = 0; i < dtypes.size(); i++) {
gtl::InlinedVector<int64, 4> dims = shapes[i].dim_sizes();
dims[0] = stop - start;
tensors.emplace_back(Tensor(dtypes[i], TensorShape(dims)));
}
OP_REQUIRES_OK(context, resource->GetItem(start, stop, step, tensors));
for (size_t i = 0; i < tensors.size(); i++) {
context->set_output(i, tensors[i]);
}
gtl::InlinedVector<int64, 4> dims = shapes[component].dim_sizes();
dims[0] = stop - start;
Tensor tensor(dtypes[component], TensorShape(dims));
OP_REQUIRES_OK(context, resource->GetItem(start, stop, step, component, &tensor));
context->set_output(0, tensor);
}
};
} // namespace data
Expand Down
2 changes: 1 addition & 1 deletion tensorflow_io/core/python/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,5 +66,5 @@ def _load_library(filename, lib="op"):
"unable to open file: " +
"{}, from paths: {}\ncaused by: {}".format(filename, filenames, errs))

_load_library("libtensorflow_io_prometheus.so", "dependency")
_load_library("libtensorflow_io_golang.so", "dependency")
core_ops = _load_library('libtensorflow_io.so')
Loading