diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index ec6e6ba0f091..423bf66a24d1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -527,13 +527,14 @@ case class Least(children: Seq[Expression]) extends Expression { override def checkInputDataTypes(): TypeCheckResult = { if (children.length <= 1) { - TypeCheckResult.TypeCheckFailure(s"LEAST requires at least 2 arguments") + TypeCheckResult.TypeCheckFailure( + s"input to function $prettyName requires at least two arguments") } else if (children.map(_.dataType).distinct.count(_ != NullType) > 1) { TypeCheckResult.TypeCheckFailure( s"The expressions should all have the same type," + s" got LEAST(${children.map(_.dataType.simpleString).mkString(", ")}).") } else { - TypeUtils.checkForOrderingExpr(dataType, "function " + prettyName) + TypeUtils.checkForOrderingExpr(dataType, s"function $prettyName") } } @@ -592,13 +593,14 @@ case class Greatest(children: Seq[Expression]) extends Expression { override def checkInputDataTypes(): TypeCheckResult = { if (children.length <= 1) { - TypeCheckResult.TypeCheckFailure(s"GREATEST requires at least 2 arguments") + TypeCheckResult.TypeCheckFailure( + s"input to function $prettyName requires at least two arguments") } else if (children.map(_.dataType).distinct.count(_ != NullType) > 1) { TypeCheckResult.TypeCheckFailure( s"The expressions should all have the same type," + s" got GREATEST(${children.map(_.dataType.simpleString).mkString(", ")}).") } else { - TypeUtils.checkForOrderingExpr(dataType, "function " + prettyName) + TypeUtils.checkForOrderingExpr(dataType, s"function $prettyName") } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala index 98c4cbee38de..d9eeb5358ef7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala @@ -41,12 +41,13 @@ case class CreateArray(children: Seq[Expression]) extends Expression { override def foldable: Boolean = children.forall(_.foldable) - override def checkInputDataTypes(): TypeCheckResult = - TypeUtils.checkForSameTypeInputExpr(children.map(_.dataType), "function array") + override def checkInputDataTypes(): TypeCheckResult = { + TypeUtils.checkForSameTypeInputExpr(children.map(_.dataType), s"function $prettyName") + } override def dataType: ArrayType = { ArrayType( - children.headOption.map(_.dataType).getOrElse(NullType), + children.headOption.map(_.dataType).getOrElse(StringType), containsNull = children.exists(_.nullable)) } @@ -93,7 +94,7 @@ private [sql] object GenArrayData { if (!ctx.isPrimitiveType(elementType)) { val genericArrayClass = classOf[GenericArrayData].getName ctx.addMutableState("Object[]", arrayName, - s"$arrayName = new Object[${numElements}];") + s"$arrayName = new Object[$numElements];") val assignments = elementsCode.zipWithIndex.map { case (eval, i) => val isNullAssignment = if (!isMapKey) { @@ -119,7 +120,7 @@ private [sql] object GenArrayData { UnsafeArrayData.calculateHeaderPortionInBytes(numElements) + ByteArrayMethods.roundNumberOfBytesToNearestWord(elementType.defaultSize * numElements) val baseOffset = Platform.BYTE_ARRAY_OFFSET - ctx.addMutableState("UnsafeArrayData", arrayDataName, ""); + ctx.addMutableState("UnsafeArrayData", arrayDataName, "") val primitiveValueTypeName = ctx.primitiveTypeName(elementType) val assignments = elementsCode.zipWithIndex.map { case (eval, i) => @@ -169,13 +170,16 @@ case class CreateMap(children: Seq[Expression]) extends Expression { override def checkInputDataTypes(): TypeCheckResult = { if (children.size % 2 != 0) { - TypeCheckResult.TypeCheckFailure(s"$prettyName expects a positive even number of arguments.") + TypeCheckResult.TypeCheckFailure( + s"$prettyName expects a positive even number of arguments.") } else if (keys.map(_.dataType).distinct.length > 1) { - TypeCheckResult.TypeCheckFailure("The given keys of function map should all be the same " + - "type, but they are " + keys.map(_.dataType.simpleString).mkString("[", ", ", "]")) + TypeCheckResult.TypeCheckFailure( + "The given keys of function map should all be the same type, but they are " + + keys.map(_.dataType.simpleString).mkString("[", ", ", "]")) } else if (values.map(_.dataType).distinct.length > 1) { - TypeCheckResult.TypeCheckFailure("The given values of function map should all be the same " + - "type, but they are " + values.map(_.dataType.simpleString).mkString("[", ", ", "]")) + TypeCheckResult.TypeCheckFailure( + "The given values of function map should all be the same type, but they are " + + values.map(_.dataType.simpleString).mkString("[", ", ", "]")) } else { TypeCheckResult.TypeCheckSuccess } @@ -183,8 +187,8 @@ case class CreateMap(children: Seq[Expression]) extends Expression { override def dataType: DataType = { MapType( - keyType = keys.headOption.map(_.dataType).getOrElse(NullType), - valueType = values.headOption.map(_.dataType).getOrElse(NullType), + keyType = keys.headOption.map(_.dataType).getOrElse(StringType), + valueType = values.headOption.map(_.dataType).getOrElse(StringType), valueContainsNull = values.exists(_.nullable)) } @@ -292,14 +296,17 @@ trait CreateNamedStructLike extends Expression { } override def checkInputDataTypes(): TypeCheckResult = { - if (children.size % 2 != 0) { + if (children.length < 1) { + TypeCheckResult.TypeCheckFailure( + s"input to function $prettyName requires at least one argument") + } else if (children.size % 2 != 0) { TypeCheckResult.TypeCheckFailure(s"$prettyName expects an even number of arguments.") } else { val invalidNames = nameExprs.filterNot(e => e.foldable && e.dataType == StringType) if (invalidNames.nonEmpty) { TypeCheckResult.TypeCheckFailure( "Only foldable StringType expressions are allowed to appear at odd position, got:" + - s" ${invalidNames.mkString(",")}") + s" ${invalidNames.mkString(",")}") } else if (!names.contains(null)) { TypeCheckResult.TypeCheckSuccess } else { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala index ffd0e64d86cf..2476fc962a6f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala @@ -247,8 +247,9 @@ abstract class HashExpression[E] extends Expression { override def nullable: Boolean = false override def checkInputDataTypes(): TypeCheckResult = { - if (children.isEmpty) { - TypeCheckResult.TypeCheckFailure("function hash requires at least one argument") + if (children.length < 1) { + TypeCheckResult.TypeCheckFailure( + s"input to function $prettyName requires at least one argument") } else { TypeCheckResult.TypeCheckSuccess } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala index 0866b8d791e0..1b625141d56a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala @@ -52,10 +52,11 @@ case class Coalesce(children: Seq[Expression]) extends Expression { override def foldable: Boolean = children.forall(_.foldable) override def checkInputDataTypes(): TypeCheckResult = { - if (children == Nil) { - TypeCheckResult.TypeCheckFailure("input to function coalesce cannot be empty") + if (children.length < 1) { + TypeCheckResult.TypeCheckFailure( + s"input to function $prettyName requires at least one argument") } else { - TypeUtils.checkForSameTypeInputExpr(children.map(_.dataType), "function coalesce") + TypeUtils.checkForSameTypeInputExpr(children.map(_.dataType), s"function $prettyName") } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala index 30459f173ab5..30725773a37b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala @@ -155,7 +155,7 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite { "input to function array should all be the same type") assertError(Coalesce(Seq('intField, 'booleanField)), "input to function coalesce should all be the same type") - assertError(Coalesce(Nil), "input to function coalesce cannot be empty") + assertError(Coalesce(Nil), "function coalesce requires at least one argument") assertError(new Murmur3Hash(Nil), "function hash requires at least one argument") assertError(Explode('intField), "input to function explode should be array or map type") @@ -207,7 +207,7 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite { test("check types for Greatest/Least") { for (operator <- Seq[(Seq[Expression] => Expression)](Greatest, Least)) { - assertError(operator(Seq('booleanField)), "requires at least 2 arguments") + assertError(operator(Seq('booleanField)), "requires at least two arguments") assertError(operator(Seq('intField, 'stringField)), "should all have the same type") assertError(operator(Seq('mapField, 'mapField)), "does not support ordering") } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 839cbf42024e..66890c19d3b2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -1566,10 +1566,7 @@ object functions { * @since 1.5.0 */ @scala.annotation.varargs - def greatest(exprs: Column*): Column = withExpr { - require(exprs.length > 1, "greatest requires at least 2 arguments.") - Greatest(exprs.map(_.expr)) - } + def greatest(exprs: Column*): Column = withExpr { Greatest(exprs.map(_.expr)) } /** * Returns the greatest value of the list of column names, skipping null values. @@ -1673,10 +1670,7 @@ object functions { * @since 1.5.0 */ @scala.annotation.varargs - def least(exprs: Column*): Column = withExpr { - require(exprs.length > 1, "least requires at least 2 arguments.") - Least(exprs.map(_.expr)) - } + def least(exprs: Column*): Column = withExpr { Least(exprs.map(_.expr)) } /** * Returns the least value of the list of column names, skipping null values. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 0e9a2c6cf7de..0681b9cbeb1d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -448,6 +448,42 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { rand(Random.nextLong()), randn(Random.nextLong()) ).foreach(assertValuesDoNotChangeAfterCoalesceOrUnion(_)) } + + test("SPARK-21281 use string types by default if array and map have no argument") { + val ds = spark.range(1) + var expectedSchema = new StructType() + .add("x", ArrayType(StringType, containsNull = false), nullable = false) + assert(ds.select(array().as("x")).schema == expectedSchema) + expectedSchema = new StructType() + .add("x", MapType(StringType, StringType, valueContainsNull = false), nullable = false) + assert(ds.select(map().as("x")).schema == expectedSchema) + } + + test("SPARK-21281 fails if functions have no argument") { + val df = Seq(1).toDF("a") + + val funcsMustHaveAtLeastOneArg = + ("coalesce", (df: DataFrame) => df.select(coalesce())) :: + ("coalesce", (df: DataFrame) => df.selectExpr("coalesce()")) :: + ("named_struct", (df: DataFrame) => df.select(struct())) :: + ("named_struct", (df: DataFrame) => df.selectExpr("named_struct()")) :: + ("hash", (df: DataFrame) => df.select(hash())) :: + ("hash", (df: DataFrame) => df.selectExpr("hash()")) :: Nil + funcsMustHaveAtLeastOneArg.foreach { case (name, func) => + val errMsg = intercept[AnalysisException] { func(df) }.getMessage + assert(errMsg.contains(s"input to function $name requires at least one argument")) + } + + val funcsMustHaveAtLeastTwoArgs = + ("greatest", (df: DataFrame) => df.select(greatest())) :: + ("greatest", (df: DataFrame) => df.selectExpr("greatest()")) :: + ("least", (df: DataFrame) => df.select(least())) :: + ("least", (df: DataFrame) => df.selectExpr("least()")) :: Nil + funcsMustHaveAtLeastTwoArgs.foreach { case (name, func) => + val errMsg = intercept[AnalysisException] { func(df) }.getMessage + assert(errMsg.contains(s"input to function $name requires at least two arguments")) + } + } } object DataFrameFunctionsSuite {