Skip to content

Commit 04cbf78

Browse files
committed
Apply more fixes
1 parent 0b492fd commit 04cbf78

File tree

7 files changed

+50
-66
lines changed

7 files changed

+50
-66
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala

Lines changed: 18 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -526,18 +526,15 @@ case class Least(children: Seq[Expression]) extends Expression {
526526
private lazy val ordering = TypeUtils.getInterpretedOrdering(dataType)
527527

528528
override def checkInputDataTypes(): TypeCheckResult = {
529-
TypeUtils.checkTypeInputDimension(
530-
children.map(_.dataType), s"function $prettyName", requiredMinDimension = 2) match {
531-
case TypeCheckResult.TypeCheckSuccess =>
532-
if (children.map(_.dataType).distinct.count(_ != NullType) > 1) {
533-
TypeCheckResult.TypeCheckFailure(
534-
s"The expressions should all have the same type," +
535-
s" got LEAST(${children.map(_.dataType.simpleString).mkString(", ")}).")
536-
} else {
537-
TypeUtils.checkForOrderingExpr(dataType, s"function $prettyName")
538-
}
539-
case typeCheckFailure =>
540-
typeCheckFailure
529+
if (children.length <= 1) {
530+
TypeCheckResult.TypeCheckFailure(
531+
s"input to function $prettyName requires at least two arguments")
532+
} else if (children.map(_.dataType).distinct.count(_ != NullType) > 1) {
533+
TypeCheckResult.TypeCheckFailure(
534+
s"The expressions should all have the same type," +
535+
s" got LEAST(${children.map(_.dataType.simpleString).mkString(", ")}).")
536+
} else {
537+
TypeUtils.checkForOrderingExpr(dataType, s"function $prettyName")
541538
}
542539
}
543540

@@ -595,18 +592,15 @@ case class Greatest(children: Seq[Expression]) extends Expression {
595592
private lazy val ordering = TypeUtils.getInterpretedOrdering(dataType)
596593

597594
override def checkInputDataTypes(): TypeCheckResult = {
598-
TypeUtils.checkTypeInputDimension(
599-
children.map(_.dataType), s"function $prettyName", requiredMinDimension = 2) match {
600-
case TypeCheckResult.TypeCheckSuccess =>
601-
if (children.map(_.dataType).distinct.count(_ != NullType) > 1) {
602-
TypeCheckResult.TypeCheckFailure(
603-
s"The expressions should all have the same type," +
604-
s" got GREATEST(${children.map(_.dataType.simpleString).mkString(", ")}).")
605-
} else {
606-
TypeUtils.checkForOrderingExpr(dataType, s"function $prettyName")
607-
}
608-
case typeCheckFailure =>
609-
typeCheckFailure
595+
if (children.length <= 1) {
596+
TypeCheckResult.TypeCheckFailure(
597+
s"input to function $prettyName requires at least two arguments")
598+
} else if (children.map(_.dataType).distinct.count(_ != NullType) > 1) {
599+
TypeCheckResult.TypeCheckFailure(
600+
s"The expressions should all have the same type," +
601+
s" got GREATEST(${children.map(_.dataType.simpleString).mkString(", ")}).")
602+
} else {
603+
TypeUtils.checkForOrderingExpr(dataType, s"function $prettyName")
610604
}
611605
}
612606

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala

Lines changed: 16 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -296,25 +296,24 @@ trait CreateNamedStructLike extends Expression {
296296
}
297297

298298
override def checkInputDataTypes(): TypeCheckResult = {
299-
TypeUtils.checkTypeInputDimension(
300-
children.map(_.dataType), s"function $prettyName", requiredMinDimension = 1) match {
301-
case TypeCheckResult.TypeCheckSuccess =>
302-
if (children.size % 2 != 0) {
303-
TypeCheckResult.TypeCheckFailure(s"$prettyName expects an even number of arguments.")
299+
if (children.length < 1) {
300+
TypeCheckResult.TypeCheckFailure(
301+
s"input to function $prettyName requires at least one argument")
302+
} else {
303+
if (children.size % 2 != 0) {
304+
TypeCheckResult.TypeCheckFailure(s"$prettyName expects an even number of arguments.")
305+
} else {
306+
val invalidNames = nameExprs.filterNot(e => e.foldable && e.dataType == StringType)
307+
if (invalidNames.nonEmpty) {
308+
TypeCheckResult.TypeCheckFailure(
309+
"Only foldable StringType expressions are allowed to appear at odd position, got:" +
310+
s" ${invalidNames.mkString(",")}")
311+
} else if (!names.contains(null)) {
312+
TypeCheckResult.TypeCheckSuccess
304313
} else {
305-
val invalidNames = nameExprs.filterNot(e => e.foldable && e.dataType == StringType)
306-
if (invalidNames.nonEmpty) {
307-
TypeCheckResult.TypeCheckFailure(
308-
"Only foldable StringType expressions are allowed to appear at odd position, got:" +
309-
s" ${invalidNames.mkString(",")}")
310-
} else if (!names.contains(null)) {
311-
TypeCheckResult.TypeCheckSuccess
312-
} else {
313-
TypeCheckResult.TypeCheckFailure("Field name should not be null")
314-
}
314+
TypeCheckResult.TypeCheckFailure("Field name should not be null")
315315
}
316-
case typeCheckFailure =>
317-
typeCheckFailure
316+
}
318317
}
319318
}
320319

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ import org.apache.commons.codec.digest.DigestUtils
2828
import org.apache.spark.sql.catalyst.InternalRow
2929
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
3030
import org.apache.spark.sql.catalyst.expressions.codegen._
31-
import org.apache.spark.sql.catalyst.util.{ArrayData, MapData, TypeUtils}
31+
import org.apache.spark.sql.catalyst.util.{ArrayData, MapData}
3232
import org.apache.spark.sql.types._
3333
import org.apache.spark.unsafe.hash.Murmur3_x86_32
3434
import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}
@@ -247,8 +247,12 @@ abstract class HashExpression[E] extends Expression {
247247
override def nullable: Boolean = false
248248

249249
override def checkInputDataTypes(): TypeCheckResult = {
250-
TypeUtils.checkTypeInputDimension(
251-
children.map(_.dataType), s"function $prettyName", requiredMinDimension = 1)
250+
if (children.length < 1) {
251+
TypeCheckResult.TypeCheckFailure(
252+
s"input to function $prettyName requires at least one argument")
253+
} else {
254+
TypeCheckResult.TypeCheckSuccess
255+
}
252256
}
253257

254258
override def eval(input: InternalRow = null): Any = {

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -52,13 +52,11 @@ case class Coalesce(children: Seq[Expression]) extends Expression {
5252
override def foldable: Boolean = children.forall(_.foldable)
5353

5454
override def checkInputDataTypes(): TypeCheckResult = {
55-
val inputDataTypes = children.map(_.dataType)
56-
TypeUtils.checkTypeInputDimension(
57-
inputDataTypes, s"function $prettyName", requiredMinDimension = 1) match {
58-
case TypeCheckResult.TypeCheckSuccess =>
59-
TypeUtils.checkForSameTypeInputExpr(inputDataTypes, s"function $prettyName")
60-
case typeCheckFailure =>
61-
typeCheckFailure
55+
if (children.length < 1) {
56+
TypeCheckResult.TypeCheckFailure(
57+
s"input to function $prettyName requires at least one argument")
58+
} else {
59+
TypeUtils.checkForSameTypeInputExpr(children.map(_.dataType), s"function $prettyName")
6260
}
6361
}
6462

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -57,17 +57,6 @@ object TypeUtils {
5757
}
5858
}
5959

60-
def checkTypeInputDimension(types: Seq[DataType], caller: String, requiredMinDimension: Int)
61-
: TypeCheckResult = {
62-
if (types.size >= requiredMinDimension) {
63-
TypeCheckResult.TypeCheckSuccess
64-
} else {
65-
TypeCheckResult.TypeCheckFailure(
66-
s"input to $caller requires at least $requiredMinDimension " +
67-
s"argument${if (requiredMinDimension > 1) "s"}")
68-
}
69-
}
70-
7160
def getNumeric(t: DataType): Numeric[Any] =
7261
t.asInstanceOf[NumericType].numeric.asInstanceOf[Numeric[Any]]
7362

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,7 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite {
155155
"input to function array should all be the same type")
156156
assertError(Coalesce(Seq('intField, 'booleanField)),
157157
"input to function coalesce should all be the same type")
158-
assertError(Coalesce(Nil), "input to function coalesce cannot be empty")
158+
assertError(Coalesce(Nil), "function coalesce requires at least one argument")
159159
assertError(new Murmur3Hash(Nil), "function hash requires at least one argument")
160160
assertError(Explode('intField),
161161
"input to function explode should be array or map type")
@@ -207,7 +207,7 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite {
207207

208208
test("check types for Greatest/Least") {
209209
for (operator <- Seq[(Seq[Expression] => Expression)](Greatest, Least)) {
210-
assertError(operator(Seq('booleanField)), "requires at least 2 arguments")
210+
assertError(operator(Seq('booleanField)), "requires at least two arguments")
211211
assertError(operator(Seq('intField, 'stringField)), "should all have the same type")
212212
assertError(operator(Seq('mapField, 'mapField)), "does not support ordering")
213213
}

sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -471,7 +471,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext {
471471
("hash", (df: DataFrame) => df.selectExpr("hash()")) :: Nil
472472
funcsMustHaveAtLeastOneArg.foreach { case (name, func) =>
473473
val errMsg = intercept[AnalysisException] { func(df) }.getMessage
474-
assert(errMsg.contains(s"input to function $name requires at least 1 argument"))
474+
assert(errMsg.contains(s"input to function $name requires at least one argument"))
475475
}
476476

477477
val funcsMustHaveAtLeastTwoArgs =
@@ -481,7 +481,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext {
481481
("least", (df: DataFrame) => df.selectExpr("least()")) :: Nil
482482
funcsMustHaveAtLeastTwoArgs.foreach { case (name, func) =>
483483
val errMsg = intercept[AnalysisException] { func(df) }.getMessage
484-
assert(errMsg.contains(s"input to function $name requires at least 2 arguments"))
484+
assert(errMsg.contains(s"input to function $name requires at least two arguments"))
485485
}
486486
}
487487
}

0 commit comments

Comments
 (0)