Skip to content

Commit c27a616

Browse files
imback82cloud-fan
authored andcommitted
[SPARK-30612][SQL] Resolve qualified column name with v2 tables
### What changes were proposed in this pull request? This PR fixes the issue where queries with qualified columns like `SELECT t.a FROM t` would fail to resolve for v2 tables. This PR would allow qualified column names in query as following: ```SQL SELECT testcat.ns1.ns2.tbl.foo FROM testcat.ns1.ns2.tbl SELECT ns1.ns2.tbl.foo FROM testcat.ns1.ns2.tbl SELECT ns2.tbl.foo FROM testcat.ns1.ns2.tbl SELECT tbl.foo FROM testcat.ns1.ns2.tbl ``` ### Why are the changes needed? This is a bug because you cannot qualify column names in queries. ### Does this PR introduce any user-facing change? Yes, now users can qualify column names for v2 tables. ### How was this patch tested? Added new tests. Closes #27391 from imback82/qualified_col. Authored-by: Terry Kim <[email protected]> Signed-off-by: Wenchen Fan <[email protected]>
1 parent aebabf0 commit c27a616

File tree

14 files changed

+342
-74
lines changed

14 files changed

+342
-74
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -799,6 +799,7 @@ class Analyzer(
799799
def apply(plan: LogicalPlan): LogicalPlan = ResolveTempViews(plan).resolveOperatorsUp {
800800
case u: UnresolvedRelation =>
801801
lookupV2Relation(u.multipartIdentifier)
802+
.map(SubqueryAlias(u.multipartIdentifier, _))
802803
.getOrElse(u)
803804

804805
case u @ UnresolvedTable(NonSessionCatalogAndIdentifier(catalog, ident)) =>
@@ -923,7 +924,9 @@ class Analyzer(
923924
case v1Table: V1Table =>
924925
v1SessionCatalog.getRelation(v1Table.v1Table)
925926
case table =>
926-
DataSourceV2Relation.create(table, Some(catalog), Some(ident))
927+
SubqueryAlias(
928+
identifier,
929+
DataSourceV2Relation.create(table, Some(catalog), Some(ident)))
927930
}
928931
val key = catalog.name +: ident.namespace :+ ident.name
929932
Option(AnalysisContext.get.relationCache.getOrElseUpdate(key, loaded.orNull))

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -236,8 +236,6 @@ case class AttributeReference(
236236
val qualifier: Seq[String] = Seq.empty[String])
237237
extends Attribute with Unevaluable {
238238

239-
// currently can only handle qualifier of length 2
240-
require(qualifier.length <= 2)
241239
/**
242240
* Returns true iff the expression id is the same for both attributes.
243241
*/

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala

Lines changed: 77 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@ import com.google.common.collect.Maps
2323

2424
import org.apache.spark.sql.AnalysisException
2525
import org.apache.spark.sql.catalyst.analysis.{Resolver, UnresolvedAttribute}
26-
import org.apache.spark.sql.catalyst.expressions._
2726
import org.apache.spark.sql.types.{StructField, StructType}
2827

2928
/**
@@ -153,13 +152,19 @@ package object expressions {
153152
unique(grouped)
154153
}
155154

156-
/** Perform attribute resolution given a name and a resolver. */
157-
def resolve(nameParts: Seq[String], resolver: Resolver): Option[NamedExpression] = {
155+
/** Returns true if all qualifiers in `attrs` have 2 or less parts. */
156+
@transient private val hasTwoOrLessQualifierParts: Boolean =
157+
attrs.forall(_.qualifier.length <= 2)
158+
159+
/** Match attributes for the case where all qualifiers in `attrs` have 2 or less parts. */
160+
private def matchWithTwoOrLessQualifierParts(
161+
nameParts: Seq[String],
162+
resolver: Resolver): (Seq[Attribute], Seq[String]) = {
158163
// Collect matching attributes given a name and a lookup.
159164
def collectMatches(name: String, candidates: Option[Seq[Attribute]]): Seq[Attribute] = {
160-
candidates.toSeq.flatMap(_.collect {
165+
candidates.getOrElse(Nil).collect {
161166
case a if resolver(a.name, name) => a.withName(name)
162-
})
167+
}
163168
}
164169

165170
// Find matches for the given name assuming that the 1st two parts are qualifier
@@ -204,13 +209,79 @@ package object expressions {
204209

205210
// If none of attributes match database.table.column pattern or
206211
// `table.column` pattern, we try to resolve it as a column.
207-
val (candidates, nestedFields) = matches match {
212+
matches match {
208213
case (Seq(), _) =>
209214
val name = nameParts.head
210215
val attributes = collectMatches(name, direct.get(name.toLowerCase(Locale.ROOT)))
211216
(attributes, nameParts.tail)
212217
case _ => matches
213218
}
219+
}
220+
221+
/**
222+
* Match attributes for the case where at least one qualifier in `attrs` has more than 2 parts.
223+
*/
224+
private def matchWithThreeOrMoreQualifierParts(
225+
nameParts: Seq[String],
226+
resolver: Resolver): (Seq[Attribute], Seq[String]) = {
227+
// Returns true if the `short` qualifier is a subset of the last elements of
228+
// `long` qualifier. For example, Seq("a", "b") is a subset of Seq("a", "a", "b"),
229+
// but not a subset of Seq("a", "b", "b").
230+
def matchQualifier(short: Seq[String], long: Seq[String]): Boolean = {
231+
(long.length >= short.length) &&
232+
long.takeRight(short.length)
233+
.zip(short)
234+
.forall(x => resolver(x._1, x._2))
235+
}
236+
237+
// Collect attributes that match the given name and qualifier.
238+
// A match occurs if
239+
// 1) the given name matches the attribute's name according to the resolver.
240+
// 2) the given qualifier is a subset of the attribute's qualifier.
241+
def collectMatches(
242+
name: String,
243+
qualifier: Seq[String],
244+
candidates: Option[Seq[Attribute]]): Seq[Attribute] = {
245+
candidates.getOrElse(Nil).collect {
246+
case a if resolver(name, a.name) && matchQualifier(qualifier, a.qualifier) =>
247+
a.withName(name)
248+
}
249+
}
250+
251+
// Iterate each string in `nameParts` in a reverse order and try to match the attributes
252+
// considering the current string as the attribute name. For example, if `nameParts` is
253+
// Seq("a", "b", "c"), the match will be performed in the following order:
254+
// 1) name = "c", qualifier = Seq("a", "b")
255+
// 2) name = "b", qualifier = Seq("a")
256+
// 3) name = "a", qualifier = Seq()
257+
// Note that the match is performed in the reverse order in order to match the longest
258+
// qualifier as possible. If a match is found, the remaining portion of `nameParts`
259+
// is also returned as nested fields.
260+
var candidates: Seq[Attribute] = Nil
261+
var nestedFields: Seq[String] = Nil
262+
var i = nameParts.length - 1
263+
while (i >= 0 && candidates.isEmpty) {
264+
val name = nameParts(i)
265+
candidates = collectMatches(
266+
name,
267+
nameParts.take(i),
268+
direct.get(name.toLowerCase(Locale.ROOT)))
269+
if (candidates.nonEmpty) {
270+
nestedFields = nameParts.takeRight(nameParts.length - i - 1)
271+
}
272+
i -= 1
273+
}
274+
275+
(candidates, nestedFields)
276+
}
277+
278+
/** Perform attribute resolution given a name and a resolver. */
279+
def resolve(nameParts: Seq[String], resolver: Resolver): Option[NamedExpression] = {
280+
val (candidates, nestedFields) = if (hasTwoOrLessQualifierParts) {
281+
matchWithTwoOrLessQualifierParts(nameParts, resolver)
282+
} else {
283+
matchWithThreeOrMoreQualifierParts(nameParts, resolver)
284+
}
214285

215286
def name = UnresolvedAttribute(nameParts).name
216287
candidates match {

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/identifiers.scala

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -49,19 +49,21 @@ sealed trait IdentifierWithDatabase {
4949

5050
/**
5151
* Encapsulates an identifier that is either a alias name or an identifier that has table
52-
* name and optionally a database name.
52+
* name and a qualifier.
5353
* The SubqueryAlias node keeps track of the qualifier using the information in this structure
54-
* @param identifier - Is an alias name or a table name
55-
* @param database - Is a database name and is optional
54+
* @param name - Is an alias name or a table name
55+
* @param qualifier - Is a qualifier
5656
*/
57-
case class AliasIdentifier(identifier: String, database: Option[String])
58-
extends IdentifierWithDatabase {
57+
case class AliasIdentifier(name: String, qualifier: Seq[String]) {
58+
import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._
59+
60+
def this(identifier: String) = this(identifier, Seq())
5961

60-
def this(identifier: String) = this(identifier, None)
62+
override def toString: String = (qualifier :+ name).quoted
6163
}
6264

6365
object AliasIdentifier {
64-
def apply(identifier: String): AliasIdentifier = new AliasIdentifier(identifier)
66+
def apply(name: String): AliasIdentifier = new AliasIdentifier(name)
6567
}
6668

6769
/**

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression,
2727
import org.apache.spark.sql.catalyst.plans._
2828
import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning, RangePartitioning, RoundRobinPartitioning}
2929
import org.apache.spark.sql.catalyst.util.truncatedString
30+
import org.apache.spark.sql.connector.catalog.Identifier
3031
import org.apache.spark.sql.types._
3132
import org.apache.spark.util.random.RandomSampler
3233

@@ -849,18 +850,18 @@ case class Tail(limitExpr: Expression, child: LogicalPlan) extends OrderPreservi
849850
/**
850851
* Aliased subquery.
851852
*
852-
* @param name the alias identifier for this subquery.
853+
* @param identifier the alias identifier for this subquery.
853854
* @param child the logical plan of this subquery.
854855
*/
855856
case class SubqueryAlias(
856-
name: AliasIdentifier,
857+
identifier: AliasIdentifier,
857858
child: LogicalPlan)
858859
extends OrderPreservingUnaryNode {
859860

860-
def alias: String = name.identifier
861+
def alias: String = identifier.name
861862

862863
override def output: Seq[Attribute] = {
863-
val qualifierList = name.database.map(Seq(_, alias)).getOrElse(Seq(alias))
864+
val qualifierList = identifier.qualifier :+ alias
864865
child.output.map(_.withQualifier(qualifierList))
865866
}
866867
override def doCanonicalize(): LogicalPlan = child.canonicalized
@@ -877,7 +878,13 @@ object SubqueryAlias {
877878
identifier: String,
878879
database: String,
879880
child: LogicalPlan): SubqueryAlias = {
880-
SubqueryAlias(AliasIdentifier(identifier, Some(database)), child)
881+
SubqueryAlias(AliasIdentifier(identifier, Seq(database)), child)
882+
}
883+
884+
def apply(
885+
multipartIdentifier: Seq[String],
886+
child: LogicalPlan): SubqueryAlias = {
887+
SubqueryAlias(AliasIdentifier(multipartIdentifier.last, multipartIdentifier.init), child)
881888
}
882889
}
883890
/**

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ import org.json4s.JsonAST._
2727
import org.json4s.JsonDSL._
2828
import org.json4s.jackson.JsonMethods._
2929

30-
import org.apache.spark.sql.catalyst.IdentifierWithDatabase
30+
import org.apache.spark.sql.catalyst.{AliasIdentifier, IdentifierWithDatabase}
3131
import org.apache.spark.sql.catalyst.ScalaReflection._
3232
import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogStorageFormat, CatalogTable, CatalogTableType, FunctionResource}
3333
import org.apache.spark.sql.catalyst.errors._
@@ -780,6 +780,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product {
780780
case exprId: ExprId => true
781781
case field: StructField => true
782782
case id: IdentifierWithDatabase => true
783+
case alias: AliasIdentifier => true
783784
case join: JoinType => true
784785
case spec: BucketSpec => true
785786
case catalog: CatalogTable => true
Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,137 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.catalyst.expressions
19+
20+
import org.apache.spark.SparkFunSuite
21+
import org.apache.spark.sql.AnalysisException
22+
import org.apache.spark.sql.catalyst.analysis.{caseInsensitiveResolution, caseSensitiveResolution}
23+
import org.apache.spark.sql.types.{IntegerType, StructField, StructType}
24+
25+
class AttributeResolutionSuite extends SparkFunSuite {
26+
val resolver = caseInsensitiveResolution
27+
28+
test("basic attribute resolution with namespaces") {
29+
val attrs = Seq(
30+
AttributeReference("a", IntegerType)(qualifier = Seq("ns1", "ns2", "t1")),
31+
AttributeReference("a", IntegerType)(qualifier = Seq("ns1", "ns2", "ns3", "t2")))
32+
33+
// Try to match attribute reference with name "a" with qualifier "ns1.ns2.t1".
34+
Seq(Seq("t1", "a"), Seq("ns2", "t1", "a"), Seq("ns1", "ns2", "t1", "a")).foreach { nameParts =>
35+
attrs.resolve(nameParts, resolver) match {
36+
case Some(attr) => assert(attr.semanticEquals(attrs(0)))
37+
case _ => fail()
38+
}
39+
}
40+
41+
// Non-matching cases
42+
Seq(Seq("ns1", "ns2", "t1"), Seq("ns2", "a")).foreach { nameParts =>
43+
assert(attrs.resolve(nameParts, resolver).isEmpty)
44+
}
45+
}
46+
47+
test("attribute resolution where table and attribute names are the same") {
48+
val attrs = Seq(AttributeReference("t", IntegerType)(qualifier = Seq("ns1", "ns2", "t")))
49+
// Matching cases
50+
Seq(
51+
Seq("t"), Seq("t", "t"), Seq("ns2", "t", "t"), Seq("ns1", "ns2", "t", "t")
52+
).foreach { nameParts =>
53+
attrs.resolve(nameParts, resolver) match {
54+
case Some(attr) => assert(attr.semanticEquals(attrs(0)))
55+
case _ => fail()
56+
}
57+
}
58+
59+
// Non-matching case
60+
assert(attrs.resolve(Seq("ns1", "ns2", "t"), resolver).isEmpty)
61+
}
62+
63+
test("attribute resolution ambiguity at the attribute name level") {
64+
val attrs = Seq(
65+
AttributeReference("a", IntegerType)(qualifier = Seq("ns1", "t1")),
66+
AttributeReference("a", IntegerType)(qualifier = Seq("ns1", "ns2", "t2")))
67+
68+
val ex = intercept[AnalysisException] {
69+
attrs.resolve(Seq("a"), resolver)
70+
}
71+
assert(ex.getMessage.contains(
72+
"Reference 'a' is ambiguous, could be: ns1.t1.a, ns1.ns2.t2.a."))
73+
}
74+
75+
test("attribute resolution ambiguity at the qualifier level") {
76+
val attrs = Seq(
77+
AttributeReference("a", IntegerType)(qualifier = Seq("ns1", "t")),
78+
AttributeReference("a", IntegerType)(qualifier = Seq("ns2", "ns1", "t")))
79+
80+
val ex = intercept[AnalysisException] {
81+
attrs.resolve(Seq("ns1", "t", "a"), resolver)
82+
}
83+
assert(ex.getMessage.contains(
84+
"Reference 'ns1.t.a' is ambiguous, could be: ns1.t.a, ns2.ns1.t.a."))
85+
}
86+
87+
test("attribute resolution with nested fields") {
88+
val attrType = StructType(Seq(StructField("aa", IntegerType), StructField("bb", IntegerType)))
89+
val attrs = Seq(AttributeReference("a", attrType)(qualifier = Seq("ns1", "t")))
90+
91+
val resolved = attrs.resolve(Seq("ns1", "t", "a", "aa"), resolver)
92+
resolved match {
93+
case Some(Alias(_, name)) => assert(name == "aa")
94+
case _ => fail()
95+
}
96+
97+
val ex = intercept[AnalysisException] {
98+
attrs.resolve(Seq("ns1", "t", "a", "cc"), resolver)
99+
}
100+
assert(ex.getMessage.contains("No such struct field cc in aa, bb"))
101+
}
102+
103+
test("attribute resolution with case insensitive resolver") {
104+
val attrs = Seq(AttributeReference("a", IntegerType)(qualifier = Seq("ns1", "t")))
105+
attrs.resolve(Seq("Ns1", "T", "A"), caseInsensitiveResolution) match {
106+
case Some(attr) => assert(attr.semanticEquals(attrs(0)) && attr.name == "A")
107+
case _ => fail()
108+
}
109+
}
110+
111+
test("attribute resolution with case sensitive resolver") {
112+
val attrs = Seq(AttributeReference("a", IntegerType)(qualifier = Seq("ns1", "t")))
113+
assert(attrs.resolve(Seq("Ns1", "T", "A"), caseSensitiveResolution).isEmpty)
114+
assert(attrs.resolve(Seq("ns1", "t", "A"), caseSensitiveResolution).isEmpty)
115+
attrs.resolve(Seq("ns1", "t", "a"), caseSensitiveResolution) match {
116+
case Some(attr) => assert(attr.semanticEquals(attrs(0)))
117+
case _ => fail()
118+
}
119+
}
120+
121+
test("attribute resolution should try to match the longest qualifier") {
122+
// We have two attributes:
123+
// 1) "a.b" where "a" is the name and "b" is the nested field.
124+
// 2) "a.b.a" where "b" is the name, left-side "a" is the qualifier and the right-side "a"
125+
// is the nested field.
126+
// When "a.b" is resolved, "b" is tried first as the name, so it is resolved to #2 attribute.
127+
val a1Type = StructType(Seq(StructField("b", IntegerType)))
128+
val a2Type = StructType(Seq(StructField("a", IntegerType)))
129+
val attrs = Seq(
130+
AttributeReference("a", a1Type)(),
131+
AttributeReference("b", a2Type)(qualifier = Seq("a")))
132+
attrs.resolve(Seq("a", "b"), resolver) match {
133+
case Some(attr) => assert(attr.semanticEquals(attrs(1)))
134+
case _ => fail()
135+
}
136+
}
137+
}

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -433,10 +433,11 @@ class TreeNodeSuite extends SparkFunSuite with SQLHelper {
433433

434434
// Converts AliasIdentifier to JSON
435435
assertJSON(
436-
AliasIdentifier("alias"),
436+
AliasIdentifier("alias", Seq("ns1", "ns2")),
437437
JObject(
438438
"product-class" -> JString(classOf[AliasIdentifier].getName),
439-
"identifier" -> "alias"))
439+
"name" -> "alias",
440+
"qualifier" -> "[ns1, ns2]"))
440441

441442
// Converts SubqueryAlias to JSON
442443
assertJSON(
@@ -445,8 +446,9 @@ class TreeNodeSuite extends SparkFunSuite with SQLHelper {
445446
JObject(
446447
"class" -> classOf[SubqueryAlias].getName,
447448
"num-children" -> 1,
448-
"name" -> JObject("product-class" -> JString(classOf[AliasIdentifier].getName),
449-
"identifier" -> "t1"),
449+
"identifier" -> JObject("product-class" -> JString(classOf[AliasIdentifier].getName),
450+
"name" -> "t1",
451+
"qualifier" -> JArray(Nil)),
450452
"child" -> 0),
451453
JObject(
452454
"class" -> classOf[JsonTestTreeNode].getName,

0 commit comments

Comments
 (0)