Skip to content

Commit 689e458

Browse files
alekjarmovAlek Jarmov
authored andcommitted
[SPARK-52846][SQL] Add a metric in JDBCRDD for how long it takes to fetch the resultset
### What changes were proposed in this pull request? * Create a helper function `withTimingNs` * Use the function to measure how long it takes to fetch data from the JDBC source. ### Why are the changes needed? Provides better observability, for example we have execution time, now we would know if the network is the bottleneck. ### Does this PR introduce _any_ user-facing change? In SparkUI user could see this new metric. ### How was this patch tested? Added a unit test. ### Was this patch authored or co-authored using generative AI tooling? No. Closes apache#51536 from alekjarmov/measure-function-utility. Lead-authored-by: alekjarmov <[email protected]> Co-authored-by: Alek Jarmov <[email protected]> Signed-off-by: Wenchen Fan <[email protected]>
1 parent 386e464 commit 689e458

File tree

4 files changed

+69
-16
lines changed

4 files changed

+69
-16
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala

Lines changed: 28 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -189,6 +189,17 @@ class JDBCRDD(
189189
sparkContext,
190190
name = "JDBC query execution time")
191191

192+
/**
193+
* Time needed to fetch the data and transform it into Spark's InternalRow format.
194+
*
195+
* Usually this is spent in network transfer time, but it can be spent in transformation time
196+
* as well if we are transforming some more complex datatype such as structs.
197+
*/
198+
val fetchAndTransformToInternalRowsMetric: SQLMetric = SQLMetrics.createNanoTimingMetric(
199+
sparkContext,
200+
// Message that user sees does not have to leak details about conversion
201+
name = "JDBC remote data fetch and translation time")
202+
192203
private lazy val dialect = JdbcDialects.get(url)
193204

194205
def generateJdbcQuery(partition: Option[JDBCPartition]): String = {
@@ -301,30 +312,33 @@ class JDBCRDD(
301312
stmt.setFetchSize(options.fetchSize)
302313
stmt.setQueryTimeout(options.queryTimeout)
303314

304-
val startTime = System.nanoTime
305-
rs = try {
306-
stmt.executeQuery()
307-
} catch {
308-
case e: SQLException if dialect.isSyntaxErrorBestEffort(e) =>
309-
throw new SparkException(
310-
errorClass = "JDBC_EXTERNAL_ENGINE_SYNTAX_ERROR.DURING_QUERY_EXECUTION",
311-
messageParameters = Map("jdbcQuery" -> sqlText),
312-
cause = e)
315+
rs = SQLMetrics.withTimingNs(queryExecutionTimeMetric) {
316+
try {
317+
stmt.executeQuery()
318+
} catch {
319+
case e: SQLException if dialect.isSyntaxErrorBestEffort(e) =>
320+
throw new SparkException(
321+
errorClass = "JDBC_EXTERNAL_ENGINE_SYNTAX_ERROR.DURING_QUERY_EXECUTION",
322+
messageParameters = Map("jdbcQuery" -> sqlText),
323+
cause = e)
324+
}
313325
}
314-
val endTime = System.nanoTime
315-
316-
val executionTime = endTime - startTime
317-
queryExecutionTimeMetric.add(executionTime)
318326

319327
val rowsIterator =
320-
JdbcUtils.resultSetToSparkInternalRows(rs, dialect, schema, inputMetrics)
328+
JdbcUtils.resultSetToSparkInternalRows(
329+
rs,
330+
dialect,
331+
schema,
332+
inputMetrics,
333+
Some(fetchAndTransformToInternalRowsMetric))
321334

322335
CompletionIterator[InternalRow, Iterator[InternalRow]](
323336
new InterruptibleIterator(context, rowsIterator), close())
324337
}
325338

326339
override def getMetrics: Seq[(String, SQLMetric)] = {
327340
Seq(
341+
"fetchAndTransformToInternalRowsNs" -> fetchAndTransformToInternalRowsMetric,
328342
"queryExecutionTime" -> queryExecutionTimeMetric
329343
)
330344
}

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ import org.apache.spark.sql.connector.catalog.{Identifier, TableChange}
4646
import org.apache.spark.sql.connector.catalog.index.{SupportsIndex, TableIndex}
4747
import org.apache.spark.sql.connector.expressions.NamedReference
4848
import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors}
49+
import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics}
4950
import org.apache.spark.sql.jdbc.{JdbcDialect, JdbcDialects, JdbcType, NoopDialect}
5051
import org.apache.spark.sql.types._
5152
import org.apache.spark.sql.util.SchemaUtils
@@ -357,7 +358,8 @@ object JdbcUtils extends Logging with SQLConfHelper {
357358
resultSet: ResultSet,
358359
dialect: JdbcDialect,
359360
schema: StructType,
360-
inputMetrics: InputMetrics): Iterator[InternalRow] = {
361+
inputMetrics: InputMetrics,
362+
fetchAndTransformToInternalRowsMetric: Option[SQLMetric] = None): Iterator[InternalRow] = {
361363
new NextIterator[InternalRow] {
362364
private[this] val rs = resultSet
363365
private[this] val getters: Array[JDBCValueGetter] = makeGetters(dialect, schema)
@@ -372,7 +374,7 @@ object JdbcUtils extends Logging with SQLConfHelper {
372374
}
373375
}
374376

375-
override protected def getNext(): InternalRow = {
377+
private def getNextWithoutTiming: InternalRow = {
376378
if (rs.next()) {
377379
inputMetrics.incRecordsRead(1)
378380
var i = 0
@@ -387,6 +389,16 @@ object JdbcUtils extends Logging with SQLConfHelper {
387389
null.asInstanceOf[InternalRow]
388390
}
389391
}
392+
393+
override protected def getNext(): InternalRow = {
394+
if (fetchAndTransformToInternalRowsMetric.isDefined) {
395+
SQLMetrics.withTimingNs(fetchAndTransformToInternalRowsMetric.get) {
396+
getNextWithoutTiming
397+
}
398+
} else {
399+
getNextWithoutTiming
400+
}
401+
}
390402
}
391403
}
392404

sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -221,4 +221,19 @@ object SQLMetrics {
221221
SparkListenerDriverAccumUpdates(executionId.toLong, metrics.map(m => m.id -> m.value)))
222222
}
223223
}
224+
225+
/**
226+
* Measures the time taken by the function `f` in nanoseconds and adds it to the provided metric.
227+
*
228+
* @param metric SQLMetric to record the time taken.
229+
* @param f Function/Codeblock to execute and measure.
230+
* @return The result of the function `f`.
231+
*/
232+
def withTimingNs[T](metric: SQLMetric)(f: => T): T = {
233+
val startTime = System.nanoTime()
234+
val result = f
235+
val endTime = System.nanoTime()
236+
metric.add(endTime - startTime)
237+
result
238+
}
224239
}

sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -987,6 +987,18 @@ class SQLMetricsSuite extends SharedSparkSession with SQLMetricsTestUtils
987987
assert(SQLMetrics.createSizeMetric(sparkContext, name = "m").toInfoUpdate.update === Some(-1))
988988
assert(SQLMetrics.createMetric(sparkContext, name = "m").toInfoUpdate.update === Some(0))
989989
}
990+
991+
test("withTimingNs should time and return same result") {
992+
val metric = SQLMetrics.createTimingMetric(sparkContext, name = "m")
993+
994+
// Use a simple block that returns a value
995+
val result = SQLMetrics.withTimingNs(metric) {
996+
42
997+
}
998+
999+
assert(result === 42)
1000+
assert(!metric.isZero, "Metric was not increased")
1001+
}
9901002
}
9911003

9921004
case class CustomFileCommitProtocol(

0 commit comments

Comments
 (0)