Skip to content

Commit ce755e2

Browse files
committed
address review comments
1 parent 6fba1ee commit ce755e2

File tree

2 files changed

+80
-76
lines changed

2 files changed

+80
-76
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala

Lines changed: 74 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -4043,7 +4043,7 @@ object ArrayUnion {
40434043
array2, without duplicates.
40444044
""",
40454045
examples = """
4046-
Examples:Fun
4046+
Examples:
40474047
> SELECT _FUNC_(array(1, 2, 3), array(1, 3, 5));
40484048
array(1, 3)
40494049
""",
@@ -4060,81 +4060,89 @@ case class ArrayIntersect(left: Expression, right: Expression) extends ArraySetL
40604060
@transient lazy val evalIntersect: (ArrayData, ArrayData) => ArrayData = {
40614061
if (elementTypeSupportEquals) {
40624062
(array1, array2) =>
4063-
val hs = new OpenHashSet[Any]
4064-
val hsResult = new OpenHashSet[Any]
4065-
var foundNullElement = false
4066-
var i = 0
4067-
while (i < array2.numElements()) {
4068-
if (array2.isNullAt(i)) {
4069-
foundNullElement = true
4070-
} else {
4071-
val elem = array2.get(i, elementType)
4072-
hs.add(elem)
4073-
}
4074-
i += 1
4075-
}
4076-
val arrayBuffer = new scala.collection.mutable.ArrayBuffer[Any]
4077-
i = 0
4078-
while (i < array1.numElements()) {
4079-
if (array1.isNullAt(i)) {
4080-
if (foundNullElement) {
4081-
arrayBuffer += null
4082-
foundNullElement = false
4063+
if (array1.numElements() != 0 && array2.numElements() != 0) {
4064+
val hs = new OpenHashSet[Any]
4065+
val hsResult = new OpenHashSet[Any]
4066+
var foundNullElement = false
4067+
var i = 0
4068+
while (i < array2.numElements()) {
4069+
if (array2.isNullAt(i)) {
4070+
foundNullElement = true
4071+
} else {
4072+
val elem = array2.get(i, elementType)
4073+
hs.add(elem)
40834074
}
4084-
} else {
4085-
val elem = array1.get(i, elementType)
4086-
if (hs.contains(elem) && !hsResult.contains(elem)) {
4087-
arrayBuffer += elem
4088-
hsResult.add(elem)
4075+
i += 1
4076+
}
4077+
val arrayBuffer = new scala.collection.mutable.ArrayBuffer[Any]
4078+
i = 0
4079+
while (i < array1.numElements()) {
4080+
if (array1.isNullAt(i)) {
4081+
if (foundNullElement) {
4082+
arrayBuffer += null
4083+
foundNullElement = false
4084+
}
4085+
} else {
4086+
val elem = array1.get(i, elementType)
4087+
if (hs.contains(elem) && !hsResult.contains(elem)) {
4088+
arrayBuffer += elem
4089+
hsResult.add(elem)
4090+
}
40894091
}
4092+
i += 1
40904093
}
4091-
i += 1
4094+
new GenericArrayData(arrayBuffer)
4095+
} else {
4096+
new GenericArrayData(Seq.empty)
40924097
}
4093-
new GenericArrayData(arrayBuffer)
40944098
} else {
40954099
(array1, array2) =>
4096-
val arrayBuffer = new scala.collection.mutable.ArrayBuffer[Any]
4097-
var alreadySeenNull = false
4098-
var i = 0
4099-
while (i < array1.numElements()) {
4100-
var found = false
4101-
val elem1 = array1.get(i, elementType)
4102-
if (array1.isNullAt(i)) {
4103-
if (!alreadySeenNull) {
4100+
if (array1.numElements() != 0 && array2.numElements() != 0) {
4101+
val arrayBuffer = new scala.collection.mutable.ArrayBuffer[Any]
4102+
var alreadySeenNull = false
4103+
var i = 0
4104+
while (i < array1.numElements()) {
4105+
var found = false
4106+
val elem1 = array1.get(i, elementType)
4107+
if (array1.isNullAt(i)) {
4108+
if (!alreadySeenNull) {
4109+
var j = 0
4110+
while (!found && j < array2.numElements()) {
4111+
found = array2.isNullAt(j)
4112+
j += 1
4113+
}
4114+
// array2 is scanned only once for null element
4115+
alreadySeenNull = true
4116+
}
4117+
} else {
41044118
var j = 0
41054119
while (!found && j < array2.numElements()) {
4106-
found = array2.isNullAt(j)
4107-
j += 1
4108-
}
4109-
// array2 is scanned only once for null element
4110-
alreadySeenNull = true
4111-
}
4112-
} else {
4113-
var j = 0
4114-
while (!found && j < array2.numElements()) {
4115-
if (!array2.isNullAt(j)) {
4116-
val elem2 = array2.get(j, elementType)
4117-
if (ordering.equiv(elem1, elem2)) {
4118-
// check whether elem1 is already stored in arrayBuffer
4119-
var foundArrayBuffer = false
4120-
var k = 0
4121-
while (!foundArrayBuffer && k < arrayBuffer.size) {
4122-
val va = arrayBuffer(k)
4123-
foundArrayBuffer = (va != null) && ordering.equiv(va, elem1)
4124-
k += 1
4120+
if (!array2.isNullAt(j)) {
4121+
val elem2 = array2.get(j, elementType)
4122+
if (ordering.equiv(elem1, elem2)) {
4123+
// check whether elem1 is already stored in arrayBuffer
4124+
var foundArrayBuffer = false
4125+
var k = 0
4126+
while (!foundArrayBuffer && k < arrayBuffer.size) {
4127+
val va = arrayBuffer(k)
4128+
foundArrayBuffer = (va != null) && ordering.equiv(va, elem1)
4129+
k += 1
4130+
}
4131+
found = !foundArrayBuffer
41254132
}
4126-
found = !foundArrayBuffer
41274133
}
4134+
j += 1
41284135
}
4129-
j += 1
41304136
}
4137+
if (found) {
4138+
arrayBuffer += elem1
4139+
}
4140+
i += 1
41314141
}
4132-
if (found) {
4133-
arrayBuffer += elem1
4134-
}
4135-
i += 1
4142+
new GenericArrayData(arrayBuffer)
4143+
} else {
4144+
new GenericArrayData(Seq.empty)
41364145
}
4137-
new GenericArrayData(arrayBuffer)
41384146
}
41394147
}
41404148

@@ -4162,9 +4170,8 @@ case class ArrayIntersect(left: Expression, right: Expression) extends ArraySetL
41624170
val classTag = s"scala.reflect.ClassTag$$.MODULE$$.$hsTypeName()"
41634171
val hashSet = ctx.freshName("hashSet")
41644172
val hashSetResult = ctx.freshName("hashSetResult")
4165-
val arrayBuilder = "scala.collection.mutable.ArrayBuilder"
4173+
val arrayBuilder = classOf[mutable.ArrayBuilder[_]].getName
41664174
val arrayBuilderClass = s"$arrayBuilder$$of$ptName"
4167-
val arrayBuilderClassTag = s"scala.reflect.ClassTag$$.MODULE$$.$ptName()"
41684175

41694176
def withArray2NullCheck(body: String): String =
41704177
if (right.dataType.asInstanceOf[ArrayType].containsNull) {
@@ -4250,8 +4257,7 @@ case class ArrayIntersect(left: Expression, right: Expression) extends ArraySetL
42504257
|for (int $i = 0; $i < $array2.numElements(); $i++) {
42514258
| $writeArray2ToHashSet
42524259
|}
4253-
|$arrayBuilderClass $builder =
4254-
| ($arrayBuilderClass)$arrayBuilder.make($arrayBuilderClassTag);
4260+
|$arrayBuilderClass $builder = new $arrayBuilderClass();
42554261
|int $size = 0;
42564262
|for (int $i = 0; $i < $array1.numElements(); $i++) {
42574263
| $processArray1
@@ -4396,9 +4402,8 @@ case class ArrayExcept(left: Expression, right: Expression) extends ArraySetLike
43964402
val openHashSet = classOf[OpenHashSet[_]].getName
43974403
val classTag = s"scala.reflect.ClassTag$$.MODULE$$.$hsTypeName()"
43984404
val hashSet = ctx.freshName("hashSet")
4399-
val arrayBuilder = "scala.collection.mutable.ArrayBuilder"
4405+
val arrayBuilder = classOf[mutable.ArrayBuilder[_]].getName
44004406
val arrayBuilderClass = s"$arrayBuilder$$of$ptName"
4401-
val arrayBuilderClassTag = s"scala.reflect.ClassTag$$.MODULE$$.$ptName()"
44024407

44034408
def withArray2NullCheck(body: String): String =
44044409
if (right.dataType.asInstanceOf[ArrayType].containsNull) {
@@ -4474,8 +4479,7 @@ case class ArrayExcept(left: Expression, right: Expression) extends ArraySetLike
44744479
|for (int $i = 0; $i < $array2.numElements(); $i++) {
44754480
| $writeArray2ToHashSet
44764481
|}
4477-
|$arrayBuilderClass $builder =
4478-
| ($arrayBuilderClass)$arrayBuilder.make($arrayBuilderClassTag);
4482+
|$arrayBuilderClass $builder = new $arrayBuilderClass();
44794483
|int $size = 0;
44804484
|for (int $i = 0; $i < $array1.numElements(); $i++) {
44814485
| $processArray1

sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1679,26 +1679,26 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext {
16791679
val df6 = Seq((null, null)).toDF("a", "b")
16801680
intercept[AnalysisException] {
16811681
df6.select(array_intersect($"a", $"b"))
1682-
}
1682+
}.getMessage.contains("data type mismatch")
16831683
intercept[AnalysisException] {
16841684
df6.selectExpr("array_intersect(a, b)")
1685-
}
1685+
}.getMessage.contains("data type mismatch")
16861686

16871687
val df7 = Seq((Array(1), Array("a"))).toDF("a", "b")
16881688
intercept[AnalysisException] {
16891689
df7.select(array_intersect($"a", $"b"))
1690-
}
1690+
}.getMessage.contains("data type mismatch")
16911691
intercept[AnalysisException] {
16921692
df7.selectExpr("array_intersect(a, b)")
1693-
}
1693+
}.getMessage.contains("data type mismatch")
16941694

16951695
val df8 = Seq((null, Array("a"))).toDF("a", "b")
16961696
intercept[AnalysisException] {
16971697
df8.select(array_intersect($"a", $"b"))
1698-
}
1698+
}.getMessage.contains("data type mismatch")
16991699
intercept[AnalysisException] {
17001700
df8.selectExpr("array_intersect(a, b)")
1701-
}
1701+
}.getMessage.contains("data type mismatch")
17021702
}
17031703

17041704
test("transform function - array for primitive type not containing null") {

0 commit comments

Comments
 (0)