-
Notifications
You must be signed in to change notification settings - Fork 28.9k
[SPARK-3698][SQL] Correctly check case sensitivity in GetField #2543
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -17,6 +17,8 @@ | |
|
|
||
| package org.apache.spark.sql.catalyst.expressions | ||
|
|
||
| import org.apache.spark.sql.catalyst.analysis.UnresolvedGetField | ||
|
|
||
| import scala.collection.Map | ||
|
|
||
| import org.apache.spark.sql.catalyst.types._ | ||
|
|
@@ -73,33 +75,38 @@ case class GetItem(child: Expression, ordinal: Expression) extends Expression { | |
| /** | ||
| * Returns the value of fields in the Struct `child`. | ||
| */ | ||
| case class GetField(child: Expression, fieldName: String) extends UnaryExpression { | ||
| case class GetField(child: Expression, field: StructField, ordinal: Int) extends UnaryExpression { | ||
| type EvaluatedType = Any | ||
|
|
||
| def dataType = field.dataType | ||
| override def nullable = child.nullable || field.nullable | ||
| override def foldable = child.foldable | ||
|
|
||
| protected def structType = child.dataType match { | ||
| case s: StructType => s | ||
| case otherType => sys.error(s"GetField is not valid on fields of type $otherType") | ||
| } | ||
|
|
||
| lazy val field = | ||
| structType.fields | ||
| .find(_.name == fieldName) | ||
| .getOrElse(sys.error(s"No such field $fieldName in ${child.dataType}")) | ||
|
|
||
| lazy val ordinal = structType.fields.indexOf(field) | ||
|
|
||
| override lazy val resolved = childrenResolved && child.dataType.isInstanceOf[StructType] | ||
|
|
||
| override def eval(input: Row): Any = { | ||
| val baseValue = child.eval(input).asInstanceOf[Row] | ||
| if (baseValue == null) null else baseValue(ordinal) | ||
| } | ||
|
|
||
| override def toString = s"$child.$fieldName" | ||
| override def toString = s"$child.${field.name}" | ||
| } | ||
|
|
||
| object GetField { | ||
| def apply( | ||
| e: Expression, | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Here we can check to see if the field actually exists in Struct, otherwise
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Currently all |
||
| fieldName: String, | ||
| equality: (String, String) => Boolean = _ == _): GetField = { | ||
| val structType = e.dataType match { | ||
| case s: StructType => s | ||
| case otherType => sys.error(s"GetField is not valid on fields of type $otherType") | ||
| } | ||
| val field = structType.fields | ||
| .find(f => equality(f.name, fieldName)) | ||
| .getOrElse(sys.error(s"No such field $fieldName in ${e.dataType}")) | ||
| val ordinal = structType.fields.indexOf(field) | ||
| GetField(e, field, ordinal) | ||
| } | ||
|
|
||
| def apply(ug: UnresolvedGetField): GetField = GetField(ug.child, ug.fieldName) | ||
| } | ||
|
|
||
| /** | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -18,13 +18,12 @@ | |
| package org.apache.spark.sql.catalyst.plans.logical | ||
|
|
||
| import org.apache.spark.Logging | ||
| import org.apache.spark.sql.catalyst.analysis.Resolver | ||
| import org.apache.spark.sql.catalyst.analysis.{Resolver, UnresolvedGetField} | ||
| import org.apache.spark.sql.catalyst.errors.TreeNodeException | ||
| import org.apache.spark.sql.catalyst.expressions._ | ||
| import org.apache.spark.sql.catalyst.plans.QueryPlan | ||
| import org.apache.spark.sql.catalyst.trees.TreeNode | ||
| import org.apache.spark.sql.catalyst.types.StructType | ||
| import org.apache.spark.sql.catalyst.trees | ||
| import org.apache.spark.sql.catalyst.trees.TreeNode | ||
|
|
||
| /** | ||
| * Estimates of various statistics. The default estimation logic simply lazily multiplies the | ||
|
|
@@ -160,11 +159,7 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging { | |
|
|
||
| // One match, but we also need to extract the requested nested field. | ||
| case Seq((a, nestedFields)) => | ||
| val aliased = | ||
| Alias( | ||
| resolveNesting(nestedFields, a, resolver), | ||
| nestedFields.last)() // Preserve the case of the user's field access. | ||
| Some(aliased) | ||
| Some(Alias(nestedFields.foldLeft(a: Expression)(UnresolvedGetField), nestedFields.last)()) | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For something like |
||
|
|
||
| // No matches. | ||
| case Seq() => | ||
|
|
@@ -177,32 +172,6 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging { | |
| this, s"Ambiguous references to $name: ${ambiguousReferences.mkString(",")}") | ||
| } | ||
| } | ||
|
|
||
| /** | ||
| * Given a list of successive nested field accesses, and a based expression, attempt to resolve | ||
| * the actual field lookups on this expression. | ||
| */ | ||
| private def resolveNesting( | ||
| nestedFields: List[String], | ||
| expression: Expression, | ||
| resolver: Resolver): Expression = { | ||
|
|
||
| (nestedFields, expression.dataType) match { | ||
| case (Nil, _) => expression | ||
| case (requestedField :: rest, StructType(fields)) => | ||
| val actualField = fields.filter(f => resolver(f.name, requestedField)) | ||
| actualField match { | ||
| case Seq() => | ||
| sys.error( | ||
| s"No such struct field $requestedField in ${fields.map(_.name).mkString(", ")}") | ||
| case Seq(singleMatch) => | ||
| resolveNesting(rest, GetField(expression, singleMatch.name), resolver) | ||
| case multipleMatches => | ||
| sys.error(s"Ambiguous reference to fields ${multipleMatches.mkString(", ")}") | ||
| } | ||
| case (_, dt) => sys.error(s"Can't access nested field in type $dt") | ||
| } | ||
| } | ||
| } | ||
|
|
||
| /** | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If possible, I think it might be clearer to keep the resolver logic in the Analyzer rule.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was going to put this logic into Analyzer rule, but found some tests depend on
GetField(child, fieldName), so I have to create this constructor ofGetField. And these two are so similar, so I combine them together. Maybe I should fix those tests instead?