Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
213 changes: 1 addition & 212 deletions cpp/src/arrow/engine/substrait/options.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,216 +32,6 @@ namespace engine {

namespace {

std::vector<compute::Declaration::Input> MakeDeclarationInputs(
const std::vector<DeclarationInfo>& inputs) {
std::vector<compute::Declaration::Input> input_decls(inputs.size());
for (size_t i = 0; i < inputs.size(); i++) {
input_decls[i] = inputs[i].declaration;
}
return input_decls;
}

} // namespace

class BaseExtensionProvider : public ExtensionProvider {
public:
Result<RelationInfo> MakeRel(const ConversionOptions& conv_opts,
const std::vector<DeclarationInfo>& inputs,
const ExtensionDetails& ext_details,
const ExtensionSet& ext_set) override {
auto details = dynamic_cast<const DefaultExtensionDetails&>(ext_details);
return MakeRel(conv_opts, inputs, details.rel, ext_set);
}

virtual Result<RelationInfo> MakeRel(const ConversionOptions& conv_opts,
const std::vector<DeclarationInfo>& inputs,
const google::protobuf::Any& rel,
const ExtensionSet& ext_set) = 0;
};

class DefaultExtensionProvider : public BaseExtensionProvider {
public:
Result<RelationInfo> MakeRel(const ConversionOptions& conv_opts,
const std::vector<DeclarationInfo>& inputs,
const google::protobuf::Any& rel,
const ExtensionSet& ext_set) override {
if (rel.Is<substrait_ext::AsOfJoinRel>()) {
substrait_ext::AsOfJoinRel as_of_join_rel;
rel.UnpackTo(&as_of_join_rel);
return MakeAsOfJoinRel(inputs, as_of_join_rel, ext_set);
}
if (rel.Is<substrait_ext::NamedTapRel>()) {
substrait_ext::NamedTapRel named_tap_rel;
rel.UnpackTo(&named_tap_rel);
return MakeNamedTapRel(conv_opts, inputs, named_tap_rel, ext_set);
}
if (rel.Is<substrait_ext::SegmentedAggregateRel>()) {
substrait_ext::SegmentedAggregateRel seg_agg_rel;
rel.UnpackTo(&seg_agg_rel);
return MakeSegmentedAggregateRel(conv_opts, inputs, seg_agg_rel, ext_set);
}
return Status::NotImplemented("Unrecognized extension in Susbstrait plan: ",
rel.DebugString());
}

private:
Result<RelationInfo> MakeAsOfJoinRel(const std::vector<DeclarationInfo>& inputs,
const substrait_ext::AsOfJoinRel& as_of_join_rel,
const ExtensionSet& ext_set) {
if (inputs.size() < 2) {
return Status::Invalid("substrait_ext::AsOfJoinNode too few input tables: ",
inputs.size());
}
if (static_cast<size_t>(as_of_join_rel.keys_size()) != inputs.size()) {
return Status::Invalid("substrait_ext::AsOfJoinNode mismatched number of inputs");
}

size_t n_input = inputs.size(), i = 0;
std::vector<compute::AsofJoinNodeOptions::Keys> input_keys(n_input);
for (const auto& keys : as_of_join_rel.keys()) {
// on-key
if (!keys.has_on()) {
return Status::Invalid("substrait_ext::AsOfJoinNode missing on-key for input ",
i);
}
ARROW_ASSIGN_OR_RAISE(auto on_key_expr, FromProto(keys.on(), ext_set, {}));
if (on_key_expr.field_ref() == NULLPTR) {
return Status::NotImplemented(
"substrait_ext::AsOfJoinNode non-field-ref on-key for input ", i);
}
const FieldRef& on_key = *on_key_expr.field_ref();

// by-key
std::vector<FieldRef> by_key;
for (const auto& by_item : keys.by()) {
ARROW_ASSIGN_OR_RAISE(auto by_key_expr, FromProto(by_item, ext_set, {}));
if (by_key_expr.field_ref() == NULLPTR) {
return Status::NotImplemented(
"substrait_ext::AsOfJoinNode non-field-ref by-key for input ", i);
}
by_key.push_back(*by_key_expr.field_ref());
}

input_keys[i] = {std::move(on_key), std::move(by_key)};
++i;
}

// schema
int64_t tolerance = as_of_join_rel.tolerance();
std::vector<std::shared_ptr<Schema>> input_schema(inputs.size());
for (size_t i = 0; i < inputs.size(); i++) {
input_schema[i] = inputs[i].output_schema;
}
std::vector<int> field_output_indices;
ARROW_ASSIGN_OR_RAISE(auto schema,
compute::asofjoin::MakeOutputSchema(input_schema, input_keys,
&field_output_indices));
compute::AsofJoinNodeOptions asofjoin_node_opts{std::move(input_keys), tolerance};

// declaration
auto input_decls = MakeDeclarationInputs(inputs);
return RelationInfo{
{compute::Declaration("asofjoin", input_decls, std::move(asofjoin_node_opts)),
std::move(schema)},
std::move(field_output_indices)};
}

Result<RelationInfo> MakeNamedTapRel(const ConversionOptions& conv_opts,
const std::vector<DeclarationInfo>& inputs,
const substrait_ext::NamedTapRel& named_tap_rel,
const ExtensionSet& ext_set) {
if (inputs.size() != 1) {
return Status::Invalid(
"substrait_ext::NamedTapRel requires a single input but got: ", inputs.size());
}

auto schema = inputs[0].output_schema;
int num_fields = schema->num_fields();
if (named_tap_rel.columns_size() != num_fields) {
return Status::Invalid("Got ", named_tap_rel.columns_size(),
" NamedTapRel columns but expected ", num_fields);
}
std::vector<std::string> columns(named_tap_rel.columns().begin(),
named_tap_rel.columns().end());
ARROW_ASSIGN_OR_RAISE(auto renamed_schema, schema->WithNames(columns));
auto input_decls = MakeDeclarationInputs(inputs);
ARROW_ASSIGN_OR_RAISE(
auto decl, conv_opts.named_tap_provider(named_tap_rel.kind(), input_decls,
named_tap_rel.name(), renamed_schema));
return RelationInfo{{std::move(decl), std::move(renamed_schema)}, std::nullopt};
}

Result<RelationInfo> MakeSegmentedAggregateRel(
const ConversionOptions& conv_opts, const std::vector<DeclarationInfo>& inputs,
const substrait_ext::SegmentedAggregateRel& seg_agg_rel,
const ExtensionSet& ext_set) {
if (inputs.size() != 1) {
return Status::Invalid(
"substrait_ext::SegmentedAggregateRel requires a single input but got: ",
inputs.size());
}
if (seg_agg_rel.segment_keys_size() == 0) {
return Status::Invalid(
"substrait_ext::SegmentedAggregateRel requires at least one segment key");
}

auto input_schema = inputs[0].output_schema;

// store key fields to be used when output schema is created
std::vector<int> key_field_ids;
std::vector<FieldRef> keys;
for (auto& key_refseg : seg_agg_rel.grouping_keys()) {
ARROW_ASSIGN_OR_RAISE(auto field_ref,
DirectReferenceFromProto(&key_refseg, ext_set, conv_opts));
ARROW_ASSIGN_OR_RAISE(auto match, field_ref.FindOne(*input_schema));
key_field_ids.emplace_back(std::move(match[0]));
keys.emplace_back(std::move(field_ref));
}

// store segment key fields to be used when output schema is created
std::vector<int> segment_key_field_ids;
std::vector<FieldRef> segment_keys;
for (auto& key_refseg : seg_agg_rel.segment_keys()) {
ARROW_ASSIGN_OR_RAISE(auto field_ref,
DirectReferenceFromProto(&key_refseg, ext_set, conv_opts));
ARROW_ASSIGN_OR_RAISE(auto match, field_ref.FindOne(*input_schema));
segment_key_field_ids.emplace_back(std::move(match[0]));
segment_keys.emplace_back(std::move(field_ref));
}

std::vector<compute::Aggregate> aggregates;
aggregates.reserve(seg_agg_rel.measures_size());
std::vector<std::vector<int>> agg_src_fieldsets;
agg_src_fieldsets.reserve(seg_agg_rel.measures_size());
for (auto agg_measure : seg_agg_rel.measures()) {
ARROW_ASSIGN_OR_RAISE(
auto parsed_measure,
internal::ParseAggregateMeasure(agg_measure, ext_set, conv_opts,
/*is_hash=*/!keys.empty(), input_schema));
aggregates.push_back(std::move(parsed_measure.aggregate));
agg_src_fieldsets.push_back(std::move(parsed_measure.fieldset));
}

ARROW_ASSIGN_OR_RAISE(auto decl_info,
internal::MakeAggregateDeclaration(
std::move(inputs[0].declaration), std::move(input_schema),
seg_agg_rel.measures_size(), std::move(aggregates),
std::move(agg_src_fieldsets), std::move(keys),
std::move(key_field_ids), std::move(segment_keys),
std::move(segment_key_field_ids), ext_set, conv_opts));

const auto& output_schema = decl_info.output_schema;
size_t out_size = output_schema->num_fields();
std::vector<int> field_output_indices(out_size);
for (int i = 0; i < static_cast<int>(out_size); i++) {
field_output_indices[i] = i;
}
return RelationInfo{decl_info, std::move(field_output_indices)};
}
};

namespace {

template <typename T>
class ConfigurableSingleton {
public:
Expand All @@ -264,8 +54,7 @@ class ConfigurableSingleton {

ConfigurableSingleton<std::shared_ptr<ExtensionProvider>>&
default_extension_provider_singleton() {
static ConfigurableSingleton<std::shared_ptr<ExtensionProvider>> singleton(
std::make_shared<DefaultExtensionProvider>());
static ConfigurableSingleton<std::shared_ptr<ExtensionProvider>> singleton(GetDefaultExtensionProvider());
return singleton;
}

Expand Down
Loading