Skip to content

Commit 2c9d5ef

Browse files
committed
[SPARK-21463] Allow userSpecifiedSchema to override partition inference performed by MetadataLogFileIndex
## What changes were proposed in this pull request? When using the MetadataLogFileIndex to read back a table, we don't respect the user provided schema as the proper column types. This can lead to issues when trying to read strings that look like dates that get truncated to DateType, or longs being truncated to IntegerType, just because a long value doesn't exist. ## How was this patch tested? Unit tests and manual tests Author: Burak Yavuz <[email protected]> Closes #18676 from brkyvz/stream-partitioning.
1 parent 8cd9cdf commit 2c9d5ef

File tree

4 files changed

+69
-10
lines changed

4 files changed

+69
-10
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala

Lines changed: 26 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,24 @@ case class DataSource(
9696
bucket.sortColumnNames, "in the sort definition", equality)
9797
}
9898

99+
/**
100+
* In the read path, only managed tables by Hive provide the partition columns properly when
101+
* initializing this class. All other file based data sources will try to infer the partitioning,
102+
* and then cast the inferred types to user specified dataTypes if the partition columns exist
103+
* inside `userSpecifiedSchema`, otherwise we can hit data corruption bugs like SPARK-18510, or
104+
* inconsistent data types as reported in SPARK-21463.
105+
* @param fileIndex A FileIndex that will perform partition inference
106+
* @return The PartitionSchema resolved from inference and cast according to `userSpecifiedSchema`
107+
*/
108+
private def combineInferredAndUserSpecifiedPartitionSchema(fileIndex: FileIndex): StructType = {
109+
val resolved = fileIndex.partitionSchema.map { partitionField =>
110+
// SPARK-18510: try to get schema from userSpecifiedSchema, otherwise fallback to inferred
111+
userSpecifiedSchema.flatMap(_.find(f => equality(f.name, partitionField.name))).getOrElse(
112+
partitionField)
113+
}
114+
StructType(resolved)
115+
}
116+
99117
/**
100118
* Get the schema of the given FileFormat, if provided by `userSpecifiedSchema`, or try to infer
101119
* it. In the read path, only managed tables by Hive provide the partition columns properly when
@@ -139,12 +157,7 @@ case class DataSource(
139157
val partitionSchema = if (partitionColumns.isEmpty) {
140158
// Try to infer partitioning, because no DataSource in the read path provides the partitioning
141159
// columns properly unless it is a Hive DataSource
142-
val resolved = tempFileIndex.partitionSchema.map { partitionField =>
143-
// SPARK-18510: try to get schema from userSpecifiedSchema, otherwise fallback to inferred
144-
userSpecifiedSchema.flatMap(_.find(f => equality(f.name, partitionField.name))).getOrElse(
145-
partitionField)
146-
}
147-
StructType(resolved)
160+
combineInferredAndUserSpecifiedPartitionSchema(tempFileIndex)
148161
} else {
149162
// maintain old behavior before SPARK-18510. If userSpecifiedSchema is empty used inferred
150163
// partitioning
@@ -336,7 +349,13 @@ case class DataSource(
336349
caseInsensitiveOptions.get("path").toSeq ++ paths,
337350
sparkSession.sessionState.newHadoopConf()) =>
338351
val basePath = new Path((caseInsensitiveOptions.get("path").toSeq ++ paths).head)
339-
val fileCatalog = new MetadataLogFileIndex(sparkSession, basePath)
352+
val tempFileCatalog = new MetadataLogFileIndex(sparkSession, basePath, None)
353+
val fileCatalog = if (userSpecifiedSchema.nonEmpty) {
354+
val partitionSchema = combineInferredAndUserSpecifiedPartitionSchema(tempFileCatalog)
355+
new MetadataLogFileIndex(sparkSession, basePath, Option(partitionSchema))
356+
} else {
357+
tempFileCatalog
358+
}
340359
val dataSchema = userSpecifiedSchema.orElse {
341360
format.inferSchema(
342361
sparkSession,

sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSource.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -195,7 +195,7 @@ class FileStreamSource(
195195
private def allFilesUsingMetadataLogFileIndex() = {
196196
// Note if `sourceHasMetadata` holds, then `qualifiedBasePath` is guaranteed to be a
197197
// non-glob path
198-
new MetadataLogFileIndex(sparkSession, qualifiedBasePath).allFiles()
198+
new MetadataLogFileIndex(sparkSession, qualifiedBasePath, None).allFiles()
199199
}
200200

201201
/**

sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MetadataLogFileIndex.scala

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,14 +23,21 @@ import org.apache.hadoop.fs.{FileStatus, Path}
2323

2424
import org.apache.spark.sql.SparkSession
2525
import org.apache.spark.sql.execution.datasources._
26+
import org.apache.spark.sql.types.StructType
2627

2728

2829
/**
2930
* A [[FileIndex]] that generates the list of files to processing by reading them from the
3031
* metadata log files generated by the [[FileStreamSink]].
32+
*
33+
* @param userPartitionSchema an optional partition schema that will be use to provide types for
34+
* the discovered partitions
3135
*/
32-
class MetadataLogFileIndex(sparkSession: SparkSession, path: Path)
33-
extends PartitioningAwareFileIndex(sparkSession, Map.empty, None) {
36+
class MetadataLogFileIndex(
37+
sparkSession: SparkSession,
38+
path: Path,
39+
userPartitionSchema: Option[StructType])
40+
extends PartitioningAwareFileIndex(sparkSession, Map.empty, userPartitionSchema) {
3441

3542
private val metadataDirectory = new Path(path, FileStreamSink.metadataDir)
3643
logInfo(s"Reading streaming file log from $metadataDirectory")

sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ import org.apache.spark.sql.catalyst.expressions.Literal
3535
import org.apache.spark.sql.catalyst.util.DateTimeUtils
3636
import org.apache.spark.sql.execution.datasources._
3737
import org.apache.spark.sql.execution.datasources.{PartitionPath => Partition}
38+
import org.apache.spark.sql.execution.streaming.MemoryStream
3839
import org.apache.spark.sql.functions._
3940
import org.apache.spark.sql.internal.SQLConf
4041
import org.apache.spark.sql.test.SharedSQLContext
@@ -1022,4 +1023,36 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha
10221023
}
10231024
}
10241025
}
1026+
1027+
test("SPARK-21463: MetadataLogFileIndex should respect userSpecifiedSchema for partition cols") {
1028+
withTempDir { tempDir =>
1029+
val output = new File(tempDir, "output").toString
1030+
val checkpoint = new File(tempDir, "chkpoint").toString
1031+
try {
1032+
val stream = MemoryStream[(String, Int)]
1033+
val df = stream.toDS().toDF("time", "value")
1034+
val sq = df.writeStream
1035+
.option("checkpointLocation", checkpoint)
1036+
.format("parquet")
1037+
.partitionBy("time")
1038+
.start(output)
1039+
1040+
stream.addData(("2017-01-01-00", 1), ("2017-01-01-01", 2))
1041+
sq.processAllAvailable()
1042+
1043+
val schema = new StructType()
1044+
.add("time", StringType)
1045+
.add("value", IntegerType)
1046+
val readBack = spark.read.schema(schema).parquet(output)
1047+
assert(readBack.schema.toSet === schema.toSet)
1048+
1049+
checkAnswer(
1050+
readBack,
1051+
Seq(Row("2017-01-01-00", 1), Row("2017-01-01-01", 2))
1052+
)
1053+
} finally {
1054+
spark.streams.active.foreach(_.stop())
1055+
}
1056+
}
1057+
}
10251058
}

0 commit comments

Comments
 (0)