diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolvePartitionSpec.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolvePartitionSpec.scala index 6d061fce06919..5cdd48834c674 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolvePartitionSpec.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolvePartitionSpec.scala @@ -17,14 +17,14 @@ package org.apache.spark.sql.catalyst.analysis +import scala.collection.JavaConverters._ + import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec -import org.apache.spark.sql.catalyst.expressions.{Cast, Literal} import org.apache.spark.sql.catalyst.plans.logical.{AlterTableAddPartition, AlterTableDropPartition, LogicalPlan} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.connector.catalog.SupportsPartitionManagement -import org.apache.spark.sql.types._ -import org.apache.spark.sql.util.PartitioningUtils.normalizePartitionSpec +import org.apache.spark.sql.util.PartitioningUtils.{castPartitionValues, normalizePartitionSpec} /** * Resolve [[UnresolvedPartitionSpec]] to [[ResolvedPartitionSpec]] in partition related commands. @@ -33,41 +33,40 @@ object ResolvePartitionSpec extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { case r @ AlterTableAddPartition( - ResolvedTable(_, _, table: SupportsPartitionManagement), partSpecs, _) => - r.copy(parts = resolvePartitionSpecs(table.name, partSpecs, table.partitionSchema())) + ResolvedTable(_, _, table: SupportsPartitionManagement), partitionSpec, _) => + r.copy(parts = resolvePartitionSpecs(table, partitionSpec)) case r @ AlterTableDropPartition( - ResolvedTable(_, _, table: SupportsPartitionManagement), partSpecs, _, _, _) => - r.copy(parts = resolvePartitionSpecs(table.name, partSpecs, table.partitionSchema())) + ResolvedTable(_, _, table: SupportsPartitionManagement), partitionSpec, _, _, _) => + r.copy(parts = resolvePartitionSpecs(table, partitionSpec)) } private def resolvePartitionSpecs( - tableName: String, - partSpecs: Seq[PartitionSpec], - partSchema: StructType): Seq[ResolvedPartitionSpec] = - partSpecs.map { + table: SupportsPartitionManagement, + partitionSpec: Seq[PartitionSpec]): Seq[ResolvedPartitionSpec] = + partitionSpec.map { case unresolvedPartSpec: UnresolvedPartitionSpec => ResolvedPartitionSpec( - convertToPartIdent(tableName, unresolvedPartSpec.spec, partSchema), + convertToPartIdent(table, unresolvedPartSpec.spec), unresolvedPartSpec.location) case resolvedPartitionSpec: ResolvedPartitionSpec => resolvedPartitionSpec } private def convertToPartIdent( - tableName: String, - partitionSpec: TablePartitionSpec, - partSchema: StructType): InternalRow = { + table: SupportsPartitionManagement, + partitionSpec: TablePartitionSpec): InternalRow = { + val partitionSchema = table.partitionSchema() val normalizedSpec = normalizePartitionSpec( partitionSpec, - partSchema.map(_.name), - tableName, + partitionSchema.map(_.name), + table.name, conf.resolver) - val partValues = partSchema.map { part => - val raw = normalizedSpec.get(part.name).orNull - Cast(Literal.create(raw, StringType), part.dataType, Some(conf.sessionLocalTimeZone)).eval() - } - InternalRow.fromSeq(partValues) + castPartitionValues( + normalizedSpec, + partitionSchema, + table.properties().asScala.toMap, + conf.sessionLocalTimeZone) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala index ee7216e93ebb5..486af001d713b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala @@ -30,7 +30,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.{FunctionIdentifier, InternalRow, SQLConfHelper, TableIdentifier} import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation -import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, AttributeReference, Cast, ExprId, Literal} +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, AttributeReference, ExprId} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.EstimationUtils import org.apache.spark.sql.catalyst.util._ @@ -38,7 +38,7 @@ import org.apache.spark.sql.connector.catalog.CatalogManager import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.sql.util.CaseInsensitiveStringMap - +import org.apache.spark.sql.util.PartitioningUtils.castPartitionValues /** * A function defined in the catalog. @@ -149,18 +149,8 @@ case class CatalogTablePartition( /** * Given the partition schema, returns a row with that schema holding the partition values. */ - def toRow(partitionSchema: StructType, defaultTimeZondId: String): InternalRow = { - val caseInsensitiveProperties = CaseInsensitiveMap(storage.properties) - val timeZoneId = caseInsensitiveProperties.getOrElse( - DateTimeUtils.TIMEZONE_OPTION, defaultTimeZondId) - InternalRow.fromSeq(partitionSchema.map { field => - val partValue = if (spec(field.name) == ExternalCatalogUtils.DEFAULT_PARTITION_NAME) { - null - } else { - spec(field.name) - } - Cast(Literal(partValue), field.dataType, Option(timeZoneId)).eval() - }) + def toRow(partitionSchema: StructType, defaultTimeZoneId: String): InternalRow = { + castPartitionValues(spec, partitionSchema, storage.properties, defaultTimeZoneId) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/util/PartitioningUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/util/PartitioningUtils.scala index 586aa6c59164f..3f2370146a88f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/util/PartitioningUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/util/PartitioningUtils.scala @@ -18,9 +18,15 @@ package org.apache.spark.sql.util import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.Resolver +import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec +import org.apache.spark.sql.catalyst.catalog.ExternalCatalogUtils +import org.apache.spark.sql.catalyst.expressions.{Cast, Literal} +import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils} +import org.apache.spark.sql.types.StructType -object PartitioningUtils { +private[sql] object PartitioningUtils { /** * Normalize the column names in partition specification, w.r.t. the real partition column names * and case sensitivity. e.g., if the partition spec has a column named `monTh`, and there is a @@ -44,4 +50,25 @@ object PartitioningUtils { normalizedPartSpec.toMap } + + /** + * Given the partition schema, returns a row with that schema holding the partition values. + */ + def castPartitionValues( + spec: TablePartitionSpec, + partitionSchema: StructType, + properties: Map[String, String], + defaultTimeZoneId: String): InternalRow = { + val caseInsensitiveProperties = CaseInsensitiveMap(properties) + val timeZoneId = caseInsensitiveProperties.getOrElse( + DateTimeUtils.TIMEZONE_OPTION, defaultTimeZoneId) + InternalRow.fromSeq(partitionSchema.map { field => + val partValue = if (spec(field.name) == ExternalCatalogUtils.DEFAULT_PARTITION_NAME) { + null + } else { + spec(field.name) + } + Cast(Literal(partValue), field.dataType, Option(timeZoneId)).eval() + }) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/AlterTablePartitionV2SQLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/AlterTablePartitionV2SQLSuite.scala index 4cacd5ec2b49e..6e68719056ce7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/AlterTablePartitionV2SQLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/AlterTablePartitionV2SQLSuite.scala @@ -243,4 +243,22 @@ class AlterTablePartitionV2SQLSuite extends DatasourceV2SQLBase { assert(!partTable.partitionExists(expectedPartition)) } } + + test("SPARK-33529: handle __HIVE_DEFAULT_PARTITION__") { + val t = "testpart.ns1.ns2.tbl" + withTable(t) { + sql(s"CREATE TABLE $t (part0 string) USING foo PARTITIONED BY (part0)") + val partTable = catalog("testpart") + .asTableCatalog + .loadTable(Identifier.of(Array("ns1", "ns2"), "tbl")) + .asPartitionable + val expectedPartition = InternalRow.fromSeq(Seq[Any](null)) + assert(!partTable.partitionExists(expectedPartition)) + val partSpec = "PARTITION (part0 = '__HIVE_DEFAULT_PARTITION__')" + sql(s"ALTER TABLE $t ADD $partSpec") + assert(partTable.partitionExists(expectedPartition)) + spark.sql(s"ALTER TABLE $t DROP $partSpec") + assert(!partTable.partitionExists(expectedPartition)) + } + } }