Skip to content

Commit a165e10

Browse files
authored
PIVOT command (#22)
* Pivot initial changes * Polishing * Rebase fixes * Add slt tests * Add display for PIVOT * Only check projections in case of subqueries * Add builder test * Add default on null. Cargo fmt * Add default_on_null to proto and tests * Fix protoc and clippy * Cargo fmt
1 parent 3b6996f commit a165e10

File tree

17 files changed

+1541
-18
lines changed

17 files changed

+1541
-18
lines changed

datafusion/core/src/physical_planner.rs

Lines changed: 222 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -76,9 +76,9 @@ use datafusion_expr::expr::{
7676
use datafusion_expr::expr_rewriter::unnormalize_cols;
7777
use datafusion_expr::logical_plan::builder::wrap_projection_for_join_if_necessary;
7878
use datafusion_expr::{
79-
Analyze, DescribeTable, DmlStatement, Explain, ExplainFormat, Extension, FetchType,
80-
Filter, JoinType, RecursiveQuery, SkipType, StringifiedPlan, WindowFrame,
81-
WindowFrameBound, WriteOp,
79+
Analyze, BinaryExpr, DescribeTable, DmlStatement, Explain, ExplainFormat, Extension,
80+
FetchType, Filter, JoinType, LogicalPlanBuilder, RecursiveQuery, SkipType,
81+
StringifiedPlan, WindowFrame, WindowFrameBound, WriteOp,
8282
};
8383
use datafusion_physical_expr::aggregate::{AggregateExprBuilder, AggregateFunctionExpr};
8484
use datafusion_physical_expr::expressions::{Column, Literal};
@@ -91,12 +91,15 @@ use datafusion_physical_plan::unnest::ListUnnest;
9191
use crate::schema_equivalence::schema_satisfied_by;
9292
use async_trait::async_trait;
9393
use datafusion_datasource::file_groups::FileGroup;
94+
use datafusion_expr_common::operator::Operator;
9495
use futures::{StreamExt, TryStreamExt};
9596
use itertools::{multiunzip, Itertools};
9697
use log::{debug, trace};
9798
use sqlparser::ast::NullTreatment;
9899
use tokio::sync::Mutex;
99100

101+
use datafusion_physical_plan::collect;
102+
100103
/// Physical query planner that converts a `LogicalPlan` to an
101104
/// `ExecutionPlan` suitable for execution.
102105
#[async_trait]
@@ -887,7 +890,60 @@ impl DefaultPhysicalPlanner {
887890
options.clone(),
888891
))
889892
}
893+
LogicalPlan::Pivot(pivot) => {
894+
return if !pivot.pivot_values.is_empty() {
895+
let agg_plan = transform_pivot_to_aggregate(
896+
Arc::new(pivot.input.as_ref().clone()),
897+
&pivot.aggregate_expr,
898+
&pivot.pivot_column,
899+
pivot.pivot_values.clone(),
900+
pivot.default_on_null_expr.as_ref(),
901+
)?;
902+
903+
self.create_physical_plan(&agg_plan, session_state).await
904+
} else if let Some(subquery) = &pivot.value_subquery {
905+
let optimized_subquery = session_state.optimize(subquery.as_ref())?;
906+
907+
let subquery_physical_plan = self
908+
.create_physical_plan(&optimized_subquery, session_state)
909+
.await?;
910+
911+
let subquery_results = collect(
912+
Arc::clone(&subquery_physical_plan),
913+
session_state.task_ctx(),
914+
)
915+
.await?;
916+
917+
let mut pivot_values = Vec::new();
918+
for batch in subquery_results.iter() {
919+
if batch.num_columns() != 1 {
920+
return plan_err!(
921+
"Pivot subquery must return a single column"
922+
);
923+
}
924+
925+
let column = batch.column(0);
926+
for row_idx in 0..batch.num_rows() {
927+
if !column.is_null(row_idx) {
928+
pivot_values
929+
.push(ScalarValue::try_from_array(column, row_idx)?);
930+
}
931+
}
932+
}
933+
934+
let agg_plan = transform_pivot_to_aggregate(
935+
Arc::new(pivot.input.as_ref().clone()),
936+
&pivot.aggregate_expr,
937+
&pivot.pivot_column,
938+
pivot_values,
939+
pivot.default_on_null_expr.as_ref(),
940+
)?;
890941

942+
self.create_physical_plan(&agg_plan, session_state).await
943+
} else {
944+
plan_err!("PIVOT operation requires at least one value to pivot on")
945+
}
946+
}
891947
// 2 Children
892948
LogicalPlan::Join(Join {
893949
left,
@@ -1683,6 +1739,136 @@ pub use datafusion_physical_expr::{
16831739
create_physical_sort_expr, create_physical_sort_exprs,
16841740
};
16851741

1742+
/// Transform a PIVOT operation into a more standard Aggregate + Projection plan
1743+
/// For known pivot values, we create a projection that includes "IS NOT DISTINCT FROM" conditions
1744+
///
1745+
/// For example, for SUM(amount) PIVOT(quarter FOR quarter in ('2023_Q1', '2023_Q2')), we create:
1746+
/// - SUM(amount) FILTER (WHERE quarter IS NOT DISTINCT FROM '2023_Q1') AS "2023_Q1"
1747+
/// - SUM(amount) FILTER (WHERE quarter IS NOT DISTINCT FROM '2023_Q2') AS "2023_Q2"
1748+
///
1749+
/// If DEFAULT ON NULL is specified, each aggregate expression is wrapped with an outer projection that
1750+
/// applies COALESCE to the results.
1751+
pub fn transform_pivot_to_aggregate(
1752+
input: Arc<LogicalPlan>,
1753+
aggregate_expr: &Expr,
1754+
pivot_column: &datafusion_common::Column,
1755+
pivot_values: Vec<ScalarValue>,
1756+
default_on_null_expr: Option<&Expr>,
1757+
) -> Result<LogicalPlan> {
1758+
let df_schema = input.schema();
1759+
1760+
let all_columns: Vec<datafusion_common::Column> = df_schema.columns();
1761+
1762+
// Filter to include only columns we want for GROUP BY
1763+
// (exclude pivot column and aggregate expression columns)
1764+
let group_by_columns: Vec<Expr> = all_columns
1765+
.into_iter()
1766+
.filter(|col: &datafusion_common::Column| {
1767+
col.name != pivot_column.name
1768+
&& !aggregate_expr
1769+
.column_refs()
1770+
.iter()
1771+
.any(|agg_col| agg_col.name == col.name)
1772+
})
1773+
.map(|col: datafusion_common::Column| Expr::Column(col))
1774+
.collect();
1775+
1776+
let builder = LogicalPlanBuilder::from(Arc::unwrap_or_clone(input));
1777+
1778+
// Create the aggregate plan with filtered aggregates
1779+
let mut aggregate_exprs = Vec::new();
1780+
1781+
for value in &pivot_values {
1782+
let filter_condition = Expr::BinaryExpr(BinaryExpr::new(
1783+
Box::new(Expr::Column(pivot_column.clone())),
1784+
Operator::IsNotDistinctFrom,
1785+
Box::new(Expr::Literal(value.clone())),
1786+
));
1787+
1788+
let filtered_agg = match aggregate_expr {
1789+
Expr::AggregateFunction(agg) => {
1790+
let mut new_params = agg.params.clone();
1791+
new_params.filter = Some(Box::new(filter_condition));
1792+
Expr::AggregateFunction(AggregateFunction {
1793+
func: Arc::clone(&agg.func),
1794+
params: new_params,
1795+
})
1796+
}
1797+
_ => {
1798+
return plan_err!(
1799+
"Unsupported aggregate expression should always be AggregateFunction"
1800+
);
1801+
}
1802+
};
1803+
1804+
// Use the pivot value as the column name
1805+
let field_name = value.to_string().trim_matches('\'').to_string();
1806+
let aliased_agg = Expr::Alias(Alias {
1807+
expr: Box::new(filtered_agg),
1808+
relation: None,
1809+
name: field_name,
1810+
metadata: None,
1811+
});
1812+
1813+
aggregate_exprs.push(aliased_agg);
1814+
}
1815+
1816+
// Create the plan with the aggregate
1817+
let aggregate_plan = builder
1818+
.aggregate(group_by_columns, aggregate_exprs)?
1819+
.build()?;
1820+
1821+
// If DEFAULT ON NULL is specified, add a projection to apply COALESCE
1822+
if let Some(default_expr) = default_on_null_expr {
1823+
let schema = aggregate_plan.schema();
1824+
let mut projection_exprs = Vec::new();
1825+
1826+
for field in schema.fields() {
1827+
if !pivot_values
1828+
.iter()
1829+
.any(|v| field.name() == v.to_string().trim_matches('\''))
1830+
{
1831+
projection_exprs.push(Expr::Column(
1832+
datafusion_common::Column::from_name(field.name()),
1833+
));
1834+
}
1835+
}
1836+
1837+
// Apply COALESCE to aggregate columns
1838+
for value in &pivot_values {
1839+
let field_name = value.to_string().trim_matches('\'').to_string();
1840+
let aggregate_col =
1841+
Expr::Column(datafusion_common::Column::from_name(&field_name));
1842+
1843+
// Create COALESCE expression using CASE: CASE WHEN col IS NULL THEN default_value ELSE col END
1844+
let coalesce_expr = Expr::Case(datafusion_expr::expr::Case {
1845+
expr: None,
1846+
when_then_expr: vec![(
1847+
Box::new(Expr::IsNull(Box::new(aggregate_col.clone()))),
1848+
Box::new(default_expr.clone()),
1849+
)],
1850+
else_expr: Some(Box::new(aggregate_col)),
1851+
});
1852+
1853+
let aliased_coalesce = Expr::Alias(Alias {
1854+
expr: Box::new(coalesce_expr),
1855+
relation: None,
1856+
name: field_name,
1857+
metadata: None,
1858+
});
1859+
1860+
projection_exprs.push(aliased_coalesce);
1861+
}
1862+
1863+
// Apply the projection
1864+
LogicalPlanBuilder::from(aggregate_plan)
1865+
.project(projection_exprs)?
1866+
.build()
1867+
} else {
1868+
Ok(aggregate_plan)
1869+
}
1870+
}
1871+
16861872
impl DefaultPhysicalPlanner {
16871873
/// Handles capturing the various plans for EXPLAIN queries
16881874
///
@@ -2044,6 +2230,39 @@ impl DefaultPhysicalPlanner {
20442230
})
20452231
.collect::<Result<Vec<_>>>()?;
20462232

2233+
// When we detect a PIVOT-derived plan with a value_subquery, ensure all generated columns are preserved
2234+
if let LogicalPlan::Pivot(pivot) = input.as_ref() {
2235+
if pivot.value_subquery.is_some()
2236+
&& input_exec
2237+
.as_any()
2238+
.downcast_ref::<AggregateExec>()
2239+
.is_some()
2240+
{
2241+
let agg_exec =
2242+
input_exec.as_any().downcast_ref::<AggregateExec>().unwrap();
2243+
let schema = input_exec.schema();
2244+
let group_by_len = agg_exec.group_expr().expr().len();
2245+
2246+
if group_by_len < schema.fields().len() {
2247+
let mut all_exprs = physical_exprs.clone();
2248+
2249+
for (i, field) in
2250+
schema.fields().iter().enumerate().skip(group_by_len)
2251+
{
2252+
if !physical_exprs.iter().any(|(_, name)| name == field.name()) {
2253+
all_exprs.push((
2254+
Arc::new(Column::new(field.name(), i))
2255+
as Arc<dyn PhysicalExpr>,
2256+
field.name().clone(),
2257+
));
2258+
}
2259+
}
2260+
2261+
return Ok(Arc::new(ProjectionExec::try_new(all_exprs, input_exec)?));
2262+
}
2263+
}
2264+
}
2265+
20472266
Ok(Arc::new(ProjectionExec::try_new(
20482267
physical_exprs,
20492268
input_exec,

datafusion/expr/src/logical_plan/builder.rs

Lines changed: 44 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ use crate::expr_rewriter::{
3232
};
3333
use crate::logical_plan::{
3434
Aggregate, Analyze, Distinct, DistinctOn, EmptyRelation, Explain, Filter, Join,
35-
JoinConstraint, JoinType, Limit, LogicalPlan, Partitioning, PlanType, Prepare,
35+
JoinConstraint, JoinType, Limit, LogicalPlan, Partitioning, Pivot, PlanType, Prepare,
3636
Projection, Repartition, Sort, SubqueryAlias, TableScan, Union, Unnest, Values,
3737
Window,
3838
};
@@ -1427,6 +1427,23 @@ impl LogicalPlanBuilder {
14271427
unnest_with_options(Arc::unwrap_or_clone(self.plan), columns, options)
14281428
.map(Self::new)
14291429
}
1430+
1431+
pub fn pivot(
1432+
self,
1433+
aggregate_expr: Expr,
1434+
pivot_column: Column,
1435+
pivot_values: Vec<ScalarValue>,
1436+
default_on_null: Option<Expr>,
1437+
) -> Result<Self> {
1438+
let pivot_plan = Pivot::try_new(
1439+
self.plan,
1440+
aggregate_expr,
1441+
pivot_column,
1442+
pivot_values,
1443+
default_on_null,
1444+
)?;
1445+
Ok(Self::new(LogicalPlan::Pivot(pivot_plan)))
1446+
}
14301447
}
14311448

14321449
impl From<LogicalPlan> for LogicalPlanBuilder {
@@ -2824,4 +2841,30 @@ mod tests {
28242841

28252842
Ok(())
28262843
}
2844+
2845+
#[test]
2846+
fn plan_builder_pivot() -> Result<()> {
2847+
let schema = Schema::new(vec![
2848+
Field::new("region", DataType::Utf8, false),
2849+
Field::new("product", DataType::Utf8, false),
2850+
Field::new("sales", DataType::Int32, false),
2851+
]);
2852+
2853+
let plan = LogicalPlanBuilder::scan("sales", table_source(&schema), None)?
2854+
.pivot(
2855+
col("sales"),
2856+
Column::from_name("product"),
2857+
vec![
2858+
ScalarValue::Utf8(Some("widget".to_string())),
2859+
ScalarValue::Utf8(Some("gadget".to_string())),
2860+
],
2861+
None,
2862+
)?
2863+
.build()?;
2864+
2865+
let expected = "Pivot: sales FOR product IN (widget, gadget)\n TableScan: sales";
2866+
assert_eq!(expected, format!("{plan}"));
2867+
2868+
Ok(())
2869+
}
28272870
}

0 commit comments

Comments
 (0)