From 53c38b12a5a14a581c4d58060ccdb96db25b5956 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Thu, 18 Jun 2015 00:59:57 -0700 Subject: [PATCH 1/9] fix hashCode() and equals() of BinaryType in Row --- .../java/org/apache/spark/sql/BaseRow.java | 25 ++++++++- .../codegen/GenerateProjection.scala | 1 + .../spark/sql/catalyst/expressions/rows.scala | 16 +++++- .../expressions/ExpressionEvalHelper.scala | 4 +- .../expressions/LiteralExpressionSuite.scala | 56 +++++++++++++++++-- 5 files changed, 92 insertions(+), 10 deletions(-) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/BaseRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/BaseRow.java index 611e02d8fb666..636355a3472a4 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/BaseRow.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/BaseRow.java @@ -167,9 +167,21 @@ public boolean equals(Object other) { return false; } for (int i = 0; i < n; i ++) { - if (isNullAt(i) != row.isNullAt(i) || (!isNullAt(i) && !get(i).equals(row.get(i)))) { + if (isNullAt(i) != row.isNullAt(i)) { return false; } + if (!isNullAt(i)) { + Object o1 = get(i); + Object o2 = row.get(i); + if (o1 instanceof byte[]) { + // handle equals() of byte[] + if (!(o2 instanceof byte[]) || !java.util.Arrays.equals((byte[])o1, (byte[])o2)) { + return false; + } + } else if (!o1.equals(o2)) { + return false; + } + } } return true; } @@ -215,4 +227,15 @@ public String mkString(String sep) { public String mkString(String start, String sep, String end) { return toSeq().mkString(start, sep, end); } + + /* + * Returns hash code based on bytes in `arr` + * */ + protected int bytesHashCode(byte[] arr) { + int hash = 0; + for (int i = 0; i < arr.length; i++) { + hash = hash * 37 + (int)arr[i]; + } + return hash; + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala index 8b5dc194be31f..7bdb95d71367a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala @@ -127,6 +127,7 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] { case FloatType => s"Float.floatToIntBits($col)" case DoubleType => s"(int)(Double.doubleToLongBits($col) ^ (Double.doubleToLongBits($col) >>> 32))" + case BinaryType => s"bytesHashCode($col)" case _ => s"$col.hashCode()" } s"isNullAt($i) ? 0 : ($nonNull)" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala index 1098962ddc018..e13297b6691d5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala @@ -143,6 +143,7 @@ class GenericRow(protected[sql] val values: Array[Any]) extends InternalRow { case d: Double => val b = java.lang.Double.doubleToLongBits(d) (b ^ (b >>> 32)).toInt + case a: Array[Byte] => a.map(_.toInt).fold(0)(_ * 37 + _) case other => other.hashCode() } } @@ -163,8 +164,19 @@ class GenericRow(protected[sql] val values: Array[Any]) extends InternalRow { if (isNullAt(i) != other.isNullAt(i)) { return false } - if (apply(i) != other.apply(i)) { - return false + if (!isNullAt(i)) { + val o1 = apply(i) + val o2 = other.apply(i) + if (o1.isInstanceOf[Array[Byte]]) { + val b1 = o1.asInstanceOf[Array[Byte]] + if (!o2.isInstanceOf[Array[Byte]] || + java.util.Arrays.equals(b1, o2.asInstanceOf[Array[Byte]])) { + return false + } + + } else if (apply(i) != other.apply(i)) { + return false + } } i += 1 } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala index 12d2da8b33986..aa1ad43149c86 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala @@ -55,7 +55,7 @@ trait ExpressionEvalHelper { val actual = try evaluate(expression, inputRow) catch { case e: Exception => fail(s"Exception evaluating $expression", e) } - if (actual != expected) { + if (actual !== expected) { val input = if (inputRow == EmptyRow) "" else s", input: $inputRow" fail(s"Incorrect evaluation (codegen off): $expression, " + s"actual: $actual, " + @@ -83,7 +83,7 @@ trait ExpressionEvalHelper { } val actual = plan(inputRow).apply(0) - if (actual != expected) { + if (actual !== expected) { val input = if (inputRow == EmptyRow) "" else s", input: $inputRow" fail(s"Incorrect Evaluation: $expression, actual: $actual, expected: $expected$input") } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralExpressionSuite.scala index f44f55dfb92d1..27cc44c799489 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralExpressionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralExpressionSuite.scala @@ -18,12 +18,26 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.types.StringType +import org.apache.spark.sql.types._ class LiteralExpressionSuite extends SparkFunSuite with ExpressionEvalHelper { - // TODO: Add tests for all data types. + test("null") { + checkEvaluation(Literal.create(null, BooleanType), null) + checkEvaluation(Literal.create(null, ByteType), null) + checkEvaluation(Literal.create(null, ShortType), null) + checkEvaluation(Literal.create(null, IntegerType), null) + checkEvaluation(Literal.create(null, LongType), null) + checkEvaluation(Literal.create(null, FloatType), null) + checkEvaluation(Literal.create(null, LongType), null) + checkEvaluation(Literal.create(null, StringType), null) + checkEvaluation(Literal.create(null, BinaryType), null) + checkEvaluation(Literal.create(null, DecimalType()), null) + checkEvaluation(Literal.create(null, ArrayType(ByteType, true)), null) + checkEvaluation(Literal.create(null, MapType(StringType, IntegerType)), null) + checkEvaluation(Literal.create(null, StructType(Seq.empty)), null) + } test("boolean literals") { checkEvaluation(Literal(true), true) @@ -31,8 +45,16 @@ class LiteralExpressionSuite extends SparkFunSuite with ExpressionEvalHelper { } test("int literals") { - checkEvaluation(Literal(1), 1) - checkEvaluation(Literal(0L), 0L) + List(0, 1, Int.MinValue, Int.MaxValue).foreach { + d => { + checkEvaluation(Literal(d), d) + checkEvaluation(Literal(d.toLong), d.toLong) + checkEvaluation(Literal(d.toShort), d.toShort) + checkEvaluation(Literal(d.toByte), d.toByte) + } + } + checkEvaluation(Literal(Long.MinValue), Long.MinValue) + checkEvaluation(Literal(Long.MaxValue), Long.MaxValue) } test("double literals") { @@ -42,14 +64,38 @@ class LiteralExpressionSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(Literal(d.toFloat), d.toFloat) } } + checkEvaluation(Literal(Double.MinValue), Double.MinValue) + checkEvaluation(Literal(Double.MaxValue), Double.MaxValue) + checkEvaluation(Literal(Float.MinValue), Float.MinValue) + checkEvaluation(Literal(Float.MaxValue), Float.MaxValue) + } test("string literals") { + checkEvaluation(Literal(""), "") checkEvaluation(Literal("test"), "test") - checkEvaluation(Literal.create(null, StringType), null) + checkEvaluation(Literal("世界"), "世界") + checkEvaluation(Literal("\0"), "\0") } test("sum two literals") { checkEvaluation(Add(Literal(1), Literal(1)), 2) } + + test("binary literals") { + checkEvaluation(Literal.create(new Array[Byte](0), BinaryType), new Array[Byte](0)) + checkEvaluation(Literal.create(new Array[Byte](2), BinaryType), new Array[Byte](2)) + } + + test("decimal") { + List(0.0, 1.2, 1.1111, 5).foreach { d => + checkEvaluation(Literal(Decimal(d)), Decimal(d)) + checkEvaluation(Literal(Decimal(d.toInt)), Decimal(d.toInt)) + checkEvaluation(Literal(Decimal(d.toLong)), Decimal(d.toLong)) + checkEvaluation(Literal(Decimal((d * 1000L).toLong, 10, 1)), + Decimal((d * 1000L).toLong, 10, 1)) + } + } + + // TODO(davies): add tests for ArrayType, MapType and StructType } From 0fff25de628b76f5cf2ce0bfbb4fc9c2f3fdbf1f Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Thu, 18 Jun 2015 01:13:27 -0700 Subject: [PATCH 2/9] fix style --- .../spark/sql/catalyst/expressions/LiteralExpressionSuite.scala | 1 - 1 file changed, 1 deletion(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralExpressionSuite.scala index 27cc44c799489..833677e1ed711 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralExpressionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralExpressionSuite.scala @@ -74,7 +74,6 @@ class LiteralExpressionSuite extends SparkFunSuite with ExpressionEvalHelper { test("string literals") { checkEvaluation(Literal(""), "") checkEvaluation(Literal("test"), "test") - checkEvaluation(Literal("世界"), "世界") checkEvaluation(Literal("\0"), "\0") } From 5819d33fd38359c3d10dc6b22edab410b72acab7 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Thu, 18 Jun 2015 10:35:12 -0700 Subject: [PATCH 3/9] unify equals() and hashCode() --- .../java/org/apache/spark/sql/BaseRow.java | 44 ------------- .../main/scala/org/apache/spark/sql/Row.scala | 32 ---------- .../spark/sql/catalyst/InternalRow.scala | 58 ++++++++++++++++- .../codegen/GenerateProjection.scala | 2 +- .../spark/sql/catalyst/expressions/rows.scala | 64 ------------------- 5 files changed, 58 insertions(+), 142 deletions(-) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/BaseRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/BaseRow.java index 636355a3472a4..6a2356f1f9c6f 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/BaseRow.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/BaseRow.java @@ -155,39 +155,6 @@ public int fieldIndex(String name) { throw new UnsupportedOperationException(); } - /** - * A generic version of Row.equals(Row), which is used for tests. - */ - @Override - public boolean equals(Object other) { - if (other instanceof Row) { - Row row = (Row) other; - int n = size(); - if (n != row.size()) { - return false; - } - for (int i = 0; i < n; i ++) { - if (isNullAt(i) != row.isNullAt(i)) { - return false; - } - if (!isNullAt(i)) { - Object o1 = get(i); - Object o2 = row.get(i); - if (o1 instanceof byte[]) { - // handle equals() of byte[] - if (!(o2 instanceof byte[]) || !java.util.Arrays.equals((byte[])o1, (byte[])o2)) { - return false; - } - } else if (!o1.equals(o2)) { - return false; - } - } - } - return true; - } - return false; - } - @Override public InternalRow copy() { final int n = size(); @@ -227,15 +194,4 @@ public String mkString(String sep) { public String mkString(String start, String sep, String end) { return toSeq().mkString(start, sep, end); } - - /* - * Returns hash code based on bytes in `arr` - * */ - protected int bytesHashCode(byte[] arr) { - int hash = 0; - for (int i = 0; i < arr.length; i++) { - hash = hash * 37 + (int)arr[i]; - } - return hash; - } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala index 8aaf5d7d89154..e99d5c87a44fe 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala @@ -17,8 +17,6 @@ package org.apache.spark.sql -import scala.util.hashing.MurmurHash3 - import org.apache.spark.sql.catalyst.expressions.GenericRow import org.apache.spark.sql.types.StructType @@ -365,36 +363,6 @@ trait Row extends Serializable { false } - override def equals(that: Any): Boolean = that match { - case null => false - case that: Row => - if (this.length != that.length) { - return false - } - var i = 0 - val len = this.length - while (i < len) { - if (apply(i) != that.apply(i)) { - return false - } - i += 1 - } - true - case _ => false - } - - override def hashCode: Int = { - // Using Scala's Seq hash code implementation. - var n = 0 - var h = MurmurHash3.seqSeed - val len = length - while (n < len) { - h = MurmurHash3.mix(h, apply(n).##) - n += 1 - } - MurmurHash3.finalizeHash(h, n) - } - /* ---------------------- utility methods for Scala ---------------------- */ /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala index e3c2cc243310b..55ce6572c1fe6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.catalyst import org.apache.spark.sql.Row -import org.apache.spark.sql.catalyst.expressions.GenericRow +import org.apache.spark.sql.catalyst.expressions._ /** * An abstract class for row used internal in Spark SQL, which only contain the columns as @@ -27,6 +27,62 @@ import org.apache.spark.sql.catalyst.expressions.GenericRow abstract class InternalRow extends Row { // A default implementation to change the return type override def copy(): InternalRow = {this} + + // A default version (slow), used for tests + override def equals(o: Any): Boolean = o match { + case other: InternalRow => + if (length != other.length) { + return false + } + + for (i <- 0 until length) { + if (isNullAt(i) != other.isNullAt(i)) { + return false + } + if (!isNullAt(i)) { + val o1 = apply(i) + val o2 = other.apply(i) + if (o1.isInstanceOf[Array[Byte]]) { + if (!o2.isInstanceOf[Array[Byte]] || + !java.util.Arrays.equals(o1.asInstanceOf[Array[Byte]], o2.asInstanceOf[Array[Byte]])) { + return false + } + } else if (o1 != o2) { + return false + } + } + } + true + case _ => false + } + + // Custom hashCode function that matches the efficient code generated version. + override def hashCode: Int = { + var result: Int = 37 + + for (i <- 0 until length) { + val update: Int = + if (isNullAt(i)) { + 0 + } else { + apply(i) match { + case b: Boolean => if (b) 0 else 1 + case b: Byte => b.toInt + case s: Short => s.toInt + case i: Int => i + case l: Long => (l ^ (l >>> 32)).toInt + case f: Float => java.lang.Float.floatToIntBits(f) + case d: Double => + val b = java.lang.Double.doubleToLongBits(d) + (b ^ (b >>> 32)).toInt + case a: Array[Byte] => java.util.Arrays.hashCode(a) + case other => other.hashCode() + } + } + result = 37 * result + update + } + result + } } object InternalRow { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala index 7bdb95d71367a..f0e57e024f35b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala @@ -127,7 +127,7 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] { case FloatType => s"Float.floatToIntBits($col)" case DoubleType => s"(int)(Double.doubleToLongBits($col) ^ (Double.doubleToLongBits($col) >>> 32))" - case BinaryType => s"bytesHashCode($col)" + case BinaryType => s"java.util.Arrays.hashCode($col)" case _ => s"$col.hashCode()" } s"isNullAt($i) ? 0 : ($nonNull)" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala index e13297b6691d5..0d4c9ace5e124 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala @@ -121,70 +121,6 @@ class GenericRow(protected[sql] val values: Array[Any]) extends InternalRow { } } - // TODO(davies): add getDate and getDecimal - - // Custom hashCode function that matches the efficient code generated version. - override def hashCode: Int = { - var result: Int = 37 - - var i = 0 - while (i < values.length) { - val update: Int = - if (isNullAt(i)) { - 0 - } else { - apply(i) match { - case b: Boolean => if (b) 0 else 1 - case b: Byte => b.toInt - case s: Short => s.toInt - case i: Int => i - case l: Long => (l ^ (l >>> 32)).toInt - case f: Float => java.lang.Float.floatToIntBits(f) - case d: Double => - val b = java.lang.Double.doubleToLongBits(d) - (b ^ (b >>> 32)).toInt - case a: Array[Byte] => a.map(_.toInt).fold(0)(_ * 37 + _) - case other => other.hashCode() - } - } - result = 37 * result + update - i += 1 - } - result - } - - override def equals(o: Any): Boolean = o match { - case other: InternalRow => - if (values.length != other.length) { - return false - } - - var i = 0 - while (i < values.length) { - if (isNullAt(i) != other.isNullAt(i)) { - return false - } - if (!isNullAt(i)) { - val o1 = apply(i) - val o2 = other.apply(i) - if (o1.isInstanceOf[Array[Byte]]) { - val b1 = o1.asInstanceOf[Array[Byte]] - if (!o2.isInstanceOf[Array[Byte]] || - java.util.Arrays.equals(b1, o2.asInstanceOf[Array[Byte]])) { - return false - } - - } else if (apply(i) != other.apply(i)) { - return false - } - } - i += 1 - } - true - - case _ => false - } - override def copy(): InternalRow = this } From 6ad2a908956040c4343db82634caa3f347668ad1 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Thu, 18 Jun 2015 10:47:56 -0700 Subject: [PATCH 4/9] fix style --- .../main/scala/org/apache/spark/sql/catalyst/InternalRow.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala index 55ce6572c1fe6..60fa6f5740adc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala @@ -43,8 +43,9 @@ abstract class InternalRow extends Row { val o1 = apply(i) val o2 = other.apply(i) if (o1.isInstanceOf[Array[Byte]]) { + val b1 = o1.asInstanceOf[Array[Byte]] if (!o2.isInstanceOf[Array[Byte]] || - !java.util.Arrays.equals(o1.asInstanceOf[Array[Byte]], o2.asInstanceOf[Array[Byte]])) { + !java.util.Arrays.equals(b1, o2.asInstanceOf[Array[Byte]])) { return false } } else if (o1 != o2) { From d96929bd4f9f740be82164430c3cbb9f1d8b4035 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Thu, 18 Jun 2015 16:46:29 -0700 Subject: [PATCH 5/9] address comment --- python/pyspark/streaming/dstream.py | 2 +- .../spark/sql/catalyst/InternalRow.scala | 47 ++++++++++--------- .../expressions/ExpressionEvalHelper.scala | 4 +- 3 files changed, 28 insertions(+), 25 deletions(-) diff --git a/python/pyspark/streaming/dstream.py b/python/pyspark/streaming/dstream.py index ff097985fae3e..8dcb9645cdc6b 100644 --- a/python/pyspark/streaming/dstream.py +++ b/python/pyspark/streaming/dstream.py @@ -176,7 +176,7 @@ def takeAndPrint(time, rdd): print(record) if len(taken) > num: print("...") - print() + print("") self.foreachRDD(takeAndPrint) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala index 60fa6f5740adc..5e176fd8175fe 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala @@ -26,35 +26,38 @@ import org.apache.spark.sql.catalyst.expressions._ */ abstract class InternalRow extends Row { // A default implementation to change the return type - override def copy(): InternalRow = {this} + override def copy(): InternalRow = this - // A default version (slow), used for tests - override def equals(o: Any): Boolean = o match { - case other: InternalRow => - if (length != other.length) { + override def equals(o: Any): Boolean = { + if (!o.isInstanceOf[Row]) { + return false + } + + val other = o.asInstanceOf[Row] + if (length != other.length) { + return false + } + + for (i <- 0 until length) { + if (isNullAt(i) != other.isNullAt(i)) { return false } - - for (i <- 0 until length) { - if (isNullAt(i) != other.isNullAt(i)) { - return false - } - if (!isNullAt(i)) { - val o1 = apply(i) - val o2 = other.apply(i) - if (o1.isInstanceOf[Array[Byte]]) { - val b1 = o1.asInstanceOf[Array[Byte]] - if (!o2.isInstanceOf[Array[Byte]] || - !java.util.Arrays.equals(b1, o2.asInstanceOf[Array[Byte]])) { - return false - } - } else if (o1 != o2) { + if (!isNullAt(i)) { + val o1 = apply(i) + val o2 = other.apply(i) + if (o1.isInstanceOf[Array[Byte]]) { + // handle equality of Array[Byte] + val b1 = o1.asInstanceOf[Array[Byte]] + if (!o2.isInstanceOf[Array[Byte]] || + !java.util.Arrays.equals(b1, o2.asInstanceOf[Array[Byte]])) { return false } + } else if (o1 != o2) { + return false } } - true - case _ => false + } + true } // Custom hashCode function that matches the efficient code generated version. diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala index aa1ad43149c86..12d2da8b33986 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala @@ -55,7 +55,7 @@ trait ExpressionEvalHelper { val actual = try evaluate(expression, inputRow) catch { case e: Exception => fail(s"Exception evaluating $expression", e) } - if (actual !== expected) { + if (actual != expected) { val input = if (inputRow == EmptyRow) "" else s", input: $inputRow" fail(s"Incorrect evaluation (codegen off): $expression, " + s"actual: $actual, " + @@ -83,7 +83,7 @@ trait ExpressionEvalHelper { } val actual = plan(inputRow).apply(0) - if (actual !== expected) { + if (actual != expected) { val input = if (inputRow == EmptyRow) "" else s", input: $inputRow" fail(s"Incorrect Evaluation: $expression, actual: $actual, expected: $expected$input") } From 41caec6c9c7bf8d82c714d7ba081e002fa919a9d Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Thu, 18 Jun 2015 22:09:38 -0700 Subject: [PATCH 6/9] change for to while --- .../org/apache/spark/sql/catalyst/InternalRow.scala | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala index 5e176fd8175fe..d7b537a9fe3bc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala @@ -38,7 +38,8 @@ abstract class InternalRow extends Row { return false } - for (i <- 0 until length) { + var i = 0 + while (i < length) { if (isNullAt(i) != other.isNullAt(i)) { return false } @@ -56,6 +57,7 @@ abstract class InternalRow extends Row { return false } } + i += 1 } true } @@ -63,8 +65,8 @@ abstract class InternalRow extends Row { // Custom hashCode function that matches the efficient code generated version. override def hashCode: Int = { var result: Int = 37 - - for (i <- 0 until length) { + var i = 0 + while (i < length) { val update: Int = if (isNullAt(i)) { 0 @@ -84,6 +86,7 @@ abstract class InternalRow extends Row { } } result = 37 * result + update + i += 1 } result } From bd20780beeb134aa54f41326fa22f7176beee29c Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Fri, 19 Jun 2015 09:58:36 -0700 Subject: [PATCH 7/9] check with catalyst types --- .../expressions/ExpressionEvalHelper.scala | 27 ++++++++++++++----- .../expressions/StringFunctionsSuite.scala | 5 +--- .../apache/spark/unsafe/types/UTF8String.java | 6 +---- 3 files changed, 22 insertions(+), 16 deletions(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala index 12d2da8b33986..158f54af13802 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala @@ -38,10 +38,23 @@ trait ExpressionEvalHelper { protected def checkEvaluation( expression: Expression, expected: Any, inputRow: InternalRow = EmptyRow): Unit = { - checkEvaluationWithoutCodegen(expression, expected, inputRow) - checkEvaluationWithGeneratedMutableProjection(expression, expected, inputRow) - checkEvaluationWithGeneratedProjection(expression, expected, inputRow) - checkEvaluationWithOptimization(expression, expected, inputRow) + val catalystValue = CatalystTypeConverters.convertToCatalyst(expected) + checkEvaluationWithoutCodegen(expression, catalystValue, inputRow) + checkEvaluationWithGeneratedMutableProjection(expression, catalystValue, inputRow) + checkEvaluationWithGeneratedProjection(expression, catalystValue, inputRow) + checkEvaluationWithOptimization(expression, catalystValue, inputRow) + } + + /** + * Check the equality between result of expression and expected value, it will handle + * Array[Byte]. + */ + protected def checkResult(result: Any, expected: Any): Boolean = { + (result, expected) match { + case (result: Array[Byte], expected: Array[Byte]) => + java.util.Arrays.equals(result, expected) + case _ => result == expected + } } protected def evaluate(expression: Expression, inputRow: InternalRow = EmptyRow): Any = { @@ -55,7 +68,7 @@ trait ExpressionEvalHelper { val actual = try evaluate(expression, inputRow) catch { case e: Exception => fail(s"Exception evaluating $expression", e) } - if (actual != expected) { + if (!checkResult(actual, expected)) { val input = if (inputRow == EmptyRow) "" else s", input: $inputRow" fail(s"Incorrect evaluation (codegen off): $expression, " + s"actual: $actual, " + @@ -83,7 +96,7 @@ trait ExpressionEvalHelper { } val actual = plan(inputRow).apply(0) - if (actual != expected) { + if (!checkResult(actual, expected)) { val input = if (inputRow == EmptyRow) "" else s", input: $inputRow" fail(s"Incorrect Evaluation: $expression, actual: $actual, expected: $expected$input") } @@ -109,7 +122,7 @@ trait ExpressionEvalHelper { } val actual = plan(inputRow) - val expectedRow = new GenericRow(Array[Any](CatalystTypeConverters.convertToCatalyst(expected))) + val expectedRow = new GenericRow(Array[Any](expected)) if (actual.hashCode() != expectedRow.hashCode()) { fail( s""" diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringFunctionsSuite.scala index d363e631540d8..5dbb1d562c1d9 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringFunctionsSuite.scala @@ -222,9 +222,6 @@ class StringFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(StringLength(regEx), 5, create_row("abdef")) checkEvaluation(StringLength(regEx), 0, create_row("")) checkEvaluation(StringLength(regEx), null, create_row(null)) - // TODO currently bug in codegen, let's temporally disable this - // checkEvaluation(StringLength(Literal.create(null, StringType)), null, create_row("abdef")) + checkEvaluation(StringLength(Literal.create(null, StringType)), null, create_row("abdef")) } - - } diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java b/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java index a35168019549e..3db6993798127 100644 --- a/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java +++ b/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java @@ -17,10 +17,10 @@ package org.apache.spark.unsafe.types; +import javax.annotation.Nullable; import java.io.Serializable; import java.io.UnsupportedEncodingException; import java.util.Arrays; -import javax.annotation.Nullable; import org.apache.spark.unsafe.PlatformDependent; @@ -196,10 +196,6 @@ public int compare(final UTF8String other) { public boolean equals(final Object other) { if (other instanceof UTF8String) { return Arrays.equals(bytes, ((UTF8String) other).getBytes()); - } else if (other instanceof String) { - // Used only in unit tests. - String s = (String) other; - return bytes.length >= s.length() && length() == s.length() && toString().equals(s); } else { return false; } From 89c2432e07a062ccac62a18113f96ab29e21ab57 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Fri, 19 Jun 2015 10:04:05 -0700 Subject: [PATCH 8/9] fix style --- python/pyspark/streaming/dstream.py | 2 +- .../expressions/LiteralExpressionSuite.scala | 20 ++++++++----------- 2 files changed, 9 insertions(+), 13 deletions(-) diff --git a/python/pyspark/streaming/dstream.py b/python/pyspark/streaming/dstream.py index 8dcb9645cdc6b..ff097985fae3e 100644 --- a/python/pyspark/streaming/dstream.py +++ b/python/pyspark/streaming/dstream.py @@ -176,7 +176,7 @@ def takeAndPrint(time, rdd): print(record) if len(taken) > num: print("...") - print("") + print() self.foreachRDD(takeAndPrint) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralExpressionSuite.scala index 833677e1ed711..d924ff7a102f6 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralExpressionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralExpressionSuite.scala @@ -45,24 +45,20 @@ class LiteralExpressionSuite extends SparkFunSuite with ExpressionEvalHelper { } test("int literals") { - List(0, 1, Int.MinValue, Int.MaxValue).foreach { - d => { - checkEvaluation(Literal(d), d) - checkEvaluation(Literal(d.toLong), d.toLong) - checkEvaluation(Literal(d.toShort), d.toShort) - checkEvaluation(Literal(d.toByte), d.toByte) - } + List(0, 1, Int.MinValue, Int.MaxValue).foreach { d => + checkEvaluation(Literal(d), d) + checkEvaluation(Literal(d.toLong), d.toLong) + checkEvaluation(Literal(d.toShort), d.toShort) + checkEvaluation(Literal(d.toByte), d.toByte) } checkEvaluation(Literal(Long.MinValue), Long.MinValue) checkEvaluation(Literal(Long.MaxValue), Long.MaxValue) } test("double literals") { - List(0.0, -0.0, Double.NegativeInfinity, Double.PositiveInfinity).foreach { - d => { - checkEvaluation(Literal(d), d) - checkEvaluation(Literal(d.toFloat), d.toFloat) - } + List(0.0, -0.0, Double.NegativeInfinity, Double.PositiveInfinity).foreach { d => + checkEvaluation(Literal(d), d) + checkEvaluation(Literal(d.toFloat), d.toFloat) } checkEvaluation(Literal(Double.MinValue), Double.MinValue) checkEvaluation(Literal(Double.MaxValue), Double.MaxValue) From 32d981137fd24d1e55c3a4c2c23bb19e494b4f65 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Fri, 19 Jun 2015 14:01:29 -0700 Subject: [PATCH 9/9] fix test --- .../java/org/apache/spark/unsafe/types/UTF8StringSuite.java | 2 -- 1 file changed, 2 deletions(-) diff --git a/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java b/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java index 80c179a1b5e75..796cdc9dbebdb 100644 --- a/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java +++ b/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java @@ -28,8 +28,6 @@ private void checkBasic(String str, int len) throws UnsupportedEncodingException Assert.assertEquals(UTF8String.fromString(str).length(), len); Assert.assertEquals(UTF8String.fromBytes(str.getBytes("utf8")).length(), len); - Assert.assertEquals(UTF8String.fromString(str), str); - Assert.assertEquals(UTF8String.fromBytes(str.getBytes("utf8")), str); Assert.assertEquals(UTF8String.fromString(str).toString(), str); Assert.assertEquals(UTF8String.fromBytes(str.getBytes("utf8")).toString(), str); Assert.assertEquals(UTF8String.fromBytes(str.getBytes("utf8")), UTF8String.fromString(str));