Skip to content

Commit cd539fe

Browse files
committed
Address comments.
1 parent a738943 commit cd539fe

File tree

5 files changed

+62
-22
lines changed

5 files changed

+62
-22
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@ import org.apache.spark.sql._
2323
import org.apache.spark.sql.catalyst.catalog._
2424
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
2525
import org.apache.spark.sql.execution.datasources._
26+
import org.apache.spark.sql.execution.datasources.orc.OrcFileFormat
27+
import org.apache.spark.sql.execution.datasources.parquet.ParquetSchemaConverter
2628
import org.apache.spark.sql.sources.BaseRelation
2729

2830
/**
@@ -85,6 +87,13 @@ case class CreateDataSourceTableCommand(table: CatalogTable, ignoreIfExists: Boo
8587
}
8688
}
8789

90+
table.provider.get.toLowerCase match {
91+
case "parquet" =>
92+
dataSource.schema.map(_.name).foreach(ParquetSchemaConverter.checkFieldName)
93+
case "orc" =>
94+
dataSource.schema.map(_.name).foreach(OrcFileFormat.checkFieldName)
95+
}
96+
8897
val newTable = table.copy(
8998
schema = dataSource.schema,
9099
partitionColumnNames = partitionColumnNames,
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.execution.datasources.orc
19+
20+
import org.apache.orc.TypeDescription
21+
22+
import org.apache.spark.sql.AnalysisException
23+
24+
private[sql] object OrcFileFormat {
25+
def checkFieldName(name: String): Unit = {
26+
try {
27+
TypeDescription.fromString(s"struct<$name:int>")
28+
} catch {
29+
case _: IllegalArgumentException =>
30+
throw new AnalysisException(
31+
s"""Attribute name "$name" contains invalid character(s).
32+
|Please use alias to rename it.
33+
""".stripMargin.split("\n").mkString(" ").trim)
34+
}
35+
}
36+
}

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaConverter.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -556,7 +556,7 @@ private[parquet] class ParquetSchemaConverter(
556556
}
557557
}
558558

559-
private[parquet] object ParquetSchemaConverter {
559+
private[sql] object ParquetSchemaConverter {
560560
val SPARK_PARQUET_SCHEMA_NAME = "spark_schema"
561561

562562
val EMPTY_MESSAGE: MessageType =

sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala

Lines changed: 3 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -32,13 +32,13 @@ import org.apache.hadoop.io.{NullWritable, Writable}
3232
import org.apache.hadoop.mapred.{JobConf, OutputFormat => MapRedOutputFormat, RecordWriter, Reporter}
3333
import org.apache.hadoop.mapreduce._
3434
import org.apache.hadoop.mapreduce.lib.input.{FileInputFormat, FileSplit}
35-
import org.apache.orc.TypeDescription
3635

3736
import org.apache.spark.TaskContext
38-
import org.apache.spark.sql._
37+
import org.apache.spark.sql.SparkSession
3938
import org.apache.spark.sql.catalyst.InternalRow
4039
import org.apache.spark.sql.catalyst.expressions._
4140
import org.apache.spark.sql.execution.datasources._
41+
import org.apache.spark.sql.execution.datasources.orc.OrcFileFormat
4242
import org.apache.spark.sql.hive.{HiveInspectors, HiveShim}
4343
import org.apache.spark.sql.sources._
4444
import org.apache.spark.sql.types.StructType
@@ -84,7 +84,7 @@ class OrcFileFormat extends FileFormat with DataSourceRegister with Serializable
8484
classOf[MapRedOutputFormat[_, _]])
8585
}
8686

87-
dataSchema.map(_.name).foreach(checkFieldName)
87+
dataSchema.map(_.name).foreach(OrcFileFormat.checkFieldName)
8888

8989
new OutputWriterFactory {
9090
override def newInstance(
@@ -172,18 +172,6 @@ class OrcFileFormat extends FileFormat with DataSourceRegister with Serializable
172172
}
173173
}
174174
}
175-
176-
private def checkFieldName(name: String): Unit = {
177-
try {
178-
TypeDescription.fromString(s"struct<$name:int>")
179-
} catch {
180-
case _: IllegalArgumentException =>
181-
throw new AnalysisException(
182-
s"""Attribute name "$name" contains invalid character(s).
183-
|Please use alias to rename it.
184-
""".stripMargin.split("\n").mkString(" ").trim)
185-
}
186-
}
187175
}
188176

189177
private[orc] class OrcSerializer(dataSchema: StructType, conf: Configuration)

sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2001,13 +2001,20 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton {
20012001
}
20022002
}
20032003

2004-
test("SPARK-21912 Creating ORC datasource table should check invalid column names") {
2004+
test("SPARK-21912 Creating ORC/Parquet datasource table should check invalid column names") {
20052005
withTable("orc1") {
2006-
Seq(" ", "?", ",", ";", "{", "}", "(", ")", "\n", "\t", "=").foreach { name =>
2007-
val m = intercept[AnalysisException] {
2008-
sql(s"CREATE TABLE orc1 USING ORC AS SELECT 1 `column$name`")
2009-
}.getMessage
2010-
assert(m.contains(s"contains invalid character(s)"))
2006+
Seq(" ", ",", ";", "{", "}", "(", ")", "\n", "\t", "=").foreach { name =>
2007+
Seq("ORC", "PARQUET").foreach { dataSource =>
2008+
val m = intercept[AnalysisException] {
2009+
sql(s"CREATE TABLE orc1(`column$name` INT) USING $dataSource")
2010+
}.getMessage
2011+
assert(m.contains(s"contains invalid character(s)"))
2012+
2013+
val m2 = intercept[AnalysisException] {
2014+
sql(s"CREATE TABLE orc1 USING $dataSource AS SELECT 1 `column$name`")
2015+
}.getMessage
2016+
assert(m2.contains(s"contains invalid character(s)"))
2017+
}
20112018
}
20122019
}
20132020
}

0 commit comments

Comments
 (0)