Skip to content

Commit aefeaa7

Browse files
cloud-fanhvanhovell
authored andcommitted
[SPARK-18053][SQL] compare unsafe and safe complex-type values correctly
## What changes were proposed in this pull request? In Spark SQL, some expression may output safe format values, e.g. `CreateArray`, `CreateStruct`, `Cast`, etc. When we compare 2 values, we should be able to compare safe and unsafe formats. The `GreaterThan`, `LessThan`, etc. in Spark SQL already handles it, but the `EqualTo` doesn't. This PR fixes it. ## How was this patch tested? new unit test and regression test Author: Wenchen Fan <[email protected]> Closes #15929 from cloud-fan/type-aware. (cherry picked from commit 84284e8) Signed-off-by: Herman van Hovell <[email protected]>
1 parent 072f4c5 commit aefeaa7

File tree

5 files changed

+59
-35
lines changed

5 files changed

+59
-35
lines changed

sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -578,12 +578,8 @@ public boolean equals(Object other) {
578578
return (sizeInBytes == o.sizeInBytes) &&
579579
ByteArrayMethods.arrayEquals(baseObject, baseOffset, o.baseObject, o.baseOffset,
580580
sizeInBytes);
581-
} else if (!(other instanceof InternalRow)) {
582-
return false;
583-
} else {
584-
throw new IllegalArgumentException(
585-
"Cannot compare UnsafeRow to " + other.getClass().getName());
586581
}
582+
return false;
587583
}
588584

589585
/**

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -464,8 +464,13 @@ class CodegenContext {
464464
case FloatType => s"(java.lang.Float.isNaN($c1) && java.lang.Float.isNaN($c2)) || $c1 == $c2"
465465
case DoubleType => s"(java.lang.Double.isNaN($c1) && java.lang.Double.isNaN($c2)) || $c1 == $c2"
466466
case dt: DataType if isPrimitiveType(dt) => s"$c1 == $c2"
467+
case dt: DataType if dt.isInstanceOf[AtomicType] => s"$c1.equals($c2)"
468+
case array: ArrayType => genComp(array, c1, c2) + " == 0"
469+
case struct: StructType => genComp(struct, c1, c2) + " == 0"
467470
case udt: UserDefinedType[_] => genEqual(udt.sqlType, c1, c2)
468-
case other => s"$c1.equals($c2)"
471+
case _ =>
472+
throw new IllegalArgumentException(
473+
"cannot generate equality code for un-comparable type: " + dataType.simpleString)
469474
}
470475

471476
/**
@@ -495,6 +500,11 @@ class CodegenContext {
495500
val funcCode: String =
496501
s"""
497502
public int $compareFunc(ArrayData a, ArrayData b) {
503+
// when comparing unsafe arrays, try equals first as it compares the binary directly
504+
// which is very fast.
505+
if (a instanceof UnsafeArrayData && b instanceof UnsafeArrayData && a.equals(b)) {
506+
return 0;
507+
}
498508
int lengthA = a.numElements();
499509
int lengthB = b.numElements();
500510
int $minLength = (lengthA > lengthB) ? lengthB : lengthA;
@@ -534,6 +544,11 @@ class CodegenContext {
534544
val funcCode: String =
535545
s"""
536546
public int $compareFunc(InternalRow a, InternalRow b) {
547+
// when comparing unsafe rows, try equals first as it compares the binary directly
548+
// which is very fast.
549+
if (a instanceof UnsafeRow && b instanceof UnsafeRow && a.equals(b)) {
550+
return 0;
551+
}
537552
InternalRow i = null;
538553
$comparisons
539554
return 0;
@@ -544,7 +559,8 @@ class CodegenContext {
544559
case other if other.isInstanceOf[AtomicType] => s"$c1.compare($c2)"
545560
case udt: UserDefinedType[_] => genComp(udt.sqlType, c1, c2)
546561
case _ =>
547-
throw new IllegalArgumentException("cannot generate compare code for un-comparable type")
562+
throw new IllegalArgumentException(
563+
"cannot generate compare code for un-comparable type: " + dataType.simpleString)
548564
}
549565

550566
/**

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala

Lines changed: 4 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -391,6 +391,8 @@ abstract class BinaryComparison extends BinaryOperator with Predicate {
391391
defineCodeGen(ctx, ev, (c1, c2) => s"${ctx.genComp(left.dataType, c1, c2)} $symbol 0")
392392
}
393393
}
394+
395+
protected lazy val ordering = TypeUtils.getInterpretedOrdering(left.dataType)
394396
}
395397

396398

@@ -417,17 +419,7 @@ case class EqualTo(left: Expression, right: Expression)
417419

418420
override def symbol: String = "="
419421

420-
protected override def nullSafeEval(input1: Any, input2: Any): Any = {
421-
if (left.dataType == FloatType) {
422-
Utils.nanSafeCompareFloats(input1.asInstanceOf[Float], input2.asInstanceOf[Float]) == 0
423-
} else if (left.dataType == DoubleType) {
424-
Utils.nanSafeCompareDoubles(input1.asInstanceOf[Double], input2.asInstanceOf[Double]) == 0
425-
} else if (left.dataType != BinaryType) {
426-
input1 == input2
427-
} else {
428-
java.util.Arrays.equals(input1.asInstanceOf[Array[Byte]], input2.asInstanceOf[Array[Byte]])
429-
}
430-
}
422+
protected override def nullSafeEval(left: Any, right: Any): Any = ordering.equiv(left, right)
431423

432424
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
433425
defineCodeGen(ctx, ev, (c1, c2) => ctx.genEqual(left.dataType, c1, c2))
@@ -453,15 +445,7 @@ case class EqualNullSafe(left: Expression, right: Expression) extends BinaryComp
453445
} else if (input1 == null || input2 == null) {
454446
false
455447
} else {
456-
if (left.dataType == FloatType) {
457-
Utils.nanSafeCompareFloats(input1.asInstanceOf[Float], input2.asInstanceOf[Float]) == 0
458-
} else if (left.dataType == DoubleType) {
459-
Utils.nanSafeCompareDoubles(input1.asInstanceOf[Double], input2.asInstanceOf[Double]) == 0
460-
} else if (left.dataType != BinaryType) {
461-
input1 == input2
462-
} else {
463-
java.util.Arrays.equals(input1.asInstanceOf[Array[Byte]], input2.asInstanceOf[Array[Byte]])
464-
}
448+
ordering.equiv(input1, input2)
465449
}
466450
}
467451

@@ -484,8 +468,6 @@ case class LessThan(left: Expression, right: Expression)
484468

485469
override def symbol: String = "<"
486470

487-
private lazy val ordering = TypeUtils.getInterpretedOrdering(left.dataType)
488-
489471
protected override def nullSafeEval(input1: Any, input2: Any): Any = ordering.lt(input1, input2)
490472
}
491473

@@ -498,8 +480,6 @@ case class LessThanOrEqual(left: Expression, right: Expression)
498480

499481
override def symbol: String = "<="
500482

501-
private lazy val ordering = TypeUtils.getInterpretedOrdering(left.dataType)
502-
503483
protected override def nullSafeEval(input1: Any, input2: Any): Any = ordering.lteq(input1, input2)
504484
}
505485

@@ -512,8 +492,6 @@ case class GreaterThan(left: Expression, right: Expression)
512492

513493
override def symbol: String = ">"
514494

515-
private lazy val ordering = TypeUtils.getInterpretedOrdering(left.dataType)
516-
517495
protected override def nullSafeEval(input1: Any, input2: Any): Any = ordering.gt(input1, input2)
518496
}
519497

@@ -526,7 +504,5 @@ case class GreaterThanOrEqual(left: Expression, right: Expression)
526504

527505
override def symbol: String = ">="
528506

529-
private lazy val ordering = TypeUtils.getInterpretedOrdering(left.dataType)
530-
531507
protected override def nullSafeEval(input1: Any, input2: Any): Any = ordering.gteq(input1, input2)
532508
}

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@ import scala.collection.immutable.HashSet
2121

2222
import org.apache.spark.SparkFunSuite
2323
import org.apache.spark.sql.RandomDataGenerator
24+
import org.apache.spark.sql.catalyst.InternalRow
25+
import org.apache.spark.sql.catalyst.util.GenericArrayData
2426
import org.apache.spark.sql.types._
2527

2628

@@ -293,4 +295,31 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper {
293295
checkEvaluation(EqualNullSafe(nullInt, normalInt), false)
294296
checkEvaluation(EqualNullSafe(nullInt, nullInt), true)
295297
}
298+
299+
test("EqualTo on complex type") {
300+
val array = new GenericArrayData(Array(1, 2, 3))
301+
val struct = create_row("a", 1L, array)
302+
303+
val arrayType = ArrayType(IntegerType)
304+
val structType = new StructType()
305+
.add("1", StringType)
306+
.add("2", LongType)
307+
.add("3", ArrayType(IntegerType))
308+
309+
val projection = UnsafeProjection.create(
310+
new StructType().add("array", arrayType).add("struct", structType))
311+
312+
val unsafeRow = projection(InternalRow(array, struct))
313+
314+
val unsafeArray = unsafeRow.getArray(0)
315+
val unsafeStruct = unsafeRow.getStruct(1, 3)
316+
317+
checkEvaluation(EqualTo(
318+
Literal.create(array, arrayType),
319+
Literal.create(unsafeArray, arrayType)), true)
320+
321+
checkEvaluation(EqualTo(
322+
Literal.create(struct, structType),
323+
Literal.create(unsafeStruct, structType)), true)
324+
}
296325
}

sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2618,4 +2618,11 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
26182618
}
26192619
}
26202620
}
2621+
2622+
test("SPARK-18053: ARRAY equality is broken") {
2623+
withTable("array_tbl") {
2624+
spark.range(10).select(array($"id").as("arr")).write.saveAsTable("array_tbl")
2625+
assert(sql("SELECT * FROM array_tbl where arr = ARRAY(1L)").count == 1)
2626+
}
2627+
}
26212628
}

0 commit comments

Comments
 (0)