Skip to content

Commit e83910f

Browse files
committed
[SPARK-26164][SQL][FOLLOWUP] WriteTaskStatsTracker should know which file the row is written to
### What changes were proposed in this pull request? This is a follow-up of #32198 Before #32198, in `WriteTaskStatsTracker.newRow`, we know that the row is written to the current file. After #32198 , we no longer know this connection. This PR adds the file path parameter in `WriteTaskStatsTracker.newRow` to bring back the connection. ### Why are the changes needed? To not break some custom `WriteTaskStatsTracker` implementations. ### Does this PR introduce _any_ user-facing change? no ### How was this patch tested? N/A Closes #32459 from cloud-fan/minor. Authored-by: Wenchen Fan <[email protected]> Signed-off-by: Wenchen Fan <[email protected]>
1 parent 33c1034 commit e83910f

File tree

4 files changed

+77
-4
lines changed

4 files changed

+77
-4
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/BasicWriteStatsTracker.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,7 @@ class BasicWriteTaskStatsTracker(hadoopConf: Configuration)
151151
}
152152
}
153153

154-
override def newRow(row: InternalRow): Unit = {
154+
override def newRow(filePath: String, row: InternalRow): Unit = {
155155
numRows += 1
156156
}
157157

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatDataWriter.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,7 @@ class SingleDirectoryDataWriter(
157157
}
158158

159159
currentWriter.write(record)
160-
statsTrackers.foreach(_.newRow(record))
160+
statsTrackers.foreach(_.newRow(currentWriter.path, record))
161161
recordsInFile += 1
162162
}
163163
}
@@ -301,7 +301,7 @@ abstract class BaseDynamicPartitionDataWriter(
301301
protected def writeRecord(record: InternalRow): Unit = {
302302
val outputRow = getOutputRow(record)
303303
currentWriter.write(outputRow)
304-
statsTrackers.foreach(_.newRow(outputRow))
304+
statsTrackers.foreach(_.newRow(currentWriter.path, outputRow))
305305
recordsInFile += 1
306306
}
307307
}

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriteStatsTracker.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,9 +59,10 @@ trait WriteTaskStatsTracker {
5959
* Process the fact that a new row to update the tracked statistics accordingly.
6060
* @note Keep in mind that any overhead here is per-row, obviously,
6161
* so implementations should be as lightweight as possible.
62+
* @param filePath Path of the file which the row is written to.
6263
* @param row Current data row to be processed.
6364
*/
64-
def newRow(row: InternalRow): Unit
65+
def newRow(filePath: String, row: InternalRow): Unit
6566

6667
/**
6768
* Returns the final statistics computed so far.
Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.execution.datasources
19+
20+
import scala.collection.mutable
21+
22+
import org.apache.spark.SparkFunSuite
23+
import org.apache.spark.sql.catalyst.InternalRow
24+
25+
class CustomWriteTaskStatsTrackerSuite extends SparkFunSuite {
26+
27+
def checkFinalStats(tracker: CustomWriteTaskStatsTracker, result: Map[String, Int]): Unit = {
28+
assert(tracker.getFinalStats().asInstanceOf[CustomWriteTaskStats].numRowsPerFile == result)
29+
}
30+
31+
test("sequential file writing") {
32+
val tracker = new CustomWriteTaskStatsTracker
33+
tracker.newFile("a")
34+
tracker.newRow("a", null)
35+
tracker.newRow("a", null)
36+
tracker.newFile("b")
37+
checkFinalStats(tracker, Map("a" -> 2, "b" -> 0))
38+
}
39+
40+
test("random file writing") {
41+
val tracker = new CustomWriteTaskStatsTracker
42+
tracker.newFile("a")
43+
tracker.newRow("a", null)
44+
tracker.newFile("b")
45+
tracker.newRow("a", null)
46+
tracker.newRow("b", null)
47+
checkFinalStats(tracker, Map("a" -> 2, "b" -> 1))
48+
}
49+
}
50+
51+
class CustomWriteTaskStatsTracker extends WriteTaskStatsTracker {
52+
53+
val numRowsPerFile = mutable.Map.empty[String, Int]
54+
55+
override def newPartition(partitionValues: InternalRow): Unit = {}
56+
57+
override def newFile(filePath: String): Unit = {
58+
numRowsPerFile.put(filePath, 0)
59+
}
60+
61+
override def closeFile(filePath: String): Unit = {}
62+
63+
override def newRow(filePath: String, row: InternalRow): Unit = {
64+
numRowsPerFile(filePath) += 1
65+
}
66+
67+
override def getFinalStats(): WriteTaskStats = {
68+
CustomWriteTaskStats(numRowsPerFile.toMap)
69+
}
70+
}
71+
72+
case class CustomWriteTaskStats(numRowsPerFile: Map[String, Int]) extends WriteTaskStats

0 commit comments

Comments
 (0)