Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
133 changes: 128 additions & 5 deletions datafusion/substrait/src/logical_plan/consumer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

use async_recursion::async_recursion;
use datafusion::arrow::datatypes::{
DataType, Field, Fields, IntervalUnit, Schema, TimeUnit,
DataType, Field, FieldRef, Fields, IntervalUnit, Schema, TimeUnit,
};
use datafusion::common::{
not_impl_err, substrait_datafusion_err, substrait_err, DFSchema, DFSchemaRef,
Expand All @@ -29,12 +29,13 @@ use url::Url;
use arrow_buffer::{IntervalDayTime, IntervalMonthDayNano};
use datafusion::execution::FunctionRegistry;
use datafusion::logical_expr::{
aggregate_function, expr::find_df_window_func, BinaryExpr, Case, EmptyRelation, Expr,
LogicalPlan, Operator, ScalarUDF, Values,
aggregate_function, expr::find_df_window_func, Aggregate, BinaryExpr, Case,
EmptyRelation, Expr, ExprSchemable, LogicalPlan, Operator, Projection, ScalarUDF,
Values,
};

use datafusion::logical_expr::{
expr, Cast, Extension, GroupingSet, Like, LogicalPlanBuilder, Partitioning,
col, expr, Cast, Extension, GroupingSet, Like, LogicalPlanBuilder, Partitioning,
Repartition, Subquery, WindowFrameBound, WindowFrameUnits, WindowFunctionDefinition,
};
use datafusion::prelude::JoinType;
Expand Down Expand Up @@ -225,6 +226,7 @@ pub async fn from_substrait_plan(
None => not_impl_err!("Cannot parse empty extension"),
})
.collect::<Result<HashMap<_, _>>>()?;

// Parse relations
match plan.relations.len() {
1 => {
Expand All @@ -234,7 +236,29 @@ pub async fn from_substrait_plan(
Ok(from_substrait_rel(ctx, rel, &function_extension).await?)
},
plan_rel::RelType::Root(root) => {
Ok(from_substrait_rel(ctx, root.input.as_ref().unwrap(), &function_extension).await?)
let plan = from_substrait_rel(ctx, root.input.as_ref().unwrap(), &function_extension).await?;
if root.names.is_empty() {
// Backwards compatibility for plans missing names
return Ok(plan);
}
let renamed_schema = make_renamed_schema(plan.schema(), &root.names)?;
if renamed_schema.equivalent_names_and_types(plan.schema()) {
// Nothing to do if the schema is already equivalent
return Ok(plan);
}

match plan {
// If the last node of the plan produces expressions, bake the renames into those expressions.
// This isn't necessary for correctness, but helps with roundtrip tests.
LogicalPlan::Projection(p) => Ok(LogicalPlan::Projection(Projection::try_new(rename_expressions(p.expr, p.input.schema(), renamed_schema)?, p.input)?)),
LogicalPlan::Aggregate(a) => {
let new_aggr_exprs = rename_expressions(a.aggr_expr, a.input.schema(), renamed_schema)?;
Ok(LogicalPlan::Aggregate(Aggregate::try_new(a.input, a.group_expr, new_aggr_exprs)?))
},
// There are probably more plans where we could bake things in, can add them later as needed.
// Otherwise, add a new Project to handle the renaming.
_ => Ok(LogicalPlan::Projection(Projection::try_new(rename_expressions(plan.schema().columns().iter().map(|c| col(c.to_owned())), plan.schema(), renamed_schema)?, Arc::new(plan))?))
}
}
},
None => plan_err!("Cannot parse plan relation: None")
Expand Down Expand Up @@ -284,6 +308,105 @@ pub fn extract_projection(
}
}

fn rename_expressions(
exprs: impl IntoIterator<Item = Expr>,
input_schema: &DFSchema,
new_schema: DFSchemaRef,
) -> Result<Vec<Expr>> {
exprs
.into_iter()
.zip(new_schema.fields())
.map(|(old_expr, new_field)| {
if &old_expr.get_type(input_schema)? == new_field.data_type() {
// Alias column if needed
old_expr.alias_if_changed(new_field.name().into())
} else {
// Use Cast to rename inner struct fields + alias column if needed
Expr::Cast(Cast::new(
Box::new(old_expr),
new_field.data_type().to_owned(),
))
.alias_if_changed(new_field.name().into())
}
})
.collect()
}

fn make_renamed_schema(
schema: &DFSchemaRef,
dfs_names: &Vec<String>,
) -> Result<DFSchemaRef> {
fn rename_inner_fields(
dtype: &DataType,
dfs_names: &Vec<String>,
name_idx: &mut usize,
) -> Result<DataType> {
match dtype {
DataType::Struct(fields) => {
let fields = fields
.iter()
.map(|f| {
let name = next_struct_field_name(0, dfs_names, name_idx)?;
Ok((**f).to_owned().with_name(name).with_data_type(
rename_inner_fields(f.data_type(), dfs_names, name_idx)?,
))
})
.collect::<Result<_>>()?;
Ok(DataType::Struct(fields))
}
DataType::List(inner) => Ok(DataType::List(FieldRef::new(
(**inner).to_owned().with_data_type(rename_inner_fields(
inner.data_type(),
dfs_names,
name_idx,
)?),
))),
DataType::LargeList(inner) => Ok(DataType::LargeList(FieldRef::new(
(**inner).to_owned().with_data_type(rename_inner_fields(
inner.data_type(),
dfs_names,
name_idx,
)?),
))),
_ => Ok(dtype.to_owned()),
}
}

let mut name_idx = 0;

let (qualifiers, fields): (_, Vec<Field>) = schema
.iter()
.map(|(q, f)| {
let name = next_struct_field_name(0, dfs_names, &mut name_idx)?;
Ok((
q.cloned(),
(**f)
.to_owned()
.with_name(name)
.with_data_type(rename_inner_fields(
f.data_type(),
dfs_names,
&mut name_idx,
)?),
))
})
.collect::<Result<Vec<_>>>()?
.into_iter()
.unzip();

if name_idx != dfs_names.len() {
return substrait_err!(
"Names list must match exactly to nested schema, but found {} uses for {} names",
name_idx,
dfs_names.len());
}

Ok(Arc::new(DFSchema::from_field_specific_qualified_schema(
qualifiers,
&Arc::new(Schema::new(fields)),
)?))
}

/// Convert Substrait Rel to DataFusion DataFrame
#[async_recursion]
pub async fn from_substrait_rel(
Expand Down
2 changes: 1 addition & 1 deletion datafusion/substrait/src/logical_plan/producer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ pub fn to_substrait_plan(plan: &LogicalPlan, ctx: &SessionContext) -> Result<Box
let plan_rels = vec![PlanRel {
rel_type: Some(plan_rel::RelType::Root(RelRoot {
input: Some(*to_substrait_rel(plan, ctx, &mut extension_info)?),
names: plan.schema().field_names(),
names: to_substrait_named_struct(plan.schema())?.names,
})),
}];

Expand Down
17 changes: 9 additions & 8 deletions datafusion/substrait/tests/cases/consumer_integration.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,14 +43,15 @@ mod tests {

let plan = from_substrait_plan(&ctx, &proto).await?;

assert!(
format!("{:?}", plan).eq_ignore_ascii_case(
"Sort: FILENAME_PLACEHOLDER_0.l_returnflag ASC NULLS LAST, FILENAME_PLACEHOLDER_0.l_linestatus ASC NULLS LAST\n \
Aggregate: groupBy=[[FILENAME_PLACEHOLDER_0.l_returnflag, FILENAME_PLACEHOLDER_0.l_linestatus]], aggr=[[SUM(FILENAME_PLACEHOLDER_0.l_quantity), SUM(FILENAME_PLACEHOLDER_0.l_extendedprice), SUM(FILENAME_PLACEHOLDER_0.l_extendedprice * Int32(1) - FILENAME_PLACEHOLDER_0.l_discount), SUM(FILENAME_PLACEHOLDER_0.l_extendedprice * Int32(1) - FILENAME_PLACEHOLDER_0.l_discount * Int32(1) + FILENAME_PLACEHOLDER_0.l_tax), AVG(FILENAME_PLACEHOLDER_0.l_quantity), AVG(FILENAME_PLACEHOLDER_0.l_extendedprice), AVG(FILENAME_PLACEHOLDER_0.l_discount), COUNT(Int64(1))]]\n \
Projection: FILENAME_PLACEHOLDER_0.l_returnflag, FILENAME_PLACEHOLDER_0.l_linestatus, FILENAME_PLACEHOLDER_0.l_quantity, FILENAME_PLACEHOLDER_0.l_extendedprice, FILENAME_PLACEHOLDER_0.l_extendedprice * (CAST(Int32(1) AS Decimal128(19, 0)) - FILENAME_PLACEHOLDER_0.l_discount), FILENAME_PLACEHOLDER_0.l_extendedprice * (CAST(Int32(1) AS Decimal128(19, 0)) - FILENAME_PLACEHOLDER_0.l_discount) * (CAST(Int32(1) AS Decimal128(19, 0)) + FILENAME_PLACEHOLDER_0.l_tax), FILENAME_PLACEHOLDER_0.l_discount\n \
Filter: FILENAME_PLACEHOLDER_0.l_shipdate <= Date32(\"1998-12-01\") - IntervalDayTime(\"IntervalDayTime { days: 120, milliseconds: 0 }\")\n \
TableScan: FILENAME_PLACEHOLDER_0 projection=[l_orderkey, l_partkey, l_suppkey, l_linenumber, l_quantity, l_extendedprice, l_discount, l_tax, l_returnflag, l_linestatus, l_shipdate, l_commitdate, l_receiptdate, l_shipinstruct, l_shipmode, l_comment]"
)
let plan_str = format!("{:?}", plan);
assert_eq!(
plan_str,
"Projection: FILENAME_PLACEHOLDER_0.l_returnflag AS L_RETURNFLAG, FILENAME_PLACEHOLDER_0.l_linestatus AS L_LINESTATUS, sum(FILENAME_PLACEHOLDER_0.l_quantity) AS SUM_QTY, sum(FILENAME_PLACEHOLDER_0.l_extendedprice) AS SUM_BASE_PRICE, sum(FILENAME_PLACEHOLDER_0.l_extendedprice * Int32(1) - FILENAME_PLACEHOLDER_0.l_discount) AS SUM_DISC_PRICE, sum(FILENAME_PLACEHOLDER_0.l_extendedprice * Int32(1) - FILENAME_PLACEHOLDER_0.l_discount * Int32(1) + FILENAME_PLACEHOLDER_0.l_tax) AS SUM_CHARGE, AVG(FILENAME_PLACEHOLDER_0.l_quantity) AS AVG_QTY, AVG(FILENAME_PLACEHOLDER_0.l_extendedprice) AS AVG_PRICE, AVG(FILENAME_PLACEHOLDER_0.l_discount) AS AVG_DISC, COUNT(Int64(1)) AS COUNT_ORDER\
\n Sort: FILENAME_PLACEHOLDER_0.l_returnflag ASC NULLS LAST, FILENAME_PLACEHOLDER_0.l_linestatus ASC NULLS LAST\
\n Aggregate: groupBy=[[FILENAME_PLACEHOLDER_0.l_returnflag, FILENAME_PLACEHOLDER_0.l_linestatus]], aggr=[[sum(FILENAME_PLACEHOLDER_0.l_quantity), sum(FILENAME_PLACEHOLDER_0.l_extendedprice), sum(FILENAME_PLACEHOLDER_0.l_extendedprice * Int32(1) - FILENAME_PLACEHOLDER_0.l_discount), sum(FILENAME_PLACEHOLDER_0.l_extendedprice * Int32(1) - FILENAME_PLACEHOLDER_0.l_discount * Int32(1) + FILENAME_PLACEHOLDER_0.l_tax), AVG(FILENAME_PLACEHOLDER_0.l_quantity), AVG(FILENAME_PLACEHOLDER_0.l_extendedprice), AVG(FILENAME_PLACEHOLDER_0.l_discount), COUNT(Int64(1))]]\
\n Projection: FILENAME_PLACEHOLDER_0.l_returnflag, FILENAME_PLACEHOLDER_0.l_linestatus, FILENAME_PLACEHOLDER_0.l_quantity, FILENAME_PLACEHOLDER_0.l_extendedprice, FILENAME_PLACEHOLDER_0.l_extendedprice * (CAST(Int32(1) AS Decimal128(19, 0)) - FILENAME_PLACEHOLDER_0.l_discount), FILENAME_PLACEHOLDER_0.l_extendedprice * (CAST(Int32(1) AS Decimal128(19, 0)) - FILENAME_PLACEHOLDER_0.l_discount) * (CAST(Int32(1) AS Decimal128(19, 0)) + FILENAME_PLACEHOLDER_0.l_tax), FILENAME_PLACEHOLDER_0.l_discount\
\n Filter: FILENAME_PLACEHOLDER_0.l_shipdate <= Date32(\"1998-12-01\") - IntervalDayTime(\"IntervalDayTime { days: 120, milliseconds: 0 }\")\
\n TableScan: FILENAME_PLACEHOLDER_0 projection=[l_orderkey, l_partkey, l_suppkey, l_linenumber, l_quantity, l_extendedprice, l_discount, l_tax, l_returnflag, l_linestatus, l_shipdate, l_commitdate, l_receiptdate, l_shipinstruct, l_shipmode, l_comment]"
);
Ok(())
}
Expand Down
2 changes: 1 addition & 1 deletion datafusion/substrait/tests/cases/logical_plans.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ mod tests {

assert_eq!(
format!("{:?}", plan),
"Projection: NOT DATA.a\
"Projection: NOT DATA.a AS EXPR$0\
\n TableScan: DATA projection=[a, b, c, d, e, f]"
);
Ok(())
Expand Down
43 changes: 17 additions & 26 deletions datafusion/substrait/tests/cases/roundtrip_logical_plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,11 @@ async fn wildcard_select() -> Result<()> {
roundtrip("SELECT * FROM data").await
}

#[tokio::test]
async fn select_with_alias() -> Result<()> {
roundtrip("SELECT a AS aliased_a FROM data").await
}

#[tokio::test]
async fn select_with_filter() -> Result<()> {
roundtrip("SELECT * FROM data WHERE a > 1").await
Expand Down Expand Up @@ -367,9 +372,9 @@ async fn implicit_cast() -> Result<()> {
async fn aggregate_case() -> Result<()> {
assert_expected_plan(
"SELECT sum(CASE WHEN a > 0 THEN 1 ELSE NULL END) FROM data",
"Aggregate: groupBy=[[]], aggr=[[sum(CASE WHEN data.a > Int64(0) THEN Int64(1) ELSE Int64(NULL) END)]]\
"Aggregate: groupBy=[[]], aggr=[[sum(CASE WHEN data.a > Int64(0) THEN Int64(1) ELSE Int64(NULL) END) AS sum(CASE WHEN data.a > Int64(0) THEN Int64(1) ELSE NULL END)]]\
\n TableScan: data projection=[a]",
false // NULL vs Int64(NULL)
true
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🎉

)
.await
}
Expand Down Expand Up @@ -589,32 +594,23 @@ async fn roundtrip_union_all() -> Result<()> {

#[tokio::test]
async fn simple_intersect() -> Result<()> {
// Substrait treats both COUNT(*) and COUNT(1) the same
assert_expected_plan(
"SELECT COUNT(*) FROM (SELECT data.a FROM data INTERSECT SELECT data2.a FROM data2);",
"Aggregate: groupBy=[[]], aggr=[[COUNT(Int64(1))]]\
"Aggregate: groupBy=[[]], aggr=[[COUNT(Int64(1)) AS COUNT(*)]]\
\n Projection: \
\n LeftSemi Join: data.a = data2.a\
\n Aggregate: groupBy=[[data.a]], aggr=[[]]\
\n TableScan: data projection=[a]\
\n TableScan: data2 projection=[a]",
false // COUNT(*) vs COUNT(Int64(1))
true
)
.await
}

#[tokio::test]
async fn simple_intersect_table_reuse() -> Result<()> {
assert_expected_plan(
"SELECT COUNT(*) FROM (SELECT data.a FROM data INTERSECT SELECT data.a FROM data);",
"Aggregate: groupBy=[[]], aggr=[[COUNT(Int64(1))]]\
\n Projection: \
\n LeftSemi Join: data.a = data.a\
\n Aggregate: groupBy=[[data.a]], aggr=[[]]\
\n TableScan: data projection=[a]\
\n TableScan: data projection=[a]",
false // COUNT(*) vs COUNT(Int64(1))
)
.await
roundtrip("SELECT COUNT(1) FROM (SELECT data.a FROM data INTERSECT SELECT data.a FROM data);").await
}

#[tokio::test]
Expand Down Expand Up @@ -694,20 +690,14 @@ async fn all_type_literal() -> Result<()> {

#[tokio::test]
async fn roundtrip_literal_list() -> Result<()> {
assert_expected_plan(
"SELECT [[1,2,3], [], NULL, [NULL]] FROM data",
"Projection: List([[1, 2, 3], [], , []])\
\n TableScan: data projection=[]",
false, // "List(..)" vs "make_array(..)"
)
.await
roundtrip("SELECT [[1,2,3], [], NULL, [NULL]] FROM data").await
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I verified that roundtrip actually is more stringent than assert_expected_plan as it checks the plan before and after roundtripping 👍

}

#[tokio::test]
async fn roundtrip_literal_struct() -> Result<()> {
assert_expected_plan(
"SELECT STRUCT(1, true, CAST(NULL AS STRING)) FROM data",
"Projection: Struct({c0:1,c1:true,c2:})\
"Projection: Struct({c0:1,c1:true,c2:}) AS struct(Int64(1),Boolean(true),NULL)\
\n TableScan: data projection=[]",
false, // "Struct(..)" vs "struct(..)"
)
Expand Down Expand Up @@ -980,12 +970,13 @@ async fn assert_expected_plan(

println!("{proto:?}");

let plan2str = format!("{plan2:?}");
assert_eq!(expected_plan_str, &plan2str);

if assert_schema {
assert_eq!(plan.schema(), plan2.schema());
}

let plan2str = format!("{plan2:?}");
assert_eq!(expected_plan_str, &plan2str);

Ok(())
}

Expand Down