Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion cpp/src/arrow/compute/kernels/codegen_internal.cc
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,7 @@ TypeHolder CommonTemporal(const TypeHolder* begin, size_t count) {
const std::string* timezone = nullptr;
bool saw_date32 = false;
bool saw_date64 = false;

bool saw_duration = false;
const TypeHolder* end = begin + count;
for (auto it = begin; it != end; it++) {
auto id = it->type->id();
Expand All @@ -271,6 +271,12 @@ TypeHolder CommonTemporal(const TypeHolder* begin, size_t count) {
finest_unit = std::max(finest_unit, ty.unit());
continue;
}
case Type::DURATION: {
const auto& ty = checked_cast<const DurationType&>(*it->type);
finest_unit = std::max(finest_unit, ty.unit());
saw_duration = true;
continue;
}
default:
return TypeHolder(nullptr);
}
Expand All @@ -283,6 +289,8 @@ TypeHolder CommonTemporal(const TypeHolder* begin, size_t count) {
return date64();
} else if (saw_date32) {
return date32();
} else if (saw_duration) {
return duration(finest_unit);
}
return TypeHolder(nullptr);
}
Expand Down
4 changes: 4 additions & 0 deletions cpp/src/arrow/compute/kernels/scalar_if_else.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2798,6 +2798,7 @@ void RegisterScalarIfElse(FunctionRegistry* registry) {
AddPrimitiveIfElseKernels(func, NumericTypes());
AddPrimitiveIfElseKernels(func, TemporalTypes());
AddPrimitiveIfElseKernels(func, IntervalTypes());
AddPrimitiveIfElseKernels(func, DurationTypes());
AddPrimitiveIfElseKernels(func, {boolean()});
AddNullIfElseKernel(func);
AddBinaryIfElseKernels(func, BaseBinaryTypes());
Expand All @@ -2813,6 +2814,7 @@ void RegisterScalarIfElse(FunctionRegistry* registry) {
AddPrimitiveCaseWhenKernels(func, NumericTypes());
AddPrimitiveCaseWhenKernels(func, TemporalTypes());
AddPrimitiveCaseWhenKernels(func, IntervalTypes());
AddPrimitiveCaseWhenKernels(func, DurationTypes());
AddPrimitiveCaseWhenKernels(func, {boolean(), null()});
AddCaseWhenKernel(func, Type::FIXED_SIZE_BINARY,
CaseWhenFunctor<FixedSizeBinaryType>::Exec);
Expand All @@ -2836,6 +2838,7 @@ void RegisterScalarIfElse(FunctionRegistry* registry) {
AddPrimitiveCoalesceKernels(func, NumericTypes());
AddPrimitiveCoalesceKernels(func, TemporalTypes());
AddPrimitiveCoalesceKernels(func, IntervalTypes());
AddPrimitiveCoalesceKernels(func, DurationTypes());
AddPrimitiveCoalesceKernels(func, {boolean(), null()});
AddCoalesceKernel(func, Type::FIXED_SIZE_BINARY,
CoalesceFunctor<FixedSizeBinaryType>::Exec);
Expand All @@ -2861,6 +2864,7 @@ void RegisterScalarIfElse(FunctionRegistry* registry) {
AddPrimitiveChooseKernels(func, NumericTypes());
AddPrimitiveChooseKernels(func, TemporalTypes());
AddPrimitiveChooseKernels(func, IntervalTypes());
AddPrimitiveChooseKernels(func, DurationTypes());
AddPrimitiveChooseKernels(func, {boolean(), null()});
AddChooseKernel(func, Type::FIXED_SIZE_BINARY,
ChooseFunctor<FixedSizeBinaryType>::Exec);
Expand Down
15 changes: 13 additions & 2 deletions cpp/src/arrow/compute/kernels/scalar_if_else_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -73,15 +73,16 @@ class TestIfElsePrimitive : public ::testing::Test {};
#ifdef ARROW_VALGRIND
using IfElseNumericBasedTypes =
::testing::Types<UInt32Type, FloatType, Date32Type, Time32Type, TimestampType,
MonthIntervalType>;
MonthIntervalType, DurationType>;
using BaseBinaryArrowTypes = ::testing::Types<BinaryType>;
using ListArrowTypes = ::testing::Types<ListType>;
using IntegralArrowTypes = ::testing::Types<Int32Type>;
#else
using IfElseNumericBasedTypes =
::testing::Types<UInt8Type, UInt16Type, UInt32Type, UInt64Type, Int8Type, Int16Type,
Int32Type, Int64Type, FloatType, DoubleType, Date32Type, Date64Type,
Time32Type, Time64Type, TimestampType, MonthIntervalType>;
Time32Type, Time64Type, TimestampType, MonthIntervalType,
DurationType>;
#endif

TYPED_TEST_SUITE(TestIfElsePrimitive, IfElseNumericBasedTypes);
Expand Down Expand Up @@ -505,6 +506,9 @@ TEST_F(TestIfElseKernel, IfElseDispatchBest) {
{boolean(), timestamp(TimeUnit::MILLI), timestamp(TimeUnit::MILLI)});
CheckDispatchBest(name, {boolean(), date32(), timestamp(TimeUnit::MILLI)},
{boolean(), timestamp(TimeUnit::MILLI), timestamp(TimeUnit::MILLI)});
CheckDispatchBest(name,
{boolean(), duration(TimeUnit::SECOND), duration(TimeUnit::MILLI)},
{boolean(), duration(TimeUnit::MILLI), duration(TimeUnit::MILLI)});
CheckDispatchBest(name, {boolean(), date32(), date64()},
{boolean(), date64(), date64()});
CheckDispatchBest(name, {boolean(), date32(), date32()},
Expand Down Expand Up @@ -2500,6 +2504,11 @@ TEST(TestCaseWhen, DispatchBest) {
{struct_({field("", boolean())}), timestamp(TimeUnit::SECOND), date32()},
{struct_({field("", boolean())}), timestamp(TimeUnit::SECOND),
timestamp(TimeUnit::SECOND)});
CheckDispatchBest("case_when",
{struct_({field("", boolean())}), duration(TimeUnit::SECOND),
duration(TimeUnit::MILLI)},
{struct_({field("", boolean())}), duration(TimeUnit::MILLI),
duration(TimeUnit::MILLI)});
CheckDispatchBest(
"case_when", {struct_({field("", boolean())}), decimal128(38, 0), decimal128(1, 1)},
{struct_({field("", boolean())}), decimal256(39, 1), decimal256(39, 1)});
Expand Down Expand Up @@ -3350,6 +3359,8 @@ TEST(TestCoalesce, DispatchBest) {
{timestamp(TimeUnit::SECOND), timestamp(TimeUnit::SECOND)});
CheckDispatchBest("coalesce", {timestamp(TimeUnit::SECOND), timestamp(TimeUnit::MILLI)},
{timestamp(TimeUnit::MILLI), timestamp(TimeUnit::MILLI)});
CheckDispatchBest("coalesce", {duration(TimeUnit::SECOND), duration(TimeUnit::MILLI)},
{duration(TimeUnit::MILLI), duration(TimeUnit::MILLI)});
CheckDispatchFails("coalesce", {
sparse_union({field("a", boolean())}),
dense_union({field("a", boolean())}),
Expand Down
5 changes: 4 additions & 1 deletion cpp/src/arrow/compute/kernels/test_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,10 @@ template <typename T>
enable_if_decimal<T, std::shared_ptr<DataType>> default_type_instance() {
return std::make_shared<T>(5, 2);
}

template <typename T>
enable_if_duration<T, std::shared_ptr<DataType>> default_type_instance() {
return std::make_shared<T>(TimeUnit::type::SECOND);
}
// Random Generator Helpers
class RandomImpl {
protected:
Expand Down