From 1e29eab8fea5152a7608847dffbb5beb01742b8e Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Tue, 4 Jul 2017 00:07:13 +0900 Subject: [PATCH 1/5] Strict argument checks for array() and map() --- .../expressions/complexTypeCreator.scala | 13 +++++-- .../spark/sql/DataFrameFunctionsSuite.scala | 37 +++++++++++++++++++ 2 files changed, 47 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala index 98c4cbee38de..530f49914db2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala @@ -41,8 +41,13 @@ case class CreateArray(children: Seq[Expression]) extends Expression { override def foldable: Boolean = children.forall(_.foldable) - override def checkInputDataTypes(): TypeCheckResult = - TypeUtils.checkForSameTypeInputExpr(children.map(_.dataType), "function array") + override def checkInputDataTypes(): TypeCheckResult = { + if (children == Nil) { + TypeCheckResult.TypeCheckFailure("input to function coalesce cannot be empty") + } else { + TypeUtils.checkForSameTypeInputExpr(children.map(_.dataType), "function array") + } + } override def dataType: ArrayType = { ArrayType( @@ -168,7 +173,9 @@ case class CreateMap(children: Seq[Expression]) extends Expression { override def foldable: Boolean = children.forall(_.foldable) override def checkInputDataTypes(): TypeCheckResult = { - if (children.size % 2 != 0) { + if (children == Nil) { + TypeCheckResult.TypeCheckFailure("input to function coalesce cannot be empty") + } else if (children.size % 2 != 0) { TypeCheckResult.TypeCheckFailure(s"$prettyName expects a positive even number of arguments.") } else if (keys.map(_.dataType).distinct.length > 1) { TypeCheckResult.TypeCheckFailure("The given keys of function map should all be the same " + diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 0e9a2c6cf7de..b356ffe110d0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -448,6 +448,43 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { rand(Random.nextLong()), randn(Random.nextLong()) ).foreach(assertValuesDoNotChangeAfterCoalesceOrUnion(_)) } + + test("SPARK-21281 fails if functions have no argument") { + var errMsg = intercept[AnalysisException] { + spark.range(1).select(array()) + }.getMessage + assert(errMsg.contains("due to data type mismatch: input to function coalesce cannot be empty")) + + errMsg = intercept[AnalysisException] { + spark.range(1).select(map()) + }.getMessage + assert(errMsg.contains("due to data type mismatch: input to function coalesce cannot be empty")) + + // spark.range(1).select(coalesce()) + errMsg = intercept[AnalysisException] { + spark.range(1).select(coalesce()) + }.getMessage + assert(errMsg.contains("due to data type mismatch: input to function coalesce cannot be empty")) + + // This hits java.lang.AssertionError + // spark.range(1).select(struct()) + + errMsg = intercept[IllegalArgumentException] { + spark.range(1).select(greatest()) + }.getMessage + assert(errMsg.contains("requirement failed: greatest requires at least 2 arguments")) + + errMsg = intercept[IllegalArgumentException] { + spark.range(1).select(least()) + }.getMessage + assert(errMsg.contains("requirement failed: least requires at least 2 arguments")) + + errMsg = intercept[AnalysisException] { + spark.range(1).select(hash()) + }.getMessage + assert(errMsg.contains( + "due to data type mismatch: function hash requires at least one argument")) + } } object DataFrameFunctionsSuite { From 04b904b8eb6189255e19a62ade2dce990d5c6e79 Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Tue, 4 Jul 2017 17:00:57 +0900 Subject: [PATCH 2/5] Apply more fixes --- .../sql/catalyst/expressions/arithmetic.scala | 40 ++++++----- .../expressions/complexTypeCreator.scala | 69 +++++++++++-------- .../spark/sql/catalyst/expressions/hash.scala | 9 +-- .../expressions/nullExpressions.scala | 6 +- .../spark/sql/catalyst/util/TypeUtils.scala | 36 +++++++--- .../org/apache/spark/sql/functions.scala | 10 +-- .../spark/sql/DataFrameFunctionsSuite.scala | 61 ++++++++-------- 7 files changed, 122 insertions(+), 109 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index ec6e6ba0f091..2c00a0df014e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -526,14 +526,18 @@ case class Least(children: Seq[Expression]) extends Expression { private lazy val ordering = TypeUtils.getInterpretedOrdering(dataType) override def checkInputDataTypes(): TypeCheckResult = { - if (children.length <= 1) { - TypeCheckResult.TypeCheckFailure(s"LEAST requires at least 2 arguments") - } else if (children.map(_.dataType).distinct.count(_ != NullType) > 1) { - TypeCheckResult.TypeCheckFailure( - s"The expressions should all have the same type," + - s" got LEAST(${children.map(_.dataType.simpleString).mkString(", ")}).") - } else { - TypeUtils.checkForOrderingExpr(dataType, "function " + prettyName) + TypeUtils.checkTypeInputDimension( + children.map(_.dataType), s"function $prettyName", requiredMinDimension = 2) match { + case TypeCheckResult.TypeCheckSuccess => + if (children.map(_.dataType).distinct.count(_ != NullType) > 1) { + TypeCheckResult.TypeCheckFailure( + s"The expressions should all have the same type," + + s" got LEAST(${children.map(_.dataType.simpleString).mkString(", ")}).") + } else { + TypeUtils.checkForOrderingExpr(dataType, s"function $prettyName") + } + case typeCheckFailure => + typeCheckFailure } } @@ -591,14 +595,18 @@ case class Greatest(children: Seq[Expression]) extends Expression { private lazy val ordering = TypeUtils.getInterpretedOrdering(dataType) override def checkInputDataTypes(): TypeCheckResult = { - if (children.length <= 1) { - TypeCheckResult.TypeCheckFailure(s"GREATEST requires at least 2 arguments") - } else if (children.map(_.dataType).distinct.count(_ != NullType) > 1) { - TypeCheckResult.TypeCheckFailure( - s"The expressions should all have the same type," + - s" got GREATEST(${children.map(_.dataType.simpleString).mkString(", ")}).") - } else { - TypeUtils.checkForOrderingExpr(dataType, "function " + prettyName) + TypeUtils.checkTypeInputDimension( + children.map(_.dataType), s"function $prettyName", requiredMinDimension = 2) match { + case TypeCheckResult.TypeCheckSuccess => + if (children.map(_.dataType).distinct.count(_ != NullType) > 1) { + TypeCheckResult.TypeCheckFailure( + s"The expressions should all have the same type," + + s" got GREATEST(${children.map(_.dataType.simpleString).mkString(", ")}).") + } else { + TypeUtils.checkForOrderingExpr(dataType, s"function $prettyName") + } + case typeCheckFailure => + typeCheckFailure } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala index 530f49914db2..4584fafc256d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala @@ -42,11 +42,7 @@ case class CreateArray(children: Seq[Expression]) extends Expression { override def foldable: Boolean = children.forall(_.foldable) override def checkInputDataTypes(): TypeCheckResult = { - if (children == Nil) { - TypeCheckResult.TypeCheckFailure("input to function coalesce cannot be empty") - } else { - TypeUtils.checkForSameTypeInputExpr(children.map(_.dataType), "function array") - } + TypeUtils.checkForSameTypeInputExpr(children.map(_.dataType), s"function $prettyName") } override def dataType: ArrayType = { @@ -173,18 +169,25 @@ case class CreateMap(children: Seq[Expression]) extends Expression { override def foldable: Boolean = children.forall(_.foldable) override def checkInputDataTypes(): TypeCheckResult = { - if (children == Nil) { - TypeCheckResult.TypeCheckFailure("input to function coalesce cannot be empty") - } else if (children.size % 2 != 0) { - TypeCheckResult.TypeCheckFailure(s"$prettyName expects a positive even number of arguments.") - } else if (keys.map(_.dataType).distinct.length > 1) { - TypeCheckResult.TypeCheckFailure("The given keys of function map should all be the same " + - "type, but they are " + keys.map(_.dataType.simpleString).mkString("[", ", ", "]")) - } else if (values.map(_.dataType).distinct.length > 1) { - TypeCheckResult.TypeCheckFailure("The given values of function map should all be the same " + - "type, but they are " + values.map(_.dataType.simpleString).mkString("[", ", ", "]")) - } else { - TypeCheckResult.TypeCheckSuccess + TypeUtils.checkTypeInputDimension( + children.map(_.dataType), s"function $prettyName", requiredMinDimension = 1) match { + case TypeCheckResult.TypeCheckSuccess => + if (children.size % 2 != 0) { + TypeCheckResult.TypeCheckFailure( + s"$prettyName expects a positive even number of arguments.") + } else if (keys.map(_.dataType).distinct.length > 1) { + TypeCheckResult.TypeCheckFailure( + "The given keys of function map should all be the same type, but they are " + + keys.map(_.dataType.simpleString).mkString("[", ", ", "]")) + } else if (values.map(_.dataType).distinct.length > 1) { + TypeCheckResult.TypeCheckFailure( + "The given values of function map should all be the same type, but they are " + + values.map(_.dataType.simpleString).mkString("[", ", ", "]")) + } else { + TypeCheckResult.TypeCheckSuccess + } + case typeCheckFailure => + typeCheckFailure } } @@ -299,19 +302,25 @@ trait CreateNamedStructLike extends Expression { } override def checkInputDataTypes(): TypeCheckResult = { - if (children.size % 2 != 0) { - TypeCheckResult.TypeCheckFailure(s"$prettyName expects an even number of arguments.") - } else { - val invalidNames = nameExprs.filterNot(e => e.foldable && e.dataType == StringType) - if (invalidNames.nonEmpty) { - TypeCheckResult.TypeCheckFailure( - "Only foldable StringType expressions are allowed to appear at odd position, got:" + - s" ${invalidNames.mkString(",")}") - } else if (!names.contains(null)) { - TypeCheckResult.TypeCheckSuccess - } else { - TypeCheckResult.TypeCheckFailure("Field name should not be null") - } + TypeUtils.checkTypeInputDimension( + children.map(_.dataType), s"function $prettyName", requiredMinDimension = 1) match { + case TypeCheckResult.TypeCheckSuccess => + if (children.size % 2 != 0) { + TypeCheckResult.TypeCheckFailure(s"$prettyName expects an even number of arguments.") + } else { + val invalidNames = nameExprs.filterNot(e => e.foldable && e.dataType == StringType) + if (invalidNames.nonEmpty) { + TypeCheckResult.TypeCheckFailure( + "Only foldable StringType expressions are allowed to appear at odd position, got:" + + s" ${invalidNames.mkString(",")}") + } else if (!names.contains(null)) { + TypeCheckResult.TypeCheckSuccess + } else { + TypeCheckResult.TypeCheckFailure("Field name should not be null") + } + } + case typeCheckFailure => + typeCheckFailure } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala index ffd0e64d86cf..f928917c9ab0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala @@ -28,7 +28,7 @@ import org.apache.commons.codec.digest.DigestUtils import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen._ -import org.apache.spark.sql.catalyst.util.{ArrayData, MapData} +import org.apache.spark.sql.catalyst.util.{ArrayData, MapData, TypeUtils} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.hash.Murmur3_x86_32 import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} @@ -247,11 +247,8 @@ abstract class HashExpression[E] extends Expression { override def nullable: Boolean = false override def checkInputDataTypes(): TypeCheckResult = { - if (children.isEmpty) { - TypeCheckResult.TypeCheckFailure("function hash requires at least one argument") - } else { - TypeCheckResult.TypeCheckSuccess - } + TypeUtils.checkTypeInputDimension( + children.map(_.dataType), s"function $prettyName", requiredMinDimension = 1) } override def eval(input: InternalRow = null): Any = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala index 0866b8d791e0..f9d7a4213af1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala @@ -52,11 +52,7 @@ case class Coalesce(children: Seq[Expression]) extends Expression { override def foldable: Boolean = children.forall(_.foldable) override def checkInputDataTypes(): TypeCheckResult = { - if (children == Nil) { - TypeCheckResult.TypeCheckFailure("input to function coalesce cannot be empty") - } else { - TypeUtils.checkForSameTypeInputExpr(children.map(_.dataType), "function coalesce") - } + TypeUtils.checkForSameTypeInputExpr(children.map(_.dataType), s"function $prettyName") } override def dataType: DataType = children.head.dataType diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala index 7101ca5a17de..8da87514a86a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala @@ -42,18 +42,34 @@ object TypeUtils { } def checkForSameTypeInputExpr(types: Seq[DataType], caller: String): TypeCheckResult = { - if (types.size <= 1) { - TypeCheckResult.TypeCheckSuccess - } else { - val firstType = types.head - types.foreach { t => - if (!t.sameType(firstType)) { - return TypeCheckResult.TypeCheckFailure( - s"input to $caller should all be the same type, but it's " + - types.map(_.simpleString).mkString("[", ", ", "]")) + checkTypeInputDimension(types, caller, requiredMinDimension = 1) match { + case TypeCheckResult.TypeCheckSuccess => + if (types.size == 1) { + TypeCheckResult.TypeCheckSuccess + } else { + val firstType = types.head + types.foreach { t => + if (!t.sameType(firstType)) { + return TypeCheckResult.TypeCheckFailure( + s"input to $caller should all be the same type, but it's " + + types.map(_.simpleString).mkString("[", ", ", "]")) + } + } + TypeCheckResult.TypeCheckSuccess } - } + case typeCheckFailure => + typeCheckFailure + } + } + + def checkTypeInputDimension(types: Seq[DataType], caller: String, requiredMinDimension: Int) + : TypeCheckResult = { + if (types.size >= requiredMinDimension) { TypeCheckResult.TypeCheckSuccess + } else { + TypeCheckResult.TypeCheckFailure( + s"input to $caller requires at least $requiredMinDimension " + + s"argument${if (requiredMinDimension > 1) "s"}") } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 839cbf42024e..66890c19d3b2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -1566,10 +1566,7 @@ object functions { * @since 1.5.0 */ @scala.annotation.varargs - def greatest(exprs: Column*): Column = withExpr { - require(exprs.length > 1, "greatest requires at least 2 arguments.") - Greatest(exprs.map(_.expr)) - } + def greatest(exprs: Column*): Column = withExpr { Greatest(exprs.map(_.expr)) } /** * Returns the greatest value of the list of column names, skipping null values. @@ -1673,10 +1670,7 @@ object functions { * @since 1.5.0 */ @scala.annotation.varargs - def least(exprs: Column*): Column = withExpr { - require(exprs.length > 1, "least requires at least 2 arguments.") - Least(exprs.map(_.expr)) - } + def least(exprs: Column*): Column = withExpr { Least(exprs.map(_.expr)) } /** * Returns the least value of the list of column names, skipping null values. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index b356ffe110d0..55b8a89e3899 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -450,40 +450,33 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { } test("SPARK-21281 fails if functions have no argument") { - var errMsg = intercept[AnalysisException] { - spark.range(1).select(array()) - }.getMessage - assert(errMsg.contains("due to data type mismatch: input to function coalesce cannot be empty")) - - errMsg = intercept[AnalysisException] { - spark.range(1).select(map()) - }.getMessage - assert(errMsg.contains("due to data type mismatch: input to function coalesce cannot be empty")) - - // spark.range(1).select(coalesce()) - errMsg = intercept[AnalysisException] { - spark.range(1).select(coalesce()) - }.getMessage - assert(errMsg.contains("due to data type mismatch: input to function coalesce cannot be empty")) - - // This hits java.lang.AssertionError - // spark.range(1).select(struct()) - - errMsg = intercept[IllegalArgumentException] { - spark.range(1).select(greatest()) - }.getMessage - assert(errMsg.contains("requirement failed: greatest requires at least 2 arguments")) - - errMsg = intercept[IllegalArgumentException] { - spark.range(1).select(least()) - }.getMessage - assert(errMsg.contains("requirement failed: least requires at least 2 arguments")) - - errMsg = intercept[AnalysisException] { - spark.range(1).select(hash()) - }.getMessage - assert(errMsg.contains( - "due to data type mismatch: function hash requires at least one argument")) + val df = Seq(1).toDF("a") + + val funcsMustHaveAtLeastOneArg = + ("array", (df: DataFrame) => df.select(array())) :: + ("array", (df: DataFrame) => df.selectExpr("array()")) :: + ("map", (df: DataFrame) => df.select(map())) :: + ("map", (df: DataFrame) => df.selectExpr("map()")) :: + ("coalesce", (df: DataFrame) => df.select(coalesce())) :: + ("coalesce", (df: DataFrame) => df.selectExpr("coalesce()")) :: + ("named_struct", (df: DataFrame) => df.select(struct())) :: + ("named_struct", (df: DataFrame) => df.selectExpr("named_struct()")) :: + ("hash", (df: DataFrame) => df.select(hash())) :: + ("hash", (df: DataFrame) => df.selectExpr("hash()")) :: Nil + funcsMustHaveAtLeastOneArg.foreach { case (name, func) => + val errMsg = intercept[AnalysisException] { func(df) }.getMessage + assert(errMsg.contains(s"input to function $name requires at least 1 argument")) + } + + val funcsMustHaveAtLeastTwoArgs = + ("greatest", (df: DataFrame) => df.select(greatest())) :: + ("greatest", (df: DataFrame) => df.selectExpr("greatest()")) :: + ("least", (df: DataFrame) => df.select(least())) :: + ("least", (df: DataFrame) => df.selectExpr("least()")) :: Nil + funcsMustHaveAtLeastTwoArgs.foreach { case (name, func) => + val errMsg = intercept[AnalysisException] { func(df) }.getMessage + assert(errMsg.contains(s"input to function $name requires at least 2 arguments")) + } } } From 0b492fd5fe33e32256271394237b494254505a99 Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Tue, 4 Jul 2017 18:32:54 +0900 Subject: [PATCH 3/5] Use string types by default if array and map have no argument --- .../expressions/complexTypeCreator.scala | 42 ++++++++----------- .../expressions/nullExpressions.scala | 9 +++- .../spark/sql/catalyst/util/TypeUtils.scala | 27 +++++------- .../spark/sql/DataFrameFunctionsSuite.scala | 14 +++++-- 4 files changed, 47 insertions(+), 45 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala index 4584fafc256d..26b11abb350c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala @@ -47,7 +47,7 @@ case class CreateArray(children: Seq[Expression]) extends Expression { override def dataType: ArrayType = { ArrayType( - children.headOption.map(_.dataType).getOrElse(NullType), + children.headOption.map(_.dataType).getOrElse(StringType), containsNull = children.exists(_.nullable)) } @@ -94,7 +94,7 @@ private [sql] object GenArrayData { if (!ctx.isPrimitiveType(elementType)) { val genericArrayClass = classOf[GenericArrayData].getName ctx.addMutableState("Object[]", arrayName, - s"$arrayName = new Object[${numElements}];") + s"$arrayName = new Object[$numElements];") val assignments = elementsCode.zipWithIndex.map { case (eval, i) => val isNullAssignment = if (!isMapKey) { @@ -120,7 +120,7 @@ private [sql] object GenArrayData { UnsafeArrayData.calculateHeaderPortionInBytes(numElements) + ByteArrayMethods.roundNumberOfBytesToNearestWord(elementType.defaultSize * numElements) val baseOffset = Platform.BYTE_ARRAY_OFFSET - ctx.addMutableState("UnsafeArrayData", arrayDataName, ""); + ctx.addMutableState("UnsafeArrayData", arrayDataName, "") val primitiveValueTypeName = ctx.primitiveTypeName(elementType) val assignments = elementsCode.zipWithIndex.map { case (eval, i) => @@ -169,32 +169,26 @@ case class CreateMap(children: Seq[Expression]) extends Expression { override def foldable: Boolean = children.forall(_.foldable) override def checkInputDataTypes(): TypeCheckResult = { - TypeUtils.checkTypeInputDimension( - children.map(_.dataType), s"function $prettyName", requiredMinDimension = 1) match { - case TypeCheckResult.TypeCheckSuccess => - if (children.size % 2 != 0) { - TypeCheckResult.TypeCheckFailure( - s"$prettyName expects a positive even number of arguments.") - } else if (keys.map(_.dataType).distinct.length > 1) { - TypeCheckResult.TypeCheckFailure( - "The given keys of function map should all be the same type, but they are " + - keys.map(_.dataType.simpleString).mkString("[", ", ", "]")) - } else if (values.map(_.dataType).distinct.length > 1) { - TypeCheckResult.TypeCheckFailure( - "The given values of function map should all be the same type, but they are " + - values.map(_.dataType.simpleString).mkString("[", ", ", "]")) - } else { - TypeCheckResult.TypeCheckSuccess - } - case typeCheckFailure => - typeCheckFailure + if (children.size % 2 != 0) { + TypeCheckResult.TypeCheckFailure( + s"$prettyName expects a positive even number of arguments.") + } else if (keys.map(_.dataType).distinct.length > 1) { + TypeCheckResult.TypeCheckFailure( + "The given keys of function map should all be the same type, but they are " + + keys.map(_.dataType.simpleString).mkString("[", ", ", "]")) + } else if (values.map(_.dataType).distinct.length > 1) { + TypeCheckResult.TypeCheckFailure( + "The given values of function map should all be the same type, but they are " + + values.map(_.dataType.simpleString).mkString("[", ", ", "]")) + } else { + TypeCheckResult.TypeCheckSuccess } } override def dataType: DataType = { MapType( - keyType = keys.headOption.map(_.dataType).getOrElse(NullType), - valueType = values.headOption.map(_.dataType).getOrElse(NullType), + keyType = keys.headOption.map(_.dataType).getOrElse(StringType), + valueType = values.headOption.map(_.dataType).getOrElse(StringType), valueContainsNull = values.exists(_.nullable)) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala index f9d7a4213af1..9ae4c35b92d0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala @@ -52,7 +52,14 @@ case class Coalesce(children: Seq[Expression]) extends Expression { override def foldable: Boolean = children.forall(_.foldable) override def checkInputDataTypes(): TypeCheckResult = { - TypeUtils.checkForSameTypeInputExpr(children.map(_.dataType), s"function $prettyName") + val inputDataTypes = children.map(_.dataType) + TypeUtils.checkTypeInputDimension( + inputDataTypes, s"function $prettyName", requiredMinDimension = 1) match { + case TypeCheckResult.TypeCheckSuccess => + TypeUtils.checkForSameTypeInputExpr(inputDataTypes, s"function $prettyName") + case typeCheckFailure => + typeCheckFailure + } } override def dataType: DataType = children.head.dataType diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala index 8da87514a86a..faa98af969a0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala @@ -42,23 +42,18 @@ object TypeUtils { } def checkForSameTypeInputExpr(types: Seq[DataType], caller: String): TypeCheckResult = { - checkTypeInputDimension(types, caller, requiredMinDimension = 1) match { - case TypeCheckResult.TypeCheckSuccess => - if (types.size == 1) { - TypeCheckResult.TypeCheckSuccess - } else { - val firstType = types.head - types.foreach { t => - if (!t.sameType(firstType)) { - return TypeCheckResult.TypeCheckFailure( - s"input to $caller should all be the same type, but it's " + - types.map(_.simpleString).mkString("[", ", ", "]")) - } - } - TypeCheckResult.TypeCheckSuccess + if (types.size <= 1) { + TypeCheckResult.TypeCheckSuccess + } else { + val firstType = types.head + types.foreach { t => + if (!t.sameType(firstType)) { + return TypeCheckResult.TypeCheckFailure( + s"input to $caller should all be the same type, but it's " + + types.map(_.simpleString).mkString("[", ", ", "]")) } - case typeCheckFailure => - typeCheckFailure + } + TypeCheckResult.TypeCheckSuccess } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 55b8a89e3899..aeef4948bc2a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -449,14 +449,20 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { ).foreach(assertValuesDoNotChangeAfterCoalesceOrUnion(_)) } + test("SPARK-21281 use string types by default if array and map have no argument") { + val ds = spark.range(1) + var expectedSchema = new StructType() + .add("x", ArrayType(StringType, containsNull = false), nullable = false) + assert(ds.select(array().as("x")).schema == expectedSchema) + expectedSchema = new StructType() + .add("x", MapType(StringType, StringType, valueContainsNull = false), nullable = false) + assert(ds.select(map().as("x")).schema == expectedSchema) + } + test("SPARK-21281 fails if functions have no argument") { val df = Seq(1).toDF("a") val funcsMustHaveAtLeastOneArg = - ("array", (df: DataFrame) => df.select(array())) :: - ("array", (df: DataFrame) => df.selectExpr("array()")) :: - ("map", (df: DataFrame) => df.select(map())) :: - ("map", (df: DataFrame) => df.selectExpr("map()")) :: ("coalesce", (df: DataFrame) => df.select(coalesce())) :: ("coalesce", (df: DataFrame) => df.selectExpr("coalesce()")) :: ("named_struct", (df: DataFrame) => df.select(struct())) :: From 04cbf784944aff0c9a1c195b4facb5cc97d0fe8a Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Wed, 5 Jul 2017 10:32:29 +0900 Subject: [PATCH 4/5] Apply more fixes --- .../sql/catalyst/expressions/arithmetic.scala | 42 ++++++++----------- .../expressions/complexTypeCreator.scala | 33 +++++++-------- .../spark/sql/catalyst/expressions/hash.scala | 10 +++-- .../expressions/nullExpressions.scala | 12 +++--- .../spark/sql/catalyst/util/TypeUtils.scala | 11 ----- .../ExpressionTypeCheckingSuite.scala | 4 +- .../spark/sql/DataFrameFunctionsSuite.scala | 4 +- 7 files changed, 50 insertions(+), 66 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index 2c00a0df014e..423bf66a24d1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -526,18 +526,15 @@ case class Least(children: Seq[Expression]) extends Expression { private lazy val ordering = TypeUtils.getInterpretedOrdering(dataType) override def checkInputDataTypes(): TypeCheckResult = { - TypeUtils.checkTypeInputDimension( - children.map(_.dataType), s"function $prettyName", requiredMinDimension = 2) match { - case TypeCheckResult.TypeCheckSuccess => - if (children.map(_.dataType).distinct.count(_ != NullType) > 1) { - TypeCheckResult.TypeCheckFailure( - s"The expressions should all have the same type," + - s" got LEAST(${children.map(_.dataType.simpleString).mkString(", ")}).") - } else { - TypeUtils.checkForOrderingExpr(dataType, s"function $prettyName") - } - case typeCheckFailure => - typeCheckFailure + if (children.length <= 1) { + TypeCheckResult.TypeCheckFailure( + s"input to function $prettyName requires at least two arguments") + } else if (children.map(_.dataType).distinct.count(_ != NullType) > 1) { + TypeCheckResult.TypeCheckFailure( + s"The expressions should all have the same type," + + s" got LEAST(${children.map(_.dataType.simpleString).mkString(", ")}).") + } else { + TypeUtils.checkForOrderingExpr(dataType, s"function $prettyName") } } @@ -595,18 +592,15 @@ case class Greatest(children: Seq[Expression]) extends Expression { private lazy val ordering = TypeUtils.getInterpretedOrdering(dataType) override def checkInputDataTypes(): TypeCheckResult = { - TypeUtils.checkTypeInputDimension( - children.map(_.dataType), s"function $prettyName", requiredMinDimension = 2) match { - case TypeCheckResult.TypeCheckSuccess => - if (children.map(_.dataType).distinct.count(_ != NullType) > 1) { - TypeCheckResult.TypeCheckFailure( - s"The expressions should all have the same type," + - s" got GREATEST(${children.map(_.dataType.simpleString).mkString(", ")}).") - } else { - TypeUtils.checkForOrderingExpr(dataType, s"function $prettyName") - } - case typeCheckFailure => - typeCheckFailure + if (children.length <= 1) { + TypeCheckResult.TypeCheckFailure( + s"input to function $prettyName requires at least two arguments") + } else if (children.map(_.dataType).distinct.count(_ != NullType) > 1) { + TypeCheckResult.TypeCheckFailure( + s"The expressions should all have the same type," + + s" got GREATEST(${children.map(_.dataType.simpleString).mkString(", ")}).") + } else { + TypeUtils.checkForOrderingExpr(dataType, s"function $prettyName") } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala index 26b11abb350c..5af2d2f0fde9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala @@ -296,25 +296,24 @@ trait CreateNamedStructLike extends Expression { } override def checkInputDataTypes(): TypeCheckResult = { - TypeUtils.checkTypeInputDimension( - children.map(_.dataType), s"function $prettyName", requiredMinDimension = 1) match { - case TypeCheckResult.TypeCheckSuccess => - if (children.size % 2 != 0) { - TypeCheckResult.TypeCheckFailure(s"$prettyName expects an even number of arguments.") + if (children.length < 1) { + TypeCheckResult.TypeCheckFailure( + s"input to function $prettyName requires at least one argument") + } else { + if (children.size % 2 != 0) { + TypeCheckResult.TypeCheckFailure(s"$prettyName expects an even number of arguments.") + } else { + val invalidNames = nameExprs.filterNot(e => e.foldable && e.dataType == StringType) + if (invalidNames.nonEmpty) { + TypeCheckResult.TypeCheckFailure( + "Only foldable StringType expressions are allowed to appear at odd position, got:" + + s" ${invalidNames.mkString(",")}") + } else if (!names.contains(null)) { + TypeCheckResult.TypeCheckSuccess } else { - val invalidNames = nameExprs.filterNot(e => e.foldable && e.dataType == StringType) - if (invalidNames.nonEmpty) { - TypeCheckResult.TypeCheckFailure( - "Only foldable StringType expressions are allowed to appear at odd position, got:" + - s" ${invalidNames.mkString(",")}") - } else if (!names.contains(null)) { - TypeCheckResult.TypeCheckSuccess - } else { - TypeCheckResult.TypeCheckFailure("Field name should not be null") - } + TypeCheckResult.TypeCheckFailure("Field name should not be null") } - case typeCheckFailure => - typeCheckFailure + } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala index f928917c9ab0..2476fc962a6f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala @@ -28,7 +28,7 @@ import org.apache.commons.codec.digest.DigestUtils import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen._ -import org.apache.spark.sql.catalyst.util.{ArrayData, MapData, TypeUtils} +import org.apache.spark.sql.catalyst.util.{ArrayData, MapData} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.hash.Murmur3_x86_32 import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} @@ -247,8 +247,12 @@ abstract class HashExpression[E] extends Expression { override def nullable: Boolean = false override def checkInputDataTypes(): TypeCheckResult = { - TypeUtils.checkTypeInputDimension( - children.map(_.dataType), s"function $prettyName", requiredMinDimension = 1) + if (children.length < 1) { + TypeCheckResult.TypeCheckFailure( + s"input to function $prettyName requires at least one argument") + } else { + TypeCheckResult.TypeCheckSuccess + } } override def eval(input: InternalRow = null): Any = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala index 9ae4c35b92d0..1b625141d56a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala @@ -52,13 +52,11 @@ case class Coalesce(children: Seq[Expression]) extends Expression { override def foldable: Boolean = children.forall(_.foldable) override def checkInputDataTypes(): TypeCheckResult = { - val inputDataTypes = children.map(_.dataType) - TypeUtils.checkTypeInputDimension( - inputDataTypes, s"function $prettyName", requiredMinDimension = 1) match { - case TypeCheckResult.TypeCheckSuccess => - TypeUtils.checkForSameTypeInputExpr(inputDataTypes, s"function $prettyName") - case typeCheckFailure => - typeCheckFailure + if (children.length < 1) { + TypeCheckResult.TypeCheckFailure( + s"input to function $prettyName requires at least one argument") + } else { + TypeUtils.checkForSameTypeInputExpr(children.map(_.dataType), s"function $prettyName") } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala index faa98af969a0..7101ca5a17de 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala @@ -57,17 +57,6 @@ object TypeUtils { } } - def checkTypeInputDimension(types: Seq[DataType], caller: String, requiredMinDimension: Int) - : TypeCheckResult = { - if (types.size >= requiredMinDimension) { - TypeCheckResult.TypeCheckSuccess - } else { - TypeCheckResult.TypeCheckFailure( - s"input to $caller requires at least $requiredMinDimension " + - s"argument${if (requiredMinDimension > 1) "s"}") - } - } - def getNumeric(t: DataType): Numeric[Any] = t.asInstanceOf[NumericType].numeric.asInstanceOf[Numeric[Any]] diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala index 30459f173ab5..30725773a37b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala @@ -155,7 +155,7 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite { "input to function array should all be the same type") assertError(Coalesce(Seq('intField, 'booleanField)), "input to function coalesce should all be the same type") - assertError(Coalesce(Nil), "input to function coalesce cannot be empty") + assertError(Coalesce(Nil), "function coalesce requires at least one argument") assertError(new Murmur3Hash(Nil), "function hash requires at least one argument") assertError(Explode('intField), "input to function explode should be array or map type") @@ -207,7 +207,7 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite { test("check types for Greatest/Least") { for (operator <- Seq[(Seq[Expression] => Expression)](Greatest, Least)) { - assertError(operator(Seq('booleanField)), "requires at least 2 arguments") + assertError(operator(Seq('booleanField)), "requires at least two arguments") assertError(operator(Seq('intField, 'stringField)), "should all have the same type") assertError(operator(Seq('mapField, 'mapField)), "does not support ordering") } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index aeef4948bc2a..0681b9cbeb1d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -471,7 +471,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { ("hash", (df: DataFrame) => df.selectExpr("hash()")) :: Nil funcsMustHaveAtLeastOneArg.foreach { case (name, func) => val errMsg = intercept[AnalysisException] { func(df) }.getMessage - assert(errMsg.contains(s"input to function $name requires at least 1 argument")) + assert(errMsg.contains(s"input to function $name requires at least one argument")) } val funcsMustHaveAtLeastTwoArgs = @@ -481,7 +481,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { ("least", (df: DataFrame) => df.selectExpr("least()")) :: Nil funcsMustHaveAtLeastTwoArgs.foreach { case (name, func) => val errMsg = intercept[AnalysisException] { func(df) }.getMessage - assert(errMsg.contains(s"input to function $name requires at least 2 arguments")) + assert(errMsg.contains(s"input to function $name requires at least two arguments")) } } } From f71dc27d8555c7e39a67c256b73fc0836141eafb Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Thu, 6 Jul 2017 09:29:54 +0900 Subject: [PATCH 5/5] Apply comments --- .../expressions/complexTypeCreator.scala | 22 +++++++++---------- 1 file changed, 10 insertions(+), 12 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala index 5af2d2f0fde9..d9eeb5358ef7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala @@ -299,20 +299,18 @@ trait CreateNamedStructLike extends Expression { if (children.length < 1) { TypeCheckResult.TypeCheckFailure( s"input to function $prettyName requires at least one argument") + } else if (children.size % 2 != 0) { + TypeCheckResult.TypeCheckFailure(s"$prettyName expects an even number of arguments.") } else { - if (children.size % 2 != 0) { - TypeCheckResult.TypeCheckFailure(s"$prettyName expects an even number of arguments.") + val invalidNames = nameExprs.filterNot(e => e.foldable && e.dataType == StringType) + if (invalidNames.nonEmpty) { + TypeCheckResult.TypeCheckFailure( + "Only foldable StringType expressions are allowed to appear at odd position, got:" + + s" ${invalidNames.mkString(",")}") + } else if (!names.contains(null)) { + TypeCheckResult.TypeCheckSuccess } else { - val invalidNames = nameExprs.filterNot(e => e.foldable && e.dataType == StringType) - if (invalidNames.nonEmpty) { - TypeCheckResult.TypeCheckFailure( - "Only foldable StringType expressions are allowed to appear at odd position, got:" + - s" ${invalidNames.mkString(",")}") - } else if (!names.contains(null)) { - TypeCheckResult.TypeCheckSuccess - } else { - TypeCheckResult.TypeCheckFailure("Field name should not be null") - } + TypeCheckResult.TypeCheckFailure("Field name should not be null") } } }