Skip to content

Commit 00c432a

Browse files
committed
compare unsafe and safe complex-type values correctly
1 parent d9dd979 commit 00c432a

File tree

5 files changed

+55
-35
lines changed

5 files changed

+55
-35
lines changed

sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -578,12 +578,8 @@ public boolean equals(Object other) {
578578
return (sizeInBytes == o.sizeInBytes) &&
579579
ByteArrayMethods.arrayEquals(baseObject, baseOffset, o.baseObject, o.baseOffset,
580580
sizeInBytes);
581-
} else if (!(other instanceof InternalRow)) {
582-
return false;
583-
} else {
584-
throw new IllegalArgumentException(
585-
"Cannot compare UnsafeRow to " + other.getClass().getName());
586581
}
582+
return false;
587583
}
588584

589585
/**

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -481,8 +481,13 @@ class CodegenContext {
481481
case FloatType => s"(java.lang.Float.isNaN($c1) && java.lang.Float.isNaN($c2)) || $c1 == $c2"
482482
case DoubleType => s"(java.lang.Double.isNaN($c1) && java.lang.Double.isNaN($c2)) || $c1 == $c2"
483483
case dt: DataType if isPrimitiveType(dt) => s"$c1 == $c2"
484+
case dt: DataType if dt.isInstanceOf[AtomicType] => s"$c1.equals($c2)"
485+
case array: ArrayType => genComp(array, c1, c2) + " == 0"
486+
case struct: StructType => genComp(struct, c1, c2) + " == 0"
484487
case udt: UserDefinedType[_] => genEqual(udt.sqlType, c1, c2)
485-
case other => s"$c1.equals($c2)"
488+
case _ =>
489+
throw new IllegalArgumentException(
490+
"cannot generate equality code for un-comparable type: " + dataType.simpleString)
486491
}
487492

488493
/**
@@ -512,6 +517,9 @@ class CodegenContext {
512517
val funcCode: String =
513518
s"""
514519
public int $compareFunc(ArrayData a, ArrayData b) {
520+
if (a instanceof UnsafeArrayData && b instanceof UnsafeArrayData && a == b) {
521+
return 0;
522+
}
515523
int lengthA = a.numElements();
516524
int lengthB = b.numElements();
517525
int $minLength = (lengthA > lengthB) ? lengthB : lengthA;
@@ -551,6 +559,9 @@ class CodegenContext {
551559
val funcCode: String =
552560
s"""
553561
public int $compareFunc(InternalRow a, InternalRow b) {
562+
if (a instanceof UnsafeRow && b instanceof UnsafeRow && a == b) {
563+
return 0;
564+
}
554565
InternalRow i = null;
555566
$comparisons
556567
return 0;
@@ -561,7 +572,8 @@ class CodegenContext {
561572
case other if other.isInstanceOf[AtomicType] => s"$c1.compare($c2)"
562573
case udt: UserDefinedType[_] => genComp(udt.sqlType, c1, c2)
563574
case _ =>
564-
throw new IllegalArgumentException("cannot generate compare code for un-comparable type")
575+
throw new IllegalArgumentException(
576+
"cannot generate compare code for un-comparable type: " + dataType.simpleString)
565577
}
566578

567579
/**

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala

Lines changed: 4 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -388,6 +388,8 @@ abstract class BinaryComparison extends BinaryOperator with Predicate {
388388
defineCodeGen(ctx, ev, (c1, c2) => s"${ctx.genComp(left.dataType, c1, c2)} $symbol 0")
389389
}
390390
}
391+
392+
protected lazy val ordering = TypeUtils.getInterpretedOrdering(left.dataType)
391393
}
392394

393395

@@ -414,17 +416,7 @@ case class EqualTo(left: Expression, right: Expression)
414416

415417
override def symbol: String = "="
416418

417-
protected override def nullSafeEval(input1: Any, input2: Any): Any = {
418-
if (left.dataType == FloatType) {
419-
Utils.nanSafeCompareFloats(input1.asInstanceOf[Float], input2.asInstanceOf[Float]) == 0
420-
} else if (left.dataType == DoubleType) {
421-
Utils.nanSafeCompareDoubles(input1.asInstanceOf[Double], input2.asInstanceOf[Double]) == 0
422-
} else if (left.dataType != BinaryType) {
423-
input1 == input2
424-
} else {
425-
java.util.Arrays.equals(input1.asInstanceOf[Array[Byte]], input2.asInstanceOf[Array[Byte]])
426-
}
427-
}
419+
protected override def nullSafeEval(left: Any, right: Any): Any = ordering.equiv(left, right)
428420

429421
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
430422
defineCodeGen(ctx, ev, (c1, c2) => ctx.genEqual(left.dataType, c1, c2))
@@ -452,15 +444,7 @@ case class EqualNullSafe(left: Expression, right: Expression) extends BinaryComp
452444
} else if (input1 == null || input2 == null) {
453445
false
454446
} else {
455-
if (left.dataType == FloatType) {
456-
Utils.nanSafeCompareFloats(input1.asInstanceOf[Float], input2.asInstanceOf[Float]) == 0
457-
} else if (left.dataType == DoubleType) {
458-
Utils.nanSafeCompareDoubles(input1.asInstanceOf[Double], input2.asInstanceOf[Double]) == 0
459-
} else if (left.dataType != BinaryType) {
460-
input1 == input2
461-
} else {
462-
java.util.Arrays.equals(input1.asInstanceOf[Array[Byte]], input2.asInstanceOf[Array[Byte]])
463-
}
447+
ordering.equiv(input1, input2)
464448
}
465449
}
466450

@@ -483,8 +467,6 @@ case class LessThan(left: Expression, right: Expression)
483467

484468
override def symbol: String = "<"
485469

486-
private lazy val ordering = TypeUtils.getInterpretedOrdering(left.dataType)
487-
488470
protected override def nullSafeEval(input1: Any, input2: Any): Any = ordering.lt(input1, input2)
489471
}
490472

@@ -497,8 +479,6 @@ case class LessThanOrEqual(left: Expression, right: Expression)
497479

498480
override def symbol: String = "<="
499481

500-
private lazy val ordering = TypeUtils.getInterpretedOrdering(left.dataType)
501-
502482
protected override def nullSafeEval(input1: Any, input2: Any): Any = ordering.lteq(input1, input2)
503483
}
504484

@@ -511,8 +491,6 @@ case class GreaterThan(left: Expression, right: Expression)
511491

512492
override def symbol: String = ">"
513493

514-
private lazy val ordering = TypeUtils.getInterpretedOrdering(left.dataType)
515-
516494
protected override def nullSafeEval(input1: Any, input2: Any): Any = ordering.gt(input1, input2)
517495
}
518496

@@ -525,7 +503,5 @@ case class GreaterThanOrEqual(left: Expression, right: Expression)
525503

526504
override def symbol: String = ">="
527505

528-
private lazy val ordering = TypeUtils.getInterpretedOrdering(left.dataType)
529-
530506
protected override def nullSafeEval(input1: Any, input2: Any): Any = ordering.gteq(input1, input2)
531507
}

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@ import scala.collection.immutable.HashSet
2121

2222
import org.apache.spark.SparkFunSuite
2323
import org.apache.spark.sql.RandomDataGenerator
24+
import org.apache.spark.sql.catalyst.InternalRow
25+
import org.apache.spark.sql.catalyst.util.GenericArrayData
2426
import org.apache.spark.sql.types._
2527

2628

@@ -293,4 +295,31 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper {
293295
checkEvaluation(EqualNullSafe(nullInt, normalInt), false)
294296
checkEvaluation(EqualNullSafe(nullInt, nullInt), true)
295297
}
298+
299+
test("EqualTo on complex type") {
300+
val array = new GenericArrayData(Array(1, 2, 3))
301+
val struct = create_row("a", 1L, array)
302+
303+
val arrayType = ArrayType(IntegerType)
304+
val structType = new StructType()
305+
.add("1", StringType)
306+
.add("2", LongType)
307+
.add("3", ArrayType(IntegerType))
308+
309+
val projection = UnsafeProjection.create(
310+
new StructType().add("array", arrayType).add("struct", structType))
311+
312+
val unsafeRow = projection(InternalRow(array, struct))
313+
314+
val unsafeArray = unsafeRow.getArray(0)
315+
val unsafeStruct = unsafeRow.getStruct(1, 3)
316+
317+
checkEvaluation(EqualTo(
318+
Literal.create(array, arrayType),
319+
Literal.create(unsafeArray, arrayType)), true)
320+
321+
checkEvaluation(EqualTo(
322+
Literal.create(struct, structType),
323+
Literal.create(unsafeStruct, structType)), true)
324+
}
296325
}

sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2476,4 +2476,11 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
24762476
}
24772477
}
24782478
}
2479+
2480+
test("SPARK-18053: ARRAY equality is broken") {
2481+
withTable("array_tbl") {
2482+
spark.range(10).select(array($"id").as("arr")).write.saveAsTable("array_tbl")
2483+
assert(sql("SELECT * FROM array_tbl where arr = ARRAY(1L)").count == 1)
2484+
}
2485+
}
24792486
}

0 commit comments

Comments
 (0)