@@ -115,8 +115,8 @@ private[spark] class UnsafeShuffleWriter[K, V](
115115 val serializer = Serializer .getSerializer(dep.serializer).newInstance()
116116 val PAGE_SIZE = 1024 * 1024 * 1
117117
118- var currentPage : MemoryBlock = memoryManager.allocatePage( PAGE_SIZE )
119- var currentPagePosition : Long = currentPage.getBaseOffset
118+ var currentPage : MemoryBlock = null
119+ var currentPagePosition : Long = PAGE_SIZE
120120
121121 def ensureSpaceInDataPage (spaceRequired : Long ): Unit = {
122122 if (spaceRequired > PAGE_SIZE ) {
@@ -143,6 +143,7 @@ private[spark] class UnsafeShuffleWriter[K, V](
143143 serBufferSerStream.flush()
144144
145145 val serializedRecordSize = byteBuffer.position()
146+ assert(serializedRecordSize > 0 )
146147 // TODO: we should run the partition extraction function _now_, at insert time, rather than
147148 // requiring it to be stored alongisde the data, since this may lead to double storage
148149 val sizeRequirementInSortDataPage = serializedRecordSize + 8 + 8
@@ -152,17 +153,17 @@ private[spark] class UnsafeShuffleWriter[K, V](
152153 memoryManager.encodePageNumberAndOffset(currentPage, currentPagePosition)
153154 PlatformDependent .UNSAFE .putLong(currentPage.getBaseObject, currentPagePosition, partitionId)
154155 currentPagePosition += 8
155- println(" The stored record length is " + byteBuffer.position() )
156+ println(" The stored record length is " + serializedRecordSize )
156157 PlatformDependent .UNSAFE .putLong(
157- currentPage.getBaseObject, currentPagePosition, byteBuffer.position() )
158+ currentPage.getBaseObject, currentPagePosition, serializedRecordSize )
158159 currentPagePosition += 8
159160 PlatformDependent .copyMemory(
160161 serArray,
161162 PlatformDependent .BYTE_ARRAY_OFFSET ,
162163 currentPage.getBaseObject,
163164 currentPagePosition,
164- byteBuffer.position() )
165- currentPagePosition += byteBuffer.position()
165+ serializedRecordSize )
166+ currentPagePosition += serializedRecordSize
166167 println(" After writing record, current page position is " + currentPagePosition)
167168 sorter.insertRecord(newRecordAddress)
168169
@@ -195,10 +196,12 @@ private[spark] class UnsafeShuffleWriter[K, V](
195196 }
196197
197198 def switchToPartition (newPartition : Int ): Unit = {
199+ assert (newPartition > currentPartition, s " new partition $newPartition should be >= $currentPartition" )
198200 if (currentPartition != - 1 ) {
199201 closePartition()
200202 prevPartitionLength = partitionLengths(currentPartition)
201203 }
204+ println(s " Before switching to partition $newPartition, partition lengths are " + partitionLengths.toSeq)
202205 currentPartition = newPartition
203206 out = blockManager.wrapForCompression(blockId, new FileOutputStream (outputFile, true ))
204207 }
@@ -214,11 +217,11 @@ private[spark] class UnsafeShuffleWriter[K, V](
214217 val recordLength = PlatformDependent .UNSAFE .getLong(baseObject, baseOffset + 8 )
215218 println(" Base offset is " + baseOffset)
216219 println(" Record length is " + recordLength)
217- var i : Int = 0
218220 // TODO: need to have a way to figure out whether a serializer supports relocation of
219221 // serialized objects or not. Sandy also ran into this in his patch (see
220222 // https://github.com/apache/spark/pull/4450). If we're using Java serialization, we might
221223 // as well just bypass this optimized code path in favor of the old one.
224+ var i : Int = 0
222225 while (i < recordLength) {
223226 out.write(PlatformDependent .UNSAFE .getByte(baseObject, baseOffset + 16 + i))
224227 i += 1
@@ -241,6 +244,14 @@ private[spark] class UnsafeShuffleWriter[K, V](
241244 mapStatus = MapStatus (blockManager.shuffleServerId, partitionLengths)
242245 }
243246
247+ private def freeMemory (): Unit = {
248+ val iter = allocatedPages.iterator()
249+ while (iter.hasNext) {
250+ memoryManager.freePage(iter.next())
251+ iter.remove()
252+ }
253+ }
254+
244255 /** Close this writer, passing along whether the map completed */
245256 override def stop (success : Boolean ): Option [MapStatus ] = {
246257 println(" Stopping unsafeshufflewriter" )
@@ -249,6 +260,7 @@ private[spark] class UnsafeShuffleWriter[K, V](
249260 None
250261 } else {
251262 stopping = true
263+ freeMemory()
252264 if (success) {
253265 Option (mapStatus)
254266 } else {
@@ -258,24 +270,14 @@ private[spark] class UnsafeShuffleWriter[K, V](
258270 }
259271 }
260272 } finally {
261- // Clean up our sorter, which may have its own intermediate files
262- if (! allocatedPages.isEmpty) {
263- val iter = allocatedPages.iterator()
264- while (iter.hasNext) {
265- memoryManager.freePage(iter.next())
266- iter.remove()
267- }
268- val startTime = System .nanoTime()
269- // sorter.stop()
270- context.taskMetrics().shuffleWriteMetrics.foreach(
271- _.incShuffleWriteTime(System .nanoTime - startTime))
272- }
273+ freeMemory()
274+ val startTime = System .nanoTime()
275+ context.taskMetrics().shuffleWriteMetrics.foreach(
276+ _.incShuffleWriteTime(System .nanoTime - startTime))
273277 }
274278 }
275279}
276280
277-
278-
279281private [spark] class UnsafeShuffleManager (conf : SparkConf ) extends ShuffleManager {
280282
281283 private [this ] val sortShuffleManager : SortShuffleManager = new SortShuffleManager (conf)
0 commit comments