Skip to content

Commit ec40d23

Browse files
committed
unify GetStructField and GetInternalRowField
1 parent 426004a commit ec40d23

File tree

9 files changed

+18
-42
lines changed

9 files changed

+18
-42
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@ object ScalaReflection extends ScalaReflection {
130130

131131
/** Returns the current path with a field at ordinal extracted. */
132132
def addToPathOrdinal(ordinal: Int, dataType: DataType): Expression = path
133-
.map(p => GetInternalRowField(p, ordinal, dataType))
133+
.map(p => GetStructField(p, ordinal))
134134
.getOrElse(BoundReference(ordinal, dataType, false))
135135

136136
/** Returns the current path or `BoundReference`. */

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -201,12 +201,12 @@ case class UnresolvedStar(target: Option[Seq[String]]) extends Star with Unevalu
201201
if (attribute.isDefined) {
202202
// This target resolved to an attribute in child. It must be a struct. Expand it.
203203
attribute.get.dataType match {
204-
case s: StructType => {
205-
s.fields.map( f => {
206-
val extract = GetStructField(attribute.get, f, s.getFieldIndex(f.name).get)
204+
case s: StructType => s.zipWithIndex.map {
205+
case (f, i) =>
206+
val extract = GetStructField(attribute.get, i)
207207
Alias(extract, target.get + "." + f.name)()
208-
})
209208
}
209+
210210
case _ => {
211211
throw new AnalysisException("Can only star expand struct data types. Attribute: `" +
212212
target.get + "`")

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ object ExpressionEncoder {
100100
case UnresolvedAttribute(nameParts) =>
101101
assert(nameParts.length == 1)
102102
UnresolvedExtractValue(input, Literal(nameParts.head))
103-
case BoundReference(ordinal, dt, _) => GetInternalRowField(input, ordinal, dt)
103+
case BoundReference(ordinal, dt, _) => GetStructField(input, ordinal)
104104
}
105105
}
106106
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -220,7 +220,7 @@ object RowEncoder {
220220
If(
221221
Invoke(input, "isNullAt", BooleanType, Literal(i) :: Nil),
222222
Literal.create(null, externalDataTypeFor(f.dataType)),
223-
constructorFor(GetInternalRowField(input, i, f.dataType)))
223+
constructorFor(GetStructField(input, i)))
224224
}
225225
CreateExternalRow(convertedFields)
226226
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -206,7 +206,7 @@ abstract class Expression extends TreeNode[Expression] {
206206
*/
207207
def prettyString: String = {
208208
transform {
209-
case a: AttributeReference => PrettyAttribute(a.name)
209+
case a: AttributeReference => PrettyAttribute(a.name, a.dataType)
210210
case u: UnresolvedAttribute => PrettyAttribute(u.name)
211211
}.toString
212212
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ object ExtractValue {
5151
case (StructType(fields), NonNullLiteral(v, StringType)) =>
5252
val fieldName = v.toString
5353
val ordinal = findField(fields, fieldName, resolver)
54-
GetStructField(child, fields(ordinal).copy(name = fieldName), ordinal)
54+
GetStructField(child, ordinal, Some(fieldName))
5555

5656
case (ArrayType(StructType(fields), containsNull), NonNullLiteral(v, StringType)) =>
5757
val fieldName = v.toString
@@ -97,18 +97,15 @@ object ExtractValue {
9797
* Returns the value of fields in the Struct `child`.
9898
*
9999
* No need to do type checking since it is handled by [[ExtractValue]].
100-
* TODO: Unify with [[GetInternalRowField]], remove the need to specify a [[StructField]].
101100
*/
102-
case class GetStructField(child: Expression, field: StructField, ordinal: Int)
101+
case class GetStructField(child: Expression, ordinal: Int, name: Option[String] = None)
103102
extends UnaryExpression {
104103

105-
override def dataType: DataType = child.dataType match {
106-
case s: StructType => s(ordinal).dataType
107-
// This is a hack to avoid breaking existing code until we remove the need for the struct field
108-
case _ => field.dataType
109-
}
104+
private lazy val field = child.dataType.asInstanceOf[StructType](ordinal)
105+
106+
override def dataType: DataType = field.dataType
110107
override def nullable: Boolean = child.nullable || field.nullable
111-
override def toString: String = s"$child.${field.name}"
108+
override def toString: String = s"$child.${name.getOrElse(field.name)}"
112109

113110
protected override def nullSafeEval(input: Any): Any =
114111
input.asInstanceOf[InternalRow].get(ordinal, field.dataType)

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -273,7 +273,8 @@ case class AttributeReference(
273273
* A place holder used when printing expressions without debugging information such as the
274274
* expression id or the unresolved indicator.
275275
*/
276-
case class PrettyAttribute(name: String) extends Attribute with Unevaluable {
276+
case class PrettyAttribute(name: String, dataType: DataType = NullType)
277+
extends Attribute with Unevaluable {
277278

278279
override def toString: String = name
279280

@@ -286,7 +287,6 @@ case class PrettyAttribute(name: String) extends Attribute with Unevaluable {
286287
override def qualifiers: Seq[String] = throw new UnsupportedOperationException
287288
override def exprId: ExprId = throw new UnsupportedOperationException
288289
override def nullable: Boolean = throw new UnsupportedOperationException
289-
override def dataType: DataType = NullType
290290
}
291291

292292
object VirtualColumn {

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala

Lines changed: 0 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -522,27 +522,6 @@ case class CreateExternalRow(children: Seq[Expression]) extends Expression {
522522
}
523523
}
524524

525-
case class GetInternalRowField(child: Expression, ordinal: Int, dataType: DataType)
526-
extends UnaryExpression {
527-
528-
override def nullable: Boolean = true
529-
530-
override def eval(input: InternalRow): Any =
531-
throw new UnsupportedOperationException("Only code-generated evaluation is supported")
532-
533-
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
534-
val row = child.gen(ctx)
535-
s"""
536-
${row.code}
537-
final boolean ${ev.isNull} = ${row.isNull};
538-
${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
539-
if (!${ev.isNull}) {
540-
${ev.value} = ${ctx.getValue(row.value, dataType, ordinal.toString)};
541-
}
542-
"""
543-
}
544-
}
545-
546525
/**
547526
* Serializes an input object using a generic serializer (Kryo or Java).
548527
* @param kryo if true, use Kryo. Otherwise, use Java.

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -79,8 +79,8 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper {
7979
def getStructField(expr: Expression, fieldName: String): GetStructField = {
8080
expr.dataType match {
8181
case StructType(fields) =>
82-
val field = fields.find(_.name == fieldName).get
83-
GetStructField(expr, field, fields.indexOf(field))
82+
val index = fields.indexWhere(_.name == fieldName)
83+
GetStructField(expr, index)
8484
}
8585
}
8686

0 commit comments

Comments
 (0)