Skip to content

Commit 2dce36e

Browse files
committed
test -0.0 equality for complex types
1 parent 1116d3d commit 2dce36e

File tree

2 files changed

+31
-6
lines changed

2 files changed

+31
-6
lines changed

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralGenerator.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,8 @@ object LiteralGenerator {
178178
case BinaryType => binaryLiteralGen
179179
case CalendarIntervalType => calendarIntervalLiterGen
180180
case DecimalType.Fixed(precision, scale) => decimalLiteralGen(precision, scale)
181+
case ArrayType(et, _) => randomGen(et).map(
182+
lit => Literal.create(Array(lit.value), ArrayType(et)))
181183
case dt => throw new IllegalArgumentException(s"not supported type $dt")
182184
}
183185
}

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala

Lines changed: 29 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ import java.sql.{Date, Timestamp}
2222
import scala.collection.immutable.HashSet
2323

2424
import org.apache.spark.SparkFunSuite
25-
import org.apache.spark.sql.RandomDataGenerator
25+
import org.apache.spark.sql.{RandomDataGenerator, Row}
2626
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
2727
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
2828
import org.apache.spark.sql.catalyst.encoders.ExamplePointUDT
@@ -91,6 +91,10 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper {
9191
DataTypeTestUtils.propertyCheckSupported.foreach { dt =>
9292
checkConsistencyBetweenInterpretedAndCodegen(EqualTo, dt, dt)
9393
checkConsistencyBetweenInterpretedAndCodegen(EqualNullSafe, dt, dt)
94+
95+
val arrayType = ArrayType(dt)
96+
checkConsistencyBetweenInterpretedAndCodegen(EqualTo, arrayType, arrayType)
97+
checkConsistencyBetweenInterpretedAndCodegen(EqualNullSafe, arrayType, arrayType)
9498
}
9599
}
96100

@@ -496,11 +500,30 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper {
496500
checkEvaluation(EqualTo(infinity, infinity), true)
497501
}
498502

499-
test("SPARK-32688: 0.0 and -0.0 should be considered equal") {
500-
checkEvaluation(EqualTo(Literal(0.0), Literal(-0.0)), true)
501-
checkEvaluation(EqualNullSafe(Literal(0.0), Literal(-0.0)), true)
502-
checkEvaluation(EqualTo(Literal(0.0f), Literal(-0.0f)), true)
503-
checkEvaluation(EqualNullSafe(Literal(0.0f), Literal(-0.0f)), true)
503+
private def testEquality(literals: Seq[Literal]): Unit = {
504+
literals.foreach(left => {
505+
literals.foreach(right => {
506+
checkEvaluation(EqualTo(left, right), true)
507+
checkEvaluation(EqualNullSafe(left, right), true)
508+
509+
val leftArray = Literal.create(Array(left.value), ArrayType(left.dataType))
510+
val rightArray = Literal.create(Array(right.value), ArrayType(right.dataType))
511+
checkEvaluation(EqualTo(leftArray, rightArray), true)
512+
checkEvaluation(EqualNullSafe(leftArray, rightArray), true)
513+
514+
val leftStruct = Literal.create(
515+
Row(left.value), new StructType().add("a", left.dataType))
516+
val rightStruct = Literal.create(
517+
Row(right.value), new StructType().add("a", right.dataType))
518+
checkEvaluation(EqualTo(leftStruct, rightStruct), true)
519+
checkEvaluation(EqualNullSafe(leftStruct, rightStruct), true)
520+
})
521+
})
522+
}
523+
524+
test("SPARK-32688: 0.0 and -0.0 should be equal") {
525+
testEquality(Seq(Literal(0.0), Literal(-0.0)))
526+
testEquality(Seq(Literal(0.0f), Literal(-0.0f)))
504527
}
505528

506529
test("SPARK-22693: InSet should not use global variables") {

0 commit comments

Comments
 (0)