1+ package org .apache .spark .util .collection
2+
3+ import java .io ._
4+ import java .util .Comparator
5+
6+ import scala .collection .BufferedIterator
7+ import scala .collection .mutable
8+ import scala .collection .mutable .ArrayBuffer
9+ import scala .reflect .ClassTag
10+
11+ import com .google .common .io .ByteStreams
12+
13+ import org .apache .spark .{Logging , SparkEnv , TaskContext }
14+ import org .apache .spark .annotation .DeveloperApi
15+ import org .apache .spark .memory .TaskMemoryManager
16+ import org .apache .spark .serializer .{DeserializationStream , Serializer }
17+ import org .apache .spark .storage .{BlockId , BlockManager }
18+ import org .apache .spark .util .CompletionIterator
19+ import org .apache .spark .util .collection .ExternalAppendOnlyMap .HashComparator
20+ import org .apache .spark .executor .ShuffleWriteMetrics
21+
22+ class ExternalAppendOnlyArrayBuffer [T : ClassTag ](
23+ serializer : Serializer = SparkEnv .get.serializer,
24+ blockManager : BlockManager = SparkEnv .get.blockManager,
25+ context : TaskContext = TaskContext .get())
26+ extends Iterable [T ]
27+ with Serializable
28+ with Logging
29+ with Spillable [SizeTracker ] {
30+ if (context == null ) {
31+ throw new IllegalStateException (
32+ " Spillable collections should not be instantiated outside of tasks" )
33+ }
34+
35+ private var currentArray = new SizeTrackingVector [T ]
36+ private var spilledArrays = new ArrayBuffer [DiskArrayIterator ]
37+ private val arrayBuffered : ArrayBuffer [T ] = new ArrayBuffer [T ]
38+ private val sparkConf = SparkEnv .get.conf
39+ private val diskBlockManager = blockManager.diskBlockManager
40+
41+ private val serializerBatchSize = sparkConf.getLong(" spark.shuffle.spill.batchSize" , 10000 )
42+ private val cacheSizeInArrayBuffer = sparkConf.getInt(" spark.cache.array.size" , 10000 )
43+
44+ // Number of bytes spilled in total
45+ private var _diskBytesSpilled = 0L
46+ def diskBytesSpilled : Long = _diskBytesSpilled
47+
48+ // Write metrics for current spill
49+ private var curWriteMetrics : ShuffleWriteMetrics = _
50+ private var _size : Int = 0
51+ private var _getsize : Int = 0
52+
53+ // Use getSizeAsKb (not bytes) to maintain backwards compatibility if no units are provided
54+ private val fileBufferSize =
55+ sparkConf.getSizeAsKb(" spark.shuffle.file.buffer" , " 32k" ).toInt * 1024
56+
57+ private val ser = serializer.newInstance()
58+
59+ def length : Int = _size
60+
61+ override def reset (): Unit = {
62+ logDebug(" Reset:" + spilledArrays.length + " currentArray.size:" + currentArray.size)
63+ super .reset()
64+ if (spilledArrays.length > 0 ) {
65+ _diskBytesSpilled = 0
66+ spilledArrays.map(iter => iter.deleteTmpFile)
67+ spilledArrays = new ArrayBuffer [DiskArrayIterator ]
68+ }
69+ if (currentArray.size > 0 ) {
70+ currentArray = new SizeTrackingVector [T ]
71+ }
72+ _size = 0
73+ arrayBuffered.clear
74+ releaseMemory()
75+ }
76+
77+ override protected [this ] def taskMemoryManager : TaskMemoryManager = context.taskMemoryManager()
78+
79+ def += (elem : T ): Unit = {
80+ _size += 1
81+ addElementsRead()
82+ if (_size < cacheSizeInArrayBuffer) {
83+ arrayBuffered += elem
84+ return
85+ }
86+ if (arrayBuffered.length > 0 ) {
87+ val iter = arrayBuffered.iterator
88+ while (iter.hasNext) {
89+ currentArray += iter.next
90+ }
91+ arrayBuffered.clear
92+ }
93+ if (currentArray.length % 100000 == 0 ) {
94+ logInfo(" currentArray.len:" + currentArray.length)
95+ }
96+ val estimatedSize = currentArray.estimateSize()
97+ if (maybeSpill(currentArray, estimatedSize)) {
98+ currentArray = new SizeTrackingVector [T ]
99+ }
100+ currentArray += elem
101+ }
102+
103+ override protected [this ] def spill (collection : SizeTracker ): Unit = {
104+ val (blockId, file) = diskBlockManager.createTempLocalBlock()
105+ logInfo(" Spill len:" + currentArray.length + " file:" + file)
106+ curWriteMetrics = new ShuffleWriteMetrics ()
107+ var writer = blockManager.getDiskWriter(blockId, file, ser, fileBufferSize, curWriteMetrics)
108+ var objectsWritten = 0
109+
110+ val batchSizes = new ArrayBuffer [Long ]
111+
112+ // Flush the disk writer's contents to disk, and update relevant variables
113+ def flush (): Unit = {
114+ val w = writer
115+ writer = null
116+ w.commitAndClose()
117+ _diskBytesSpilled += curWriteMetrics.shuffleBytesWritten
118+ batchSizes.append(curWriteMetrics.shuffleBytesWritten)
119+ objectsWritten = 0
120+ }
121+
122+ var success = false
123+ try {
124+ val it = currentArray.iterator
125+ while (it.hasNext) {
126+ val elem = it.next()
127+ writer.write(null , elem)
128+ objectsWritten += 1
129+
130+ if (objectsWritten == serializerBatchSize) {
131+ flush()
132+ curWriteMetrics = new ShuffleWriteMetrics ()
133+ writer = blockManager.getDiskWriter(blockId, file, ser, fileBufferSize, curWriteMetrics)
134+ }
135+ }
136+ if (objectsWritten > 0 ) {
137+ flush()
138+ } else if (writer != null ) {
139+ val w = writer
140+ writer = null
141+ w.revertPartialWritesAndClose()
142+ }
143+ success = true
144+ } finally {
145+ if (! success) {
146+ // This code path only happens if an exception was thrown above before we set success;
147+ // close our stuff and let the exception be thrown further
148+ if (writer != null ) {
149+ writer.revertPartialWritesAndClose()
150+ }
151+ if (file.exists()) {
152+ if (! file.delete()) {
153+ logWarning(s " Error deleting ${file}" )
154+ }
155+ }
156+ }
157+ }
158+ spilledArrays.append(new DiskArrayIterator (file, blockId, batchSizes))
159+ }
160+
161+ override def iterator : Iterator [T ] = {
162+ logInfo(" Match size:" + _size)
163+ if (_size < cacheSizeInArrayBuffer) {
164+ arrayBuffered.iterator
165+ } else if (0 == spilledArrays.length) {
166+ currentArray.iterator
167+ } else {
168+ new ExternalIterator ()
169+ }
170+ }
171+
172+ private class ExternalIterator extends Iterator [T ] {
173+
174+ var currentIndex = 0
175+ var currentIter = spilledArrays(currentIndex)
176+ val arrayIter = currentArray.iterator
177+ override def hasNext : Boolean = {
178+ if (currentIter.hasNext) {
179+ return true
180+ } else if (currentIndex < spilledArrays.length - 1 ) {
181+ currentIndex += 1
182+ currentIter = spilledArrays(currentIndex)
183+ return true
184+ } else {
185+ arrayIter.hasNext
186+ }
187+ }
188+
189+ override def next (): T = {
190+ _getsize += 1
191+ if (currentIter.hasNext) {
192+ val tmp : T = currentIter.next
193+ if (_getsize % 100000 == 0 ) {
194+ logInfo(" Getsize" + currentIndex + " :" + _getsize)
195+ }
196+ tmp
197+ } else if (currentIndex < spilledArrays.length - 1 ) {
198+ currentIndex += 1
199+ currentIter = spilledArrays(currentIndex)
200+ _getsize = 0
201+ val tmp : T = currentIter.next
202+ if (_getsize % 100000 == 0 ) {
203+ logInfo(" Getsize" + currentIndex + " :" + _getsize)
204+ }
205+ tmp
206+ } else {
207+ val tmp : T = arrayIter.next
208+ if (_getsize % 100000 == 0 ) {
209+ logInfo(" Getsize array:" + _getsize)
210+ }
211+ tmp
212+ }
213+ }
214+ }
215+
216+ private class DiskArrayIterator (file : File , blockId : BlockId , batchSizes : ArrayBuffer [Long ])
217+ extends Iterator [T ] {
218+
219+ private val batchOffsets = batchSizes.scanLeft(0L )(_ + _) // Size will be batchSize.length + 1
220+ logDebug(" BatchOffsets:" + batchOffsets.length + " file.length():" + file.length()
221+ + " batchOffsets.last:" + batchOffsets.last)
222+ assert(file.length() == batchOffsets.last,
223+ " File length is not equal to the last batch offset:\n " +
224+ s " file length = ${file.length}\n " +
225+ s " last batch offset = ${batchOffsets.last}\n " +
226+ s " all batch offsets = ${batchOffsets.mkString(" ," )}"
227+ )
228+
229+ private var batchIndex = 0 // Which batch we're in
230+ private var fileStream : FileInputStream = null
231+
232+ // An intermediate stream that reads from exactly one batch
233+ // This guards against pre-fetching and other arbitrary behavior of higher level streams
234+ private var deserializeStream = nextBatchStream()
235+ private var itemIsNull = true
236+ private var nextItem : T = _
237+ private var objectsRead = 0
238+
239+ /**
240+ * Construct a stream that reads only from the next batch.
241+ */
242+ private def nextBatchStream (): DeserializationStream = {
243+ // Note that batchOffsets.length = numBatches + 1 since we did a scan above; check whether
244+ // we're still in a valid batch.
245+ if (batchIndex < batchOffsets.length - 1 ) {
246+ if (deserializeStream != null ) {
247+ deserializeStream.close()
248+ fileStream.close()
249+ deserializeStream = null
250+ fileStream = null
251+ }
252+
253+ val start = batchOffsets(batchIndex)
254+ fileStream = new FileInputStream (file)
255+ fileStream.getChannel.position(start)
256+ batchIndex += 1
257+
258+ val end = batchOffsets(batchIndex)
259+
260+ assert(end >= start, " start = " + start + " , end = " + end +
261+ " , batchOffsets = " + batchOffsets.mkString(" [" , " , " , " ]" ))
262+
263+ val bufferedStream = new BufferedInputStream (ByteStreams .limit(fileStream, end - start))
264+ val compressedStream = blockManager.wrapForCompression(blockId, bufferedStream)
265+ ser.deserializeStream(compressedStream)
266+ } else {
267+ // No more batches left
268+ cleanup(false )
269+ null
270+ }
271+ }
272+
273+ /**
274+ * Return the next T pair from the deserialization stream.
275+ *
276+ * If the current batch is drained, construct a stream for the next batch and read from it.
277+ * If no more element are left, return null.
278+ */
279+ private def readNextItem (): T = {
280+ try {
281+ val k = deserializeStream.readKey()
282+ val c = deserializeStream.readValue().asInstanceOf [T ]
283+ objectsRead += 1
284+ if (objectsRead == serializerBatchSize) {
285+ objectsRead = 0
286+ deserializeStream = nextBatchStream()
287+ }
288+ itemIsNull = false
289+ c
290+ } catch {
291+ case e : EOFException =>
292+ e.printStackTrace()
293+ cleanup(false )
294+ itemIsNull = true
295+ nextItem
296+ }
297+ }
298+
299+ override def hasNext : Boolean = {
300+ if (itemIsNull) {
301+ if (deserializeStream == null ) {
302+ return false
303+ }
304+ nextItem = readNextItem()
305+ }
306+ ! itemIsNull
307+ }
308+
309+ override def next (): T = {
310+ val item = if (itemIsNull) readNextItem() else nextItem
311+ if (itemIsNull) {
312+ throw new NoSuchElementException
313+ }
314+ itemIsNull = true
315+ item
316+ }
317+
318+ private def cleanup (deleteFile : Boolean ) {
319+ batchIndex = batchOffsets.length // Prevent reading any other batch
320+ val ds = deserializeStream
321+ if (ds != null ) {
322+ ds.close()
323+ deserializeStream = null
324+ }
325+ if (fileStream != null ) {
326+ fileStream.close()
327+ fileStream = null
328+ }
329+ if (deleteFile) {
330+ deleteTmpFile
331+ }
332+ }
333+
334+ def deleteTmpFile () {
335+ if (file.exists()) {
336+ if (! file.delete()) {
337+ logWarning(s " Error deleting ${file}" )
338+ }
339+ }
340+ }
341+
342+ context.addTaskCompletionListener(context => cleanup(true ))
343+ }
344+ }
0 commit comments