From f8dd1a6008acc82b15fb63607c1ccd84e5487039 Mon Sep 17 00:00:00 2001 From: jinxing Date: Sat, 13 Oct 2018 00:55:37 +0800 Subject: [PATCH] [SPARK-25724] Add sorting functionality in MapType. --- .../expressions/codegen/CodeGenerator.scala | 44 +++++++++ .../sql/catalyst/expressions/ordering.scala | 4 + .../spark/sql/catalyst/util/TypeUtils.scala | 3 + .../org/apache/spark/sql/types/MapType.scala | 85 +++++++++++++++++ .../catalyst/expressions/OrderingSuite.scala | 95 +++++++++++++++---- 5 files changed, 210 insertions(+), 21 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index d5857e060a2c4..2269ace230f94 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -606,6 +606,7 @@ class CodegenContext { case dt: DataType if dt.isInstanceOf[AtomicType] => s"$c1.equals($c2)" case array: ArrayType => genComp(array, c1, c2) + " == 0" case struct: StructType => genComp(struct, c1, c2) + " == 0" + case map: MapType if map.isOrdered => genComp(map, c1, c2) + " == 0" case udt: UserDefinedType[_] => genEqual(udt.sqlType, c1, c2) case NullType => "false" case _ => @@ -677,6 +678,49 @@ class CodegenContext { } """ s"${addNewFunction(compareFunc, funcCode)}($c1, $c2)" + + case mapType @ MapType(keyType, valueType, valueContainsNull) if mapType.isOrdered => + val compareFunc = freshName("compareMap") + val funcCode: String = + s""" + public int $compareFunc(MapData a, MapData b) { + int lengthA = a.numElements(); + int lengthB = b.numElements(); + ArrayData aKeys = a.keyArray(); + ArrayData aValues = a.valueArray(); + ArrayData bKeys = b.keyArray(); + ArrayData bValues = b.valueArray(); + int minLength = (lengthA > lengthB) ? lengthB : lengthA; + for (int i = 0; i < minLength; i++) { + ${javaType(keyType)} keyA = ${getValue("aKeys", keyType, "i")}; + ${javaType(keyType)} keyB = ${getValue("bKeys", keyType, "i")}; + int comp = ${genComp(keyType, "keyA", "keyB")}; + if (comp != 0) { + return comp; + } + boolean isNullA = aValues.isNullAt(i); + boolean isNullB = bValues.isNullAt(i); + if (isNullA && isNullB) { + // Nothing + } else if (isNullA) { + return -1; + } else if (isNullB) { + return 1; + } else { + ${javaType(valueType)} valueA = ${getValue("aValues", valueType, "i")}; + ${javaType(valueType)} valueB = ${getValue("bValues", valueType, "i")}; + comp = ${genComp(valueType, "valueA", "valueB")}; + if (comp != 0) { + return comp; + } + } + } + return lengthA - lengthB; + } + """ + addNewFunction(compareFunc, funcCode) + s"this.$compareFunc($c1, $c2)" + case schema: StructType => val comparisons = GenerateOrdering.genComparisons(this, schema) val compareFunc = freshName("compareStruct") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ordering.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ordering.scala index e24a3de3cfdbe..9b3210af670ce 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ordering.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ordering.scala @@ -53,6 +53,10 @@ class InterpretedOrdering(ordering: Seq[SortOrder]) extends Ordering[InternalRow a.interpretedOrdering.asInstanceOf[Ordering[Any]].compare(left, right) case a: ArrayType if order.direction == Descending => a.interpretedOrdering.asInstanceOf[Ordering[Any]].reverse.compare(left, right) + case m: MapType if m.isOrdered && order.direction == Ascending => + m.interpretedOrdering.asInstanceOf[Ordering[Any]].compare(left, right) + case m: MapType if m.isOrdered && order.direction == Descending => + m.interpretedOrdering.asInstanceOf[Ordering[Any]].reverse.compare(left, right) case s: StructType if order.direction == Ascending => s.interpretedOrdering.asInstanceOf[Ordering[Any]].compare(left, right) case s: StructType if order.direction == Descending => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala index 76218b459ef0d..787c6f0acb918 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala @@ -59,8 +59,11 @@ object TypeUtils { t match { case i: AtomicType => i.ordering.asInstanceOf[Ordering[Any]] case a: ArrayType => a.interpretedOrdering.asInstanceOf[Ordering[Any]] + case m: MapType if m.isOrdered => m.interpretedOrdering.asInstanceOf[Ordering[Any]] case s: StructType => s.interpretedOrdering.asInstanceOf[Ordering[Any]] case udt: UserDefinedType[_] => getInterpretedOrdering(udt.sqlType) + case other => + throw new IllegalArgumentException(s"Type $other does not support ordered operations") } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapType.scala index 594e155268bf6..82e920f0fdb43 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapType.scala @@ -21,6 +21,7 @@ import org.json4s.JsonAST.JValue import org.json4s.JsonDSL._ import org.apache.spark.annotation.InterfaceStability +import org.apache.spark.sql.catalyst.util.{MapData, TypeUtils} /** * The data type for Maps. Keys in a map are not allowed to have `null` values. @@ -73,6 +74,90 @@ case class MapType( override private[spark] def existsRecursively(f: (DataType) => Boolean): Boolean = { f(this) || keyType.existsRecursively(f) || valueType.existsRecursively(f) } + + private[this] class OrderedWrapper { + var isOrdered: Boolean = false + } + + private[this] lazy val orderedWrapper: OrderedWrapper = new OrderedWrapper() + + private[sql] def setOrdered(b: Boolean): Unit = { + orderedWrapper.isOrdered = b + } + + // Indicates if a map is itself "ordered". It makes sense to compare two + // maps only when they are themselves "ordered", i.e. entries of the map are sorted. + // This parameter is used by internal when doing ordering operation, e.g. sort + // values of `MapType`. + private[sql] def isOrdered(): Boolean = orderedWrapper.isOrdered + + // This is used to sort the entries of a map. + @transient + private[sql] lazy val interpretedKeyOrdering: Ordering[Any] = + TypeUtils.getInterpretedOrdering(keyType) + + @transient + private[this] lazy val interpretedValueOrdering: Ordering[Any] = + TypeUtils.getInterpretedOrdering(valueType) + + @transient + private[sql] lazy val interpretedOrdering: Ordering[MapData] = new Ordering[MapData] { + val keyOrdering = interpretedKeyOrdering + val valueOrdering = interpretedValueOrdering + + // The approach to compare (left: MapData, right: MapData): + // 1. The precondition is that entries inside `left` and `right` are already sorted themselves; + // 2. Compare entries from `left` and `right`, say entryA(keyA, valueA) is from `left` and + // entryB(keyB, valueB) is from `right`: + // a. entryA is bigger than entryB if keyA is bigger than keyB and vice versa; + // b. entryA is bigger than entryB if keyA equals to keyB and valueA is bigger than + // valueB and vice versa; + // 3. If entries from the head equals to each other between `left` and `right`, the `MapData` + // with more entries is bigger. + def compare(left: MapData, right: MapData): Int = { + val leftKeys = left.keyArray() + val leftValues = left.valueArray() + val rightKeys = right.keyArray() + val rightValues = right.valueArray() + val minLength = scala.math.min(leftKeys.numElements(), rightKeys.numElements()) + var i = 0 + while (i < minLength) { + val keyComp = keyOrdering.compare(leftKeys.get(i, keyType), rightKeys.get(i, keyType)) + if (keyComp != 0) { + return keyComp + } + + val isNullLeft = leftValues.isNullAt(i) + val isNullRight = rightValues.isNullAt(i) + if (isNullLeft && isNullRight) { + // Do nothing. + } else if (isNullLeft) { + return -1 + } else if (isNullRight) { + return 1 + } else { + val comp = valueOrdering.compare( + leftValues.get(i, valueType), rightValues.get(i, valueType)) + if (comp != 0) { + return comp + } + } + i += 1 + } + val diff = left.numElements() - right.numElements() + if (diff < 0) { + -1 + } else if (diff > 0) { + 1 + } else { + 0 + } + } + } + + override def toString: String = { + s"MapType(${keyType.toString},${valueType.toString},${valueContainsNull.toString})" + } } /** diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/OrderingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/OrderingSuite.scala index d0604b8eb7675..baba67437ca9c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/OrderingSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/OrderingSuite.scala @@ -25,6 +25,7 @@ import org.apache.spark.sql.{RandomDataGenerator, Row} import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, GenerateOrdering, LazilyGeneratedOrdering} +import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData, MapData} import org.apache.spark.sql.types._ class OrderingSuite extends SparkFunSuite with ExpressionEvalHelper { @@ -36,28 +37,59 @@ class OrderingSuite extends SparkFunSuite with ExpressionEvalHelper { val toCatalyst = CatalystTypeConverters.createToCatalystConverter(rowType) val rowA = toCatalyst(Row(a)).asInstanceOf[InternalRow] val rowB = toCatalyst(Row(b)).asInstanceOf[InternalRow] - Seq(Ascending, Descending).foreach { direction => - val sortOrder = direction match { - case Ascending => BoundReference(0, dataType, nullable = true).asc - case Descending => BoundReference(0, dataType, nullable = true).desc - } - val expectedCompareResult = direction match { - case Ascending => signum(expected) - case Descending => -1 * signum(expected) - } + compareByMultipleOrderings(rowA, rowB, dataType, expected) + } + } - val kryo = new KryoSerializer(new SparkConf).newInstance() - val intOrdering = new InterpretedOrdering(sortOrder :: Nil) - val genOrdering = new LazilyGeneratedOrdering(sortOrder :: Nil) - val kryoIntOrdering = kryo.deserialize[InterpretedOrdering](kryo.serialize(intOrdering)) - val kryoGenOrdering = kryo.deserialize[LazilyGeneratedOrdering](kryo.serialize(genOrdering)) - - Seq(intOrdering, genOrdering, kryoIntOrdering, kryoGenOrdering).foreach { ordering => - assert(ordering.compare(rowA, rowA) === 0) - assert(ordering.compare(rowB, rowB) === 0) - assert(signum(ordering.compare(rowA, rowB)) === expectedCompareResult) - assert(signum(ordering.compare(rowB, rowA)) === -1 * expectedCompareResult) - } + def compareMaps(a: Map[Integer, Integer], b: Map[Integer, Integer], expected: Int): Unit = { + test(s"compare two maps: a = $a, b = $b") { + val dataType = MapType(IntegerType, IntegerType) + dataType.setOrdered(true) + + val rowA = new SpecificInternalRow(Seq(dataType)) + rowA.update(0, genOrderedMapData(a, dataType)) + + val rowB = new SpecificInternalRow(Seq(dataType)) + rowB.update(0, genOrderedMapData(b, dataType)) + + compareByMultipleOrderings(rowA, rowB, dataType, expected) + } + } + + def genOrderedMapData(m: Map[Integer, Integer], dataType: MapType): MapData = { + val sortedEntries = m.toArray.sortWith { + case (entry0, entry1) => + if (dataType.interpretedKeyOrdering.compare(entry0._1, entry1._1) == 1) { true } + else { false } + } + val keys = new GenericArrayData(sortedEntries.map(_._1)) + val values = new GenericArrayData(sortedEntries.map(_._2)) + new ArrayBasedMapData(keys, values) + } + + def compareByMultipleOrderings( + rowA: InternalRow, rowB: InternalRow, dataType: DataType, expected: Int): Unit = { + Seq(Ascending, Descending).foreach { direction => + val sortOrder = direction match { + case Ascending => BoundReference(0, dataType, nullable = true).asc + case Descending => BoundReference(0, dataType, nullable = true).desc + } + val expectedCompareResult = direction match { + case Ascending => signum(expected) + case Descending => -1 * signum(expected) + } + + val kryo = new KryoSerializer(new SparkConf).newInstance() + val intOrdering = new InterpretedOrdering(sortOrder :: Nil) + val genOrdering = new LazilyGeneratedOrdering(sortOrder :: Nil) + val kryoIntOrdering = kryo.deserialize[InterpretedOrdering](kryo.serialize(intOrdering)) + val kryoGenOrdering = kryo.deserialize[LazilyGeneratedOrdering](kryo.serialize(genOrdering)) + + Seq(intOrdering, genOrdering, kryoIntOrdering, kryoGenOrdering).foreach { ordering => + assert(ordering.compare(rowA, rowA) === 0) + assert(ordering.compare(rowB, rowB) === 0) + assert(signum(ordering.compare(rowA, rowB)) === expectedCompareResult) + assert(signum(ordering.compare(rowB, rowA)) === -1 * expectedCompareResult) } } } @@ -86,6 +118,27 @@ class OrderingSuite extends SparkFunSuite with ExpressionEvalHelper { compareArrays(Seq[Any](1, null, 1), Seq[Any](1, null, 1), 0) compareArrays(Seq[Any](1, null, 1), Seq[Any](1, null, 2), -1) + // Two maps have the same size. + compareMaps(Map[Integer, Integer](), Map[Integer, Integer](), 0) + compareMaps(Map((1, 2)), Map((1, 2)), 0) + compareMaps(Map((3, 4), (1, 2)), Map((3, 4), (1, 2)), 0) + compareMaps(Map((3, 4), (1, 2)), Map((1, 2), (3, 4)), 0) + compareMaps(Map((5, 6), (3, 4), (1, 2)), Map((1, 2), (5, 6), (3, 4)), 0) + compareMaps(Map((3, 4), (1, 2)), Map((2, 4), (1, 2)), 1) + compareMaps(Map((3, 4), (1, 2)), Map((3, 5), (1, 2)), -1) + compareMaps(Map((3, 4), (1, 2)), Map((3, 3), (2, 5)), 1) + + // Two maps have different sizes. + compareMaps(Map((1, 2)), Map[Integer, Integer](), 1) + compareMaps(Map((3, 4), (1, 2), (0, 0)), Map((3, 4), (1, 2)), 1) + compareMaps(Map((3, 4), (1, 2), (0, 0)), Map((4, 4), (1, 2)), -1) + compareMaps(Map((3, 4), (1, 2), (0, 0)), Map((3, 5), (1, 2)), -1) + + // Maps having nulls. + compareMaps(Map((1, null)), Map((1, null)), 0) + compareMaps(Map((1, 2)), Map((1, null)), 1) + compareMaps(Map((3, 4), (1, 2)), Map((3, 4), (1, null), (0, 0)), 1) + // Test GenerateOrdering for all common types. For each type, we construct random input rows that // contain two columns of that type, then for pairs of randomly-generated rows we check that // GenerateOrdering agrees with RowOrdering.