diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala index b43d282bd434..f32107203454 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala @@ -163,8 +163,7 @@ case class DataSource( // maintain old behavior before SPARK-18510. If userSpecifiedSchema is empty used inferred // partitioning if (userSpecifiedSchema.isEmpty) { - val inferredPartitions = tempFileIndex.partitionSchema - inferredPartitions + tempFileIndex.partitionSchema } else { val partitionFields = partitionColumns.map { partitionColumn => userSpecifiedSchema.flatMap(_.find(c => equality(c.name, partitionColumn))).orElse { @@ -357,7 +356,11 @@ case class DataSource( } else { tempFileCatalog } - val dataSchema = userSpecifiedSchema.orElse { + + val partitionSchema = fileCatalog.partitionSchema + val dataSchema = userSpecifiedSchema.map { schema => + StructType(schema.filterNot(f => partitionSchema.exists(p => equality(p.name, f.name)))) + }.orElse { format.inferSchema( sparkSession, caseInsensitiveOptions, @@ -370,7 +373,7 @@ case class DataSource( HadoopFsRelation( fileCatalog, - partitionSchema = fileCatalog.partitionSchema, + partitionSchema = partitionSchema, dataSchema = dataSchema, bucketSpec = None, format, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/HadoopFsRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/HadoopFsRelation.scala index 9a08524476ba..566eed5962a8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/HadoopFsRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/HadoopFsRelation.scala @@ -59,8 +59,8 @@ case class HadoopFsRelation( overlappedPartCols += getColName(partitionField) -> partitionField } } - StructType(dataSchema.map(f => overlappedPartCols.getOrElse(getColName(f), f)) ++ - partitionSchema.filterNot(f => overlappedPartCols.contains(getColName(f)))) + StructType(dataSchema.filterNot(f => overlappedPartCols.contains(getColName(f))) ++ + partitionSchema.map(f => overlappedPartCols.getOrElse(getColName(f), f))) } def partitionSchemaOption: Option[StructType] = diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala index f79b92b804c7..cfbd915396d1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala @@ -519,7 +519,7 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha i <- 1 to 10 pi <- Seq(1, 2) ps <- Seq("foo", "bar") - } yield Row(i, pi, i.toString, ps)) + } yield Row(i, i.toString, pi, ps)) checkAnswer( sql("SELECT intField, pi FROM t"), @@ -534,14 +534,14 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha for { i <- 1 to 10 ps <- Seq("foo", "bar") - } yield Row(i, 1, i.toString, ps)) + } yield Row(i, i.toString, 1, ps)) checkAnswer( sql("SELECT * FROM t WHERE ps = 'foo'"), for { i <- 1 to 10 pi <- Seq(1, 2) - } yield Row(i, pi, i.toString, "foo")) + } yield Row(i, i.toString, pi, "foo")) } } } @@ -608,14 +608,14 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha i <- 1 to 10 pi <- Seq(1, 2) ps <- Seq("foo", null.asInstanceOf[String]) - } yield Row(i, pi, i.toString, ps)) + } yield Row(i, i.toString, pi, ps)) checkAnswer( sql("SELECT * FROM t WHERE ps IS NULL"), for { i <- 1 to 10 pi <- Seq(1, 2) - } yield Row(i, pi, i.toString, null)) + } yield Row(i, i.toString, pi, null)) } } } @@ -1019,7 +1019,9 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha val path = dir.getCanonicalPath val df = Seq((1L, 2.0)).toDF("a", "b") df.write.parquet(s"$path/a=1") - checkAnswer(spark.read.parquet(s"$path"), Seq(Row(1, 2.0))) + // partition columns are always at the end of the schema. + assert(spark.read.parquet(s"$path").columns === Array("b", "a")) + checkAnswer(spark.read.parquet(s"$path"), Seq(Row(2.0, 1))) } } } @@ -1048,7 +1050,7 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha checkAnswer( readBack, - Seq(Row("2017-01-01-00", 1), Row("2017-01-01-01", 2)) + Seq(Row(1, "2017-01-01-00"), Row(2, "2017-01-01-01")) ) } finally { spark.streams.active.foreach(_.stop()) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index f0f2c493498b..59f1bd7c6a6f 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -164,13 +164,12 @@ private[hive] class HiveMetastoreCatalog(sparkSession: SparkSession) extends Log } } - val (dataSchema, updatedTable) = - inferIfNeeded(relation, options, fileFormat, Option(fileIndex)) + val updatedTable = inferIfNeeded(relation, options, fileFormat, Option(fileIndex)) val fsRelation = HadoopFsRelation( location = fileIndex, partitionSchema = partitionSchema, - dataSchema = dataSchema, + dataSchema = updatedTable.dataSchema, bucketSpec = None, fileFormat = fileFormat, options = options)(sparkSession = sparkSession) @@ -191,13 +190,13 @@ private[hive] class HiveMetastoreCatalog(sparkSession: SparkSession) extends Log fileFormatClass, None) val logicalRelation = cached.getOrElse { - val (dataSchema, updatedTable) = inferIfNeeded(relation, options, fileFormat) + val updatedTable = inferIfNeeded(relation, options, fileFormat) val created = LogicalRelation( DataSource( sparkSession = sparkSession, paths = rootPath.toString :: Nil, - userSpecifiedSchema = Option(dataSchema), + userSpecifiedSchema = Option(updatedTable.schema), bucketSpec = None, options = options, className = fileType).resolveRelation(), @@ -220,11 +219,15 @@ private[hive] class HiveMetastoreCatalog(sparkSession: SparkSession) extends Log result.copy(output = newOutput) } + /** + * Infer the data schema from files if needed, and return a `CatalogTable` with the corrected + * table schema. + */ private def inferIfNeeded( relation: HiveTableRelation, options: Map[String, String], fileFormat: FileFormat, - fileIndexOpt: Option[FileIndex] = None): (StructType, CatalogTable) = { + fileIndexOpt: Option[FileIndex] = None): CatalogTable = { val inferenceMode = sparkSession.sessionState.conf.caseSensitiveInferenceMode val shouldInfer = (inferenceMode != NEVER_INFER) && !relation.tableMeta.schemaPreservesCase val tableName = relation.tableMeta.identifier.unquotedString @@ -241,21 +244,22 @@ private[hive] class HiveMetastoreCatalog(sparkSession: SparkSession) extends Log sparkSession, options, fileIndex.listFiles(Nil, Nil).flatMap(_.files)) - .map(mergeWithMetastoreSchema(relation.tableMeta.schema, _)) + .map(mergeWithMetastoreSchema(relation.tableMeta.dataSchema, _)) inferredSchema match { - case Some(schema) => + case Some(dataSchema) => + val actualSchema = StructType(dataSchema ++ relation.tableMeta.partitionSchema) if (inferenceMode == INFER_AND_SAVE) { - updateCatalogSchema(relation.tableMeta.identifier, schema) + updateCatalogSchema(relation.tableMeta.identifier, actualSchema) } - (schema, relation.tableMeta.copy(schema = schema)) + relation.tableMeta.copy(schema = actualSchema) case None => logWarning(s"Unable to infer schema for table $tableName from file format " + s"$fileFormat (inference mode: $inferenceMode). Using metastore schema.") - (relation.tableMeta.schema, relation.tableMeta) + relation.tableMeta } } else { - (relation.tableMeta.schema, relation.tableMeta) + relation.tableMeta } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetadataCacheSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetadataCacheSuite.scala index e71aba72c31f..6e1a1c1481c3 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetadataCacheSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetadataCacheSuite.scala @@ -86,7 +86,6 @@ class HiveMetadataCacheSuite extends QueryTest with SQLTestUtils with TestHiveSi |location "${dir.toURI}"""".stripMargin) spark.sql("msck repair table test") - val df = spark.sql("select * from test") assert(sql("select * from test").count() == 5) def deleteRandomFile(): Unit = { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcPartitionDiscoverySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcPartitionDiscoverySuite.scala index d1ce3f1e2f05..c27070d50bf8 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcPartitionDiscoverySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcPartitionDiscoverySuite.scala @@ -146,7 +146,7 @@ class OrcPartitionDiscoverySuite extends QueryTest with TestHiveSingleton with B i <- 1 to 10 pi <- Seq(1, 2) ps <- Seq("foo", "bar") - } yield Row(i, pi, i.toString, ps)) + } yield Row(i, i.toString, pi, ps)) checkAnswer( sql("SELECT intField, pi FROM t"), @@ -161,14 +161,14 @@ class OrcPartitionDiscoverySuite extends QueryTest with TestHiveSingleton with B for { i <- 1 to 10 ps <- Seq("foo", "bar") - } yield Row(i, 1, i.toString, ps)) + } yield Row(i, i.toString, 1, ps)) checkAnswer( sql("SELECT * FROM t WHERE ps = 'foo'"), for { i <- 1 to 10 pi <- Seq(1, 2) - } yield Row(i, pi, i.toString, "foo")) + } yield Row(i, i.toString, pi, "foo")) } } } @@ -240,14 +240,14 @@ class OrcPartitionDiscoverySuite extends QueryTest with TestHiveSingleton with B i <- 1 to 10 pi <- Seq(1, 2) ps <- Seq("foo", null.asInstanceOf[String]) - } yield Row(i, pi, i.toString, ps)) + } yield Row(i, i.toString, pi, ps)) checkAnswer( sql("SELECT * FROM t WHERE ps IS NULL"), for { i <- 1 to 10 pi <- Seq(1, 2) - } yield Row(i, pi, i.toString, null)) + } yield Row(i, i.toString, pi, null)) } } }