From 35370cd32e0fbb0cd777865a575960cdb9d80cd5 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Fri, 18 May 2018 11:17:12 +0200 Subject: [PATCH 1/8] [SPARK-24313][SQL] Fix array_contains interpreted evaluation for complex types --- .../expressions/collectionOperations.scala | 5 +++- .../CollectionExpressionsSuite.scala | 26 +++++++++++++++++++ 2 files changed, 30 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index c82db839438e..8186fc468fa9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -504,6 +504,9 @@ case class ArrayContains(left: Expression, right: Expression) override def dataType: DataType = BooleanType + @transient private lazy val ordering: Ordering[Any] = + TypeUtils.getInterpretedOrdering(right.dataType) + override def inputTypes: Seq[AbstractDataType] = right.dataType match { case NullType => Seq.empty case _ => left.dataType match { @@ -533,7 +536,7 @@ case class ArrayContains(left: Expression, right: Expression) arr.asInstanceOf[ArrayData].foreach(right.dataType, (i, v) => if (v == null) { hasNull = true - } else if (v == value) { + } else if (ordering.equiv(v, value)) { return true } ) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index 6ae1ac18c4dc..59127ed5d7cb 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -134,6 +134,32 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(ArrayContains(a3, Literal("")), null) checkEvaluation(ArrayContains(a3, Literal.create(null, StringType)), null) + + // complex data types + val b0 = Literal.create(Seq[Array[Byte]](Array[Byte](5, 6), Array[Byte](1, 2)), + ArrayType(BinaryType)) + val b1 = Literal.create(Seq[Array[Byte]](Array[Byte](2, 1), Array[Byte](4, 3)), + ArrayType(BinaryType)) + val b2 = Literal.create(Seq[Array[Byte]](Array[Byte](2, 1), null), + ArrayType(BinaryType)) + val b3 = Literal.create(Seq[Array[Byte]](null, Array[Byte](1, 2)), + ArrayType(BinaryType)) + val be = Literal.create(Array[Byte](1, 2), BinaryType) + val nullBinary = Literal.create(null, BinaryType) + + checkEvaluation(ArrayContains(b0, be), true) + checkEvaluation(ArrayContains(b1, be), false) + checkEvaluation(ArrayContains(b0, nullBinary), null) + checkEvaluation(ArrayContains(b2, be), null) + checkEvaluation(ArrayContains(b3, be), true) + + val aa0 = Literal.create(Seq[Seq[Int]](Seq[Int](1, 2), Seq[Int](3, 4)), + ArrayType(ArrayType(IntegerType))) + val aa1 = Literal.create(Seq[Seq[Int]](Seq[Int](5, 6), Seq[Int](2, 1)), + ArrayType(ArrayType(IntegerType))) + val aae = Literal.create(Seq[Int](1, 2), ArrayType(IntegerType)) + checkEvaluation(ArrayContains(aa0, aae), true) + checkEvaluation(ArrayContains(aa1, aae), false) } test("ArraysOverlap") { From 8c9553f5d4018d9ccf01ff8a68af2e2a29b8214f Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Fri, 18 May 2018 11:57:32 +0200 Subject: [PATCH 2/8] add orderable check --- .../sql/catalyst/expressions/collectionOperations.scala | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 8186fc468fa9..48c3cdf58638 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -523,7 +523,12 @@ case class ArrayContains(left: Expression, right: Expression) TypeCheckResult.TypeCheckFailure( "Arguments must be an array followed by a value of same type as the array members") } else { - TypeCheckResult.TypeCheckSuccess + if (RowOrdering.isOrderable(right.dataType)) { + TypeCheckResult.TypeCheckSuccess + } else { + TypeCheckResult.TypeCheckFailure( + s"${right.dataType.simpleString} cannot be used in comparison.") + } } } From b7d7249edfde57dd76bb7150e064a1070510c98f Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Fri, 18 May 2018 13:02:32 +0200 Subject: [PATCH 3/8] fix array position --- .../expressions/collectionOperations.scala | 18 +++++++++++++++++- .../CollectionExpressionsSuite.scala | 8 ++++++++ 2 files changed, 25 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 48c3cdf58638..1b22e9fadd6d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -1246,13 +1246,29 @@ case class ArrayMax(child: Expression) extends UnaryExpression with ImplicitCast case class ArrayPosition(left: Expression, right: Expression) extends BinaryExpression with ImplicitCastInputTypes { + @transient private lazy val ordering: Ordering[Any] = + TypeUtils.getInterpretedOrdering(right.dataType) + override def dataType: DataType = LongType override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType, left.dataType.asInstanceOf[ArrayType].elementType) + override def checkInputDataTypes(): TypeCheckResult = { + super.checkInputDataTypes() match { + case f: TypeCheckResult.TypeCheckFailure => f + case TypeCheckResult.TypeCheckSuccess => + if (RowOrdering.isOrderable(right.dataType)) { + TypeCheckResult.TypeCheckSuccess + } else { + TypeCheckResult.TypeCheckFailure( + s"${right.dataType.simpleString} cannot be used in comparison.") + } + } + } + override def nullSafeEval(arr: Any, value: Any): Any = { arr.asInstanceOf[ArrayData].foreach(right.dataType, (i, v) => - if (v == value) { + if (v != null && ordering.equiv(v, value)) { return (i + 1).toLong } ) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index 59127ed5d7cb..04a8912abe12 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -375,6 +375,14 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(ArrayPosition(a3, Literal("")), null) checkEvaluation(ArrayPosition(a3, Literal.create(null, StringType)), null) + + val aa0 = Literal.create(Seq[Seq[Int]](Seq[Int](1, 2), Seq[Int](3, 4)), + ArrayType(ArrayType(IntegerType))) + val aa1 = Literal.create(Seq[Seq[Int]](Seq[Int](5, 6), Seq[Int](2, 1)), + ArrayType(ArrayType(IntegerType))) + val aae = Literal.create(Seq[Int](1, 2), ArrayType(IntegerType)) + checkEvaluation(ArrayPosition(aa0, aae), 1L) + checkEvaluation(ArrayPosition(aa1, aae), 0L) } test("elementAt") { From a819a97ad3a8c9967c30213890ffba5050e62c91 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Fri, 18 May 2018 13:56:55 +0200 Subject: [PATCH 4/8] fix element_at and GetValueMap --- .../expressions/collectionOperations.scala | 17 ++++++++++++- .../expressions/complexTypeExtractors.scala | 24 +++++++++++++++---- .../CollectionExpressionsSuite.scala | 15 +++++++++++- .../optimizer/complexTypesSuite.scala | 13 ++++++++++ 4 files changed, 63 insertions(+), 6 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 1b22e9fadd6d..1c8a84d0838f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -1317,6 +1317,9 @@ case class ArrayPosition(left: Expression, right: Expression) since = "2.4.0") case class ElementAt(left: Expression, right: Expression) extends GetMapValueUtil { + @transient private lazy val ordering: Ordering[Any] = + TypeUtils.getInterpretedOrdering(left.dataType.asInstanceOf[MapType].keyType) + override def dataType: DataType = left.dataType match { case ArrayType(elementType, _) => elementType case MapType(_, valueType, _) => valueType @@ -1331,6 +1334,18 @@ case class ElementAt(left: Expression, right: Expression) extends GetMapValueUti ) } + override def checkInputDataTypes(): TypeCheckResult = { + def mapKeyType = left.dataType.asInstanceOf[MapType].keyType + super.checkInputDataTypes() match { + case f: TypeCheckResult.TypeCheckFailure => f + case TypeCheckResult.TypeCheckSuccess + if left.dataType.isInstanceOf[MapType] && !RowOrdering.isOrderable(mapKeyType) => + TypeCheckResult.TypeCheckFailure( + s"${mapKeyType.simpleString} cannot be used in comparison.") + case TypeCheckResult.TypeCheckSuccess => TypeCheckResult.TypeCheckSuccess + } + } + override def nullable: Boolean = true override def nullSafeEval(value: Any, ordinal: Any): Any = { @@ -1355,7 +1370,7 @@ case class ElementAt(left: Expression, right: Expression) extends GetMapValueUti } } case _: MapType => - getValueEval(value, ordinal, left.dataType.asInstanceOf[MapType].keyType) + getValueEval(value, ordinal, left.dataType.asInstanceOf[MapType].keyType, ordering) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala index 3fba52d74545..9d3f37d0c96a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala @@ -21,7 +21,7 @@ import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode} -import org.apache.spark.sql.catalyst.util.{quoteIdentifier, ArrayData, GenericArrayData, MapData} +import org.apache.spark.sql.catalyst.util.{quoteIdentifier, ArrayData, GenericArrayData, MapData, TypeUtils} import org.apache.spark.sql.types._ //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -273,7 +273,7 @@ case class GetArrayItem(child: Expression, ordinal: Expression) abstract class GetMapValueUtil extends BinaryExpression with ImplicitCastInputTypes { // todo: current search is O(n), improve it. - def getValueEval(value: Any, ordinal: Any, keyType: DataType): Any = { + def getValueEval(value: Any, ordinal: Any, keyType: DataType, ordering: Ordering[Any]): Any = { val map = value.asInstanceOf[MapData] val length = map.numElements() val keys = map.keyArray() @@ -282,7 +282,7 @@ abstract class GetMapValueUtil extends BinaryExpression with ImplicitCastInputTy var i = 0 var found = false while (i < length && !found) { - if (keys.get(i, keyType) == ordinal) { + if (ordering.equiv(keys.get(i, keyType), ordinal)) { found = true } else { i += 1 @@ -345,8 +345,24 @@ abstract class GetMapValueUtil extends BinaryExpression with ImplicitCastInputTy case class GetMapValue(child: Expression, key: Expression) extends GetMapValueUtil with ExtractValue with NullIntolerant { + @transient private lazy val ordering: Ordering[Any] = + TypeUtils.getInterpretedOrdering(keyType) + private def keyType = child.dataType.asInstanceOf[MapType].keyType + override def checkInputDataTypes(): TypeCheckResult = { + super.checkInputDataTypes() match { + case f: TypeCheckResult.TypeCheckFailure => f + case TypeCheckResult.TypeCheckSuccess => + if (RowOrdering.isOrderable(keyType)) { + TypeCheckResult.TypeCheckSuccess + } else { + TypeCheckResult.TypeCheckFailure( + s"${keyType.simpleString} cannot be used in comparison.") + } + } + } + // We have done type checking for child in `ExtractValue`, so only need to check the `key`. override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType, keyType) @@ -363,7 +379,7 @@ case class GetMapValue(child: Expression, key: Expression) // todo: current search is O(n), improve it. override def nullSafeEval(value: Any, ordinal: Any): Any = { - getValueEval(value, ordinal, keyType) + getValueEval(value, ordinal, keyType, ordering) } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index 04a8912abe12..6beaf7603a1f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -420,7 +420,7 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper val m1 = Literal.create(Map[String, String](), MapType(StringType, StringType)) val m2 = Literal.create(null, MapType(StringType, StringType)) - checkEvaluation(ElementAt(m0, Literal(1.0)), null) + assert(ElementAt(m0, Literal(1.0)).checkInputDataTypes().isFailure) checkEvaluation(ElementAt(m0, Literal("d")), null) @@ -431,6 +431,19 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(ElementAt(m0, Literal("c")), null) checkEvaluation(ElementAt(m2, Literal("a")), null) + + // test complex types as keys + val mb0 = Literal.create( + Map(Array[Byte](1, 2) -> "1", Array[Byte](3, 4) -> null, Array[Byte](2, 1) -> "2"), + MapType(BinaryType, StringType)) + val mb1 = Literal.create(Map[Array[Byte], String](), MapType(BinaryType, StringType)) + + checkEvaluation(ElementAt(mb0, Literal(Array[Byte](1, 2, 3))), null) + + checkEvaluation(ElementAt(mb1, Literal(Array[Byte](1, 2))), null) + checkEvaluation(ElementAt(mb0, Literal(Array[Byte](2, 1), BinaryType)), "2") + checkEvaluation(ElementAt(mb0, Literal(Array[Byte](3, 4))), null) + } test("Concat") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/complexTypesSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/complexTypesSuite.scala index 633d86d49558..06866fe0c10e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/complexTypesSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/complexTypesSuite.scala @@ -439,4 +439,17 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper { .select('c as 'sCol2, 'a as 'sCol1) checkRule(originalQuery, correctAnswer) } + + test("SPARK-24313: support complex types as map keys") { + val mb0 = Literal.create( + Map(Array[Byte](1, 2) -> "1", Array[Byte](3, 4) -> null, Array[Byte](2, 1) -> "2"), + MapType(BinaryType, StringType)) + val mb1 = Literal.create(Map[Array[Byte], String](), MapType(BinaryType, StringType)) + + checkEvaluation(GetMapValue(mb0, Literal(Array[Byte](1, 2, 3))), null) + + checkEvaluation(GetMapValue(mb1, Literal(Array[Byte](1, 2))), null) + checkEvaluation(GetMapValue(mb0, Literal(Array[Byte](2, 1), BinaryType)), "2") + checkEvaluation(GetMapValue(mb0, Literal(Array[Byte](3, 4))), null) + } } From 71916d93cdb300d1935b12994f1121c30b97d23f Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Fri, 18 May 2018 15:22:39 +0200 Subject: [PATCH 5/8] address comments --- .../expressions/collectionOperations.scala | 28 ++++--------------- .../expressions/complexTypeExtractors.scala | 7 +---- 2 files changed, 7 insertions(+), 28 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 1c8a84d0838f..01e435185300 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -523,12 +523,7 @@ case class ArrayContains(left: Expression, right: Expression) TypeCheckResult.TypeCheckFailure( "Arguments must be an array followed by a value of same type as the array members") } else { - if (RowOrdering.isOrderable(right.dataType)) { - TypeCheckResult.TypeCheckSuccess - } else { - TypeCheckResult.TypeCheckFailure( - s"${right.dataType.simpleString} cannot be used in comparison.") - } + TypeUtils.checkForOrderingExpr(right.dataType, s"function $prettyName") } } @@ -590,11 +585,7 @@ case class ArraysOverlap(left: Expression, right: Expression) override def checkInputDataTypes(): TypeCheckResult = super.checkInputDataTypes() match { case TypeCheckResult.TypeCheckSuccess => - if (RowOrdering.isOrderable(elementType)) { - TypeCheckResult.TypeCheckSuccess - } else { - TypeCheckResult.TypeCheckFailure(s"${elementType.simpleString} cannot be used in comparison.") - } + TypeUtils.checkForOrderingExpr(elementType, s"function $prettyName") case failure => failure } @@ -1257,12 +1248,7 @@ case class ArrayPosition(left: Expression, right: Expression) super.checkInputDataTypes() match { case f: TypeCheckResult.TypeCheckFailure => f case TypeCheckResult.TypeCheckSuccess => - if (RowOrdering.isOrderable(right.dataType)) { - TypeCheckResult.TypeCheckSuccess - } else { - TypeCheckResult.TypeCheckFailure( - s"${right.dataType.simpleString} cannot be used in comparison.") - } + TypeUtils.checkForOrderingExpr(right.dataType, s"function $prettyName") } } @@ -1335,13 +1321,11 @@ case class ElementAt(left: Expression, right: Expression) extends GetMapValueUti } override def checkInputDataTypes(): TypeCheckResult = { - def mapKeyType = left.dataType.asInstanceOf[MapType].keyType super.checkInputDataTypes() match { case f: TypeCheckResult.TypeCheckFailure => f - case TypeCheckResult.TypeCheckSuccess - if left.dataType.isInstanceOf[MapType] && !RowOrdering.isOrderable(mapKeyType) => - TypeCheckResult.TypeCheckFailure( - s"${mapKeyType.simpleString} cannot be used in comparison.") + case TypeCheckResult.TypeCheckSuccess if left.dataType.isInstanceOf[MapType] => + TypeUtils.checkForOrderingExpr( + left.dataType.asInstanceOf[MapType].keyType, s"function $prettyName") case TypeCheckResult.TypeCheckSuccess => TypeCheckResult.TypeCheckSuccess } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala index 9d3f37d0c96a..99671d5b863c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala @@ -354,12 +354,7 @@ case class GetMapValue(child: Expression, key: Expression) super.checkInputDataTypes() match { case f: TypeCheckResult.TypeCheckFailure => f case TypeCheckResult.TypeCheckSuccess => - if (RowOrdering.isOrderable(keyType)) { - TypeCheckResult.TypeCheckSuccess - } else { - TypeCheckResult.TypeCheckFailure( - s"${keyType.simpleString} cannot be used in comparison.") - } + TypeUtils.checkForOrderingExpr(keyType, s"function $prettyName") } } From c74ddd793d1e34b66a6439ebf0d45ff128136d3d Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Sun, 20 May 2018 13:44:40 +0200 Subject: [PATCH 6/8] address comments --- .../catalyst/expressions/CollectionExpressionsSuite.scala | 6 +++--- .../spark/sql/catalyst/optimizer/complexTypesSuite.scala | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index 6beaf7603a1f..ed4a96843d00 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -135,7 +135,7 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(ArrayContains(a3, Literal("")), null) checkEvaluation(ArrayContains(a3, Literal.create(null, StringType)), null) - // complex data types + // binary val b0 = Literal.create(Seq[Array[Byte]](Array[Byte](5, 6), Array[Byte](1, 2)), ArrayType(BinaryType)) val b1 = Literal.create(Seq[Array[Byte]](Array[Byte](2, 1), Array[Byte](4, 3)), @@ -153,6 +153,7 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(ArrayContains(b2, be), null) checkEvaluation(ArrayContains(b3, be), true) + // complex data types val aa0 = Literal.create(Seq[Seq[Int]](Seq[Int](1, 2), Seq[Int](3, 4)), ArrayType(ArrayType(IntegerType))) val aa1 = Literal.create(Seq[Seq[Int]](Seq[Int](5, 6), Seq[Int](2, 1)), @@ -432,7 +433,7 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(ElementAt(m2, Literal("a")), null) - // test complex types as keys + // test binary type as keys val mb0 = Literal.create( Map(Array[Byte](1, 2) -> "1", Array[Byte](3, 4) -> null, Array[Byte](2, 1) -> "2"), MapType(BinaryType, StringType)) @@ -443,7 +444,6 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(ElementAt(mb1, Literal(Array[Byte](1, 2))), null) checkEvaluation(ElementAt(mb0, Literal(Array[Byte](2, 1), BinaryType)), "2") checkEvaluation(ElementAt(mb0, Literal(Array[Byte](3, 4))), null) - } test("Concat") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/complexTypesSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/complexTypesSuite.scala index 06866fe0c10e..5452e72b3864 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/complexTypesSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/complexTypesSuite.scala @@ -440,7 +440,7 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper { checkRule(originalQuery, correctAnswer) } - test("SPARK-24313: support complex types as map keys") { + test("SPARK-24313: support binary type as map keys in GetMapValue") { val mb0 = Literal.create( Map(Array[Byte](1, 2) -> "1", Array[Byte](3, 4) -> null, Array[Byte](2, 1) -> "2"), MapType(BinaryType, StringType)) From 3f8624bf23124a1a030b4c895aebcb494ff62896 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Mon, 21 May 2018 15:54:22 +0200 Subject: [PATCH 7/8] add end to end test --- .../src/test/scala/org/apache/spark/sql/DataFrameSuite.scala | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 60e84e6ee750..8459e4392935 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -74,6 +74,10 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { assert(complexData.filter(complexData("a").getItem(0) === 2).count() == 1) assert(complexData.filter(complexData("m").getItem("1") === 1).count() == 1) assert(complexData.filter(complexData("s").getField("key") === 1).count() == 1) + + // SPARK-24313: access binary keys + val mapWithBinaryKey = map(lit(Array[Byte](1.toByte)), lit(1)) + checkAnswer(spark.range(1).select(mapWithBinaryKey.getItem(Array[Byte](1.toByte))), Row(1)) } test("table scan") { From 63157756fde2b4e9dc80fce0296da68067c57422 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Mon, 21 May 2018 16:14:30 +0200 Subject: [PATCH 8/8] address comment --- .../test/scala/org/apache/spark/sql/DataFrameSuite.scala | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 8459e4392935..1cc8cb3874c9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -74,10 +74,6 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { assert(complexData.filter(complexData("a").getItem(0) === 2).count() == 1) assert(complexData.filter(complexData("m").getItem("1") === 1).count() == 1) assert(complexData.filter(complexData("s").getField("key") === 1).count() == 1) - - // SPARK-24313: access binary keys - val mapWithBinaryKey = map(lit(Array[Byte](1.toByte)), lit(1)) - checkAnswer(spark.range(1).select(mapWithBinaryKey.getItem(Array[Byte](1.toByte))), Row(1)) } test("table scan") { @@ -2269,4 +2265,9 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { val df = spark.range(1).select($"id", new Column(Uuid())) checkAnswer(df, df.collect()) } + + test("SPARK-24313: access map with binary keys") { + val mapWithBinaryKey = map(lit(Array[Byte](1.toByte)), lit(1)) + checkAnswer(spark.range(1).select(mapWithBinaryKey.getItem(Array[Byte](1.toByte))), Row(1)) + } }