Skip to content

Commit 4d4d0de

Browse files
committed
[SPARK-19279][SQL][FOLLOW-UP] Infer Schema for Hive Serde Tables
### What changes were proposed in this pull request? `table.schema` is always not empty for partitioned tables, because `table.schema` also contains the partitioned columns, even if the original table does not have any column. This PR is to fix the issue. ### How was this patch tested? Added a test case Author: gatorsmile <[email protected]> Closes #16848 from gatorsmile/inferHiveSerdeSchema.
1 parent 0077bfc commit 4d4d0de

File tree

3 files changed

+53
-1
lines changed

3 files changed

+53
-1
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -194,6 +194,14 @@ case class CatalogTable(
194194
StructType(partitionFields)
195195
}
196196

197+
/**
198+
* schema of this table's data columns
199+
*/
200+
def dataSchema: StructType = {
201+
val dataFields = schema.dropRight(partitionColumnNames.length)
202+
StructType(dataFields)
203+
}
204+
197205
/** Return the database this table was specified to belong to, assuming it exists. */
198206
def database: String = identifier.database.getOrElse {
199207
throw new AnalysisException(s"table $identifier did not specify database")

sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -580,7 +580,7 @@ private[spark] object HiveUtils extends Logging {
580580
* CatalogTable.
581581
*/
582582
def inferSchema(table: CatalogTable): CatalogTable = {
583-
if (DDLUtils.isDatasourceTable(table) || table.schema.nonEmpty) {
583+
if (DDLUtils.isDatasourceTable(table) || table.dataSchema.nonEmpty) {
584584
table
585585
} else {
586586
val hiveTable = toHiveTable(table)

sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ import org.apache.spark.SparkContext
2727
import org.apache.spark.sql._
2828
import org.apache.spark.sql.catalyst.TableIdentifier
2929
import org.apache.spark.sql.catalyst.catalog.{CatalogStorageFormat, CatalogTable, CatalogTableType}
30+
import org.apache.spark.sql.execution.command.CreateTableCommand
3031
import org.apache.spark.sql.execution.datasources.{HadoopFsRelation, LogicalRelation}
3132
import org.apache.spark.sql.hive.HiveExternalCatalog._
3233
import org.apache.spark.sql.hive.client.HiveClient
@@ -1308,6 +1309,49 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with TestHiv
13081309
}
13091310
}
13101311

1312+
test("Infer schema for Hive serde tables") {
1313+
val tableName = "tab1"
1314+
val avroSchema =
1315+
"""{
1316+
| "name": "test_record",
1317+
| "type": "record",
1318+
| "fields": [ {
1319+
| "name": "f0",
1320+
| "type": "int"
1321+
| }]
1322+
|}
1323+
""".stripMargin
1324+
1325+
Seq(true, false).foreach { isPartitioned =>
1326+
withTable(tableName) {
1327+
val partitionClause = if (isPartitioned) "PARTITIONED BY (ds STRING)" else ""
1328+
// Creates the (non-)partitioned Avro table
1329+
val plan = sql(
1330+
s"""
1331+
|CREATE TABLE $tableName
1332+
|$partitionClause
1333+
|ROW FORMAT SERDE 'org.apache.hadoop.hive.serde2.avro.AvroSerDe'
1334+
|STORED AS
1335+
| INPUTFORMAT 'org.apache.hadoop.hive.ql.io.avro.AvroContainerInputFormat'
1336+
| OUTPUTFORMAT 'org.apache.hadoop.hive.ql.io.avro.AvroContainerOutputFormat'
1337+
|TBLPROPERTIES ('avro.schema.literal' = '$avroSchema')
1338+
""".stripMargin
1339+
).queryExecution.analyzed
1340+
1341+
assert(plan.isInstanceOf[CreateTableCommand] &&
1342+
plan.asInstanceOf[CreateTableCommand].table.dataSchema.nonEmpty)
1343+
1344+
if (isPartitioned) {
1345+
sql(s"INSERT OVERWRITE TABLE $tableName partition (ds='a') SELECT 1")
1346+
checkAnswer(spark.table(tableName), Row(1, "a"))
1347+
} else {
1348+
sql(s"INSERT OVERWRITE TABLE $tableName SELECT 1")
1349+
checkAnswer(spark.table(tableName), Row(1))
1350+
}
1351+
}
1352+
}
1353+
}
1354+
13111355
private def withDebugMode(f: => Unit): Unit = {
13121356
val previousValue = sparkSession.sparkContext.conf.get(DEBUG_MODE)
13131357
try {

0 commit comments

Comments
 (0)