Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -606,6 +606,7 @@ class CodegenContext {
case dt: DataType if dt.isInstanceOf[AtomicType] => s"$c1.equals($c2)"
case array: ArrayType => genComp(array, c1, c2) + " == 0"
case struct: StructType => genComp(struct, c1, c2) + " == 0"
case map: MapType if map.isOrdered => genComp(map, c1, c2) + " == 0"
case udt: UserDefinedType[_] => genEqual(udt.sqlType, c1, c2)
case NullType => "false"
case _ =>
Expand Down Expand Up @@ -677,6 +678,49 @@ class CodegenContext {
}
"""
s"${addNewFunction(compareFunc, funcCode)}($c1, $c2)"

case mapType @ MapType(keyType, valueType, valueContainsNull) if mapType.isOrdered =>
val compareFunc = freshName("compareMap")
val funcCode: String =
s"""
public int $compareFunc(MapData a, MapData b) {
int lengthA = a.numElements();
int lengthB = b.numElements();
ArrayData aKeys = a.keyArray();
ArrayData aValues = a.valueArray();
ArrayData bKeys = b.keyArray();
ArrayData bValues = b.valueArray();
int minLength = (lengthA > lengthB) ? lengthB : lengthA;
for (int i = 0; i < minLength; i++) {
${javaType(keyType)} keyA = ${getValue("aKeys", keyType, "i")};
${javaType(keyType)} keyB = ${getValue("bKeys", keyType, "i")};
int comp = ${genComp(keyType, "keyA", "keyB")};
if (comp != 0) {
return comp;
}
boolean isNullA = aValues.isNullAt(i);
boolean isNullB = bValues.isNullAt(i);
if (isNullA && isNullB) {
// Nothing
} else if (isNullA) {
return -1;
} else if (isNullB) {
return 1;
} else {
${javaType(valueType)} valueA = ${getValue("aValues", valueType, "i")};
${javaType(valueType)} valueB = ${getValue("bValues", valueType, "i")};
comp = ${genComp(valueType, "valueA", "valueB")};
if (comp != 0) {
return comp;
}
}
}
return lengthA - lengthB;
}
"""
addNewFunction(compareFunc, funcCode)
s"this.$compareFunc($c1, $c2)"

case schema: StructType =>
val comparisons = GenerateOrdering.genComparisons(this, schema)
val compareFunc = freshName("compareStruct")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,10 @@ class InterpretedOrdering(ordering: Seq[SortOrder]) extends Ordering[InternalRow
a.interpretedOrdering.asInstanceOf[Ordering[Any]].compare(left, right)
case a: ArrayType if order.direction == Descending =>
a.interpretedOrdering.asInstanceOf[Ordering[Any]].reverse.compare(left, right)
case m: MapType if m.isOrdered && order.direction == Ascending =>
m.interpretedOrdering.asInstanceOf[Ordering[Any]].compare(left, right)
case m: MapType if m.isOrdered && order.direction == Descending =>
m.interpretedOrdering.asInstanceOf[Ordering[Any]].reverse.compare(left, right)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to care about this ordering direction? We just need comparable maps?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, this is not necessary, but just to make the logic complete.
#9718 did the same thing.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You mean we can't remove this? If not necessary, better to remove it off.

case s: StructType if order.direction == Ascending =>
s.interpretedOrdering.asInstanceOf[Ordering[Any]].compare(left, right)
case s: StructType if order.direction == Descending =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,11 @@ object TypeUtils {
t match {
case i: AtomicType => i.ordering.asInstanceOf[Ordering[Any]]
case a: ArrayType => a.interpretedOrdering.asInstanceOf[Ordering[Any]]
case m: MapType if m.isOrdered => m.interpretedOrdering.asInstanceOf[Ordering[Any]]
case s: StructType => s.interpretedOrdering.asInstanceOf[Ordering[Any]]
case udt: UserDefinedType[_] => getInterpretedOrdering(udt.sqlType)
case other =>
throw new IllegalArgumentException(s"Type $other does not support ordered operations")
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import org.json4s.JsonAST.JValue
import org.json4s.JsonDSL._

import org.apache.spark.annotation.InterfaceStability
import org.apache.spark.sql.catalyst.util.{MapData, TypeUtils}

/**
* The data type for Maps. Keys in a map are not allowed to have `null` values.
Expand Down Expand Up @@ -73,6 +74,90 @@ case class MapType(
override private[spark] def existsRecursively(f: (DataType) => Boolean): Boolean = {
f(this) || keyType.existsRecursively(f) || valueType.existsRecursively(f)
}

private[this] class OrderedWrapper {
var isOrdered: Boolean = false
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I prefer not to make this mutable if we can. That can be a source of some pretty weird errors if we move from an unordered to an ordered map. Why do you need this?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for quick reply :)

Actually I'm not pretty sure about this.
If we do it like below

case class MapType(
  keyType: DataType,
  valueType: DataType,
  valueContainsNull: Boolean,
  ordered: Boolean)

The ordered will be spread to lots places in the code (especially in the ...match ... case ... ) and users can will also see it. But I think ordered is a pretty internal parameter/characteristic and only used when sorting map. So I try to make it private and lazy created.

}

private[this] lazy val orderedWrapper: OrderedWrapper = new OrderedWrapper()

private[sql] def setOrdered(b: Boolean): Unit = {
orderedWrapper.isOrdered = b
}

// Indicates if a map is itself "ordered". It makes sense to compare two
// maps only when they are themselves "ordered", i.e. entries of the map are sorted.
// This parameter is used by internal when doing ordering operation, e.g. sort
// values of `MapType`.
private[sql] def isOrdered(): Boolean = orderedWrapper.isOrdered

// This is used to sort the entries of a map.
@transient
private[sql] lazy val interpretedKeyOrdering: Ordering[Any] =
TypeUtils.getInterpretedOrdering(keyType)

@transient
private[this] lazy val interpretedValueOrdering: Ordering[Any] =
TypeUtils.getInterpretedOrdering(valueType)

@transient
private[sql] lazy val interpretedOrdering: Ordering[MapData] = new Ordering[MapData] {
val keyOrdering = interpretedKeyOrdering
val valueOrdering = interpretedValueOrdering

// The approach to compare (left: MapData, right: MapData):
// 1. The precondition is that entries inside `left` and `right` are already sorted themselves;
// 2. Compare entries from `left` and `right`, say entryA(keyA, valueA) is from `left` and
// entryB(keyB, valueB) is from `right`:
// a. entryA is bigger than entryB if keyA is bigger than keyB and vice versa;
// b. entryA is bigger than entryB if keyA equals to keyB and valueA is bigger than
// valueB and vice versa;
// 3. If entries from the head equals to each other between `left` and `right`, the `MapData`
// with more entries is bigger.
def compare(left: MapData, right: MapData): Int = {
val leftKeys = left.keyArray()
val leftValues = left.valueArray()
val rightKeys = right.keyArray()
val rightValues = right.valueArray()
val minLength = scala.math.min(leftKeys.numElements(), rightKeys.numElements())
var i = 0
while (i < minLength) {
val keyComp = keyOrdering.compare(leftKeys.get(i, keyType), rightKeys.get(i, keyType))
if (keyComp != 0) {
return keyComp
}

val isNullLeft = leftValues.isNullAt(i)
val isNullRight = rightValues.isNullAt(i)
if (isNullLeft && isNullRight) {
// Do nothing.
} else if (isNullLeft) {
return -1
} else if (isNullRight) {
return 1
} else {
val comp = valueOrdering.compare(
leftValues.get(i, valueType), rightValues.get(i, valueType))
if (comp != 0) {
return comp
}
}
i += 1
}
val diff = left.numElements() - right.numElements()
if (diff < 0) {
-1
} else if (diff > 0) {
1
} else {
0
}
}
}

override def toString: String = {
s"MapType(${keyType.toString},${valueType.toString},${valueContainsNull.toString})"
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import org.apache.spark.sql.{RandomDataGenerator, Row}
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, GenerateOrdering, LazilyGeneratedOrdering}
import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData, MapData}
import org.apache.spark.sql.types._

class OrderingSuite extends SparkFunSuite with ExpressionEvalHelper {
Expand All @@ -36,28 +37,59 @@ class OrderingSuite extends SparkFunSuite with ExpressionEvalHelper {
val toCatalyst = CatalystTypeConverters.createToCatalystConverter(rowType)
val rowA = toCatalyst(Row(a)).asInstanceOf[InternalRow]
val rowB = toCatalyst(Row(b)).asInstanceOf[InternalRow]
Seq(Ascending, Descending).foreach { direction =>
val sortOrder = direction match {
case Ascending => BoundReference(0, dataType, nullable = true).asc
case Descending => BoundReference(0, dataType, nullable = true).desc
}
val expectedCompareResult = direction match {
case Ascending => signum(expected)
case Descending => -1 * signum(expected)
}
compareByMultipleOrderings(rowA, rowB, dataType, expected)
}
}

val kryo = new KryoSerializer(new SparkConf).newInstance()
val intOrdering = new InterpretedOrdering(sortOrder :: Nil)
val genOrdering = new LazilyGeneratedOrdering(sortOrder :: Nil)
val kryoIntOrdering = kryo.deserialize[InterpretedOrdering](kryo.serialize(intOrdering))
val kryoGenOrdering = kryo.deserialize[LazilyGeneratedOrdering](kryo.serialize(genOrdering))

Seq(intOrdering, genOrdering, kryoIntOrdering, kryoGenOrdering).foreach { ordering =>
assert(ordering.compare(rowA, rowA) === 0)
assert(ordering.compare(rowB, rowB) === 0)
assert(signum(ordering.compare(rowA, rowB)) === expectedCompareResult)
assert(signum(ordering.compare(rowB, rowA)) === -1 * expectedCompareResult)
}
def compareMaps(a: Map[Integer, Integer], b: Map[Integer, Integer], expected: Int): Unit = {
test(s"compare two maps: a = $a, b = $b") {
val dataType = MapType(IntegerType, IntegerType)
dataType.setOrdered(true)

val rowA = new SpecificInternalRow(Seq(dataType))
rowA.update(0, genOrderedMapData(a, dataType))

val rowB = new SpecificInternalRow(Seq(dataType))
rowB.update(0, genOrderedMapData(b, dataType))

compareByMultipleOrderings(rowA, rowB, dataType, expected)
}
}

def genOrderedMapData(m: Map[Integer, Integer], dataType: MapType): MapData = {
val sortedEntries = m.toArray.sortWith {
case (entry0, entry1) =>
if (dataType.interpretedKeyOrdering.compare(entry0._1, entry1._1) == 1) { true }
else { false }
}
val keys = new GenericArrayData(sortedEntries.map(_._1))
val values = new GenericArrayData(sortedEntries.map(_._2))
new ArrayBasedMapData(keys, values)
}

def compareByMultipleOrderings(
rowA: InternalRow, rowB: InternalRow, dataType: DataType, expected: Int): Unit = {
Seq(Ascending, Descending).foreach { direction =>
val sortOrder = direction match {
case Ascending => BoundReference(0, dataType, nullable = true).asc
case Descending => BoundReference(0, dataType, nullable = true).desc
}
val expectedCompareResult = direction match {
case Ascending => signum(expected)
case Descending => -1 * signum(expected)
}

val kryo = new KryoSerializer(new SparkConf).newInstance()
val intOrdering = new InterpretedOrdering(sortOrder :: Nil)
val genOrdering = new LazilyGeneratedOrdering(sortOrder :: Nil)
val kryoIntOrdering = kryo.deserialize[InterpretedOrdering](kryo.serialize(intOrdering))
val kryoGenOrdering = kryo.deserialize[LazilyGeneratedOrdering](kryo.serialize(genOrdering))

Seq(intOrdering, genOrdering, kryoIntOrdering, kryoGenOrdering).foreach { ordering =>
assert(ordering.compare(rowA, rowA) === 0)
assert(ordering.compare(rowB, rowB) === 0)
assert(signum(ordering.compare(rowA, rowB)) === expectedCompareResult)
assert(signum(ordering.compare(rowB, rowA)) === -1 * expectedCompareResult)
}
}
}
Expand Down Expand Up @@ -86,6 +118,27 @@ class OrderingSuite extends SparkFunSuite with ExpressionEvalHelper {
compareArrays(Seq[Any](1, null, 1), Seq[Any](1, null, 1), 0)
compareArrays(Seq[Any](1, null, 1), Seq[Any](1, null, 2), -1)

// Two maps have the same size.
compareMaps(Map[Integer, Integer](), Map[Integer, Integer](), 0)
compareMaps(Map((1, 2)), Map((1, 2)), 0)
compareMaps(Map((3, 4), (1, 2)), Map((3, 4), (1, 2)), 0)
compareMaps(Map((3, 4), (1, 2)), Map((1, 2), (3, 4)), 0)
compareMaps(Map((5, 6), (3, 4), (1, 2)), Map((1, 2), (5, 6), (3, 4)), 0)
compareMaps(Map((3, 4), (1, 2)), Map((2, 4), (1, 2)), 1)
compareMaps(Map((3, 4), (1, 2)), Map((3, 5), (1, 2)), -1)
compareMaps(Map((3, 4), (1, 2)), Map((3, 3), (2, 5)), 1)

// Two maps have different sizes.
compareMaps(Map((1, 2)), Map[Integer, Integer](), 1)
compareMaps(Map((3, 4), (1, 2), (0, 0)), Map((3, 4), (1, 2)), 1)
compareMaps(Map((3, 4), (1, 2), (0, 0)), Map((4, 4), (1, 2)), -1)
compareMaps(Map((3, 4), (1, 2), (0, 0)), Map((3, 5), (1, 2)), -1)

// Maps having nulls.
compareMaps(Map((1, null)), Map((1, null)), 0)
compareMaps(Map((1, 2)), Map((1, null)), 1)
compareMaps(Map((3, 4), (1, 2)), Map((3, 4), (1, null), (0, 0)), 1)

// Test GenerateOrdering for all common types. For each type, we construct random input rows that
// contain two columns of that type, then for pairs of randomly-generated rows we check that
// GenerateOrdering agrees with RowOrdering.
Expand Down