Skip to content

Commit 3d6b68b

Browse files
gengliangwangcloud-fan
authored andcommitted
[SPARK-25313][SQL] Fix regression in FileFormatWriter output names
## What changes were proposed in this pull request? Let's see the follow example: ``` val location = "/tmp/t" val df = spark.range(10).toDF("id") df.write.format("parquet").saveAsTable("tbl") spark.sql("CREATE VIEW view1 AS SELECT id FROM tbl") spark.sql(s"CREATE TABLE tbl2(ID long) USING parquet location $location") spark.sql("INSERT OVERWRITE TABLE tbl2 SELECT ID FROM view1") println(spark.read.parquet(location).schema) spark.table("tbl2").show() ``` The output column name in schema will be `id` instead of `ID`, thus the last query shows nothing from `tbl2`. By enabling the debug message we can see that the output naming is changed from `ID` to `id`, and then the `outputColumns` in `InsertIntoHadoopFsRelationCommand` is changed in `RemoveRedundantAliases`. ![wechatimg5](https://user-images.githubusercontent.com/1097932/44947871-6299f200-ae46-11e8-9c96-d45fe368206c.jpeg) ![wechatimg4](https://user-images.githubusercontent.com/1097932/44947866-56ae3000-ae46-11e8-8923-8b3bbe060075.jpeg) **To guarantee correctness**, we should change the output columns from `Seq[Attribute]` to `Seq[String]` to avoid its names being replaced by optimizer. I will fix project elimination related rules in #22311 after this one. ## How was this patch tested? Unit test. Closes #22320 from gengliangwang/fixOutputSchema. Authored-by: Gengliang Wang <[email protected]> Signed-off-by: Wenchen Fan <[email protected]>
1 parent 3e03303 commit 3d6b68b

File tree

11 files changed

+189
-25
lines changed

11 files changed

+189
-25
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/command/DataWritingCommand.scala

Lines changed: 40 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,8 @@ import org.apache.spark.sql.catalyst.plans.logical.{Command, LogicalPlan}
2525
import org.apache.spark.sql.execution.SparkPlan
2626
import org.apache.spark.sql.execution.datasources.BasicWriteJobStatsTracker
2727
import org.apache.spark.sql.execution.datasources.FileFormatWriter
28-
import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics}
28+
import org.apache.spark.sql.execution.metric.SQLMetric
29+
import org.apache.spark.sql.types.StructType
2930
import org.apache.spark.util.SerializableConfiguration
3031

3132
/**
@@ -41,8 +42,12 @@ trait DataWritingCommand extends Command {
4142

4243
override final def children: Seq[LogicalPlan] = query :: Nil
4344

44-
// Output columns of the analyzed input query plan
45-
def outputColumns: Seq[Attribute]
45+
// Output column names of the analyzed input query plan.
46+
def outputColumnNames: Seq[String]
47+
48+
// Output columns of the analyzed input query plan.
49+
def outputColumns: Seq[Attribute] =
50+
DataWritingCommand.logicalPlanOutputWithNames(query, outputColumnNames)
4651

4752
lazy val metrics: Map[String, SQLMetric] = BasicWriteJobStatsTracker.metrics
4853

@@ -53,3 +58,35 @@ trait DataWritingCommand extends Command {
5358

5459
def run(sparkSession: SparkSession, child: SparkPlan): Seq[Row]
5560
}
61+
62+
object DataWritingCommand {
63+
/**
64+
* Returns output attributes with provided names.
65+
* The length of provided names should be the same of the length of [[LogicalPlan.output]].
66+
*/
67+
def logicalPlanOutputWithNames(
68+
query: LogicalPlan,
69+
names: Seq[String]): Seq[Attribute] = {
70+
// Save the output attributes to a variable to avoid duplicated function calls.
71+
val outputAttributes = query.output
72+
assert(outputAttributes.length == names.length,
73+
"The length of provided names doesn't match the length of output attributes.")
74+
outputAttributes.zip(names).map { case (attr, outputName) =>
75+
attr.withName(outputName)
76+
}
77+
}
78+
79+
/**
80+
* Returns schema of logical plan with provided names.
81+
* The length of provided names should be the same of the length of [[LogicalPlan.schema]].
82+
*/
83+
def logicalPlanSchemaWithNames(
84+
query: LogicalPlan,
85+
names: Seq[String]): StructType = {
86+
assert(query.schema.length == names.length,
87+
"The length of provided names doesn't match the length of query schema.")
88+
StructType(query.schema.zip(names).map { case (structField, outputName) =>
89+
structField.copy(name = outputName)
90+
})
91+
}
92+
}

sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,7 @@ case class CreateDataSourceTableAsSelectCommand(
139139
table: CatalogTable,
140140
mode: SaveMode,
141141
query: LogicalPlan,
142-
outputColumns: Seq[Attribute])
142+
outputColumnNames: Seq[String])
143143
extends DataWritingCommand {
144144

145145
override def run(sparkSession: SparkSession, child: SparkPlan): Seq[Row] = {
@@ -214,7 +214,7 @@ case class CreateDataSourceTableAsSelectCommand(
214214
catalogTable = if (tableExists) Some(table) else None)
215215

216216
try {
217-
dataSource.writeAndRead(mode, query, outputColumns, physicalPlan)
217+
dataSource.writeAndRead(mode, query, outputColumnNames, physicalPlan)
218218
} catch {
219219
case ex: AnalysisException =>
220220
logError(s"Failed to write to table ${table.identifier.unquotedString}", ex)

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ import org.apache.spark.sql.catalyst.expressions.Attribute
3434
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
3535
import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap
3636
import org.apache.spark.sql.execution.SparkPlan
37+
import org.apache.spark.sql.execution.command.DataWritingCommand
3738
import org.apache.spark.sql.execution.datasources.csv.CSVFileFormat
3839
import org.apache.spark.sql.execution.datasources.jdbc.JdbcRelationProvider
3940
import org.apache.spark.sql.execution.datasources.json.JsonFileFormat
@@ -450,7 +451,7 @@ case class DataSource(
450451
mode = mode,
451452
catalogTable = catalogTable,
452453
fileIndex = fileIndex,
453-
outputColumns = data.output)
454+
outputColumnNames = data.output.map(_.name))
454455
}
455456

456457
/**
@@ -460,9 +461,9 @@ case class DataSource(
460461
* @param mode The save mode for this writing.
461462
* @param data The input query plan that produces the data to be written. Note that this plan
462463
* is analyzed and optimized.
463-
* @param outputColumns The original output columns of the input query plan. The optimizer may not
464-
* preserve the output column's names' case, so we need this parameter
465-
* instead of `data.output`.
464+
* @param outputColumnNames The original output column names of the input query plan. The
465+
* optimizer may not preserve the output column's names' case, so we need
466+
* this parameter instead of `data.output`.
466467
* @param physicalPlan The physical plan of the input query plan. We should run the writing
467468
* command with this physical plan instead of creating a new physical plan,
468469
* so that the metrics can be correctly linked to the given physical plan and
@@ -471,8 +472,9 @@ case class DataSource(
471472
def writeAndRead(
472473
mode: SaveMode,
473474
data: LogicalPlan,
474-
outputColumns: Seq[Attribute],
475+
outputColumnNames: Seq[String],
475476
physicalPlan: SparkPlan): BaseRelation = {
477+
val outputColumns = DataWritingCommand.logicalPlanOutputWithNames(data, outputColumnNames)
476478
if (outputColumns.map(_.dataType).exists(_.isInstanceOf[CalendarIntervalType])) {
477479
throw new AnalysisException("Cannot save interval data type into external storage.")
478480
}
@@ -495,7 +497,9 @@ case class DataSource(
495497
s"Unable to resolve $name given [${data.output.map(_.name).mkString(", ")}]")
496498
}
497499
}
498-
val resolved = cmd.copy(partitionColumns = resolvedPartCols, outputColumns = outputColumns)
500+
val resolved = cmd.copy(
501+
partitionColumns = resolvedPartCols,
502+
outputColumnNames = outputColumnNames)
499503
resolved.run(sparkSession, physicalPlan)
500504
// Replace the schema with that of the DataFrame we just wrote out to avoid re-inferring
501505
copy(userSpecifiedSchema = Some(outputColumns.toStructType.asNullable)).resolveRelation()

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,7 @@ case class DataSourceAnalysis(conf: SQLConf) extends Rule[LogicalPlan] with Cast
139139
case CreateTable(tableDesc, mode, Some(query))
140140
if query.resolved && DDLUtils.isDatasourceTable(tableDesc) =>
141141
DDLUtils.checkDataColNames(tableDesc.copy(schema = query.schema))
142-
CreateDataSourceTableAsSelectCommand(tableDesc, mode, query, query.output)
142+
CreateDataSourceTableAsSelectCommand(tableDesc, mode, query, query.output.map(_.name))
143143

144144
case InsertIntoTable(l @ LogicalRelation(_: InsertableRelation, _, _, _),
145145
parts, query, overwrite, false) if parts.isEmpty =>
@@ -209,7 +209,7 @@ case class DataSourceAnalysis(conf: SQLConf) extends Rule[LogicalPlan] with Cast
209209
mode,
210210
table,
211211
Some(t.location),
212-
actualQuery.output)
212+
actualQuery.output.map(_.name))
213213
}
214214
}
215215

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -56,14 +56,14 @@ case class InsertIntoHadoopFsRelationCommand(
5656
mode: SaveMode,
5757
catalogTable: Option[CatalogTable],
5858
fileIndex: Option[FileIndex],
59-
outputColumns: Seq[Attribute])
59+
outputColumnNames: Seq[String])
6060
extends DataWritingCommand {
6161
import org.apache.spark.sql.catalyst.catalog.ExternalCatalogUtils.escapePathName
6262

6363
override def run(sparkSession: SparkSession, child: SparkPlan): Seq[Row] = {
6464
// Most formats don't do well with duplicate columns, so lets not allow that
65-
SchemaUtils.checkSchemaColumnNameDuplication(
66-
query.schema,
65+
SchemaUtils.checkColumnNameDuplication(
66+
outputColumnNames,
6767
s"when inserting into $outputPath",
6868
sparkSession.sessionState.conf.caseSensitiveAnalysis)
6969

sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -805,6 +805,80 @@ class DataFrameReaderWriterSuite extends QueryTest with SharedSQLContext with Be
805805
}
806806
}
807807

808+
test("Insert overwrite table command should output correct schema: basic") {
809+
withTable("tbl", "tbl2") {
810+
withView("view1") {
811+
val df = spark.range(10).toDF("id")
812+
df.write.format("parquet").saveAsTable("tbl")
813+
spark.sql("CREATE VIEW view1 AS SELECT id FROM tbl")
814+
spark.sql("CREATE TABLE tbl2(ID long) USING parquet")
815+
spark.sql("INSERT OVERWRITE TABLE tbl2 SELECT ID FROM view1")
816+
val identifier = TableIdentifier("tbl2")
817+
val location = spark.sessionState.catalog.getTableMetadata(identifier).location.toString
818+
val expectedSchema = StructType(Seq(StructField("ID", LongType, true)))
819+
assert(spark.read.parquet(location).schema == expectedSchema)
820+
checkAnswer(spark.table("tbl2"), df)
821+
}
822+
}
823+
}
824+
825+
test("Insert overwrite table command should output correct schema: complex") {
826+
withTable("tbl", "tbl2") {
827+
withView("view1") {
828+
val df = spark.range(10).map(x => (x, x.toInt, x.toInt)).toDF("col1", "col2", "col3")
829+
df.write.format("parquet").saveAsTable("tbl")
830+
spark.sql("CREATE VIEW view1 AS SELECT * FROM tbl")
831+
spark.sql("CREATE TABLE tbl2(COL1 long, COL2 int, COL3 int) USING parquet PARTITIONED " +
832+
"BY (COL2) CLUSTERED BY (COL3) INTO 3 BUCKETS")
833+
spark.sql("INSERT OVERWRITE TABLE tbl2 SELECT COL1, COL2, COL3 FROM view1")
834+
val identifier = TableIdentifier("tbl2")
835+
val location = spark.sessionState.catalog.getTableMetadata(identifier).location.toString
836+
val expectedSchema = StructType(Seq(
837+
StructField("COL1", LongType, true),
838+
StructField("COL3", IntegerType, true),
839+
StructField("COL2", IntegerType, true)))
840+
assert(spark.read.parquet(location).schema == expectedSchema)
841+
checkAnswer(spark.table("tbl2"), df)
842+
}
843+
}
844+
}
845+
846+
test("Create table as select command should output correct schema: basic") {
847+
withTable("tbl", "tbl2") {
848+
withView("view1") {
849+
val df = spark.range(10).toDF("id")
850+
df.write.format("parquet").saveAsTable("tbl")
851+
spark.sql("CREATE VIEW view1 AS SELECT id FROM tbl")
852+
spark.sql("CREATE TABLE tbl2 USING parquet AS SELECT ID FROM view1")
853+
val identifier = TableIdentifier("tbl2")
854+
val location = spark.sessionState.catalog.getTableMetadata(identifier).location.toString
855+
val expectedSchema = StructType(Seq(StructField("ID", LongType, true)))
856+
assert(spark.read.parquet(location).schema == expectedSchema)
857+
checkAnswer(spark.table("tbl2"), df)
858+
}
859+
}
860+
}
861+
862+
test("Create table as select command should output correct schema: complex") {
863+
withTable("tbl", "tbl2") {
864+
withView("view1") {
865+
val df = spark.range(10).map(x => (x, x.toInt, x.toInt)).toDF("col1", "col2", "col3")
866+
df.write.format("parquet").saveAsTable("tbl")
867+
spark.sql("CREATE VIEW view1 AS SELECT * FROM tbl")
868+
spark.sql("CREATE TABLE tbl2 USING parquet PARTITIONED BY (COL2) " +
869+
"CLUSTERED BY (COL3) INTO 3 BUCKETS AS SELECT COL1, COL2, COL3 FROM view1")
870+
val identifier = TableIdentifier("tbl2")
871+
val location = spark.sessionState.catalog.getTableMetadata(identifier).location.toString
872+
val expectedSchema = StructType(Seq(
873+
StructField("COL1", LongType, true),
874+
StructField("COL3", IntegerType, true),
875+
StructField("COL2", IntegerType, true)))
876+
assert(spark.read.parquet(location).schema == expectedSchema)
877+
checkAnswer(spark.table("tbl2"), df)
878+
}
879+
}
880+
}
881+
808882
test("use Spark jobs to list files") {
809883
withSQLConf(SQLConf.PARALLEL_PARTITION_DISCOVERY_THRESHOLD.key -> "1") {
810884
withTempDir { dir =>

sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -149,22 +149,22 @@ object HiveAnalysis extends Rule[LogicalPlan] {
149149
case InsertIntoTable(r: HiveTableRelation, partSpec, query, overwrite, ifPartitionNotExists)
150150
if DDLUtils.isHiveTable(r.tableMeta) =>
151151
InsertIntoHiveTable(r.tableMeta, partSpec, query, overwrite,
152-
ifPartitionNotExists, query.output)
152+
ifPartitionNotExists, query.output.map(_.name))
153153

154154
case CreateTable(tableDesc, mode, None) if DDLUtils.isHiveTable(tableDesc) =>
155155
DDLUtils.checkDataColNames(tableDesc)
156156
CreateTableCommand(tableDesc, ignoreIfExists = mode == SaveMode.Ignore)
157157

158158
case CreateTable(tableDesc, mode, Some(query)) if DDLUtils.isHiveTable(tableDesc) =>
159159
DDLUtils.checkDataColNames(tableDesc)
160-
CreateHiveTableAsSelectCommand(tableDesc, query, query.output, mode)
160+
CreateHiveTableAsSelectCommand(tableDesc, query, query.output.map(_.name), mode)
161161

162162
case InsertIntoDir(isLocal, storage, provider, child, overwrite)
163163
if DDLUtils.isHiveTable(provider) =>
164164
val outputPath = new Path(storage.locationUri.get)
165165
if (overwrite) DDLUtils.verifyNotReadPath(child, outputPath)
166166

167-
InsertIntoHiveDirCommand(isLocal, storage, child, overwrite, child.output)
167+
InsertIntoHiveDirCommand(isLocal, storage, child, overwrite, child.output.map(_.name))
168168
}
169169
}
170170

sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateHiveTableAsSelectCommand.scala

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ import org.apache.spark.sql.execution.command.DataWritingCommand
3737
case class CreateHiveTableAsSelectCommand(
3838
tableDesc: CatalogTable,
3939
query: LogicalPlan,
40-
outputColumns: Seq[Attribute],
40+
outputColumnNames: Seq[String],
4141
mode: SaveMode)
4242
extends DataWritingCommand {
4343

@@ -63,13 +63,14 @@ case class CreateHiveTableAsSelectCommand(
6363
query,
6464
overwrite = false,
6565
ifPartitionNotExists = false,
66-
outputColumns = outputColumns).run(sparkSession, child)
66+
outputColumnNames = outputColumnNames).run(sparkSession, child)
6767
} else {
6868
// TODO ideally, we should get the output data ready first and then
6969
// add the relation into catalog, just in case of failure occurs while data
7070
// processing.
7171
assert(tableDesc.schema.isEmpty)
72-
catalog.createTable(tableDesc.copy(schema = query.schema), ignoreIfExists = false)
72+
val schema = DataWritingCommand.logicalPlanSchemaWithNames(query, outputColumnNames)
73+
catalog.createTable(tableDesc.copy(schema = schema), ignoreIfExists = false)
7374

7475
try {
7576
// Read back the metadata of the table which was created just now.
@@ -82,7 +83,7 @@ case class CreateHiveTableAsSelectCommand(
8283
query,
8384
overwrite = true,
8485
ifPartitionNotExists = false,
85-
outputColumns = outputColumns).run(sparkSession, child)
86+
outputColumnNames = outputColumnNames).run(sparkSession, child)
8687
} catch {
8788
case NonFatal(e) =>
8889
// drop the created table.

sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveDirCommand.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ case class InsertIntoHiveDirCommand(
5757
storage: CatalogStorageFormat,
5858
query: LogicalPlan,
5959
overwrite: Boolean,
60-
outputColumns: Seq[Attribute]) extends SaveAsHiveFile {
60+
outputColumnNames: Seq[String]) extends SaveAsHiveFile {
6161

6262
override def run(sparkSession: SparkSession, child: SparkPlan): Seq[Row] = {
6363
assert(storage.locationUri.nonEmpty)

sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ case class InsertIntoHiveTable(
6969
query: LogicalPlan,
7070
overwrite: Boolean,
7171
ifPartitionNotExists: Boolean,
72-
outputColumns: Seq[Attribute]) extends SaveAsHiveFile {
72+
outputColumnNames: Seq[String]) extends SaveAsHiveFile {
7373

7474
/**
7575
* Inserts all the rows in the table into Hive. Row objects are properly serialized with the

0 commit comments

Comments
 (0)