-
Notifications
You must be signed in to change notification settings - Fork 28.9k
[SPARK-24313][SQL] Fix collection operations' interpreted evaluation for complex types #21361
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
35370cd
8c9553f
b7d7249
a819a97
71916d9
c74ddd7
3f8624b
6315775
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -504,6 +504,9 @@ case class ArrayContains(left: Expression, right: Expression) | |
|
|
||
| override def dataType: DataType = BooleanType | ||
|
|
||
| @transient private lazy val ordering: Ordering[Any] = | ||
| TypeUtils.getInterpretedOrdering(right.dataType) | ||
|
|
||
| override def inputTypes: Seq[AbstractDataType] = right.dataType match { | ||
| case NullType => Seq.empty | ||
| case _ => left.dataType match { | ||
|
|
@@ -520,7 +523,7 @@ case class ArrayContains(left: Expression, right: Expression) | |
| TypeCheckResult.TypeCheckFailure( | ||
| "Arguments must be an array followed by a value of same type as the array members") | ||
| } else { | ||
| TypeCheckResult.TypeCheckSuccess | ||
| TypeUtils.checkForOrderingExpr(right.dataType, s"function $prettyName") | ||
| } | ||
| } | ||
|
|
||
|
|
@@ -533,7 +536,7 @@ case class ArrayContains(left: Expression, right: Expression) | |
| arr.asInstanceOf[ArrayData].foreach(right.dataType, (i, v) => | ||
| if (v == null) { | ||
| hasNull = true | ||
| } else if (v == value) { | ||
| } else if (ordering.equiv(v, value)) { | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Previously does this work for Map? No?
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. MapType is not supported in comparison, even |
||
| return true | ||
| } | ||
| ) | ||
|
|
@@ -582,11 +585,7 @@ case class ArraysOverlap(left: Expression, right: Expression) | |
|
|
||
| override def checkInputDataTypes(): TypeCheckResult = super.checkInputDataTypes() match { | ||
| case TypeCheckResult.TypeCheckSuccess => | ||
| if (RowOrdering.isOrderable(elementType)) { | ||
| TypeCheckResult.TypeCheckSuccess | ||
| } else { | ||
| TypeCheckResult.TypeCheckFailure(s"${elementType.simpleString} cannot be used in comparison.") | ||
| } | ||
| TypeUtils.checkForOrderingExpr(elementType, s"function $prettyName") | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Also a general suggestion. For these refactoring, we should do it in a separate PR.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ok, I'll keep this in mind for the future, thanks. |
||
| case failure => failure | ||
| } | ||
|
|
||
|
|
@@ -1238,13 +1237,24 @@ case class ArrayMax(child: Expression) extends UnaryExpression with ImplicitCast | |
| case class ArrayPosition(left: Expression, right: Expression) | ||
| extends BinaryExpression with ImplicitCastInputTypes { | ||
|
|
||
| @transient private lazy val ordering: Ordering[Any] = | ||
| TypeUtils.getInterpretedOrdering(right.dataType) | ||
|
|
||
| override def dataType: DataType = LongType | ||
| override def inputTypes: Seq[AbstractDataType] = | ||
| Seq(ArrayType, left.dataType.asInstanceOf[ArrayType].elementType) | ||
|
|
||
| override def checkInputDataTypes(): TypeCheckResult = { | ||
| super.checkInputDataTypes() match { | ||
| case f: TypeCheckResult.TypeCheckFailure => f | ||
| case TypeCheckResult.TypeCheckSuccess => | ||
| TypeUtils.checkForOrderingExpr(right.dataType, s"function $prettyName") | ||
| } | ||
| } | ||
|
|
||
| override def nullSafeEval(arr: Any, value: Any): Any = { | ||
| arr.asInstanceOf[ArrayData].foreach(right.dataType, (i, v) => | ||
| if (v == value) { | ||
| if (v != null && ordering.equiv(v, value)) { | ||
| return (i + 1).toLong | ||
| } | ||
| ) | ||
|
|
@@ -1293,6 +1303,9 @@ case class ArrayPosition(left: Expression, right: Expression) | |
| since = "2.4.0") | ||
| case class ElementAt(left: Expression, right: Expression) extends GetMapValueUtil { | ||
|
|
||
| @transient private lazy val ordering: Ordering[Any] = | ||
| TypeUtils.getInterpretedOrdering(left.dataType.asInstanceOf[MapType].keyType) | ||
|
|
||
| override def dataType: DataType = left.dataType match { | ||
| case ArrayType(elementType, _) => elementType | ||
| case MapType(_, valueType, _) => valueType | ||
|
|
@@ -1307,6 +1320,16 @@ case class ElementAt(left: Expression, right: Expression) extends GetMapValueUti | |
| ) | ||
| } | ||
|
|
||
| override def checkInputDataTypes(): TypeCheckResult = { | ||
| super.checkInputDataTypes() match { | ||
| case f: TypeCheckResult.TypeCheckFailure => f | ||
| case TypeCheckResult.TypeCheckSuccess if left.dataType.isInstanceOf[MapType] => | ||
| TypeUtils.checkForOrderingExpr( | ||
| left.dataType.asInstanceOf[MapType].keyType, s"function $prettyName") | ||
| case TypeCheckResult.TypeCheckSuccess => TypeCheckResult.TypeCheckSuccess | ||
| } | ||
| } | ||
|
|
||
| override def nullable: Boolean = true | ||
|
|
||
| override def nullSafeEval(value: Any, ordinal: Any): Any = { | ||
|
|
@@ -1331,7 +1354,7 @@ case class ElementAt(left: Expression, right: Expression) extends GetMapValueUti | |
| } | ||
| } | ||
| case _: MapType => | ||
| getValueEval(value, ordinal, left.dataType.asInstanceOf[MapType].keyType) | ||
| getValueEval(value, ordinal, left.dataType.asInstanceOf[MapType].keyType, ordering) | ||
| } | ||
| } | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -134,6 +134,33 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper | |
|
|
||
| checkEvaluation(ArrayContains(a3, Literal("")), null) | ||
| checkEvaluation(ArrayContains(a3, Literal.create(null, StringType)), null) | ||
|
|
||
| // binary | ||
| val b0 = Literal.create(Seq[Array[Byte]](Array[Byte](5, 6), Array[Byte](1, 2)), | ||
| ArrayType(BinaryType)) | ||
| val b1 = Literal.create(Seq[Array[Byte]](Array[Byte](2, 1), Array[Byte](4, 3)), | ||
| ArrayType(BinaryType)) | ||
| val b2 = Literal.create(Seq[Array[Byte]](Array[Byte](2, 1), null), | ||
| ArrayType(BinaryType)) | ||
| val b3 = Literal.create(Seq[Array[Byte]](null, Array[Byte](1, 2)), | ||
| ArrayType(BinaryType)) | ||
| val be = Literal.create(Array[Byte](1, 2), BinaryType) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ditto, binary type is not complex type |
||
| val nullBinary = Literal.create(null, BinaryType) | ||
|
|
||
| checkEvaluation(ArrayContains(b0, be), true) | ||
| checkEvaluation(ArrayContains(b1, be), false) | ||
| checkEvaluation(ArrayContains(b0, nullBinary), null) | ||
| checkEvaluation(ArrayContains(b2, be), null) | ||
| checkEvaluation(ArrayContains(b3, be), true) | ||
|
|
||
| // complex data types | ||
| val aa0 = Literal.create(Seq[Seq[Int]](Seq[Int](1, 2), Seq[Int](3, 4)), | ||
| ArrayType(ArrayType(IntegerType))) | ||
| val aa1 = Literal.create(Seq[Seq[Int]](Seq[Int](5, 6), Seq[Int](2, 1)), | ||
| ArrayType(ArrayType(IntegerType))) | ||
| val aae = Literal.create(Seq[Int](1, 2), ArrayType(IntegerType)) | ||
| checkEvaluation(ArrayContains(aa0, aae), true) | ||
| checkEvaluation(ArrayContains(aa1, aae), false) | ||
| } | ||
|
|
||
| test("ArraysOverlap") { | ||
|
|
@@ -349,6 +376,14 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper | |
|
|
||
| checkEvaluation(ArrayPosition(a3, Literal("")), null) | ||
| checkEvaluation(ArrayPosition(a3, Literal.create(null, StringType)), null) | ||
|
|
||
| val aa0 = Literal.create(Seq[Seq[Int]](Seq[Int](1, 2), Seq[Int](3, 4)), | ||
| ArrayType(ArrayType(IntegerType))) | ||
| val aa1 = Literal.create(Seq[Seq[Int]](Seq[Int](5, 6), Seq[Int](2, 1)), | ||
| ArrayType(ArrayType(IntegerType))) | ||
| val aae = Literal.create(Seq[Int](1, 2), ArrayType(IntegerType)) | ||
| checkEvaluation(ArrayPosition(aa0, aae), 1L) | ||
| checkEvaluation(ArrayPosition(aa1, aae), 0L) | ||
| } | ||
|
|
||
| test("elementAt") { | ||
|
|
@@ -386,7 +421,7 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper | |
| val m1 = Literal.create(Map[String, String](), MapType(StringType, StringType)) | ||
| val m2 = Literal.create(null, MapType(StringType, StringType)) | ||
|
|
||
| checkEvaluation(ElementAt(m0, Literal(1.0)), null) | ||
| assert(ElementAt(m0, Literal(1.0)).checkInputDataTypes().isFailure) | ||
|
|
||
| checkEvaluation(ElementAt(m0, Literal("d")), null) | ||
|
|
||
|
|
@@ -397,6 +432,18 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper | |
| checkEvaluation(ElementAt(m0, Literal("c")), null) | ||
|
|
||
| checkEvaluation(ElementAt(m2, Literal("a")), null) | ||
|
|
||
| // test binary type as keys | ||
| val mb0 = Literal.create( | ||
| Map(Array[Byte](1, 2) -> "1", Array[Byte](3, 4) -> null, Array[Byte](2, 1) -> "2"), | ||
| MapType(BinaryType, StringType)) | ||
| val mb1 = Literal.create(Map[Array[Byte], String](), MapType(BinaryType, StringType)) | ||
|
|
||
| checkEvaluation(ElementAt(mb0, Literal(Array[Byte](1, 2, 3))), null) | ||
|
|
||
| checkEvaluation(ElementAt(mb1, Literal(Array[Byte](1, 2))), null) | ||
| checkEvaluation(ElementAt(mb0, Literal(Array[Byte](2, 1), BinaryType)), "2") | ||
| checkEvaluation(ElementAt(mb0, Literal(Array[Byte](3, 4))), null) | ||
| } | ||
|
|
||
| test("Concat") { | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Then in
checkInputDataTypeswe should check if there isOrderingforright.dataType. Otherwise for example MapType will throw a match error.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, thanks