Skip to content

Commit 064d431

Browse files
ericlrxin
authored andcommitted
[SPARK-18185] Fix all forms of INSERT / OVERWRITE TABLE for Datasource tables
## What changes were proposed in this pull request? As of current 2.1, INSERT OVERWRITE with dynamic partitions against a Datasource table will overwrite the entire table instead of only the partitions matching the static keys, as in Hive. It also doesn't respect custom partition locations. This PR adds support for all these operations to Datasource tables managed by the Hive metastore. It is implemented as follows - During planning time, the full set of partitions affected by an INSERT or OVERWRITE command is read from the Hive metastore. - The planner identifies any partitions with custom locations and includes this in the write task metadata. - FileFormatWriter tasks refer to this custom locations map when determining where to write for dynamic partition output. - When the write job finishes, the set of written partitions is compared against the initial set of matched partitions, and the Hive metastore is updated to reflect the newly added / removed partitions. It was necessary to introduce a method for staging files with absolute output paths to `FileCommitProtocol`. These files are not handled by the Hadoop output committer but are moved to their final locations when the job commits. The overwrite behavior of legacy Datasource tables is also changed: no longer will the entire table be overwritten if a partial partition spec is present. cc cloud-fan yhuai ## How was this patch tested? Unit tests, existing tests. Author: Eric Liang <[email protected]> Author: Wenchen Fan <[email protected]> Closes #15814 from ericl/sc-5027. (cherry picked from commit a335634) Signed-off-by: Reynold Xin <[email protected]>
1 parent c602894 commit 064d431

File tree

13 files changed

+411
-73
lines changed

13 files changed

+411
-73
lines changed

core/src/main/scala/org/apache/spark/internal/io/FileCommitProtocol.scala

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,9 +82,24 @@ abstract class FileCommitProtocol {
8282
*
8383
* The "dir" parameter specifies 2, and "ext" parameter specifies both 4 and 5, and the rest
8484
* are left to the commit protocol implementation to decide.
85+
*
86+
* Important: it is the caller's responsibility to add uniquely identifying content to "ext"
87+
* if a task is going to write out multiple files to the same dir. The file commit protocol only
88+
* guarantees that files written by different tasks will not conflict.
8589
*/
8690
def newTaskTempFile(taskContext: TaskAttemptContext, dir: Option[String], ext: String): String
8791

92+
/**
93+
* Similar to newTaskTempFile(), but allows files to committed to an absolute output location.
94+
* Depending on the implementation, there may be weaker guarantees around adding files this way.
95+
*
96+
* Important: it is the caller's responsibility to add uniquely identifying content to "ext"
97+
* if a task is going to write out multiple files to the same dir. The file commit protocol only
98+
* guarantees that files written by different tasks will not conflict.
99+
*/
100+
def newTaskTempFileAbsPath(
101+
taskContext: TaskAttemptContext, absoluteDir: String, ext: String): String
102+
88103
/**
89104
* Commits a task after the writes succeed. Must be called on the executors when running tasks.
90105
*/

core/src/main/scala/org/apache/spark/internal/io/HadoopMapReduceCommitProtocol.scala

Lines changed: 56 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,9 @@
1717

1818
package org.apache.spark.internal.io
1919

20-
import java.util.Date
20+
import java.util.{Date, UUID}
21+
22+
import scala.collection.mutable
2123

2224
import org.apache.hadoop.fs.Path
2325
import org.apache.hadoop.mapreduce._
@@ -42,17 +44,26 @@ class HadoopMapReduceCommitProtocol(jobId: String, path: String)
4244
/** OutputCommitter from Hadoop is not serializable so marking it transient. */
4345
@transient private var committer: OutputCommitter = _
4446

47+
/**
48+
* Tracks files staged by this task for absolute output paths. These outputs are not managed by
49+
* the Hadoop OutputCommitter, so we must move these to their final locations on job commit.
50+
*
51+
* The mapping is from the temp output path to the final desired output path of the file.
52+
*/
53+
@transient private var addedAbsPathFiles: mutable.Map[String, String] = null
54+
55+
/**
56+
* The staging directory for all files committed with absolute output paths.
57+
*/
58+
private def absPathStagingDir: Path = new Path(path, "_temporary-" + jobId)
59+
4560
protected def setupCommitter(context: TaskAttemptContext): OutputCommitter = {
4661
context.getOutputFormatClass.newInstance().getOutputCommitter(context)
4762
}
4863

4964
override def newTaskTempFile(
5065
taskContext: TaskAttemptContext, dir: Option[String], ext: String): String = {
51-
// The file name looks like part-r-00000-2dd664f9-d2c4-4ffe-878f-c6c70c1fb0cb_00003.gz.parquet
52-
// Note that %05d does not truncate the split number, so if we have more than 100000 tasks,
53-
// the file name is fine and won't overflow.
54-
val split = taskContext.getTaskAttemptID.getTaskID.getId
55-
val filename = f"part-$split%05d-$jobId$ext"
66+
val filename = getFilename(taskContext, ext)
5667

5768
val stagingDir: String = committer match {
5869
// For FileOutputCommitter it has its own staging path called "work path".
@@ -67,6 +78,28 @@ class HadoopMapReduceCommitProtocol(jobId: String, path: String)
6778
}
6879
}
6980

81+
override def newTaskTempFileAbsPath(
82+
taskContext: TaskAttemptContext, absoluteDir: String, ext: String): String = {
83+
val filename = getFilename(taskContext, ext)
84+
val absOutputPath = new Path(absoluteDir, filename).toString
85+
86+
// Include a UUID here to prevent file collisions for one task writing to different dirs.
87+
// In principle we could include hash(absoluteDir) instead but this is simpler.
88+
val tmpOutputPath = new Path(
89+
absPathStagingDir, UUID.randomUUID().toString() + "-" + filename).toString
90+
91+
addedAbsPathFiles(tmpOutputPath) = absOutputPath
92+
tmpOutputPath
93+
}
94+
95+
private def getFilename(taskContext: TaskAttemptContext, ext: String): String = {
96+
// The file name looks like part-r-00000-2dd664f9-d2c4-4ffe-878f-c6c70c1fb0cb_00003.gz.parquet
97+
// Note that %05d does not truncate the split number, so if we have more than 100000 tasks,
98+
// the file name is fine and won't overflow.
99+
val split = taskContext.getTaskAttemptID.getTaskID.getId
100+
f"part-$split%05d-$jobId$ext"
101+
}
102+
70103
override def setupJob(jobContext: JobContext): Unit = {
71104
// Setup IDs
72105
val jobId = SparkHadoopWriter.createJobID(new Date, 0)
@@ -87,25 +120,41 @@ class HadoopMapReduceCommitProtocol(jobId: String, path: String)
87120

88121
override def commitJob(jobContext: JobContext, taskCommits: Seq[TaskCommitMessage]): Unit = {
89122
committer.commitJob(jobContext)
123+
val filesToMove = taskCommits.map(_.obj.asInstanceOf[Map[String, String]])
124+
.foldLeft(Map[String, String]())(_ ++ _)
125+
logDebug(s"Committing files staged for absolute locations $filesToMove")
126+
val fs = absPathStagingDir.getFileSystem(jobContext.getConfiguration)
127+
for ((src, dst) <- filesToMove) {
128+
fs.rename(new Path(src), new Path(dst))
129+
}
130+
fs.delete(absPathStagingDir, true)
90131
}
91132

92133
override def abortJob(jobContext: JobContext): Unit = {
93134
committer.abortJob(jobContext, JobStatus.State.FAILED)
135+
val fs = absPathStagingDir.getFileSystem(jobContext.getConfiguration)
136+
fs.delete(absPathStagingDir, true)
94137
}
95138

96139
override def setupTask(taskContext: TaskAttemptContext): Unit = {
97140
committer = setupCommitter(taskContext)
98141
committer.setupTask(taskContext)
142+
addedAbsPathFiles = mutable.Map[String, String]()
99143
}
100144

101145
override def commitTask(taskContext: TaskAttemptContext): TaskCommitMessage = {
102146
val attemptId = taskContext.getTaskAttemptID
103147
SparkHadoopMapRedUtil.commitTask(
104148
committer, taskContext, attemptId.getJobID.getId, attemptId.getTaskID.getId)
105-
EmptyTaskCommitMessage
149+
new TaskCommitMessage(addedAbsPathFiles.toMap)
106150
}
107151

108152
override def abortTask(taskContext: TaskAttemptContext): Unit = {
109153
committer.abortTask(taskContext)
154+
// best effort cleanup of other staged files
155+
for ((src, _) <- addedAbsPathFiles) {
156+
val tmp = new Path(src)
157+
tmp.getFileSystem(taskContext.getConfiguration).delete(tmp, false)
158+
}
110159
}
111160
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -172,24 +172,20 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging {
172172
val tableIdent = visitTableIdentifier(ctx.tableIdentifier)
173173
val partitionKeys = Option(ctx.partitionSpec).map(visitPartitionSpec).getOrElse(Map.empty)
174174

175-
val dynamicPartitionKeys = partitionKeys.filter(_._2.isEmpty)
175+
val dynamicPartitionKeys: Map[String, Option[String]] = partitionKeys.filter(_._2.isEmpty)
176176
if (ctx.EXISTS != null && dynamicPartitionKeys.nonEmpty) {
177177
throw new ParseException(s"Dynamic partitions do not support IF NOT EXISTS. Specified " +
178178
"partitions with value: " + dynamicPartitionKeys.keys.mkString("[", ",", "]"), ctx)
179179
}
180180
val overwrite = ctx.OVERWRITE != null
181-
val overwritePartition =
182-
if (overwrite && partitionKeys.nonEmpty && dynamicPartitionKeys.isEmpty) {
183-
Some(partitionKeys.map(t => (t._1, t._2.get)))
184-
} else {
185-
None
186-
}
181+
val staticPartitionKeys: Map[String, String] =
182+
partitionKeys.filter(_._2.nonEmpty).map(t => (t._1, t._2.get))
187183

188184
InsertIntoTable(
189185
UnresolvedRelation(tableIdent, None),
190186
partitionKeys,
191187
query,
192-
OverwriteOptions(overwrite, overwritePartition),
188+
OverwriteOptions(overwrite, if (overwrite) staticPartitionKeys else Map.empty),
193189
ctx.EXISTS != null)
194190
}
195191

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -350,13 +350,15 @@ case class BroadcastHint(child: LogicalPlan) extends UnaryNode {
350350
* Options for writing new data into a table.
351351
*
352352
* @param enabled whether to overwrite existing data in the table.
353-
* @param specificPartition only data in the specified partition will be overwritten.
353+
* @param staticPartitionKeys if non-empty, specifies that we only want to overwrite partitions
354+
* that match this partial partition spec. If empty, all partitions
355+
* will be overwritten.
354356
*/
355357
case class OverwriteOptions(
356358
enabled: Boolean,
357-
specificPartition: Option[CatalogTypes.TablePartitionSpec] = None) {
358-
if (specificPartition.isDefined) {
359-
assert(enabled, "Overwrite must be enabled when specifying a partition to overwrite.")
359+
staticPartitionKeys: CatalogTypes.TablePartitionSpec = Map.empty) {
360+
if (staticPartitionKeys.nonEmpty) {
361+
assert(enabled, "Overwrite must be enabled when specifying specific partitions.")
360362
}
361363
}
362364

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -185,9 +185,9 @@ class PlanParserSuite extends PlanTest {
185185
OverwriteOptions(
186186
overwrite,
187187
if (overwrite && partition.nonEmpty) {
188-
Some(partition.map(kv => (kv._1, kv._2.get)))
188+
partition.map(kv => (kv._1, kv._2.get))
189189
} else {
190-
None
190+
Map.empty
191191
}),
192192
ifNotExists)
193193

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -417,15 +417,17 @@ case class DataSource(
417417
// will be adjusted within InsertIntoHadoopFsRelation.
418418
val plan =
419419
InsertIntoHadoopFsRelationCommand(
420-
outputPath,
421-
columns,
422-
bucketSpec,
423-
format,
424-
_ => Unit, // No existing table needs to be refreshed.
425-
options,
426-
data.logicalPlan,
427-
mode,
428-
catalogTable)
420+
outputPath = outputPath,
421+
staticPartitionKeys = Map.empty,
422+
customPartitionLocations = Map.empty,
423+
partitionColumns = columns,
424+
bucketSpec = bucketSpec,
425+
fileFormat = format,
426+
refreshFunction = _ => Unit, // No existing table needs to be refreshed.
427+
options = options,
428+
query = data.logicalPlan,
429+
mode = mode,
430+
catalogTable = catalogTable)
429431
sparkSession.sessionState.executePlan(plan).toRdd
430432
// Replace the schema with that of the DataFrame we just wrote out to avoid re-inferring it.
431433
copy(userSpecifiedSchema = Some(data.schema.asNullable)).resolveRelation()

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala

Lines changed: 67 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,10 @@ import org.apache.hadoop.fs.Path
2424
import org.apache.spark.internal.Logging
2525
import org.apache.spark.rdd.RDD
2626
import org.apache.spark.sql._
27-
import org.apache.spark.sql.catalyst.{CatalystConf, CatalystTypeConverters, InternalRow}
27+
import org.apache.spark.sql.catalyst.{CatalystConf, CatalystTypeConverters, InternalRow, TableIdentifier}
2828
import org.apache.spark.sql.catalyst.CatalystTypeConverters.convertToScala
2929
import org.apache.spark.sql.catalyst.analysis._
30-
import org.apache.spark.sql.catalyst.catalog.{CatalogTable, SimpleCatalogRelation}
30+
import org.apache.spark.sql.catalyst.catalog.{CatalogTable, CatalogTablePartition, SimpleCatalogRelation}
3131
import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec
3232
import org.apache.spark.sql.catalyst.expressions
3333
import org.apache.spark.sql.catalyst.expressions._
@@ -37,7 +37,7 @@ import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project, Union}
3737
import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, UnknownPartitioning}
3838
import org.apache.spark.sql.catalyst.rules.Rule
3939
import org.apache.spark.sql.execution.{RowDataSourceScanExec, SparkPlan}
40-
import org.apache.spark.sql.execution.command.{AlterTableAddPartitionCommand, DDLUtils, ExecutedCommandExec}
40+
import org.apache.spark.sql.execution.command._
4141
import org.apache.spark.sql.sources._
4242
import org.apache.spark.sql.types._
4343
import org.apache.spark.unsafe.types.UTF8String
@@ -182,41 +182,53 @@ case class DataSourceAnalysis(conf: CatalystConf) extends Rule[LogicalPlan] {
182182
"Cannot overwrite a path that is also being read from.")
183183
}
184184

185-
val overwritingSinglePartition =
186-
overwrite.specificPartition.isDefined &&
185+
val partitionSchema = query.resolve(
186+
t.partitionSchema, t.sparkSession.sessionState.analyzer.resolver)
187+
val partitionsTrackedByCatalog =
187188
t.sparkSession.sessionState.conf.manageFilesourcePartitions &&
189+
l.catalogTable.isDefined && l.catalogTable.get.partitionColumnNames.nonEmpty &&
188190
l.catalogTable.get.tracksPartitionsInCatalog
189191

190-
val effectiveOutputPath = if (overwritingSinglePartition) {
191-
val partition = t.sparkSession.sessionState.catalog.getPartition(
192-
l.catalogTable.get.identifier, overwrite.specificPartition.get)
193-
new Path(partition.location)
194-
} else {
195-
outputPath
196-
}
197-
198-
val effectivePartitionSchema = if (overwritingSinglePartition) {
199-
Nil
200-
} else {
201-
query.resolve(t.partitionSchema, t.sparkSession.sessionState.analyzer.resolver)
192+
var initialMatchingPartitions: Seq[TablePartitionSpec] = Nil
193+
var customPartitionLocations: Map[TablePartitionSpec, String] = Map.empty
194+
195+
// When partitions are tracked by the catalog, compute all custom partition locations that
196+
// may be relevant to the insertion job.
197+
if (partitionsTrackedByCatalog) {
198+
val matchingPartitions = t.sparkSession.sessionState.catalog.listPartitions(
199+
l.catalogTable.get.identifier, Some(overwrite.staticPartitionKeys))
200+
initialMatchingPartitions = matchingPartitions.map(_.spec)
201+
customPartitionLocations = getCustomPartitionLocations(
202+
t.sparkSession, l.catalogTable.get, outputPath, matchingPartitions)
202203
}
203204

205+
// Callback for updating metastore partition metadata after the insertion job completes.
206+
// TODO(ekl) consider moving this into InsertIntoHadoopFsRelationCommand
204207
def refreshPartitionsCallback(updatedPartitions: Seq[TablePartitionSpec]): Unit = {
205-
if (l.catalogTable.isDefined && updatedPartitions.nonEmpty &&
206-
l.catalogTable.get.partitionColumnNames.nonEmpty &&
207-
l.catalogTable.get.tracksPartitionsInCatalog) {
208-
val metastoreUpdater = AlterTableAddPartitionCommand(
209-
l.catalogTable.get.identifier,
210-
updatedPartitions.map(p => (p, None)),
211-
ifNotExists = true)
212-
metastoreUpdater.run(t.sparkSession)
208+
if (partitionsTrackedByCatalog) {
209+
val newPartitions = updatedPartitions.toSet -- initialMatchingPartitions
210+
if (newPartitions.nonEmpty) {
211+
AlterTableAddPartitionCommand(
212+
l.catalogTable.get.identifier, newPartitions.toSeq.map(p => (p, None)),
213+
ifNotExists = true).run(t.sparkSession)
214+
}
215+
if (overwrite.enabled) {
216+
val deletedPartitions = initialMatchingPartitions.toSet -- updatedPartitions
217+
if (deletedPartitions.nonEmpty) {
218+
AlterTableDropPartitionCommand(
219+
l.catalogTable.get.identifier, deletedPartitions.toSeq,
220+
ifExists = true, purge = true).run(t.sparkSession)
221+
}
222+
}
213223
}
214224
t.location.refresh()
215225
}
216226

217227
val insertCmd = InsertIntoHadoopFsRelationCommand(
218-
effectiveOutputPath,
219-
effectivePartitionSchema,
228+
outputPath,
229+
if (overwrite.enabled) overwrite.staticPartitionKeys else Map.empty,
230+
customPartitionLocations,
231+
partitionSchema,
220232
t.bucketSpec,
221233
t.fileFormat,
222234
refreshPartitionsCallback,
@@ -227,6 +239,34 @@ case class DataSourceAnalysis(conf: CatalystConf) extends Rule[LogicalPlan] {
227239

228240
insertCmd
229241
}
242+
243+
/**
244+
* Given a set of input partitions, returns those that have locations that differ from the
245+
* Hive default (e.g. /k1=v1/k2=v2). These partitions were manually assigned locations by
246+
* the user.
247+
*
248+
* @return a mapping from partition specs to their custom locations
249+
*/
250+
private def getCustomPartitionLocations(
251+
spark: SparkSession,
252+
table: CatalogTable,
253+
basePath: Path,
254+
partitions: Seq[CatalogTablePartition]): Map[TablePartitionSpec, String] = {
255+
val hadoopConf = spark.sessionState.newHadoopConf
256+
val fs = basePath.getFileSystem(hadoopConf)
257+
val qualifiedBasePath = basePath.makeQualified(fs.getUri, fs.getWorkingDirectory)
258+
partitions.flatMap { p =>
259+
val defaultLocation = qualifiedBasePath.suffix(
260+
"/" + PartitioningUtils.getPathFragment(p.spec, table.partitionSchema)).toString
261+
val catalogLocation = new Path(p.location).makeQualified(
262+
fs.getUri, fs.getWorkingDirectory).toString
263+
if (catalogLocation != defaultLocation) {
264+
Some(p.spec -> catalogLocation)
265+
} else {
266+
None
267+
}
268+
}.toMap
269+
}
230270
}
231271

232272

0 commit comments

Comments
 (0)