Skip to content

Commit 4754e16

Browse files
cloud-fanmarmbrus
authored andcommitted
[SPARK-6898][SQL] completely support special chars in column names
Even if we wrap column names in backticks like `` `a#$b.c` ``, we still handle the "." inside column name specially. I think it's fragile to use a special char to split name parts, why not put name parts in `UnresolvedAttribute` directly? Author: Wenchen Fan <[email protected]> This patch had conflicts when merged, resolved by Committer: Michael Armbrust <[email protected]> Closes apache#5511 from cloud-fan/6898 and squashes the following commits: 48e3e57 [Wenchen Fan] more style fix 820dc45 [Wenchen Fan] do not ignore newName in UnresolvedAttribute d81ad43 [Wenchen Fan] fix style 11699d6 [Wenchen Fan] completely support special chars in column names
1 parent 557a797 commit 4754e16

File tree

9 files changed

+52
-33
lines changed

9 files changed

+52
-33
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -381,13 +381,13 @@ class SqlParser extends AbstractSparkSQLParser with DataTypeParser {
381381
| "(" ~> expression <~ ")"
382382
| function
383383
| dotExpressionHeader
384-
| ident ^^ UnresolvedAttribute
384+
| ident ^^ {case i => UnresolvedAttribute.quoted(i)}
385385
| signedPrimary
386386
| "~" ~> expression ^^ BitwiseNot
387387
)
388388

389389
protected lazy val dotExpressionHeader: Parser[Expression] =
390390
(ident <~ ".") ~ ident ~ rep("." ~> ident) ^^ {
391-
case i1 ~ i2 ~ rest => UnresolvedAttribute((Seq(i1, i2) ++ rest).mkString("."))
391+
case i1 ~ i2 ~ rest => UnresolvedAttribute(Seq(i1, i2) ++ rest)
392392
}
393393
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -297,14 +297,15 @@ class Analyzer(
297297
case q: LogicalPlan =>
298298
logTrace(s"Attempting to resolve ${q.simpleString}")
299299
q transformExpressionsUp {
300-
case u @ UnresolvedAttribute(name) if resolver(name, VirtualColumn.groupingIdName) &&
300+
case u @ UnresolvedAttribute(nameParts) if nameParts.length == 1 &&
301+
resolver(nameParts(0), VirtualColumn.groupingIdName) &&
301302
q.isInstanceOf[GroupingAnalytics] =>
302303
// Resolve the virtual column GROUPING__ID for the operator GroupingAnalytics
303304
q.asInstanceOf[GroupingAnalytics].gid
304-
case u @ UnresolvedAttribute(name) =>
305+
case u @ UnresolvedAttribute(nameParts) =>
305306
// Leave unchanged if resolution fails. Hopefully will be resolved next round.
306307
val result =
307-
withPosition(u) { q.resolveChildren(name, resolver).getOrElse(u) }
308+
withPosition(u) { q.resolveChildren(nameParts, resolver).getOrElse(u) }
308309
logDebug(s"Resolving $u to $result")
309310
result
310311
case UnresolvedGetField(child, fieldName) if child.resolved =>
@@ -383,12 +384,12 @@ class Analyzer(
383384
child: LogicalPlan,
384385
grandchild: LogicalPlan): (Seq[SortOrder], Seq[Attribute]) = {
385386
// Find any attributes that remain unresolved in the sort.
386-
val unresolved: Seq[String] =
387-
ordering.flatMap(_.collect { case UnresolvedAttribute(name) => name })
387+
val unresolved: Seq[Seq[String]] =
388+
ordering.flatMap(_.collect { case UnresolvedAttribute(nameParts) => nameParts })
388389

389390
// Create a map from name, to resolved attributes, when the desired name can be found
390391
// prior to the projection.
391-
val resolved: Map[String, NamedExpression] =
392+
val resolved: Map[Seq[String], NamedExpression] =
392393
unresolved.flatMap(u => grandchild.resolve(u, resolver).map(a => u -> a)).toMap
393394

394395
// Construct a set that contains all of the attributes that we need to evaluate the

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,8 +46,12 @@ trait CheckAnalysis {
4646
operator transformExpressionsUp {
4747
case a: Attribute if !a.resolved =>
4848
if (operator.childrenResolved) {
49+
val nameParts = a match {
50+
case UnresolvedAttribute(nameParts) => nameParts
51+
case _ => Seq(a.name)
52+
}
4953
// Throw errors for specific problems with get field.
50-
operator.resolveChildren(a.name, resolver, throwErrors = true)
54+
operator.resolveChildren(nameParts, resolver, throwErrors = true)
5155
}
5256

5357
val from = operator.inputSet.map(_.name).mkString(", ")

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,12 @@ case class UnresolvedRelation(
4949
/**
5050
* Holds the name of an attribute that has yet to be resolved.
5151
*/
52-
case class UnresolvedAttribute(name: String) extends Attribute with trees.LeafNode[Expression] {
52+
case class UnresolvedAttribute(nameParts: Seq[String])
53+
extends Attribute with trees.LeafNode[Expression] {
54+
55+
def name: String =
56+
nameParts.map(n => if (n.contains(".")) s"`$n`" else n).mkString(".")
57+
5358
override def exprId: ExprId = throw new UnresolvedException(this, "exprId")
5459
override def dataType: DataType = throw new UnresolvedException(this, "dataType")
5560
override def nullable: Boolean = throw new UnresolvedException(this, "nullable")
@@ -59,7 +64,7 @@ case class UnresolvedAttribute(name: String) extends Attribute with trees.LeafNo
5964
override def newInstance(): UnresolvedAttribute = this
6065
override def withNullability(newNullability: Boolean): UnresolvedAttribute = this
6166
override def withQualifiers(newQualifiers: Seq[String]): UnresolvedAttribute = this
62-
override def withName(newName: String): UnresolvedAttribute = UnresolvedAttribute(name)
67+
override def withName(newName: String): UnresolvedAttribute = UnresolvedAttribute.quoted(newName)
6368

6469
// Unresolved attributes are transient at compile time and don't get evaluated during execution.
6570
override def eval(input: Row = null): EvaluatedType =
@@ -68,6 +73,11 @@ case class UnresolvedAttribute(name: String) extends Attribute with trees.LeafNo
6873
override def toString: String = s"'$name"
6974
}
7075

76+
object UnresolvedAttribute {
77+
def apply(name: String): UnresolvedAttribute = new UnresolvedAttribute(name.split("\\."))
78+
def quoted(name: String): UnresolvedAttribute = new UnresolvedAttribute(Seq(name))
79+
}
80+
7181
case class UnresolvedFunction(name: String, children: Seq[Expression]) extends Expression {
7282
override def dataType: DataType = throw new UnresolvedException(this, "dataType")
7383
override def foldable: Boolean = throw new UnresolvedException(this, "foldable")

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala

Lines changed: 13 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,11 @@ package org.apache.spark.sql.catalyst.plans.logical
1919

2020
import org.apache.spark.Logging
2121
import org.apache.spark.sql.AnalysisException
22-
import org.apache.spark.sql.catalyst.analysis.{EliminateSubQueries, UnresolvedGetField, Resolver}
22+
import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, EliminateSubQueries, Resolver}
2323
import org.apache.spark.sql.catalyst.expressions._
2424
import org.apache.spark.sql.catalyst.plans.QueryPlan
2525
import org.apache.spark.sql.catalyst.trees.TreeNode
2626
import org.apache.spark.sql.catalyst.trees
27-
import org.apache.spark.sql.types.{ArrayType, StructType, StructField}
2827

2928

3029
abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging {
@@ -111,21 +110,21 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging {
111110
* as string in the following form: `[scope].AttributeName.[nested].[fields]...`.
112111
*/
113112
def resolveChildren(
114-
name: String,
113+
nameParts: Seq[String],
115114
resolver: Resolver,
116115
throwErrors: Boolean = false): Option[NamedExpression] =
117-
resolve(name, children.flatMap(_.output), resolver, throwErrors)
116+
resolve(nameParts, children.flatMap(_.output), resolver, throwErrors)
118117

119118
/**
120119
* Optionally resolves the given string to a [[NamedExpression]] based on the output of this
121120
* LogicalPlan. The attribute is expressed as string in the following form:
122121
* `[scope].AttributeName.[nested].[fields]...`.
123122
*/
124123
def resolve(
125-
name: String,
124+
nameParts: Seq[String],
126125
resolver: Resolver,
127126
throwErrors: Boolean = false): Option[NamedExpression] =
128-
resolve(name, output, resolver, throwErrors)
127+
resolve(nameParts, output, resolver, throwErrors)
129128

130129
/**
131130
* Resolve the given `name` string against the given attribute, returning either 0 or 1 match.
@@ -135,7 +134,7 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging {
135134
* See the comment above `candidates` variable in resolve() for semantics the returned data.
136135
*/
137136
private def resolveAsTableColumn(
138-
nameParts: Array[String],
137+
nameParts: Seq[String],
139138
resolver: Resolver,
140139
attribute: Attribute): Option[(Attribute, List[String])] = {
141140
assert(nameParts.length > 1)
@@ -155,7 +154,7 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging {
155154
* See the comment above `candidates` variable in resolve() for semantics the returned data.
156155
*/
157156
private def resolveAsColumn(
158-
nameParts: Array[String],
157+
nameParts: Seq[String],
159158
resolver: Resolver,
160159
attribute: Attribute): Option[(Attribute, List[String])] = {
161160
if (resolver(attribute.name, nameParts.head)) {
@@ -167,13 +166,11 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging {
167166

168167
/** Performs attribute resolution given a name and a sequence of possible attributes. */
169168
protected def resolve(
170-
name: String,
169+
nameParts: Seq[String],
171170
input: Seq[Attribute],
172171
resolver: Resolver,
173172
throwErrors: Boolean): Option[NamedExpression] = {
174173

175-
val parts = name.split("\\.")
176-
177174
// A sequence of possible candidate matches.
178175
// Each candidate is a tuple. The first element is a resolved attribute, followed by a list
179176
// of parts that are to be resolved.
@@ -182,9 +179,9 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging {
182179
// and the second element will be List("c").
183180
var candidates: Seq[(Attribute, List[String])] = {
184181
// If the name has 2 or more parts, try to resolve it as `table.column` first.
185-
if (parts.length > 1) {
182+
if (nameParts.length > 1) {
186183
input.flatMap { option =>
187-
resolveAsTableColumn(parts, resolver, option)
184+
resolveAsTableColumn(nameParts, resolver, option)
188185
}
189186
} else {
190187
Seq.empty
@@ -194,10 +191,12 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging {
194191
// If none of attributes match `table.column` pattern, we try to resolve it as a column.
195192
if (candidates.isEmpty) {
196193
candidates = input.flatMap { candidate =>
197-
resolveAsColumn(parts, resolver, candidate)
194+
resolveAsColumn(nameParts, resolver, candidate)
198195
}
199196
}
200197

198+
def name = UnresolvedAttribute(nameParts).name
199+
201200
candidates.distinct match {
202201
// One match, no nested fields, use it.
203202
case Seq((a, Nil)) => Some(a)

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,6 @@ import org.apache.spark.sql.types._
2727
import org.apache.spark.sql.catalyst.dsl.expressions._
2828
import org.apache.spark.sql.catalyst.dsl.plans._
2929

30-
import scala.collection.immutable
31-
3230
class AnalysisSuite extends FunSuite with BeforeAndAfter {
3331
val caseSensitiveCatalog = new SimpleCatalog(true)
3432
val caseInsensitiveCatalog = new SimpleCatalog(false)

sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -158,15 +158,15 @@ class DataFrame private[sql](
158158
}
159159

160160
protected[sql] def resolve(colName: String): NamedExpression = {
161-
queryExecution.analyzed.resolve(colName, sqlContext.analyzer.resolver).getOrElse {
161+
queryExecution.analyzed.resolve(colName.split("\\."), sqlContext.analyzer.resolver).getOrElse {
162162
throw new AnalysisException(
163163
s"""Cannot resolve column name "$colName" among (${schema.fieldNames.mkString(", ")})""")
164164
}
165165
}
166166

167167
protected[sql] def numericColumns: Seq[Expression] = {
168168
schema.fields.filter(_.dataType.isInstanceOf[NumericType]).map { n =>
169-
queryExecution.analyzed.resolve(n.name, sqlContext.analyzer.resolver).get
169+
queryExecution.analyzed.resolve(n.name.split("\\."), sqlContext.analyzer.resolver).get
170170
}
171171
}
172172

sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,14 +19,13 @@ package org.apache.spark.sql
1919

2020
import org.scalatest.BeforeAndAfterAll
2121

22-
import org.apache.spark.sql.TestData._
2322
import org.apache.spark.sql.execution.GeneratedAggregate
2423
import org.apache.spark.sql.functions._
24+
import org.apache.spark.sql.TestData._
2525
import org.apache.spark.sql.test.TestSQLContext
2626
import org.apache.spark.sql.test.TestSQLContext.{udf => _, _}
2727
import org.apache.spark.sql.types._
2828

29-
3029
class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
3130
// Make sure the tables are loaded.
3231
TestData
@@ -1125,7 +1124,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
11251124
val data = sparkContext.parallelize(
11261125
Seq("""{"key?number1": "value1", "key.number2": "value2"}"""))
11271126
jsonRDD(data).registerTempTable("records")
1128-
sql("SELECT `key?number1` FROM records")
1127+
sql("SELECT `key?number1`, `key.number2` FROM records")
11291128
}
11301129

11311130
test("SPARK-3814 Support Bitwise & operator") {
@@ -1225,4 +1224,12 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
12251224
checkAnswer(sql("SELECT a.b[0] FROM t ORDER BY c0.a"), Row(1))
12261225
checkAnswer(sql("SELECT b[0].a FROM t ORDER BY c0.a"), Row(1))
12271226
}
1227+
1228+
test("SPARK-6898: complete support for special chars in column names") {
1229+
jsonRDD(sparkContext.makeRDD(
1230+
"""{"a": {"c.b": 1}, "b.$q": [{"a@!.q": 1}], "q.w": {"w.i&": [1]}}""" :: Nil))
1231+
.registerTempTable("t")
1232+
1233+
checkAnswer(sql("SELECT a.`c.b`, `b.$q`[0].`a@!.q`, `q.w`.`w.i&`[0] FROM t"), Row(1, 1, 1))
1234+
}
12281235
}

sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1101,7 +1101,7 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C
11011101
case Token(".", qualifier :: Token(attr, Nil) :: Nil) =>
11021102
nodeToExpr(qualifier) match {
11031103
case UnresolvedAttribute(qualifierName) =>
1104-
UnresolvedAttribute(qualifierName + "." + cleanIdentifier(attr))
1104+
UnresolvedAttribute(qualifierName :+ cleanIdentifier(attr))
11051105
case other => UnresolvedGetField(other, attr)
11061106
}
11071107

0 commit comments

Comments
 (0)