Skip to content

Commit 90f1add

Browse files
committed
[SPARK-9736][SQL] JoinedRow.anyNull should delegate to the underlying rows.
1 parent ebfd91c commit 90f1add

File tree

1 file changed

+144
-0
lines changed
  • sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions

1 file changed

+144
-0
lines changed
Lines changed: 144 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,144 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.catalyst.expressions
19+
20+
import org.apache.spark.sql.catalyst.InternalRow
21+
import org.apache.spark.sql.types._
22+
import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}
23+
24+
25+
/**
26+
* A mutable wrapper that makes two rows appear as a single concatenated row. Designed to
27+
* be instantiated once per thread and reused.
28+
*/
29+
class JoinedRow extends InternalRow {
30+
private[this] var row1: InternalRow = _
31+
private[this] var row2: InternalRow = _
32+
33+
def this(left: InternalRow, right: InternalRow) = {
34+
this()
35+
row1 = left
36+
row2 = right
37+
}
38+
39+
/** Updates this JoinedRow to used point at two new base rows. Returns itself. */
40+
def apply(r1: InternalRow, r2: InternalRow): InternalRow = {
41+
row1 = r1
42+
row2 = r2
43+
this
44+
}
45+
46+
/** Updates this JoinedRow by updating its left base row. Returns itself. */
47+
def withLeft(newLeft: InternalRow): InternalRow = {
48+
row1 = newLeft
49+
this
50+
}
51+
52+
/** Updates this JoinedRow by updating its right base row. Returns itself. */
53+
def withRight(newRight: InternalRow): InternalRow = {
54+
row2 = newRight
55+
this
56+
}
57+
58+
override def toSeq(fieldTypes: Seq[DataType]): Seq[Any] = {
59+
assert(fieldTypes.length == row1.numFields + row2.numFields)
60+
val (left, right) = fieldTypes.splitAt(row1.numFields)
61+
row1.toSeq(left) ++ row2.toSeq(right)
62+
}
63+
64+
override def numFields: Int = row1.numFields + row2.numFields
65+
66+
override def get(i: Int, dt: DataType): AnyRef =
67+
if (i < row1.numFields) row1.get(i, dt) else row2.get(i - row1.numFields, dt)
68+
69+
override def isNullAt(i: Int): Boolean =
70+
if (i < row1.numFields) row1.isNullAt(i) else row2.isNullAt(i - row1.numFields)
71+
72+
override def getBoolean(i: Int): Boolean =
73+
if (i < row1.numFields) row1.getBoolean(i) else row2.getBoolean(i - row1.numFields)
74+
75+
override def getByte(i: Int): Byte =
76+
if (i < row1.numFields) row1.getByte(i) else row2.getByte(i - row1.numFields)
77+
78+
override def getShort(i: Int): Short =
79+
if (i < row1.numFields) row1.getShort(i) else row2.getShort(i - row1.numFields)
80+
81+
override def getInt(i: Int): Int =
82+
if (i < row1.numFields) row1.getInt(i) else row2.getInt(i - row1.numFields)
83+
84+
override def getLong(i: Int): Long =
85+
if (i < row1.numFields) row1.getLong(i) else row2.getLong(i - row1.numFields)
86+
87+
override def getFloat(i: Int): Float =
88+
if (i < row1.numFields) row1.getFloat(i) else row2.getFloat(i - row1.numFields)
89+
90+
override def getDouble(i: Int): Double =
91+
if (i < row1.numFields) row1.getDouble(i) else row2.getDouble(i - row1.numFields)
92+
93+
override def getDecimal(i: Int, precision: Int, scale: Int): Decimal = {
94+
if (i < row1.numFields) {
95+
row1.getDecimal(i, precision, scale)
96+
} else {
97+
row2.getDecimal(i - row1.numFields, precision, scale)
98+
}
99+
}
100+
101+
override def getUTF8String(i: Int): UTF8String =
102+
if (i < row1.numFields) row1.getUTF8String(i) else row2.getUTF8String(i - row1.numFields)
103+
104+
override def getBinary(i: Int): Array[Byte] =
105+
if (i < row1.numFields) row1.getBinary(i) else row2.getBinary(i - row1.numFields)
106+
107+
override def getArray(i: Int): ArrayData =
108+
if (i < row1.numFields) row1.getArray(i) else row2.getArray(i - row1.numFields)
109+
110+
override def getInterval(i: Int): CalendarInterval =
111+
if (i < row1.numFields) row1.getInterval(i) else row2.getInterval(i - row1.numFields)
112+
113+
override def getMap(i: Int): MapData =
114+
if (i < row1.numFields) row1.getMap(i) else row2.getMap(i - row1.numFields)
115+
116+
override def getStruct(i: Int, numFields: Int): InternalRow = {
117+
if (i < row1.numFields) {
118+
row1.getStruct(i, numFields)
119+
} else {
120+
row2.getStruct(i - row1.numFields, numFields)
121+
}
122+
}
123+
124+
override def anyNull: Boolean = row1.anyNull || row2.anyNull
125+
126+
override def copy(): InternalRow = {
127+
val copy1 = row1.copy()
128+
val copy2 = row2.copy()
129+
new JoinedRow(copy1, copy2)
130+
}
131+
132+
override def toString: String = {
133+
// Make sure toString never throws NullPointerException.
134+
if ((row1 eq null) && (row2 eq null)) {
135+
"[ empty row ]"
136+
} else if (row1 eq null) {
137+
row2.toString
138+
} else if (row2 eq null) {
139+
row1.toString
140+
} else {
141+
s"{${row1.toString} + ${row2.toString}}"
142+
}
143+
}
144+
}

0 commit comments

Comments
 (0)