Skip to content

Commit 31eaabc

Browse files
committed
Lots of TODO and doc cleanup.
1 parent a95291e commit 31eaabc

File tree

5 files changed

+141
-129
lines changed

5 files changed

+141
-129
lines changed

sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java

Lines changed: 7 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -33,10 +33,6 @@
3333
import org.apache.spark.sql.types.UTF8String;
3434
import org.apache.spark.unsafe.PlatformDependent;
3535
import org.apache.spark.unsafe.bitset.BitSetMethods;
36-
import org.apache.spark.unsafe.string.UTF8StringMethods;
37-
38-
// TODO: pick a better name for this class, since this is potentially confusing.
39-
// Maybe call it UnsafeMutableRow?
4036

4137
/**
4238
* An Unsafe implementation of Row which is backed by raw memory instead of Java objects.
@@ -58,6 +54,7 @@ public final class UnsafeRow implements MutableRow {
5854

5955
private Object baseObject;
6056
private long baseOffset;
57+
/** The number of fields in this row, used for calculating the bitset width (and in assertions) */
6158
private int numFields;
6259
/** The width of the null tracking bit set, in bytes */
6360
private int bitSetWidthInBytes;
@@ -74,7 +71,7 @@ private long getFieldOffset(int ordinal) {
7471
}
7572

7673
public static int calculateBitSetWidthInBytes(int numFields) {
77-
return ((numFields / 64) + ((numFields % 64 == 0 ? 0 : 1))) * 8;
74+
return ((numFields / 64) + (numFields % 64 == 0 ? 0 : 1)) * 8;
7875
}
7976

8077
/**
@@ -211,7 +208,6 @@ public void setFloat(int ordinal, float value) {
211208

212209
@Override
213210
public void setString(int ordinal, String value) {
214-
// TODO: need to ensure that array has been suitably sized.
215211
throw new UnsupportedOperationException();
216212
}
217213

@@ -240,23 +236,14 @@ public Object get(int i) {
240236
assertIndexIsValid(i);
241237
assert (schema != null) : "Schema must be defined when calling generic get() method";
242238
final DataType dataType = schema.fields()[i].dataType();
243-
// The ordering of these `if` statements is intentional: internally, it looks like this only
244-
// gets invoked in JoinedRow when trying to access UTF8String columns. It's extremely unlikely
245-
// that internal code will call this on non-string-typed columns, but we support that anyways
246-
// just for the sake of completeness.
247-
// TODO: complete this for the remaining types?
239+
// UnsafeRow is only designed to be invoked by internal code, which only invokes this generic
240+
// get() method when trying to access UTF8String-typed columns. If we refactor the codebase to
241+
// separate the internal and external row interfaces, then internal code can fetch strings via
242+
// a new getUTF8String() method and we'll be able to remove this method.
248243
if (isNullAt(i)) {
249244
return null;
250245
} else if (dataType == StringType) {
251246
return getUTF8String(i);
252-
} else if (dataType == IntegerType) {
253-
return getInt(i);
254-
} else if (dataType == LongType) {
255-
return getLong(i);
256-
} else if (dataType == DoubleType) {
257-
return getDouble(i);
258-
} else if (dataType == FloatType) {
259-
return getFloat(i);
260247
} else {
261248
throw new UnsupportedOperationException();
262249
}
@@ -319,7 +306,7 @@ public UTF8String getUTF8String(int i) {
319306
final byte[] strBytes = new byte[stringSizeInBytes];
320307
PlatformDependent.copyMemory(
321308
baseObject,
322-
baseOffset + offsetToStringSize + 8, // The +8 is to skip past the size to get the data,
309+
baseOffset + offsetToStringSize + 8, // The `+ 8` is to skip past the size to get the data
323310
strBytes,
324311
PlatformDependent.BYTE_ARRAY_OFFSET,
325312
stringSizeInBytes
@@ -335,31 +322,26 @@ public String getString(int i) {
335322

336323
@Override
337324
public BigDecimal getDecimal(int i) {
338-
// TODO
339325
throw new UnsupportedOperationException();
340326
}
341327

342328
@Override
343329
public Date getDate(int i) {
344-
// TODO
345330
throw new UnsupportedOperationException();
346331
}
347332

348333
@Override
349334
public <T> Seq<T> getSeq(int i) {
350-
// TODO
351335
throw new UnsupportedOperationException();
352336
}
353337

354338
@Override
355339
public <T> List<T> getList(int i) {
356-
// TODO
357340
throw new UnsupportedOperationException();
358341
}
359342

360343
@Override
361344
public <K, V> Map<K, V> getMap(int i) {
362-
// TODO
363345
throw new UnsupportedOperationException();
364346
}
365347

@@ -370,19 +352,16 @@ public <T> scala.collection.immutable.Map<String, T> getValuesMap(Seq<String> fi
370352

371353
@Override
372354
public <K, V> java.util.Map<K, V> getJavaMap(int i) {
373-
// TODO
374355
throw new UnsupportedOperationException();
375356
}
376357

377358
@Override
378359
public Row getStruct(int i) {
379-
// TODO
380360
throw new UnsupportedOperationException();
381361
}
382362

383363
@Override
384364
public <T> T getAs(int i) {
385-
// TODO
386365
throw new UnsupportedOperationException();
387366
}
388367

@@ -398,7 +377,6 @@ public int fieldIndex(String name) {
398377

399378
@Override
400379
public Row copy() {
401-
// TODO
402380
throw new UnsupportedOperationException();
403381
}
404382

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverter.scala

Lines changed: 104 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -21,16 +21,88 @@ import org.apache.spark.sql.types._
2121
import org.apache.spark.unsafe.PlatformDependent
2222
import org.apache.spark.unsafe.array.ByteArrayMethods
2323

24-
/** Write a column into an UnsafeRow */
24+
/**
25+
* Converts Rows into UnsafeRow format. This class is NOT thread-safe.
26+
*
27+
* @param fieldTypes the data types of the row's columns.
28+
*/
29+
class UnsafeRowConverter(fieldTypes: Array[DataType]) {
30+
31+
def this(schema: StructType) {
32+
this(schema.fields.map(_.dataType))
33+
}
34+
35+
/** Re-used pointer to the unsafe row being written */
36+
private[this] val unsafeRow = new UnsafeRow()
37+
38+
/** Functions for encoding each column */
39+
private[this] val writers: Array[UnsafeColumnWriter[Any]] = {
40+
fieldTypes.map(t => UnsafeColumnWriter.forType(t).asInstanceOf[UnsafeColumnWriter[Any]])
41+
}
42+
43+
/** The size, in bytes, of the fixed-length portion of the row, including the null bitmap */
44+
private[this] val fixedLengthSize: Int =
45+
(8 * fieldTypes.length) + UnsafeRow.calculateBitSetWidthInBytes(fieldTypes.length)
46+
47+
/**
48+
* Compute the amount of space, in bytes, required to encode the given row.
49+
*/
50+
def getSizeRequirement(row: Row): Int = {
51+
var fieldNumber = 0
52+
var variableLengthFieldSize: Int = 0
53+
while (fieldNumber < writers.length) {
54+
if (!row.isNullAt(fieldNumber)) {
55+
variableLengthFieldSize += writers(fieldNumber).getSize(row(fieldNumber))
56+
}
57+
fieldNumber += 1
58+
}
59+
fixedLengthSize + variableLengthFieldSize
60+
}
61+
62+
/**
63+
* Convert the given row into UnsafeRow format.
64+
*
65+
* @param row the row to convert
66+
* @param baseObject the base object of the destination address
67+
* @param baseOffset the base offset of the destination address
68+
* @return the number of bytes written. This should be equal to `getSizeRequirement(row)`.
69+
*/
70+
def writeRow(row: Row, baseObject: Object, baseOffset: Long): Long = {
71+
unsafeRow.pointTo(baseObject, baseOffset, writers.length, null)
72+
var fieldNumber = 0
73+
var appendCursor: Int = fixedLengthSize
74+
while (fieldNumber < writers.length) {
75+
if (row.isNullAt(fieldNumber)) {
76+
unsafeRow.setNullAt(fieldNumber)
77+
// TODO: type-specific null value writing?
78+
} else {
79+
appendCursor += writers(fieldNumber).write(
80+
row(fieldNumber),
81+
fieldNumber,
82+
unsafeRow,
83+
baseObject,
84+
baseOffset,
85+
appendCursor)
86+
}
87+
fieldNumber += 1
88+
}
89+
appendCursor
90+
}
91+
92+
}
93+
94+
/**
95+
* Function for writing a column into an UnsafeRow.
96+
*/
2597
private abstract class UnsafeColumnWriter[T] {
2698
/**
2799
* Write a value into an UnsafeRow.
28100
*
29101
* @param value the value to write
30102
* @param columnNumber what column to write it to
31103
* @param row a pointer to the unsafe row
32-
* @param baseObject
33-
* @param baseOffset
104+
* @param baseObject the base object of the target row's address
105+
* @param baseOffset the base offset of the target row's address
34106
* @param appendCursor the offset from the start of the unsafe row to the end of the row;
35107
* used for calculating where variable-length data should be written
36108
* @return the number of variable-length bytes written
@@ -50,6 +122,12 @@ private abstract class UnsafeColumnWriter[T] {
50122
}
51123

52124
private object UnsafeColumnWriter {
125+
private object IntUnsafeColumnWriter extends IntUnsafeColumnWriter
126+
private object LongUnsafeColumnWriter extends LongUnsafeColumnWriter
127+
private object FloatUnsafeColumnWriter extends FloatUnsafeColumnWriter
128+
private object DoubleUnsafeColumnWriter extends DoubleUnsafeColumnWriter
129+
private object StringUnsafeColumnWriter extends StringUnsafeColumnWriter
130+
53131
def forType(dataType: DataType): UnsafeColumnWriter[_] = {
54132
dataType match {
55133
case IntegerType => IntUnsafeColumnWriter
@@ -63,34 +141,7 @@ private object UnsafeColumnWriter {
63141
}
64142
}
65143

66-
private class StringUnsafeColumnWriter private() extends UnsafeColumnWriter[UTF8String] {
67-
def getSize(value: UTF8String): Int = {
68-
// round to nearest word
69-
val numBytes = value.getBytes.length
70-
8 + ByteArrayMethods.roundNumberOfBytesToNearestWord(numBytes)
71-
}
72-
73-
override def write(
74-
value: UTF8String,
75-
columnNumber: Int,
76-
row: UnsafeRow,
77-
baseObject: Object,
78-
baseOffset: Long,
79-
appendCursor: Int): Int = {
80-
val numBytes = value.getBytes.length
81-
PlatformDependent.UNSAFE.putLong(baseObject, baseOffset + appendCursor, numBytes)
82-
PlatformDependent.copyMemory(
83-
value.getBytes,
84-
PlatformDependent.BYTE_ARRAY_OFFSET,
85-
baseObject,
86-
baseOffset + appendCursor + 8,
87-
numBytes
88-
)
89-
row.setLong(columnNumber, appendCursor)
90-
8 + ByteArrayMethods.roundNumberOfBytesToNearestWord(numBytes)
91-
}
92-
}
93-
private object StringUnsafeColumnWriter extends StringUnsafeColumnWriter
144+
// ------------------------------------------------------------------------------------------------
94145

95146
private abstract class PrimitiveUnsafeColumnWriter[T] extends UnsafeColumnWriter[T] {
96147
def getSize(value: T): Int = 0
@@ -108,7 +159,6 @@ private class IntUnsafeColumnWriter private() extends PrimitiveUnsafeColumnWrite
108159
0
109160
}
110161
}
111-
private object IntUnsafeColumnWriter extends IntUnsafeColumnWriter
112162

113163
private class LongUnsafeColumnWriter private() extends PrimitiveUnsafeColumnWriter[Long] {
114164
override def write(
@@ -122,7 +172,6 @@ private class LongUnsafeColumnWriter private() extends PrimitiveUnsafeColumnWrit
122172
0
123173
}
124174
}
125-
private case object LongUnsafeColumnWriter extends LongUnsafeColumnWriter
126175

127176
private class FloatUnsafeColumnWriter private() extends PrimitiveUnsafeColumnWriter[Float] {
128177
override def write(
@@ -136,7 +185,6 @@ private class FloatUnsafeColumnWriter private() extends PrimitiveUnsafeColumnWri
136185
0
137186
}
138187
}
139-
private case object FloatUnsafeColumnWriter extends FloatUnsafeColumnWriter
140188

141189
private class DoubleUnsafeColumnWriter private() extends PrimitiveUnsafeColumnWriter[Double] {
142190
override def write(
@@ -150,55 +198,29 @@ private class DoubleUnsafeColumnWriter private() extends PrimitiveUnsafeColumnWr
150198
0
151199
}
152200
}
153-
private case object DoubleUnsafeColumnWriter extends DoubleUnsafeColumnWriter
154201

155-
class UnsafeRowConverter(fieldTypes: Array[DataType]) {
156-
157-
def this(schema: StructType) {
158-
this(schema.fields.map(_.dataType))
159-
}
160-
161-
private[this] val unsafeRow = new UnsafeRow()
162-
163-
private[this] val writers: Array[UnsafeColumnWriter[Any]] = {
164-
fieldTypes.map(t => UnsafeColumnWriter.forType(t).asInstanceOf[UnsafeColumnWriter[Any]])
165-
}
166-
167-
private[this] val fixedLengthSize: Int =
168-
(8 * fieldTypes.length) + UnsafeRow.calculateBitSetWidthInBytes(fieldTypes.length)
169-
170-
def getSizeRequirement(row: Row): Int = {
171-
var fieldNumber = 0
172-
var variableLengthFieldSize: Int = 0
173-
while (fieldNumber < writers.length) {
174-
if (!row.isNullAt(fieldNumber)) {
175-
variableLengthFieldSize += writers(fieldNumber).getSize(row(fieldNumber))
176-
}
177-
fieldNumber += 1
178-
}
179-
fixedLengthSize + variableLengthFieldSize
202+
private class StringUnsafeColumnWriter private() extends UnsafeColumnWriter[UTF8String] {
203+
def getSize(value: UTF8String): Int = {
204+
8 + ByteArrayMethods.roundNumberOfBytesToNearestWord(value.getBytes.length)
180205
}
181206

182-
def writeRow(row: Row, baseObject: Object, baseOffset: Long): Long = {
183-
unsafeRow.pointTo(baseObject, baseOffset, writers.length, null)
184-
var fieldNumber = 0
185-
var appendCursor: Int = fixedLengthSize
186-
while (fieldNumber < writers.length) {
187-
if (row.isNullAt(fieldNumber)) {
188-
unsafeRow.setNullAt(fieldNumber)
189-
// TODO: type-specific null value writing?
190-
} else {
191-
appendCursor += writers(fieldNumber).write(
192-
row(fieldNumber),
193-
fieldNumber,
194-
unsafeRow,
195-
baseObject,
196-
baseOffset,
197-
appendCursor)
198-
}
199-
fieldNumber += 1
200-
}
201-
appendCursor
207+
override def write(
208+
value: UTF8String,
209+
columnNumber: Int,
210+
row: UnsafeRow,
211+
baseObject: Object,
212+
baseOffset: Long,
213+
appendCursor: Int): Int = {
214+
val numBytes = value.getBytes.length
215+
PlatformDependent.UNSAFE.putLong(baseObject, baseOffset + appendCursor, numBytes)
216+
PlatformDependent.copyMemory(
217+
value.getBytes,
218+
PlatformDependent.BYTE_ARRAY_OFFSET,
219+
baseObject,
220+
baseOffset + appendCursor + 8,
221+
numBytes
222+
)
223+
row.setLong(columnNumber, appendCursor)
224+
8 + ByteArrayMethods.roundNumberOfBytesToNearestWord(numBytes)
202225
}
203-
204-
}
226+
}

0 commit comments

Comments
 (0)