From fd65dd4486429b6561a24f9b60192b6d51054523 Mon Sep 17 00:00:00 2001 From: Ryan Blue Date: Fri, 3 Jul 2020 13:30:34 -0700 Subject: [PATCH 1/3] Fix INSERT OVERWRITE for v2 with hidden partitions. --- .../sql/catalyst/analysis/Analyzer.scala | 4 +- .../spark/sql/connector/InMemoryTable.scala | 31 ++++++++++++---- .../sql/connector/DataSourceV2SQLSuite.scala | 37 ++++++++++++++++++- .../spark/sql/connector/InsertIntoTests.scala | 3 +- 4 files changed, 62 insertions(+), 13 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index d08a6382f738..f92cf377bff1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -1050,12 +1050,10 @@ class Analyzer( val staticPartitions = i.partitionSpec.filter(_._2.isDefined).mapValues(_.get) val query = addStaticPartitionColumns(r, i.query, staticPartitions) - val dynamicPartitionOverwrite = partCols.size > staticPartitions.size && - conf.partitionOverwriteMode == PartitionOverwriteMode.DYNAMIC if (!i.overwrite) { AppendData.byPosition(r, query) - } else if (dynamicPartitionOverwrite) { + } else if (conf.partitionOverwriteMode == PartitionOverwriteMode.DYNAMIC) { OverwritePartitionsDynamic.byPosition(r, query) } else { OverwriteByExpression.byPosition(r, query, staticDeleteExpression(r, staticPartitions)) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/InMemoryTable.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/InMemoryTable.scala index 3d7026e180cd..65cfbe957462 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/InMemoryTable.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/InMemoryTable.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.connector +import java.time.ZoneId import java.util import scala.collection.JavaConverters._ @@ -25,12 +26,13 @@ import scala.collection.mutable import org.scalatest.Assertions._ import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.connector.catalog._ -import org.apache.spark.sql.connector.expressions.{IdentityTransform, NamedReference, Transform} +import org.apache.spark.sql.connector.expressions.{DaysTransform, IdentityTransform, Transform} import org.apache.spark.sql.connector.read._ import org.apache.spark.sql.connector.write._ import org.apache.spark.sql.sources.{And, EqualTo, Filter, IsNotNull} -import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.types.{DataType, DateType, StructType, TimestampType} import org.apache.spark.sql.util.CaseInsensitiveStringMap /** @@ -47,8 +49,9 @@ class InMemoryTable( properties.getOrDefault("allow-unsupported-transforms", "false").toBoolean partitioning.foreach { t => - if (!t.isInstanceOf[IdentityTransform] && !allowUnsupportedTransforms) { - throw new IllegalArgumentException(s"Transform $t must be IdentityTransform") + if (!t.isInstanceOf[IdentityTransform] && !t.isInstanceOf[DaysTransform] && + !allowUnsupportedTransforms) { + throw new IllegalArgumentException(s"Transform $t must be IdentityTransform or DaysTransform") } } @@ -67,7 +70,10 @@ class InMemoryTable( } private def getKey(row: InternalRow): Seq[Any] = { - def extractor(fieldNames: Array[String], schema: StructType, row: InternalRow): Any = { + def extractor( + fieldNames: Array[String], + schema: StructType, + row: InternalRow): (Any, DataType) = { val index = schema.fieldIndex(fieldNames(0)) val value = row.toSeq(schema).apply(index) if (fieldNames.length > 1) { @@ -78,10 +84,21 @@ class InMemoryTable( throw new IllegalArgumentException(s"Unsupported type, ${dataType.simpleString}") } } else { - value + (value, schema(index).dataType) } } - partCols.map(fieldNames => extractor(fieldNames, schema, row)) + + partitioning.map { + case IdentityTransform(ref) => + extractor(ref.fieldNames, schema, row)._1 + case DaysTransform(ref) => + extractor(ref.fieldNames, schema, row) match { + case (days, DateType) => + days + case (micros: Long, TimestampType) => + DateTimeUtils.microsToDays(micros, ZoneId.of("UTC")) + } + } } def withData(data: Array[BufferedRows]): InMemoryTable = dataMap.synchronized { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala index f7f4df8f2d2e..0e4cd39ef8b4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala @@ -17,6 +17,9 @@ package org.apache.spark.sql.connector +import java.sql.Timestamp +import java.time.LocalDate + import scala.collection.JavaConverters._ import org.apache.spark.SparkException @@ -27,7 +30,7 @@ import org.apache.spark.sql.connector.catalog._ import org.apache.spark.sql.connector.catalog.CatalogManager.SESSION_CATALOG_NAME import org.apache.spark.sql.connector.catalog.CatalogV2Util.withDefaultOwnership import org.apache.spark.sql.internal.{SQLConf, StaticSQLConf} -import org.apache.spark.sql.internal.SQLConf.V2_SESSION_CATALOG_IMPLEMENTATION +import org.apache.spark.sql.internal.SQLConf.{PARTITION_OVERWRITE_MODE, PartitionOverwriteMode, V2_SESSION_CATALOG_IMPLEMENTATION} import org.apache.spark.sql.internal.connector.SimpleTableProvider import org.apache.spark.sql.sources.SimpleScanSource import org.apache.spark.sql.types.{BooleanType, LongType, StringType, StructField, StructType} @@ -2494,6 +2497,38 @@ class DataSourceV2SQLSuite } } + test("SPARK-32168: INSERT OVERWRITE - hidden days partition - dynamic mode") { + def testTimestamp(daysOffset: Int): Timestamp = { + Timestamp.valueOf(LocalDate.of(2020, 1, 1 + daysOffset).atStartOfDay()) + } + + withSQLConf(PARTITION_OVERWRITE_MODE.key -> PartitionOverwriteMode.DYNAMIC.toString) { + val t1 = s"${catalogAndNamespace}tbl" + withTable(t1) { + val df = spark.createDataFrame(Seq( + (testTimestamp(1), "a"), + (testTimestamp(2), "b"), + (testTimestamp(3), "c"))).toDF("ts", "data") + df.createOrReplaceTempView("source_view") + + sql(s"CREATE TABLE $t1 (ts timestamp, data string) " + + s"USING $v2Format PARTITIONED BY (days(ts))") + sql(s"INSERT INTO $t1 VALUES " + + s"(CAST(date_add('2020-01-01', 2) AS timestamp), 'dummy'), " + + s"(CAST(date_add('2020-01-01', 4) AS timestamp), 'keep')") + sql(s"INSERT OVERWRITE TABLE $t1 SELECT ts, data FROM source_view") + + val expected = spark.createDataFrame(Seq( + (testTimestamp(1), "a"), + (testTimestamp(2), "b"), + (testTimestamp(3), "c"), + (testTimestamp(4), "keep"))).toDF("ts", "data") + + verifyTable(t1, expected) + } + } + } + private def testV1Command(sqlCommand: String, sqlParams: String): Unit = { val e = intercept[AnalysisException] { sql(s"$sqlCommand $sqlParams") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/InsertIntoTests.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/InsertIntoTests.scala index b88ad5218fcd..618f528a92a5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/InsertIntoTests.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/InsertIntoTests.scala @@ -446,8 +446,7 @@ trait InsertIntoSQLOnlyTests } } - test("InsertInto: overwrite - multiple static partitions - dynamic mode") { - // Since all partitions are provided statically, this should be supported by everyone + dynamicOverwriteTest("InsertInto: overwrite - multiple static partitions - dynamic mode") { withSQLConf(PARTITION_OVERWRITE_MODE.key -> PartitionOverwriteMode.DYNAMIC.toString) { val t1 = s"${catalogAndNamespace}tbl" withTableAndData(t1) { view => From afef6ceb55e45da2fd2a5c06092ac8fd5920cfde Mon Sep 17 00:00:00 2001 From: Ryan Blue Date: Mon, 6 Jul 2020 10:52:20 -0700 Subject: [PATCH 2/3] Implement years, months, hours, and bucket transforms for tests. --- .../spark/sql/connector/InMemoryTable.scala | 47 +++++++++++++++---- .../datasources/v2/BatchScanExec.scala | 2 +- .../spark/sql/DataFrameWriterV2Suite.scala | 7 --- .../sql/connector/DataSourceV2SQLSuite.scala | 1 - 4 files changed, 40 insertions(+), 17 deletions(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/InMemoryTable.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/InMemoryTable.scala index 65cfbe957462..616fc72320ca 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/InMemoryTable.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/InMemoryTable.scala @@ -17,7 +17,8 @@ package org.apache.spark.sql.connector -import java.time.ZoneId +import java.time.{Instant, ZoneId} +import java.time.temporal.ChronoUnit import java.util import scala.collection.JavaConverters._ @@ -28,7 +29,7 @@ import org.scalatest.Assertions._ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.connector.catalog._ -import org.apache.spark.sql.connector.expressions.{DaysTransform, IdentityTransform, Transform} +import org.apache.spark.sql.connector.expressions.{BucketTransform, DaysTransform, HoursTransform, IdentityTransform, MonthsTransform, Transform, YearsTransform} import org.apache.spark.sql.connector.read._ import org.apache.spark.sql.connector.write._ import org.apache.spark.sql.sources.{And, EqualTo, Filter, IsNotNull} @@ -48,11 +49,15 @@ class InMemoryTable( private val allowUnsupportedTransforms = properties.getOrDefault("allow-unsupported-transforms", "false").toBoolean - partitioning.foreach { t => - if (!t.isInstanceOf[IdentityTransform] && !t.isInstanceOf[DaysTransform] && - !allowUnsupportedTransforms) { - throw new IllegalArgumentException(s"Transform $t must be IdentityTransform or DaysTransform") - } + partitioning.foreach { + case _: IdentityTransform => + case _: YearsTransform => + case _: MonthsTransform => + case _: DaysTransform => + case _: HoursTransform => + case _: BucketTransform => + case t if !allowUnsupportedTransforms => + throw new IllegalArgumentException(s"Transform $t is not a supported transform") } // The key `Seq[Any]` is the partition values. @@ -69,6 +74,9 @@ class InMemoryTable( } } + private val UTC = ZoneId.of("UTC") + private val EPOCH_LOCAL_DATE = Instant.EPOCH.atZone(UTC).toLocalDate + private def getKey(row: InternalRow): Seq[Any] = { def extractor( fieldNames: Array[String], @@ -91,13 +99,36 @@ class InMemoryTable( partitioning.map { case IdentityTransform(ref) => extractor(ref.fieldNames, schema, row)._1 + case YearsTransform(ref) => + extractor(ref.fieldNames, schema, row) match { + case (days: Int, DateType) => + ChronoUnit.YEARS.between(EPOCH_LOCAL_DATE, DateTimeUtils.daysToLocalDate(days)) + case (micros: Long, TimestampType) => + val localDate = DateTimeUtils.microsToInstant(micros).atZone(UTC).toLocalDate + ChronoUnit.YEARS.between(EPOCH_LOCAL_DATE, localDate) + } + case MonthsTransform(ref) => + extractor(ref.fieldNames, schema, row) match { + case (days: Int, DateType) => + ChronoUnit.MONTHS.between(EPOCH_LOCAL_DATE, DateTimeUtils.daysToLocalDate(days)) + case (micros: Long, TimestampType) => + val localDate = DateTimeUtils.microsToInstant(micros).atZone(UTC).toLocalDate + ChronoUnit.MONTHS.between(EPOCH_LOCAL_DATE, localDate) + } case DaysTransform(ref) => extractor(ref.fieldNames, schema, row) match { case (days, DateType) => days case (micros: Long, TimestampType) => - DateTimeUtils.microsToDays(micros, ZoneId.of("UTC")) + ChronoUnit.DAYS.between(Instant.EPOCH, DateTimeUtils.microsToInstant(micros)) + } + case HoursTransform(ref) => + extractor(ref.fieldNames, schema, row) match { + case (micros: Long, TimestampType) => + ChronoUnit.HOURS.between(Instant.EPOCH, DateTimeUtils.microsToInstant(micros)) } + case BucketTransform(numBuckets, ref) => + (extractor(ref.fieldNames, schema, row).hashCode() & Integer.MAX_VALUE) % numBuckets } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala index e4e7887017a1..c199df676ced 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala @@ -40,7 +40,7 @@ case class BatchScanExec( override def hashCode(): Int = batch.hashCode() - override lazy val partitions: Seq[InputPartition] = batch.planInputPartitions() + @transient override lazy val partitions: Seq[InputPartition] = batch.planInputPartitions() override lazy val readerFactory: PartitionReaderFactory = batch.createReaderFactory() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWriterV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWriterV2Suite.scala index ac2ebd8bd748..508eefafd075 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWriterV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWriterV2Suite.scala @@ -336,7 +336,6 @@ class DataFrameWriterV2Suite extends QueryTest with SharedSparkSession with Befo spark.table("source") .withColumn("ts", lit("2019-06-01 10:00:00.000000").cast("timestamp")) .writeTo("testcat.table_name") - .tableProperty("allow-unsupported-transforms", "true") .partitionedBy(years($"ts")) .create() @@ -350,7 +349,6 @@ class DataFrameWriterV2Suite extends QueryTest with SharedSparkSession with Befo spark.table("source") .withColumn("ts", lit("2019-06-01 10:00:00.000000").cast("timestamp")) .writeTo("testcat.table_name") - .tableProperty("allow-unsupported-transforms", "true") .partitionedBy(months($"ts")) .create() @@ -364,7 +362,6 @@ class DataFrameWriterV2Suite extends QueryTest with SharedSparkSession with Befo spark.table("source") .withColumn("ts", lit("2019-06-01 10:00:00.000000").cast("timestamp")) .writeTo("testcat.table_name") - .tableProperty("allow-unsupported-transforms", "true") .partitionedBy(days($"ts")) .create() @@ -378,7 +375,6 @@ class DataFrameWriterV2Suite extends QueryTest with SharedSparkSession with Befo spark.table("source") .withColumn("ts", lit("2019-06-01 10:00:00.000000").cast("timestamp")) .writeTo("testcat.table_name") - .tableProperty("allow-unsupported-transforms", "true") .partitionedBy(hours($"ts")) .create() @@ -391,7 +387,6 @@ class DataFrameWriterV2Suite extends QueryTest with SharedSparkSession with Befo test("Create: partitioned by bucket(4, id)") { spark.table("source") .writeTo("testcat.table_name") - .tableProperty("allow-unsupported-transforms", "true") .partitionedBy(bucket(4, $"id")) .create() @@ -596,7 +591,6 @@ class DataFrameWriterV2Suite extends QueryTest with SharedSparkSession with Befo lit("2019-09-02 07:00:00.000000").cast("timestamp") as "modified", lit("America/Los_Angeles") as "timezone")) .writeTo("testcat.table_name") - .tableProperty("allow-unsupported-transforms", "true") .partitionedBy( years($"ts.created"), months($"ts.created"), days($"ts.created"), hours($"ts.created"), years($"ts.modified"), months($"ts.modified"), days($"ts.modified"), hours($"ts.modified") @@ -624,7 +618,6 @@ class DataFrameWriterV2Suite extends QueryTest with SharedSparkSession with Befo lit("2019-09-02 07:00:00.000000").cast("timestamp") as "modified", lit("America/Los_Angeles") as "timezone")) .writeTo("testcat.table_name") - .tableProperty("allow-unsupported-transforms", "true") .partitionedBy(bucket(4, $"ts.timezone")) .create() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala index 0e4cd39ef8b4..85aea3ce41ec 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala @@ -1650,7 +1650,6 @@ class DataSourceV2SQLSuite """ |CREATE TABLE testcat.t (id int, `a.b` string) USING foo |CLUSTERED BY (`a.b`) INTO 4 BUCKETS - |OPTIONS ('allow-unsupported-transforms'=true) """.stripMargin) val testCatalog = catalog("testcat").asTableCatalog.asInstanceOf[InMemoryTableCatalog] From 2efb84cf1abab6a67718501e79d00c4f63ecc8aa Mon Sep 17 00:00:00 2001 From: Ryan Blue Date: Tue, 7 Jul 2020 13:51:00 -0700 Subject: [PATCH 3/3] Remove unnecessary configuration from updated test. --- .../spark/sql/connector/InsertIntoTests.scala | 24 +++++++++---------- 1 file changed, 11 insertions(+), 13 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/InsertIntoTests.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/InsertIntoTests.scala index 618f528a92a5..2cc7a1f99464 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/InsertIntoTests.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/InsertIntoTests.scala @@ -447,19 +447,17 @@ trait InsertIntoSQLOnlyTests } dynamicOverwriteTest("InsertInto: overwrite - multiple static partitions - dynamic mode") { - withSQLConf(PARTITION_OVERWRITE_MODE.key -> PartitionOverwriteMode.DYNAMIC.toString) { - val t1 = s"${catalogAndNamespace}tbl" - withTableAndData(t1) { view => - sql(s"CREATE TABLE $t1 (id bigint, data string, p int) " + - s"USING $v2Format PARTITIONED BY (id, p)") - sql(s"INSERT INTO $t1 VALUES (2L, 'dummy', 2), (4L, 'keep', 2)") - sql(s"INSERT OVERWRITE TABLE $t1 PARTITION (id = 2, p = 2) SELECT data FROM $view") - verifyTable(t1, Seq( - (2, "a", 2), - (2, "b", 2), - (2, "c", 2), - (4, "keep", 2)).toDF("id", "data", "p")) - } + val t1 = s"${catalogAndNamespace}tbl" + withTableAndData(t1) { view => + sql(s"CREATE TABLE $t1 (id bigint, data string, p int) " + + s"USING $v2Format PARTITIONED BY (id, p)") + sql(s"INSERT INTO $t1 VALUES (2L, 'dummy', 2), (4L, 'keep', 2)") + sql(s"INSERT OVERWRITE TABLE $t1 PARTITION (id = 2, p = 2) SELECT data FROM $view") + verifyTable(t1, Seq( + (2, "a", 2), + (2, "b", 2), + (2, "c", 2), + (4, "keep", 2)).toDF("id", "data", "p")) } }