Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -98,8 +98,10 @@ object NormalizeFloatingNumbers extends Rule[LogicalPlan] {
}

private[sql] def normalize(expr: Expression): Expression = expr match {
case _ if expr.dataType == FloatType || expr.dataType == DoubleType =>
NormalizeNaNAndZero(expr)
case _ if !needNormalize(expr.dataType) => expr

case a: Alias =>
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A small refactor so that we can retain the alias when normalizing.

a.withNewChildren(Seq(normalize(a.child)))

case CreateNamedStruct(children) =>
CreateNamedStruct(children.map(normalize))
Expand All @@ -113,22 +115,22 @@ object NormalizeFloatingNumbers extends Rule[LogicalPlan] {
case CreateMap(children) =>
CreateMap(children.map(normalize))

case a: Alias if needNormalize(a.dataType) =>
a.withNewChildren(Seq(normalize(a.child)))
case _ if expr.dataType == FloatType || expr.dataType == DoubleType =>
NormalizeNaNAndZero(expr)

case _ if expr.dataType.isInstanceOf[StructType] && needNormalize(expr.dataType) =>
case _ if expr.dataType.isInstanceOf[StructType] =>
val fields = expr.dataType.asInstanceOf[StructType].fields.indices.map { i =>
normalize(GetStructField(expr, i))
}
CreateStruct(fields)

case _ if expr.dataType.isInstanceOf[ArrayType] && needNormalize(expr.dataType) =>
case _ if expr.dataType.isInstanceOf[ArrayType] =>
val ArrayType(et, containsNull) = expr.dataType
val lv = NamedLambdaVariable("arg", et, containsNull)
val function = normalize(lv)
ArrayTransform(expr, LambdaFunction(function, Seq(lv)))

case _ => expr
case _ => throw new IllegalStateException(s"fail to normalize $expr")
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.encoders.RowEncoder
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
import org.apache.spark.sql.catalyst.optimizer.NormalizeFloatingNumbers
import org.apache.spark.sql.catalyst.planning._
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical._
Expand Down Expand Up @@ -331,8 +332,17 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {

val stateVersion = conf.getConf(SQLConf.STREAMING_AGGREGATION_STATE_FORMAT_VERSION)

// Ideally this should be done in `NormalizeFloatingNumbers`, but we do it here because
Copy link
Member

@viirya viirya Jan 30, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't we need to remove it from AggUtils.createAggregate?

// `groupingExpressions` is not extracted during logical phase.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will be refactored after https://issues.apache.org/jira/browse/SPARK-25914?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yea

val normalizedGroupingExpressions = namedGroupingExpressions.map { e =>
NormalizeFloatingNumbers.normalize(e) match {
case n: NamedExpression => n
case other => Alias(other, e.name)(exprId = e.exprId)
}
}

aggregate.AggUtils.planStreamingAggregation(
namedGroupingExpressions,
normalizedGroupingExpressions,
aggregateExpressions.map(expr => expr.asInstanceOf[AggregateExpression]),
rewrittenResultExpressions,
stateVersion,
Expand Down Expand Up @@ -414,16 +424,25 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
"Spark user mailing list.")
}

// Ideally this should be done in `NormalizeFloatingNumbers`, but we do it here because
// `groupingExpressions` is not extracted during logical phase.
val normalizedGroupingExpressions = groupingExpressions.map { e =>
NormalizeFloatingNumbers.normalize(e) match {
case n: NamedExpression => n
case other => Alias(other, e.name)(exprId = e.exprId)
}
}

val aggregateOperator =
if (functionsWithDistinct.isEmpty) {
aggregate.AggUtils.planAggregateWithoutDistinct(
groupingExpressions,
normalizedGroupingExpressions,
aggregateExpressions,
resultExpressions,
planLater(child))
} else {
aggregate.AggUtils.planAggregateWithOneDistinct(
groupingExpressions,
normalizedGroupingExpressions,
functionsWithDistinct,
functionsWithoutDistinct,
resultExpressions,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,20 +35,12 @@ object AggUtils {
initialInputBufferOffset: Int = 0,
resultExpressions: Seq[NamedExpression] = Nil,
child: SparkPlan): SparkPlan = {
// Ideally this should be done in `NormalizeFloatingNumbers`, but we do it here because
// `groupingExpressions` is not extracted during logical phase.
val normalizedGroupingExpressions = groupingExpressions.map { e =>
NormalizeFloatingNumbers.normalize(e) match {
case n: NamedExpression => n
case other => Alias(other, e.name)(exprId = e.exprId)
}
}
val useHash = HashAggregateExec.supportsAggregate(
aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes))
if (useHash) {
HashAggregateExec(
requiredChildDistributionExpressions = requiredChildDistributionExpressions,
groupingExpressions = normalizedGroupingExpressions,
groupingExpressions = groupingExpressions,
aggregateExpressions = aggregateExpressions,
aggregateAttributes = aggregateAttributes,
initialInputBufferOffset = initialInputBufferOffset,
Expand All @@ -61,7 +53,7 @@ object AggUtils {
if (objectHashEnabled && useObjectHash) {
ObjectHashAggregateExec(
requiredChildDistributionExpressions = requiredChildDistributionExpressions,
groupingExpressions = normalizedGroupingExpressions,
groupingExpressions = groupingExpressions,
aggregateExpressions = aggregateExpressions,
aggregateAttributes = aggregateAttributes,
initialInputBufferOffset = initialInputBufferOffset,
Expand All @@ -70,7 +62,7 @@ object AggUtils {
} else {
SortAggregateExec(
requiredChildDistributionExpressions = requiredChildDistributionExpressions,
groupingExpressions = normalizedGroupingExpressions,
groupingExpressions = groupingExpressions,
aggregateExpressions = aggregateExpressions,
aggregateAttributes = aggregateAttributes,
initialInputBufferOffset = initialInputBufferOffset,
Expand Down