Skip to content

Commit 3e2e187

Browse files
yhuaidavies
authored andcommitted
[SPARK-11738] [SQL] Making ArrayType orderable
https://issues.apache.org/jira/browse/SPARK-11738 Author: Yin Huai <[email protected]> Closes #9718 from yhuai/makingArrayOrderable.
1 parent 64e5551 commit 3e2e187

File tree

14 files changed

+335
-94
lines changed

14 files changed

+335
-94
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala

Lines changed: 7 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -137,32 +137,14 @@ trait CheckAnalysis {
137137
case e => e.children.foreach(checkValidAggregateExpression)
138138
}
139139

140-
def checkSupportedGroupingDataType(
141-
expressionString: String,
142-
dataType: DataType): Unit = dataType match {
143-
case BinaryType =>
144-
failAnalysis(s"expression $expressionString cannot be used in " +
145-
s"grouping expression because it is in binary type or its inner field is " +
146-
s"in binary type")
147-
case a: ArrayType =>
148-
failAnalysis(s"expression $expressionString cannot be used in " +
149-
s"grouping expression because it is in array type or its inner field is " +
150-
s"in array type")
151-
case m: MapType =>
152-
failAnalysis(s"expression $expressionString cannot be used in " +
153-
s"grouping expression because it is in map type or its inner field is " +
154-
s"in map type")
155-
case s: StructType =>
156-
s.fields.foreach { f =>
157-
checkSupportedGroupingDataType(expressionString, f.dataType)
158-
}
159-
case udt: UserDefinedType[_] =>
160-
checkSupportedGroupingDataType(expressionString, udt.sqlType)
161-
case _ => // OK
162-
}
163-
164140
def checkValidGroupingExprs(expr: Expression): Unit = {
165-
checkSupportedGroupingDataType(expr.prettyString, expr.dataType)
141+
// Check if the data type of expr is orderable.
142+
if (!RowOrdering.isOrderable(expr.dataType)) {
143+
failAnalysis(
144+
s"expression ${expr.prettyString} cannot be used as a grouping expression " +
145+
s"because its data type ${expr.dataType.simpleString} is not a orderable " +
146+
s"data type.")
147+
}
166148

167149
if (!expr.deterministic) {
168150
// This is just a sanity check, our analysis rule PullOutNondeterministic should

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -267,6 +267,49 @@ class CodeGenContext {
267267
case dt: DataType if isPrimitiveType(dt) => s"($c1 > $c2 ? 1 : $c1 < $c2 ? -1 : 0)"
268268
case BinaryType => s"org.apache.spark.sql.catalyst.util.TypeUtils.compareBinary($c1, $c2)"
269269
case NullType => "0"
270+
case array: ArrayType =>
271+
val elementType = array.elementType
272+
val elementA = freshName("elementA")
273+
val isNullA = freshName("isNullA")
274+
val elementB = freshName("elementB")
275+
val isNullB = freshName("isNullB")
276+
val compareFunc = freshName("compareArray")
277+
val minLength = freshName("minLength")
278+
val funcCode: String =
279+
s"""
280+
public int $compareFunc(ArrayData a, ArrayData b) {
281+
int lengthA = a.numElements();
282+
int lengthB = b.numElements();
283+
int $minLength = (lengthA > lengthB) ? lengthB : lengthA;
284+
for (int i = 0; i < $minLength; i++) {
285+
boolean $isNullA = a.isNullAt(i);
286+
boolean $isNullB = b.isNullAt(i);
287+
if ($isNullA && $isNullB) {
288+
// Nothing
289+
} else if ($isNullA) {
290+
return -1;
291+
} else if ($isNullB) {
292+
return 1;
293+
} else {
294+
${javaType(elementType)} $elementA = ${getValue("a", elementType, "i")};
295+
${javaType(elementType)} $elementB = ${getValue("b", elementType, "i")};
296+
int comp = ${genComp(elementType, elementA, elementB)};
297+
if (comp != 0) {
298+
return comp;
299+
}
300+
}
301+
}
302+
303+
if (lengthA < lengthB) {
304+
return -1;
305+
} else if (lengthA > lengthB) {
306+
return 1;
307+
}
308+
return 0;
309+
}
310+
"""
311+
addNewFunction(compareFunc, funcCode)
312+
s"this.$compareFunc($c1, $c2)"
270313
case schema: StructType =>
271314
val comparisons = GenerateOrdering.genComparisons(this, schema)
272315
val compareFunc = freshName("compareStruct")

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ case class SortArray(base: Expression, ascendingOrder: Expression)
6868
private lazy val lt: Comparator[Any] = {
6969
val ordering = base.dataType match {
7070
case _ @ ArrayType(n: AtomicType, _) => n.ordering.asInstanceOf[Ordering[Any]]
71+
case _ @ ArrayType(a: ArrayType, _) => a.interpretedOrdering.asInstanceOf[Ordering[Any]]
7172
case _ @ ArrayType(s: StructType, _) => s.interpretedOrdering.asInstanceOf[Ordering[Any]]
7273
}
7374

@@ -90,6 +91,7 @@ case class SortArray(base: Expression, ascendingOrder: Expression)
9091
private lazy val gt: Comparator[Any] = {
9192
val ordering = base.dataType match {
9293
case _ @ ArrayType(n: AtomicType, _) => n.ordering.asInstanceOf[Ordering[Any]]
94+
case _ @ ArrayType(a: ArrayType, _) => a.interpretedOrdering.asInstanceOf[Ordering[Any]]
9395
case _ @ ArrayType(s: StructType, _) => s.interpretedOrdering.asInstanceOf[Ordering[Any]]
9496
}
9597

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ordering.scala

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,10 @@ class InterpretedOrdering(ordering: Seq[SortOrder]) extends Ordering[InternalRow
4848
dt.ordering.asInstanceOf[Ordering[Any]].compare(left, right)
4949
case dt: AtomicType if order.direction == Descending =>
5050
dt.ordering.asInstanceOf[Ordering[Any]].reverse.compare(left, right)
51+
case a: ArrayType if order.direction == Ascending =>
52+
a.interpretedOrdering.asInstanceOf[Ordering[Any]].compare(left, right)
53+
case a: ArrayType if order.direction == Descending =>
54+
a.interpretedOrdering.asInstanceOf[Ordering[Any]].reverse.compare(left, right)
5155
case s: StructType if order.direction == Ascending =>
5256
s.interpretedOrdering.asInstanceOf[Ordering[Any]].compare(left, right)
5357
case s: StructType if order.direction == Descending =>
@@ -86,6 +90,8 @@ object RowOrdering {
8690
case NullType => true
8791
case dt: AtomicType => true
8892
case struct: StructType => struct.fields.forall(f => isOrderable(f.dataType))
93+
case array: ArrayType => isOrderable(array.elementType)
94+
case udt: UserDefinedType[_] => isOrderable(udt.sqlType)
8995
case _ => false
9096
}
9197

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ object TypeUtils {
5757
def getInterpretedOrdering(t: DataType): Ordering[Any] = {
5858
t match {
5959
case i: AtomicType => i.ordering.asInstanceOf[Ordering[Any]]
60+
case a: ArrayType => a.interpretedOrdering.asInstanceOf[Ordering[Any]]
6061
case s: StructType => s.interpretedOrdering.asInstanceOf[Ordering[Any]]
6162
}
6263
}

sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@ private[sql] object TypeCollection {
8484
* Types that can be ordered/compared. In the long run we should probably make this a trait
8585
* that can be mixed into each data type, and perhaps create an [[AbstractDataType]].
8686
*/
87+
// TODO: Should we consolidate this with RowOrdering.isOrderable?
8788
val Ordered = TypeCollection(
8889
BooleanType,
8990
ByteType, ShortType, IntegerType, LongType,

sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,13 @@
1717

1818
package org.apache.spark.sql.types
1919

20+
import org.apache.spark.sql.catalyst.util.ArrayData
2021
import org.json4s.JsonDSL._
2122

2223
import org.apache.spark.annotation.DeveloperApi
2324

25+
import scala.math.Ordering
26+
2427

2528
object ArrayType extends AbstractDataType {
2629
/** Construct a [[ArrayType]] object with the given element type. The `containsNull` is true. */
@@ -81,4 +84,49 @@ case class ArrayType(elementType: DataType, containsNull: Boolean) extends DataT
8184
override private[spark] def existsRecursively(f: (DataType) => Boolean): Boolean = {
8285
f(this) || elementType.existsRecursively(f)
8386
}
87+
88+
@transient
89+
private[sql] lazy val interpretedOrdering: Ordering[ArrayData] = new Ordering[ArrayData] {
90+
private[this] val elementOrdering: Ordering[Any] = elementType match {
91+
case dt: AtomicType => dt.ordering.asInstanceOf[Ordering[Any]]
92+
case a : ArrayType => a.interpretedOrdering.asInstanceOf[Ordering[Any]]
93+
case s: StructType => s.interpretedOrdering.asInstanceOf[Ordering[Any]]
94+
case other =>
95+
throw new IllegalArgumentException(s"Type $other does not support ordered operations")
96+
}
97+
98+
def compare(x: ArrayData, y: ArrayData): Int = {
99+
val leftArray = x
100+
val rightArray = y
101+
val minLength = scala.math.min(leftArray.numElements(), rightArray.numElements())
102+
var i = 0
103+
while (i < minLength) {
104+
val isNullLeft = leftArray.isNullAt(i)
105+
val isNullRight = rightArray.isNullAt(i)
106+
if (isNullLeft && isNullRight) {
107+
// Do nothing.
108+
} else if (isNullLeft) {
109+
return -1
110+
} else if (isNullRight) {
111+
return 1
112+
} else {
113+
val comp =
114+
elementOrdering.compare(
115+
leftArray.get(i, elementType),
116+
rightArray.get(i, elementType))
117+
if (comp != 0) {
118+
return comp
119+
}
120+
}
121+
i += 1
122+
}
123+
if (leftArray.numElements() < rightArray.numElements()) {
124+
return -1
125+
} else if (leftArray.numElements() > rightArray.numElements()) {
126+
return 1
127+
} else {
128+
return 0
129+
}
130+
}
131+
}
84132
}

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala

Lines changed: 23 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.plans.logical._
2323
import org.apache.spark.sql.catalyst.plans.Inner
2424
import org.apache.spark.sql.catalyst.dsl.expressions._
2525
import org.apache.spark.sql.catalyst.dsl.plans._
26-
import org.apache.spark.sql.catalyst.util.{GenericArrayData, ArrayData}
26+
import org.apache.spark.sql.catalyst.util.{MapData, ArrayBasedMapData, GenericArrayData, ArrayData}
2727
import org.apache.spark.sql.types._
2828

2929
import scala.beans.{BeanProperty, BeanInfo}
@@ -53,21 +53,29 @@ private[sql] class GroupableUDT extends UserDefinedType[GroupableData] {
5353
}
5454

5555
@BeanInfo
56-
private[sql] case class UngroupableData(@BeanProperty data: Array[Int])
56+
private[sql] case class UngroupableData(@BeanProperty data: Map[Int, Int])
5757

5858
private[sql] class UngroupableUDT extends UserDefinedType[UngroupableData] {
5959

60-
override def sqlType: DataType = ArrayType(IntegerType)
60+
override def sqlType: DataType = MapType(IntegerType, IntegerType)
6161

62-
override def serialize(obj: Any): ArrayData = {
62+
override def serialize(obj: Any): MapData = {
6363
obj match {
64-
case groupableData: UngroupableData => new GenericArrayData(groupableData.data)
64+
case groupableData: UngroupableData =>
65+
val keyArray = new GenericArrayData(groupableData.data.keys.toSeq)
66+
val valueArray = new GenericArrayData(groupableData.data.values.toSeq)
67+
new ArrayBasedMapData(keyArray, valueArray)
6568
}
6669
}
6770

6871
override def deserialize(datum: Any): UngroupableData = {
6972
datum match {
70-
case data: Array[Int] => UngroupableData(data)
73+
case data: MapData =>
74+
val keyArray = data.keyArray().array
75+
val valueArray = data.valueArray().array
76+
assert(keyArray.length == valueArray.length)
77+
val mapData = keyArray.zip(valueArray).toMap.asInstanceOf[Map[Int, Int]]
78+
UngroupableData(mapData)
7179
}
7280
}
7381

@@ -154,8 +162,8 @@ class AnalysisErrorSuite extends AnalysisTest {
154162

155163
errorTest(
156164
"sorting by unsupported column types",
157-
listRelation.orderBy('list.asc),
158-
"sort" :: "type" :: "array<int>" :: Nil)
165+
mapRelation.orderBy('map.asc),
166+
"sort" :: "type" :: "map<int,int>" :: Nil)
159167

160168
errorTest(
161169
"non-boolean filters",
@@ -259,32 +267,33 @@ class AnalysisErrorSuite extends AnalysisTest {
259267
case true =>
260268
assertAnalysisSuccess(plan, true)
261269
case false =>
262-
assertAnalysisError(plan, "expression a cannot be used in grouping expression" :: Nil)
270+
assertAnalysisError(plan, "expression a cannot be used as a grouping expression" :: Nil)
263271
}
264-
265272
}
266273

267274
val supportedDataTypes = Seq(
268-
StringType,
275+
StringType, BinaryType,
269276
NullType, BooleanType,
270277
ByteType, ShortType, IntegerType, LongType,
271278
FloatType, DoubleType, DecimalType(25, 5), DecimalType(6, 5),
272279
DateType, TimestampType,
280+
ArrayType(IntegerType),
273281
new StructType()
274282
.add("f1", FloatType, nullable = true)
275283
.add("f2", StringType, nullable = true),
284+
new StructType()
285+
.add("f1", FloatType, nullable = true)
286+
.add("f2", ArrayType(BooleanType, containsNull = true), nullable = true),
276287
new GroupableUDT())
277288
supportedDataTypes.foreach { dataType =>
278289
checkDataType(dataType, shouldSuccess = true)
279290
}
280291

281292
val unsupportedDataTypes = Seq(
282-
BinaryType,
283-
ArrayType(IntegerType),
284293
MapType(StringType, LongType),
285294
new StructType()
286295
.add("f1", FloatType, nullable = true)
287-
.add("f2", ArrayType(BooleanType, containsNull = true), nullable = true),
296+
.add("f2", MapType(StringType, LongType), nullable = true),
288297
new UngroupableUDT())
289298
unsupportedDataTypes.foreach { dataType =>
290299
checkDataType(dataType, shouldSuccess = false)

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -24,15 +24,16 @@ import org.apache.spark.sql.catalyst.dsl.plans._
2424
import org.apache.spark.sql.catalyst.expressions._
2525
import org.apache.spark.sql.catalyst.expressions.aggregate._
2626
import org.apache.spark.sql.catalyst.plans.logical.LocalRelation
27-
import org.apache.spark.sql.types.{TypeCollection, StringType}
27+
import org.apache.spark.sql.types.{LongType, TypeCollection, StringType}
2828

2929
class ExpressionTypeCheckingSuite extends SparkFunSuite {
3030

3131
val testRelation = LocalRelation(
3232
'intField.int,
3333
'stringField.string,
3434
'booleanField.boolean,
35-
'complexField.array(StringType))
35+
'arrayField.array(StringType),
36+
'mapField.map(StringType, LongType))
3637

3738
def assertError(expr: Expression, errorMessage: String): Unit = {
3839
val e = intercept[AnalysisException] {
@@ -90,9 +91,9 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite {
9091
assertError(BitwiseOr('booleanField, 'booleanField), "requires integral type")
9192
assertError(BitwiseXor('booleanField, 'booleanField), "requires integral type")
9293

93-
assertError(MaxOf('complexField, 'complexField),
94+
assertError(MaxOf('mapField, 'mapField),
9495
s"requires ${TypeCollection.Ordered.simpleString} type")
95-
assertError(MinOf('complexField, 'complexField),
96+
assertError(MinOf('mapField, 'mapField),
9697
s"requires ${TypeCollection.Ordered.simpleString} type")
9798
}
9899

@@ -109,31 +110,31 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite {
109110
assertSuccess(EqualTo('intField, 'booleanField))
110111
assertSuccess(EqualNullSafe('intField, 'booleanField))
111112

112-
assertErrorForDifferingTypes(EqualTo('intField, 'complexField))
113-
assertErrorForDifferingTypes(EqualNullSafe('intField, 'complexField))
113+
assertErrorForDifferingTypes(EqualTo('intField, 'mapField))
114+
assertErrorForDifferingTypes(EqualNullSafe('intField, 'mapField))
114115
assertErrorForDifferingTypes(LessThan('intField, 'booleanField))
115116
assertErrorForDifferingTypes(LessThanOrEqual('intField, 'booleanField))
116117
assertErrorForDifferingTypes(GreaterThan('intField, 'booleanField))
117118
assertErrorForDifferingTypes(GreaterThanOrEqual('intField, 'booleanField))
118119

119-
assertError(LessThan('complexField, 'complexField),
120+
assertError(LessThan('mapField, 'mapField),
120121
s"requires ${TypeCollection.Ordered.simpleString} type")
121-
assertError(LessThanOrEqual('complexField, 'complexField),
122+
assertError(LessThanOrEqual('mapField, 'mapField),
122123
s"requires ${TypeCollection.Ordered.simpleString} type")
123-
assertError(GreaterThan('complexField, 'complexField),
124+
assertError(GreaterThan('mapField, 'mapField),
124125
s"requires ${TypeCollection.Ordered.simpleString} type")
125-
assertError(GreaterThanOrEqual('complexField, 'complexField),
126+
assertError(GreaterThanOrEqual('mapField, 'mapField),
126127
s"requires ${TypeCollection.Ordered.simpleString} type")
127128

128129
assertError(If('intField, 'stringField, 'stringField),
129130
"type of predicate expression in If should be boolean")
130131
assertErrorForDifferingTypes(If('booleanField, 'intField, 'booleanField))
131132

132133
assertError(
133-
CaseWhen(Seq('booleanField, 'intField, 'booleanField, 'complexField)),
134+
CaseWhen(Seq('booleanField, 'intField, 'booleanField, 'mapField)),
134135
"THEN and ELSE expressions should all be same type or coercible to a common type")
135136
assertError(
136-
CaseKeyWhen('intField, Seq('intField, 'stringField, 'intField, 'complexField)),
137+
CaseKeyWhen('intField, Seq('intField, 'stringField, 'intField, 'mapField)),
137138
"THEN and ELSE expressions should all be same type or coercible to a common type")
138139
assertError(
139140
CaseWhen(Seq('booleanField, 'intField, 'intField, 'intField)),
@@ -147,9 +148,10 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite {
147148
// We will cast String to Double for sum and average
148149
assertSuccess(Sum('stringField))
149150
assertSuccess(Average('stringField))
151+
assertSuccess(Min('arrayField))
150152

151-
assertError(Min('complexField), "min does not support ordering on type")
152-
assertError(Max('complexField), "max does not support ordering on type")
153+
assertError(Min('mapField), "min does not support ordering on type")
154+
assertError(Max('mapField), "max does not support ordering on type")
153155
assertError(Sum('booleanField), "function sum requires numeric type")
154156
assertError(Average('booleanField), "function average requires numeric type")
155157
}
@@ -184,7 +186,7 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite {
184186

185187
assertError(Round('intField, 'intField), "Only foldable Expression is allowed")
186188
assertError(Round('intField, 'booleanField), "requires int type")
187-
assertError(Round('intField, 'complexField), "requires int type")
189+
assertError(Round('intField, 'mapField), "requires int type")
188190
assertError(Round('booleanField, 'intField), "requires numeric type")
189191
}
190192
}

0 commit comments

Comments
 (0)