From 0d3bb3c04c62f071fc626dba7d7f2cbaf05b865e Mon Sep 17 00:00:00 2001 From: Yijie Shen Date: Sat, 1 Aug 2015 23:57:15 +0800 Subject: [PATCH 1/7] [WIP] Utilize ScalaCheck to reveal potential bugs in sql expressions --- .../ArithmeticExpressionSuite.scala | 30 ++++++ .../expressions/ExpressionEvalHelper.scala | 69 ++++++++++++- .../expressions/MathFunctionsSuite.scala | 7 +- .../expressions/PropertyGenerator.scala | 97 +++++++++++++++++++ .../spark/sql/types/DataTypeTestUtils.scala | 14 +++ 5 files changed, 214 insertions(+), 3 deletions(-) create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PropertyGenerator.scala diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala index a1f15e4f0f25a..35d4f887da03f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala @@ -52,6 +52,9 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(Add(positiveShortLit, negativeShortLit), -1.toShort) checkEvaluation(Add(positiveIntLit, negativeIntLit), -1) checkEvaluation(Add(positiveLongLit, negativeLongLit), -1L) + DataTypeTestUtils.numericAndInterval.foreach { tpe => + checkConsistency(tpe, tpe, classOf[Add]) + } } test("- (UnaryMinus)") { @@ -71,6 +74,9 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(UnaryMinus(negativeIntLit), - negativeInt) checkEvaluation(UnaryMinus(positiveLongLit), - positiveLong) checkEvaluation(UnaryMinus(negativeLongLit), - negativeLong) + DataTypeTestUtils.numericAndInterval.foreach { tpe => + checkConsistency(tpe, classOf[UnaryMinus]) + } } test("- (Minus)") { @@ -85,6 +91,9 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper (positiveShort - negativeShort).toShort) checkEvaluation(Subtract(positiveIntLit, negativeIntLit), positiveInt - negativeInt) checkEvaluation(Subtract(positiveLongLit, negativeLongLit), positiveLong - negativeLong) + DataTypeTestUtils.numericAndInterval.foreach { tpe => + checkConsistency(tpe, tpe, classOf[Subtract]) + } } test("* (Multiply)") { @@ -99,6 +108,9 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper (positiveShort * negativeShort).toShort) checkEvaluation(Multiply(positiveIntLit, negativeIntLit), positiveInt * negativeInt) checkEvaluation(Multiply(positiveLongLit, negativeLongLit), positiveLong * negativeLong) + DataTypeTestUtils.numericTypeWithoutDecimal.foreach { tpe => + checkConsistency(tpe, tpe, classOf[Multiply]) + } } test("/ (Divide) basic") { @@ -111,6 +123,9 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(Divide(left, Literal.create(null, right.dataType)), null) checkEvaluation(Divide(left, Literal(convert(0))), null) // divide by zero } + DataTypeTestUtils.numericTypeWithoutDecimal.foreach { tpe => + checkConsistency(tpe, tpe, classOf[Divide]) + } } test("/ (Divide) for integral type") { @@ -144,6 +159,9 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(Remainder(negativeIntLit, negativeIntLit), 0) checkEvaluation(Remainder(positiveLongLit, positiveLongLit), 0L) checkEvaluation(Remainder(negativeLongLit, negativeLongLit), 0L) + DataTypeTestUtils.numericTypeWithoutDecimal.foreach { tpe => + checkConsistency(tpe, tpe, classOf[Remainder]) + } } test("Abs") { @@ -161,6 +179,9 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(Abs(negativeIntLit), - negativeInt) checkEvaluation(Abs(positiveLongLit), positiveLong) checkEvaluation(Abs(negativeLongLit), - negativeLong) + DataTypeTestUtils.numericTypeWithoutDecimal.foreach { tpe => + checkConsistency(tpe, classOf[Abs]) + } } test("MaxOf basic") { @@ -175,6 +196,9 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(MaxOf(positiveShortLit, negativeShortLit), (positiveShort).toShort) checkEvaluation(MaxOf(positiveIntLit, negativeIntLit), positiveInt) checkEvaluation(MaxOf(positiveLongLit, negativeLongLit), positiveLong) + DataTypeTestUtils.ordered.foreach { tpe => + checkConsistency(tpe, tpe, classOf[MaxOf]) + } } test("MaxOf for atomic type") { @@ -196,6 +220,9 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(MinOf(positiveShortLit, negativeShortLit), (negativeShort).toShort) checkEvaluation(MinOf(positiveIntLit, negativeIntLit), negativeInt) checkEvaluation(MinOf(positiveLongLit, negativeLongLit), negativeLong) + DataTypeTestUtils.ordered.foreach { tpe => + checkConsistency(tpe, tpe, classOf[MinOf]) + } } test("MinOf for atomic type") { @@ -222,4 +249,7 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(Pmod(positiveInt, negativeInt), positiveInt) checkEvaluation(Pmod(positiveLong, negativeLong), positiveLong) } + DataTypeTestUtils.numericTypeWithoutDecimal.foreach { tpe => + checkConsistency(tpe, tpe, classOf[MinOf]) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala index a41185b4d8754..da9f8497a7c5b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala @@ -24,11 +24,12 @@ import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.optimizer.DefaultOptimizer import org.apache.spark.sql.catalyst.plans.logical.{OneRowRelation, Project} +import org.apache.spark.sql.types.DataType /** * A few helper functions for expression evaluation testing. Mixin this trait to use them. */ -trait ExpressionEvalHelper { +trait ExpressionEvalHelper extends PropertyGenerator { self: SparkFunSuite => protected def create_row(values: Any*): InternalRow = { @@ -211,4 +212,70 @@ trait ExpressionEvalHelper { plan(inputRow)).get(0, expression.dataType) assert(checkResult(actual, expected)) } + + def checkConsistency(dt: DataType, clazz: Class[_]): Unit = { + val ctor = clazz.getDeclaredConstructor(classOf[Expression]) + forAll (randomGen(dt)) { (l: Literal) => + val expr = ctor.newInstance(l).asInstanceOf[Expression] + cmpInterpretWithCodegen(EmptyRow, expr) + } + } + + def checkConsistency(dt1: DataType, dt2: DataType, clazz: Class[_]): Unit = { + val ctor = clazz.getDeclaredConstructor(classOf[Expression], classOf[Expression]) + forAll ( + randomGen(dt1), + randomGen(dt2) + ) { (l1: Literal, l2: Literal) => + val expr = ctor.newInstance(l1, l2).asInstanceOf[Expression] + cmpInterpretWithCodegen(EmptyRow, expr) + } + } + + def checkConsistency(dt1: DataType, dt2: DataType, dt3: DataType, + clazz: Class[_]): Unit = { + val ctor = clazz.getDeclaredConstructor( + classOf[Expression], classOf[Expression], classOf[Expression]) + forAll ( + randomGen(dt1), + randomGen(dt2), + randomGen(dt3) + ) { (l1: Literal, l2: Literal, l3: Literal) => + val expr = ctor.newInstance(l1, l2, l3).asInstanceOf[Expression] + cmpInterpretWithCodegen(EmptyRow, expr) + } + } + + private def cmpInterpretWithCodegen(inputRow: InternalRow, expr: Expression): Unit = { + val interpret = try evaluate(expr, inputRow) catch { + case e: Exception => fail(s"Exception evaluating $expr", e) + } + + val plan = generateProject( + GenerateMutableProjection.generate(Alias(expr, s"Optimized($expr)")() :: Nil)(), + expr) + val codegen = plan(inputRow).get(0, expr.dataType) + + if (!checkResultRegardingNaN(interpret, codegen)) { + fail(s"Incorrect evaluation: $expr, interpret: $interpret, codegen: $codegen") + } + } + + /** + * Check the equality between result of expression and expected value, it will handle + * Array[Byte] and Spread[Double]. + */ + private[this] def checkResultRegardingNaN(result: Any, expected: Any): Boolean = { + (result, expected) match { + case (result: Array[Byte], expected: Array[Byte]) => + java.util.Arrays.equals(result, expected) + case (result: Double, expected: Spread[Double]) => + expected.isWithin(result) + case (result: Double, expected: Double) if result.isNaN && expected.isNaN => + true + case (result: Float, expected: Float) if result.isNaN && expected.isNaN => + true + case _ => result == expected + } + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala index 033792eee6c0f..d3ef4383de1fa 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala @@ -200,8 +200,11 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { } test("acos") { - testUnary(Acos, math.acos, (-10 to 10).map(_ * 0.1)) - testUnary(Acos, math.acos, (11 to 20).map(_ * 0.1), expectNaN = true) + for (i <- 0 to 100) { + testUnary(Acos, math.acos, (-10 to 10).map(_ * 0.1)) + testUnary(Acos, math.acos, (11 to 20).map(_ * 0.1), expectNaN = true) + checkConsistency(DoubleType, classOf[Acos]) + } } test("cosh") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PropertyGenerator.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PropertyGenerator.scala new file mode 100644 index 0000000000000..11fe3397a7019 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PropertyGenerator.scala @@ -0,0 +1,97 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import java.sql.{Date, Timestamp} + +import org.scalacheck.{Arbitrary, Gen} +import org.scalatest.Matchers +import org.scalatest.prop.GeneratorDrivenPropertyChecks + +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.CalendarInterval + + +trait PropertyGenerator extends GeneratorDrivenPropertyChecks with Matchers { + + lazy val byteLiteralGen: Gen[Literal] = + for { b <- Arbitrary.arbByte.arbitrary } yield Literal.create(b, ByteType) + + lazy val shortLiteralGen: Gen[Literal] = + for { s <- Arbitrary.arbShort.arbitrary } yield Literal.create(s, ShortType) + + lazy val integerLiteralGen: Gen[Literal] = + for { i <- Arbitrary.arbInt.arbitrary } yield Literal.create(i, IntegerType) + + lazy val longLiteralGen: Gen[Literal] = + for { l <- Arbitrary.arbLong.arbitrary } yield Literal.create(l, LongType) + + lazy val floatLiteralGen: Gen[Literal] = + for { f <- Arbitrary.arbFloat.arbitrary } yield Literal.create(f, FloatType) + + lazy val doubleLiteralGen: Gen[Literal] = + for { d <- Arbitrary.arbDouble.arbitrary } yield Literal.create(d, DoubleType) + + // TODO: decimal type + + lazy val stringLiteralGen: Gen[Literal] = + for { s <- Arbitrary.arbString.arbitrary } yield Literal.create(s, StringType) + + lazy val binaryLiteralGen: Gen[Literal] = + for { ab <- Gen.listOf[Byte](Arbitrary.arbByte.arbitrary) } + yield Literal.create(ab, BinaryType) + + lazy val booleanLiteralGen: Gen[Literal] = + for { b <- Arbitrary.arbBool.arbitrary } yield Literal.create(b, BooleanType) + + lazy val dateLiteralGen: Gen[Literal] = + for { d <- Arbitrary.arbInt.arbitrary } yield Literal.create(new Date(d), DateType) + + lazy val timestampLiteralGen: Gen[Literal] = + for { t <- Arbitrary.arbLong.arbitrary } yield Literal.create(new Timestamp(t), TimestampType) + + lazy val calendarIntervalLiterGen: Gen[Literal] = + for { m <- Arbitrary.arbInt.arbitrary; s <- Arbitrary.arbLong.arbitrary} + yield Literal.create(new CalendarInterval(m, s), CalendarIntervalType) + + + // Sometimes, it would be quite expensive when unlimited value is used, + // for example, the `times` arguments for StringRepeat would hang the test 'forever' + // if it's tested against Int.MaxValue by ScalaCheck, therefore, use values from a limited + // range is more reasonable + val limitedIntegerLiteralGen: Gen[Literal] = + for {i <- Gen.choose(-100, 100)} yield Literal.create(i, IntegerType) + + def randomGen(dt: DataType): Gen[Literal] = { + dt match { + case ByteType => byteLiteralGen + case ShortType => shortLiteralGen + case IntegerType => integerLiteralGen + case LongType => longLiteralGen + case DoubleType => doubleLiteralGen + case FloatType => floatLiteralGen + case DateType => dateLiteralGen + case TimestampType => timestampLiteralGen + case BooleanType => booleanLiteralGen + case StringType => stringLiteralGen + case BinaryType => binaryLiteralGen + case CalendarIntervalType => calendarIntervalLiterGen + case dt => throw new IllegalArgumentException(s"not supported type $dt") + } + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeTestUtils.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeTestUtils.scala index 417df006ab7c2..18144ad23a07f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeTestUtils.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeTestUtils.scala @@ -46,6 +46,20 @@ object DataTypeTestUtils { */ val numericTypes: Set[NumericType] = integralType ++ fractionalTypes + // TODO: remove this once we find out how to handle decimal properly in property check + val numericTypeWithoutDecimal: Set[DataType] = integralType ++ Set(DoubleType, FloatType) + + /** + * Instances of all [[NumericType]]s and CalendarIntervalType + */ + val numericAndInterval: Set[DataType] = numericTypeWithoutDecimal + CalendarIntervalType + + /** + * All the types that support ordering + */ + val ordered: Set[DataType] = + numericTypeWithoutDecimal + BooleanType + TimestampType + DateType + StringType + BinaryType + /** * Instances of all [[AtomicType]]s. */ From e05bbd06361fb3d615f36779f778f5a4cfcd654e Mon Sep 17 00:00:00 2001 From: Yijie Shen Date: Sun, 2 Aug 2015 14:45:11 +0800 Subject: [PATCH 2/7] property check more expressions --- .../expressions/ArithmeticExpressionSuite.scala | 15 +++++++++++++++ .../expressions/BitwiseFunctionsSuite.scala | 17 ++++++++++++++++- .../ConditionalExpressionSuite.scala | 13 ++++++++++++- .../expressions/DateExpressionsSuite.scala | 17 +++++++++++++++++ .../expressions/ExpressionEvalHelper.scala | 11 +++++++++++ .../expressions/PropertyGenerator.scala | 16 +++++++++++----- .../spark/sql/types/DataTypeTestUtils.scala | 5 +++++ 7 files changed, 87 insertions(+), 7 deletions(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala index 35d4f887da03f..c9a6325bcedb0 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala @@ -52,6 +52,7 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(Add(positiveShortLit, negativeShortLit), -1.toShort) checkEvaluation(Add(positiveIntLit, negativeIntLit), -1) checkEvaluation(Add(positiveLongLit, negativeLongLit), -1L) + DataTypeTestUtils.numericAndInterval.foreach { tpe => checkConsistency(tpe, tpe, classOf[Add]) } @@ -74,6 +75,7 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(UnaryMinus(negativeIntLit), - negativeInt) checkEvaluation(UnaryMinus(positiveLongLit), - positiveLong) checkEvaluation(UnaryMinus(negativeLongLit), - negativeLong) + DataTypeTestUtils.numericAndInterval.foreach { tpe => checkConsistency(tpe, classOf[UnaryMinus]) } @@ -91,6 +93,7 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper (positiveShort - negativeShort).toShort) checkEvaluation(Subtract(positiveIntLit, negativeIntLit), positiveInt - negativeInt) checkEvaluation(Subtract(positiveLongLit, negativeLongLit), positiveLong - negativeLong) + DataTypeTestUtils.numericAndInterval.foreach { tpe => checkConsistency(tpe, tpe, classOf[Subtract]) } @@ -108,6 +111,7 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper (positiveShort * negativeShort).toShort) checkEvaluation(Multiply(positiveIntLit, negativeIntLit), positiveInt * negativeInt) checkEvaluation(Multiply(positiveLongLit, negativeLongLit), positiveLong * negativeLong) + DataTypeTestUtils.numericTypeWithoutDecimal.foreach { tpe => checkConsistency(tpe, tpe, classOf[Multiply]) } @@ -123,6 +127,7 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(Divide(left, Literal.create(null, right.dataType)), null) checkEvaluation(Divide(left, Literal(convert(0))), null) // divide by zero } + DataTypeTestUtils.numericTypeWithoutDecimal.foreach { tpe => checkConsistency(tpe, tpe, classOf[Divide]) } @@ -159,9 +164,15 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(Remainder(negativeIntLit, negativeIntLit), 0) checkEvaluation(Remainder(positiveLongLit, positiveLongLit), 0L) checkEvaluation(Remainder(negativeLongLit, negativeLongLit), 0L) + DataTypeTestUtils.numericTypeWithoutDecimal.foreach { tpe => checkConsistency(tpe, tpe, classOf[Remainder]) } + // TODO: the following lines would fail the test due to inconsistency result of interpret + // and codegen for remainder between giant values, seems like a numeric stability issue + // DataTypeTestUtils.numericTypeWithoutDecimal.foreach { tpe => + // checkConsistency(tpe, tpe, classOf[Remainder]) + // } } test("Abs") { @@ -179,6 +190,7 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(Abs(negativeIntLit), - negativeInt) checkEvaluation(Abs(positiveLongLit), positiveLong) checkEvaluation(Abs(negativeLongLit), - negativeLong) + DataTypeTestUtils.numericTypeWithoutDecimal.foreach { tpe => checkConsistency(tpe, classOf[Abs]) } @@ -196,6 +208,7 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(MaxOf(positiveShortLit, negativeShortLit), (positiveShort).toShort) checkEvaluation(MaxOf(positiveIntLit, negativeIntLit), positiveInt) checkEvaluation(MaxOf(positiveLongLit, negativeLongLit), positiveLong) + DataTypeTestUtils.ordered.foreach { tpe => checkConsistency(tpe, tpe, classOf[MaxOf]) } @@ -220,6 +233,7 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(MinOf(positiveShortLit, negativeShortLit), (negativeShort).toShort) checkEvaluation(MinOf(positiveIntLit, negativeIntLit), negativeInt) checkEvaluation(MinOf(positiveLongLit, negativeLongLit), negativeLong) + DataTypeTestUtils.ordered.foreach { tpe => checkConsistency(tpe, tpe, classOf[MinOf]) } @@ -249,6 +263,7 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(Pmod(positiveInt, negativeInt), positiveInt) checkEvaluation(Pmod(positiveLong, negativeLong), positiveLong) } + DataTypeTestUtils.numericTypeWithoutDecimal.foreach { tpe => checkConsistency(tpe, tpe, classOf[MinOf]) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/BitwiseFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/BitwiseFunctionsSuite.scala index 4fc1c06153595..9e34561ddf257 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/BitwiseFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/BitwiseFunctionsSuite.scala @@ -45,6 +45,10 @@ class BitwiseFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(BitwiseNot(negativeIntLit), ~negativeInt) checkEvaluation(BitwiseNot(positiveLongLit), ~positiveLong) checkEvaluation(BitwiseNot(negativeLongLit), ~negativeLong) + + DataTypeTestUtils.integralType.foreach { dt => + checkConsistency(dt, classOf[BitwiseNot]) + } } test("BitwiseAnd") { @@ -68,6 +72,10 @@ class BitwiseFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { (positiveShort & negativeShort).toShort) checkEvaluation(BitwiseAnd(positiveIntLit, negativeIntLit), positiveInt & negativeInt) checkEvaluation(BitwiseAnd(positiveLongLit, negativeLongLit), positiveLong & negativeLong) + + DataTypeTestUtils.integralType.foreach { dt => + checkConsistency(dt, dt, classOf[BitwiseAnd]) + } } test("BitwiseOr") { @@ -91,6 +99,10 @@ class BitwiseFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { (positiveShort | negativeShort).toShort) checkEvaluation(BitwiseOr(positiveIntLit, negativeIntLit), positiveInt | negativeInt) checkEvaluation(BitwiseOr(positiveLongLit, negativeLongLit), positiveLong | negativeLong) + + DataTypeTestUtils.integralType.foreach { dt => + checkConsistency(dt, dt, classOf[BitwiseOr]) + } } test("BitwiseXor") { @@ -110,10 +122,13 @@ class BitwiseFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(BitwiseXor(nullLit, Literal(1)), null) checkEvaluation(BitwiseXor(Literal(1), nullLit), null) checkEvaluation(BitwiseXor(nullLit, nullLit), null) - checkEvaluation(BitwiseXor(positiveShortLit, negativeShortLit), (positiveShort ^ negativeShort).toShort) checkEvaluation(BitwiseXor(positiveIntLit, negativeIntLit), positiveInt ^ negativeInt) checkEvaluation(BitwiseXor(positiveLongLit, negativeLongLit), positiveLong ^ negativeLong) + + DataTypeTestUtils.integralType.foreach { dt => + checkConsistency(dt, dt, classOf[BitwiseXor]) + } } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala index d26bcdb2902ab..c4effe9df75ec 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala @@ -66,6 +66,10 @@ class ConditionalExpressionSuite extends SparkFunSuite with ExpressionEvalHelper testIf(_.toLong, TimestampType) testIf(_.toString, StringType) + + DataTypeTestUtils.propertyCheckSupported.foreach { dt => + checkConsistency(BooleanType, dt, dt, classOf[If]) + } } test("case when") { @@ -176,6 +180,10 @@ class ConditionalExpressionSuite extends SparkFunSuite with ExpressionEvalHelper Literal(Timestamp.valueOf("2015-07-01 08:00:00")), Literal(Timestamp.valueOf("2015-07-01 10:00:00")))), Timestamp.valueOf("2015-07-01 08:00:00"), InternalRow.empty) + + DataTypeTestUtils.ordered.foreach { dt => + checkSeqConsistency(dt, classOf[Least], 2) + } } test("function greatest") { @@ -218,6 +226,9 @@ class ConditionalExpressionSuite extends SparkFunSuite with ExpressionEvalHelper Literal(Timestamp.valueOf("2015-07-01 08:00:00")), Literal(Timestamp.valueOf("2015-07-01 10:00:00")))), Timestamp.valueOf("2015-07-01 10:00:00"), InternalRow.empty) - } + DataTypeTestUtils.ordered.foreach { dt => + checkSeqConsistency(dt, classOf[Greatest], 2) + } + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala index f9b73f1a75e73..7ab3bd6d516e5 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala @@ -60,6 +60,7 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { } } checkEvaluation(DayOfYear(Literal.create(null, DateType)), null) + checkConsistency(DateType, classOf[DayOfYear]) } test("Year") { @@ -79,6 +80,7 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { } } } + checkConsistency(DateType, classOf[Year]) } test("Quarter") { @@ -98,6 +100,7 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { } } } + checkConsistency(DateType, classOf[Quarter]) } test("Month") { @@ -117,6 +120,7 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { } } } + checkConsistency(DateType, classOf[Month]) } test("Day / DayOfMonth") { @@ -135,6 +139,7 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { c.get(Calendar.DAY_OF_MONTH)) } } + checkConsistency(DateType, classOf[DayOfMonth]) } test("Seconds") { @@ -149,6 +154,7 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(Second(Literal(new Timestamp(c.getTimeInMillis))), c.get(Calendar.SECOND)) } + checkConsistency(TimestampType, classOf[Second]) } test("WeekOfYear") { @@ -157,6 +163,7 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(WeekOfYear(Cast(Literal(sdfDate.format(d)), DateType)), 15) checkEvaluation(WeekOfYear(Cast(Literal(ts), DateType)), 45) checkEvaluation(WeekOfYear(Cast(Literal("2011-05-06"), DateType)), 18) + checkConsistency(DateType, classOf[WeekOfYear]) } test("DateFormat") { @@ -184,6 +191,7 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { } } } + checkConsistency(TimestampType, classOf[Hour]) } test("Minute") { @@ -200,6 +208,7 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { c.get(Calendar.MINUTE)) } } + checkConsistency(TimestampType, classOf[Minute]) } test("date_add") { @@ -218,6 +227,7 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { DateAdd(Literal(Date.valueOf("2016-02-28")), positiveIntLit), 49627) checkEvaluation( DateAdd(Literal(Date.valueOf("2016-02-28")), negativeIntLit), -15910) + checkConsistency(DateType, IntegerType, classOf[DateAdd]) } test("date_sub") { @@ -236,6 +246,7 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { DateSub(Literal(Date.valueOf("2016-02-28")), positiveIntLit), -15909) checkEvaluation( DateSub(Literal(Date.valueOf("2016-02-28")), negativeIntLit), 49628) + checkConsistency(DateType, IntegerType, classOf[DateSub]) } test("time_add") { @@ -254,6 +265,7 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation( TimeAdd(Literal.create(null, TimestampType), Literal.create(null, CalendarIntervalType)), null) + checkConsistency(TimestampType, CalendarIntervalType, classOf[TimeAdd]) } test("time_sub") { @@ -277,6 +289,7 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation( TimeSub(Literal.create(null, TimestampType), Literal.create(null, CalendarIntervalType)), null) + checkConsistency(TimestampType, CalendarIntervalType, classOf[TimeSub]) } test("add_months") { @@ -296,6 +309,7 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { AddMonths(Literal(Date.valueOf("2016-02-28")), positiveIntLit), 1014213) checkEvaluation( AddMonths(Literal(Date.valueOf("2016-02-28")), negativeIntLit), -980528) + checkConsistency(DateType, IntegerType, classOf[AddMonths]) } test("months_between") { @@ -320,6 +334,7 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(MonthsBetween(t, tnull), null) checkEvaluation(MonthsBetween(tnull, t), null) checkEvaluation(MonthsBetween(tnull, tnull), null) + checkConsistency(TimestampType, TimestampType, classOf[MonthsBetween]) } test("last_day") { @@ -337,6 +352,7 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(LastDay(Literal(Date.valueOf("2016-01-06"))), Date.valueOf("2016-01-31")) checkEvaluation(LastDay(Literal(Date.valueOf("2016-02-07"))), Date.valueOf("2016-02-29")) checkEvaluation(LastDay(Literal.create(null, DateType)), null) + checkConsistency(DateType, classOf[LastDay]) } test("next_day") { @@ -370,6 +386,7 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { ToDate(Literal(Date.valueOf("2015-07-22"))), DateTimeUtils.fromJavaDate(Date.valueOf("2015-07-22"))) checkEvaluation(ToDate(Literal.create(null, DateType)), null) + checkConsistency(DateType, classOf[ToDate]) } test("function trunc") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala index da9f8497a7c5b..a46b5ca840242 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.catalyst.expressions +import org.scalacheck.Gen import org.scalactic.TripleEqualsSupport.Spread import org.apache.spark.SparkFunSuite @@ -246,6 +247,16 @@ trait ExpressionEvalHelper extends PropertyGenerator { } } + def checkSeqConsistency(dt: DataType, clazz: Class[_], leastNumOfElements: Int = 0): Unit = { + val ctor = clazz.getDeclaredConstructor(classOf[Seq[Expression]]) + forAll (Gen.listOf(randomGen(dt))) { (literals: Seq[Literal]) => + whenever(literals.size >= leastNumOfElements) { + val expr = ctor.newInstance(literals).asInstanceOf[Expression] + cmpInterpretWithCodegen(EmptyRow, expr) + } + } + } + private def cmpInterpretWithCodegen(inputRow: InternalRow, expr: Expression): Unit = { val interpret = try evaluate(expr, inputRow) catch { case e: Exception => fail(s"Exception evaluating $expr", e) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PropertyGenerator.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PropertyGenerator.scala index 11fe3397a7019..34a376f0f8b9a 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PropertyGenerator.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PropertyGenerator.scala @@ -42,10 +42,16 @@ trait PropertyGenerator extends GeneratorDrivenPropertyChecks with Matchers { for { l <- Arbitrary.arbLong.arbitrary } yield Literal.create(l, LongType) lazy val floatLiteralGen: Gen[Literal] = - for { f <- Arbitrary.arbFloat.arbitrary } yield Literal.create(f, FloatType) + for { + f <- Gen.chooseNum(Float.MinValue / 2, Float.MaxValue / 2, + Float.NaN, Float.PositiveInfinity, Float.NegativeInfinity) + } yield Literal.create(f, FloatType) lazy val doubleLiteralGen: Gen[Literal] = - for { d <- Arbitrary.arbDouble.arbitrary } yield Literal.create(d, DoubleType) + for { + f <- Gen.chooseNum(Double.MinValue / 2, Double.MaxValue / 2, + Double.NaN, Double.PositiveInfinity, Double.NegativeInfinity) + } yield Literal.create(f, DoubleType) // TODO: decimal type @@ -54,7 +60,7 @@ trait PropertyGenerator extends GeneratorDrivenPropertyChecks with Matchers { lazy val binaryLiteralGen: Gen[Literal] = for { ab <- Gen.listOf[Byte](Arbitrary.arbByte.arbitrary) } - yield Literal.create(ab, BinaryType) + yield Literal.create(ab.toArray, BinaryType) lazy val booleanLiteralGen: Gen[Literal] = for { b <- Arbitrary.arbBool.arbitrary } yield Literal.create(b, BooleanType) @@ -74,8 +80,8 @@ trait PropertyGenerator extends GeneratorDrivenPropertyChecks with Matchers { // for example, the `times` arguments for StringRepeat would hang the test 'forever' // if it's tested against Int.MaxValue by ScalaCheck, therefore, use values from a limited // range is more reasonable - val limitedIntegerLiteralGen: Gen[Literal] = - for {i <- Gen.choose(-100, 100)} yield Literal.create(i, IntegerType) + lazy val limitedIntegerLiteralGen: Gen[Literal] = + for { i <- Gen.choose(-100, 100) } yield Literal.create(i, IntegerType) def randomGen(dt: DataType): Gen[Literal] = { dt match { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeTestUtils.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeTestUtils.scala index 18144ad23a07f..b27782ebc40b8 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeTestUtils.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeTestUtils.scala @@ -60,6 +60,11 @@ object DataTypeTestUtils { val ordered: Set[DataType] = numericTypeWithoutDecimal + BooleanType + TimestampType + DateType + StringType + BinaryType + /** + * All the types that we can use in a property check + */ + val propertyCheckSupported: Set[DataType] = ordered + /** * Instances of all [[AtomicType]]s. */ From 2100600d7b7e7887f73560101ce21dd1340a48e4 Mon Sep 17 00:00:00 2001 From: Yijie Shen Date: Sun, 2 Aug 2015 16:28:15 +0800 Subject: [PATCH 3/7] Finish first pass of property check --- .../sql/catalyst/expressions/predicates.scala | 3 +- .../expressions/MathFunctionsSuite.scala | 52 ++++++++++++++++--- .../expressions/MiscFunctionsSuite.scala | 4 +- .../catalyst/expressions/PredicateSuite.scala | 19 +++++++ 4 files changed, 69 insertions(+), 9 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index fe7dffb815987..e9d6f4a0ddd09 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -280,7 +280,8 @@ abstract class BinaryComparison extends BinaryOperator with Predicate { if (ctx.isPrimitiveType(left.dataType) && left.dataType != BooleanType // java boolean doesn't support > or < operator && left.dataType != FloatType - && left.dataType != DoubleType) { + && left.dataType != DoubleType + && left.dataType != BooleanType) { // faster version defineCodeGen(ctx, ev, (c1, c2) => s"$c1 $symbol $c2") } else { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala index d3ef4383de1fa..1fbf055dbafe9 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala @@ -27,7 +27,6 @@ import org.apache.spark.sql.catalyst.optimizer.DefaultOptimizer import org.apache.spark.sql.catalyst.plans.logical.{OneRowRelation, Project} import org.apache.spark.sql.types._ - class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { import IntegralLiteralTestUtils._ @@ -184,63 +183,74 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { test("sin") { testUnary(Sin, math.sin) + checkConsistency(DoubleType, classOf[Sin]) } test("asin") { testUnary(Asin, math.asin, (-10 to 10).map(_ * 0.1)) testUnary(Asin, math.asin, (11 to 20).map(_ * 0.1), expectNaN = true) + checkConsistency(DoubleType, classOf[Asin]) } test("sinh") { testUnary(Sinh, math.sinh) + checkConsistency(DoubleType, classOf[Sinh]) } test("cos") { testUnary(Cos, math.cos) + checkConsistency(DoubleType, classOf[Cos]) } test("acos") { - for (i <- 0 to 100) { - testUnary(Acos, math.acos, (-10 to 10).map(_ * 0.1)) - testUnary(Acos, math.acos, (11 to 20).map(_ * 0.1), expectNaN = true) - checkConsistency(DoubleType, classOf[Acos]) - } + testUnary(Acos, math.acos, (-10 to 10).map(_ * 0.1)) + testUnary(Acos, math.acos, (11 to 20).map(_ * 0.1), expectNaN = true) + checkConsistency(DoubleType, classOf[Acos]) } test("cosh") { testUnary(Cosh, math.cosh) + checkConsistency(DoubleType, classOf[Cosh]) } test("tan") { testUnary(Tan, math.tan) + checkConsistency(DoubleType, classOf[Tan]) } test("atan") { testUnary(Atan, math.atan) + checkConsistency(DoubleType, classOf[Atan]) } test("tanh") { testUnary(Tanh, math.tanh) + checkConsistency(DoubleType, classOf[Tanh]) } test("toDegrees") { testUnary(ToDegrees, math.toDegrees) + checkConsistency(DoubleType, classOf[Acos]) } test("toRadians") { testUnary(ToRadians, math.toRadians) + checkConsistency(DoubleType, classOf[ToRadians]) } test("cbrt") { testUnary(Cbrt, math.cbrt) + checkConsistency(DoubleType, classOf[Cbrt]) } test("ceil") { testUnary(Ceil, math.ceil) + checkConsistency(DoubleType, classOf[Ceil]) } test("floor") { testUnary(Floor, math.floor) + checkConsistency(DoubleType, classOf[Floor]) } test("factorial") { @@ -250,37 +260,45 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(Literal.create(null, IntegerType), null, create_row(null)) checkEvaluation(Factorial(Literal(20)), 2432902008176640000L, EmptyRow) checkEvaluation(Factorial(Literal(21)), null, EmptyRow) + checkConsistency(IntegerType, classOf[Factorial]) } test("rint") { testUnary(Rint, math.rint) + checkConsistency(DoubleType, classOf[Rint]) } test("exp") { testUnary(Exp, math.exp) + checkConsistency(DoubleType, classOf[Exp]) } test("expm1") { testUnary(Expm1, math.expm1) + checkConsistency(DoubleType, classOf[Expm1]) } test("signum") { testUnary[Double, Double](Signum, math.signum) + checkConsistency(DoubleType, classOf[Signum]) } test("log") { testUnary(Log, math.log, (1 to 20).map(_ * 0.1)) testUnary(Log, math.log, (-5 to 0).map(_ * 0.1), expectNull = true) + checkConsistency(DoubleType, classOf[Log]) } test("log10") { testUnary(Log10, math.log10, (1 to 20).map(_ * 0.1)) testUnary(Log10, math.log10, (-5 to 0).map(_ * 0.1), expectNull = true) + checkConsistency(DoubleType, classOf[Log10]) } test("log1p") { testUnary(Log1p, math.log1p, (0 to 20).map(_ * 0.1)) testUnary(Log1p, math.log1p, (-10 to -1).map(_ * 1.0), expectNull = true) + checkConsistency(DoubleType, classOf[Log1p]) } test("bin") { @@ -301,12 +319,15 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(Bin(positiveLongLit), java.lang.Long.toBinaryString(positiveLong)) checkEvaluation(Bin(negativeLongLit), java.lang.Long.toBinaryString(negativeLong)) + + checkConsistency(LongType, classOf[Bin]) } test("log2") { def f: (Double) => Double = (x: Double) => math.log(x) / math.log(2) testUnary(Log2, f, (1 to 20).map(_ * 0.1)) testUnary(Log2, f, (-5 to 0).map(_ * 1.0), expectNull = true) + checkConsistency(DoubleType, classOf[Log2]) } test("sqrt") { @@ -316,11 +337,13 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(Sqrt(Literal.create(null, DoubleType)), null, create_row(null)) checkNaN(Sqrt(Literal(-1.0)), EmptyRow) checkNaN(Sqrt(Literal(-1.5)), EmptyRow) + checkConsistency(DoubleType, classOf[Sqrt]) } test("pow") { testBinary(Pow, math.pow, (-5 to 5).map(v => (v * 1.0, v * 1.0))) testBinary(Pow, math.pow, Seq((-1.0, 0.9), (-2.2, 1.7), (-2.2, -1.7)), expectNaN = true) + checkConsistency(DoubleType, DoubleType, classOf[Pow]) } test("shift left") { @@ -341,6 +364,9 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(ShiftLeft(positiveLongLit, negativeIntLit), positiveLong << negativeInt) checkEvaluation(ShiftLeft(negativeLongLit, positiveIntLit), negativeLong << positiveInt) checkEvaluation(ShiftLeft(negativeLongLit, negativeIntLit), negativeLong << negativeInt) + + checkConsistency(IntegerType, IntegerType, classOf[ShiftLeft]) + checkConsistency(LongType, IntegerType, classOf[ShiftLeft]) } test("shift right") { @@ -361,6 +387,9 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(ShiftRight(positiveLongLit, negativeIntLit), positiveLong >> negativeInt) checkEvaluation(ShiftRight(negativeLongLit, positiveIntLit), negativeLong >> positiveInt) checkEvaluation(ShiftRight(negativeLongLit, negativeIntLit), negativeLong >> negativeInt) + + checkConsistency(IntegerType, IntegerType, classOf[ShiftRight]) + checkConsistency(LongType, IntegerType, classOf[ShiftRight]) } test("shift right unsigned") { @@ -389,6 +418,9 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { negativeLong >>> positiveInt) checkEvaluation(ShiftRightUnsigned(negativeLongLit, negativeIntLit), negativeLong >>> negativeInt) + + checkConsistency(IntegerType, IntegerType, classOf[ShiftRightUnsigned]) + checkConsistency(LongType, IntegerType, classOf[ShiftRightUnsigned]) } test("hex") { @@ -403,6 +435,9 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { // Turn off scala style for non-ascii chars checkEvaluation(Hex(Literal("三重的".getBytes("UTF8"))), "E4B889E9878DE79A84") // scalastyle:on + Seq(LongType, BinaryType, StringType).foreach { dt => + checkConsistency(dt, classOf[Hex]) + } } test("unhex") { @@ -416,16 +451,18 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { // Turn off scala style for non-ascii chars checkEvaluation(Unhex(Literal("E4B889E9878DE79A84")), "三重的".getBytes("UTF-8")) checkEvaluation(Unhex(Literal("三重的")), null) - // scalastyle:on + checkConsistency(StringType, classOf[Unhex]) } test("hypot") { testBinary(Hypot, math.hypot) + checkConsistency(DoubleType, DoubleType, classOf[Hypot]) } test("atan2") { testBinary(Atan2, math.atan2) + checkConsistency(DoubleType, DoubleType, classOf[Atan2]) } test("binary log") { @@ -457,6 +494,7 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { Logarithm(Literal(1.0), Literal(-1.0)), null, create_row(null)) + checkConsistency(DoubleType, DoubleType, classOf[Logarithm]) } test("round") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscFunctionsSuite.scala index b524d0af14a67..d51dd9f7dd4bb 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscFunctionsSuite.scala @@ -29,6 +29,7 @@ class MiscFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(Md5(Literal.create(Array[Byte](1, 2, 3, 4, 5, 6), BinaryType)), "6ac1e56bc78f031059be7be854522c4c") checkEvaluation(Md5(Literal.create(null, BinaryType)), null) + checkConsistency(BinaryType, classOf[Md5]) } test("sha1") { @@ -37,6 +38,7 @@ class MiscFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { "5d211bad8f4ee70e16c7d343a838fc344a1ed961") checkEvaluation(Sha1(Literal.create(null, BinaryType)), null) checkEvaluation(Sha1(Literal("".getBytes)), "da39a3ee5e6b4b0d3255bfef95601890afd80709") + checkConsistency(BinaryType, classOf[Sha1]) } test("sha2") { @@ -55,6 +57,6 @@ class MiscFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(Crc32(Literal.create(Array[Byte](1, 2, 3, 4, 5, 6), BinaryType)), 2180413220L) checkEvaluation(Crc32(Literal.create(null, BinaryType)), null) + checkConsistency(BinaryType, classOf[Crc32]) } - } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala index 7beef71845e43..bf67a4ef18c98 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala @@ -73,6 +73,16 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper { notTrueTable.foreach { case (v, answer) => checkEvaluation(Not(Literal.create(v, BooleanType)), answer) } + checkConsistency(BooleanType, classOf[Not]) + } + + test("AND, OR, EqualTo, EqualNullSafe consistency check") { + checkConsistency(BooleanType, BooleanType, classOf[And]) + checkConsistency(BooleanType, BooleanType, classOf[Or]) + DataTypeTestUtils.propertyCheckSupported.foreach { dt => + checkConsistency(dt, dt, classOf[EqualTo]) + checkConsistency(dt, dt, classOf[EqualNullSafe]) + } } booleanLogicTest("AND", And, @@ -180,6 +190,15 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper { private val equalValues2 = Seq(1, Decimal(1), Array(1.toByte), "a", Float.NaN, Double.NaN, true).map(Literal(_)) + test("BinaryComparison consistency check") { + DataTypeTestUtils.ordered.foreach { dt => + checkConsistency(dt, dt, classOf[LessThan]) + checkConsistency(dt, dt, classOf[LessThanOrEqual]) + checkConsistency(dt, dt, classOf[GreaterThan]) + checkConsistency(dt, dt, classOf[GreaterThanOrEqual]) + } + } + test("BinaryComparison: lessThan") { for (i <- 0 until smallValues.length) { checkEvaluation(LessThan(smallValues(i), largeValues(i)), true) From 4e362047b949e553f4fbbb4a00a3c0510bd7d57a Mon Sep 17 00:00:00 2001 From: Yijie Shen Date: Mon, 3 Aug 2015 10:31:27 +0800 Subject: [PATCH 4/7] address comments --- .../scala/org/apache/spark/sql/types/DataTypeTestUtils.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeTestUtils.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeTestUtils.scala index b27782ebc40b8..ed2c641d63e25 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeTestUtils.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeTestUtils.scala @@ -50,7 +50,7 @@ object DataTypeTestUtils { val numericTypeWithoutDecimal: Set[DataType] = integralType ++ Set(DoubleType, FloatType) /** - * Instances of all [[NumericType]]s and CalendarIntervalType + * Instances of all [[NumericType]]s and [[CalendarIntervalType]] */ val numericAndInterval: Set[DataType] = numericTypeWithoutDecimal + CalendarIntervalType From 645df77947b169483095439a9379c38181a22483 Mon Sep 17 00:00:00 2001 From: Yijie Shen Date: Tue, 4 Aug 2015 14:41:35 +0800 Subject: [PATCH 5/7] rename & add javadoc --- .../ArithmeticExpressionSuite.scala | 3 -- .../expressions/ExpressionEvalHelper.scala | 4 ++- ...Generator.scala => LiteralGenerator.scala} | 29 +++++++++++++++++-- 3 files changed, 30 insertions(+), 6 deletions(-) rename sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/{PropertyGenerator.scala => LiteralGenerator.scala} (81%) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala index c9a6325bcedb0..c8da5b17da23a 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala @@ -165,9 +165,6 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(Remainder(positiveLongLit, positiveLongLit), 0L) checkEvaluation(Remainder(negativeLongLit, negativeLongLit), 0L) - DataTypeTestUtils.numericTypeWithoutDecimal.foreach { tpe => - checkConsistency(tpe, tpe, classOf[Remainder]) - } // TODO: the following lines would fail the test due to inconsistency result of interpret // and codegen for remainder between giant values, seems like a numeric stability issue // DataTypeTestUtils.numericTypeWithoutDecimal.foreach { tpe => diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala index a46b5ca840242..473bff74ab106 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala @@ -20,6 +20,8 @@ package org.apache.spark.sql.catalyst.expressions import org.scalacheck.Gen import org.scalactic.TripleEqualsSupport.Spread +import org.scalatest.prop.GeneratorDrivenPropertyChecks + import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.expressions.codegen._ @@ -30,7 +32,7 @@ import org.apache.spark.sql.types.DataType /** * A few helper functions for expression evaluation testing. Mixin this trait to use them. */ -trait ExpressionEvalHelper extends PropertyGenerator { +trait ExpressionEvalHelper extends LiteralGenerator with GeneratorDrivenPropertyChecks { self: SparkFunSuite => protected def create_row(values: Any*): InternalRow = { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PropertyGenerator.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralGenerator.scala similarity index 81% rename from sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PropertyGenerator.scala rename to sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralGenerator.scala index 34a376f0f8b9a..377ab9d7925ce 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PropertyGenerator.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralGenerator.scala @@ -26,8 +26,33 @@ import org.scalatest.prop.GeneratorDrivenPropertyChecks import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.CalendarInterval - -trait PropertyGenerator extends GeneratorDrivenPropertyChecks with Matchers { +/** + * Property is a high-level specification of behavior that should hold for a range of data points. + * + * For example, while we are evaluating a deterministic expression for some input, we should always + * hold the property that the result never changes, regardless of how we get the result, + * via interpreted or codegen. + * + * In ScalaTest, properties are specified as functions and the data points used to check properties + * can be supplied by either tables or generators. + * + * Generator-driven property checks are performed via integration with ScalaCheck. + * + * @example {{{ + * def toTest(i: Int): Boolean = if (i % 2 == 0) true else false + * + * import org.scalacheck.Gen + * + * test ("true if param is even") { + * val evenInts = for (n <- Gen.choose(-1000, 1000)) yield 2 * n + * forAll(evenInts) { (i: Int) => + * assert (toTest(i) === true) + * } + * } + * }}} + * + */ +trait LiteralGenerator { lazy val byteLiteralGen: Gen[Literal] = for { b <- Arbitrary.arbByte.arbitrary } yield Literal.create(b, ByteType) From 963af5fb6fc61531c54caf2c72d2ea1109213fd7 Mon Sep 17 00:00:00 2001 From: Yijie Shen Date: Sun, 9 Aug 2015 09:42:55 +0800 Subject: [PATCH 6/7] typo fix --- .../org/apache/spark/sql/catalyst/expressions/predicates.scala | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index e9d6f4a0ddd09..fe7dffb815987 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -280,8 +280,7 @@ abstract class BinaryComparison extends BinaryOperator with Predicate { if (ctx.isPrimitiveType(left.dataType) && left.dataType != BooleanType // java boolean doesn't support > or < operator && left.dataType != FloatType - && left.dataType != DoubleType - && left.dataType != BooleanType) { + && left.dataType != DoubleType) { // faster version defineCodeGen(ctx, ev, (c1, c2) => s"$c1 $symbol $c2") } else { From 0a5bdc978d5292fd4b0569ad4bc27a7d1de22685 Mon Sep 17 00:00:00 2001 From: Yijie Shen Date: Sat, 15 Aug 2015 13:43:45 +0800 Subject: [PATCH 7/7] address comments --- .../ArithmeticExpressionSuite.scala | 20 ++-- .../expressions/BitwiseFunctionsSuite.scala | 8 +- .../ConditionalExpressionSuite.scala | 6 +- .../expressions/DateExpressionsSuite.scala | 34 +++---- .../expressions/ExpressionEvalHelper.scala | 92 ++++++++++++------- .../expressions/LiteralGenerator.scala | 2 +- .../expressions/MathFunctionsSuite.scala | 74 +++++++-------- .../expressions/MiscFunctionsSuite.scala | 6 +- .../catalyst/expressions/PredicateSuite.scala | 18 ++-- 9 files changed, 145 insertions(+), 115 deletions(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala index c8da5b17da23a..72285c6a24199 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala @@ -54,7 +54,7 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(Add(positiveLongLit, negativeLongLit), -1L) DataTypeTestUtils.numericAndInterval.foreach { tpe => - checkConsistency(tpe, tpe, classOf[Add]) + checkConsistencyBetweenInterpretedAndCodegen(Add, tpe, tpe) } } @@ -77,7 +77,7 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(UnaryMinus(negativeLongLit), - negativeLong) DataTypeTestUtils.numericAndInterval.foreach { tpe => - checkConsistency(tpe, classOf[UnaryMinus]) + checkConsistencyBetweenInterpretedAndCodegen(UnaryMinus, tpe) } } @@ -95,7 +95,7 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(Subtract(positiveLongLit, negativeLongLit), positiveLong - negativeLong) DataTypeTestUtils.numericAndInterval.foreach { tpe => - checkConsistency(tpe, tpe, classOf[Subtract]) + checkConsistencyBetweenInterpretedAndCodegen(Subtract, tpe, tpe) } } @@ -113,7 +113,7 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(Multiply(positiveLongLit, negativeLongLit), positiveLong * negativeLong) DataTypeTestUtils.numericTypeWithoutDecimal.foreach { tpe => - checkConsistency(tpe, tpe, classOf[Multiply]) + checkConsistencyBetweenInterpretedAndCodegen(Multiply, tpe, tpe) } } @@ -129,7 +129,7 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper } DataTypeTestUtils.numericTypeWithoutDecimal.foreach { tpe => - checkConsistency(tpe, tpe, classOf[Divide]) + checkConsistencyBetweenInterpretedAndCodegen(Divide, tpe, tpe) } } @@ -168,7 +168,7 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper // TODO: the following lines would fail the test due to inconsistency result of interpret // and codegen for remainder between giant values, seems like a numeric stability issue // DataTypeTestUtils.numericTypeWithoutDecimal.foreach { tpe => - // checkConsistency(tpe, tpe, classOf[Remainder]) + // checkConsistencyBetweenInterpretedAndCodegen(Remainder, tpe, tpe) // } } @@ -189,7 +189,7 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(Abs(negativeLongLit), - negativeLong) DataTypeTestUtils.numericTypeWithoutDecimal.foreach { tpe => - checkConsistency(tpe, classOf[Abs]) + checkConsistencyBetweenInterpretedAndCodegen(Abs, tpe) } } @@ -207,7 +207,7 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(MaxOf(positiveLongLit, negativeLongLit), positiveLong) DataTypeTestUtils.ordered.foreach { tpe => - checkConsistency(tpe, tpe, classOf[MaxOf]) + checkConsistencyBetweenInterpretedAndCodegen(MaxOf, tpe, tpe) } } @@ -232,7 +232,7 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(MinOf(positiveLongLit, negativeLongLit), negativeLong) DataTypeTestUtils.ordered.foreach { tpe => - checkConsistency(tpe, tpe, classOf[MinOf]) + checkConsistencyBetweenInterpretedAndCodegen(MinOf, tpe, tpe) } } @@ -262,6 +262,6 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper } DataTypeTestUtils.numericTypeWithoutDecimal.foreach { tpe => - checkConsistency(tpe, tpe, classOf[MinOf]) + checkConsistencyBetweenInterpretedAndCodegen(MinOf, tpe, tpe) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/BitwiseFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/BitwiseFunctionsSuite.scala index 9e34561ddf257..3a310c0e9a7a6 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/BitwiseFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/BitwiseFunctionsSuite.scala @@ -47,7 +47,7 @@ class BitwiseFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(BitwiseNot(negativeLongLit), ~negativeLong) DataTypeTestUtils.integralType.foreach { dt => - checkConsistency(dt, classOf[BitwiseNot]) + checkConsistencyBetweenInterpretedAndCodegen(BitwiseNot, dt) } } @@ -74,7 +74,7 @@ class BitwiseFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(BitwiseAnd(positiveLongLit, negativeLongLit), positiveLong & negativeLong) DataTypeTestUtils.integralType.foreach { dt => - checkConsistency(dt, dt, classOf[BitwiseAnd]) + checkConsistencyBetweenInterpretedAndCodegen(BitwiseAnd, dt, dt) } } @@ -101,7 +101,7 @@ class BitwiseFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(BitwiseOr(positiveLongLit, negativeLongLit), positiveLong | negativeLong) DataTypeTestUtils.integralType.foreach { dt => - checkConsistency(dt, dt, classOf[BitwiseOr]) + checkConsistencyBetweenInterpretedAndCodegen(BitwiseOr, dt, dt) } } @@ -128,7 +128,7 @@ class BitwiseFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(BitwiseXor(positiveLongLit, negativeLongLit), positiveLong ^ negativeLong) DataTypeTestUtils.integralType.foreach { dt => - checkConsistency(dt, dt, classOf[BitwiseXor]) + checkConsistencyBetweenInterpretedAndCodegen(BitwiseXor, dt, dt) } } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala index c4effe9df75ec..0df673bb9fa02 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala @@ -68,7 +68,7 @@ class ConditionalExpressionSuite extends SparkFunSuite with ExpressionEvalHelper testIf(_.toString, StringType) DataTypeTestUtils.propertyCheckSupported.foreach { dt => - checkConsistency(BooleanType, dt, dt, classOf[If]) + checkConsistencyBetweenInterpretedAndCodegen(If, BooleanType, dt, dt) } } @@ -182,7 +182,7 @@ class ConditionalExpressionSuite extends SparkFunSuite with ExpressionEvalHelper Timestamp.valueOf("2015-07-01 08:00:00"), InternalRow.empty) DataTypeTestUtils.ordered.foreach { dt => - checkSeqConsistency(dt, classOf[Least], 2) + checkConsistencyBetweenInterpretedAndCodegen(Least, dt, 2) } } @@ -228,7 +228,7 @@ class ConditionalExpressionSuite extends SparkFunSuite with ExpressionEvalHelper Timestamp.valueOf("2015-07-01 10:00:00"), InternalRow.empty) DataTypeTestUtils.ordered.foreach { dt => - checkSeqConsistency(dt, classOf[Greatest], 2) + checkConsistencyBetweenInterpretedAndCodegen(Greatest, dt, 2) } } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala index 7ab3bd6d516e5..610d39e8493cd 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala @@ -60,7 +60,7 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { } } checkEvaluation(DayOfYear(Literal.create(null, DateType)), null) - checkConsistency(DateType, classOf[DayOfYear]) + checkConsistencyBetweenInterpretedAndCodegen(DayOfYear, DateType) } test("Year") { @@ -80,7 +80,7 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { } } } - checkConsistency(DateType, classOf[Year]) + checkConsistencyBetweenInterpretedAndCodegen(Year, DateType) } test("Quarter") { @@ -100,7 +100,7 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { } } } - checkConsistency(DateType, classOf[Quarter]) + checkConsistencyBetweenInterpretedAndCodegen(Quarter, DateType) } test("Month") { @@ -120,7 +120,7 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { } } } - checkConsistency(DateType, classOf[Month]) + checkConsistencyBetweenInterpretedAndCodegen(Month, DateType) } test("Day / DayOfMonth") { @@ -139,7 +139,7 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { c.get(Calendar.DAY_OF_MONTH)) } } - checkConsistency(DateType, classOf[DayOfMonth]) + checkConsistencyBetweenInterpretedAndCodegen(DayOfMonth, DateType) } test("Seconds") { @@ -154,7 +154,7 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(Second(Literal(new Timestamp(c.getTimeInMillis))), c.get(Calendar.SECOND)) } - checkConsistency(TimestampType, classOf[Second]) + checkConsistencyBetweenInterpretedAndCodegen(Second, TimestampType) } test("WeekOfYear") { @@ -163,7 +163,7 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(WeekOfYear(Cast(Literal(sdfDate.format(d)), DateType)), 15) checkEvaluation(WeekOfYear(Cast(Literal(ts), DateType)), 45) checkEvaluation(WeekOfYear(Cast(Literal("2011-05-06"), DateType)), 18) - checkConsistency(DateType, classOf[WeekOfYear]) + checkConsistencyBetweenInterpretedAndCodegen(WeekOfYear, DateType) } test("DateFormat") { @@ -191,7 +191,7 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { } } } - checkConsistency(TimestampType, classOf[Hour]) + checkConsistencyBetweenInterpretedAndCodegen(Hour, TimestampType) } test("Minute") { @@ -208,7 +208,7 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { c.get(Calendar.MINUTE)) } } - checkConsistency(TimestampType, classOf[Minute]) + checkConsistencyBetweenInterpretedAndCodegen(Minute, TimestampType) } test("date_add") { @@ -227,7 +227,7 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { DateAdd(Literal(Date.valueOf("2016-02-28")), positiveIntLit), 49627) checkEvaluation( DateAdd(Literal(Date.valueOf("2016-02-28")), negativeIntLit), -15910) - checkConsistency(DateType, IntegerType, classOf[DateAdd]) + checkConsistencyBetweenInterpretedAndCodegen(DateAdd, DateType, IntegerType) } test("date_sub") { @@ -246,7 +246,7 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { DateSub(Literal(Date.valueOf("2016-02-28")), positiveIntLit), -15909) checkEvaluation( DateSub(Literal(Date.valueOf("2016-02-28")), negativeIntLit), 49628) - checkConsistency(DateType, IntegerType, classOf[DateSub]) + checkConsistencyBetweenInterpretedAndCodegen(DateSub, DateType, IntegerType) } test("time_add") { @@ -265,7 +265,7 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation( TimeAdd(Literal.create(null, TimestampType), Literal.create(null, CalendarIntervalType)), null) - checkConsistency(TimestampType, CalendarIntervalType, classOf[TimeAdd]) + checkConsistencyBetweenInterpretedAndCodegen(TimeAdd, TimestampType, CalendarIntervalType) } test("time_sub") { @@ -289,7 +289,7 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation( TimeSub(Literal.create(null, TimestampType), Literal.create(null, CalendarIntervalType)), null) - checkConsistency(TimestampType, CalendarIntervalType, classOf[TimeSub]) + checkConsistencyBetweenInterpretedAndCodegen(TimeSub, TimestampType, CalendarIntervalType) } test("add_months") { @@ -309,7 +309,7 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { AddMonths(Literal(Date.valueOf("2016-02-28")), positiveIntLit), 1014213) checkEvaluation( AddMonths(Literal(Date.valueOf("2016-02-28")), negativeIntLit), -980528) - checkConsistency(DateType, IntegerType, classOf[AddMonths]) + checkConsistencyBetweenInterpretedAndCodegen(AddMonths, DateType, IntegerType) } test("months_between") { @@ -334,7 +334,7 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(MonthsBetween(t, tnull), null) checkEvaluation(MonthsBetween(tnull, t), null) checkEvaluation(MonthsBetween(tnull, tnull), null) - checkConsistency(TimestampType, TimestampType, classOf[MonthsBetween]) + checkConsistencyBetweenInterpretedAndCodegen(MonthsBetween, TimestampType, TimestampType) } test("last_day") { @@ -352,7 +352,7 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(LastDay(Literal(Date.valueOf("2016-01-06"))), Date.valueOf("2016-01-31")) checkEvaluation(LastDay(Literal(Date.valueOf("2016-02-07"))), Date.valueOf("2016-02-29")) checkEvaluation(LastDay(Literal.create(null, DateType)), null) - checkConsistency(DateType, classOf[LastDay]) + checkConsistencyBetweenInterpretedAndCodegen(LastDay, DateType) } test("next_day") { @@ -386,7 +386,7 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { ToDate(Literal(Date.valueOf("2015-07-22"))), DateTimeUtils.fromJavaDate(Date.valueOf("2015-07-22"))) checkEvaluation(ToDate(Literal.create(null, DateType)), null) - checkConsistency(DateType, classOf[ToDate]) + checkConsistencyBetweenInterpretedAndCodegen(ToDate, DateType) } test("function trunc") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala index 473bff74ab106..465f7d08aa142 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala @@ -19,7 +19,6 @@ package org.apache.spark.sql.catalyst.expressions import org.scalacheck.Gen import org.scalactic.TripleEqualsSupport.Spread - import org.scalatest.prop.GeneratorDrivenPropertyChecks import org.apache.spark.SparkFunSuite @@ -32,7 +31,7 @@ import org.apache.spark.sql.types.DataType /** * A few helper functions for expression evaluation testing. Mixin this trait to use them. */ -trait ExpressionEvalHelper extends LiteralGenerator with GeneratorDrivenPropertyChecks { +trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks { self: SparkFunSuite => protected def create_row(values: Any*): InternalRow = { @@ -216,51 +215,82 @@ trait ExpressionEvalHelper extends LiteralGenerator with GeneratorDrivenProperty assert(checkResult(actual, expected)) } - def checkConsistency(dt: DataType, clazz: Class[_]): Unit = { - val ctor = clazz.getDeclaredConstructor(classOf[Expression]) - forAll (randomGen(dt)) { (l: Literal) => - val expr = ctor.newInstance(l).asInstanceOf[Expression] - cmpInterpretWithCodegen(EmptyRow, expr) + /** + * Test evaluation results between Interpreted mode and Codegen mode, making sure we have + * consistent result regardless of the evaluation method we use. + * + * This method test against unary expressions by feeding them arbitrary literals of `dataType`. + */ + def checkConsistencyBetweenInterpretedAndCodegen( + c: Expression => Expression, + dataType: DataType): Unit = { + forAll (LiteralGenerator.randomGen(dataType)) { (l: Literal) => + cmpInterpretWithCodegen(EmptyRow, c(l)) } } - def checkConsistency(dt1: DataType, dt2: DataType, clazz: Class[_]): Unit = { - val ctor = clazz.getDeclaredConstructor(classOf[Expression], classOf[Expression]) + /** + * Test evaluation results between Interpreted mode and Codegen mode, making sure we have + * consistent result regardless of the evaluation method we use. + * + * This method test against binary expressions by feeding them arbitrary literals of `dataType1` + * and `dataType2`. + */ + def checkConsistencyBetweenInterpretedAndCodegen( + c: (Expression, Expression) => Expression, + dataType1: DataType, + dataType2: DataType): Unit = { forAll ( - randomGen(dt1), - randomGen(dt2) + LiteralGenerator.randomGen(dataType1), + LiteralGenerator.randomGen(dataType2) ) { (l1: Literal, l2: Literal) => - val expr = ctor.newInstance(l1, l2).asInstanceOf[Expression] - cmpInterpretWithCodegen(EmptyRow, expr) + cmpInterpretWithCodegen(EmptyRow, c(l1, l2)) } } - def checkConsistency(dt1: DataType, dt2: DataType, dt3: DataType, - clazz: Class[_]): Unit = { - val ctor = clazz.getDeclaredConstructor( - classOf[Expression], classOf[Expression], classOf[Expression]) + /** + * Test evaluation results between Interpreted mode and Codegen mode, making sure we have + * consistent result regardless of the evaluation method we use. + * + * This method test against ternary expressions by feeding them arbitrary literals of `dataType1`, + * `dataType2` and `dataType3`. + */ + def checkConsistencyBetweenInterpretedAndCodegen( + c: (Expression, Expression, Expression) => Expression, + dataType1: DataType, + dataType2: DataType, + dataType3: DataType): Unit = { forAll ( - randomGen(dt1), - randomGen(dt2), - randomGen(dt3) + LiteralGenerator.randomGen(dataType1), + LiteralGenerator.randomGen(dataType2), + LiteralGenerator.randomGen(dataType3) ) { (l1: Literal, l2: Literal, l3: Literal) => - val expr = ctor.newInstance(l1, l2, l3).asInstanceOf[Expression] - cmpInterpretWithCodegen(EmptyRow, expr) + cmpInterpretWithCodegen(EmptyRow, c(l1, l2, l3)) } } - def checkSeqConsistency(dt: DataType, clazz: Class[_], leastNumOfElements: Int = 0): Unit = { - val ctor = clazz.getDeclaredConstructor(classOf[Seq[Expression]]) - forAll (Gen.listOf(randomGen(dt))) { (literals: Seq[Literal]) => - whenever(literals.size >= leastNumOfElements) { - val expr = ctor.newInstance(literals).asInstanceOf[Expression] - cmpInterpretWithCodegen(EmptyRow, expr) + /** + * Test evaluation results between Interpreted mode and Codegen mode, making sure we have + * consistent result regardless of the evaluation method we use. + * + * This method test against expressions take Seq[Expression] as input by feeding them + * arbitrary length Seq of arbitrary literal of `dataType`. + */ + def checkConsistencyBetweenInterpretedAndCodegen( + c: Seq[Expression] => Expression, + dataType: DataType, + minNumElements: Int = 0): Unit = { + forAll (Gen.listOf(LiteralGenerator.randomGen(dataType))) { (literals: Seq[Literal]) => + whenever(literals.size >= minNumElements) { + cmpInterpretWithCodegen(EmptyRow, c(literals)) } } } private def cmpInterpretWithCodegen(inputRow: InternalRow, expr: Expression): Unit = { - val interpret = try evaluate(expr, inputRow) catch { + val interpret = try { + evaluate(expr, inputRow) + } catch { case e: Exception => fail(s"Exception evaluating $expr", e) } @@ -269,7 +299,7 @@ trait ExpressionEvalHelper extends LiteralGenerator with GeneratorDrivenProperty expr) val codegen = plan(inputRow).get(0, expr.dataType) - if (!checkResultRegardingNaN(interpret, codegen)) { + if (!compareResults(interpret, codegen)) { fail(s"Incorrect evaluation: $expr, interpret: $interpret, codegen: $codegen") } } @@ -278,7 +308,7 @@ trait ExpressionEvalHelper extends LiteralGenerator with GeneratorDrivenProperty * Check the equality between result of expression and expected value, it will handle * Array[Byte] and Spread[Double]. */ - private[this] def checkResultRegardingNaN(result: Any, expected: Any): Boolean = { + private[this] def compareResults(result: Any, expected: Any): Boolean = { (result, expected) match { case (result: Array[Byte], expected: Array[Byte]) => java.util.Arrays.equals(result, expected) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralGenerator.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralGenerator.scala index 377ab9d7925ce..ee6d25157fc08 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralGenerator.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralGenerator.scala @@ -52,7 +52,7 @@ import org.apache.spark.unsafe.types.CalendarInterval * }}} * */ -trait LiteralGenerator { +object LiteralGenerator { lazy val byteLiteralGen: Gen[Literal] = for { b <- Arbitrary.arbByte.arbitrary } yield Literal.create(b, ByteType) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala index 1fbf055dbafe9..90c59f240b542 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala @@ -183,74 +183,74 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { test("sin") { testUnary(Sin, math.sin) - checkConsistency(DoubleType, classOf[Sin]) + checkConsistencyBetweenInterpretedAndCodegen(Sin, DoubleType) } test("asin") { testUnary(Asin, math.asin, (-10 to 10).map(_ * 0.1)) testUnary(Asin, math.asin, (11 to 20).map(_ * 0.1), expectNaN = true) - checkConsistency(DoubleType, classOf[Asin]) + checkConsistencyBetweenInterpretedAndCodegen(Asin, DoubleType) } test("sinh") { testUnary(Sinh, math.sinh) - checkConsistency(DoubleType, classOf[Sinh]) + checkConsistencyBetweenInterpretedAndCodegen(Sinh, DoubleType) } test("cos") { testUnary(Cos, math.cos) - checkConsistency(DoubleType, classOf[Cos]) + checkConsistencyBetweenInterpretedAndCodegen(Cos, DoubleType) } test("acos") { testUnary(Acos, math.acos, (-10 to 10).map(_ * 0.1)) testUnary(Acos, math.acos, (11 to 20).map(_ * 0.1), expectNaN = true) - checkConsistency(DoubleType, classOf[Acos]) + checkConsistencyBetweenInterpretedAndCodegen(Acos, DoubleType) } test("cosh") { testUnary(Cosh, math.cosh) - checkConsistency(DoubleType, classOf[Cosh]) + checkConsistencyBetweenInterpretedAndCodegen(Cosh, DoubleType) } test("tan") { testUnary(Tan, math.tan) - checkConsistency(DoubleType, classOf[Tan]) + checkConsistencyBetweenInterpretedAndCodegen(Tan, DoubleType) } test("atan") { testUnary(Atan, math.atan) - checkConsistency(DoubleType, classOf[Atan]) + checkConsistencyBetweenInterpretedAndCodegen(Atan, DoubleType) } test("tanh") { testUnary(Tanh, math.tanh) - checkConsistency(DoubleType, classOf[Tanh]) + checkConsistencyBetweenInterpretedAndCodegen(Tanh, DoubleType) } test("toDegrees") { testUnary(ToDegrees, math.toDegrees) - checkConsistency(DoubleType, classOf[Acos]) + checkConsistencyBetweenInterpretedAndCodegen(Acos, DoubleType) } test("toRadians") { testUnary(ToRadians, math.toRadians) - checkConsistency(DoubleType, classOf[ToRadians]) + checkConsistencyBetweenInterpretedAndCodegen(ToRadians, DoubleType) } test("cbrt") { testUnary(Cbrt, math.cbrt) - checkConsistency(DoubleType, classOf[Cbrt]) + checkConsistencyBetweenInterpretedAndCodegen(Cbrt, DoubleType) } test("ceil") { testUnary(Ceil, math.ceil) - checkConsistency(DoubleType, classOf[Ceil]) + checkConsistencyBetweenInterpretedAndCodegen(Ceil, DoubleType) } test("floor") { testUnary(Floor, math.floor) - checkConsistency(DoubleType, classOf[Floor]) + checkConsistencyBetweenInterpretedAndCodegen(Floor, DoubleType) } test("factorial") { @@ -260,45 +260,45 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(Literal.create(null, IntegerType), null, create_row(null)) checkEvaluation(Factorial(Literal(20)), 2432902008176640000L, EmptyRow) checkEvaluation(Factorial(Literal(21)), null, EmptyRow) - checkConsistency(IntegerType, classOf[Factorial]) + checkConsistencyBetweenInterpretedAndCodegen(Factorial.apply _, IntegerType) } test("rint") { testUnary(Rint, math.rint) - checkConsistency(DoubleType, classOf[Rint]) + checkConsistencyBetweenInterpretedAndCodegen(Rint, DoubleType) } test("exp") { testUnary(Exp, math.exp) - checkConsistency(DoubleType, classOf[Exp]) + checkConsistencyBetweenInterpretedAndCodegen(Exp, DoubleType) } test("expm1") { testUnary(Expm1, math.expm1) - checkConsistency(DoubleType, classOf[Expm1]) + checkConsistencyBetweenInterpretedAndCodegen(Expm1, DoubleType) } test("signum") { testUnary[Double, Double](Signum, math.signum) - checkConsistency(DoubleType, classOf[Signum]) + checkConsistencyBetweenInterpretedAndCodegen(Signum, DoubleType) } test("log") { testUnary(Log, math.log, (1 to 20).map(_ * 0.1)) testUnary(Log, math.log, (-5 to 0).map(_ * 0.1), expectNull = true) - checkConsistency(DoubleType, classOf[Log]) + checkConsistencyBetweenInterpretedAndCodegen(Log, DoubleType) } test("log10") { testUnary(Log10, math.log10, (1 to 20).map(_ * 0.1)) testUnary(Log10, math.log10, (-5 to 0).map(_ * 0.1), expectNull = true) - checkConsistency(DoubleType, classOf[Log10]) + checkConsistencyBetweenInterpretedAndCodegen(Log10, DoubleType) } test("log1p") { testUnary(Log1p, math.log1p, (0 to 20).map(_ * 0.1)) testUnary(Log1p, math.log1p, (-10 to -1).map(_ * 1.0), expectNull = true) - checkConsistency(DoubleType, classOf[Log1p]) + checkConsistencyBetweenInterpretedAndCodegen(Log1p, DoubleType) } test("bin") { @@ -320,14 +320,14 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(Bin(positiveLongLit), java.lang.Long.toBinaryString(positiveLong)) checkEvaluation(Bin(negativeLongLit), java.lang.Long.toBinaryString(negativeLong)) - checkConsistency(LongType, classOf[Bin]) + checkConsistencyBetweenInterpretedAndCodegen(Bin, LongType) } test("log2") { def f: (Double) => Double = (x: Double) => math.log(x) / math.log(2) testUnary(Log2, f, (1 to 20).map(_ * 0.1)) testUnary(Log2, f, (-5 to 0).map(_ * 1.0), expectNull = true) - checkConsistency(DoubleType, classOf[Log2]) + checkConsistencyBetweenInterpretedAndCodegen(Log2, DoubleType) } test("sqrt") { @@ -337,13 +337,13 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(Sqrt(Literal.create(null, DoubleType)), null, create_row(null)) checkNaN(Sqrt(Literal(-1.0)), EmptyRow) checkNaN(Sqrt(Literal(-1.5)), EmptyRow) - checkConsistency(DoubleType, classOf[Sqrt]) + checkConsistencyBetweenInterpretedAndCodegen(Sqrt, DoubleType) } test("pow") { testBinary(Pow, math.pow, (-5 to 5).map(v => (v * 1.0, v * 1.0))) testBinary(Pow, math.pow, Seq((-1.0, 0.9), (-2.2, 1.7), (-2.2, -1.7)), expectNaN = true) - checkConsistency(DoubleType, DoubleType, classOf[Pow]) + checkConsistencyBetweenInterpretedAndCodegen(Pow, DoubleType, DoubleType) } test("shift left") { @@ -365,8 +365,8 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(ShiftLeft(negativeLongLit, positiveIntLit), negativeLong << positiveInt) checkEvaluation(ShiftLeft(negativeLongLit, negativeIntLit), negativeLong << negativeInt) - checkConsistency(IntegerType, IntegerType, classOf[ShiftLeft]) - checkConsistency(LongType, IntegerType, classOf[ShiftLeft]) + checkConsistencyBetweenInterpretedAndCodegen(ShiftLeft, IntegerType, IntegerType) + checkConsistencyBetweenInterpretedAndCodegen(ShiftLeft, LongType, IntegerType) } test("shift right") { @@ -388,8 +388,8 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(ShiftRight(negativeLongLit, positiveIntLit), negativeLong >> positiveInt) checkEvaluation(ShiftRight(negativeLongLit, negativeIntLit), negativeLong >> negativeInt) - checkConsistency(IntegerType, IntegerType, classOf[ShiftRight]) - checkConsistency(LongType, IntegerType, classOf[ShiftRight]) + checkConsistencyBetweenInterpretedAndCodegen(ShiftRight, IntegerType, IntegerType) + checkConsistencyBetweenInterpretedAndCodegen(ShiftRight, LongType, IntegerType) } test("shift right unsigned") { @@ -419,8 +419,8 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(ShiftRightUnsigned(negativeLongLit, negativeIntLit), negativeLong >>> negativeInt) - checkConsistency(IntegerType, IntegerType, classOf[ShiftRightUnsigned]) - checkConsistency(LongType, IntegerType, classOf[ShiftRightUnsigned]) + checkConsistencyBetweenInterpretedAndCodegen(ShiftRightUnsigned, IntegerType, IntegerType) + checkConsistencyBetweenInterpretedAndCodegen(ShiftRightUnsigned, LongType, IntegerType) } test("hex") { @@ -436,7 +436,7 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(Hex(Literal("三重的".getBytes("UTF8"))), "E4B889E9878DE79A84") // scalastyle:on Seq(LongType, BinaryType, StringType).foreach { dt => - checkConsistency(dt, classOf[Hex]) + checkConsistencyBetweenInterpretedAndCodegen(Hex.apply _, dt) } } @@ -452,17 +452,17 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(Unhex(Literal("E4B889E9878DE79A84")), "三重的".getBytes("UTF-8")) checkEvaluation(Unhex(Literal("三重的")), null) // scalastyle:on - checkConsistency(StringType, classOf[Unhex]) + checkConsistencyBetweenInterpretedAndCodegen(Unhex, StringType) } test("hypot") { testBinary(Hypot, math.hypot) - checkConsistency(DoubleType, DoubleType, classOf[Hypot]) + checkConsistencyBetweenInterpretedAndCodegen(Hypot, DoubleType, DoubleType) } test("atan2") { testBinary(Atan2, math.atan2) - checkConsistency(DoubleType, DoubleType, classOf[Atan2]) + checkConsistencyBetweenInterpretedAndCodegen(Atan2, DoubleType, DoubleType) } test("binary log") { @@ -494,7 +494,7 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { Logarithm(Literal(1.0), Literal(-1.0)), null, create_row(null)) - checkConsistency(DoubleType, DoubleType, classOf[Logarithm]) + checkConsistencyBetweenInterpretedAndCodegen(Logarithm, DoubleType, DoubleType) } test("round") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscFunctionsSuite.scala index d51dd9f7dd4bb..75d17417e5a02 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscFunctionsSuite.scala @@ -29,7 +29,7 @@ class MiscFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(Md5(Literal.create(Array[Byte](1, 2, 3, 4, 5, 6), BinaryType)), "6ac1e56bc78f031059be7be854522c4c") checkEvaluation(Md5(Literal.create(null, BinaryType)), null) - checkConsistency(BinaryType, classOf[Md5]) + checkConsistencyBetweenInterpretedAndCodegen(Md5, BinaryType) } test("sha1") { @@ -38,7 +38,7 @@ class MiscFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { "5d211bad8f4ee70e16c7d343a838fc344a1ed961") checkEvaluation(Sha1(Literal.create(null, BinaryType)), null) checkEvaluation(Sha1(Literal("".getBytes)), "da39a3ee5e6b4b0d3255bfef95601890afd80709") - checkConsistency(BinaryType, classOf[Sha1]) + checkConsistencyBetweenInterpretedAndCodegen(Sha1, BinaryType) } test("sha2") { @@ -57,6 +57,6 @@ class MiscFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(Crc32(Literal.create(Array[Byte](1, 2, 3, 4, 5, 6), BinaryType)), 2180413220L) checkEvaluation(Crc32(Literal.create(null, BinaryType)), null) - checkConsistency(BinaryType, classOf[Crc32]) + checkConsistencyBetweenInterpretedAndCodegen(Crc32, BinaryType) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala index bf67a4ef18c98..54c04faddb477 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala @@ -73,15 +73,15 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper { notTrueTable.foreach { case (v, answer) => checkEvaluation(Not(Literal.create(v, BooleanType)), answer) } - checkConsistency(BooleanType, classOf[Not]) + checkConsistencyBetweenInterpretedAndCodegen(Not, BooleanType) } test("AND, OR, EqualTo, EqualNullSafe consistency check") { - checkConsistency(BooleanType, BooleanType, classOf[And]) - checkConsistency(BooleanType, BooleanType, classOf[Or]) + checkConsistencyBetweenInterpretedAndCodegen(And, BooleanType, BooleanType) + checkConsistencyBetweenInterpretedAndCodegen(Or, BooleanType, BooleanType) DataTypeTestUtils.propertyCheckSupported.foreach { dt => - checkConsistency(dt, dt, classOf[EqualTo]) - checkConsistency(dt, dt, classOf[EqualNullSafe]) + checkConsistencyBetweenInterpretedAndCodegen(EqualTo, dt, dt) + checkConsistencyBetweenInterpretedAndCodegen(EqualNullSafe, dt, dt) } } @@ -192,10 +192,10 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper { test("BinaryComparison consistency check") { DataTypeTestUtils.ordered.foreach { dt => - checkConsistency(dt, dt, classOf[LessThan]) - checkConsistency(dt, dt, classOf[LessThanOrEqual]) - checkConsistency(dt, dt, classOf[GreaterThan]) - checkConsistency(dt, dt, classOf[GreaterThanOrEqual]) + checkConsistencyBetweenInterpretedAndCodegen(LessThan, dt, dt) + checkConsistencyBetweenInterpretedAndCodegen(LessThanOrEqual, dt, dt) + checkConsistencyBetweenInterpretedAndCodegen(GreaterThan, dt, dt) + checkConsistencyBetweenInterpretedAndCodegen(GreaterThanOrEqual, dt, dt) } }