Skip to content

Commit fd65dd4

Browse files
committed
Fix INSERT OVERWRITE for v2 with hidden partitions.
1 parent 42f01e3 commit fd65dd4

File tree

4 files changed

+62
-13
lines changed

4 files changed

+62
-13
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1050,12 +1050,10 @@ class Analyzer(
10501050

10511051
val staticPartitions = i.partitionSpec.filter(_._2.isDefined).mapValues(_.get)
10521052
val query = addStaticPartitionColumns(r, i.query, staticPartitions)
1053-
val dynamicPartitionOverwrite = partCols.size > staticPartitions.size &&
1054-
conf.partitionOverwriteMode == PartitionOverwriteMode.DYNAMIC
10551053

10561054
if (!i.overwrite) {
10571055
AppendData.byPosition(r, query)
1058-
} else if (dynamicPartitionOverwrite) {
1056+
} else if (conf.partitionOverwriteMode == PartitionOverwriteMode.DYNAMIC) {
10591057
OverwritePartitionsDynamic.byPosition(r, query)
10601058
} else {
10611059
OverwriteByExpression.byPosition(r, query, staticDeleteExpression(r, staticPartitions))

sql/catalyst/src/test/scala/org/apache/spark/sql/connector/InMemoryTable.scala

Lines changed: 24 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
package org.apache.spark.sql.connector
1919

20+
import java.time.ZoneId
2021
import java.util
2122

2223
import scala.collection.JavaConverters._
@@ -25,12 +26,13 @@ import scala.collection.mutable
2526
import org.scalatest.Assertions._
2627

2728
import org.apache.spark.sql.catalyst.InternalRow
29+
import org.apache.spark.sql.catalyst.util.DateTimeUtils
2830
import org.apache.spark.sql.connector.catalog._
29-
import org.apache.spark.sql.connector.expressions.{IdentityTransform, NamedReference, Transform}
31+
import org.apache.spark.sql.connector.expressions.{DaysTransform, IdentityTransform, Transform}
3032
import org.apache.spark.sql.connector.read._
3133
import org.apache.spark.sql.connector.write._
3234
import org.apache.spark.sql.sources.{And, EqualTo, Filter, IsNotNull}
33-
import org.apache.spark.sql.types.StructType
35+
import org.apache.spark.sql.types.{DataType, DateType, StructType, TimestampType}
3436
import org.apache.spark.sql.util.CaseInsensitiveStringMap
3537

3638
/**
@@ -47,8 +49,9 @@ class InMemoryTable(
4749
properties.getOrDefault("allow-unsupported-transforms", "false").toBoolean
4850

4951
partitioning.foreach { t =>
50-
if (!t.isInstanceOf[IdentityTransform] && !allowUnsupportedTransforms) {
51-
throw new IllegalArgumentException(s"Transform $t must be IdentityTransform")
52+
if (!t.isInstanceOf[IdentityTransform] && !t.isInstanceOf[DaysTransform] &&
53+
!allowUnsupportedTransforms) {
54+
throw new IllegalArgumentException(s"Transform $t must be IdentityTransform or DaysTransform")
5255
}
5356
}
5457

@@ -67,7 +70,10 @@ class InMemoryTable(
6770
}
6871

6972
private def getKey(row: InternalRow): Seq[Any] = {
70-
def extractor(fieldNames: Array[String], schema: StructType, row: InternalRow): Any = {
73+
def extractor(
74+
fieldNames: Array[String],
75+
schema: StructType,
76+
row: InternalRow): (Any, DataType) = {
7177
val index = schema.fieldIndex(fieldNames(0))
7278
val value = row.toSeq(schema).apply(index)
7379
if (fieldNames.length > 1) {
@@ -78,10 +84,21 @@ class InMemoryTable(
7884
throw new IllegalArgumentException(s"Unsupported type, ${dataType.simpleString}")
7985
}
8086
} else {
81-
value
87+
(value, schema(index).dataType)
8288
}
8389
}
84-
partCols.map(fieldNames => extractor(fieldNames, schema, row))
90+
91+
partitioning.map {
92+
case IdentityTransform(ref) =>
93+
extractor(ref.fieldNames, schema, row)._1
94+
case DaysTransform(ref) =>
95+
extractor(ref.fieldNames, schema, row) match {
96+
case (days, DateType) =>
97+
days
98+
case (micros: Long, TimestampType) =>
99+
DateTimeUtils.microsToDays(micros, ZoneId.of("UTC"))
100+
}
101+
}
85102
}
86103

87104
def withData(data: Array[BufferedRows]): InMemoryTable = dataMap.synchronized {

sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,9 @@
1717

1818
package org.apache.spark.sql.connector
1919

20+
import java.sql.Timestamp
21+
import java.time.LocalDate
22+
2023
import scala.collection.JavaConverters._
2124

2225
import org.apache.spark.SparkException
@@ -27,7 +30,7 @@ import org.apache.spark.sql.connector.catalog._
2730
import org.apache.spark.sql.connector.catalog.CatalogManager.SESSION_CATALOG_NAME
2831
import org.apache.spark.sql.connector.catalog.CatalogV2Util.withDefaultOwnership
2932
import org.apache.spark.sql.internal.{SQLConf, StaticSQLConf}
30-
import org.apache.spark.sql.internal.SQLConf.V2_SESSION_CATALOG_IMPLEMENTATION
33+
import org.apache.spark.sql.internal.SQLConf.{PARTITION_OVERWRITE_MODE, PartitionOverwriteMode, V2_SESSION_CATALOG_IMPLEMENTATION}
3134
import org.apache.spark.sql.internal.connector.SimpleTableProvider
3235
import org.apache.spark.sql.sources.SimpleScanSource
3336
import org.apache.spark.sql.types.{BooleanType, LongType, StringType, StructField, StructType}
@@ -2494,6 +2497,38 @@ class DataSourceV2SQLSuite
24942497
}
24952498
}
24962499

2500+
test("SPARK-32168: INSERT OVERWRITE - hidden days partition - dynamic mode") {
2501+
def testTimestamp(daysOffset: Int): Timestamp = {
2502+
Timestamp.valueOf(LocalDate.of(2020, 1, 1 + daysOffset).atStartOfDay())
2503+
}
2504+
2505+
withSQLConf(PARTITION_OVERWRITE_MODE.key -> PartitionOverwriteMode.DYNAMIC.toString) {
2506+
val t1 = s"${catalogAndNamespace}tbl"
2507+
withTable(t1) {
2508+
val df = spark.createDataFrame(Seq(
2509+
(testTimestamp(1), "a"),
2510+
(testTimestamp(2), "b"),
2511+
(testTimestamp(3), "c"))).toDF("ts", "data")
2512+
df.createOrReplaceTempView("source_view")
2513+
2514+
sql(s"CREATE TABLE $t1 (ts timestamp, data string) " +
2515+
s"USING $v2Format PARTITIONED BY (days(ts))")
2516+
sql(s"INSERT INTO $t1 VALUES " +
2517+
s"(CAST(date_add('2020-01-01', 2) AS timestamp), 'dummy'), " +
2518+
s"(CAST(date_add('2020-01-01', 4) AS timestamp), 'keep')")
2519+
sql(s"INSERT OVERWRITE TABLE $t1 SELECT ts, data FROM source_view")
2520+
2521+
val expected = spark.createDataFrame(Seq(
2522+
(testTimestamp(1), "a"),
2523+
(testTimestamp(2), "b"),
2524+
(testTimestamp(3), "c"),
2525+
(testTimestamp(4), "keep"))).toDF("ts", "data")
2526+
2527+
verifyTable(t1, expected)
2528+
}
2529+
}
2530+
}
2531+
24972532
private def testV1Command(sqlCommand: String, sqlParams: String): Unit = {
24982533
val e = intercept[AnalysisException] {
24992534
sql(s"$sqlCommand $sqlParams")

sql/core/src/test/scala/org/apache/spark/sql/connector/InsertIntoTests.scala

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -446,8 +446,7 @@ trait InsertIntoSQLOnlyTests
446446
}
447447
}
448448

449-
test("InsertInto: overwrite - multiple static partitions - dynamic mode") {
450-
// Since all partitions are provided statically, this should be supported by everyone
449+
dynamicOverwriteTest("InsertInto: overwrite - multiple static partitions - dynamic mode") {
451450
withSQLConf(PARTITION_OVERWRITE_MODE.key -> PartitionOverwriteMode.DYNAMIC.toString) {
452451
val t1 = s"${catalogAndNamespace}tbl"
453452
withTableAndData(t1) { view =>

0 commit comments

Comments
 (0)