Skip to content

Commit a01f6b3

Browse files
committed
address comments
1 parent 7c1a8e5 commit a01f6b3

File tree

7 files changed

+60
-43
lines changed

7 files changed

+60
-43
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/NoSuchItemException.scala

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,9 @@ class NoSuchDatabaseException(db: String) extends AnalysisException(s"Database '
3030
class NoSuchTableException(db: String, table: String)
3131
extends AnalysisException(s"Table or view '$table' not found in database '$db'")
3232

33+
class NoSuchTempViewException(table: String)
34+
extends AnalysisException(s"Temporary view '$table' not found")
35+
3336
class NoSuchPartitionException(
3437
db: String,
3538
table: String,

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala

Lines changed: 12 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -246,36 +246,27 @@ class SessionCatalog(
246246
}
247247

248248
/**
249-
* Retrieve the metadata of an existing metastore table/view or a temporary view.
250-
* If no database is specified, we check whether the corresponding temporary view exists.
251-
* If the temporary view does not exist, we assume the table/view is in the current database.
252-
* If still not found in the database then a [[NoSuchTableException]] is thrown.
249+
* Retrieve the metadata of an existing temporary view.
250+
* If the temporary view does not exist, a [[NoSuchTempViewException]] is thrown.
253251
*/
254-
def getTableMetadata(name: TableIdentifier): CatalogTable = {
255-
val db = formatDatabaseName(name.database.getOrElse(getCurrentDatabase))
256-
val table = formatTableName(name.table)
257-
val tid = TableIdentifier(table)
258-
if (isTemporaryTable(name)) {
259-
CatalogTable(
260-
identifier = tid,
261-
tableType = CatalogTableType.VIEW,
262-
storage = CatalogStorageFormat.empty,
263-
schema = tempTables(table).output.toStructType,
264-
properties = Map(),
265-
viewText = None)
266-
} else {
267-
requireDbExists(db)
268-
requireTableExists(TableIdentifier(table, Some(db)))
269-
externalCatalog.getTable(db, table)
252+
def getTempViewMetadata(name: String): CatalogTable = {
253+
val table = formatTableName(name)
254+
if (!tempTables.contains(table)) {
255+
throw new NoSuchTempViewException(table)
270256
}
257+
CatalogTable(
258+
identifier = TableIdentifier(table),
259+
tableType = CatalogTableType.VIEW,
260+
storage = CatalogStorageFormat.empty,
261+
schema = tempTables(table).output.toStructType)
271262
}
272263

273264
/**
274265
* Retrieve the metadata of an existing permanent table/view. If no database is specified,
275266
* assume the table/view is in the current database. If the specified table/view is not found
276267
* in the database then a [[NoSuchTableException]] is thrown.
277268
*/
278-
def getNonTempTableMetadata(name: TableIdentifier): CatalogTable = {
269+
def getTableMetadata(name: TableIdentifier): CatalogTable = {
279270
val db = formatDatabaseName(name.database.getOrElse(getCurrentDatabase))
280271
val table = formatTableName(name.table)
281272
requireDbExists(db)

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -444,7 +444,7 @@ class SessionCatalogSuite extends SparkFunSuite {
444444
assert(!catalog.tableExists(TableIdentifier("view1", Some("default"))))
445445
}
446446

447-
test("getTableMetadata on temporary views") {
447+
test("getTableMetadata and getTempViewMetadata on temporary views") {
448448
val catalog = new SessionCatalog(newBasicCatalog())
449449
val tempTable = Range(1, 10, 2, 10)
450450
val m = intercept[AnalysisException] {
@@ -457,9 +457,16 @@ class SessionCatalogSuite extends SparkFunSuite {
457457
}.getMessage
458458
assert(m2.contains("Table or view 'view1' not found in database 'default'"))
459459

460+
intercept[NoSuchTempViewException] {
461+
catalog.getTempViewMetadata("view1")
462+
}.getMessage
463+
460464
catalog.createTempView("view1", tempTable, overrideIfExists = false)
461-
assert(catalog.getTableMetadata(TableIdentifier("view1")).identifier.table == "view1")
462-
assert(catalog.getTableMetadata(TableIdentifier("view1")).schema(0).name == "id")
465+
assert(catalog.getTempViewMetadata("view1").identifier === TableIdentifier("view1"))
466+
467+
intercept[NoSuchTableException] {
468+
catalog.getTableMetadata(TableIdentifier("view1"))
469+
}.getMessage
463470

464471
val m3 = intercept[AnalysisException] {
465472
catalog.getTableMetadata(TableIdentifier("view1", Some("default")))

sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -265,7 +265,7 @@ case class AlterTableUnsetPropertiesCommand(
265265
override def run(sparkSession: SparkSession): Seq[Row] = {
266266
val catalog = sparkSession.sessionState.catalog
267267
DDLUtils.verifyAlterTableType(catalog, tableName, isView)
268-
val table = catalog.getNonTempTableMetadata(tableName)
268+
val table = catalog.getTableMetadata(tableName)
269269

270270
if (!ifExists) {
271271
propKeys.foreach { k =>
@@ -305,7 +305,7 @@ case class AlterTableSerDePropertiesCommand(
305305

306306
override def run(sparkSession: SparkSession): Seq[Row] = {
307307
val catalog = sparkSession.sessionState.catalog
308-
val table = catalog.getNonTempTableMetadata(tableName)
308+
val table = catalog.getTableMetadata(tableName)
309309
// For datasource tables, disallow setting serde or specifying partition
310310
if (partSpec.isDefined && DDLUtils.isDatasourceTable(table)) {
311311
throw new AnalysisException("Operation not allowed: ALTER TABLE SET " +
@@ -354,7 +354,7 @@ case class AlterTableAddPartitionCommand(
354354

355355
override def run(sparkSession: SparkSession): Seq[Row] = {
356356
val catalog = sparkSession.sessionState.catalog
357-
val table = catalog.getNonTempTableMetadata(tableName)
357+
val table = catalog.getTableMetadata(tableName)
358358
if (DDLUtils.isDatasourceTable(table)) {
359359
throw new AnalysisException(
360360
"ALTER TABLE ADD PARTITION is not allowed for tables defined using the datasource API")
@@ -414,7 +414,7 @@ case class AlterTableDropPartitionCommand(
414414

415415
override def run(sparkSession: SparkSession): Seq[Row] = {
416416
val catalog = sparkSession.sessionState.catalog
417-
val table = catalog.getNonTempTableMetadata(tableName)
417+
val table = catalog.getTableMetadata(tableName)
418418
if (DDLUtils.isDatasourceTable(table)) {
419419
throw new AnalysisException(
420420
"ALTER TABLE DROP PARTITIONS is not allowed for tables defined using the datasource API")
@@ -468,7 +468,7 @@ case class AlterTableRecoverPartitionsCommand(
468468

469469
override def run(spark: SparkSession): Seq[Row] = {
470470
val catalog = spark.sessionState.catalog
471-
val table = catalog.getNonTempTableMetadata(tableName)
471+
val table = catalog.getTableMetadata(tableName)
472472
val qualifiedName = table.identifier.quotedString
473473

474474
if (DDLUtils.isDatasourceTable(table)) {
@@ -645,7 +645,7 @@ case class AlterTableSetLocationCommand(
645645

646646
override def run(sparkSession: SparkSession): Seq[Row] = {
647647
val catalog = sparkSession.sessionState.catalog
648-
val table = catalog.getNonTempTableMetadata(tableName)
648+
val table = catalog.getTableMetadata(tableName)
649649
partitionSpec match {
650650
case Some(spec) =>
651651
// Partition spec is specified, so we set the location only for this partition

sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala

Lines changed: 22 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,11 @@ case class CreateTableLikeCommand(
6464
s"Source table in CREATE TABLE LIKE does not exist: '$sourceTable'")
6565
}
6666

67-
val sourceTableDesc = catalog.getTableMetadata(sourceTable)
67+
val sourceTableDesc = if (catalog.isTemporaryTable(sourceTable)) {
68+
catalog.getTempViewMetadata(sourceTable.table)
69+
} else {
70+
catalog.getTableMetadata(sourceTable)
71+
}
6872

6973
// Storage format
7074
val newStorage =
@@ -176,7 +180,11 @@ case class AlterTableRenameCommand(
176180
}
177181
}
178182
// For datasource tables, we also need to update the "path" serde property
179-
val table = catalog.getTableMetadata(oldName)
183+
val table = if (catalog.isTemporaryTable(oldName)) {
184+
catalog.getTempViewMetadata(oldName.table)
185+
} else {
186+
catalog.getTableMetadata(oldName)
187+
}
180188
if (DDLUtils.isDatasourceTable(table) && table.tableType == CatalogTableType.MANAGED) {
181189
val newPath = catalog.defaultTablePath(newTblName)
182190
val newTable = table.withNewStorage(
@@ -214,7 +222,7 @@ case class LoadDataCommand(
214222

215223
override def run(sparkSession: SparkSession): Seq[Row] = {
216224
val catalog = sparkSession.sessionState.catalog
217-
val targetTable = catalog.getNonTempTableMetadata(table)
225+
val targetTable = catalog.getTableMetadata(table)
218226
val qualifiedName = targetTable.identifier.quotedString
219227

220228
if (targetTable.tableType == CatalogTableType.VIEW) {
@@ -333,7 +341,7 @@ case class TruncateTableCommand(
333341

334342
override def run(spark: SparkSession): Seq[Row] = {
335343
val catalog = spark.sessionState.catalog
336-
val table = catalog.getNonTempTableMetadata(tableName)
344+
val table = catalog.getTableMetadata(tableName)
337345
val qualifiedName = table.identifier.quotedString
338346

339347
if (table.tableType == CatalogTableType.EXTERNAL) {
@@ -592,13 +600,19 @@ case class ShowTablePropertiesCommand(table: TableIdentifier, propertyKey: Optio
592600
* SHOW COLUMNS (FROM | IN) table_identifier [(FROM | IN) database];
593601
* }}}
594602
*/
595-
case class ShowColumnsCommand(table: TableIdentifier) extends RunnableCommand {
603+
case class ShowColumnsCommand(tableName: TableIdentifier) extends RunnableCommand {
596604
override val output: Seq[Attribute] = {
597605
AttributeReference("col_name", StringType, nullable = false)() :: Nil
598606
}
599607

600608
override def run(sparkSession: SparkSession): Seq[Row] = {
601-
sparkSession.sessionState.catalog.getTableMetadata(table).schema.map { c =>
609+
val catalog = sparkSession.sessionState.catalog
610+
val table = if (catalog.isTemporaryTable(tableName)) {
611+
catalog.getTempViewMetadata(tableName.table)
612+
} else {
613+
catalog.getTableMetadata(tableName)
614+
}
615+
table.schema.map { c =>
602616
Row(c.name)
603617
}
604618
}
@@ -634,7 +648,7 @@ case class ShowPartitionsCommand(
634648

635649
override def run(sparkSession: SparkSession): Seq[Row] = {
636650
val catalog = sparkSession.sessionState.catalog
637-
val table = catalog.getNonTempTableMetadata(tableName)
651+
val table = catalog.getTableMetadata(tableName)
638652
val qualifiedName = table.identifier.quotedString
639653

640654
/**
@@ -686,9 +700,7 @@ case class ShowCreateTableCommand(table: TableIdentifier) extends RunnableComman
686700

687701
override def run(sparkSession: SparkSession): Seq[Row] = {
688702
val catalog = sparkSession.sessionState.catalog
689-
val db = table.database.getOrElse(catalog.getCurrentDatabase)
690-
val qualifiedName = TableIdentifier(table.table, Some(db))
691-
val tableMetadata = catalog.getTableMetadata(qualifiedName)
703+
val tableMetadata = catalog.getTableMetadata(table)
692704

693705
// TODO: unify this after we unify the CREATE TABLE syntax for hive serde and data source table.
694706
val stmt = if (DDLUtils.isDatasourceTable(tableMetadata)) {

sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,12 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog {
151151
}
152152

153153
private def listColumns(tableIdentifier: TableIdentifier): Dataset[Column] = {
154-
val tableMetadata = sessionCatalog.getTableMetadata(tableIdentifier)
154+
val tableMetadata = if (sessionCatalog.isTemporaryTable(tableIdentifier)) {
155+
sessionCatalog.getTempViewMetadata(tableIdentifier.table)
156+
} else {
157+
sessionCatalog.getTableMetadata(tableIdentifier)
158+
}
159+
155160
val partitionColumnNames = tableMetadata.partitionColumnNames.toSet
156161
val bucketColumnNames = tableMetadata.bucketSpec.map(_.bucketColumnNames).getOrElse(Nil).toSet
157162
val columns = tableMetadata.schema.map { c =>

sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -671,8 +671,7 @@ class HiveDDLSuite
671671
.createTempView(sourceViewName)
672672
sql(s"CREATE TABLE $targetTabName LIKE $sourceViewName")
673673

674-
val sourceTable = spark.sessionState.catalog.getTableMetadata(
675-
TableIdentifier(sourceViewName, None))
674+
val sourceTable = spark.sessionState.catalog.getTempViewMetadata(sourceViewName)
676675
val targetTable = spark.sessionState.catalog.getTableMetadata(
677676
TableIdentifier(targetTabName, Some("default")))
678677

0 commit comments

Comments
 (0)