Skip to content

Commit b0c00f4

Browse files
author
hongshen
committed
External spilling when join a lot of rows with the same key
1 parent 633d63a commit b0c00f4

File tree

4 files changed

+403
-26
lines changed

4 files changed

+403
-26
lines changed
Lines changed: 344 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,344 @@
1+
package org.apache.spark.util.collection
2+
3+
import java.io._
4+
import java.util.Comparator
5+
6+
import scala.collection.BufferedIterator
7+
import scala.collection.mutable
8+
import scala.collection.mutable.ArrayBuffer
9+
import scala.reflect.ClassTag
10+
11+
import com.google.common.io.ByteStreams
12+
13+
import org.apache.spark.{Logging, SparkEnv, TaskContext}
14+
import org.apache.spark.annotation.DeveloperApi
15+
import org.apache.spark.memory.TaskMemoryManager
16+
import org.apache.spark.serializer.{DeserializationStream, Serializer}
17+
import org.apache.spark.storage.{BlockId, BlockManager}
18+
import org.apache.spark.util.CompletionIterator
19+
import org.apache.spark.util.collection.ExternalAppendOnlyMap.HashComparator
20+
import org.apache.spark.executor.ShuffleWriteMetrics
21+
22+
class ExternalAppendOnlyArrayBuffer[T: ClassTag](
23+
serializer: Serializer = SparkEnv.get.serializer,
24+
blockManager: BlockManager = SparkEnv.get.blockManager,
25+
context: TaskContext = TaskContext.get())
26+
extends Iterable[T]
27+
with Serializable
28+
with Logging
29+
with Spillable[SizeTracker] {
30+
if (context == null) {
31+
throw new IllegalStateException(
32+
"Spillable collections should not be instantiated outside of tasks")
33+
}
34+
35+
private var currentArray = new SizeTrackingVector[T]
36+
private var spilledArrays = new ArrayBuffer[DiskArrayIterator]
37+
private val arrayBuffered: ArrayBuffer[T] = new ArrayBuffer[T]
38+
private val sparkConf = SparkEnv.get.conf
39+
private val diskBlockManager = blockManager.diskBlockManager
40+
41+
private val serializerBatchSize = sparkConf.getLong("spark.shuffle.spill.batchSize", 10000)
42+
private val cacheSizeInArrayBuffer = sparkConf.getInt("spark.cache.array.size", 10000)
43+
44+
// Number of bytes spilled in total
45+
private var _diskBytesSpilled = 0L
46+
def diskBytesSpilled: Long = _diskBytesSpilled
47+
48+
// Write metrics for current spill
49+
private var curWriteMetrics: ShuffleWriteMetrics = _
50+
private var _size: Int = 0
51+
private var _getsize: Int = 0
52+
53+
// Use getSizeAsKb (not bytes) to maintain backwards compatibility if no units are provided
54+
private val fileBufferSize =
55+
sparkConf.getSizeAsKb("spark.shuffle.file.buffer", "32k").toInt * 1024
56+
57+
private val ser = serializer.newInstance()
58+
59+
def length: Int = _size
60+
61+
override def reset(): Unit = {
62+
logDebug("Reset:" + spilledArrays.length + " currentArray.size:" + currentArray.size)
63+
super.reset()
64+
if (spilledArrays.length > 0) {
65+
_diskBytesSpilled = 0
66+
spilledArrays.map(iter => iter.deleteTmpFile)
67+
spilledArrays = new ArrayBuffer[DiskArrayIterator]
68+
}
69+
if (currentArray.size > 0) {
70+
currentArray = new SizeTrackingVector[T]
71+
}
72+
_size = 0
73+
arrayBuffered.clear
74+
releaseMemory()
75+
}
76+
77+
override protected[this] def taskMemoryManager: TaskMemoryManager = context.taskMemoryManager()
78+
79+
def +=(elem: T): Unit = {
80+
_size += 1
81+
addElementsRead()
82+
if (_size < cacheSizeInArrayBuffer) {
83+
arrayBuffered += elem
84+
return
85+
}
86+
if (arrayBuffered.length > 0) {
87+
val iter = arrayBuffered.iterator
88+
while(iter.hasNext) {
89+
currentArray += iter.next
90+
}
91+
arrayBuffered.clear
92+
}
93+
if(currentArray.length % 100000 == 0) {
94+
logInfo("currentArray.len:" + currentArray.length)
95+
}
96+
val estimatedSize = currentArray.estimateSize()
97+
if (maybeSpill(currentArray, estimatedSize)) {
98+
currentArray = new SizeTrackingVector[T]
99+
}
100+
currentArray += elem
101+
}
102+
103+
override protected[this] def spill(collection: SizeTracker): Unit = {
104+
val (blockId, file) = diskBlockManager.createTempLocalBlock()
105+
logInfo("Spill len:" + currentArray.length + " file:" + file)
106+
curWriteMetrics = new ShuffleWriteMetrics()
107+
var writer = blockManager.getDiskWriter(blockId, file, ser, fileBufferSize, curWriteMetrics)
108+
var objectsWritten = 0
109+
110+
val batchSizes = new ArrayBuffer[Long]
111+
112+
// Flush the disk writer's contents to disk, and update relevant variables
113+
def flush(): Unit = {
114+
val w = writer
115+
writer = null
116+
w.commitAndClose()
117+
_diskBytesSpilled += curWriteMetrics.shuffleBytesWritten
118+
batchSizes.append(curWriteMetrics.shuffleBytesWritten)
119+
objectsWritten = 0
120+
}
121+
122+
var success = false
123+
try {
124+
val it = currentArray.iterator
125+
while (it.hasNext) {
126+
val elem = it.next()
127+
writer.write(null, elem)
128+
objectsWritten += 1
129+
130+
if (objectsWritten == serializerBatchSize) {
131+
flush()
132+
curWriteMetrics = new ShuffleWriteMetrics()
133+
writer = blockManager.getDiskWriter(blockId, file, ser, fileBufferSize, curWriteMetrics)
134+
}
135+
}
136+
if (objectsWritten > 0) {
137+
flush()
138+
} else if (writer != null) {
139+
val w = writer
140+
writer = null
141+
w.revertPartialWritesAndClose()
142+
}
143+
success = true
144+
} finally {
145+
if (!success) {
146+
// This code path only happens if an exception was thrown above before we set success;
147+
// close our stuff and let the exception be thrown further
148+
if (writer != null) {
149+
writer.revertPartialWritesAndClose()
150+
}
151+
if (file.exists()) {
152+
if (!file.delete()) {
153+
logWarning(s"Error deleting ${file}")
154+
}
155+
}
156+
}
157+
}
158+
spilledArrays.append(new DiskArrayIterator(file, blockId, batchSizes))
159+
}
160+
161+
override def iterator: Iterator[T] = {
162+
logInfo("Match size:" + _size)
163+
if (_size < cacheSizeInArrayBuffer) {
164+
arrayBuffered.iterator
165+
} else if (0 == spilledArrays.length) {
166+
currentArray.iterator
167+
} else {
168+
new ExternalIterator()
169+
}
170+
}
171+
172+
private class ExternalIterator extends Iterator[T] {
173+
174+
var currentIndex = 0
175+
var currentIter = spilledArrays(currentIndex)
176+
val arrayIter = currentArray.iterator
177+
override def hasNext: Boolean = {
178+
if (currentIter.hasNext) {
179+
return true
180+
} else if (currentIndex < spilledArrays.length - 1) {
181+
currentIndex += 1
182+
currentIter = spilledArrays(currentIndex)
183+
return true
184+
} else {
185+
arrayIter.hasNext
186+
}
187+
}
188+
189+
override def next(): T = {
190+
_getsize += 1
191+
if (currentIter.hasNext) {
192+
val tmp: T = currentIter.next
193+
if (_getsize % 100000 == 0) {
194+
logInfo("Getsize" + currentIndex + ":" + _getsize)
195+
}
196+
tmp
197+
} else if (currentIndex < spilledArrays.length - 1) {
198+
currentIndex += 1
199+
currentIter = spilledArrays(currentIndex)
200+
_getsize = 0
201+
val tmp: T = currentIter.next
202+
if (_getsize % 100000 == 0) {
203+
logInfo("Getsize" + currentIndex + ":" + _getsize)
204+
}
205+
tmp
206+
} else {
207+
val tmp: T = arrayIter.next
208+
if (_getsize % 100000 == 0) {
209+
logInfo("Getsize array:" + _getsize)
210+
}
211+
tmp
212+
}
213+
}
214+
}
215+
216+
private class DiskArrayIterator(file: File, blockId: BlockId, batchSizes: ArrayBuffer[Long])
217+
extends Iterator[T] {
218+
219+
private val batchOffsets = batchSizes.scanLeft(0L)(_ + _) // Size will be batchSize.length + 1
220+
logDebug("BatchOffsets:" + batchOffsets.length + " file.length():" + file.length()
221+
+ " batchOffsets.last:" + batchOffsets.last)
222+
assert(file.length() == batchOffsets.last,
223+
"File length is not equal to the last batch offset:\n" +
224+
s" file length = ${file.length}\n" +
225+
s" last batch offset = ${batchOffsets.last}\n" +
226+
s" all batch offsets = ${batchOffsets.mkString(",")}"
227+
)
228+
229+
private var batchIndex = 0 // Which batch we're in
230+
private var fileStream: FileInputStream = null
231+
232+
// An intermediate stream that reads from exactly one batch
233+
// This guards against pre-fetching and other arbitrary behavior of higher level streams
234+
private var deserializeStream = nextBatchStream()
235+
private var itemIsNull = true
236+
private var nextItem: T = _
237+
private var objectsRead = 0
238+
239+
/**
240+
* Construct a stream that reads only from the next batch.
241+
*/
242+
private def nextBatchStream(): DeserializationStream = {
243+
// Note that batchOffsets.length = numBatches + 1 since we did a scan above; check whether
244+
// we're still in a valid batch.
245+
if (batchIndex < batchOffsets.length - 1) {
246+
if (deserializeStream != null) {
247+
deserializeStream.close()
248+
fileStream.close()
249+
deserializeStream = null
250+
fileStream = null
251+
}
252+
253+
val start = batchOffsets(batchIndex)
254+
fileStream = new FileInputStream(file)
255+
fileStream.getChannel.position(start)
256+
batchIndex += 1
257+
258+
val end = batchOffsets(batchIndex)
259+
260+
assert(end >= start, "start = " + start + ", end = " + end +
261+
", batchOffsets = " + batchOffsets.mkString("[", ", ", "]"))
262+
263+
val bufferedStream = new BufferedInputStream(ByteStreams.limit(fileStream, end - start))
264+
val compressedStream = blockManager.wrapForCompression(blockId, bufferedStream)
265+
ser.deserializeStream(compressedStream)
266+
} else {
267+
// No more batches left
268+
cleanup(false)
269+
null
270+
}
271+
}
272+
273+
/**
274+
* Return the next T pair from the deserialization stream.
275+
*
276+
* If the current batch is drained, construct a stream for the next batch and read from it.
277+
* If no more element are left, return null.
278+
*/
279+
private def readNextItem(): T = {
280+
try {
281+
val k = deserializeStream.readKey()
282+
val c = deserializeStream.readValue().asInstanceOf[T]
283+
objectsRead += 1
284+
if (objectsRead == serializerBatchSize) {
285+
objectsRead = 0
286+
deserializeStream = nextBatchStream()
287+
}
288+
itemIsNull = false
289+
c
290+
} catch {
291+
case e: EOFException =>
292+
e.printStackTrace()
293+
cleanup(false)
294+
itemIsNull = true
295+
nextItem
296+
}
297+
}
298+
299+
override def hasNext: Boolean = {
300+
if (itemIsNull) {
301+
if (deserializeStream == null) {
302+
return false
303+
}
304+
nextItem = readNextItem()
305+
}
306+
!itemIsNull
307+
}
308+
309+
override def next(): T = {
310+
val item = if (itemIsNull) readNextItem() else nextItem
311+
if (itemIsNull) {
312+
throw new NoSuchElementException
313+
}
314+
itemIsNull = true
315+
item
316+
}
317+
318+
private def cleanup(deleteFile: Boolean) {
319+
batchIndex = batchOffsets.length // Prevent reading any other batch
320+
val ds = deserializeStream
321+
if (ds != null) {
322+
ds.close()
323+
deserializeStream = null
324+
}
325+
if (fileStream != null) {
326+
fileStream.close()
327+
fileStream = null
328+
}
329+
if (deleteFile) {
330+
deleteTmpFile
331+
}
332+
}
333+
334+
def deleteTmpFile() {
335+
if (file.exists()) {
336+
if (!file.delete()) {
337+
logWarning(s"Error deleting ${file}")
338+
}
339+
}
340+
}
341+
342+
context.addTaskCompletionListener(context => cleanup(true))
343+
}
344+
}

core/src/main/scala/org/apache/spark/util/collection/Spillable.scala

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,14 @@ private[spark] trait Spillable[C] extends Logging {
6565
// Number of spills
6666
private[this] var _spillCount = 0
6767

68+
protected def reset(): Unit = {
69+
releaseMemory()
70+
//myMemoryThreshold = initialMemoryThreshold
71+
_elementsRead = 0
72+
_memoryBytesSpilled = 0L
73+
_spillCount = 0
74+
}
75+
6876
/**
6977
* Spills the current in-memory collection to disk if needed. Attempts to acquire more
7078
* memory before spilling.

0 commit comments

Comments
 (0)