Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -799,6 +799,7 @@ class Analyzer(
def apply(plan: LogicalPlan): LogicalPlan = ResolveTempViews(plan).resolveOperatorsUp {
case u: UnresolvedRelation =>
lookupV2Relation(u.multipartIdentifier)
.map(SubqueryAlias(u.multipartIdentifier, _))
.getOrElse(u)

case u @ UnresolvedTable(NonSessionCatalogAndIdentifier(catalog, ident)) =>
Expand Down Expand Up @@ -923,7 +924,9 @@ class Analyzer(
case v1Table: V1Table =>
v1SessionCatalog.getRelation(v1Table.v1Table)
case table =>
DataSourceV2Relation.create(table, Some(catalog), Some(ident))
SubqueryAlias(
identifier,
DataSourceV2Relation.create(table, Some(catalog), Some(ident)))
}
val key = catalog.name +: ident.namespace :+ ident.name
Option(AnalysisContext.get.relationCache.getOrElseUpdate(key, loaded.orNull))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -236,8 +236,6 @@ case class AttributeReference(
val qualifier: Seq[String] = Seq.empty[String])
extends Attribute with Unevaluable {

// currently can only handle qualifier of length 2
require(qualifier.length <= 2)
/**
* Returns true iff the expression id is the same for both attributes.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ import com.google.common.collect.Maps

import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.analysis.{Resolver, UnresolvedAttribute}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.types.{StructField, StructType}

/**
Expand Down Expand Up @@ -153,13 +152,19 @@ package object expressions {
unique(grouped)
}

/** Perform attribute resolution given a name and a resolver. */
def resolve(nameParts: Seq[String], resolver: Resolver): Option[NamedExpression] = {
/** Returns true if all qualifiers in `attrs` have 2 or less parts. */
@transient private val hasTwoOrLessQualifierParts: Boolean =
attrs.forall(_.qualifier.length <= 2)

/** Match attributes for the case where all qualifiers in `attrs` have 2 or less parts. */
private def matchWithTwoOrLessQualifierParts(
nameParts: Seq[String],
resolver: Resolver): (Seq[Attribute], Seq[String]) = {
// Collect matching attributes given a name and a lookup.
def collectMatches(name: String, candidates: Option[Seq[Attribute]]): Seq[Attribute] = {
candidates.toSeq.flatMap(_.collect {
candidates.getOrElse(Nil).collect {
case a if resolver(a.name, name) => a.withName(name)
})
}
}

// Find matches for the given name assuming that the 1st two parts are qualifier
Expand Down Expand Up @@ -204,13 +209,79 @@ package object expressions {

// If none of attributes match database.table.column pattern or
// `table.column` pattern, we try to resolve it as a column.
val (candidates, nestedFields) = matches match {
matches match {
case (Seq(), _) =>
val name = nameParts.head
val attributes = collectMatches(name, direct.get(name.toLowerCase(Locale.ROOT)))
(attributes, nameParts.tail)
case _ => matches
}
}

/**
* Match attributes for the case where at least one qualifier in `attrs` has more than 2 parts.
*/
private def matchWithThreeOrMoreQualifierParts(
nameParts: Seq[String],
resolver: Resolver): (Seq[Attribute], Seq[String]) = {
// Returns true if the `short` qualifier is a subset of the last elements of
// `long` qualifier. For example, Seq("a", "b") is a subset of Seq("a", "a", "b"),
// but not a subset of Seq("a", "b", "b").
def matchQualifier(short: Seq[String], long: Seq[String]): Boolean = {
(long.length >= short.length) &&
long.takeRight(short.length)
.zip(short)
.forall(x => resolver(x._1, x._2))
}

// Collect attributes that match the given name and qualifier.
// A match occurs if
// 1) the given name matches the attribute's name according to the resolver.
// 2) the given qualifier is a subset of the attribute's qualifier.
def collectMatches(
name: String,
qualifier: Seq[String],
candidates: Option[Seq[Attribute]]): Seq[Attribute] = {
candidates.getOrElse(Nil).collect {
case a if resolver(name, a.name) && matchQualifier(qualifier, a.qualifier) =>
a.withName(name)
}
}

// Iterate each string in `nameParts` in a reverse order and try to match the attributes
// considering the current string as the attribute name. For example, if `nameParts` is
// Seq("a", "b", "c"), the match will be performed in the following order:
// 1) name = "c", qualifier = Seq("a", "b")
// 2) name = "b", qualifier = Seq("a")
// 3) name = "a", qualifier = Seq()
// Note that the match is performed in the reverse order in order to match the longest
// qualifier as possible. If a match is found, the remaining portion of `nameParts`
// is also returned as nested fields.
var candidates: Seq[Attribute] = Nil
var nestedFields: Seq[String] = Nil
var i = nameParts.length - 1
while (i >= 0 && candidates.isEmpty) {
val name = nameParts(i)
candidates = collectMatches(
name,
nameParts.take(i),
direct.get(name.toLowerCase(Locale.ROOT)))
if (candidates.nonEmpty) {
nestedFields = nameParts.takeRight(nameParts.length - i - 1)
}
i -= 1
}

(candidates, nestedFields)
}

/** Perform attribute resolution given a name and a resolver. */
def resolve(nameParts: Seq[String], resolver: Resolver): Option[NamedExpression] = {
val (candidates, nestedFields) = if (hasTwoOrLessQualifierParts) {
matchWithTwoOrLessQualifierParts(nameParts, resolver)
} else {
matchWithThreeOrMoreQualifierParts(nameParts, resolver)
}

def name = UnresolvedAttribute(nameParts).name
candidates match {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,19 +49,21 @@ sealed trait IdentifierWithDatabase {

/**
* Encapsulates an identifier that is either a alias name or an identifier that has table
* name and optionally a database name.
* name and a qualifier.
* The SubqueryAlias node keeps track of the qualifier using the information in this structure
* @param identifier - Is an alias name or a table name
* @param database - Is a database name and is optional
* @param name - Is an alias name or a table name
* @param qualifier - Is a qualifier
*/
case class AliasIdentifier(identifier: String, database: Option[String])
extends IdentifierWithDatabase {
case class AliasIdentifier(name: String, qualifier: Seq[String]) {
import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._

def this(identifier: String) = this(identifier, Seq())

def this(identifier: String) = this(identifier, None)
override def toString: String = (qualifier :+ name).quoted
}

object AliasIdentifier {
def apply(identifier: String): AliasIdentifier = new AliasIdentifier(identifier)
def apply(name: String): AliasIdentifier = new AliasIdentifier(name)
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression,
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning, RangePartitioning, RoundRobinPartitioning}
import org.apache.spark.sql.catalyst.util.truncatedString
import org.apache.spark.sql.connector.catalog.Identifier
import org.apache.spark.sql.types._
import org.apache.spark.util.random.RandomSampler

Expand Down Expand Up @@ -849,18 +850,18 @@ case class Tail(limitExpr: Expression, child: LogicalPlan) extends OrderPreservi
/**
* Aliased subquery.
*
* @param name the alias identifier for this subquery.
* @param identifier the alias identifier for this subquery.
* @param child the logical plan of this subquery.
*/
case class SubqueryAlias(
name: AliasIdentifier,
identifier: AliasIdentifier,
child: LogicalPlan)
extends OrderPreservingUnaryNode {

def alias: String = name.identifier
def alias: String = identifier.name

override def output: Seq[Attribute] = {
val qualifierList = name.database.map(Seq(_, alias)).getOrElse(Seq(alias))
val qualifierList = identifier.qualifier :+ alias
child.output.map(_.withQualifier(qualifierList))
}
override def doCanonicalize(): LogicalPlan = child.canonicalized
Expand All @@ -877,7 +878,13 @@ object SubqueryAlias {
identifier: String,
database: String,
child: LogicalPlan): SubqueryAlias = {
SubqueryAlias(AliasIdentifier(identifier, Some(database)), child)
SubqueryAlias(AliasIdentifier(identifier, Seq(database)), child)
}

def apply(
multipartIdentifier: Seq[String],
child: LogicalPlan): SubqueryAlias = {
SubqueryAlias(AliasIdentifier(multipartIdentifier.last, multipartIdentifier.init), child)
}
}
/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ import org.json4s.JsonAST._
import org.json4s.JsonDSL._
import org.json4s.jackson.JsonMethods._

import org.apache.spark.sql.catalyst.IdentifierWithDatabase
import org.apache.spark.sql.catalyst.{AliasIdentifier, IdentifierWithDatabase}
import org.apache.spark.sql.catalyst.ScalaReflection._
import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogStorageFormat, CatalogTable, CatalogTableType, FunctionResource}
import org.apache.spark.sql.catalyst.errors._
Expand Down Expand Up @@ -780,6 +780,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product {
case exprId: ExprId => true
case field: StructField => true
case id: IdentifierWithDatabase => true
case alias: AliasIdentifier => true
case join: JoinType => true
case spec: BucketSpec => true
case catalog: CatalogTable => true
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.analysis.{caseInsensitiveResolution, caseSensitiveResolution}
import org.apache.spark.sql.types.{IntegerType, StructField, StructType}

class AttributeResolutionSuite extends SparkFunSuite {
val resolver = caseInsensitiveResolution

test("basic attribute resolution with namespaces") {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add a test please where the table name and the column name is the same and make sure resolution works. Something like:

val attrs = Seq(AttributeReference("t", IntegerType)(qualifier = Seq("ns1", "ns2", "t")))
attrs.resolve(Seq("ns1", "ns2", "t"), resolver) match {
      case Some(attr) => fail()
      case _ => fail()
}

val attrs = Seq(
AttributeReference("a", IntegerType)(qualifier = Seq("ns1", "ns2", "t1")),
AttributeReference("a", IntegerType)(qualifier = Seq("ns1", "ns2", "ns3", "t2")))

// Try to match attribute reference with name "a" with qualifier "ns1.ns2.t1".
Seq(Seq("t1", "a"), Seq("ns2", "t1", "a"), Seq("ns1", "ns2", "t1", "a")).foreach { nameParts =>
attrs.resolve(nameParts, resolver) match {
case Some(attr) => assert(attr.semanticEquals(attrs(0)))
case _ => fail()
}
}

// Non-matching cases
Seq(Seq("ns1", "ns2", "t1"), Seq("ns2", "a")).foreach { nameParts =>
assert(attrs.resolve(nameParts, resolver).isEmpty)
}
}

test("attribute resolution where table and attribute names are the same") {
val attrs = Seq(AttributeReference("t", IntegerType)(qualifier = Seq("ns1", "ns2", "t")))
// Matching cases
Seq(
Seq("t"), Seq("t", "t"), Seq("ns2", "t", "t"), Seq("ns1", "ns2", "t", "t")
).foreach { nameParts =>
attrs.resolve(nameParts, resolver) match {
case Some(attr) => assert(attr.semanticEquals(attrs(0)))
case _ => fail()
}
}

// Non-matching case
assert(attrs.resolve(Seq("ns1", "ns2", "t"), resolver).isEmpty)
}

test("attribute resolution ambiguity at the attribute name level") {
val attrs = Seq(
AttributeReference("a", IntegerType)(qualifier = Seq("ns1", "t1")),
AttributeReference("a", IntegerType)(qualifier = Seq("ns1", "ns2", "t2")))

val ex = intercept[AnalysisException] {
attrs.resolve(Seq("a"), resolver)
}
assert(ex.getMessage.contains(
"Reference 'a' is ambiguous, could be: ns1.t1.a, ns1.ns2.t2.a."))
}

test("attribute resolution ambiguity at the qualifier level") {
val attrs = Seq(
AttributeReference("a", IntegerType)(qualifier = Seq("ns1", "t")),
AttributeReference("a", IntegerType)(qualifier = Seq("ns2", "ns1", "t")))

val ex = intercept[AnalysisException] {
attrs.resolve(Seq("ns1", "t", "a"), resolver)
}
assert(ex.getMessage.contains(
"Reference 'ns1.t.a' is ambiguous, could be: ns1.t.a, ns2.ns1.t.a."))
}

test("attribute resolution with nested fields") {
val attrType = StructType(Seq(StructField("aa", IntegerType), StructField("bb", IntegerType)))
val attrs = Seq(AttributeReference("a", attrType)(qualifier = Seq("ns1", "t")))

val resolved = attrs.resolve(Seq("ns1", "t", "a", "aa"), resolver)
resolved match {
case Some(Alias(_, name)) => assert(name == "aa")
case _ => fail()
}

val ex = intercept[AnalysisException] {
attrs.resolve(Seq("ns1", "t", "a", "cc"), resolver)
}
assert(ex.getMessage.contains("No such struct field cc in aa, bb"))
}

test("attribute resolution with case insensitive resolver") {
val attrs = Seq(AttributeReference("a", IntegerType)(qualifier = Seq("ns1", "t")))
attrs.resolve(Seq("Ns1", "T", "A"), caseInsensitiveResolution) match {
case Some(attr) => assert(attr.semanticEquals(attrs(0)) && attr.name == "A")
case _ => fail()
}
}

test("attribute resolution with case sensitive resolver") {
val attrs = Seq(AttributeReference("a", IntegerType)(qualifier = Seq("ns1", "t")))
assert(attrs.resolve(Seq("Ns1", "T", "A"), caseSensitiveResolution).isEmpty)
assert(attrs.resolve(Seq("ns1", "t", "A"), caseSensitiveResolution).isEmpty)
attrs.resolve(Seq("ns1", "t", "a"), caseSensitiveResolution) match {
case Some(attr) => assert(attr.semanticEquals(attrs(0)))
case _ => fail()
}
}

test("attribute resolution should try to match the longest qualifier") {
// We have two attributes:
// 1) "a.b" where "a" is the name and "b" is the nested field.
// 2) "a.b.a" where "b" is the name, left-side "a" is the qualifier and the right-side "a"
// is the nested field.
// When "a.b" is resolved, "b" is tried first as the name, so it is resolved to #2 attribute.
val a1Type = StructType(Seq(StructField("b", IntegerType)))
val a2Type = StructType(Seq(StructField("a", IntegerType)))
val attrs = Seq(
AttributeReference("a", a1Type)(),
AttributeReference("b", a2Type)(qualifier = Seq("a")))
attrs.resolve(Seq("a", "b"), resolver) match {
case Some(attr) => assert(attr.semanticEquals(attrs(1)))
case _ => fail()
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -433,10 +433,11 @@ class TreeNodeSuite extends SparkFunSuite with SQLHelper {

// Converts AliasIdentifier to JSON
assertJSON(
AliasIdentifier("alias"),
AliasIdentifier("alias", Seq("ns1", "ns2")),
JObject(
"product-class" -> JString(classOf[AliasIdentifier].getName),
"identifier" -> "alias"))
"name" -> "alias",
"qualifier" -> "[ns1, ns2]"))

// Converts SubqueryAlias to JSON
assertJSON(
Expand All @@ -445,8 +446,9 @@ class TreeNodeSuite extends SparkFunSuite with SQLHelper {
JObject(
"class" -> classOf[SubqueryAlias].getName,
"num-children" -> 1,
"name" -> JObject("product-class" -> JString(classOf[AliasIdentifier].getName),
"identifier" -> "t1"),
"identifier" -> JObject("product-class" -> JString(classOf[AliasIdentifier].getName),
"name" -> "t1",
"qualifier" -> JArray(Nil)),
"child" -> 0),
JObject(
"class" -> classOf[JsonTestTreeNode].getName,
Expand Down
Loading