Skip to content

Commit 48f8fd4

Browse files
JoshRosenrxin
authored andcommitted
[SPARK-9023] [SQL] Followup for #7456 (Efficiency improvements for UnsafeRows in Exchange)
This patch addresses code review feedback from #7456. Author: Josh Rosen <[email protected]> Closes #7551 from JoshRosen/unsafe-exchange-followup and squashes the following commits: 76dbdf8 [Josh Rosen] Add comments + more methods to UnsafeRowSerializer 3d7a1f2 [Josh Rosen] Add writeToStream() method to UnsafeRow
1 parent 67570be commit 48f8fd4

File tree

3 files changed

+161
-23
lines changed

3 files changed

+161
-23
lines changed

sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,9 @@
1717

1818
package org.apache.spark.sql.catalyst.expressions;
1919

20+
import java.io.IOException;
21+
import java.io.OutputStream;
22+
2023
import org.apache.spark.sql.catalyst.InternalRow;
2124
import org.apache.spark.sql.catalyst.util.ObjectPool;
2225
import org.apache.spark.unsafe.PlatformDependent;
@@ -371,6 +374,36 @@ public InternalRow copy() {
371374
}
372375
}
373376

377+
/**
378+
* Write this UnsafeRow's underlying bytes to the given OutputStream.
379+
*
380+
* @param out the stream to write to.
381+
* @param writeBuffer a byte array for buffering chunks of off-heap data while writing to the
382+
* output stream. If this row is backed by an on-heap byte array, then this
383+
* buffer will not be used and may be null.
384+
*/
385+
public void writeToStream(OutputStream out, byte[] writeBuffer) throws IOException {
386+
if (baseObject instanceof byte[]) {
387+
int offsetInByteArray = (int) (PlatformDependent.BYTE_ARRAY_OFFSET - baseOffset);
388+
out.write((byte[]) baseObject, offsetInByteArray, sizeInBytes);
389+
} else {
390+
int dataRemaining = sizeInBytes;
391+
long rowReadPosition = baseOffset;
392+
while (dataRemaining > 0) {
393+
int toTransfer = Math.min(writeBuffer.length, dataRemaining);
394+
PlatformDependent.copyMemory(
395+
baseObject,
396+
rowReadPosition,
397+
writeBuffer,
398+
PlatformDependent.BYTE_ARRAY_OFFSET,
399+
toTransfer);
400+
out.write(writeBuffer, 0, toTransfer);
401+
rowReadPosition += toTransfer;
402+
dataRemaining -= toTransfer;
403+
}
404+
}
405+
}
406+
374407
@Override
375408
public boolean anyNull() {
376409
return BitSetMethods.anySet(baseObject, baseOffset, bitSetWidthInBytes);

sql/core/src/main/scala/org/apache/spark/sql/execution/UnsafeRowSerializer.scala

Lines changed: 57 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -49,8 +49,16 @@ private[sql] class UnsafeRowSerializer(numFields: Int) extends Serializer with S
4949

5050
private class UnsafeRowSerializerInstance(numFields: Int) extends SerializerInstance {
5151

52+
/**
53+
* Marks the end of a stream written with [[serializeStream()]].
54+
*/
5255
private[this] val EOF: Int = -1
5356

57+
/**
58+
* Serializes a stream of UnsafeRows. Within the stream, each record consists of a record
59+
* length (stored as a 4-byte integer, written high byte first), followed by the record's bytes.
60+
* The end of the stream is denoted by a record with the special length `EOF` (-1).
61+
*/
5462
override def serializeStream(out: OutputStream): SerializationStream = new SerializationStream {
5563
private[this] var writeBuffer: Array[Byte] = new Array[Byte](4096)
5664
private[this] val dOut: DataOutputStream = new DataOutputStream(out)
@@ -59,32 +67,31 @@ private class UnsafeRowSerializerInstance(numFields: Int) extends SerializerInst
5967
val row = value.asInstanceOf[UnsafeRow]
6068
assert(row.getPool == null, "UnsafeRowSerializer does not support ObjectPool")
6169
dOut.writeInt(row.getSizeInBytes)
62-
var dataRemaining: Int = row.getSizeInBytes
63-
val baseObject = row.getBaseObject
64-
var rowReadPosition: Long = row.getBaseOffset
65-
while (dataRemaining > 0) {
66-
val toTransfer: Int = Math.min(writeBuffer.length, dataRemaining)
67-
PlatformDependent.copyMemory(
68-
baseObject,
69-
rowReadPosition,
70-
writeBuffer,
71-
PlatformDependent.BYTE_ARRAY_OFFSET,
72-
toTransfer)
73-
out.write(writeBuffer, 0, toTransfer)
74-
rowReadPosition += toTransfer
75-
dataRemaining -= toTransfer
76-
}
70+
row.writeToStream(out, writeBuffer)
7771
this
7872
}
73+
7974
override def writeKey[T: ClassTag](key: T): SerializationStream = {
75+
// The key is only needed on the map side when computing partition ids. It does not need to
76+
// be shuffled.
8077
assert(key.isInstanceOf[Int])
8178
this
8279
}
83-
override def writeAll[T: ClassTag](iter: Iterator[T]): SerializationStream =
80+
81+
override def writeAll[T: ClassTag](iter: Iterator[T]): SerializationStream = {
82+
// This method is never called by shuffle code.
8483
throw new UnsupportedOperationException
85-
override def writeObject[T: ClassTag](t: T): SerializationStream =
84+
}
85+
86+
override def writeObject[T: ClassTag](t: T): SerializationStream = {
87+
// This method is never called by shuffle code.
8688
throw new UnsupportedOperationException
87-
override def flush(): Unit = dOut.flush()
89+
}
90+
91+
override def flush(): Unit = {
92+
dOut.flush()
93+
}
94+
8895
override def close(): Unit = {
8996
writeBuffer = null
9097
dOut.writeInt(EOF)
@@ -95,6 +102,7 @@ private class UnsafeRowSerializerInstance(numFields: Int) extends SerializerInst
95102
override def deserializeStream(in: InputStream): DeserializationStream = {
96103
new DeserializationStream {
97104
private[this] val dIn: DataInputStream = new DataInputStream(in)
105+
// 1024 is a default buffer size; this buffer will grow to accommodate larger rows
98106
private[this] var rowBuffer: Array[Byte] = new Array[Byte](1024)
99107
private[this] var row: UnsafeRow = new UnsafeRow()
100108
private[this] var rowTuple: (Int, UnsafeRow) = (0, row)
@@ -126,14 +134,40 @@ private class UnsafeRowSerializerInstance(numFields: Int) extends SerializerInst
126134
}
127135
}
128136
}
129-
override def asIterator: Iterator[Any] = throw new UnsupportedOperationException
130-
override def readKey[T: ClassTag](): T = throw new UnsupportedOperationException
131-
override def readValue[T: ClassTag](): T = throw new UnsupportedOperationException
132-
override def readObject[T: ClassTag](): T = throw new UnsupportedOperationException
133-
override def close(): Unit = dIn.close()
137+
138+
override def asIterator: Iterator[Any] = {
139+
// This method is never called by shuffle code.
140+
throw new UnsupportedOperationException
141+
}
142+
143+
override def readKey[T: ClassTag](): T = {
144+
// We skipped serialization of the key in writeKey(), so just return a dummy value since
145+
// this is going to be discarded anyways.
146+
null.asInstanceOf[T]
147+
}
148+
149+
override def readValue[T: ClassTag](): T = {
150+
val rowSize = dIn.readInt()
151+
if (rowBuffer.length < rowSize) {
152+
rowBuffer = new Array[Byte](rowSize)
153+
}
154+
ByteStreams.readFully(in, rowBuffer, 0, rowSize)
155+
row.pointTo(rowBuffer, PlatformDependent.BYTE_ARRAY_OFFSET, numFields, rowSize, null)
156+
row.asInstanceOf[T]
157+
}
158+
159+
override def readObject[T: ClassTag](): T = {
160+
// This method is never called by shuffle code.
161+
throw new UnsupportedOperationException
162+
}
163+
164+
override def close(): Unit = {
165+
dIn.close()
166+
}
134167
}
135168
}
136169

170+
// These methods are never called by shuffle code.
137171
override def serialize[T: ClassTag](t: T): ByteBuffer = throw new UnsupportedOperationException
138172
override def deserialize[T: ClassTag](bytes: ByteBuffer): T =
139173
throw new UnsupportedOperationException
Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql
19+
20+
import java.io.ByteArrayOutputStream
21+
22+
import org.apache.spark.SparkFunSuite
23+
import org.apache.spark.sql.catalyst.InternalRow
24+
import org.apache.spark.sql.catalyst.expressions.{UnsafeRow, UnsafeProjection}
25+
import org.apache.spark.sql.types.{IntegerType, StringType}
26+
import org.apache.spark.unsafe.PlatformDependent
27+
import org.apache.spark.unsafe.memory.MemoryAllocator
28+
import org.apache.spark.unsafe.types.UTF8String
29+
30+
class UnsafeRowSuite extends SparkFunSuite {
31+
test("writeToStream") {
32+
val row = InternalRow.apply(UTF8String.fromString("hello"), UTF8String.fromString("world"), 123)
33+
val arrayBackedUnsafeRow: UnsafeRow =
34+
UnsafeProjection.create(Seq(StringType, StringType, IntegerType)).apply(row)
35+
assert(arrayBackedUnsafeRow.getBaseObject.isInstanceOf[Array[Byte]])
36+
val bytesFromArrayBackedRow: Array[Byte] = {
37+
val baos = new ByteArrayOutputStream()
38+
arrayBackedUnsafeRow.writeToStream(baos, null)
39+
baos.toByteArray
40+
}
41+
val bytesFromOffheapRow: Array[Byte] = {
42+
val offheapRowPage = MemoryAllocator.UNSAFE.allocate(arrayBackedUnsafeRow.getSizeInBytes)
43+
try {
44+
PlatformDependent.copyMemory(
45+
arrayBackedUnsafeRow.getBaseObject,
46+
arrayBackedUnsafeRow.getBaseOffset,
47+
offheapRowPage.getBaseObject,
48+
offheapRowPage.getBaseOffset,
49+
arrayBackedUnsafeRow.getSizeInBytes
50+
)
51+
val offheapUnsafeRow: UnsafeRow = new UnsafeRow()
52+
offheapUnsafeRow.pointTo(
53+
offheapRowPage.getBaseObject,
54+
offheapRowPage.getBaseOffset,
55+
3, // num fields
56+
arrayBackedUnsafeRow.getSizeInBytes,
57+
null // object pool
58+
)
59+
assert(offheapUnsafeRow.getBaseObject === null)
60+
val baos = new ByteArrayOutputStream()
61+
val writeBuffer = new Array[Byte](1024)
62+
offheapUnsafeRow.writeToStream(baos, writeBuffer)
63+
baos.toByteArray
64+
} finally {
65+
MemoryAllocator.UNSAFE.free(offheapRowPage)
66+
}
67+
}
68+
69+
assert(bytesFromArrayBackedRow === bytesFromOffheapRow)
70+
}
71+
}

0 commit comments

Comments
 (0)