From d2b4a4a9a2139b1a6c2be5d1f1aa3d98a6c9ed99 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Wed, 1 Jul 2015 20:18:05 -0700 Subject: [PATCH 01/28] Add random data generator test utilities to Spark SQL. --- .../spark/sql/test/DataTypeTestUtils.scala | 59 +++++++ .../spark/sql/test/RandomDataGenerator.scala | 151 ++++++++++++++++++ .../sql/test/RandomDataGeneratorSuite.scala | 77 +++++++++ 3 files changed, 287 insertions(+) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/test/DataTypeTestUtils.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/test/RandomDataGenerator.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/test/RandomDataGeneratorSuite.scala diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/DataTypeTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/DataTypeTestUtils.scala new file mode 100644 index 0000000000000..d862eb7293d6d --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/DataTypeTestUtils.scala @@ -0,0 +1,59 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.test + +import org.apache.spark.sql.types._ + +/** + * Utility functions for working with DataTypes in tests. + */ +object DataTypeTestUtils { + + /** + * Instances of all [[IntegralType]]s. + */ + val integralType: Set[IntegralType] = Set( + ByteType, ShortType, IntegerType, LongType + ) + + /** + * Instances of all [[FractionalType]]s, including both fixed- and unlimited-precision + * decimal types. + */ + val fractionalTypes: Set[FractionalType] = Set( + DecimalType(precisionInfo = None), + DecimalType(2, 1), + DoubleType, + FloatType + ) + + /** + * Instances of all [[NumericType]]s. + */ + val numericTypes: Set[NumericType] = integralType ++ fractionalTypes + + /** + * Instances of all [[AtomicType]]s. + */ + val atomicTypes: Set[DataType] = Set(BinaryType, StringType, TimestampType) ++ numericTypes + + /** + * Instances of [[ArrayType]] for all [[AtomicType]]s. Arrays of these types may contain null. + */ + val atomicArrayTypes: Set[ArrayType] = atomicTypes.map(ArrayType(_, containsNull = true)) +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/RandomDataGenerator.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/RandomDataGenerator.scala new file mode 100644 index 0000000000000..6ac2ba155655c --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/RandomDataGenerator.scala @@ -0,0 +1,151 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.test + +import org.apache.spark.sql.Row + +import scala.util.Random + +import org.apache.spark.sql.types._ + +/** + * Random data generators for Spark SQL DataTypes. These generators do not generate uniformly random + * values; instead, they're biased to return "interesting" values (such as maximum / minimum values) + * with higher probability. + */ +object RandomDataGenerator { + + /** + * The conditional probability of a non-null value being drawn from a set of "interesting" values + * instead of being chosen uniformly at random. + */ + private val PROBABILITY_OF_INTERESTING_VALUE: Float = 0.25f + + /** + * The probability of the generated value being null + */ + private val PROBABILITY_OF_NULL: Float = 0.1f + + private val MAX_STR_LEN: Int = 1024 + private val MAX_ARR_SIZE: Int = 128 + private val MAX_MAP_SIZE: Int = 128 + + /** + * Helper function for constructing a biased random number generator which returns "interesting" + * values with a higher probability. + */ + private def randomNumeric[T]( + rand: Random, + uniformRand: Random => T, + interestingValues: Seq[T]): Some[() => T] = { + val f = () => { + if (rand.nextFloat() <= PROBABILITY_OF_INTERESTING_VALUE) { + interestingValues(rand.nextInt(interestingValues.length)) + } else { + uniformRand(rand) + } + } + Some(f) + } + + /** + * Returns a function which generates random values for the given [[DataType]], or `None` if no + * random data generator is defined for that data type. The generated values will use an external + * representation of the data type; for example, the random generator for [[DateType]] will return + * instances of [[java.sql.Date]] and the generator for [[StructType]] will return a + * [[org.apache.spark.Row]]. + * + * @param dataType the type to generate values for + * @param nullable whether null values should be generated + * @param seed an optional seed for the random number generator + * @return a function which can be called to generate random values. + */ + def forType( + dataType: DataType, + nullable: Boolean = true, + seed: Option[Long] = None): Option[() => Any] = { + val rand = new Random() + seed.foreach(rand.setSeed) + + val valueGenerator: Option[() => Any] = dataType match { + case StringType => Some(() => rand.nextString(rand.nextInt(MAX_STR_LEN))) + case BinaryType => Some(() => { + val arr = new Array[Byte](rand.nextInt(MAX_STR_LEN)) + rand.nextBytes(arr) + arr + }) + case BooleanType => Some(() => rand.nextBoolean()) + case DateType => Some(() => new java.sql.Date(rand.nextInt(Int.MaxValue))) + case DoubleType => randomNumeric[Double]( + rand, _.nextDouble(), Seq(Double.MinValue, Double.MinPositiveValue, Double.MaxValue, 0.0)) + case FloatType => randomNumeric[Float]( + rand, _.nextFloat(), Seq(Float.MinValue, Float.MinPositiveValue, Float.MaxValue, 0.0f)) + case ByteType => randomNumeric[Byte]( + rand, _.nextInt().toByte, Seq(Byte.MinValue, Byte.MaxValue, 0.toByte)) + case IntegerType => randomNumeric[Int]( + rand, _.nextInt(), Seq(Int.MinValue, Int.MaxValue, 0)) + case LongType => randomNumeric[Long]( + rand, _.nextLong(), Seq(Long.MinValue, Long.MaxValue, 0L)) + case ShortType => randomNumeric[Short]( + rand, _.nextInt().toShort, Seq(Short.MinValue, Short.MaxValue, 0.toShort)) + case NullType => Some(() => null) + case ArrayType(elementType, containsNull) => { + forType(elementType, nullable = containsNull, seed = Some(rand.nextLong())).map { + elementGenerator => () => Array.fill(rand.nextInt(MAX_ARR_SIZE))(elementGenerator()) + } + } + case MapType(keyType, valueType, valueContainsNull) => { + for ( + keyGenerator <- forType(keyType, nullable = false, seed = Some(rand.nextLong())); + valueGenerator <- + forType(valueType, nullable = valueContainsNull, seed = Some(rand.nextLong())) + ) yield { + () => { + Seq.fill(rand.nextInt(MAX_MAP_SIZE))((keyGenerator(), valueGenerator())).toMap + } + } + } + case StructType(fields) => { + val maybeFieldGenerators: Seq[Option[() => Any]] = fields.map { field => + forType(field.dataType, nullable = field.nullable, seed = Some(rand.nextLong())) + } + if (maybeFieldGenerators.forall(_.isDefined)) { + val fieldGenerators: Seq[() => Any] = maybeFieldGenerators.map(_.get) + Some(() => Row.fromSeq(fieldGenerators.map(_.apply()))) + } else { + None + } + } + case unsupportedType => None + } + // Handle nullability by wrapping the non-null value generator: + valueGenerator.map { valueGenerator => + if (nullable) { + () => { + if (rand.nextFloat() <= PROBABILITY_OF_NULL) { + null + } else { + valueGenerator() + } + } + } else { + valueGenerator + } + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/RandomDataGeneratorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/RandomDataGeneratorSuite.scala new file mode 100644 index 0000000000000..fb4ed9028c2c5 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/RandomDataGeneratorSuite.scala @@ -0,0 +1,77 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.test + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.CatalystTypeConverters +import org.apache.spark.sql.types.{StructField, StructType, MapType, DataType} + +/** + * Tests of [[RandomDataGenerator]]. + */ +class RandomDataGeneratorSuite extends SparkFunSuite { + + /** + * Tests random data generation for the given type by using it to generate random values then + * converting those values into their Catalyst equivalents using CatalystTypeConverters. + */ + def testRandomDataGeneration(dataType: DataType, nullable: Boolean = true): Unit = { + val toCatalyst = CatalystTypeConverters.createToCatalystConverter(dataType) + RandomDataGenerator.forType(dataType, nullable, Some(42L)).foreach { generator => + for (_ <- 1 to 10) { + val generatedValue = generator() + val convertedValue = toCatalyst(generatedValue) + if (!nullable) { + assert(convertedValue !== null) + } + } + } + + } + + // Basic types: + + (DataTypeTestUtils.atomicTypes ++ DataTypeTestUtils.atomicArrayTypes).foreach { dataType => + test(s"$dataType") { + testRandomDataGeneration(dataType) + } + } + + // Complex types: + + for ( + keyType <- DataTypeTestUtils.atomicTypes; + valueType <- DataTypeTestUtils.atomicTypes + ) { + val mapType = MapType(keyType, valueType) + test(s"$mapType") { + testRandomDataGeneration(mapType) + } + } + + for ( + colOneType <- DataTypeTestUtils.atomicTypes; + colTwoType <- DataTypeTestUtils.atomicTypes + ) { + val structType = StructType(StructField("a", colOneType) :: StructField("b", colTwoType) :: Nil) + test(s"$structType") { + testRandomDataGeneration(structType) + } + } + +} From ab76cbd89bf800d590b7833f5a25c62df4ec2a95 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Wed, 1 Jul 2015 21:37:38 -0700 Subject: [PATCH 02/28] Move code to Catalyst package. --- .../scala/org/apache/spark/sql}/RandomDataGenerator.scala | 6 ++---- .../org/apache/spark/sql}/RandomDataGeneratorSuite.scala | 4 ++-- .../org/apache/spark/sql/types}/DataTypeTestUtils.scala | 4 +--- 3 files changed, 5 insertions(+), 9 deletions(-) rename sql/{core/src/test/scala/org/apache/spark/sql/test => catalyst/src/test/scala/org/apache/spark/sql}/RandomDataGenerator.scala (98%) rename sql/{core/src/test/scala/org/apache/spark/sql/test => catalyst/src/test/scala/org/apache/spark/sql}/RandomDataGeneratorSuite.scala (95%) rename sql/{core/src/test/scala/org/apache/spark/sql/test => catalyst/src/test/scala/org/apache/spark/sql/types}/DataTypeTestUtils.scala (96%) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/RandomDataGenerator.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGenerator.scala similarity index 98% rename from sql/core/src/test/scala/org/apache/spark/sql/test/RandomDataGenerator.scala rename to sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGenerator.scala index 6ac2ba155655c..f167557be818f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/RandomDataGenerator.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGenerator.scala @@ -15,14 +15,12 @@ * limitations under the License. */ -package org.apache.spark.sql.test +package org.apache.spark.sql -import org.apache.spark.sql.Row +import org.apache.spark.sql.types._ import scala.util.Random -import org.apache.spark.sql.types._ - /** * Random data generators for Spark SQL DataTypes. These generators do not generate uniformly random * values; instead, they're biased to return "interesting" values (such as maximum / minimum values) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/RandomDataGeneratorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGeneratorSuite.scala similarity index 95% rename from sql/core/src/test/scala/org/apache/spark/sql/test/RandomDataGeneratorSuite.scala rename to sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGeneratorSuite.scala index fb4ed9028c2c5..ea70fe03eb912 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/RandomDataGeneratorSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGeneratorSuite.scala @@ -15,11 +15,11 @@ * limitations under the License. */ -package org.apache.spark.sql.test +package org.apache.spark.sql import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.CatalystTypeConverters -import org.apache.spark.sql.types.{StructField, StructType, MapType, DataType} +import org.apache.spark.sql.types._ /** * Tests of [[RandomDataGenerator]]. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/DataTypeTestUtils.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeTestUtils.scala similarity index 96% rename from sql/core/src/test/scala/org/apache/spark/sql/test/DataTypeTestUtils.scala rename to sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeTestUtils.scala index d862eb7293d6d..0b7ed54c681e9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/DataTypeTestUtils.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeTestUtils.scala @@ -15,9 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.test - -import org.apache.spark.sql.types._ +package org.apache.spark.sql.types /** * Utility functions for working with DataTypes in tests. From 5acdd5ccf36487ba49815e8e0429f4c99558d427 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Wed, 1 Jul 2015 22:15:13 -0700 Subject: [PATCH 03/28] Infinity and NaN are interesting. --- .../scala/org/apache/spark/sql/RandomDataGenerator.scala | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGenerator.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGenerator.scala index f167557be818f..cd4ffdfd45173 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGenerator.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGenerator.scala @@ -32,7 +32,7 @@ object RandomDataGenerator { * The conditional probability of a non-null value being drawn from a set of "interesting" values * instead of being chosen uniformly at random. */ - private val PROBABILITY_OF_INTERESTING_VALUE: Float = 0.25f + private val PROBABILITY_OF_INTERESTING_VALUE: Float = 0.5f /** * The probability of the generated value being null @@ -90,9 +90,11 @@ object RandomDataGenerator { case BooleanType => Some(() => rand.nextBoolean()) case DateType => Some(() => new java.sql.Date(rand.nextInt(Int.MaxValue))) case DoubleType => randomNumeric[Double]( - rand, _.nextDouble(), Seq(Double.MinValue, Double.MinPositiveValue, Double.MaxValue, 0.0)) + rand, _.nextDouble(), Seq(Double.MinValue, Double.MinPositiveValue, Double.MaxValue, + Double.PositiveInfinity, Double.NegativeInfinity, Double.NaN, 0.0)) case FloatType => randomNumeric[Float]( - rand, _.nextFloat(), Seq(Float.MinValue, Float.MinPositiveValue, Float.MaxValue, 0.0f)) + rand, _.nextFloat(), Seq(Float.MinValue, Float.MinPositiveValue, Float.MaxValue, + Float.PositiveInfinity, Float.NegativeInfinity, Float.NaN, 0.0f)) case ByteType => randomNumeric[Byte]( rand, _.nextInt().toByte, Seq(Byte.MinValue, Byte.MaxValue, 0.toByte)) case IntegerType => randomNumeric[Int]( From b55875a05e4805cfdf2c3468a6cd50eec6a30578 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Wed, 1 Jul 2015 22:23:55 -0700 Subject: [PATCH 04/28] Generate doubles and floats over entire possible range. --- .../org/apache/spark/sql/RandomDataGenerator.scala | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGenerator.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGenerator.scala index cd4ffdfd45173..26437c45eb41e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGenerator.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGenerator.scala @@ -17,10 +17,13 @@ package org.apache.spark.sql -import org.apache.spark.sql.types._ +import java.lang.Double.longBitsToDouble +import java.lang.Float.intBitsToFloat import scala.util.Random +import org.apache.spark.sql.types._ + /** * Random data generators for Spark SQL DataTypes. These generators do not generate uniformly random * values; instead, they're biased to return "interesting" values (such as maximum / minimum values) @@ -90,11 +93,11 @@ object RandomDataGenerator { case BooleanType => Some(() => rand.nextBoolean()) case DateType => Some(() => new java.sql.Date(rand.nextInt(Int.MaxValue))) case DoubleType => randomNumeric[Double]( - rand, _.nextDouble(), Seq(Double.MinValue, Double.MinPositiveValue, Double.MaxValue, - Double.PositiveInfinity, Double.NegativeInfinity, Double.NaN, 0.0)) + rand, r => longBitsToDouble(r.nextLong()), Seq(Double.MinValue, Double.MinPositiveValue, + Double.MaxValue, Double.PositiveInfinity, Double.NegativeInfinity, Double.NaN, 0.0)) case FloatType => randomNumeric[Float]( - rand, _.nextFloat(), Seq(Float.MinValue, Float.MinPositiveValue, Float.MaxValue, - Float.PositiveInfinity, Float.NegativeInfinity, Float.NaN, 0.0f)) + rand, r => intBitsToFloat(r.nextInt()), Seq(Float.MinValue, Float.MinPositiveValue, + Float.MaxValue, Float.PositiveInfinity, Float.NegativeInfinity, Float.NaN, 0.0f)) case ByteType => randomNumeric[Byte]( rand, _.nextInt().toByte, Seq(Byte.MinValue, Byte.MaxValue, 0.toByte)) case IntegerType => randomNumeric[Int]( From 7d5c13ea39cc0b811cc57b58b4214395026b1432 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Wed, 1 Jul 2015 22:40:55 -0700 Subject: [PATCH 05/28] Add regression test for SPARK-8782 (ORDER BY NULL) --- .../test/scala/org/apache/spark/sql/SQLQuerySuite.scala | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 82dc0e9ce5132..cc6af1ccc1cce 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -1451,4 +1451,11 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { checkAnswer(sql("SELECT a.b FROM t ORDER BY b[0].d"), Row(Seq(Row(1)))) } } + + test("SPARK-8782: ORDER BY NULL") { + withTempTable("t") { + Seq((1, 2), (1, 2)).toDF("a", "b").registerTempTable("t") + checkAnswer(sql("SELECT * FROM t ORDER BY NULL"), Seq(Row(1, 2), Row(1, 2))) + } + } } From e7dc4fbb7c9e441c4367af7680c3acb42440ef33 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Wed, 1 Jul 2015 23:15:49 -0700 Subject: [PATCH 06/28] Add very generic test for ordering --- .../expressions/CodeGenerationSuite.scala | 48 +++++++++++++++++++ 1 file changed, 48 insertions(+) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala index 481b335d15dfd..491657de6afd8 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala @@ -17,9 +17,14 @@ package org.apache.spark.sql.catalyst.expressions +import scala.math._ + import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.{Row, RandomDataGenerator} +import org.apache.spark.sql.catalyst.CatalystTypeConverters import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.types.{DataTypeTestUtils, NullType, StructField, StructType} /** * Additional tests for code generation. @@ -42,4 +47,47 @@ class CodeGenerationSuite extends SparkFunSuite { futures.foreach(Await.result(_, 10.seconds)) } + + // Test GenerateOrdering for all common types. For each type, we construct random input rows that + // contain two columns of that type, then for pairs of randomly-generated rows we check that + // GenerateOrdering agrees with RowOrdering. + (DataTypeTestUtils.atomicTypes ++ Set(NullType)).foreach { dataType => + test(s"GenerateOrdering with $dataType") { + val rowOrdering = RowOrdering.forSchema(Seq(dataType, dataType)) + val genOrdering = GenerateOrdering.generate( + BoundReference(0, dataType, nullable = true).asc :: + BoundReference(1, dataType, nullable = true).asc :: Nil) + val rowType = StructType( + StructField("a", dataType, nullable = true) :: + StructField("b", dataType, nullable = true) :: Nil) + val toCatalyst = CatalystTypeConverters.createToCatalystConverter(rowType) + // Sort ordering is not defined for NaN, so skip any random inputs that contain it: + def isIncomparable(v: Any): Boolean = v match { + case d: Double => java.lang.Double.isNaN(d) + case f: Float => java.lang.Float.isNaN(f) + case _ => false + } + RandomDataGenerator.forType(rowType, nullable = false).foreach { randGenerator => + for (_ <- 1 to 50) { + val aExt = randGenerator().asInstanceOf[Row] + val bExt = randGenerator().asInstanceOf[Row] + if ((aExt.toSeq ++ bExt.toSeq).forall(v => !isIncomparable(v))) { + val a = toCatalyst(aExt).asInstanceOf[InternalRow] + val b = toCatalyst(bExt).asInstanceOf[InternalRow] + withClue(s"a = $a, b = $b") { + assert(genOrdering.compare(a, a) === 0) + assert(genOrdering.compare(b, b) === 0) + assert(rowOrdering.compare(a, a) === 0) + assert(rowOrdering.compare(b, b) === 0) + assert(signum(genOrdering.compare(a, b)) === -1 * signum(genOrdering.compare(b, a))) + assert(signum(rowOrdering.compare(a, b)) === -1 * signum(rowOrdering.compare(b, a))) + assert( + signum(rowOrdering.compare(a, b)) === signum(genOrdering.compare(a, b)), + "Generated and non-generated orderings should agree") + } + } + } + } + } + } } From f9efbb5f317d28f8d38e1de9943fa9f976e8b5e5 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Wed, 1 Jul 2015 23:17:28 -0700 Subject: [PATCH 07/28] Fix ORDER BY NULL --- .../spark/sql/catalyst/expressions/codegen/CodeGenerator.scala | 1 + 1 file changed, 1 insertion(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index a64027e48a00b..9f6329bbda4ec 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -185,6 +185,7 @@ class CodeGenContext { // use c1 - c2 may overflow case dt: DataType if isPrimitiveType(dt) => s"($c1 > $c2 ? 1 : $c1 < $c2 ? -1 : 0)" case BinaryType => s"org.apache.spark.sql.catalyst.util.TypeUtils.compareBinary($c1, $c2)" + case NullType => "0" case other => s"$c1.compare($c2)" } From 13fc06a6e339eda0bb1b775c64fa6c1d78ba19bb Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Thu, 2 Jul 2015 09:42:13 -0700 Subject: [PATCH 08/28] Add regression test for NaN sorting issue --- .../scala/org/apache/spark/sql/DataFrameSuite.scala | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index afb1cf5f8d1cb..e8decb028590f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql import scala.language.postfixOps +import scala.util.Random import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ @@ -737,4 +738,16 @@ class DataFrameSuite extends QueryTest { df.col("") df.col("t.``") } + + test("SPARK-XXXX: sort by float column containing NaN") { + val inputData = Seq.fill(10)(Tuple1(Float.NaN)) ++ (1 to 1000).map(x => Tuple1(x.toFloat)) + val df = Random.shuffle(inputData).toDF("a") + df.orderBy("a").collect() + } + + test("SPARK-XXXX: sort by double column containing NaN") { + val inputData = Seq.fill(10)(Tuple1(Double.NaN)) ++ (1 to 1000).map(x => Tuple1(x.toDouble)) + val df = Random.shuffle(inputData).toDF("a") + df.orderBy("a").collect() + } } From 9bf195a716a2191e621f7aefba3db329aa7656e4 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Thu, 2 Jul 2015 09:43:50 -0700 Subject: [PATCH 09/28] Re-enable NaNs in CodeGenerationSuite to produce more regression tests --- .../expressions/CodeGenerationSuite.scala | 36 +++++++------------ 1 file changed, 13 insertions(+), 23 deletions(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala index 491657de6afd8..aa0ee226d7600 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.expressions import scala.math._ import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.{Row, RandomDataGenerator} +import org.apache.spark.sql.RandomDataGenerator import org.apache.spark.sql.catalyst.CatalystTypeConverters import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen._ @@ -61,30 +61,20 @@ class CodeGenerationSuite extends SparkFunSuite { StructField("a", dataType, nullable = true) :: StructField("b", dataType, nullable = true) :: Nil) val toCatalyst = CatalystTypeConverters.createToCatalystConverter(rowType) - // Sort ordering is not defined for NaN, so skip any random inputs that contain it: - def isIncomparable(v: Any): Boolean = v match { - case d: Double => java.lang.Double.isNaN(d) - case f: Float => java.lang.Float.isNaN(f) - case _ => false - } RandomDataGenerator.forType(rowType, nullable = false).foreach { randGenerator => for (_ <- 1 to 50) { - val aExt = randGenerator().asInstanceOf[Row] - val bExt = randGenerator().asInstanceOf[Row] - if ((aExt.toSeq ++ bExt.toSeq).forall(v => !isIncomparable(v))) { - val a = toCatalyst(aExt).asInstanceOf[InternalRow] - val b = toCatalyst(bExt).asInstanceOf[InternalRow] - withClue(s"a = $a, b = $b") { - assert(genOrdering.compare(a, a) === 0) - assert(genOrdering.compare(b, b) === 0) - assert(rowOrdering.compare(a, a) === 0) - assert(rowOrdering.compare(b, b) === 0) - assert(signum(genOrdering.compare(a, b)) === -1 * signum(genOrdering.compare(b, a))) - assert(signum(rowOrdering.compare(a, b)) === -1 * signum(rowOrdering.compare(b, a))) - assert( - signum(rowOrdering.compare(a, b)) === signum(genOrdering.compare(a, b)), - "Generated and non-generated orderings should agree") - } + val a = toCatalyst(randGenerator()).asInstanceOf[InternalRow] + val b = toCatalyst(randGenerator()).asInstanceOf[InternalRow] + withClue(s"a = $a, b = $b") { + assert(genOrdering.compare(a, a) === 0) + assert(genOrdering.compare(b, b) === 0) + assert(rowOrdering.compare(a, a) === 0) + assert(rowOrdering.compare(b, b) === 0) + assert(signum(genOrdering.compare(a, b)) === -1 * signum(genOrdering.compare(b, a))) + assert(signum(rowOrdering.compare(a, b)) === -1 * signum(rowOrdering.compare(b, a))) + assert( + signum(rowOrdering.compare(a, b)) === signum(genOrdering.compare(a, b)), + "Generated and non-generated orderings should agree") } } } From 630ebc5756de8db0fe53e820ea70403c6d244ce3 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Thu, 2 Jul 2015 10:41:58 -0700 Subject: [PATCH 10/28] Specify an ordering for NaN values. --- .../expressions/codegen/CodeGenerator.scala | 2 + .../spark/sql/catalyst/util/TypeUtils.scala | 20 +++++++++ .../apache/spark/sql/types/DoubleType.scala | 5 ++- .../apache/spark/sql/types/FloatType.scala | 5 ++- .../sql/catalyst/util/TypeUtilsSuite.scala | 45 +++++++++++++++++++ .../org/apache/spark/sql/DataFrameSuite.scala | 4 +- 6 files changed, 77 insertions(+), 4 deletions(-) create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/TypeUtilsSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 9f6329bbda4ec..78aac0207019f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -182,6 +182,8 @@ class CodeGenContext { def genComp(dataType: DataType, c1: String, c2: String): String = dataType match { // java boolean doesn't support > or < operator case BooleanType => s"($c1 == $c2 ? 0 : ($c1 ? 1 : -1))" + case DoubleType => s"org.apache.spark.sql.catalyst.util.TypeUtils.compareDoubles($c1, $c2)" + case FloatType => s"org.apache.spark.sql.catalyst.util.TypeUtils.compareFloats($c1, $c2)" // use c1 - c2 may overflow case dt: DataType if isPrimitiveType(dt) => s"($c1 > $c2 ? 1 : $c1 < $c2 ? -1 : 0)" case BinaryType => s"org.apache.spark.sql.catalyst.util.TypeUtils.compareBinary($c1, $c2)" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala index 3148309a2166f..22fe9f297f013 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala @@ -70,4 +70,24 @@ object TypeUtils { } x.length - y.length } + + def compareDoubles(x: Double, y: Double): Int = { + val xIsNan: Boolean = java.lang.Double.isNaN(x) + val yIsNan: Boolean = java.lang.Double.isNaN(y) + if ((xIsNan && yIsNan) || (x == y)) 0 + else if (xIsNan) -1 + else if (yIsNan) 1 + else if (x > y) -1 + else 1 + } + + def compareFloats(x: Float, y: Float): Int = { + val xIsNan: Boolean = java.lang.Float.isNaN(x) + val yIsNan: Boolean = java.lang.Float.isNaN(y) + if ((xIsNan && yIsNan) || (x == y)) 0 + else if (xIsNan) -1 + else if (yIsNan) 1 + else if (x > y) -1 + else 1 + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DoubleType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DoubleType.scala index 66766623213c9..d031f323d4a69 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DoubleType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DoubleType.scala @@ -23,6 +23,7 @@ import scala.reflect.runtime.universe.typeTag import org.apache.spark.annotation.DeveloperApi import org.apache.spark.sql.catalyst.ScalaReflectionLock +import org.apache.spark.sql.catalyst.util.TypeUtils /** * :: DeveloperApi :: @@ -39,7 +40,9 @@ class DoubleType private() extends FractionalType { @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[InternalType] } private[sql] val numeric = implicitly[Numeric[Double]] private[sql] val fractional = implicitly[Fractional[Double]] - private[sql] val ordering = implicitly[Ordering[InternalType]] + private[sql] val ordering = new Ordering[Double] { + override def compare(x: Double, y: Double): Int = TypeUtils.compareDoubles(x, y) + } private[sql] val asIntegral = DoubleAsIfIntegral /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/FloatType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/FloatType.scala index 1d5a2f4f6f86c..61ba9700a7035 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/FloatType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/FloatType.scala @@ -23,6 +23,7 @@ import scala.reflect.runtime.universe.typeTag import org.apache.spark.annotation.DeveloperApi import org.apache.spark.sql.catalyst.ScalaReflectionLock +import org.apache.spark.sql.catalyst.util.TypeUtils /** * :: DeveloperApi :: @@ -39,7 +40,9 @@ class FloatType private() extends FractionalType { @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[InternalType] } private[sql] val numeric = implicitly[Numeric[Float]] private[sql] val fractional = implicitly[Fractional[Float]] - private[sql] val ordering = implicitly[Ordering[InternalType]] + private[sql] val ordering = new Ordering[Float] { + override def compare(x: Float, y: Float): Int = TypeUtils.compareFloats(x, y) + } private[sql] val asIntegral = FloatAsIfIntegral /** diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/TypeUtilsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/TypeUtilsSuite.scala new file mode 100644 index 0000000000000..a7326345342fa --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/TypeUtilsSuite.scala @@ -0,0 +1,45 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.util + +import org.apache.spark.SparkFunSuite + +class TypeUtilsSuite extends SparkFunSuite { + + import TypeUtils._ + + test("compareDoubles") { + assert(compareDoubles(0, 0) === 0) + assert(compareDoubles(1, 0) === -1) + assert(compareDoubles(0, 1) === 1) + assert(compareDoubles(Double.MinValue, Double.MaxValue) === 1) + assert(compareDoubles(Double.NaN, Double.NaN) === 0) + assert(compareDoubles(Double.NaN, Double.PositiveInfinity) === -1) + assert(compareDoubles(Double.NaN, Double.NegativeInfinity) === -1) + } + + test("compareFloats") { + assert(compareFloats(0, 0) === 0) + assert(compareFloats(1, 0) === -1) + assert(compareFloats(0, 1) === 1) + assert(compareFloats(Float.MinValue, Float.MaxValue) === 1) + assert(compareFloats(Float.NaN, Float.NaN) === 0) + assert(compareFloats(Float.NaN, Float.PositiveInfinity) === -1) + assert(compareFloats(Float.NaN, Float.NegativeInfinity) === -1) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index e8decb028590f..63f0a6e51c197 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -739,13 +739,13 @@ class DataFrameSuite extends QueryTest { df.col("t.``") } - test("SPARK-XXXX: sort by float column containing NaN") { + test("SPARK-8797: sort by float column containing NaN") { val inputData = Seq.fill(10)(Tuple1(Float.NaN)) ++ (1 to 1000).map(x => Tuple1(x.toFloat)) val df = Random.shuffle(inputData).toDF("a") df.orderBy("a").collect() } - test("SPARK-XXXX: sort by double column containing NaN") { + test("SPARK-8797: sort by double column containing NaN") { val inputData = Seq.fill(10)(Tuple1(Double.NaN)) ++ (1 to 1000).map(x => Tuple1(x.toDouble)) val df = Random.shuffle(inputData).toDF("a") df.orderBy("a").collect() From 5b88b2b26e93b8f5b475438ec29617f1774fdc58 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Fri, 17 Jul 2015 19:04:39 -0700 Subject: [PATCH 11/28] Fix compilation of CodeGenerationSuite --- .../spark/sql/catalyst/expressions/CodeGenerationSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala index aa0ee226d7600..221cb75276949 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala @@ -21,7 +21,7 @@ import scala.math._ import org.apache.spark.SparkFunSuite import org.apache.spark.sql.RandomDataGenerator -import org.apache.spark.sql.catalyst.CatalystTypeConverters +import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.types.{DataTypeTestUtils, NullType, StructField, StructType} From b20837bb2a9f377988df8037e6cdccfa946fa137 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Fri, 17 Jul 2015 19:06:45 -0700 Subject: [PATCH 12/28] Add failing test for new NaN comparision ordering --- .../scala/org/apache/spark/sql/DataFrameSuite.scala | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index d1eb6849e13d4..868379de5ac56 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -743,18 +743,27 @@ class DataFrameSuite extends QueryTest with SQLTestUtils { df.col("t.``") } - test("SPARK-8797: sort by float column containing NaN") { + test("SPARK-8797: sort by float column containing NaN should not crash") { val inputData = Seq.fill(10)(Tuple1(Float.NaN)) ++ (1 to 1000).map(x => Tuple1(x.toFloat)) val df = Random.shuffle(inputData).toDF("a") df.orderBy("a").collect() } - test("SPARK-8797: sort by double column containing NaN") { + test("SPARK-8797: sort by double column containing NaN should not crash") { val inputData = Seq.fill(10)(Tuple1(Double.NaN)) ++ (1 to 1000).map(x => Tuple1(x.toDouble)) val df = Random.shuffle(inputData).toDF("a") df.orderBy("a").collect() } + test("SPARK-9146: NaN is greater than all other non-NaN numeric values") { + val maxDouble = Seq(Double.NaN, Double.PositiveInfinity, Double.MaxValue) + .map(Tuple1.apply).toDF("a").selectExpr("max(a)").first() + assert(java.lang.Double.isNaN(maxDouble.getDouble(0))) + val maxFloat = Seq(Float.NaN, Float.PositiveInfinity, Float.MaxValue) + .map(Tuple1.apply).toDF("a").selectExpr("max(a)").first() + assert(java.lang.Float.isNaN(maxFloat.getFloat(0))) + } + test("SPARK-8072: Better Exception for Duplicate Columns") { // only one duplicate column present val e = intercept[org.apache.spark.sql.AnalysisException] { From 8d7be610fd3ec4be91ccd553a17bb9d2c935d7df Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Fri, 17 Jul 2015 19:10:57 -0700 Subject: [PATCH 13/28] Update randomized test to use ScalaTest's assume() --- .../expressions/CodeGenerationSuite.scala | 31 ++++++++++--------- 1 file changed, 16 insertions(+), 15 deletions(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala index 221cb75276949..9024058f01686 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala @@ -60,22 +60,23 @@ class CodeGenerationSuite extends SparkFunSuite { val rowType = StructType( StructField("a", dataType, nullable = true) :: StructField("b", dataType, nullable = true) :: Nil) + val maybeDataGenerator = RandomDataGenerator.forType(rowType, nullable = false) + assume(maybeDataGenerator.isDefined) + val randGenerator = maybeDataGenerator.get val toCatalyst = CatalystTypeConverters.createToCatalystConverter(rowType) - RandomDataGenerator.forType(rowType, nullable = false).foreach { randGenerator => - for (_ <- 1 to 50) { - val a = toCatalyst(randGenerator()).asInstanceOf[InternalRow] - val b = toCatalyst(randGenerator()).asInstanceOf[InternalRow] - withClue(s"a = $a, b = $b") { - assert(genOrdering.compare(a, a) === 0) - assert(genOrdering.compare(b, b) === 0) - assert(rowOrdering.compare(a, a) === 0) - assert(rowOrdering.compare(b, b) === 0) - assert(signum(genOrdering.compare(a, b)) === -1 * signum(genOrdering.compare(b, a))) - assert(signum(rowOrdering.compare(a, b)) === -1 * signum(rowOrdering.compare(b, a))) - assert( - signum(rowOrdering.compare(a, b)) === signum(genOrdering.compare(a, b)), - "Generated and non-generated orderings should agree") - } + for (_ <- 1 to 50) { + val a = toCatalyst(randGenerator()).asInstanceOf[InternalRow] + val b = toCatalyst(randGenerator()).asInstanceOf[InternalRow] + withClue(s"a = $a, b = $b") { + assert(genOrdering.compare(a, a) === 0) + assert(genOrdering.compare(b, b) === 0) + assert(rowOrdering.compare(a, a) === 0) + assert(rowOrdering.compare(b, b) === 0) + assert(signum(genOrdering.compare(a, b)) === -1 * signum(genOrdering.compare(b, a))) + assert(signum(rowOrdering.compare(a, b)) === -1 * signum(rowOrdering.compare(b, a))) + assert( + signum(rowOrdering.compare(a, b)) === signum(genOrdering.compare(a, b)), + "Generated and non-generated orderings should agree") } } } From bfca524765c2fca48f10214cda28770deb0f39ec Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Fri, 17 Jul 2015 19:19:49 -0700 Subject: [PATCH 14/28] Change ordering so that NaN is maximum value. --- .../apache/spark/sql/catalyst/util/TypeUtils.scala | 8 ++++---- .../spark/sql/catalyst/util/TypeUtilsSuite.scala | 12 ++++++++---- 2 files changed, 12 insertions(+), 8 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala index 9dd2e37fe81ef..f8d98cf9bceae 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala @@ -67,8 +67,8 @@ object TypeUtils { val xIsNan: Boolean = java.lang.Double.isNaN(x) val yIsNan: Boolean = java.lang.Double.isNaN(y) if ((xIsNan && yIsNan) || (x == y)) 0 - else if (xIsNan) -1 - else if (yIsNan) 1 + else if (xIsNan) 1 + else if (yIsNan) -1 else if (x > y) -1 else 1 } @@ -77,8 +77,8 @@ object TypeUtils { val xIsNan: Boolean = java.lang.Float.isNaN(x) val yIsNan: Boolean = java.lang.Float.isNaN(y) if ((xIsNan && yIsNan) || (x == y)) 0 - else if (xIsNan) -1 - else if (yIsNan) 1 + else if (xIsNan) 1 + else if (yIsNan) -1 else if (x > y) -1 else 1 } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/TypeUtilsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/TypeUtilsSuite.scala index a7326345342fa..4de3ec5da844f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/TypeUtilsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/TypeUtilsSuite.scala @@ -29,8 +29,10 @@ class TypeUtilsSuite extends SparkFunSuite { assert(compareDoubles(0, 1) === 1) assert(compareDoubles(Double.MinValue, Double.MaxValue) === 1) assert(compareDoubles(Double.NaN, Double.NaN) === 0) - assert(compareDoubles(Double.NaN, Double.PositiveInfinity) === -1) - assert(compareDoubles(Double.NaN, Double.NegativeInfinity) === -1) + assert(compareDoubles(Double.NaN, Double.PositiveInfinity) === 1) + assert(compareDoubles(Double.NaN, Double.NegativeInfinity) === 1) + assert(compareDoubles(Double.PositiveInfinity, Double.NaN) === -1) + assert(compareDoubles(Double.NegativeInfinity, Double.NaN) === -1) } test("compareFloats") { @@ -39,7 +41,9 @@ class TypeUtilsSuite extends SparkFunSuite { assert(compareFloats(0, 1) === 1) assert(compareFloats(Float.MinValue, Float.MaxValue) === 1) assert(compareFloats(Float.NaN, Float.NaN) === 0) - assert(compareFloats(Float.NaN, Float.PositiveInfinity) === -1) - assert(compareFloats(Float.NaN, Float.NegativeInfinity) === -1) + assert(compareFloats(Float.NaN, Float.PositiveInfinity) === 1) + assert(compareFloats(Float.NaN, Float.NegativeInfinity) === 1) + assert(compareFloats(Float.PositiveInfinity, Float.NaN) === -1) + assert(compareFloats(Float.NegativeInfinity, Float.NaN) === -1) } } From 42a1ad54e40784032b09b2b2dc9730d8db1e533d Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Fri, 17 Jul 2015 19:20:09 -0700 Subject: [PATCH 15/28] Stop filtering NaNs in UnsafeExternalSortSuite --- .../spark/sql/execution/UnsafeExternalSortSuite.scala | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeExternalSortSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeExternalSortSuite.scala index 4f4c1f28564cb..5fe73f7e0b072 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeExternalSortSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeExternalSortSuite.scala @@ -83,11 +83,7 @@ class UnsafeExternalSortSuite extends SparkPlanTest with BeforeAndAfterAll { randomDataGenerator <- RandomDataGenerator.forType(dataType, nullable) ) { test(s"sorting on $dataType with nullable=$nullable, sortOrder=$sortOrder") { - val inputData = Seq.fill(1000)(randomDataGenerator()).filter { - case d: Double => !d.isNaN - case f: Float => !java.lang.Float.isNaN(f) - case x => true - } + val inputData = Seq.fill(1000)(randomDataGenerator()) val inputDf = TestSQLContext.createDataFrame( TestSQLContext.sparkContext.parallelize(Random.shuffle(inputData).map(v => Row(v))), StructType(StructField("a", dataType, nullable = true) :: Nil) From 6f03f85cbdd0804bc2f2be42d4c980465b17d928 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Fri, 17 Jul 2015 19:40:37 -0700 Subject: [PATCH 16/28] Fix bug in Double / Float ordering --- .../spark/sql/catalyst/util/TypeUtils.scala | 8 +++---- .../sql/catalyst/util/TypeUtilsSuite.scala | 24 ++++++++++++------- 2 files changed, 20 insertions(+), 12 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala index f8d98cf9bceae..7ecf3e642c534 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala @@ -69,8 +69,8 @@ object TypeUtils { if ((xIsNan && yIsNan) || (x == y)) 0 else if (xIsNan) 1 else if (yIsNan) -1 - else if (x > y) -1 - else 1 + else if (x > y) 1 + else -1 } def compareFloats(x: Float, y: Float): Int = { @@ -79,7 +79,7 @@ object TypeUtils { if ((xIsNan && yIsNan) || (x == y)) 0 else if (xIsNan) 1 else if (yIsNan) -1 - else if (x > y) -1 - else 1 + else if (x > y) 1 + else -1 } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/TypeUtilsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/TypeUtilsSuite.scala index 4de3ec5da844f..aae7337069836 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/TypeUtilsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/TypeUtilsSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.catalyst.util +import java.lang.{Double => JDouble, Float => JFloat} + import org.apache.spark.SparkFunSuite class TypeUtilsSuite extends SparkFunSuite { @@ -24,10 +26,13 @@ class TypeUtilsSuite extends SparkFunSuite { import TypeUtils._ test("compareDoubles") { - assert(compareDoubles(0, 0) === 0) - assert(compareDoubles(1, 0) === -1) - assert(compareDoubles(0, 1) === 1) - assert(compareDoubles(Double.MinValue, Double.MaxValue) === 1) + def shouldMatchDefaultOrder(a: Double, b: Double): Unit = { + assert(compareDoubles(a, b) === JDouble.compare(a, b)) + assert(compareDoubles(b, a) === JDouble.compare(b, a)) + } + shouldMatchDefaultOrder(0d, 0d) + shouldMatchDefaultOrder(0d, 1d) + shouldMatchDefaultOrder(Double.MinValue, Double.MaxValue) assert(compareDoubles(Double.NaN, Double.NaN) === 0) assert(compareDoubles(Double.NaN, Double.PositiveInfinity) === 1) assert(compareDoubles(Double.NaN, Double.NegativeInfinity) === 1) @@ -36,10 +41,13 @@ class TypeUtilsSuite extends SparkFunSuite { } test("compareFloats") { - assert(compareFloats(0, 0) === 0) - assert(compareFloats(1, 0) === -1) - assert(compareFloats(0, 1) === 1) - assert(compareFloats(Float.MinValue, Float.MaxValue) === 1) + def shouldMatchDefaultOrder(a: Float, b: Float): Unit = { + assert(compareFloats(a, b) === JFloat.compare(a, b)) + assert(compareFloats(b, a) === JFloat.compare(b, a)) + } + shouldMatchDefaultOrder(0f, 0f) + shouldMatchDefaultOrder(1f, 1f) + shouldMatchDefaultOrder(Float.MinValue, Float.MaxValue) assert(compareFloats(Float.NaN, Float.NaN) === 0) assert(compareFloats(Float.NaN, Float.PositiveInfinity) === 1) assert(compareFloats(Float.NaN, Float.NegativeInfinity) === 1) From a30d3711cc0a8cb41bf8cd103c6df9ab45f373b2 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Fri, 17 Jul 2015 19:59:38 -0700 Subject: [PATCH 17/28] Compare rows' string representations to work around NaN incomparability. --- .../spark/sql/execution/SparkPlanTest.scala | 26 ++++++++++++++----- .../execution/UnsafeExternalSortSuite.scala | 3 ++- 2 files changed, 21 insertions(+), 8 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala index 6a8f394545816..17c05e854e9f2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala @@ -115,13 +115,17 @@ class SparkPlanTest extends SparkFunSuite { * treated as the source-of-truth for the test. * @param sortAnswers if true, the answers will be sorted by their toString representations prior * to being compared. + * @param compareStrings if true, the answers will be converted to strings before being compared */ protected def checkThatPlansAgree( input: DataFrame, planFunction: SparkPlan => SparkPlan, expectedPlanFunction: SparkPlan => SparkPlan, - sortAnswers: Boolean = true): Unit = { - SparkPlanTest.checkAnswer(input, planFunction, expectedPlanFunction, sortAnswers) match { + sortAnswers: Boolean = true, + compareStrings: Boolean = false): Unit = { + val result = SparkPlanTest.checkAnswer( + input, planFunction, expectedPlanFunction, sortAnswers, compareStrings) + result match { case Some(errorMessage) => fail(errorMessage) case None => } @@ -142,12 +146,14 @@ object SparkPlanTest { * instantiate a reference implementation of the physical operator * that's being tested. The result of executing this plan will be * treated as the source-of-truth for the test. + * @param compareStrings if true, the answers will be converted to strings before being compared */ def checkAnswer( input: DataFrame, planFunction: SparkPlan => SparkPlan, expectedPlanFunction: SparkPlan => SparkPlan, - sortAnswers: Boolean): Option[String] = { + sortAnswers: Boolean, + compareStrings: Boolean): Option[String] = { val outputPlan = planFunction(input.queryExecution.sparkPlan) val expectedOutputPlan = expectedPlanFunction(input.queryExecution.sparkPlan) @@ -182,7 +188,7 @@ object SparkPlanTest { return Some(errorMessage) } - compareAnswers(actualAnswer, expectedAnswer, sortAnswers).map { errorMessage => + compareAnswers(actualAnswer, expectedAnswer, sortAnswers, compareStrings).map { errorMessage => s""" | Results do not match. | Actual result Spark plan: @@ -226,7 +232,7 @@ object SparkPlanTest { return Some(errorMessage) } - compareAnswers(sparkAnswer, expectedAnswer, sortAnswers).map { errorMessage => + compareAnswers(sparkAnswer, expectedAnswer, sortAnswers, false).map { errorMessage => s""" | Results do not match for Spark plan: | $outputPlan @@ -238,7 +244,8 @@ object SparkPlanTest { private def compareAnswers( sparkAnswer: Seq[Row], expectedAnswer: Seq[Row], - sort: Boolean): Option[String] = { + sort: Boolean, + compareStrings: Boolean): Option[String] = { def prepareAnswer(answer: Seq[Row]): Seq[Row] = { // Converts data to types that we can do equality comparison using Scala collections. // For BigDecimal type, the Scala type has a better definition of equality test (similar to @@ -253,11 +260,16 @@ object SparkPlanTest { case o => o }) } - if (sort) { + val maybeSorted = if (sort) { converted.sortBy(_.toString()) } else { converted } + if (compareStrings) { + maybeSorted.map(r => Row.fromSeq(r.toSeq.map(String.valueOf))) // valueOf handles nulls + } else { + maybeSorted + } } if (prepareAnswer(expectedAnswer) != prepareAnswer(sparkAnswer)) { val errorMessage = diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeExternalSortSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeExternalSortSuite.scala index 5fe73f7e0b072..24f6bdb3136c8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeExternalSortSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeExternalSortSuite.scala @@ -93,7 +93,8 @@ class UnsafeExternalSortSuite extends SparkPlanTest with BeforeAndAfterAll { inputDf, UnsafeExternalSort(sortOrder, global = true, _: SparkPlan, testSpillFrequency = 23), Sort(sortOrder, global = true, _: SparkPlan), - sortAnswers = false + sortAnswers = false, + compareStrings = true ) } } From a2ba2e77e922e3ed4e69cd9a8c00940d9c64c17c Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Fri, 17 Jul 2015 23:01:45 -0700 Subject: [PATCH 18/28] Fix prefix comparision for NaNs --- .../unsafe/sort/PrefixComparators.java | 5 +- .../scala/org/apache/spark/util/Utils.scala | 28 +++++++++ .../org/apache/spark/util/UtilsSuite.scala | 31 ++++++++++ .../unsafe/sort/PrefixComparatorsSuite.scala | 25 ++++++++ .../spark/sql/catalyst/util/TypeUtils.scala | 10 ---- .../apache/spark/sql/types/DoubleType.scala | 4 +- .../apache/spark/sql/types/FloatType.scala | 4 +- .../sql/catalyst/util/TypeUtilsSuite.scala | 57 ------------------- 8 files changed, 91 insertions(+), 73 deletions(-) delete mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/TypeUtilsSuite.scala diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparators.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparators.java index 438742565c51d..bf1bc5dffba78 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparators.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparators.java @@ -23,6 +23,7 @@ import org.apache.spark.annotation.Private; import org.apache.spark.unsafe.types.UTF8String; +import org.apache.spark.util.Utils; @Private public class PrefixComparators { @@ -82,7 +83,7 @@ public static final class FloatPrefixComparator extends PrefixComparator { public int compare(long aPrefix, long bPrefix) { float a = Float.intBitsToFloat((int) aPrefix); float b = Float.intBitsToFloat((int) bPrefix); - return (a < b) ? -1 : (a > b) ? 1 : 0; + return Utils.nanSafeCompareFloats(a, b); } public long computePrefix(float value) { @@ -97,7 +98,7 @@ public static final class DoublePrefixComparator extends PrefixComparator { public int compare(long aPrefix, long bPrefix) { double a = Double.longBitsToDouble(aPrefix); double b = Double.longBitsToDouble(bPrefix); - return (a < b) ? -1 : (a > b) ? 1 : 0; + return Utils.nanSafeCompareDoubles(a, b); } public long computePrefix(double value) { diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index e6374f17d858f..c5816949cd360 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -1586,6 +1586,34 @@ private[spark] object Utils extends Logging { hashAbs } + /** + * NaN-safe version of [[java.lang.Double.compare()]] which allows NaN values to be compared + * according to semantics where NaN == NaN and NaN > any non-NaN double. + */ + def nanSafeCompareDoubles(x: Double, y: Double): Int = { + val xIsNan: Boolean = java.lang.Double.isNaN(x) + val yIsNan: Boolean = java.lang.Double.isNaN(y) + if ((xIsNan && yIsNan) || (x == y)) 0 + else if (xIsNan) 1 + else if (yIsNan) -1 + else if (x > y) 1 + else -1 + } + + /** + * NaN-safe version of [[java.lang.Float.compare()]] which allows NaN values to be compared + * according to semantics where NaN == NaN and NaN > any non-NaN float. + */ + def nanSafeCompareFloats(x: Float, y: Float): Int = { + val xIsNan: Boolean = java.lang.Float.isNaN(x) + val yIsNan: Boolean = java.lang.Float.isNaN(y) + if ((xIsNan && yIsNan) || (x == y)) 0 + else if (xIsNan) 1 + else if (yIsNan) -1 + else if (x > y) 1 + else -1 + } + /** Returns the system properties map that is thread-safe to iterator over. It gets the * properties which have been set explicitly, as well as those for which only a default value * has been defined. */ diff --git a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala index c7638507c88c6..8f7e402d5f2a6 100644 --- a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.util import java.io.{File, ByteArrayOutputStream, ByteArrayInputStream, FileOutputStream} +import java.lang.{Double => JDouble, Float => JFloat} import java.net.{BindException, ServerSocket, URI} import java.nio.{ByteBuffer, ByteOrder} import java.text.DecimalFormatSymbols @@ -689,4 +690,34 @@ class UtilsSuite extends SparkFunSuite with ResetSystemProperties with Logging { // scalastyle:on println assert(buffer.toString === "t circular test circular\n") } + + test("nanSafeCompareDoubles") { + def shouldMatchDefaultOrder(a: Double, b: Double): Unit = { + assert(Utils.nanSafeCompareDoubles(a, b) === JDouble.compare(a, b)) + assert(Utils.nanSafeCompareDoubles(b, a) === JDouble.compare(b, a)) + } + shouldMatchDefaultOrder(0d, 0d) + shouldMatchDefaultOrder(0d, 1d) + shouldMatchDefaultOrder(Double.MinValue, Double.MaxValue) + assert(Utils.nanSafeCompareDoubles(Double.NaN, Double.NaN) === 0) + assert(Utils.nanSafeCompareDoubles(Double.NaN, Double.PositiveInfinity) === 1) + assert(Utils.nanSafeCompareDoubles(Double.NaN, Double.NegativeInfinity) === 1) + assert(Utils.nanSafeCompareDoubles(Double.PositiveInfinity, Double.NaN) === -1) + assert(Utils.nanSafeCompareDoubles(Double.NegativeInfinity, Double.NaN) === -1) + } + + test("nanSafeCompareFloats") { + def shouldMatchDefaultOrder(a: Float, b: Float): Unit = { + assert(Utils.nanSafeCompareFloats(a, b) === JFloat.compare(a, b)) + assert(Utils.nanSafeCompareFloats(b, a) === JFloat.compare(b, a)) + } + shouldMatchDefaultOrder(0f, 0f) + shouldMatchDefaultOrder(1f, 1f) + shouldMatchDefaultOrder(Float.MinValue, Float.MaxValue) + assert(Utils.nanSafeCompareFloats(Float.NaN, Float.NaN) === 0) + assert(Utils.nanSafeCompareFloats(Float.NaN, Float.PositiveInfinity) === 1) + assert(Utils.nanSafeCompareFloats(Float.NaN, Float.NegativeInfinity) === 1) + assert(Utils.nanSafeCompareFloats(Float.PositiveInfinity, Float.NaN) === -1) + assert(Utils.nanSafeCompareFloats(Float.NegativeInfinity, Float.NaN) === -1) + } } diff --git a/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/PrefixComparatorsSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/PrefixComparatorsSuite.scala index dd505dfa7d758..dc03e374b51db 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/PrefixComparatorsSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/PrefixComparatorsSuite.scala @@ -47,4 +47,29 @@ class PrefixComparatorsSuite extends SparkFunSuite with PropertyChecks { forAll (regressionTests) { (s1: String, s2: String) => testPrefixComparison(s1, s2) } forAll { (s1: String, s2: String) => testPrefixComparison(s1, s2) } } + + test("float prefix comparator handles NaN properly") { + val nan1: Float = java.lang.Float.intBitsToFloat(0x7f800001) + val nan2: Float = java.lang.Float.intBitsToFloat(0x7fffffff) + assert(nan1.isNaN) + assert(nan2.isNaN) + val nan1Prefix = PrefixComparators.FLOAT.computePrefix(nan1) + val nan2Prefix = PrefixComparators.FLOAT.computePrefix(nan2) + assert(nan1Prefix === nan2Prefix) + val floatMaxPrefix = PrefixComparators.FLOAT.computePrefix(Float.MaxValue) + assert(PrefixComparators.FLOAT.compare(nan1Prefix, floatMaxPrefix) === 1) + } + + test("double prefix comparator handles NaNs properly") { + val nan1: Double = java.lang.Double.longBitsToDouble(0x7ff0000000000001L) + val nan2: Double = java.lang.Double.longBitsToDouble(0x7fffffffffffffffL) + assert(nan1.isNaN) + assert(nan2.isNaN) + val nan1Prefix = PrefixComparators.DOUBLE.computePrefix(nan1) + val nan2Prefix = PrefixComparators.DOUBLE.computePrefix(nan2) + assert(nan1Prefix === nan2Prefix) + val doubleMaxPrefix = PrefixComparators.DOUBLE.computePrefix(Double.MaxValue) + assert(PrefixComparators.DOUBLE.compare(nan1Prefix, doubleMaxPrefix) === 1) + } + } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala index 7ecf3e642c534..1800be0aaa12c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala @@ -63,16 +63,6 @@ object TypeUtils { x.length - y.length } - def compareDoubles(x: Double, y: Double): Int = { - val xIsNan: Boolean = java.lang.Double.isNaN(x) - val yIsNan: Boolean = java.lang.Double.isNaN(y) - if ((xIsNan && yIsNan) || (x == y)) 0 - else if (xIsNan) 1 - else if (yIsNan) -1 - else if (x > y) 1 - else -1 - } - def compareFloats(x: Float, y: Float): Int = { val xIsNan: Boolean = java.lang.Float.isNaN(x) val yIsNan: Boolean = java.lang.Float.isNaN(y) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DoubleType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DoubleType.scala index af66d7d4afa3c..2a1bf0938e5a8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DoubleType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DoubleType.scala @@ -23,7 +23,7 @@ import scala.reflect.runtime.universe.typeTag import org.apache.spark.annotation.DeveloperApi import org.apache.spark.sql.catalyst.ScalaReflectionLock -import org.apache.spark.sql.catalyst.util.TypeUtils +import org.apache.spark.util.Utils /** * :: DeveloperApi :: @@ -39,7 +39,7 @@ class DoubleType private() extends FractionalType { private[sql] val numeric = implicitly[Numeric[Double]] private[sql] val fractional = implicitly[Fractional[Double]] private[sql] val ordering = new Ordering[Double] { - override def compare(x: Double, y: Double): Int = TypeUtils.compareDoubles(x, y) + override def compare(x: Double, y: Double): Int = Utils.nanSafeCompareDoubles(x, y) } private[sql] val asIntegral = DoubleAsIfIntegral diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/FloatType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/FloatType.scala index bd11cfaba98b7..08e22252aef82 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/FloatType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/FloatType.scala @@ -23,7 +23,7 @@ import scala.reflect.runtime.universe.typeTag import org.apache.spark.annotation.DeveloperApi import org.apache.spark.sql.catalyst.ScalaReflectionLock -import org.apache.spark.sql.catalyst.util.TypeUtils +import org.apache.spark.util.Utils /** * :: DeveloperApi :: @@ -39,7 +39,7 @@ class FloatType private() extends FractionalType { private[sql] val numeric = implicitly[Numeric[Float]] private[sql] val fractional = implicitly[Fractional[Float]] private[sql] val ordering = new Ordering[Float] { - override def compare(x: Float, y: Float): Int = TypeUtils.compareFloats(x, y) + override def compare(x: Float, y: Float): Int = Utils.nanSafeCompareFloats(x, y) } private[sql] val asIntegral = FloatAsIfIntegral diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/TypeUtilsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/TypeUtilsSuite.scala deleted file mode 100644 index aae7337069836..0000000000000 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/TypeUtilsSuite.scala +++ /dev/null @@ -1,57 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.util - -import java.lang.{Double => JDouble, Float => JFloat} - -import org.apache.spark.SparkFunSuite - -class TypeUtilsSuite extends SparkFunSuite { - - import TypeUtils._ - - test("compareDoubles") { - def shouldMatchDefaultOrder(a: Double, b: Double): Unit = { - assert(compareDoubles(a, b) === JDouble.compare(a, b)) - assert(compareDoubles(b, a) === JDouble.compare(b, a)) - } - shouldMatchDefaultOrder(0d, 0d) - shouldMatchDefaultOrder(0d, 1d) - shouldMatchDefaultOrder(Double.MinValue, Double.MaxValue) - assert(compareDoubles(Double.NaN, Double.NaN) === 0) - assert(compareDoubles(Double.NaN, Double.PositiveInfinity) === 1) - assert(compareDoubles(Double.NaN, Double.NegativeInfinity) === 1) - assert(compareDoubles(Double.PositiveInfinity, Double.NaN) === -1) - assert(compareDoubles(Double.NegativeInfinity, Double.NaN) === -1) - } - - test("compareFloats") { - def shouldMatchDefaultOrder(a: Float, b: Float): Unit = { - assert(compareFloats(a, b) === JFloat.compare(a, b)) - assert(compareFloats(b, a) === JFloat.compare(b, a)) - } - shouldMatchDefaultOrder(0f, 0f) - shouldMatchDefaultOrder(1f, 1f) - shouldMatchDefaultOrder(Float.MinValue, Float.MaxValue) - assert(compareFloats(Float.NaN, Float.NaN) === 0) - assert(compareFloats(Float.NaN, Float.PositiveInfinity) === 1) - assert(compareFloats(Float.NaN, Float.NegativeInfinity) === 1) - assert(compareFloats(Float.PositiveInfinity, Float.NaN) === -1) - assert(compareFloats(Float.NegativeInfinity, Float.NaN) === -1) - } -} From 3998ef208f73b7af6e070f42f9ddfe01efb22890 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Fri, 17 Jul 2015 23:03:42 -0700 Subject: [PATCH 19/28] Remove unused code --- .../org/apache/spark/sql/catalyst/util/TypeUtils.scala | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala index 1800be0aaa12c..0103ddcf9cfb7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala @@ -62,14 +62,4 @@ object TypeUtils { } x.length - y.length } - - def compareFloats(x: Float, y: Float): Int = { - val xIsNan: Boolean = java.lang.Float.isNaN(x) - val yIsNan: Boolean = java.lang.Float.isNaN(y) - if ((xIsNan && yIsNan) || (x == y)) 0 - else if (xIsNan) 1 - else if (yIsNan) -1 - else if (x > y) 1 - else -1 - } } From fc6b4d2cb62072a73c1756ee6759bc3157574ec0 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Sat, 18 Jul 2015 00:39:37 -0700 Subject: [PATCH 20/28] Update CodeGenerator --- .../sql/catalyst/expressions/codegen/CodeGenerator.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 1822ec86fe5e3..99a98c01a46d2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -194,8 +194,8 @@ class CodeGenContext { def genComp(dataType: DataType, c1: String, c2: String): String = dataType match { // java boolean doesn't support > or < operator case BooleanType => s"($c1 == $c2 ? 0 : ($c1 ? 1 : -1))" - case DoubleType => s"org.apache.spark.sql.catalyst.util.TypeUtils.compareDoubles($c1, $c2)" - case FloatType => s"org.apache.spark.sql.catalyst.util.TypeUtils.compareFloats($c1, $c2)" + case DoubleType => s"org.apache.spark.util.Utils.nanSafeCompareDoubles($c1, $c2)" + case FloatType => s"org.apache.spark.util.Utils.nanSafeCompareFloats($c1, $c2)" // use c1 - c2 may overflow case dt: DataType if isPrimitiveType(dt) => s"($c1 > $c2 ? 1 : $c1 < $c2 ? -1 : 0)" case BinaryType => s"org.apache.spark.sql.catalyst.util.TypeUtils.compareBinary($c1, $c2)" From 58bad2cc91d146b7e33fe7bfb472242f90debb4c Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Sat, 18 Jul 2015 18:37:48 -0700 Subject: [PATCH 21/28] Revert "Compare rows' string representations to work around NaN incomparability." This reverts commit a30d3711cc0a8cb41bf8cd103c6df9ab45f373b2. --- .../spark/sql/execution/SparkPlanTest.scala | 26 +++++-------------- .../execution/UnsafeExternalSortSuite.scala | 3 +-- 2 files changed, 8 insertions(+), 21 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala index 17c05e854e9f2..6a8f394545816 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala @@ -115,17 +115,13 @@ class SparkPlanTest extends SparkFunSuite { * treated as the source-of-truth for the test. * @param sortAnswers if true, the answers will be sorted by their toString representations prior * to being compared. - * @param compareStrings if true, the answers will be converted to strings before being compared */ protected def checkThatPlansAgree( input: DataFrame, planFunction: SparkPlan => SparkPlan, expectedPlanFunction: SparkPlan => SparkPlan, - sortAnswers: Boolean = true, - compareStrings: Boolean = false): Unit = { - val result = SparkPlanTest.checkAnswer( - input, planFunction, expectedPlanFunction, sortAnswers, compareStrings) - result match { + sortAnswers: Boolean = true): Unit = { + SparkPlanTest.checkAnswer(input, planFunction, expectedPlanFunction, sortAnswers) match { case Some(errorMessage) => fail(errorMessage) case None => } @@ -146,14 +142,12 @@ object SparkPlanTest { * instantiate a reference implementation of the physical operator * that's being tested. The result of executing this plan will be * treated as the source-of-truth for the test. - * @param compareStrings if true, the answers will be converted to strings before being compared */ def checkAnswer( input: DataFrame, planFunction: SparkPlan => SparkPlan, expectedPlanFunction: SparkPlan => SparkPlan, - sortAnswers: Boolean, - compareStrings: Boolean): Option[String] = { + sortAnswers: Boolean): Option[String] = { val outputPlan = planFunction(input.queryExecution.sparkPlan) val expectedOutputPlan = expectedPlanFunction(input.queryExecution.sparkPlan) @@ -188,7 +182,7 @@ object SparkPlanTest { return Some(errorMessage) } - compareAnswers(actualAnswer, expectedAnswer, sortAnswers, compareStrings).map { errorMessage => + compareAnswers(actualAnswer, expectedAnswer, sortAnswers).map { errorMessage => s""" | Results do not match. | Actual result Spark plan: @@ -232,7 +226,7 @@ object SparkPlanTest { return Some(errorMessage) } - compareAnswers(sparkAnswer, expectedAnswer, sortAnswers, false).map { errorMessage => + compareAnswers(sparkAnswer, expectedAnswer, sortAnswers).map { errorMessage => s""" | Results do not match for Spark plan: | $outputPlan @@ -244,8 +238,7 @@ object SparkPlanTest { private def compareAnswers( sparkAnswer: Seq[Row], expectedAnswer: Seq[Row], - sort: Boolean, - compareStrings: Boolean): Option[String] = { + sort: Boolean): Option[String] = { def prepareAnswer(answer: Seq[Row]): Seq[Row] = { // Converts data to types that we can do equality comparison using Scala collections. // For BigDecimal type, the Scala type has a better definition of equality test (similar to @@ -260,16 +253,11 @@ object SparkPlanTest { case o => o }) } - val maybeSorted = if (sort) { + if (sort) { converted.sortBy(_.toString()) } else { converted } - if (compareStrings) { - maybeSorted.map(r => Row.fromSeq(r.toSeq.map(String.valueOf))) // valueOf handles nulls - } else { - maybeSorted - } } if (prepareAnswer(expectedAnswer) != prepareAnswer(sparkAnswer)) { val errorMessage = diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeExternalSortSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeExternalSortSuite.scala index 24f6bdb3136c8..5fe73f7e0b072 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeExternalSortSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeExternalSortSuite.scala @@ -93,8 +93,7 @@ class UnsafeExternalSortSuite extends SparkPlanTest with BeforeAndAfterAll { inputDf, UnsafeExternalSort(sortOrder, global = true, _: SparkPlan, testSpillFrequency = 23), Sort(sortOrder, global = true, _: SparkPlan), - sortAnswers = false, - compareStrings = true + sortAnswers = false ) } } From 7fe67aff227692fcbbc1a65c4627124a4dd21578 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Sat, 18 Jul 2015 19:31:40 -0700 Subject: [PATCH 22/28] Support NaN == NaN (SPARK-9145) --- .../main/scala/org/apache/spark/sql/Row.scala | 26 +++++++++++++------ .../expressions/codegen/CodeGenerator.scala | 2 ++ .../sql/catalyst/expressions/predicates.scala | 4 +++ .../catalyst/expressions/PredicateSuite.scala | 13 ++++++++++ .../scala/org/apache/spark/sql/RowSuite.scala | 12 +++++++++ 5 files changed, 49 insertions(+), 8 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala index 2cb64d00935de..e3f04cc8e318e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala @@ -403,20 +403,30 @@ trait Row extends Serializable { if (!isNullAt(i)) { val o1 = get(i) val o2 = other.get(i) - if (o1.isInstanceOf[Array[Byte]]) { - // handle equality of Array[Byte] - val b1 = o1.asInstanceOf[Array[Byte]] - if (!o2.isInstanceOf[Array[Byte]] || - !java.util.Arrays.equals(b1, o2.asInstanceOf[Array[Byte]])) { + o1 match { + case b1: Array[Byte] => + if (!o2.isInstanceOf[Array[Byte]] || + !java.util.Arrays.equals(b1, o2.asInstanceOf[Array[Byte]])) { + return false + } + case f1: Float => + if (!o2.isInstanceOf[Float] || + (java.lang.Float.isNaN(f1) && !java.lang.Float.isNaN(o2.asInstanceOf[Float]))) { + return false + } + case d1: Double => + if (!o2.isInstanceOf[Double] || + (java.lang.Double.isNaN(d1) && !java.lang.Double.isNaN(o2.asInstanceOf[Double]))) { + return false + } + case _ => if (o1 != o2) { return false } - } else if (o1 != o2) { - return false } } i += 1 } - return true + true } /* ---------------------- utility methods for Scala ---------------------- */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 99a98c01a46d2..2db33cd31aa55 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -184,6 +184,8 @@ class CodeGenContext { */ def genEqual(dataType: DataType, c1: String, c2: String): String = dataType match { case BinaryType => s"java.util.Arrays.equals($c1, $c2)" + case FloatType => s"(java.lang.Float.isNaN($c1) && java.lang.Float.isNaN($c2)) || $c1 == $c2" + case DoubleType => s"(java.lang.Double.isNaN($c1) && java.lang.Double.isNaN($c2)) || $c1 == $c2" case dt: DataType if isPrimitiveType(dt) => s"$c1 == $c2" case other => s"$c1.equals($c2)" } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index 2751c8e75f357..caffa1b102fa5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -304,6 +304,8 @@ case class EqualTo(left: Expression, right: Expression) extends BinaryComparison override def symbol: String = "=" protected override def nullSafeEval(input1: Any, input2: Any): Any = { + // Note that we do not have to do anything special here to handle NaN values: boxed Double and + // Float NaNs will be equal (see Float.equals()' Javadoc for more details). if (left.dataType != BinaryType) input1 == input2 else java.util.Arrays.equals(input1.asInstanceOf[Array[Byte]], input2.asInstanceOf[Array[Byte]]) } @@ -330,6 +332,8 @@ case class EqualNullSafe(left: Expression, right: Expression) extends BinaryComp } else if (input1 == null || input2 == null) { false } else { + // Note that we do not have to do anything special here to handle NaN values: boxed Double and + // Float NaNs will be equal (see Float.equals()' Javadoc for more details). if (left.dataType != BinaryType) { input1 == input2 } else { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala index 052abc51af5fd..b16736c984496 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala @@ -126,6 +126,19 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(IsNaN(Literal(5.5f)), false) } + test("NaN equality and comparison") { + def testNaN(nan: Expression): Unit = { + checkEvaluation(nan === nan, true) + checkEvaluation(nan <=> nan, true) +// checkEvaluation(nan <= nan, true) +// checkEvaluation(nan >= nan, true) +// checkEvaluation(nan < nan, false) +// checkEvaluation(nan > nan, false) + } + testNaN(Literal(Float.NaN)) + testNaN(Literal(Double.NaN)) + } + test("INSET") { val hS = HashSet[Any]() + 1 + 2 val nS = HashSet[Any]() + 1 + 2 + null diff --git a/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala index d84b57af9c882..7cc6ffd7548d0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala @@ -73,4 +73,16 @@ class RowSuite extends SparkFunSuite { row.getAs[Int]("c") } } + + test("float NaN == NaN") { + val r1 = Row(Float.NaN) + val r2 = Row(Float.NaN) + assert(r1 === r2) + } + + test("double NaN == NaN") { + val r1 = Row(Double.NaN) + val r2 = Row(Double.NaN) + assert(r1 === r2) + } } From b31eb1907fbefa5d99b48a6e585522b62f12c6a6 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Sat, 18 Jul 2015 19:34:22 -0700 Subject: [PATCH 23/28] Uncomment failing tests --- .../spark/sql/catalyst/expressions/PredicateSuite.scala | 8 ++++---- .../test/scala/org/apache/spark/sql/DataFrameSuite.scala | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala index b16736c984496..743b7d509678a 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala @@ -130,10 +130,10 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper { def testNaN(nan: Expression): Unit = { checkEvaluation(nan === nan, true) checkEvaluation(nan <=> nan, true) -// checkEvaluation(nan <= nan, true) -// checkEvaluation(nan >= nan, true) -// checkEvaluation(nan < nan, false) -// checkEvaluation(nan > nan, false) + checkEvaluation(nan <= nan, true) + checkEvaluation(nan >= nan, true) + checkEvaluation(nan < nan, false) + checkEvaluation(nan > nan, false) } testNaN(Literal(Float.NaN)) testNaN(Literal(Double.NaN)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 868379de5ac56..f67f2c60c0e16 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -755,7 +755,7 @@ class DataFrameSuite extends QueryTest with SQLTestUtils { df.orderBy("a").collect() } - test("SPARK-9146: NaN is greater than all other non-NaN numeric values") { + test("NaN is greater than all other non-NaN numeric values") { val maxDouble = Seq(Double.NaN, Double.PositiveInfinity, Double.MaxValue) .map(Tuple1.apply).toDF("a").selectExpr("max(a)").first() assert(java.lang.Double.isNaN(maxDouble.getDouble(0))) From c1fd4fef1b69c66da2b73af0a641ed0a2e70bea4 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Sat, 18 Jul 2015 20:03:38 -0700 Subject: [PATCH 24/28] Fold NaN test into existing test framework --- .../catalyst/expressions/PredicateSuite.scala | 23 +++++-------------- 1 file changed, 6 insertions(+), 17 deletions(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala index 743b7d509678a..089c5a8f4797d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala @@ -126,19 +126,6 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(IsNaN(Literal(5.5f)), false) } - test("NaN equality and comparison") { - def testNaN(nan: Expression): Unit = { - checkEvaluation(nan === nan, true) - checkEvaluation(nan <=> nan, true) - checkEvaluation(nan <= nan, true) - checkEvaluation(nan >= nan, true) - checkEvaluation(nan < nan, false) - checkEvaluation(nan > nan, false) - } - testNaN(Literal(Float.NaN)) - testNaN(Literal(Double.NaN)) - } - test("INSET") { val hS = HashSet[Any]() + 1 + 2 val nS = HashSet[Any]() + 1 + 2 + null @@ -155,11 +142,13 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(And(InSet(one, hS), InSet(two, hS)), true) } - private val smallValues = Seq(1, Decimal(1), Array(1.toByte), "a").map(Literal(_)) - private val largeValues = Seq(2, Decimal(2), Array(2.toByte), "b").map(Literal(_)) + private val smallValues = Seq(1, Decimal(1), Array(1.toByte), "a", 0f, 0d).map(Literal(_)) + private val largeValues = + Seq(2, Decimal(2), Array(2.toByte), "b", Float.NaN, Double.NaN).map(Literal(_)) - private val equalValues1 = smallValues - private val equalValues2 = Seq(1, Decimal(1), Array(1.toByte), "a").map(Literal(_)) + private val equalValues1 = smallValues ++ Seq(Float.NaN, Double.NaN).map(Literal(_)) + private val equalValues2 = + Seq(1, Decimal(1), Array(1.toByte), "a", Float.NaN, Double.NaN).map(Literal(_)) test("BinaryComparison: <") { for (i <- 0 until smallValues.length) { From fbb2a29ac7189addd8745abf6601532b6e88c781 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Sat, 18 Jul 2015 20:16:01 -0700 Subject: [PATCH 25/28] Fix NaN comparisons in BinaryComparison expressions --- .../sql/catalyst/expressions/predicates.scala | 33 ++++++++++++++----- .../catalyst/expressions/PredicateSuite.scala | 3 +- 2 files changed, 27 insertions(+), 9 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index caffa1b102fa5..c32fc7fec20b9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -272,7 +272,9 @@ case class Or(left: Expression, right: Expression) extends BinaryOperator with P abstract class BinaryComparison extends BinaryOperator with Predicate { override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - if (ctx.isPrimitiveType(left.dataType)) { + if (ctx.isPrimitiveType(left.dataType) + && left.dataType != FloatType + && left.dataType != DoubleType) { // faster version defineCodeGen(ctx, ev, (c1, c2) => s"$c1 $symbol $c2") } else { @@ -304,10 +306,19 @@ case class EqualTo(left: Expression, right: Expression) extends BinaryComparison override def symbol: String = "=" protected override def nullSafeEval(input1: Any, input2: Any): Any = { - // Note that we do not have to do anything special here to handle NaN values: boxed Double and - // Float NaNs will be equal (see Float.equals()' Javadoc for more details). - if (left.dataType != BinaryType) input1 == input2 - else java.util.Arrays.equals(input1.asInstanceOf[Array[Byte]], input2.asInstanceOf[Array[Byte]]) + if (left.dataType == FloatType) { + val f1 = input1.asInstanceOf[Float] + val f2 = input2.asInstanceOf[Float] + (java.lang.Float.isNaN(f1) && java.lang.Float.isNaN(f2)) || f1 == f2 + } else if (left.dataType == DoubleType) { + val d1 = input1.asInstanceOf[Double] + val d2 = input2.asInstanceOf[Double] + (java.lang.Double.isNaN(d1) && java.lang.Double.isNaN(d2)) || d1 == d2 + } else if (left.dataType != BinaryType) { + input1 == input2 + } else { + java.util.Arrays.equals(input1.asInstanceOf[Array[Byte]], input2.asInstanceOf[Array[Byte]]) + } } override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { @@ -332,9 +343,15 @@ case class EqualNullSafe(left: Expression, right: Expression) extends BinaryComp } else if (input1 == null || input2 == null) { false } else { - // Note that we do not have to do anything special here to handle NaN values: boxed Double and - // Float NaNs will be equal (see Float.equals()' Javadoc for more details). - if (left.dataType != BinaryType) { + if (left.dataType == FloatType) { + val f1 = input1.asInstanceOf[Float] + val f2 = input2.asInstanceOf[Float] + (java.lang.Float.isNaN(f1) && java.lang.Float.isNaN(f2)) || f1 == f2 + } else if (left.dataType == DoubleType) { + val d1 = input1.asInstanceOf[Double] + val d2 = input2.asInstanceOf[Double] + (java.lang.Double.isNaN(d1) && java.lang.Double.isNaN(d2)) || d1 == d2 + } else if (left.dataType != BinaryType) { input1 == input2 } else { java.util.Arrays.equals(input1.asInstanceOf[Array[Byte]], input2.asInstanceOf[Array[Byte]]) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala index 089c5a8f4797d..547f2085bb230 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala @@ -146,7 +146,8 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper { private val largeValues = Seq(2, Decimal(2), Array(2.toByte), "b", Float.NaN, Double.NaN).map(Literal(_)) - private val equalValues1 = smallValues ++ Seq(Float.NaN, Double.NaN).map(Literal(_)) + private val equalValues1 = + Seq(1, Decimal(1), Array(1.toByte), "a", Float.NaN, Double.NaN).map(Literal(_)) private val equalValues2 = Seq(1, Decimal(1), Array(1.toByte), "a", Float.NaN, Double.NaN).map(Literal(_)) From a7267cf2904c8d1ec856d83b7dc106bb94db454e Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Sat, 18 Jul 2015 20:29:47 -0700 Subject: [PATCH 26/28] Normalize NaNs in UnsafeRow --- .../sql/catalyst/expressions/UnsafeRow.java | 6 +++++ .../expressions/UnsafeRowConverterSuite.scala | 22 +++++++++++++++++++ 2 files changed, 28 insertions(+) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java index 87294a0e21441..8cd9e7bc60a03 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java @@ -215,6 +215,9 @@ public void setLong(int ordinal, long value) { public void setDouble(int ordinal, double value) { assertIndexIsValid(ordinal); setNotNullAt(ordinal); + if (Double.isNaN(value)) { + value = Double.NaN; + } PlatformDependent.UNSAFE.putDouble(baseObject, getFieldOffset(ordinal), value); } @@ -243,6 +246,9 @@ public void setByte(int ordinal, byte value) { public void setFloat(int ordinal, float value) { assertIndexIsValid(ordinal); setNotNullAt(ordinal); + if (Float.isNaN(value)) { + value = Float.NaN; + } PlatformDependent.UNSAFE.putFloat(baseObject, getFieldOffset(ordinal), value); } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala index d00aeb4dfbf47..ef76af190b807 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala @@ -316,4 +316,26 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { assert(setToNullAfterCreation.get(11) === rowWithNoNullColumns.get(11)) } + test("NaN normalization") { + val fieldTypes: Array[DataType] = Array(FloatType, DoubleType) + + val row1 = new SpecificMutableRow(fieldTypes) + row1.setFloat(0, java.lang.Float.intBitsToFloat(0x7f800001)) + row1.setDouble(1, java.lang.Double.longBitsToDouble(0x7ff0000000000001L)) + + val row2 = new SpecificMutableRow(fieldTypes) + row2.setFloat(0, java.lang.Float.intBitsToFloat(0x7fffffff)) + row2.setDouble(1, java.lang.Double.longBitsToDouble(0x7fffffffffffffffL)) + + val converter = new UnsafeRowConverter(fieldTypes) + val row1Buffer = new Array[Byte](converter.getSizeRequirement(row1)) + val row2Buffer = new Array[Byte](converter.getSizeRequirement(row2)) + converter.writeRow( + row1, row1Buffer, PlatformDependent.BYTE_ARRAY_OFFSET, row1Buffer.length, null) + converter.writeRow( + row2, row2Buffer, PlatformDependent.BYTE_ARRAY_OFFSET, row2Buffer.length, null) + + assert(row1Buffer.toSeq === row2Buffer.toSeq) + } + } From a702e2eff2f2d9696a0f2ffa5e54736d0e66fa9f Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Sat, 18 Jul 2015 20:30:45 -0700 Subject: [PATCH 27/28] normalization -> canonicalization --- .../sql/catalyst/expressions/UnsafeRowConverterSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala index ef76af190b807..dff5faf9f6ec8 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala @@ -316,7 +316,7 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { assert(setToNullAfterCreation.get(11) === rowWithNoNullColumns.get(11)) } - test("NaN normalization") { + test("NaN canonicalization") { val fieldTypes: Array[DataType] = Array(FloatType, DoubleType) val row1 = new SpecificMutableRow(fieldTypes) From 88bd73c5ec7ee81aa17eb0de4f1851e60ebc6f51 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Sat, 18 Jul 2015 21:56:01 -0700 Subject: [PATCH 28/28] Fix Row.equals() --- .../main/scala/org/apache/spark/sql/Row.scala | 14 ++++++-------- .../sql/catalyst/expressions/predicates.scala | 17 +++++------------ 2 files changed, 11 insertions(+), 20 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala index e3f04cc8e318e..91449479fa539 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala @@ -406,17 +406,15 @@ trait Row extends Serializable { o1 match { case b1: Array[Byte] => if (!o2.isInstanceOf[Array[Byte]] || - !java.util.Arrays.equals(b1, o2.asInstanceOf[Array[Byte]])) { + !java.util.Arrays.equals(b1, o2.asInstanceOf[Array[Byte]])) { return false } - case f1: Float => - if (!o2.isInstanceOf[Float] || - (java.lang.Float.isNaN(f1) && !java.lang.Float.isNaN(o2.asInstanceOf[Float]))) { - return false + case f1: Float if java.lang.Float.isNaN(f1) => + if (!o2.isInstanceOf[Float] || ! java.lang.Float.isNaN(o2.asInstanceOf[Float])) { + return false } - case d1: Double => - if (!o2.isInstanceOf[Double] || - (java.lang.Double.isNaN(d1) && !java.lang.Double.isNaN(o2.asInstanceOf[Double]))) { + case d1: Double if java.lang.Double.isNaN(d1) => + if (!o2.isInstanceOf[Double] || ! java.lang.Double.isNaN(o2.asInstanceOf[Double])) { return false } case _ => if (o1 != o2) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index d4890ad91ca95..a53ec31ee6a4b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -22,6 +22,7 @@ import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.types._ +import org.apache.spark.util.Utils object InterpretedPredicate { @@ -257,13 +258,9 @@ case class EqualTo(left: Expression, right: Expression) extends BinaryComparison protected override def nullSafeEval(input1: Any, input2: Any): Any = { if (left.dataType == FloatType) { - val f1 = input1.asInstanceOf[Float] - val f2 = input2.asInstanceOf[Float] - (java.lang.Float.isNaN(f1) && java.lang.Float.isNaN(f2)) || f1 == f2 + Utils.nanSafeCompareFloats(input1.asInstanceOf[Float], input2.asInstanceOf[Float]) == 0 } else if (left.dataType == DoubleType) { - val d1 = input1.asInstanceOf[Double] - val d2 = input2.asInstanceOf[Double] - (java.lang.Double.isNaN(d1) && java.lang.Double.isNaN(d2)) || d1 == d2 + Utils.nanSafeCompareDoubles(input1.asInstanceOf[Double], input2.asInstanceOf[Double]) == 0 } else if (left.dataType != BinaryType) { input1 == input2 } else { @@ -294,13 +291,9 @@ case class EqualNullSafe(left: Expression, right: Expression) extends BinaryComp false } else { if (left.dataType == FloatType) { - val f1 = input1.asInstanceOf[Float] - val f2 = input2.asInstanceOf[Float] - (java.lang.Float.isNaN(f1) && java.lang.Float.isNaN(f2)) || f1 == f2 + Utils.nanSafeCompareFloats(input1.asInstanceOf[Float], input2.asInstanceOf[Float]) == 0 } else if (left.dataType == DoubleType) { - val d1 = input1.asInstanceOf[Double] - val d2 = input2.asInstanceOf[Double] - (java.lang.Double.isNaN(d1) && java.lang.Double.isNaN(d2)) || d1 == d2 + Utils.nanSafeCompareDoubles(input1.asInstanceOf[Double], input2.asInstanceOf[Double]) == 0 } else if (left.dataType != BinaryType) { input1 == input2 } else {