Skip to content

Commit c21171e

Browse files
committed
Ensure the resolver is used for field lookups and ensure that case insensitive resolution is still case preserving.
1 parent d4320f1 commit c21171e

File tree

5 files changed

+61
-6
lines changed

5 files changed

+61
-6
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ case class UnresolvedAttribute(name: String) extends Attribute with trees.LeafNo
5454
override def newInstance = this
5555
override def withNullability(newNullability: Boolean) = this
5656
override def withQualifiers(newQualifiers: Seq[String]) = this
57+
override def withName(newName: String) = UnresolvedAttribute(name)
5758

5859
// Unresolved attributes are transient at compile time and don't get evaluated during execution.
5960
override def eval(input: Row = null): EvaluatedType =
@@ -97,6 +98,7 @@ case class Star(
9798
override def newInstance = this
9899
override def withNullability(newNullability: Boolean) = this
99100
override def withQualifiers(newQualifiers: Seq[String]) = this
101+
override def withName(newName: String) = this
100102

101103
def expand(input: Seq[Attribute], resolver: Resolver): Seq[NamedExpression] = {
102104
val expandedAttributes: Seq[Attribute] = table match {

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ abstract class Attribute extends NamedExpression {
5959

6060
def withNullability(newNullability: Boolean): Attribute
6161
def withQualifiers(newQualifiers: Seq[String]): Attribute
62+
def withName(newName: String): Attribute
6263

6364
def toAttribute = this
6465
def newInstance: Attribute
@@ -86,7 +87,6 @@ case class Alias(child: Expression, name: String)
8687
override def dataType = child.dataType
8788
override def nullable = child.nullable
8889

89-
9090
override def toAttribute = {
9191
if (resolved) {
9292
AttributeReference(name, child.dataType, child.nullable)(exprId, qualifiers)
@@ -144,6 +144,14 @@ case class AttributeReference(name: String, dataType: DataType, nullable: Boolea
144144
}
145145
}
146146

147+
override def withName(newName: String): AttributeReference = {
148+
if (name == newName) {
149+
this
150+
} else {
151+
AttributeReference(newName, dataType, nullable)(exprId, qualifiers)
152+
}
153+
}
154+
147155
/**
148156
* Returns a copy of this [[AttributeReference]] with new qualifiers.
149157
*/

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala

Lines changed: 42 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,7 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging {
9595
resolver: Resolver): Option[NamedExpression] = {
9696

9797
val parts = name.split("\\.")
98+
9899
// Collect all attributes that are output by this nodes children where either the first part
99100
// matches the name or where the first part matches the scope and the second part matches the
100101
// name. Return these matches along with any remaining parts, which represent dotted access to
@@ -109,25 +110,62 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging {
109110
}
110111

111112
if (resolver(option.name, remainingParts.head)) {
112-
(option, remainingParts.tail.toList) :: Nil
113+
// Preserve the case of the user's attribute reference.
114+
(option.withName(remainingParts.head), remainingParts.tail.toList) :: Nil
113115
} else {
114116
Nil
115117
}
116118
}
117119

118120
options.distinct match {
119-
case Seq((a, Nil)) => Some(a) // One match, no nested fields, use it.
121+
// One match, no nested fields, use it.
122+
case Seq((a, Nil)) => Some(a)
123+
120124
// One match, but we also need to extract the requested nested field.
121125
case Seq((a, nestedFields)) =>
122-
Some(Alias(nestedFields.foldLeft(a: Expression)(GetField), nestedFields.last)())
126+
val aliased =
127+
Alias(
128+
resolveNesting(nestedFields, a, resolver),
129+
nestedFields.last)() // Preserve the case of the user's field access.
130+
Some(aliased)
131+
132+
// No matches.
123133
case Seq() =>
124134
logTrace(s"Could not find $name in ${input.mkString(", ")}")
125-
None // No matches.
135+
None
136+
137+
// More than one match.
126138
case ambiguousReferences =>
127139
throw new TreeNodeException(
128140
this, s"Ambiguous references to $name: ${ambiguousReferences.mkString(",")}")
129141
}
130142
}
143+
144+
/**
145+
* Given a list of successive nested field accesses, and a based expression, attempt to resolve
146+
* the actual field lookups on this expression.
147+
*/
148+
private def resolveNesting(
149+
nestedFields: List[String],
150+
expression: Expression,
151+
resolver: Resolver): Expression = {
152+
153+
(nestedFields, expression.dataType) match {
154+
case (Nil, _) => expression
155+
case (requestedField :: rest, StructType(fields)) =>
156+
val actualField = fields.filter(f => resolver(f.name, requestedField))
157+
actualField match {
158+
case Seq() =>
159+
sys.error(
160+
s"No such struct field $requestedField in ${fields.map(_.name).mkString(", ")}")
161+
case Seq(singleMatch) =>
162+
resolveNesting(rest, GetField(expression, singleMatch.name), resolver)
163+
case multipleMatches =>
164+
sys.error(s"Ambiguous reference to fields ${multipleMatches.mkString(", ")}")
165+
}
166+
case (_, dt) => sys.error(s"Can't access nested field in type $dt")
167+
}
168+
}
131169
}
132170

133171
/**
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
0

sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveResolutionSuite.scala

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,9 @@ class HiveResolutionSuite extends HiveComparisonTest {
3636
createQueryTest("database.table table.attr",
3737
"SELECT src.key FROM default.src ORDER BY key LIMIT 1")
3838

39+
createQueryTest("database.table table.attr case insensitive",
40+
"SELECT SRC.Key FROM Default.Src ORDER BY key LIMIT 1")
41+
3942
createQueryTest("alias.attr",
4043
"SELECT a.key FROM src a ORDER BY key LIMIT 1")
4144

@@ -56,7 +59,10 @@ class HiveResolutionSuite extends HiveComparisonTest {
5659
TestHive.sparkContext.parallelize(Data(1, 2, Nested(1,2), Seq(Nested(1,2))) :: Nil)
5760
.registerTempTable("caseSensitivityTest")
5861

59-
sql("SELECT a, b, A, B, n.a, n.b, n.A, n.B FROM caseSensitivityTest")
62+
val query = sql("SELECT a, b, A, B, n.a, n.b, n.A, n.B FROM caseSensitivityTest")
63+
assert(query.schema.fields.map(_.name) === Seq("a", "b", "A", "B", "a", "b", "A", "B"),
64+
"The output schema did not preserve the case of the query.")
65+
query.collect()
6066
}
6167

6268
ignore("case insensitivity with scala reflection joins") {

0 commit comments

Comments
 (0)