From e08bfc19cd41f52ad37329ceaeea092aee3e8b4c Mon Sep 17 00:00:00 2001 From: "Vayda, Oleksandr: IT (PRG)" Date: Tue, 22 May 2018 17:46:26 +0200 Subject: [PATCH 1/2] SPARK-24350 "array_position" error fix --- .../catalyst/expressions/collectionOperations.scala | 10 ++++++++-- .../org/apache/spark/sql/DataFrameFunctionsSuite.scala | 4 ++++ 2 files changed, 12 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 7da4c3cc6b9f..21308bcec035 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -1394,8 +1394,14 @@ case class ArrayPosition(left: Expression, right: Expression) TypeUtils.getInterpretedOrdering(right.dataType) override def dataType: DataType = LongType - override def inputTypes: Seq[AbstractDataType] = - Seq(ArrayType, left.dataType.asInstanceOf[ArrayType].elementType) + + override def inputTypes: Seq[AbstractDataType] = { + val elementType = left.dataType match { + case t: ArrayType => t.elementType + case _ => AnyDataType + } + Seq(ArrayType, elementType) + } override def checkInputDataTypes(): TypeCheckResult = { super.checkInputDataTypes() match { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index df23e07e441a..1802ce0efc2a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -708,6 +708,10 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { df.selectExpr("array_position(array(1, null), array(1, null)[0])"), Seq(Row(1L), Row(1L)) ) + + intercept[AnalysisException] { + Seq(("a string element", "a")).toDF().selectExpr("array_position(_1, _2)") + } } test("element_at function") { From 8e8f62385b3a81cd21b5fd033a5dee9bc7d8a040 Mon Sep 17 00:00:00 2001 From: "Vayda, Oleksandr: IT (PRG)" Date: Tue, 22 May 2018 19:05:25 +0200 Subject: [PATCH 2/2] SPARK-24350 + error message check --- .../scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 1802ce0efc2a..afa952528631 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -709,9 +709,10 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { Seq(Row(1L), Row(1L)) ) - intercept[AnalysisException] { + val e = intercept[AnalysisException] { Seq(("a string element", "a")).toDF().selectExpr("array_position(_1, _2)") } + assert(e.message.contains("argument 1 requires array type, however, '`_1`' is of string type")) } test("element_at function") {