Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ class Analyzer(
ResolveRelations ::
ResolveReferences ::
ResolveGroupingAnalytics ::
ResolvePivot ::
ResolveSortReferences ::
ResolveGenerate ::
ResolveFunctions ::
Expand Down Expand Up @@ -166,6 +167,10 @@ class Analyzer(
case g: GroupingAnalytics if g.child.resolved && hasUnresolvedAlias(g.aggregations) =>
g.withNewAggs(assignAliases(g.aggregations))

case Pivot(groupByExprs, pivotColumn, pivotValues, aggregates, child)
if child.resolved && hasUnresolvedAlias(groupByExprs) =>
Pivot(assignAliases(groupByExprs), pivotColumn, pivotValues, aggregates, child)

case Project(projectList, child) if child.resolved && hasUnresolvedAlias(projectList) =>
Project(assignAliases(projectList), child)
}
Expand Down Expand Up @@ -248,6 +253,43 @@ class Analyzer(
}
}

object ResolvePivot extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case p: Pivot if !p.childrenResolved => p
case Pivot(groupByExprs, pivotColumn, pivotValues, aggregates, child) =>
val singleAgg = aggregates.size == 1
val pivotAggregates: Seq[NamedExpression] = pivotValues.flatMap { value =>
def ifExpr(expr: Expression) = {
If(EqualTo(pivotColumn, value), expr, Literal(null))
}
aggregates.map { aggregate =>
val filteredAggregate = aggregate.transformDown {
// Assumption is the aggregate function ignores nulls. This is true for all current
// AggregateFunction's with the exception of First and Last in their default mode
// (which we handle) and possibly some Hive UDAF's.
case First(expr, _) =>
First(ifExpr(expr), Literal(true))
case Last(expr, _) =>
Last(ifExpr(expr), Literal(true))
case a: AggregateFunction =>
a.withNewChildren(a.children.map(ifExpr))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems we still need to check the number of children and make sure we have a single child?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It should now work fine with aggregate functions that have multiple children as long as they ignore updates when all values are null. For example Corr should work since it only updates its aggregation buffer if both its arguments are non null.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh, yes. You are right.

}
if (filteredAggregate.fastEquals(aggregate)) {
throw new AnalysisException(
s"Aggregate expression required for pivot, found '$aggregate'")
}
val name = if (singleAgg) value.toString else value + "_" + aggregate.prettyString
Alias(filteredAggregate, name)()
}
}
val newGroupByExprs = groupByExprs.map {
case UnresolvedAlias(e) => e
case e => e
}
Aggregate(newGroupByExprs, groupByExprs ++ pivotAggregates, child)
}
}

/**
* Replaces [[UnresolvedRelation]]s with concrete relations from the catalog.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -385,6 +385,20 @@ case class Rollup(
this.copy(aggregations = aggs)
}

case class Pivot(
groupByExprs: Seq[NamedExpression],
pivotColumn: Expression,
pivotValues: Seq[Literal],
aggregates: Seq[Expression],
child: LogicalPlan) extends UnaryNode {
override def output: Seq[Attribute] = groupByExprs.map(_.toAttribute) ++ aggregates match {
case agg :: Nil => pivotValues.map(value => AttributeReference(value.toString, agg.dataType)())
case _ => pivotValues.flatMap{ value =>
aggregates.map(agg => AttributeReference(value + "_" + agg.prettyString, agg.dataType)())
}
}
}

case class Limit(limitExpr: Expression, child: LogicalPlan) extends UnaryNode {
override def output: Seq[Attribute] = child.output

Expand Down
103 changes: 93 additions & 10 deletions sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@ import org.apache.spark.annotation.Experimental
import org.apache.spark.sql.catalyst.analysis.{UnresolvedFunction, UnresolvedAlias, UnresolvedAttribute, Star}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.plans.logical.{Rollup, Cube, Aggregate}
import org.apache.spark.sql.types.NumericType
import org.apache.spark.sql.catalyst.plans.logical.{Pivot, Rollup, Cube, Aggregate}
import org.apache.spark.sql.types.{StringType, NumericType}


/**
Expand All @@ -50,14 +50,8 @@ class GroupedData protected[sql](
aggExprs
}

val aliasedAgg = aggregates.map {
// Wrap UnresolvedAttribute with UnresolvedAlias, as when we resolve UnresolvedAttribute, we
// will remove intermediate Alias for ExtractValue chain, and we need to alias it again to
// make it a NamedExpression.
case u: UnresolvedAttribute => UnresolvedAlias(u)
case expr: NamedExpression => expr
case expr: Expression => Alias(expr, expr.prettyString)()
}
val aliasedAgg = aggregates.map(alias)

groupType match {
case GroupedData.GroupByType =>
DataFrame(
Expand All @@ -68,9 +62,22 @@ class GroupedData protected[sql](
case GroupedData.CubeType =>
DataFrame(
df.sqlContext, Cube(groupingExprs, df.logicalPlan, aliasedAgg))
case GroupedData.PivotType(pivotCol, values) =>
val aliasedGrps = groupingExprs.map(alias)
DataFrame(
df.sqlContext, Pivot(aliasedGrps, pivotCol, values, aggExprs, df.logicalPlan))
}
}

// Wrap UnresolvedAttribute with UnresolvedAlias, as when we resolve UnresolvedAttribute, we
// will remove intermediate Alias for ExtractValue chain, and we need to alias it again to
// make it a NamedExpression.
private[this] def alias(expr: Expression): NamedExpression = expr match {
case u: UnresolvedAttribute => UnresolvedAlias(u)
case expr: NamedExpression => expr
case expr: Expression => Alias(expr, expr.prettyString)()
}

private[this] def aggregateNumericColumns(colNames: String*)(f: Expression => AggregateFunction)
: DataFrame = {

Expand Down Expand Up @@ -273,6 +280,77 @@ class GroupedData protected[sql](
def sum(colNames: String*): DataFrame = {
aggregateNumericColumns(colNames : _*)(Sum)
}

/**
* (Scala-specific) Pivots a column of the current [[DataFrame]] and preform the specified
* aggregation.
* {{{
* // Compute the sum of earnings for each year by course with each course as a separate column
* df.groupBy($"year").pivot($"course", "dotNET", "Java").agg(sum($"earnings"))
* // Or without specifying column values
* df.groupBy($"year").pivot($"course").agg(sum($"earnings"))
* }}}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about we let users know that if no pivot values are provided, we will launch a job to find all distinct values of the pivot column?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another question I have is that, what will happen if we have too many distinct values? I am wondering if we should always ask users to put pivot values?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is a note below in the description of the values parameter

If values are not provided the method with do an immediate call to .distinct() on the pivot column.
Do we need to duplicate/move that note to the method description?

Right now if there is too many distinct values we probably get OOM. Obvious solve is to have some configurable maximum above which we give an error. Should I try to add that now?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah. Let's add that check. In our error message, we can ask users to change the value of that conf (btw, let's make sure the default value of that conf is large enough for common use cases).

* @param pivotColumn Column to pivot
* @param values Optional list of values of pivotColumn that will be translated to columns in the
* output data frame. If values are not provided the method with do an immediate
* call to .distinct() on the pivot column.
* @since 1.6.0
*/
@scala.annotation.varargs
def pivot(pivotColumn: Column, values: Column*): GroupedData = groupType match {
case _: GroupedData.PivotType =>
throw new UnsupportedOperationException("repeated pivots are not supported")
case GroupedData.GroupByType =>
val pivotValues = if (values.nonEmpty) {
values.map {
case Column(literal: Literal) => literal
case other =>
throw new UnsupportedOperationException(
s"The values of a pivot must be literals, found $other")
}
} else {
// This is to prevent unintended OOM errors when the number of distinct values is large
val maxValues = df.sqlContext.conf.getConf(SQLConf.DATAFRAME_PIVOT_MAX_VALUES)
// Get the distinct values of the column and sort them so its consistent
val values = df.select(pivotColumn)
.distinct()
.sort(pivotColumn)
.map(_.get(0))
.take(maxValues + 1)
.map(Literal(_)).toSeq
if (values.length > maxValues) {
throw new RuntimeException(
s"The pivot column $pivotColumn has more than $maxValues distinct values, " +
"this could indicate an error. " +
"If this was intended, set \"" + SQLConf.DATAFRAME_PIVOT_MAX_VALUES.key + "\" " +
s"to at least the number of distinct values of the pivot column.")
}
values
}
new GroupedData(df, groupingExprs, GroupedData.PivotType(pivotColumn.expr, pivotValues))
case _ =>
throw new UnsupportedOperationException("pivot is only supported after a groupBy")
}

/**
* Pivots a column of the current [[DataFrame]] and preform the specified aggregation.
* {{{
* // Compute the sum of earnings for each year by course with each course as a separate column
* df.groupBy("year").pivot("course", "dotNET", "Java").sum("earnings")
* // Or without specifying column values
* df.groupBy("year").pivot("course").sum("earnings")
* }}}
* @param pivotColumn Column to pivot
* @param values Optional list of values of pivotColumn that will be translated to columns in the
* output data frame. If values are not provided the method with do an immediate
* call to .distinct() on the pivot column.
* @since 1.6.0
*/
@scala.annotation.varargs
def pivot(pivotColumn: String, values: Any*): GroupedData = {
val resolvedPivotColumn = Column(df.resolve(pivotColumn))
pivot(resolvedPivotColumn, values.map(functions.lit): _*)
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For the first version, maybe we can just have the API using Column as the argument type? (I am thinking about the type of values. I am not sure String is the right type).

}


Expand Down Expand Up @@ -307,4 +385,9 @@ private[sql] object GroupedData {
* To indicate it's the ROLLUP
*/
private[sql] object RollupType extends GroupType

/**
* To indicate it's the PIVOT
*/
private[sql] case class PivotType(pivotCol: Expression, values: Seq[Literal]) extends GroupType
}
7 changes: 7 additions & 0 deletions sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala
Original file line number Diff line number Diff line change
Expand Up @@ -453,6 +453,13 @@ private[spark] object SQLConf {
defaultValue = Some(true),
isPublic = false)

val DATAFRAME_PIVOT_MAX_VALUES = intConf(
"spark.sql.pivotMaxValues",
defaultValue = Some(10000),
doc = "When doing a pivot without specifying values for the pivot column this is the maximum " +
"number of (distinct) values that will be collected without error."
)

val RUN_SQL_ON_FILES = booleanConf("spark.sql.runSQLOnFiles",
defaultValue = Some(true),
isPublic = false,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql

import org.apache.spark.sql.functions._
import org.apache.spark.sql.test.SharedSQLContext

class DataFramePivotSuite extends QueryTest with SharedSQLContext{
import testImplicits._

test("pivot courses with literals") {
checkAnswer(
courseSales.groupBy($"year").pivot($"course", lit("dotNET"), lit("Java"))
.agg(sum($"earnings")),
Row(2012, 15000.0, 20000.0) :: Row(2013, 48000.0, 30000.0) :: Nil
)
}

test("pivot year with literals") {
checkAnswer(
courseSales.groupBy($"course").pivot($"year", lit(2012), lit(2013)).agg(sum($"earnings")),
Row("dotNET", 15000.0, 48000.0) :: Row("Java", 20000.0, 30000.0) :: Nil
)
}

test("pivot courses with literals and multiple aggregations") {
checkAnswer(
courseSales.groupBy($"year").pivot($"course", lit("dotNET"), lit("Java"))
.agg(sum($"earnings"), avg($"earnings")),
Row(2012, 15000.0, 7500.0, 20000.0, 20000.0) ::
Row(2013, 48000.0, 48000.0, 30000.0, 30000.0) :: Nil
)
}

test("pivot year with string values (cast)") {
checkAnswer(
courseSales.groupBy("course").pivot("year", "2012", "2013").sum("earnings"),
Row("dotNET", 15000.0, 48000.0) :: Row("Java", 20000.0, 30000.0) :: Nil
)
}

test("pivot year with int values") {
checkAnswer(
courseSales.groupBy("course").pivot("year", 2012, 2013).sum("earnings"),
Row("dotNET", 15000.0, 48000.0) :: Row("Java", 20000.0, 30000.0) :: Nil
)
}

test("pivot courses with no values") {
// Note Java comes before dotNet in sorted order
checkAnswer(
courseSales.groupBy($"year").pivot($"course").agg(sum($"earnings")),
Row(2012, 20000.0, 15000.0) :: Row(2013, 30000.0, 48000.0) :: Nil
)
}

test("pivot year with no values") {
checkAnswer(
courseSales.groupBy($"course").pivot($"year").agg(sum($"earnings")),
Row("dotNET", 15000.0, 48000.0) :: Row("Java", 20000.0, 30000.0) :: Nil
)
}

test("pivot max values inforced") {
sqlContext.conf.setConf(SQLConf.DATAFRAME_PIVOT_MAX_VALUES, 1)
intercept[RuntimeException](
courseSales.groupBy($"year").pivot($"course")
)
sqlContext.conf.setConf(SQLConf.DATAFRAME_PIVOT_MAX_VALUES,
SQLConf.DATAFRAME_PIVOT_MAX_VALUES.defaultValue.get)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,17 @@ private[sql] trait SQLTestData { self =>
df
}

protected lazy val courseSales: DataFrame = {
val df = sqlContext.sparkContext.parallelize(
CourseSales("dotNET", 2012, 10000) ::
CourseSales("Java", 2012, 20000) ::
CourseSales("dotNET", 2012, 5000) ::
CourseSales("dotNET", 2013, 48000) ::
CourseSales("Java", 2013, 30000) :: Nil).toDF()
df.registerTempTable("courseSales")
df
}

/**
* Initialize all test data such that all temp tables are properly registered.
*/
Expand Down Expand Up @@ -295,4 +306,5 @@ private[sql] object SQLTestData {
case class Person(id: Int, name: String, age: Int)
case class Salary(personId: Int, salary: Double)
case class ComplexData(m: Map[String, Int], s: TestData, a: Seq[Int], b: Boolean)
case class CourseSales(course: String, year: Int, earnings: Double)
}