Skip to content

Commit 2c547c6

Browse files
committed
Import GetTensorFlowType and GetArrowType
Signed-off-by: Yong Tang <[email protected]>
1 parent 85e2e91 commit 2c547c6

File tree

4 files changed

+30
-64
lines changed

4 files changed

+30
-64
lines changed

tensorflow_io/arrow/kernels/arrow_dataset_ops.cc

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@ limitations under the License.
1414
==============================================================================*/
1515

1616
#include "arrow/api.h"
17-
#include "arrow/adapters/tensorflow/convert.h"
1817
#include "arrow/ipc/api.h"
1918
#include "arrow/util/io-util.h"
2019
#include "tensorflow/core/framework/dataset.h"
@@ -99,8 +98,10 @@ class ArrowColumnTypeChecker : public arrow::TypeVisitor {
9998
// Check scalar types with arrow::adapters::tensorflow
10099
arrow::Status CheckScalarType(std::shared_ptr<arrow::DataType> scalar_type) {
101100
DataType converted_type;
102-
ARROW_RETURN_NOT_OK(arrow::adapters::tensorflow::GetTensorFlowType(
103-
scalar_type, &converted_type));
101+
::tensorflow::Status status = GetTensorFlowType(scalar_type, &converted_type);
102+
if (!status.ok()) {
103+
return ::arrow::Status::Invalid(status);
104+
}
104105
if (converted_type != expected_type_) {
105106
return arrow::Status::TypeError(
106107
"Arrow type mismatch: expected dtype=" +
@@ -523,11 +524,7 @@ class ArrowOpKernelBase : public DatasetOpKernel {
523524
OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_));
524525
for (const DataType& dt : output_types_) {
525526
std::shared_ptr<arrow::DataType> arrow_type;
526-
auto status = arrow::adapters::tensorflow::GetArrowType(dt, &arrow_type);
527-
OP_REQUIRES(ctx, status.ok(),
528-
errors::InvalidArgument(
529-
"Arrow type is unsupported for output_type dtype=" +
530-
std::to_string(dt)));
527+
OP_REQUIRES_OK(ctx, GetArrowType(dt, &arrow_type));
531528
}
532529
for (const PartialTensorShape& pts : output_shapes_) {
533530
OP_REQUIRES(ctx, -1 <= pts.dims() && pts.dims() <= 2,

tensorflow_io/arrow/kernels/arrow_kernels.cc

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,27 @@ limitations under the License.
1919
#include "arrow/ipc/feather.h"
2020
#include "arrow/ipc/feather_generated.h"
2121
#include "arrow/buffer.h"
22+
#include "arrow/adapters/tensorflow/convert.h"
2223

2324
namespace tensorflow {
2425
namespace data {
26+
27+
Status GetTensorFlowType(std::shared_ptr<::arrow::DataType> dtype, ::tensorflow::DataType* out) {
28+
::arrow::Status status = ::arrow::adapters::tensorflow::GetTensorFlowType(dtype, out);
29+
if (!status.ok()) {
30+
return errors::InvalidArgument("arrow data type ", dtype, " is not supported: ", status);
31+
}
32+
return Status::OK();
33+
}
34+
35+
Status GetArrowType(::tensorflow::DataType dtype, std::shared_ptr<::arrow::DataType>* out) {
36+
::arrow::Status status = ::arrow::adapters::tensorflow::GetArrowType(dtype, out);
37+
if (!status.ok()) {
38+
return errors::InvalidArgument("tensorflow data type ", dtype, " is not supported: ", status);
39+
}
40+
return Status::OK();
41+
}
42+
2543
namespace {
2644

2745
class ListFeatherColumnsOp : public OpKernel {

tensorflow_io/arrow/kernels/arrow_kernels.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,14 @@ limitations under the License.
1919
#include "kernels/stream.h"
2020
#include "arrow/io/api.h"
2121
#include "arrow/buffer.h"
22+
#include "arrow/type.h"
2223

2324
namespace tensorflow {
2425
namespace data {
2526

27+
Status GetTensorFlowType(std::shared_ptr<::arrow::DataType> dtype, ::tensorflow::DataType* out);
28+
Status GetArrowType(::tensorflow::DataType dtype, std::shared_ptr<::arrow::DataType>* out);
29+
2630
// NOTE: Both SizedRandomAccessFile and ArrowRandomAccessFile overlap
2731
// with another PR. Will remove duplicate once PR merged
2832

@@ -92,6 +96,8 @@ class ArrowRandomAccessFile : public ::arrow::io::RandomAccessFile {
9296
int64 size_;
9397
int64 position_;
9498
};
99+
100+
95101
} // namespace data
96102
} // namespace tensorflow
97103

tensorflow_io/json/kernels/json_kernels.cc

Lines changed: 1 addition & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -262,61 +262,7 @@ class JSONIndexable : public IOIndexableInterface {
262262
for (size_t i = 0; i < columns_index_.size(); i++) {
263263
int column_index = columns_index_[i];
264264
::tensorflow::DataType dtype;
265-
switch (table_->column(column_index)->type()->id()) {
266-
case ::arrow::Type::BOOL:
267-
dtype = ::tensorflow::DT_BOOL;
268-
break;
269-
case ::arrow::Type::UINT8:
270-
dtype= ::tensorflow::DT_UINT8;
271-
break;
272-
case ::arrow::Type::INT8:
273-
dtype= ::tensorflow::DT_INT8;
274-
break;
275-
case ::arrow::Type::UINT16:
276-
dtype= ::tensorflow::DT_UINT16;
277-
break;
278-
case ::arrow::Type::INT16:
279-
dtype= ::tensorflow::DT_INT16;
280-
break;
281-
case ::arrow::Type::UINT32:
282-
dtype= ::tensorflow::DT_UINT32;
283-
break;
284-
case ::arrow::Type::INT32:
285-
dtype= ::tensorflow::DT_INT32;
286-
break;
287-
case ::arrow::Type::UINT64:
288-
dtype= ::tensorflow::DT_UINT64;
289-
break;
290-
case ::arrow::Type::INT64:
291-
dtype= ::tensorflow::DT_INT64;
292-
break;
293-
case ::arrow::Type::HALF_FLOAT:
294-
dtype= ::tensorflow::DT_HALF;
295-
break;
296-
case ::arrow::Type::FLOAT:
297-
dtype= ::tensorflow::DT_FLOAT;
298-
break;
299-
case ::arrow::Type::DOUBLE:
300-
dtype= ::tensorflow::DT_DOUBLE;
301-
break;
302-
case ::arrow::Type::STRING:
303-
case ::arrow::Type::BINARY:
304-
case ::arrow::Type::FIXED_SIZE_BINARY:
305-
case ::arrow::Type::DATE32:
306-
case ::arrow::Type::DATE64:
307-
case ::arrow::Type::TIMESTAMP:
308-
case ::arrow::Type::TIME32:
309-
case ::arrow::Type::TIME64:
310-
case ::arrow::Type::INTERVAL:
311-
case ::arrow::Type::DECIMAL:
312-
case ::arrow::Type::LIST:
313-
case ::arrow::Type::STRUCT:
314-
case ::arrow::Type::UNION:
315-
case ::arrow::Type::DICTIONARY:
316-
case ::arrow::Type::MAP:
317-
default:
318-
return errors::InvalidArgument("arrow data type is not supported: ", table_->column(i)->type()->ToString());
319-
}
265+
TF_RETURN_IF_ERROR(GetTensorFlowType(table_->column(column_index)->type(), &dtype));
320266
dtypes_.push_back(dtype);
321267
shapes_.push_back(TensorShape({static_cast<int64>(table_->num_rows())}));
322268
columns_.push_back(table_->column(column_index)->name());
@@ -347,7 +293,6 @@ class JSONIndexable : public IOIndexableInterface {
347293
}
348294

349295
Status GetItem(const int64 start, const int64 stop, const int64 step, std::vector<Tensor>& tensors) override {
350-
Tensor& output_tensor = tensors[0];
351296
if (step != 1) {
352297
return errors::InvalidArgument("step ", step, " is not supported");
353298
}

0 commit comments

Comments
 (0)