Skip to content

Commit 30e3fcb

Browse files
rdbluedongjoon-hyun
authored andcommitted
[SPARK-32168][SQL] Fix hidden partitioning correctness bug in SQL overwrite
### What changes were proposed in this pull request? When converting an `INSERT OVERWRITE` query to a v2 overwrite plan, Spark attempts to detect when a dynamic overwrite and a static overwrite will produce the same result so it can use the static overwrite. Spark incorrectly detects when dynamic and static overwrites are equivalent when there are hidden partitions, such as `days(ts)`. This updates the analyzer rule `ResolveInsertInto` to always use a dynamic overwrite when the mode is dynamic, and static when the mode is static. This avoids the problem by not trying to determine whether the two plans are equivalent and always using the one that corresponds to the partition overwrite mode. ### Why are the changes needed? This is a correctness bug. If a table has hidden partitions, all of the values for those partitions are dropped instead of dynamically overwriting changed partitions. This only affects SQL commands (not `DataFrameWriter`) writing to tables that have hidden partitions. It is also only a problem when the partition overwrite mode is dynamic. ### Does this PR introduce _any_ user-facing change? Yes, it fixes the correctness bug detailed above. ### How was this patch tested? * This updates the in-memory table to support a hidden partition transform, `days`, and adds a test case to `DataSourceV2SQLSuite` in which the table uses this hidden partition function. This test fails without the fix to `ResolveInsertInto`. * This updates the test case `InsertInto: overwrite - multiple static partitions - dynamic mode` in `InsertIntoTests`. The result of the SQL command is unchanged, but the SQL command will now use a dynamic overwrite so the test now uses `dynamicOverwriteTest`. Closes #28993 from rdblue/fix-insert-overwrite-v2-conversion. Authored-by: Ryan Blue <[email protected]> Signed-off-by: Dongjoon Hyun <[email protected]> (cherry picked from commit 3bb1ac5) Signed-off-by: Dongjoon Hyun <[email protected]>
1 parent ac2c6cd commit 30e3fcb

File tree

6 files changed

+107
-37
lines changed

6 files changed

+107
-37
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1041,12 +1041,10 @@ class Analyzer(
10411041

10421042
val staticPartitions = i.partitionSpec.filter(_._2.isDefined).mapValues(_.get)
10431043
val query = addStaticPartitionColumns(r, i.query, staticPartitions)
1044-
val dynamicPartitionOverwrite = partCols.size > staticPartitions.size &&
1045-
conf.partitionOverwriteMode == PartitionOverwriteMode.DYNAMIC
10461044

10471045
if (!i.overwrite) {
10481046
AppendData.byPosition(r, query)
1049-
} else if (dynamicPartitionOverwrite) {
1047+
} else if (conf.partitionOverwriteMode == PartitionOverwriteMode.DYNAMIC) {
10501048
OverwritePartitionsDynamic.byPosition(r, query)
10511049
} else {
10521050
OverwriteByExpression.byPosition(r, query, staticDeleteExpression(r, staticPartitions))

sql/catalyst/src/test/scala/org/apache/spark/sql/connector/InMemoryTable.scala

Lines changed: 57 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717

1818
package org.apache.spark.sql.connector
1919

20+
import java.time.{Instant, ZoneId}
21+
import java.time.temporal.ChronoUnit
2022
import java.util
2123

2224
import scala.collection.JavaConverters._
@@ -25,12 +27,13 @@ import scala.collection.mutable
2527
import org.scalatest.Assertions._
2628

2729
import org.apache.spark.sql.catalyst.InternalRow
30+
import org.apache.spark.sql.catalyst.util.DateTimeUtils
2831
import org.apache.spark.sql.connector.catalog._
29-
import org.apache.spark.sql.connector.expressions.{IdentityTransform, NamedReference, Transform}
32+
import org.apache.spark.sql.connector.expressions.{BucketTransform, DaysTransform, HoursTransform, IdentityTransform, MonthsTransform, Transform, YearsTransform}
3033
import org.apache.spark.sql.connector.read._
3134
import org.apache.spark.sql.connector.write._
3235
import org.apache.spark.sql.sources.{And, EqualTo, Filter, IsNotNull}
33-
import org.apache.spark.sql.types.StructType
36+
import org.apache.spark.sql.types.{DataType, DateType, StructType, TimestampType}
3437
import org.apache.spark.sql.util.CaseInsensitiveStringMap
3538

3639
/**
@@ -46,10 +49,15 @@ class InMemoryTable(
4649
private val allowUnsupportedTransforms =
4750
properties.getOrDefault("allow-unsupported-transforms", "false").toBoolean
4851

49-
partitioning.foreach { t =>
50-
if (!t.isInstanceOf[IdentityTransform] && !allowUnsupportedTransforms) {
51-
throw new IllegalArgumentException(s"Transform $t must be IdentityTransform")
52-
}
52+
partitioning.foreach {
53+
case _: IdentityTransform =>
54+
case _: YearsTransform =>
55+
case _: MonthsTransform =>
56+
case _: DaysTransform =>
57+
case _: HoursTransform =>
58+
case _: BucketTransform =>
59+
case t if !allowUnsupportedTransforms =>
60+
throw new IllegalArgumentException(s"Transform $t is not a supported transform")
5361
}
5462

5563
// The key `Seq[Any]` is the partition values.
@@ -66,8 +74,14 @@ class InMemoryTable(
6674
}
6775
}
6876

77+
private val UTC = ZoneId.of("UTC")
78+
private val EPOCH_LOCAL_DATE = Instant.EPOCH.atZone(UTC).toLocalDate
79+
6980
private def getKey(row: InternalRow): Seq[Any] = {
70-
def extractor(fieldNames: Array[String], schema: StructType, row: InternalRow): Any = {
81+
def extractor(
82+
fieldNames: Array[String],
83+
schema: StructType,
84+
row: InternalRow): (Any, DataType) = {
7185
val index = schema.fieldIndex(fieldNames(0))
7286
val value = row.toSeq(schema).apply(index)
7387
if (fieldNames.length > 1) {
@@ -78,10 +92,44 @@ class InMemoryTable(
7892
throw new IllegalArgumentException(s"Unsupported type, ${dataType.simpleString}")
7993
}
8094
} else {
81-
value
95+
(value, schema(index).dataType)
8296
}
8397
}
84-
partCols.map(fieldNames => extractor(fieldNames, schema, row))
98+
99+
partitioning.map {
100+
case IdentityTransform(ref) =>
101+
extractor(ref.fieldNames, schema, row)._1
102+
case YearsTransform(ref) =>
103+
extractor(ref.fieldNames, schema, row) match {
104+
case (days: Int, DateType) =>
105+
ChronoUnit.YEARS.between(EPOCH_LOCAL_DATE, DateTimeUtils.daysToLocalDate(days))
106+
case (micros: Long, TimestampType) =>
107+
val localDate = DateTimeUtils.microsToInstant(micros).atZone(UTC).toLocalDate
108+
ChronoUnit.YEARS.between(EPOCH_LOCAL_DATE, localDate)
109+
}
110+
case MonthsTransform(ref) =>
111+
extractor(ref.fieldNames, schema, row) match {
112+
case (days: Int, DateType) =>
113+
ChronoUnit.MONTHS.between(EPOCH_LOCAL_DATE, DateTimeUtils.daysToLocalDate(days))
114+
case (micros: Long, TimestampType) =>
115+
val localDate = DateTimeUtils.microsToInstant(micros).atZone(UTC).toLocalDate
116+
ChronoUnit.MONTHS.between(EPOCH_LOCAL_DATE, localDate)
117+
}
118+
case DaysTransform(ref) =>
119+
extractor(ref.fieldNames, schema, row) match {
120+
case (days, DateType) =>
121+
days
122+
case (micros: Long, TimestampType) =>
123+
ChronoUnit.DAYS.between(Instant.EPOCH, DateTimeUtils.microsToInstant(micros))
124+
}
125+
case HoursTransform(ref) =>
126+
extractor(ref.fieldNames, schema, row) match {
127+
case (micros: Long, TimestampType) =>
128+
ChronoUnit.HOURS.between(Instant.EPOCH, DateTimeUtils.microsToInstant(micros))
129+
}
130+
case BucketTransform(numBuckets, ref) =>
131+
(extractor(ref.fieldNames, schema, row).hashCode() & Integer.MAX_VALUE) % numBuckets
132+
}
85133
}
86134

87135
def withData(data: Array[BufferedRows]): InMemoryTable = dataMap.synchronized {

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ case class BatchScanExec(
4040

4141
override def hashCode(): Int = batch.hashCode()
4242

43-
override lazy val partitions: Seq[InputPartition] = batch.planInputPartitions()
43+
@transient override lazy val partitions: Seq[InputPartition] = batch.planInputPartitions()
4444

4545
override lazy val readerFactory: PartitionReaderFactory = batch.createReaderFactory()
4646

sql/core/src/test/scala/org/apache/spark/sql/DataFrameWriterV2Suite.scala

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -336,7 +336,6 @@ class DataFrameWriterV2Suite extends QueryTest with SharedSparkSession with Befo
336336
spark.table("source")
337337
.withColumn("ts", lit("2019-06-01 10:00:00.000000").cast("timestamp"))
338338
.writeTo("testcat.table_name")
339-
.tableProperty("allow-unsupported-transforms", "true")
340339
.partitionedBy(years($"ts"))
341340
.create()
342341

@@ -350,7 +349,6 @@ class DataFrameWriterV2Suite extends QueryTest with SharedSparkSession with Befo
350349
spark.table("source")
351350
.withColumn("ts", lit("2019-06-01 10:00:00.000000").cast("timestamp"))
352351
.writeTo("testcat.table_name")
353-
.tableProperty("allow-unsupported-transforms", "true")
354352
.partitionedBy(months($"ts"))
355353
.create()
356354

@@ -364,7 +362,6 @@ class DataFrameWriterV2Suite extends QueryTest with SharedSparkSession with Befo
364362
spark.table("source")
365363
.withColumn("ts", lit("2019-06-01 10:00:00.000000").cast("timestamp"))
366364
.writeTo("testcat.table_name")
367-
.tableProperty("allow-unsupported-transforms", "true")
368365
.partitionedBy(days($"ts"))
369366
.create()
370367

@@ -378,7 +375,6 @@ class DataFrameWriterV2Suite extends QueryTest with SharedSparkSession with Befo
378375
spark.table("source")
379376
.withColumn("ts", lit("2019-06-01 10:00:00.000000").cast("timestamp"))
380377
.writeTo("testcat.table_name")
381-
.tableProperty("allow-unsupported-transforms", "true")
382378
.partitionedBy(hours($"ts"))
383379
.create()
384380

@@ -391,7 +387,6 @@ class DataFrameWriterV2Suite extends QueryTest with SharedSparkSession with Befo
391387
test("Create: partitioned by bucket(4, id)") {
392388
spark.table("source")
393389
.writeTo("testcat.table_name")
394-
.tableProperty("allow-unsupported-transforms", "true")
395390
.partitionedBy(bucket(4, $"id"))
396391
.create()
397392

@@ -596,7 +591,6 @@ class DataFrameWriterV2Suite extends QueryTest with SharedSparkSession with Befo
596591
lit("2019-09-02 07:00:00.000000").cast("timestamp") as "modified",
597592
lit("America/Los_Angeles") as "timezone"))
598593
.writeTo("testcat.table_name")
599-
.tableProperty("allow-unsupported-transforms", "true")
600594
.partitionedBy(
601595
years($"ts.created"), months($"ts.created"), days($"ts.created"), hours($"ts.created"),
602596
years($"ts.modified"), months($"ts.modified"), days($"ts.modified"), hours($"ts.modified")
@@ -624,7 +618,6 @@ class DataFrameWriterV2Suite extends QueryTest with SharedSparkSession with Befo
624618
lit("2019-09-02 07:00:00.000000").cast("timestamp") as "modified",
625619
lit("America/Los_Angeles") as "timezone"))
626620
.writeTo("testcat.table_name")
627-
.tableProperty("allow-unsupported-transforms", "true")
628621
.partitionedBy(bucket(4, $"ts.timezone"))
629622
.create()
630623

sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala

Lines changed: 36 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,9 @@
1717

1818
package org.apache.spark.sql.connector
1919

20+
import java.sql.Timestamp
21+
import java.time.LocalDate
22+
2023
import scala.collection.JavaConverters._
2124

2225
import org.apache.spark.SparkException
@@ -27,7 +30,7 @@ import org.apache.spark.sql.connector.catalog._
2730
import org.apache.spark.sql.connector.catalog.CatalogManager.SESSION_CATALOG_NAME
2831
import org.apache.spark.sql.connector.catalog.CatalogV2Util.withDefaultOwnership
2932
import org.apache.spark.sql.internal.{SQLConf, StaticSQLConf}
30-
import org.apache.spark.sql.internal.SQLConf.V2_SESSION_CATALOG_IMPLEMENTATION
33+
import org.apache.spark.sql.internal.SQLConf.{PARTITION_OVERWRITE_MODE, PartitionOverwriteMode, V2_SESSION_CATALOG_IMPLEMENTATION}
3134
import org.apache.spark.sql.internal.connector.SimpleTableProvider
3235
import org.apache.spark.sql.sources.SimpleScanSource
3336
import org.apache.spark.sql.types.{BooleanType, LongType, StringType, StructField, StructType}
@@ -1630,7 +1633,6 @@ class DataSourceV2SQLSuite
16301633
"""
16311634
|CREATE TABLE testcat.t (id int, `a.b` string) USING foo
16321635
|CLUSTERED BY (`a.b`) INTO 4 BUCKETS
1633-
|OPTIONS ('allow-unsupported-transforms'=true)
16341636
""".stripMargin)
16351637

16361638
val testCatalog = catalog("testcat").asTableCatalog.asInstanceOf[InMemoryTableCatalog]
@@ -2476,6 +2478,38 @@ class DataSourceV2SQLSuite
24762478
}
24772479
}
24782480

2481+
test("SPARK-32168: INSERT OVERWRITE - hidden days partition - dynamic mode") {
2482+
def testTimestamp(daysOffset: Int): Timestamp = {
2483+
Timestamp.valueOf(LocalDate.of(2020, 1, 1 + daysOffset).atStartOfDay())
2484+
}
2485+
2486+
withSQLConf(PARTITION_OVERWRITE_MODE.key -> PartitionOverwriteMode.DYNAMIC.toString) {
2487+
val t1 = s"${catalogAndNamespace}tbl"
2488+
withTable(t1) {
2489+
val df = spark.createDataFrame(Seq(
2490+
(testTimestamp(1), "a"),
2491+
(testTimestamp(2), "b"),
2492+
(testTimestamp(3), "c"))).toDF("ts", "data")
2493+
df.createOrReplaceTempView("source_view")
2494+
2495+
sql(s"CREATE TABLE $t1 (ts timestamp, data string) " +
2496+
s"USING $v2Format PARTITIONED BY (days(ts))")
2497+
sql(s"INSERT INTO $t1 VALUES " +
2498+
s"(CAST(date_add('2020-01-01', 2) AS timestamp), 'dummy'), " +
2499+
s"(CAST(date_add('2020-01-01', 4) AS timestamp), 'keep')")
2500+
sql(s"INSERT OVERWRITE TABLE $t1 SELECT ts, data FROM source_view")
2501+
2502+
val expected = spark.createDataFrame(Seq(
2503+
(testTimestamp(1), "a"),
2504+
(testTimestamp(2), "b"),
2505+
(testTimestamp(3), "c"),
2506+
(testTimestamp(4), "keep"))).toDF("ts", "data")
2507+
2508+
verifyTable(t1, expected)
2509+
}
2510+
}
2511+
}
2512+
24792513
private def testV1Command(sqlCommand: String, sqlParams: String): Unit = {
24802514
val e = intercept[AnalysisException] {
24812515
sql(s"$sqlCommand $sqlParams")

sql/core/src/test/scala/org/apache/spark/sql/connector/InsertIntoTests.scala

Lines changed: 12 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -446,21 +446,18 @@ trait InsertIntoSQLOnlyTests
446446
}
447447
}
448448

449-
test("InsertInto: overwrite - multiple static partitions - dynamic mode") {
450-
// Since all partitions are provided statically, this should be supported by everyone
451-
withSQLConf(PARTITION_OVERWRITE_MODE.key -> PartitionOverwriteMode.DYNAMIC.toString) {
452-
val t1 = s"${catalogAndNamespace}tbl"
453-
withTableAndData(t1) { view =>
454-
sql(s"CREATE TABLE $t1 (id bigint, data string, p int) " +
455-
s"USING $v2Format PARTITIONED BY (id, p)")
456-
sql(s"INSERT INTO $t1 VALUES (2L, 'dummy', 2), (4L, 'keep', 2)")
457-
sql(s"INSERT OVERWRITE TABLE $t1 PARTITION (id = 2, p = 2) SELECT data FROM $view")
458-
verifyTable(t1, Seq(
459-
(2, "a", 2),
460-
(2, "b", 2),
461-
(2, "c", 2),
462-
(4, "keep", 2)).toDF("id", "data", "p"))
463-
}
449+
dynamicOverwriteTest("InsertInto: overwrite - multiple static partitions - dynamic mode") {
450+
val t1 = s"${catalogAndNamespace}tbl"
451+
withTableAndData(t1) { view =>
452+
sql(s"CREATE TABLE $t1 (id bigint, data string, p int) " +
453+
s"USING $v2Format PARTITIONED BY (id, p)")
454+
sql(s"INSERT INTO $t1 VALUES (2L, 'dummy', 2), (4L, 'keep', 2)")
455+
sql(s"INSERT OVERWRITE TABLE $t1 PARTITION (id = 2, p = 2) SELECT data FROM $view")
456+
verifyTable(t1, Seq(
457+
(2, "a", 2),
458+
(2, "b", 2),
459+
(2, "c", 2),
460+
(4, "keep", 2)).toDF("id", "data", "p"))
464461
}
465462
}
466463

0 commit comments

Comments
 (0)