diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalog.scala index 1a145c24d78c..dcc143982a4a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalog.scala @@ -128,6 +128,8 @@ trait ExternalCatalog { def getTable(db: String, table: String): CatalogTable + def getTablesByName(db: String, tables: Seq[String]): Seq[CatalogTable] + def tableExists(db: String, table: String): Boolean def listTables(db: String): Seq[String] diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogWithListener.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogWithListener.scala index 2f009be5816f..86113d3ec3ea 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogWithListener.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogWithListener.scala @@ -138,6 +138,10 @@ class ExternalCatalogWithListener(delegate: ExternalCatalog) delegate.getTable(db, table) } + override def getTablesByName(db: String, tables: Seq[String]): Seq[CatalogTable] = { + delegate.getTablesByName(db, tables) + } + override def tableExists(db: String, table: String): Boolean = { delegate.tableExists(db, table) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala index 741dc46b0738..abf69939dea1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala @@ -327,6 +327,11 @@ class InMemoryCatalog( catalog(db).tables(table).table } + override def getTablesByName(db: String, tables: Seq[String]): Seq[CatalogTable] = { + requireDbExists(db) + tables.flatMap(catalog(db).tables.get).map(_.table) + } + override def tableExists(db: String, table: String): Boolean = synchronized { requireDbExists(db) catalog(db).tables.contains(table) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala index c05f777770f3..e49e54f9bb31 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala @@ -434,6 +434,34 @@ class SessionCatalog( externalCatalog.getTable(db, table) } + /** + * Retrieve all metadata of existing permanent tables/views. If no database is specified, + * assume the table/view is in the current database. + * Only the tables/views belong to the same database that can be retrieved are returned. + * For example, if none of the requested tables could be retrieved, an empty list is returned. + * There is no guarantee of ordering of the returned tables. + */ + @throws[NoSuchDatabaseException] + def getTablesByName(names: Seq[TableIdentifier]): Seq[CatalogTable] = { + if (names.nonEmpty) { + val dbs = names.map(_.database.getOrElse(getCurrentDatabase)) + if (dbs.distinct.size != 1) { + val tables = names.map(name => formatTableName(name.table)) + val qualifiedTableNames = dbs.zip(tables).map { case (d, t) => QualifiedTableName(d, t)} + throw new AnalysisException( + s"Only the tables/views belong to the same database can be retrieved. Querying " + + s"tables/views are $qualifiedTableNames" + ) + } + val db = formatDatabaseName(dbs.head) + requireDbExists(db) + val tables = names.map(name => formatTableName(name.table)) + externalCatalog.getTablesByName(db, tables) + } else { + Seq.empty + } + } + /** * Load files stored in given path into an existing metastore table. * If no database is specified, assume the table is in the current database. diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogSuite.scala index b376108399c1..6b1c35094e4a 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogSuite.scala @@ -277,6 +277,28 @@ abstract class ExternalCatalogSuite extends SparkFunSuite with BeforeAndAfterEac } } + test("get tables by name") { + assert(newBasicCatalog().getTablesByName("db2", Seq("tbl1", "tbl2")) + .map(_.identifier.table) == Seq("tbl1", "tbl2")) + } + + test("get tables by name when some tables do not exists") { + assert(newBasicCatalog().getTablesByName("db2", Seq("tbl1", "tblnotexist")) + .map(_.identifier.table) == Seq("tbl1")) + } + + test("get tables by name when contains invalid name") { + // scalastyle:off + val name = "砖" + // scalastyle:on + assert(newBasicCatalog().getTablesByName("db2", Seq("tbl1", name)) + .map(_.identifier.table) == Seq("tbl1")) + } + + test("get tables by name when empty table list") { + assert(newBasicCatalog().getTablesByName("db2", Seq.empty).isEmpty) + } + test("list tables without pattern") { val catalog = newBasicCatalog() intercept[AnalysisException] { catalog.listTables("unknown_db") } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala index 92f87ea796e8..5a9e4adf6d88 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala @@ -509,6 +509,96 @@ abstract class SessionCatalogSuite extends AnalysisTest { } } + test("get tables by name") { + withBasicCatalog { catalog => + assert(catalog.getTablesByName( + Seq( + TableIdentifier("tbl1", Some("db2")), + TableIdentifier("tbl2", Some("db2")) + ) + ) == catalog.externalCatalog.getTablesByName("db2", Seq("tbl1", "tbl2"))) + // Get table without explicitly specifying database + catalog.setCurrentDatabase("db2") + assert(catalog.getTablesByName( + Seq( + TableIdentifier("tbl1"), + TableIdentifier("tbl2") + ) + ) == catalog.externalCatalog.getTablesByName("db2", Seq("tbl1", "tbl2"))) + } + } + + test("get tables by name when some tables do not exist") { + withBasicCatalog { catalog => + assert(catalog.getTablesByName( + Seq( + TableIdentifier("tbl1", Some("db2")), + TableIdentifier("tblnotexit", Some("db2")) + ) + ) == catalog.externalCatalog.getTablesByName("db2", Seq("tbl1"))) + // Get table without explicitly specifying database + catalog.setCurrentDatabase("db2") + assert(catalog.getTablesByName( + Seq( + TableIdentifier("tbl1"), + TableIdentifier("tblnotexit") + ) + ) == catalog.externalCatalog.getTablesByName("db2", Seq("tbl1"))) + } + } + + test("get tables by name when contains invalid name") { + // scalastyle:off + val name = "砖" + // scalastyle:on + withBasicCatalog { catalog => + assert(catalog.getTablesByName( + Seq( + TableIdentifier("tbl1", Some("db2")), + TableIdentifier(name, Some("db2")) + ) + ) == catalog.externalCatalog.getTablesByName("db2", Seq("tbl1"))) + // Get table without explicitly specifying database + catalog.setCurrentDatabase("db2") + assert(catalog.getTablesByName( + Seq( + TableIdentifier("tbl1"), + TableIdentifier(name) + ) + ) == catalog.externalCatalog.getTablesByName("db2", Seq("tbl1"))) + } + } + + test("get tables by name when empty") { + withBasicCatalog { catalog => + assert(catalog.getTablesByName(Seq.empty) + == catalog.externalCatalog.getTablesByName("db2", Seq.empty)) + // Get table without explicitly specifying database + catalog.setCurrentDatabase("db2") + assert(catalog.getTablesByName(Seq.empty) + == catalog.externalCatalog.getTablesByName("db2", Seq.empty)) + } + } + + test("get tables by name when tables belong to different databases") { + withBasicCatalog { catalog => + intercept[AnalysisException](catalog.getTablesByName( + Seq( + TableIdentifier("tbl1", Some("db1")), + TableIdentifier("tbl2", Some("db2")) + ) + )) + // Get table without explicitly specifying database + catalog.setCurrentDatabase("db2") + intercept[AnalysisException](catalog.getTablesByName( + Seq( + TableIdentifier("tbl1", Some("db1")), + TableIdentifier("tbl2") + ) + )) + } + } + test("lookup table relation") { withBasicCatalog { catalog => val tempTable1 = Range(1, 10, 1, 10) diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkGetTablesOperation.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkGetTablesOperation.scala index e0b610423cce..56f89dfeb360 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkGetTablesOperation.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkGetTablesOperation.scala @@ -74,8 +74,7 @@ private[hive] class SparkGetTablesOperation( val tablePattern = convertIdentifierPattern(tableName, true) matchingDbs.foreach { dbName => - catalog.listTables(dbName, tablePattern).foreach { tableIdentifier => - val catalogTable = catalog.getTableMetadata(tableIdentifier) + catalog.getTablesByName(catalog.listTables(dbName, tablePattern)).foreach { catalogTable => val tableType = tableTypeString(catalogTable.tableType) if (tableTypes == null || tableTypes.isEmpty || tableTypes.contains(tableType)) { val rowData = Array[AnyRef]( diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala index 11a219231875..d4df35c8ec69 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala @@ -120,6 +120,10 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat client.getTable(db, table) } + private[hive] def getRawTablesByNames(db: String, tables: Seq[String]): Seq[CatalogTable] = { + client.getTablesByName(db, tables) + } + /** * If the given table properties contains datasource properties, throw an exception. We will do * this check when create or alter a table, i.e. when we try to write table metadata to Hive @@ -702,6 +706,10 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat restoreTableMetadata(getRawTable(db, table)) } + override def getTablesByName(db: String, tables: Seq[String]): Seq[CatalogTable] = withClient { + getRawTablesByNames(db, tables).map(restoreTableMetadata) + } + /** * Restores table metadata from the table properties. This method is kind of a opposite version * of [[createTable]]. diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClient.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClient.scala index e1280d024638..cb015d7301c1 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClient.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClient.scala @@ -84,6 +84,9 @@ private[hive] trait HiveClient { /** Returns the metadata for the specified table or None if it doesn't exist. */ def getTableOption(dbName: String, tableName: String): Option[CatalogTable] + /** Returns metadata of existing permanent tables/views for given names. */ + def getTablesByName(dbName: String, tableNames: Seq[String]): Seq[CatalogTable] + /** Creates a table with the given metadata. */ def createTable(table: CatalogTable, ignoreIfExists: Boolean): Unit diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala index b8d5f2148df1..b42574f116a9 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala @@ -30,14 +30,17 @@ import org.apache.hadoop.fs.Path import org.apache.hadoop.hive.common.StatsSetupConst import org.apache.hadoop.hive.conf.HiveConf import org.apache.hadoop.hive.conf.HiveConf.ConfVars -import org.apache.hadoop.hive.metastore.{TableType => HiveTableType} -import org.apache.hadoop.hive.metastore.api.{Database => HiveDatabase, FieldSchema, Order} -import org.apache.hadoop.hive.metastore.api.{SerDeInfo, StorageDescriptor} +import org.apache.hadoop.hive.metastore.{IMetaStoreClient, TableType => HiveTableType} +import org.apache.hadoop.hive.metastore.api.{Database => HiveDatabase, Table => MetaStoreApiTable} +import org.apache.hadoop.hive.metastore.api.{FieldSchema, Order, SerDeInfo, StorageDescriptor} import org.apache.hadoop.hive.ql.Driver -import org.apache.hadoop.hive.ql.metadata.{Hive, Partition => HivePartition, Table => HiveTable} +import org.apache.hadoop.hive.ql.metadata.{Hive, HiveException, Partition => HivePartition, Table => HiveTable} import org.apache.hadoop.hive.ql.parse.BaseSemanticAnalyzer.HIVE_COLUMN_ORDER_ASC import org.apache.hadoop.hive.ql.processors._ import org.apache.hadoop.hive.ql.session.SessionState +import org.apache.hadoop.hive.serde.serdeConstants +import org.apache.hadoop.hive.serde2.MetadataTypedColumnsetSerDe +import org.apache.hadoop.hive.serde2.`lazy`.LazySimpleSerDe import org.apache.spark.{SparkConf, SparkException} import org.apache.spark.internal.Logging @@ -276,6 +279,10 @@ private[hive] class HiveClientImpl( } } + private def msClient: IMetaStoreClient = { + shim.getMSC(client) + } + /** Return the associated Hive [[SessionState]] of this [[HiveClientImpl]] */ override def getState: SessionState = withHiveState(state) @@ -384,10 +391,26 @@ private[hive] class HiveClientImpl( Option(client.getTable(dbName, tableName, false /* do not throw exception */)) } + private def getRawTablesByName(dbName: String, tableNames: Seq[String]): Seq[HiveTable] = { + try { + msClient.getTableObjectsByName(dbName, tableNames.asJava).asScala + .map(extraFixesForNonView).map(new HiveTable(_)) + } catch { + case ex: Exception => + throw new HiveException(s"Unable to fetch tables of db $dbName", ex); + } + } + override def tableExists(dbName: String, tableName: String): Boolean = withHiveState { getRawTableOption(dbName, tableName).nonEmpty } + override def getTablesByName( + dbName: String, + tableNames: Seq[String]): Seq[CatalogTable] = withHiveState { + getRawTablesByName(dbName, tableNames).map(convertHiveTableToCatalogTable) + } + override def getTableOption( dbName: String, tableName: String): Option[CatalogTable] = withHiveState { @@ -1091,6 +1114,40 @@ private[hive] object HiveClientImpl { stats = readHiveStats(properties)) } + /** + * This is the same process copied from the method `getTable()` + * of [[org.apache.hadoop.hive.ql.metadata.Hive]] to do some extra fixes for non-views. + * Methods of extracting multiple [[HiveTable]] like `getRawTablesByName()` + * should invoke this before return. + */ + def extraFixesForNonView(tTable: MetaStoreApiTable): MetaStoreApiTable = { + // For non-views, we need to do some extra fixes + if (!(HiveTableType.VIRTUAL_VIEW.toString == tTable.getTableType)) { + // Fix the non-printable chars + val parameters = tTable.getSd.getParameters + if (parameters != null) { + val sf = parameters.get(serdeConstants.SERIALIZATION_FORMAT) + if (sf != null) { + val b: Array[Char] = sf.toCharArray + if ((b.length == 1) && (b(0) < 10)) { // ^A, ^B, ^C, ^D, \t + parameters.put(serdeConstants.SERIALIZATION_FORMAT, Integer.toString(b(0))) + } + } + } + // Use LazySimpleSerDe for MetadataTypedColumnsetSerDe. + // NOTE: LazySimpleSerDe does not support tables with a single column of col + // of type "array". This happens when the table is created using + // an earlier version of Hive. + if (classOf[MetadataTypedColumnsetSerDe].getName == + tTable.getSd.getSerdeInfo.getSerializationLib && + tTable.getSd.getColsSize > 0 && + tTable.getSd.getCols.get(0).getType.indexOf('<') == -1) { + tTable.getSd.getSerdeInfo.setSerializationLib(classOf[LazySimpleSerDe].getName) + } + } + tTable + } + /** * Reads statistics from Hive. * Note that this statistics could be overridden by Spark's statistics if that's available. diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala index 18f8c5360981..80ab6855da93 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala @@ -28,6 +28,7 @@ import scala.util.control.NonFatal import org.apache.hadoop.fs.Path import org.apache.hadoop.hive.conf.HiveConf +import org.apache.hadoop.hive.metastore.IMetaStoreClient import org.apache.hadoop.hive.metastore.api.{EnvironmentContext, Function => HiveFunction, FunctionType} import org.apache.hadoop.hive.metastore.api.{MetaException, PrincipalType, ResourceType, ResourceUri} import org.apache.hadoop.hive.ql.Driver @@ -160,6 +161,8 @@ private[client] sealed abstract class Shim { method } + def getMSC(hive: Hive): IMetaStoreClient + protected def findMethod(klass: Class[_], name: String, args: Class[_]*): Method = { klass.getMethod(name, args: _*) } @@ -171,6 +174,17 @@ private[client] class Shim_v0_12 extends Shim with Logging { // deletes the underlying data along with metadata protected lazy val deleteDataInDropIndex = JBoolean.TRUE + protected lazy val getMSCMethod = { + // Since getMSC() in Hive 0.12 is private, findMethod() could not work here + val msc = classOf[Hive].getDeclaredMethod("getMSC") + msc.setAccessible(true) + msc + } + + override def getMSC(hive: Hive): IMetaStoreClient = { + getMSCMethod.invoke(hive).asInstanceOf[IMetaStoreClient] + } + private lazy val startMethod = findStaticMethod( classOf[SessionState], diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala index 9861a0af0482..4c20af9103f9 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala @@ -237,6 +237,33 @@ class VersionsSuite extends SparkFunSuite with Logging { assert(client.getTableOption("default", "src").isDefined) } + test(s"$version: getTablesByName") { + assert(client.getTablesByName("default", Seq("src")).head + == client.getTableOption("default", "src").get) + } + + test(s"$version: getTablesByName when multiple tables") { + assert(client.getTablesByName("default", Seq("src", "temporary")) + .map(_.identifier.table) == Seq("src", "temporary")) + } + + test(s"$version: getTablesByName when some tables do not exist") { + assert(client.getTablesByName("default", Seq("src", "notexist")) + .map(_.identifier.table) == Seq("src")) + } + + test(s"$version: getTablesByName when contains invalid name") { + // scalastyle:off + val name = "砖" + // scalastyle:on + assert(client.getTablesByName("default", Seq("src", name)) + .map(_.identifier.table) == Seq("src")) + } + + test(s"$version: getTablesByName when empty") { + assert(client.getTablesByName("default", Seq.empty).isEmpty) + } + test(s"$version: alterTable(table: CatalogTable)") { val newTable = client.getTable("default", "src").copy(properties = Map("changed" -> "")) client.alterTable(newTable)