Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -1050,12 +1050,10 @@ class Analyzer(

val staticPartitions = i.partitionSpec.filter(_._2.isDefined).mapValues(_.get)
val query = addStaticPartitionColumns(r, i.query, staticPartitions)
val dynamicPartitionOverwrite = partCols.size > staticPartitions.size &&
conf.partitionOverwriteMode == PartitionOverwriteMode.DYNAMIC

if (!i.overwrite) {
AppendData.byPosition(r, query)
} else if (dynamicPartitionOverwrite) {
} else if (conf.partitionOverwriteMode == PartitionOverwriteMode.DYNAMIC) {
OverwritePartitionsDynamic.byPosition(r, query)
} else {
OverwriteByExpression.byPosition(r, query, staticDeleteExpression(r, staticPartitions))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

package org.apache.spark.sql.connector

import java.time.{Instant, ZoneId}
import java.time.temporal.ChronoUnit
import java.util

import scala.collection.JavaConverters._
Expand All @@ -25,12 +27,13 @@ import scala.collection.mutable
import org.scalatest.Assertions._

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.connector.catalog._
import org.apache.spark.sql.connector.expressions.{IdentityTransform, NamedReference, Transform}
import org.apache.spark.sql.connector.expressions.{BucketTransform, DaysTransform, HoursTransform, IdentityTransform, MonthsTransform, Transform, YearsTransform}
import org.apache.spark.sql.connector.read._
import org.apache.spark.sql.connector.write._
import org.apache.spark.sql.sources.{And, EqualTo, Filter, IsNotNull}
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.types.{DataType, DateType, StructType, TimestampType}
import org.apache.spark.sql.util.CaseInsensitiveStringMap

/**
Expand All @@ -46,10 +49,15 @@ class InMemoryTable(
private val allowUnsupportedTransforms =
properties.getOrDefault("allow-unsupported-transforms", "false").toBoolean

partitioning.foreach { t =>
if (!t.isInstanceOf[IdentityTransform] && !allowUnsupportedTransforms) {
throw new IllegalArgumentException(s"Transform $t must be IdentityTransform")
}
partitioning.foreach {
case _: IdentityTransform =>
case _: YearsTransform =>
case _: MonthsTransform =>
case _: DaysTransform =>
case _: HoursTransform =>
case _: BucketTransform =>
case t if !allowUnsupportedTransforms =>
throw new IllegalArgumentException(s"Transform $t is not a supported transform")
}

// The key `Seq[Any]` is the partition values.
Expand All @@ -66,8 +74,14 @@ class InMemoryTable(
}
}

private val UTC = ZoneId.of("UTC")
private val EPOCH_LOCAL_DATE = Instant.EPOCH.atZone(UTC).toLocalDate

private def getKey(row: InternalRow): Seq[Any] = {
def extractor(fieldNames: Array[String], schema: StructType, row: InternalRow): Any = {
def extractor(
fieldNames: Array[String],
schema: StructType,
row: InternalRow): (Any, DataType) = {
val index = schema.fieldIndex(fieldNames(0))
val value = row.toSeq(schema).apply(index)
if (fieldNames.length > 1) {
Expand All @@ -78,10 +92,44 @@ class InMemoryTable(
throw new IllegalArgumentException(s"Unsupported type, ${dataType.simpleString}")
}
} else {
value
(value, schema(index).dataType)
}
}
partCols.map(fieldNames => extractor(fieldNames, schema, row))

partitioning.map {
case IdentityTransform(ref) =>
extractor(ref.fieldNames, schema, row)._1
case YearsTransform(ref) =>
extractor(ref.fieldNames, schema, row) match {
case (days: Int, DateType) =>
ChronoUnit.YEARS.between(EPOCH_LOCAL_DATE, DateTimeUtils.daysToLocalDate(days))
case (micros: Long, TimestampType) =>
val localDate = DateTimeUtils.microsToInstant(micros).atZone(UTC).toLocalDate
ChronoUnit.YEARS.between(EPOCH_LOCAL_DATE, localDate)
}
case MonthsTransform(ref) =>
extractor(ref.fieldNames, schema, row) match {
case (days: Int, DateType) =>
ChronoUnit.MONTHS.between(EPOCH_LOCAL_DATE, DateTimeUtils.daysToLocalDate(days))
case (micros: Long, TimestampType) =>
val localDate = DateTimeUtils.microsToInstant(micros).atZone(UTC).toLocalDate
ChronoUnit.MONTHS.between(EPOCH_LOCAL_DATE, localDate)
}
case DaysTransform(ref) =>
extractor(ref.fieldNames, schema, row) match {
case (days, DateType) =>
days
case (micros: Long, TimestampType) =>
ChronoUnit.DAYS.between(Instant.EPOCH, DateTimeUtils.microsToInstant(micros))
}
case HoursTransform(ref) =>
extractor(ref.fieldNames, schema, row) match {
case (micros: Long, TimestampType) =>
ChronoUnit.HOURS.between(Instant.EPOCH, DateTimeUtils.microsToInstant(micros))
}
case BucketTransform(numBuckets, ref) =>
(extractor(ref.fieldNames, schema, row).hashCode() & Integer.MAX_VALUE) % numBuckets
}
}

def withData(data: Array[BufferedRows]): InMemoryTable = dataMap.synchronized {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ case class BatchScanExec(

override def hashCode(): Int = batch.hashCode()

override lazy val partitions: Seq[InputPartition] = batch.planInputPartitions()
@transient override lazy val partitions: Seq[InputPartition] = batch.planInputPartitions()

override lazy val readerFactory: PartitionReaderFactory = batch.createReaderFactory()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -336,7 +336,6 @@ class DataFrameWriterV2Suite extends QueryTest with SharedSparkSession with Befo
spark.table("source")
.withColumn("ts", lit("2019-06-01 10:00:00.000000").cast("timestamp"))
.writeTo("testcat.table_name")
.tableProperty("allow-unsupported-transforms", "true")
.partitionedBy(years($"ts"))
.create()

Expand All @@ -350,7 +349,6 @@ class DataFrameWriterV2Suite extends QueryTest with SharedSparkSession with Befo
spark.table("source")
.withColumn("ts", lit("2019-06-01 10:00:00.000000").cast("timestamp"))
.writeTo("testcat.table_name")
.tableProperty("allow-unsupported-transforms", "true")
.partitionedBy(months($"ts"))
.create()

Expand All @@ -364,7 +362,6 @@ class DataFrameWriterV2Suite extends QueryTest with SharedSparkSession with Befo
spark.table("source")
.withColumn("ts", lit("2019-06-01 10:00:00.000000").cast("timestamp"))
.writeTo("testcat.table_name")
.tableProperty("allow-unsupported-transforms", "true")
.partitionedBy(days($"ts"))
.create()

Expand All @@ -378,7 +375,6 @@ class DataFrameWriterV2Suite extends QueryTest with SharedSparkSession with Befo
spark.table("source")
.withColumn("ts", lit("2019-06-01 10:00:00.000000").cast("timestamp"))
.writeTo("testcat.table_name")
.tableProperty("allow-unsupported-transforms", "true")
.partitionedBy(hours($"ts"))
.create()

Expand All @@ -391,7 +387,6 @@ class DataFrameWriterV2Suite extends QueryTest with SharedSparkSession with Befo
test("Create: partitioned by bucket(4, id)") {
spark.table("source")
.writeTo("testcat.table_name")
.tableProperty("allow-unsupported-transforms", "true")
.partitionedBy(bucket(4, $"id"))
.create()

Expand Down Expand Up @@ -596,7 +591,6 @@ class DataFrameWriterV2Suite extends QueryTest with SharedSparkSession with Befo
lit("2019-09-02 07:00:00.000000").cast("timestamp") as "modified",
lit("America/Los_Angeles") as "timezone"))
.writeTo("testcat.table_name")
.tableProperty("allow-unsupported-transforms", "true")
.partitionedBy(
years($"ts.created"), months($"ts.created"), days($"ts.created"), hours($"ts.created"),
years($"ts.modified"), months($"ts.modified"), days($"ts.modified"), hours($"ts.modified")
Expand Down Expand Up @@ -624,7 +618,6 @@ class DataFrameWriterV2Suite extends QueryTest with SharedSparkSession with Befo
lit("2019-09-02 07:00:00.000000").cast("timestamp") as "modified",
lit("America/Los_Angeles") as "timezone"))
.writeTo("testcat.table_name")
.tableProperty("allow-unsupported-transforms", "true")
.partitionedBy(bucket(4, $"ts.timezone"))
.create()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@

package org.apache.spark.sql.connector

import java.sql.Timestamp
import java.time.LocalDate

import scala.collection.JavaConverters._

import org.apache.spark.SparkException
Expand All @@ -27,7 +30,7 @@ import org.apache.spark.sql.connector.catalog._
import org.apache.spark.sql.connector.catalog.CatalogManager.SESSION_CATALOG_NAME
import org.apache.spark.sql.connector.catalog.CatalogV2Util.withDefaultOwnership
import org.apache.spark.sql.internal.{SQLConf, StaticSQLConf}
import org.apache.spark.sql.internal.SQLConf.V2_SESSION_CATALOG_IMPLEMENTATION
import org.apache.spark.sql.internal.SQLConf.{PARTITION_OVERWRITE_MODE, PartitionOverwriteMode, V2_SESSION_CATALOG_IMPLEMENTATION}
import org.apache.spark.sql.internal.connector.SimpleTableProvider
import org.apache.spark.sql.sources.SimpleScanSource
import org.apache.spark.sql.types.{BooleanType, LongType, StringType, StructField, StructType}
Expand Down Expand Up @@ -1647,7 +1650,6 @@ class DataSourceV2SQLSuite
"""
|CREATE TABLE testcat.t (id int, `a.b` string) USING foo
|CLUSTERED BY (`a.b`) INTO 4 BUCKETS
|OPTIONS ('allow-unsupported-transforms'=true)
""".stripMargin)

val testCatalog = catalog("testcat").asTableCatalog.asInstanceOf[InMemoryTableCatalog]
Expand Down Expand Up @@ -2494,6 +2496,38 @@ class DataSourceV2SQLSuite
}
}

test("SPARK-32168: INSERT OVERWRITE - hidden days partition - dynamic mode") {
def testTimestamp(daysOffset: Int): Timestamp = {
Timestamp.valueOf(LocalDate.of(2020, 1, 1 + daysOffset).atStartOfDay())
}

withSQLConf(PARTITION_OVERWRITE_MODE.key -> PartitionOverwriteMode.DYNAMIC.toString) {
val t1 = s"${catalogAndNamespace}tbl"
withTable(t1) {
val df = spark.createDataFrame(Seq(
(testTimestamp(1), "a"),
(testTimestamp(2), "b"),
(testTimestamp(3), "c"))).toDF("ts", "data")
df.createOrReplaceTempView("source_view")

sql(s"CREATE TABLE $t1 (ts timestamp, data string) " +
s"USING $v2Format PARTITIONED BY (days(ts))")
sql(s"INSERT INTO $t1 VALUES " +
s"(CAST(date_add('2020-01-01', 2) AS timestamp), 'dummy'), " +
s"(CAST(date_add('2020-01-01', 4) AS timestamp), 'keep')")
sql(s"INSERT OVERWRITE TABLE $t1 SELECT ts, data FROM source_view")

val expected = spark.createDataFrame(Seq(
(testTimestamp(1), "a"),
(testTimestamp(2), "b"),
(testTimestamp(3), "c"),
(testTimestamp(4), "keep"))).toDF("ts", "data")

verifyTable(t1, expected)
}
}
}

private def testV1Command(sqlCommand: String, sqlParams: String): Unit = {
val e = intercept[AnalysisException] {
sql(s"$sqlCommand $sqlParams")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -446,21 +446,18 @@ trait InsertIntoSQLOnlyTests
}
}

test("InsertInto: overwrite - multiple static partitions - dynamic mode") {
// Since all partitions are provided statically, this should be supported by everyone
withSQLConf(PARTITION_OVERWRITE_MODE.key -> PartitionOverwriteMode.DYNAMIC.toString) {
val t1 = s"${catalogAndNamespace}tbl"
withTableAndData(t1) { view =>
sql(s"CREATE TABLE $t1 (id bigint, data string, p int) " +
s"USING $v2Format PARTITIONED BY (id, p)")
sql(s"INSERT INTO $t1 VALUES (2L, 'dummy', 2), (4L, 'keep', 2)")
sql(s"INSERT OVERWRITE TABLE $t1 PARTITION (id = 2, p = 2) SELECT data FROM $view")
verifyTable(t1, Seq(
(2, "a", 2),
(2, "b", 2),
(2, "c", 2),
(4, "keep", 2)).toDF("id", "data", "p"))
}
dynamicOverwriteTest("InsertInto: overwrite - multiple static partitions - dynamic mode") {
val t1 = s"${catalogAndNamespace}tbl"
withTableAndData(t1) { view =>
sql(s"CREATE TABLE $t1 (id bigint, data string, p int) " +
s"USING $v2Format PARTITIONED BY (id, p)")
sql(s"INSERT INTO $t1 VALUES (2L, 'dummy', 2), (4L, 'keep', 2)")
sql(s"INSERT OVERWRITE TABLE $t1 PARTITION (id = 2, p = 2) SELECT data FROM $view")
verifyTable(t1, Seq(
(2, "a", 2),
(2, "b", 2),
(2, "c", 2),
(4, "keep", 2)).toDF("id", "data", "p"))
}
}

Expand Down