diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CurrentUserContext.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CurrentUserContext.scala index 16960db5f20e3..9dea473a34c4b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CurrentUserContext.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CurrentUserContext.scala @@ -17,8 +17,14 @@ package org.apache.spark.sql.catalyst +import org.apache.spark.util.Utils + object CurrentUserContext { val CURRENT_USER: InheritableThreadLocal[String] = new InheritableThreadLocal[String] { override protected def initialValue(): String = null } + + def getCurrentUser: String = Option(CURRENT_USER.get()).getOrElse(Utils.getCurrentUserName()) + + def getCurrentUserOrEmpty: String = Option(CURRENT_USER.get()).getOrElse("") } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala index 4b04cfddbe8ce..26b38676b0762 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala @@ -30,7 +30,7 @@ import org.json4s.jackson.JsonMethods._ import org.apache.spark.internal.Logging import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.catalyst.{FunctionIdentifier, InternalRow, SQLConfHelper, TableIdentifier} +import org.apache.spark.sql.catalyst.{CurrentUserContext, FunctionIdentifier, InternalRow, SQLConfHelper, TableIdentifier} import org.apache.spark.sql.catalyst.analysis.{MultiInstanceRelation, UnresolvedLeafNode} import org.apache.spark.sql.catalyst.catalog.CatalogTable.VIEW_STORING_ANALYZED_PLAN import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, AttributeReference, Cast, ExprId, Literal} @@ -241,7 +241,7 @@ case class CatalogTable( provider: Option[String] = None, partitionColumnNames: Seq[String] = Seq.empty, bucketSpec: Option[BucketSpec] = None, - owner: String = "", + owner: String = CurrentUserContext.getCurrentUserOrEmpty, createTime: Long = System.currentTimeMillis, lastAccessTime: Long = -1, createVersion: String = "", diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/finishAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/finishAnalysis.scala index 466781fa1def7..4052ccd64965d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/finishAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/finishAnalysis.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.optimizer import java.time.{Instant, LocalDateTime} -import org.apache.spark.sql.catalyst.CurrentUserContext.CURRENT_USER +import org.apache.spark.sql.catalyst.CurrentUserContext import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ @@ -29,7 +29,6 @@ import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.catalyst.util.DateTimeUtils.{convertSpecialDate, convertSpecialTimestamp, convertSpecialTimestampNTZ, instantToMicros, localDateTimeToMicros} import org.apache.spark.sql.connector.catalog.CatalogManager import org.apache.spark.sql.types._ -import org.apache.spark.util.Utils /** @@ -109,7 +108,7 @@ case class ReplaceCurrentLike(catalogManager: CatalogManager) extends Rule[Logic import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ val currentNamespace = catalogManager.currentNamespace.quoted val currentCatalog = catalogManager.currentCatalog.name() - val currentUser = Option(CURRENT_USER.get()).getOrElse(Utils.getCurrentUserName()) + val currentUser = CurrentUserContext.getCurrentUser plan.transformAllExpressionsWithPruning(_.containsPattern(CURRENT_LIKE)) { case CurrentDatabase() => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Util.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Util.scala index 72cb1e58c7ef7..f0f02c156ae6c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Util.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Util.scala @@ -23,6 +23,7 @@ import java.util.Collections import scala.jdk.CollectionConverters._ import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.CurrentUserContext import org.apache.spark.sql.catalyst.analysis.{AsOfTimestamp, AsOfVersion, NamedRelation, NoSuchDatabaseException, NoSuchFunctionException, NoSuchNamespaceException, NoSuchTableException, TimeTravelSpec} import org.apache.spark.sql.catalyst.expressions.Literal import org.apache.spark.sql.catalyst.plans.logical.{SerdeInfo, TableSpec} @@ -34,7 +35,6 @@ import org.apache.spark.sql.connector.expressions.LiteralValue import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation import org.apache.spark.sql.types.{ArrayType, MapType, Metadata, MetadataBuilder, StructField, StructType} import org.apache.spark.sql.util.CaseInsensitiveStringMap -import org.apache.spark.util.Utils private[sql] object CatalogV2Util { import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ @@ -423,7 +423,7 @@ private[sql] object CatalogV2Util { } def withDefaultOwnership(properties: Map[String, String]): Map[String, String] = { - properties ++ Map(TableCatalog.PROP_OWNER -> Utils.getCurrentUserName()) + properties ++ Map(TableCatalog.PROP_OWNER -> CurrentUserContext.getCurrentUser) } def getTableProviderCatalog( diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/CatalogSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/CatalogSuite.scala index d670053ba1b5d..0ac25e628a8ef 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/CatalogSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/CatalogSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.catalyst.analysis +import org.apache.spark.sql.catalyst.CurrentUserContext.CURRENT_USER import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.catalog.{CatalogStorageFormat, CatalogTable, CatalogTableType} import org.apache.spark.sql.types.StructType @@ -34,4 +35,20 @@ class CatalogSuite extends AnalysisTest { provider = Some("parquet")) table.toLinkedHashMap } + + test("SPARK-45454: Set table owner to current_user") { + val testOwner = "test_table_owner" + try { + CURRENT_USER.set(testOwner) + val table = CatalogTable( + identifier = TableIdentifier("tbl", Some("db1")), + tableType = CatalogTableType.MANAGED, + storage = CatalogStorageFormat.empty, + schema = new StructType().add("col1", "int").add("col2", "string"), + provider = Some("parquet")) + assert(table.owner === testOwner) + } finally { + CURRENT_USER.remove() + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala index 2855d7b06f52e..ae639b272a2ae 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala @@ -26,6 +26,7 @@ import scala.jdk.CollectionConverters._ import org.apache.spark.{SparkException, SparkUnsupportedOperationException} import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.CurrentUserContext.CURRENT_USER import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.{CannotReplaceMissingTableException, NoSuchDatabaseException, NoSuchNamespaceException, TableAlreadyExistsException} import org.apache.spark.sql.catalyst.parser.ParseException @@ -3294,6 +3295,18 @@ class DataSourceV2SQLSuiteV1Filter } } + test("SPARK-45454: Set table owner to current_user if it is set") { + val testOwner = "test_table_owner" + try { + CURRENT_USER.set(testOwner) + spark.sql("CREATE TABLE testcat.table_name (id int) USING foo") + val table = catalog("testcat").asTableCatalog.loadTable(Identifier.of(Array(), "table_name")) + assert(table.properties.get(TableCatalog.PROP_OWNER) === testOwner) + } finally { + CURRENT_USER.remove() + } + } + private def testNotSupportedV2Command( sqlCommand: String, sqlParams: String, diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/ThriftServerWithSparkContextSuite.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/ThriftServerWithSparkContextSuite.scala index 318328d71a807..0589f9de6097f 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/ThriftServerWithSparkContextSuite.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/ThriftServerWithSparkContextSuite.scala @@ -24,6 +24,7 @@ import org.apache.hive.service.cli.{GetInfoType, HiveSQLException, OperationHand import org.apache.spark.{ErrorMessageFormat, TaskKilled} import org.apache.spark.scheduler.{SparkListener, SparkListenerTaskEnd} +import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.internal.SQLConf trait ThriftServerWithSparkContextSuite extends SharedThriftServer { @@ -254,6 +255,21 @@ trait ThriftServerWithSparkContextSuite extends SharedThriftServer { assertThrows[SQLException](rs1.beforeFirst()) } } + + test("SPARK-45454: Set table owner to current_user") { + val testOwner = "test_table_owner" + val tableName = "t" + withTable(tableName) { + withCLIServiceClient(testOwner) { client => + val sessionHandle = client.openSession(testOwner, "") + val confOverlay = new java.util.HashMap[java.lang.String, java.lang.String] + val exec: String => OperationHandle = client.executeStatement(sessionHandle, _, confOverlay) + exec(s"CREATE TABLE $tableName(id int) using parquet") + val owner = spark.sessionState.catalog.getTableMetadata(TableIdentifier(tableName)).owner + assert(owner === testOwner) + } + } + } } class ThriftServerWithSparkContextInBinarySuite extends ThriftServerWithSparkContextSuite {