Skip to content

Commit ee77923

Browse files
wakunGitHub Enterprise
authored andcommitted
[CARMEL-6185] Expose row count for RepeatableIterator (#1074)
* [CARMEL-6185] Expose row count for RepeatableIterator * fix code style * Fix code style * Fix UT * Update code * Update code
1 parent e0ff530 commit ee77923

File tree

12 files changed

+87
-39
lines changed

12 files changed

+87
-39
lines changed

core/src/main/scala/org/apache/spark/InternalAccumulator.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ private[spark] object InternalAccumulator {
4343
val UPDATED_BLOCK_STATUSES = METRICS_PREFIX + "updatedBlockStatuses"
4444
val PRUNED_STATS = "index.prunedStats"
4545
val TEST_ACCUM = METRICS_PREFIX + "testAccumulator"
46+
val RECORDS_OUTPUT = OUTPUT_METRICS_PREFIX + "recordsOutput"
4647

4748
// scalastyle:off
4849

core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ class TaskMetrics private[spark] () extends Serializable {
5757
private val _peakExecutionMemory = new LongAccumulator
5858
private val _updatedBlockStatuses = new CollectionAccumulator[(BlockId, BlockStatus)]
5959
private val _prunedStats = new PrunedMetricsAccum
60+
private val _recordsOutput = new LongAccumulator
6061

6162
def prunedStats: PrunedMetricsAccum = _prunedStats
6263
/**
@@ -113,6 +114,11 @@ class TaskMetrics private[spark] () extends Serializable {
113114
*/
114115
def peakExecutionMemory: Long = _peakExecutionMemory.sum
115116

117+
/**
118+
* Total number of records output.
119+
*/
120+
def recordsOutput: Long = _recordsOutput.sum
121+
116122
/**
117123
* Storage statuses of any blocks that have been updated as a result of this task.
118124
*
@@ -152,6 +158,7 @@ class TaskMetrics private[spark] () extends Serializable {
152158
private[spark] def setPrunedStats(v: List[PrunedStats]): Unit = {
153159
_prunedStats.setValue(v)
154160
}
161+
private[spark] def setRecordsOutput(v: Long): Unit = _recordsOutput.setValue(v)
155162

156163
/**
157164
* Metrics related to reading data from a [[org.apache.spark.rdd.HadoopRDD]] or from persisted
@@ -226,6 +233,7 @@ class TaskMetrics private[spark] () extends Serializable {
226233
PEAK_EXECUTION_MEMORY -> _peakExecutionMemory,
227234
UPDATED_BLOCK_STATUSES -> _updatedBlockStatuses,
228235
PRUNED_STATS -> _prunedStats,
236+
RECORDS_OUTPUT -> _recordsOutput,
229237
shuffleRead.REMOTE_BLOCKS_FETCHED -> shuffleReadMetrics._remoteBlocksFetched,
230238
shuffleRead.LOCAL_BLOCKS_FETCHED -> shuffleReadMetrics._localBlocksFetched,
231239
shuffleRead.REMOTE_BYTES_READ -> shuffleReadMetrics._remoteBytesRead,

core/src/main/scala/org/apache/spark/rdd/RDD.scala

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -324,12 +324,12 @@ abstract class RDD[T: ClassTag](
324324
* expansion rate = (number of output rows in the task) / (number of input rows in task).
325325
*/
326326
final def expansionLimitedIterator(split: Partition, context: TaskContext): Iterator[T] = {
327-
val innerItrator = iterator(split, context)
327+
val innerIterator = iterator(split, context)
328328
if (maxExpandRate > 0) {
329329
new Iterator[T] {
330330
private var output = 0
331331
override def hasNext: Boolean = {
332-
innerItrator.hasNext
332+
innerIterator.hasNext
333333
}
334334
override def next(): T = {
335335
output += 1
@@ -345,11 +345,11 @@ abstract class RDD[T: ClassTag](
345345
}
346346
output = 0
347347
}
348-
innerItrator.next()
348+
innerIterator.next()
349349
}
350350
}
351351
} else {
352-
innerItrator
352+
innerIterator
353353
}
354354
}
355355

core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1942,7 +1942,7 @@ private[spark] class DAGScheduler(
19421942
// taskSucceeded runs some user code that might throw an exception. Make sure
19431943
// we are resilient against that.
19441944
try {
1945-
job.listener.taskSucceeded(rt.outputId, event.result)
1945+
job.listener.taskSucceeded(rt.outputId, event.result, event.taskMetrics)
19461946
} catch {
19471947
case e: Throwable if !Utils.isFatalError(e) =>
19481948
// TODO: Perhaps we want to mark the resultStage as failed?

core/src/main/scala/org/apache/spark/scheduler/IterableJobWaiter.scala

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ import java.util.concurrent.atomic.AtomicInteger
2222
import scala.concurrent.{Future, Promise}
2323
import scala.reflect.ClassTag
2424

25+
import org.apache.spark.executor.TaskMetrics
2526
import org.apache.spark.internal.Logging
2627

2728
/**
@@ -48,6 +49,7 @@ private[spark] class IterableJobWaiter[U: ClassTag, R](
4849
// to hold the result data as spilled files or in memory array
4950
private var spilledResultData: Option[Array[SpilledPartitionResult]] = None
5051
private val resultData: Array[U] = new Array[U](totalTasks)
52+
private var allRowCount: Long = 0L
5153

5254
// indicate whether the in memory result data has been cleaned after
5355
// the result data is spilled to disk
@@ -66,7 +68,11 @@ private[spark] class IterableJobWaiter[U: ClassTag, R](
6668
dagScheduler.cancelJob(jobId, None)
6769
}
6870

69-
override def taskSucceeded(index: Int, result: Any): Unit = {
71+
override def taskSucceeded(index: Int, result: Any): Unit =
72+
taskSucceeded(index, result, new TaskMetrics)
73+
74+
override def taskSucceeded(index: Int, result: Any, taskMetrics: TaskMetrics): Unit = {
75+
allRowCount += taskMetrics.recordsOutput
7076
result match {
7177
case spilledPartitionResult: Array[SpilledPartitionResult] =>
7278
spilledResultData = Some(spilledPartitionResult)
@@ -109,12 +115,12 @@ private[spark] class IterableJobWaiter[U: ClassTag, R](
109115
if (spilledResultData.nonEmpty) {
110116
logInfo(s"Return result as a SpilledResultIterator for job $jobId " +
111117
s"with files ${spilledResultData.get.map(_.file.getPath).mkString(",")}")
112-
SpilledResultIterator[U, R](spilledResultData.get, resultConverter,
118+
SpilledResultIterator[U, R](spilledResultData.get, resultConverter, allRowCount,
113119
dagScheduler.sc.conf.getBoolean("spark.sql.thriftserver.cleanShareResultFiles", false),
114120
dagScheduler.sc.conf.getBoolean("spark.sql.thriftserver.shareResult", true))
115121
} else {
116122
logInfo(s"Return result as a SimpleRepeatableIterator for job $jobId")
117-
SimpleRepeatableIterator(resultData, resultConverter)
123+
SimpleRepeatableIterator(resultData, resultConverter, allRowCount)
118124
}
119125
}
120126

core/src/main/scala/org/apache/spark/scheduler/JobListener.scala

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,17 @@
1717

1818
package org.apache.spark.scheduler
1919

20+
import org.apache.spark.executor.TaskMetrics
21+
2022
/**
2123
* Interface used to listen for job completion or failure events after submitting a job to the
2224
* DAGScheduler. The listener is notified each time a task succeeds, as well as if the whole
2325
* job fails (and no further taskSucceeded events will happen).
2426
*/
2527
private[spark] trait JobListener {
2628
def taskSucceeded(index: Int, result: Any): Unit
29+
30+
def taskSucceeded(index: Int, result: Any, taskMetrics: TaskMetrics): Unit =
31+
taskSucceeded(index, result)
2732
def jobFailed(exception: Exception): Unit
2833
}

core/src/main/scala/org/apache/spark/scheduler/TaskResultStore.scala

Lines changed: 22 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -180,22 +180,28 @@ private[spark] case class SpilledPartitionResult(
180180
/**
181181
* the interface for iterator which supports read from the start again
182182
*/
183-
private[spark] trait RepeatableIterator[T] extends Iterator[T] {
183+
private[spark] abstract class RepeatableIterator[T](_rowCount: Long) extends Iterator[T] {
184184

185185
def backToStart(): Unit
186186

187187
def close(): Unit
188188

189189
def copy(): RepeatableIterator[T]
190+
191+
def rowCount(): Long = _rowCount
192+
193+
// Do not call this method, length will < 0 if rowCount > Int.MAX_VALUE
194+
override def length: Int = rowCount.toInt
190195
}
191196

192197
/**
193198
* the simple iterator implementation to read from the in memory data
194199
*/
195200
private[spark] case class SimpleRepeatableIterator[T, U](
196201
originData: Array[U],
197-
resultConverter: U => Iterator[T])
198-
extends RepeatableIterator[T] {
202+
resultConverter: U => Iterator[T],
203+
_rowCount: Long)
204+
extends RepeatableIterator[T](_rowCount) {
199205

200206
private var it: Iterator[T] = originData.iterator.flatMap(resultConverter)
201207

@@ -210,24 +216,29 @@ private[spark] case class SimpleRepeatableIterator[T, U](
210216
override def close(): Unit = {
211217
}
212218

213-
// length calculation is time consuming
214-
override def length: Int = {
215-
originData.iterator.flatMap(resultConverter).length
216-
}
217-
218219
override def copy(): RepeatableIterator[T] = {
219-
new SimpleRepeatableIterator[T, U](originData, resultConverter)
220+
new SimpleRepeatableIterator[T, U](originData, resultConverter, rowCount)
220221
}
221222
}
222223

223224
/**
224225
* The iterator implementation to read from spilled files
226+
* data of spilledResults: Array[SpilledPartitionResult]
227+
* file blockId offset length
228+
* /data/yarn/tmp/file1, "temp_local_001", 0, 100
229+
* /data/yarn/tmp/file1, "temp_local_002", 100, 200
230+
* /data/yarn/tmp/file2, "temp_local_003", 0, 400
231+
*
232+
* nextBatchStream() will clean the temp file and then read a new SpilledPartitionResult.
233+
* readNextBatch() will convert the partition result to currentBatch: Iterator[R]
225234
*/
226235
private[spark] case class SpilledResultIterator[U, R](
227236
spilledResults: Array[SpilledPartitionResult],
228237
converter: U => Iterator[R],
238+
_rowCount: Long,
229239
cleanShareResultFiles: Boolean = false,
230-
override val isTraversableAgain: Boolean) extends RepeatableIterator[R] with Logging {
240+
override val isTraversableAgain: Boolean)
241+
extends RepeatableIterator[R](_rowCount) with Logging {
231242

232243
private val serializer = SparkEnv.get.serializer.newInstance()
233244
private val serializerManager = SparkEnv.get.serializerManager
@@ -378,7 +389,7 @@ private[spark] case class SpilledResultIterator[U, R](
378389
}
379390

380391
override def copy(): RepeatableIterator[R] = {
381-
new SpilledResultIterator[U, R](spilledResults, converter, cleanShareResultFiles,
392+
new SpilledResultIterator[U, R](spilledResults, converter, rowCount, cleanShareResultFiles,
382393
isTraversableAgain)
383394
}
384395
}

core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala

Lines changed: 21 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -2093,104 +2093,111 @@ private[spark] object JsonProtocolSuite extends Assertions {
20932093
| },
20942094
| {
20952095
| "ID": 11,
2096-
| "Name": "${shuffleRead.REMOTE_BLOCKS_FETCHED}",
2096+
| "Name": "${RECORDS_OUTPUT}",
20972097
| "Update": 0,
20982098
| "Internal": true,
20992099
| "Count Failed Values": true
21002100
| },
21012101
| {
21022102
| "ID": 12,
2103-
| "Name": "${shuffleRead.LOCAL_BLOCKS_FETCHED}",
2103+
| "Name": "${shuffleRead.REMOTE_BLOCKS_FETCHED}",
21042104
| "Update": 0,
21052105
| "Internal": true,
21062106
| "Count Failed Values": true
21072107
| },
21082108
| {
21092109
| "ID": 13,
2110-
| "Name": "${shuffleRead.REMOTE_BYTES_READ}",
2110+
| "Name": "${shuffleRead.LOCAL_BLOCKS_FETCHED}",
21112111
| "Update": 0,
21122112
| "Internal": true,
21132113
| "Count Failed Values": true
21142114
| },
21152115
| {
21162116
| "ID": 14,
2117-
| "Name": "${shuffleRead.REMOTE_BYTES_READ_TO_DISK}",
2117+
| "Name": "${shuffleRead.REMOTE_BYTES_READ}",
21182118
| "Update": 0,
21192119
| "Internal": true,
21202120
| "Count Failed Values": true
21212121
| },
21222122
| {
21232123
| "ID": 15,
2124-
| "Name": "${shuffleRead.LOCAL_BYTES_READ}",
2124+
| "Name": "${shuffleRead.REMOTE_BYTES_READ_TO_DISK}",
21252125
| "Update": 0,
21262126
| "Internal": true,
21272127
| "Count Failed Values": true
21282128
| },
21292129
| {
21302130
| "ID": 16,
2131-
| "Name": "${shuffleRead.FETCH_WAIT_TIME}",
2131+
| "Name": "${shuffleRead.LOCAL_BYTES_READ}",
21322132
| "Update": 0,
21332133
| "Internal": true,
21342134
| "Count Failed Values": true
21352135
| },
21362136
| {
21372137
| "ID": 17,
2138-
| "Name": "${shuffleRead.RECORDS_READ}",
2138+
| "Name": "${shuffleRead.FETCH_WAIT_TIME}",
21392139
| "Update": 0,
21402140
| "Internal": true,
21412141
| "Count Failed Values": true
21422142
| },
21432143
| {
21442144
| "ID": 18,
2145-
| "Name": "${shuffleWrite.BYTES_WRITTEN}",
2145+
| "Name": "${shuffleRead.RECORDS_READ}",
21462146
| "Update": 0,
21472147
| "Internal": true,
21482148
| "Count Failed Values": true
21492149
| },
21502150
| {
21512151
| "ID": 19,
2152-
| "Name": "${shuffleWrite.RECORDS_WRITTEN}",
2152+
| "Name": "${shuffleWrite.BYTES_WRITTEN}",
21532153
| "Update": 0,
21542154
| "Internal": true,
21552155
| "Count Failed Values": true
21562156
| },
21572157
| {
21582158
| "ID": 20,
2159-
| "Name": "${shuffleWrite.WRITE_TIME}",
2159+
| "Name": "${shuffleWrite.RECORDS_WRITTEN}",
21602160
| "Update": 0,
21612161
| "Internal": true,
21622162
| "Count Failed Values": true
21632163
| },
21642164
| {
21652165
| "ID": 21,
2166+
| "Name": "${shuffleWrite.WRITE_TIME}",
2167+
| "Update": 0,
2168+
| "Internal": true,
2169+
| "Count Failed Values": true
2170+
| },
2171+
| {
2172+
| "ID": 22,
21662173
| "Name": "${input.BYTES_READ}",
21672174
| "Update": 2100,
21682175
| "Internal": true,
21692176
| "Count Failed Values": true
21702177
| },
21712178
| {
2172-
| "ID": 22,
2179+
| "ID": 23,
21732180
| "Name": "${input.RECORDS_READ}",
21742181
| "Update": 21,
21752182
| "Internal": true,
21762183
| "Count Failed Values": true
21772184
| },
21782185
| {
2179-
| "ID": 23,
2186+
| "ID": 24,
21802187
| "Name": "${output.BYTES_WRITTEN}",
21812188
| "Update": 1200,
21822189
| "Internal": true,
21832190
| "Count Failed Values": true
21842191
| },
21852192
| {
2186-
| "ID": 24,
2193+
| "ID": 25,
21872194
| "Name": "${output.RECORDS_WRITTEN}",
21882195
| "Update": 12,
21892196
| "Internal": true,
21902197
| "Count Failed Values": true
21912198
| },
21922199
| {
2193-
| "ID": 25,
2200+
| "ID": 26,
21942201
| "Name": "$TEST_ACCUM",
21952202
| "Update": 0,
21962203
| "Internal": true,

sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ import java.util.concurrent.atomic.AtomicInteger
2323
import scala.collection.AbstractIterator
2424
import scala.collection.mutable.{ArrayBuffer, ListBuffer}
2525

26-
import org.apache.spark.{broadcast, SparkEnv, TaskKilledException}
26+
import org.apache.spark.{broadcast, SparkEnv, TaskContext, TaskKilledException}
2727
import org.apache.spark.internal.Logging
2828
import org.apache.spark.io.CompressionCodec
2929
import org.apache.spark.rdd.{RDD, RDDOperationScope}
@@ -362,6 +362,7 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ
362362
out.writeInt(-1)
363363
out.flush()
364364
out.close()
365+
TaskContext.get().taskMetrics().setRecordsOutput(count)
365366
Iterator((count, bos.toByteArray))
366367
}
367368
}
@@ -544,7 +545,8 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ
544545
throw new IllegalArgumentException(s"Limit cannot exceed threshold ${conf.limitMaxRows}")
545546
}
546547
logInfo(s"Return limit result as a SimpleRepeatableIterator.")
547-
SimpleRepeatableIterator[R, InternalRow](executeTake(n), row => Iterator(proj(row)))
548+
val array = executeTake(n)
549+
SimpleRepeatableIterator[R, InternalRow](array, row => Iterator(proj(row)), array.length)
548550
}
549551

550552
/**

0 commit comments

Comments
 (0)