Skip to content

Commit 3cd3d51

Browse files
rxinCodingCat
authored andcommitted
[SPARK-9736] [SQL] JoinedRow.anyNull should delegate to the underlying rows.
JoinedRow.anyNull currently loops through every field to check for null, which is inefficient if the underlying rows are UnsafeRows. It should just delegate to the underlying implementation. Author: Reynold Xin <[email protected]> Closes apache#8027 from rxin/SPARK-9736 and squashes the following commits: 03a2e92 [Reynold Xin] Include all files. 90f1add [Reynold Xin] [SPARK-9736][SQL] JoinedRow.anyNull should delegate to the underlying rows.
1 parent 6dee300 commit 3cd3d51

File tree

4 files changed

+156
-129
lines changed

4 files changed

+156
-129
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -37,15 +37,7 @@ abstract class InternalRow extends SpecializedGetters with Serializable {
3737
def copy(): InternalRow
3838

3939
/** Returns true if there are any NULL values in this row. */
40-
def anyNull: Boolean = {
41-
val len = numFields
42-
var i = 0
43-
while (i < len) {
44-
if (isNullAt(i)) { return true }
45-
i += 1
46-
}
47-
false
48-
}
40+
def anyNull: Boolean
4941

5042
/* ---------------------- utility methods for Scala ---------------------- */
5143

Lines changed: 144 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,144 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.catalyst.expressions
19+
20+
import org.apache.spark.sql.catalyst.InternalRow
21+
import org.apache.spark.sql.types._
22+
import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}
23+
24+
25+
/**
26+
* A mutable wrapper that makes two rows appear as a single concatenated row. Designed to
27+
* be instantiated once per thread and reused.
28+
*/
29+
class JoinedRow extends InternalRow {
30+
private[this] var row1: InternalRow = _
31+
private[this] var row2: InternalRow = _
32+
33+
def this(left: InternalRow, right: InternalRow) = {
34+
this()
35+
row1 = left
36+
row2 = right
37+
}
38+
39+
/** Updates this JoinedRow to used point at two new base rows. Returns itself. */
40+
def apply(r1: InternalRow, r2: InternalRow): InternalRow = {
41+
row1 = r1
42+
row2 = r2
43+
this
44+
}
45+
46+
/** Updates this JoinedRow by updating its left base row. Returns itself. */
47+
def withLeft(newLeft: InternalRow): InternalRow = {
48+
row1 = newLeft
49+
this
50+
}
51+
52+
/** Updates this JoinedRow by updating its right base row. Returns itself. */
53+
def withRight(newRight: InternalRow): InternalRow = {
54+
row2 = newRight
55+
this
56+
}
57+
58+
override def toSeq(fieldTypes: Seq[DataType]): Seq[Any] = {
59+
assert(fieldTypes.length == row1.numFields + row2.numFields)
60+
val (left, right) = fieldTypes.splitAt(row1.numFields)
61+
row1.toSeq(left) ++ row2.toSeq(right)
62+
}
63+
64+
override def numFields: Int = row1.numFields + row2.numFields
65+
66+
override def get(i: Int, dt: DataType): AnyRef =
67+
if (i < row1.numFields) row1.get(i, dt) else row2.get(i - row1.numFields, dt)
68+
69+
override def isNullAt(i: Int): Boolean =
70+
if (i < row1.numFields) row1.isNullAt(i) else row2.isNullAt(i - row1.numFields)
71+
72+
override def getBoolean(i: Int): Boolean =
73+
if (i < row1.numFields) row1.getBoolean(i) else row2.getBoolean(i - row1.numFields)
74+
75+
override def getByte(i: Int): Byte =
76+
if (i < row1.numFields) row1.getByte(i) else row2.getByte(i - row1.numFields)
77+
78+
override def getShort(i: Int): Short =
79+
if (i < row1.numFields) row1.getShort(i) else row2.getShort(i - row1.numFields)
80+
81+
override def getInt(i: Int): Int =
82+
if (i < row1.numFields) row1.getInt(i) else row2.getInt(i - row1.numFields)
83+
84+
override def getLong(i: Int): Long =
85+
if (i < row1.numFields) row1.getLong(i) else row2.getLong(i - row1.numFields)
86+
87+
override def getFloat(i: Int): Float =
88+
if (i < row1.numFields) row1.getFloat(i) else row2.getFloat(i - row1.numFields)
89+
90+
override def getDouble(i: Int): Double =
91+
if (i < row1.numFields) row1.getDouble(i) else row2.getDouble(i - row1.numFields)
92+
93+
override def getDecimal(i: Int, precision: Int, scale: Int): Decimal = {
94+
if (i < row1.numFields) {
95+
row1.getDecimal(i, precision, scale)
96+
} else {
97+
row2.getDecimal(i - row1.numFields, precision, scale)
98+
}
99+
}
100+
101+
override def getUTF8String(i: Int): UTF8String =
102+
if (i < row1.numFields) row1.getUTF8String(i) else row2.getUTF8String(i - row1.numFields)
103+
104+
override def getBinary(i: Int): Array[Byte] =
105+
if (i < row1.numFields) row1.getBinary(i) else row2.getBinary(i - row1.numFields)
106+
107+
override def getArray(i: Int): ArrayData =
108+
if (i < row1.numFields) row1.getArray(i) else row2.getArray(i - row1.numFields)
109+
110+
override def getInterval(i: Int): CalendarInterval =
111+
if (i < row1.numFields) row1.getInterval(i) else row2.getInterval(i - row1.numFields)
112+
113+
override def getMap(i: Int): MapData =
114+
if (i < row1.numFields) row1.getMap(i) else row2.getMap(i - row1.numFields)
115+
116+
override def getStruct(i: Int, numFields: Int): InternalRow = {
117+
if (i < row1.numFields) {
118+
row1.getStruct(i, numFields)
119+
} else {
120+
row2.getStruct(i - row1.numFields, numFields)
121+
}
122+
}
123+
124+
override def anyNull: Boolean = row1.anyNull || row2.anyNull
125+
126+
override def copy(): InternalRow = {
127+
val copy1 = row1.copy()
128+
val copy2 = row2.copy()
129+
new JoinedRow(copy1, copy2)
130+
}
131+
132+
override def toString: String = {
133+
// Make sure toString never throws NullPointerException.
134+
if ((row1 eq null) && (row2 eq null)) {
135+
"[ empty row ]"
136+
} else if (row1 eq null) {
137+
row2.toString
138+
} else if (row2 eq null) {
139+
row1.toString
140+
} else {
141+
s"{${row1.toString} + ${row2.toString}}"
142+
}
143+
}
144+
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala

Lines changed: 0 additions & 119 deletions
Original file line numberDiff line numberDiff line change
@@ -169,122 +169,3 @@ object FromUnsafeProjection {
169169
GenerateSafeProjection.generate(exprs)
170170
}
171171
}
172-
173-
/**
174-
* A mutable wrapper that makes two rows appear as a single concatenated row. Designed to
175-
* be instantiated once per thread and reused.
176-
*/
177-
class JoinedRow extends InternalRow {
178-
private[this] var row1: InternalRow = _
179-
private[this] var row2: InternalRow = _
180-
181-
def this(left: InternalRow, right: InternalRow) = {
182-
this()
183-
row1 = left
184-
row2 = right
185-
}
186-
187-
/** Updates this JoinedRow to used point at two new base rows. Returns itself. */
188-
def apply(r1: InternalRow, r2: InternalRow): InternalRow = {
189-
row1 = r1
190-
row2 = r2
191-
this
192-
}
193-
194-
/** Updates this JoinedRow by updating its left base row. Returns itself. */
195-
def withLeft(newLeft: InternalRow): InternalRow = {
196-
row1 = newLeft
197-
this
198-
}
199-
200-
/** Updates this JoinedRow by updating its right base row. Returns itself. */
201-
def withRight(newRight: InternalRow): InternalRow = {
202-
row2 = newRight
203-
this
204-
}
205-
206-
override def toSeq(fieldTypes: Seq[DataType]): Seq[Any] = {
207-
assert(fieldTypes.length == row1.numFields + row2.numFields)
208-
val (left, right) = fieldTypes.splitAt(row1.numFields)
209-
row1.toSeq(left) ++ row2.toSeq(right)
210-
}
211-
212-
override def numFields: Int = row1.numFields + row2.numFields
213-
214-
override def get(i: Int, dt: DataType): AnyRef =
215-
if (i < row1.numFields) row1.get(i, dt) else row2.get(i - row1.numFields, dt)
216-
217-
override def isNullAt(i: Int): Boolean =
218-
if (i < row1.numFields) row1.isNullAt(i) else row2.isNullAt(i - row1.numFields)
219-
220-
override def getBoolean(i: Int): Boolean =
221-
if (i < row1.numFields) row1.getBoolean(i) else row2.getBoolean(i - row1.numFields)
222-
223-
override def getByte(i: Int): Byte =
224-
if (i < row1.numFields) row1.getByte(i) else row2.getByte(i - row1.numFields)
225-
226-
override def getShort(i: Int): Short =
227-
if (i < row1.numFields) row1.getShort(i) else row2.getShort(i - row1.numFields)
228-
229-
override def getInt(i: Int): Int =
230-
if (i < row1.numFields) row1.getInt(i) else row2.getInt(i - row1.numFields)
231-
232-
override def getLong(i: Int): Long =
233-
if (i < row1.numFields) row1.getLong(i) else row2.getLong(i - row1.numFields)
234-
235-
override def getFloat(i: Int): Float =
236-
if (i < row1.numFields) row1.getFloat(i) else row2.getFloat(i - row1.numFields)
237-
238-
override def getDouble(i: Int): Double =
239-
if (i < row1.numFields) row1.getDouble(i) else row2.getDouble(i - row1.numFields)
240-
241-
override def getDecimal(i: Int, precision: Int, scale: Int): Decimal = {
242-
if (i < row1.numFields) {
243-
row1.getDecimal(i, precision, scale)
244-
} else {
245-
row2.getDecimal(i - row1.numFields, precision, scale)
246-
}
247-
}
248-
249-
override def getUTF8String(i: Int): UTF8String =
250-
if (i < row1.numFields) row1.getUTF8String(i) else row2.getUTF8String(i - row1.numFields)
251-
252-
override def getBinary(i: Int): Array[Byte] =
253-
if (i < row1.numFields) row1.getBinary(i) else row2.getBinary(i - row1.numFields)
254-
255-
override def getArray(i: Int): ArrayData =
256-
if (i < row1.numFields) row1.getArray(i) else row2.getArray(i - row1.numFields)
257-
258-
override def getInterval(i: Int): CalendarInterval =
259-
if (i < row1.numFields) row1.getInterval(i) else row2.getInterval(i - row1.numFields)
260-
261-
override def getMap(i: Int): MapData =
262-
if (i < row1.numFields) row1.getMap(i) else row2.getMap(i - row1.numFields)
263-
264-
override def getStruct(i: Int, numFields: Int): InternalRow = {
265-
if (i < row1.numFields) {
266-
row1.getStruct(i, numFields)
267-
} else {
268-
row2.getStruct(i - row1.numFields, numFields)
269-
}
270-
}
271-
272-
override def copy(): InternalRow = {
273-
val copy1 = row1.copy()
274-
val copy2 = row2.copy()
275-
new JoinedRow(copy1, copy2)
276-
}
277-
278-
override def toString: String = {
279-
// Make sure toString never throws NullPointerException.
280-
if ((row1 eq null) && (row2 eq null)) {
281-
"[ empty row ]"
282-
} else if (row1 eq null) {
283-
row2.toString
284-
} else if (row2 eq null) {
285-
row1.toString
286-
} else {
287-
s"{${row1.toString} + ${row2.toString}}"
288-
}
289-
}
290-
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,17 @@ trait BaseGenericInternalRow extends InternalRow {
4949
override def getMap(ordinal: Int): MapData = getAs(ordinal)
5050
override def getStruct(ordinal: Int, numFields: Int): InternalRow = getAs(ordinal)
5151

52-
override def toString(): String = {
52+
override def anyNull: Boolean = {
53+
val len = numFields
54+
var i = 0
55+
while (i < len) {
56+
if (isNullAt(i)) { return true }
57+
i += 1
58+
}
59+
false
60+
}
61+
62+
override def toString: String = {
5363
if (numFields == 0) {
5464
"[empty row]"
5565
} else {

0 commit comments

Comments
 (0)