Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -527,13 +527,14 @@ case class Least(children: Seq[Expression]) extends Expression {

override def checkInputDataTypes(): TypeCheckResult = {
if (children.length <= 1) {
TypeCheckResult.TypeCheckFailure(s"LEAST requires at least 2 arguments")
TypeCheckResult.TypeCheckFailure(
s"input to function $prettyName requires at least two arguments")
} else if (children.map(_.dataType).distinct.count(_ != NullType) > 1) {
TypeCheckResult.TypeCheckFailure(
s"The expressions should all have the same type," +
s" got LEAST(${children.map(_.dataType.simpleString).mkString(", ")}).")
} else {
TypeUtils.checkForOrderingExpr(dataType, "function " + prettyName)
TypeUtils.checkForOrderingExpr(dataType, s"function $prettyName")
}
}

Expand Down Expand Up @@ -592,13 +593,14 @@ case class Greatest(children: Seq[Expression]) extends Expression {

override def checkInputDataTypes(): TypeCheckResult = {
if (children.length <= 1) {
TypeCheckResult.TypeCheckFailure(s"GREATEST requires at least 2 arguments")
TypeCheckResult.TypeCheckFailure(
s"input to function $prettyName requires at least two arguments")
} else if (children.map(_.dataType).distinct.count(_ != NullType) > 1) {
TypeCheckResult.TypeCheckFailure(
s"The expressions should all have the same type," +
s" got GREATEST(${children.map(_.dataType.simpleString).mkString(", ")}).")
} else {
TypeUtils.checkForOrderingExpr(dataType, "function " + prettyName)
TypeUtils.checkForOrderingExpr(dataType, s"function $prettyName")
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,13 @@ case class CreateArray(children: Seq[Expression]) extends Expression {

override def foldable: Boolean = children.forall(_.foldable)

override def checkInputDataTypes(): TypeCheckResult =
TypeUtils.checkForSameTypeInputExpr(children.map(_.dataType), "function array")
override def checkInputDataTypes(): TypeCheckResult = {
TypeUtils.checkForSameTypeInputExpr(children.map(_.dataType), s"function $prettyName")
}

override def dataType: ArrayType = {
ArrayType(
children.headOption.map(_.dataType).getOrElse(NullType),
children.headOption.map(_.dataType).getOrElse(StringType),
containsNull = children.exists(_.nullable))
}

Expand Down Expand Up @@ -93,7 +94,7 @@ private [sql] object GenArrayData {
if (!ctx.isPrimitiveType(elementType)) {
val genericArrayClass = classOf[GenericArrayData].getName
ctx.addMutableState("Object[]", arrayName,
s"$arrayName = new Object[${numElements}];")
s"$arrayName = new Object[$numElements];")

val assignments = elementsCode.zipWithIndex.map { case (eval, i) =>
val isNullAssignment = if (!isMapKey) {
Expand All @@ -119,7 +120,7 @@ private [sql] object GenArrayData {
UnsafeArrayData.calculateHeaderPortionInBytes(numElements) +
ByteArrayMethods.roundNumberOfBytesToNearestWord(elementType.defaultSize * numElements)
val baseOffset = Platform.BYTE_ARRAY_OFFSET
ctx.addMutableState("UnsafeArrayData", arrayDataName, "");
ctx.addMutableState("UnsafeArrayData", arrayDataName, "")

val primitiveValueTypeName = ctx.primitiveTypeName(elementType)
val assignments = elementsCode.zipWithIndex.map { case (eval, i) =>
Expand Down Expand Up @@ -169,22 +170,25 @@ case class CreateMap(children: Seq[Expression]) extends Expression {

override def checkInputDataTypes(): TypeCheckResult = {
if (children.size % 2 != 0) {
TypeCheckResult.TypeCheckFailure(s"$prettyName expects a positive even number of arguments.")
TypeCheckResult.TypeCheckFailure(
s"$prettyName expects a positive even number of arguments.")
} else if (keys.map(_.dataType).distinct.length > 1) {
TypeCheckResult.TypeCheckFailure("The given keys of function map should all be the same " +
"type, but they are " + keys.map(_.dataType.simpleString).mkString("[", ", ", "]"))
TypeCheckResult.TypeCheckFailure(
"The given keys of function map should all be the same type, but they are " +
keys.map(_.dataType.simpleString).mkString("[", ", ", "]"))
} else if (values.map(_.dataType).distinct.length > 1) {
TypeCheckResult.TypeCheckFailure("The given values of function map should all be the same " +
"type, but they are " + values.map(_.dataType.simpleString).mkString("[", ", ", "]"))
TypeCheckResult.TypeCheckFailure(
"The given values of function map should all be the same type, but they are " +
values.map(_.dataType.simpleString).mkString("[", ", ", "]"))
} else {
TypeCheckResult.TypeCheckSuccess
}
}

override def dataType: DataType = {
MapType(
keyType = keys.headOption.map(_.dataType).getOrElse(NullType),
valueType = values.headOption.map(_.dataType).getOrElse(NullType),
keyType = keys.headOption.map(_.dataType).getOrElse(StringType),
valueType = values.headOption.map(_.dataType).getOrElse(StringType),
valueContainsNull = values.exists(_.nullable))
}

Expand Down Expand Up @@ -292,14 +296,17 @@ trait CreateNamedStructLike extends Expression {
}

override def checkInputDataTypes(): TypeCheckResult = {
if (children.size % 2 != 0) {
if (children.length < 1) {
TypeCheckResult.TypeCheckFailure(
s"input to function $prettyName requires at least one argument")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not related to what this PR claims to do. What's the reason behind this change?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a behavior change and caused a problem in #22373

Copy link
Member Author

@maropu maropu Sep 10, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sorry, but I don't remember correctly.
I looked over this pr again and I also think the modification is not related to this pr. So, it's ok to revert this part.

} else if (children.size % 2 != 0) {
TypeCheckResult.TypeCheckFailure(s"$prettyName expects an even number of arguments.")
} else {
val invalidNames = nameExprs.filterNot(e => e.foldable && e.dataType == StringType)
if (invalidNames.nonEmpty) {
TypeCheckResult.TypeCheckFailure(
"Only foldable StringType expressions are allowed to appear at odd position, got:" +
s" ${invalidNames.mkString(",")}")
s" ${invalidNames.mkString(",")}")
} else if (!names.contains(null)) {
TypeCheckResult.TypeCheckSuccess
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -247,8 +247,9 @@ abstract class HashExpression[E] extends Expression {
override def nullable: Boolean = false

override def checkInputDataTypes(): TypeCheckResult = {
if (children.isEmpty) {
TypeCheckResult.TypeCheckFailure("function hash requires at least one argument")
if (children.length < 1) {
TypeCheckResult.TypeCheckFailure(
s"input to function $prettyName requires at least one argument")
} else {
TypeCheckResult.TypeCheckSuccess
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,10 +52,11 @@ case class Coalesce(children: Seq[Expression]) extends Expression {
override def foldable: Boolean = children.forall(_.foldable)

override def checkInputDataTypes(): TypeCheckResult = {
if (children == Nil) {
TypeCheckResult.TypeCheckFailure("input to function coalesce cannot be empty")
if (children.length < 1) {
TypeCheckResult.TypeCheckFailure(
s"input to function $prettyName requires at least one argument")
} else {
TypeUtils.checkForSameTypeInputExpr(children.map(_.dataType), "function coalesce")
TypeUtils.checkForSameTypeInputExpr(children.map(_.dataType), s"function $prettyName")
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite {
"input to function array should all be the same type")
assertError(Coalesce(Seq('intField, 'booleanField)),
"input to function coalesce should all be the same type")
assertError(Coalesce(Nil), "input to function coalesce cannot be empty")
assertError(Coalesce(Nil), "function coalesce requires at least one argument")
assertError(new Murmur3Hash(Nil), "function hash requires at least one argument")
assertError(Explode('intField),
"input to function explode should be array or map type")
Expand Down Expand Up @@ -207,7 +207,7 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite {

test("check types for Greatest/Least") {
for (operator <- Seq[(Seq[Expression] => Expression)](Greatest, Least)) {
assertError(operator(Seq('booleanField)), "requires at least 2 arguments")
assertError(operator(Seq('booleanField)), "requires at least two arguments")
assertError(operator(Seq('intField, 'stringField)), "should all have the same type")
assertError(operator(Seq('mapField, 'mapField)), "does not support ordering")
}
Expand Down
10 changes: 2 additions & 8 deletions sql/core/src/main/scala/org/apache/spark/sql/functions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1566,10 +1566,7 @@ object functions {
* @since 1.5.0
*/
@scala.annotation.varargs
def greatest(exprs: Column*): Column = withExpr {
require(exprs.length > 1, "greatest requires at least 2 arguments.")
Greatest(exprs.map(_.expr))
}
def greatest(exprs: Column*): Column = withExpr { Greatest(exprs.map(_.expr)) }

/**
* Returns the greatest value of the list of column names, skipping null values.
Expand Down Expand Up @@ -1673,10 +1670,7 @@ object functions {
* @since 1.5.0
*/
@scala.annotation.varargs
def least(exprs: Column*): Column = withExpr {
require(exprs.length > 1, "least requires at least 2 arguments.")
Least(exprs.map(_.expr))
}
def least(exprs: Column*): Column = withExpr { Least(exprs.map(_.expr)) }

/**
* Returns the least value of the list of column names, skipping null values.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -448,6 +448,42 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext {
rand(Random.nextLong()), randn(Random.nextLong())
).foreach(assertValuesDoNotChangeAfterCoalesceOrUnion(_))
}

test("SPARK-21281 use string types by default if array and map have no argument") {
val ds = spark.range(1)
var expectedSchema = new StructType()
.add("x", ArrayType(StringType, containsNull = false), nullable = false)
assert(ds.select(array().as("x")).schema == expectedSchema)
expectedSchema = new StructType()
.add("x", MapType(StringType, StringType, valueContainsNull = false), nullable = false)
assert(ds.select(map().as("x")).schema == expectedSchema)
}

test("SPARK-21281 fails if functions have no argument") {
val df = Seq(1).toDF("a")

val funcsMustHaveAtLeastOneArg =
("coalesce", (df: DataFrame) => df.select(coalesce())) ::
("coalesce", (df: DataFrame) => df.selectExpr("coalesce()")) ::
("named_struct", (df: DataFrame) => df.select(struct())) ::
("named_struct", (df: DataFrame) => df.selectExpr("named_struct()")) ::
("hash", (df: DataFrame) => df.select(hash())) ::
("hash", (df: DataFrame) => df.selectExpr("hash()")) :: Nil
funcsMustHaveAtLeastOneArg.foreach { case (name, func) =>
val errMsg = intercept[AnalysisException] { func(df) }.getMessage
assert(errMsg.contains(s"input to function $name requires at least one argument"))
}

val funcsMustHaveAtLeastTwoArgs =
("greatest", (df: DataFrame) => df.select(greatest())) ::
("greatest", (df: DataFrame) => df.selectExpr("greatest()")) ::
("least", (df: DataFrame) => df.select(least())) ::
("least", (df: DataFrame) => df.selectExpr("least()")) :: Nil
funcsMustHaveAtLeastTwoArgs.foreach { case (name, func) =>
val errMsg = intercept[AnalysisException] { func(df) }.getMessage
assert(errMsg.contains(s"input to function $name requires at least two arguments"))
}
}
}

object DataFrameFunctionsSuite {
Expand Down