From 88478b70c364a09bbbf064b5932215a9c62297ef Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Fri, 7 May 2021 00:28:55 +0800 Subject: [PATCH 1/2] WriteTaskStatsTracker should know which file the row is written to --- .../sql/execution/datasources/BasicWriteStatsTracker.scala | 2 +- .../sql/execution/datasources/FileFormatDataWriter.scala | 4 ++-- .../spark/sql/execution/datasources/WriteStatsTracker.scala | 3 ++- 3 files changed, 5 insertions(+), 4 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/BasicWriteStatsTracker.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/BasicWriteStatsTracker.scala index 4f60a9d4c8c0..160ee6dc8d55 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/BasicWriteStatsTracker.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/BasicWriteStatsTracker.scala @@ -151,7 +151,7 @@ class BasicWriteTaskStatsTracker(hadoopConf: Configuration) } } - override def newRow(row: InternalRow): Unit = { + override def newRow(filePath: String, row: InternalRow): Unit = { numRows += 1 } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatDataWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatDataWriter.scala index 8230737a61ca..7e5a8cce2783 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatDataWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatDataWriter.scala @@ -157,7 +157,7 @@ class SingleDirectoryDataWriter( } currentWriter.write(record) - statsTrackers.foreach(_.newRow(record)) + statsTrackers.foreach(_.newRow(currentWriter.path, record)) recordsInFile += 1 } } @@ -301,7 +301,7 @@ abstract class BaseDynamicPartitionDataWriter( protected def writeRecord(record: InternalRow): Unit = { val outputRow = getOutputRow(record) currentWriter.write(outputRow) - statsTrackers.foreach(_.newRow(outputRow)) + statsTrackers.foreach(_.newRow(currentWriter.path, outputRow)) recordsInFile += 1 } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriteStatsTracker.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriteStatsTracker.scala index aaf866bced86..f58aa33be869 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriteStatsTracker.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriteStatsTracker.scala @@ -59,9 +59,10 @@ trait WriteTaskStatsTracker { * Process the fact that a new row to update the tracked statistics accordingly. * @note Keep in mind that any overhead here is per-row, obviously, * so implementations should be as lightweight as possible. + * @param filePath Path of the file which the row is written to. * @param row Current data row to be processed. */ - def newRow(row: InternalRow): Unit + def newRow(filePath: String, row: InternalRow): Unit /** * Returns the final statistics computed so far. From 8e9f6cb8d5b19792fc408c7b9fe9bcc77a4a56d7 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Fri, 7 May 2021 11:59:17 +0800 Subject: [PATCH 2/2] add test --- .../CustomWriteTaskStatsTrackerSuite.scala | 72 +++++++++++++++++++ 1 file changed, 72 insertions(+) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/CustomWriteTaskStatsTrackerSuite.scala diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/CustomWriteTaskStatsTrackerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/CustomWriteTaskStatsTrackerSuite.scala new file mode 100644 index 000000000000..82d873a2cd81 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/CustomWriteTaskStatsTrackerSuite.scala @@ -0,0 +1,72 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources + +import scala.collection.mutable + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.InternalRow + +class CustomWriteTaskStatsTrackerSuite extends SparkFunSuite { + + def checkFinalStats(tracker: CustomWriteTaskStatsTracker, result: Map[String, Int]): Unit = { + assert(tracker.getFinalStats().asInstanceOf[CustomWriteTaskStats].numRowsPerFile == result) + } + + test("sequential file writing") { + val tracker = new CustomWriteTaskStatsTracker + tracker.newFile("a") + tracker.newRow("a", null) + tracker.newRow("a", null) + tracker.newFile("b") + checkFinalStats(tracker, Map("a" -> 2, "b" -> 0)) + } + + test("random file writing") { + val tracker = new CustomWriteTaskStatsTracker + tracker.newFile("a") + tracker.newRow("a", null) + tracker.newFile("b") + tracker.newRow("a", null) + tracker.newRow("b", null) + checkFinalStats(tracker, Map("a" -> 2, "b" -> 1)) + } +} + +class CustomWriteTaskStatsTracker extends WriteTaskStatsTracker { + + val numRowsPerFile = mutable.Map.empty[String, Int] + + override def newPartition(partitionValues: InternalRow): Unit = {} + + override def newFile(filePath: String): Unit = { + numRowsPerFile.put(filePath, 0) + } + + override def closeFile(filePath: String): Unit = {} + + override def newRow(filePath: String, row: InternalRow): Unit = { + numRowsPerFile(filePath) += 1 + } + + override def getFinalStats(): WriteTaskStats = { + CustomWriteTaskStats(numRowsPerFile.toMap) + } +} + +case class CustomWriteTaskStats(numRowsPerFile: Map[String, Int]) extends WriteTaskStats