Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,14 @@

package org.apache.spark.sql.catalyst

import org.apache.spark.util.Utils

object CurrentUserContext {
val CURRENT_USER: InheritableThreadLocal[String] = new InheritableThreadLocal[String] {
override protected def initialValue(): String = null
}

def getCurrentUser: String = Option(CURRENT_USER.get()).getOrElse(Utils.getCurrentUserName())

def getCurrentUserOrEmpty: String = Option(CURRENT_USER.get()).getOrElse("")
}
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ import org.json4s.jackson.JsonMethods._

import org.apache.spark.internal.Logging
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.{FunctionIdentifier, InternalRow, SQLConfHelper, TableIdentifier}
import org.apache.spark.sql.catalyst.{CurrentUserContext, FunctionIdentifier, InternalRow, SQLConfHelper, TableIdentifier}
import org.apache.spark.sql.catalyst.analysis.{MultiInstanceRelation, UnresolvedLeafNode}
import org.apache.spark.sql.catalyst.catalog.CatalogTable.VIEW_STORING_ANALYZED_PLAN
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, AttributeReference, Cast, ExprId, Literal}
Expand Down Expand Up @@ -241,7 +241,7 @@ case class CatalogTable(
provider: Option[String] = None,
partitionColumnNames: Seq[String] = Seq.empty,
bucketSpec: Option[BucketSpec] = None,
owner: String = "",
owner: String = CurrentUserContext.getCurrentUserOrEmpty,
createTime: Long = System.currentTimeMillis,
lastAccessTime: Long = -1,
createVersion: String = "",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.optimizer

import java.time.{Instant, LocalDateTime}

import org.apache.spark.sql.catalyst.CurrentUserContext.CURRENT_USER
import org.apache.spark.sql.catalyst.CurrentUserContext
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules._
Expand All @@ -29,7 +29,6 @@ import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.catalyst.util.DateTimeUtils.{convertSpecialDate, convertSpecialTimestamp, convertSpecialTimestampNTZ, instantToMicros, localDateTimeToMicros}
import org.apache.spark.sql.connector.catalog.CatalogManager
import org.apache.spark.sql.types._
import org.apache.spark.util.Utils


/**
Expand Down Expand Up @@ -109,7 +108,7 @@ case class ReplaceCurrentLike(catalogManager: CatalogManager) extends Rule[Logic
import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._
val currentNamespace = catalogManager.currentNamespace.quoted
val currentCatalog = catalogManager.currentCatalog.name()
val currentUser = Option(CURRENT_USER.get()).getOrElse(Utils.getCurrentUserName())
val currentUser = CurrentUserContext.getCurrentUser

plan.transformAllExpressionsWithPruning(_.containsPattern(CURRENT_LIKE)) {
case CurrentDatabase() =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import java.util.Collections
import scala.jdk.CollectionConverters._

import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.CurrentUserContext
import org.apache.spark.sql.catalyst.analysis.{AsOfTimestamp, AsOfVersion, NamedRelation, NoSuchDatabaseException, NoSuchFunctionException, NoSuchNamespaceException, NoSuchTableException, TimeTravelSpec}
import org.apache.spark.sql.catalyst.expressions.Literal
import org.apache.spark.sql.catalyst.plans.logical.{SerdeInfo, TableSpec}
Expand All @@ -34,7 +35,6 @@ import org.apache.spark.sql.connector.expressions.LiteralValue
import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation
import org.apache.spark.sql.types.{ArrayType, MapType, Metadata, MetadataBuilder, StructField, StructType}
import org.apache.spark.sql.util.CaseInsensitiveStringMap
import org.apache.spark.util.Utils

private[sql] object CatalogV2Util {
import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._
Expand Down Expand Up @@ -423,7 +423,7 @@ private[sql] object CatalogV2Util {
}

def withDefaultOwnership(properties: Map[String, String]): Map[String, String] = {
properties ++ Map(TableCatalog.PROP_OWNER -> Utils.getCurrentUserName())
properties ++ Map(TableCatalog.PROP_OWNER -> CurrentUserContext.getCurrentUser)
}

def getTableProviderCatalog(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

package org.apache.spark.sql.catalyst.analysis

import org.apache.spark.sql.catalyst.CurrentUserContext.CURRENT_USER
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.catalog.{CatalogStorageFormat, CatalogTable, CatalogTableType}
import org.apache.spark.sql.types.StructType
Expand All @@ -34,4 +35,20 @@ class CatalogSuite extends AnalysisTest {
provider = Some("parquet"))
table.toLinkedHashMap
}

test("SPARK-45454: Set table owner to current_user") {
val testOwner = "test_table_owner"
try {
CURRENT_USER.set(testOwner)
val table = CatalogTable(
identifier = TableIdentifier("tbl", Some("db1")),
tableType = CatalogTableType.MANAGED,
storage = CatalogStorageFormat.empty,
schema = new StructType().add("col1", "int").add("col2", "string"),
provider = Some("parquet"))
assert(table.owner === testOwner)
} finally {
CURRENT_USER.remove()
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import scala.jdk.CollectionConverters._

import org.apache.spark.{SparkException, SparkUnsupportedOperationException}
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.CurrentUserContext.CURRENT_USER
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.{CannotReplaceMissingTableException, NoSuchDatabaseException, NoSuchNamespaceException, TableAlreadyExistsException}
import org.apache.spark.sql.catalyst.parser.ParseException
Expand Down Expand Up @@ -3294,6 +3295,18 @@ class DataSourceV2SQLSuiteV1Filter
}
}

test("SPARK-45454: Set table owner to current_user if it is set") {
val testOwner = "test_table_owner"
try {
CURRENT_USER.set(testOwner)
spark.sql("CREATE TABLE testcat.table_name (id int) USING foo")
val table = catalog("testcat").asTableCatalog.loadTable(Identifier.of(Array(), "table_name"))
assert(table.properties.get(TableCatalog.PROP_OWNER) === testOwner)
} finally {
CURRENT_USER.remove()
}
}

private def testNotSupportedV2Command(
sqlCommand: String,
sqlParams: String,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import org.apache.hive.service.cli.{GetInfoType, HiveSQLException, OperationHand

import org.apache.spark.{ErrorMessageFormat, TaskKilled}
import org.apache.spark.scheduler.{SparkListener, SparkListenerTaskEnd}
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.internal.SQLConf

trait ThriftServerWithSparkContextSuite extends SharedThriftServer {
Expand Down Expand Up @@ -254,6 +255,21 @@ trait ThriftServerWithSparkContextSuite extends SharedThriftServer {
assertThrows[SQLException](rs1.beforeFirst())
}
}

test("SPARK-45454: Set table owner to current_user") {
val testOwner = "test_table_owner"
val tableName = "t"
withTable(tableName) {
withCLIServiceClient(testOwner) { client =>
val sessionHandle = client.openSession(testOwner, "")
val confOverlay = new java.util.HashMap[java.lang.String, java.lang.String]
val exec: String => OperationHandle = client.executeStatement(sessionHandle, _, confOverlay)
exec(s"CREATE TABLE $tableName(id int) using parquet")
val owner = spark.sessionState.catalog.getTableMetadata(TableIdentifier(tableName)).owner
assert(owner === testOwner)
}
}
}
}

class ThriftServerWithSparkContextInBinarySuite extends ThriftServerWithSparkContextSuite {
Expand Down