Skip to content

Commit b49aaa3

Browse files
maropuviirya
authored andcommitted
[SPARK-32906][SQL] Struct field names should not change after normalizing floats
### What changes were proposed in this pull request? This PR intends to fix a minor bug when normalizing floats for struct types; ``` scala> import org.apache.spark.sql.execution.aggregate.HashAggregateExec scala> val df = Seq(Tuple1(Tuple1(-0.0d)), Tuple1(Tuple1(0.0d))).toDF("k") scala> val agg = df.distinct() scala> agg.explain() == Physical Plan == *(2) HashAggregate(keys=[k#40], functions=[]) +- Exchange hashpartitioning(k#40, 200), true, [id=#62] +- *(1) HashAggregate(keys=[knownfloatingpointnormalized(if (isnull(k#40)) null else named_struct(col1, knownfloatingpointnormalized(normalizenanandzero(k#40._1)))) AS k#40], functions=[]) +- *(1) LocalTableScan [k#40] scala> val aggOutput = agg.queryExecution.sparkPlan.collect { case a: HashAggregateExec => a.output.head } scala> aggOutput.foreach { attr => println(attr.prettyJson) } ### Final Aggregate ### [ { "class" : "org.apache.spark.sql.catalyst.expressions.AttributeReference", "num-children" : 0, "name" : "k", "dataType" : { "type" : "struct", "fields" : [ { "name" : "_1", ^^^ "type" : "double", "nullable" : false, "metadata" : { } } ] }, "nullable" : true, "metadata" : { }, "exprId" : { "product-class" : "org.apache.spark.sql.catalyst.expressions.ExprId", "id" : 40, "jvmId" : "a824e83f-933e-4b85-a1ff-577b5a0e2366" }, "qualifier" : [ ] } ] ### Partial Aggregate ### [ { "class" : "org.apache.spark.sql.catalyst.expressions.AttributeReference", "num-children" : 0, "name" : "k", "dataType" : { "type" : "struct", "fields" : [ { "name" : "col1", ^^^^ "type" : "double", "nullable" : true, "metadata" : { } } ] }, "nullable" : true, "metadata" : { }, "exprId" : { "product-class" : "org.apache.spark.sql.catalyst.expressions.ExprId", "id" : 40, "jvmId" : "a824e83f-933e-4b85-a1ff-577b5a0e2366" }, "qualifier" : [ ] } ] ``` ### Why are the changes needed? bugfix. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Added tests. Closes #29780 from maropu/FixBugInNormalizedFloatingNumbers. Authored-by: Takeshi Yamamuro <[email protected]> Signed-off-by: Liang-Chi Hsieh <[email protected]>
1 parent 75dd864 commit b49aaa3

File tree

2 files changed

+11
-3
lines changed

2 files changed

+11
-3
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeFloatingNumbers.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -129,10 +129,10 @@ object NormalizeFloatingNumbers extends Rule[LogicalPlan] {
129129
Coalesce(children.map(normalize))
130130

131131
case _ if expr.dataType.isInstanceOf[StructType] =>
132-
val fields = expr.dataType.asInstanceOf[StructType].fields.indices.map { i =>
133-
normalize(GetStructField(expr, i))
132+
val fields = expr.dataType.asInstanceOf[StructType].fieldNames.zipWithIndex.map {
133+
case (name, i) => Seq(Literal(name), normalize(GetStructField(expr, i)))
134134
}
135-
val struct = CreateStruct(fields)
135+
val struct = CreateNamedStruct(fields.flatten.toSeq)
136136
KnownFloatingPointNormalized(If(IsNull(expr), Literal(null, struct.dataType), struct))
137137

138138
case _ if expr.dataType.isInstanceOf[ArrayType] =>

sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1043,6 +1043,14 @@ class DataFrameAggregateSuite extends QueryTest
10431043
checkAnswer(sql(queryTemplate("FIRST")), Row(1))
10441044
checkAnswer(sql(queryTemplate("LAST")), Row(3))
10451045
}
1046+
1047+
test("SPARK-32906: struct field names should not change after normalizing floats") {
1048+
val df = Seq(Tuple1(Tuple2(-0.0d, Double.NaN)), Tuple1(Tuple2(0.0d, Double.NaN))).toDF("k")
1049+
val aggs = df.distinct().queryExecution.sparkPlan.collect { case a: HashAggregateExec => a }
1050+
assert(aggs.length == 2)
1051+
assert(aggs.head.output.map(_.dataType.simpleString).head ===
1052+
aggs.last.output.map(_.dataType.simpleString).head)
1053+
}
10461054
}
10471055

10481056
case class B(c: Option[Double])

0 commit comments

Comments
 (0)