Skip to content

Commit f77f11c

Browse files
cloud-fanmarmbrus
authored andcommitted
[SPARK-14345][SQL] Decouple deserializer expression resolution from ObjectOperator
## What changes were proposed in this pull request? This PR decouples deserializer expression resolution from `ObjectOperator`, so that we can use deserializer expression in normal operators. This is needed by #12061 and #12067 , I abstracted the logic out and put them in this PR to reduce code change in the future. ## How was this patch tested? existing tests. Author: Wenchen Fan <[email protected]> Closes #12131 from cloud-fan/separate.
1 parent e4bd504 commit f77f11c

File tree

5 files changed

+153
-126
lines changed

5 files changed

+153
-126
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala

Lines changed: 97 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,6 @@
1717

1818
package org.apache.spark.sql.catalyst.analysis
1919

20-
import java.lang.reflect.Modifier
21-
2220
import scala.annotation.tailrec
2321
import scala.collection.mutable.ArrayBuffer
2422

@@ -87,9 +85,11 @@ class Analyzer(
8785
Batch("Resolution", fixedPoint,
8886
ResolveRelations ::
8987
ResolveReferences ::
88+
ResolveDeserializer ::
89+
ResolveNewInstance ::
90+
ResolveUpCast ::
9091
ResolveGroupingAnalytics ::
9192
ResolvePivot ::
92-
ResolveUpCast ::
9393
ResolveOrdinalInOrderByAndGroupBy ::
9494
ResolveSortReferences ::
9595
ResolveGenerate ::
@@ -499,18 +499,9 @@ class Analyzer(
499499
Generate(newG.asInstanceOf[Generator], join, outer, qualifier, output, child)
500500
}
501501

502-
// A special case for ObjectOperator, because the deserializer expressions in ObjectOperator
503-
// should be resolved by their corresponding attributes instead of children's output.
504-
case o: ObjectOperator if containsUnresolvedDeserializer(o.deserializers.map(_._1)) =>
505-
val deserializerToAttributes = o.deserializers.map {
506-
case (deserializer, attributes) => new TreeNodeRef(deserializer) -> attributes
507-
}.toMap
508-
509-
o.transformExpressions {
510-
case expr => deserializerToAttributes.get(new TreeNodeRef(expr)).map { attributes =>
511-
resolveDeserializer(expr, attributes)
512-
}.getOrElse(expr)
513-
}
502+
// Skips plan which contains deserializer expressions, as they should be resolved by another
503+
// rule: ResolveDeserializer.
504+
case plan if containsDeserializer(plan.expressions) => plan
514505

515506
case q: LogicalPlan =>
516507
logTrace(s"Attempting to resolve ${q.simpleString}")
@@ -526,38 +517,6 @@ class Analyzer(
526517
}
527518
}
528519

529-
private def containsUnresolvedDeserializer(exprs: Seq[Expression]): Boolean = {
530-
exprs.exists { expr =>
531-
!expr.resolved || expr.find(_.isInstanceOf[BoundReference]).isDefined
532-
}
533-
}
534-
535-
def resolveDeserializer(
536-
deserializer: Expression,
537-
attributes: Seq[Attribute]): Expression = {
538-
val unbound = deserializer transform {
539-
case b: BoundReference => attributes(b.ordinal)
540-
}
541-
542-
resolveExpression(unbound, LocalRelation(attributes), throws = true) transform {
543-
case n: NewInstance
544-
// If this is an inner class of another class, register the outer object in `OuterScopes`.
545-
// Note that static inner classes (e.g., inner classes within Scala objects) don't need
546-
// outer pointer registration.
547-
if n.outerPointer.isEmpty &&
548-
n.cls.isMemberClass &&
549-
!Modifier.isStatic(n.cls.getModifiers) =>
550-
val outer = OuterScopes.getOuterScope(n.cls)
551-
if (outer == null) {
552-
throw new AnalysisException(
553-
s"Unable to generate an encoder for inner class `${n.cls.getName}` without " +
554-
"access to the scope that this class was defined in.\n" +
555-
"Try moving this class out of its parent class.")
556-
}
557-
n.copy(outerPointer = Some(outer))
558-
}
559-
}
560-
561520
def newAliases(expressions: Seq[NamedExpression]): Seq[NamedExpression] = {
562521
expressions.map {
563522
case a: Alias => Alias(a.child, a.name)(isGenerated = a.isGenerated)
@@ -623,6 +582,10 @@ class Analyzer(
623582
}
624583
}
625584

585+
private def containsDeserializer(exprs: Seq[Expression]): Boolean = {
586+
exprs.exists(_.find(_.isInstanceOf[UnresolvedDeserializer]).isDefined)
587+
}
588+
626589
protected[sql] def resolveExpression(
627590
expr: Expression,
628591
plan: LogicalPlan,
@@ -1475,7 +1438,94 @@ class Analyzer(
14751438
Project(projectList, Join(left, right, joinType, newCondition))
14761439
}
14771440

1441+
/**
1442+
* Replaces [[UnresolvedDeserializer]] with the deserialization expression that has been resolved
1443+
* to the given input attributes.
1444+
*/
1445+
object ResolveDeserializer extends Rule[LogicalPlan] {
1446+
def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
1447+
case p if !p.childrenResolved => p
1448+
case p if p.resolved => p
14781449

1450+
case p => p transformExpressions {
1451+
case UnresolvedDeserializer(deserializer, inputAttributes) =>
1452+
val inputs = if (inputAttributes.isEmpty) {
1453+
p.children.flatMap(_.output)
1454+
} else {
1455+
inputAttributes
1456+
}
1457+
val unbound = deserializer transform {
1458+
case b: BoundReference => inputs(b.ordinal)
1459+
}
1460+
resolveExpression(unbound, LocalRelation(inputs), throws = true)
1461+
}
1462+
}
1463+
}
1464+
1465+
/**
1466+
* Resolves [[NewInstance]] by finding and adding the outer scope to it if the object being
1467+
* constructed is an inner class.
1468+
*/
1469+
object ResolveNewInstance extends Rule[LogicalPlan] {
1470+
def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
1471+
case p if !p.childrenResolved => p
1472+
case p if p.resolved => p
1473+
1474+
case p => p transformExpressions {
1475+
case n: NewInstance if n.childrenResolved && !n.resolved =>
1476+
val outer = OuterScopes.getOuterScope(n.cls)
1477+
if (outer == null) {
1478+
throw new AnalysisException(
1479+
s"Unable to generate an encoder for inner class `${n.cls.getName}` without " +
1480+
"access to the scope that this class was defined in.\n" +
1481+
"Try moving this class out of its parent class.")
1482+
}
1483+
n.copy(outerPointer = Some(outer))
1484+
}
1485+
}
1486+
}
1487+
1488+
/**
1489+
* Replace the [[UpCast]] expression by [[Cast]], and throw exceptions if the cast may truncate.
1490+
*/
1491+
object ResolveUpCast extends Rule[LogicalPlan] {
1492+
private def fail(from: Expression, to: DataType, walkedTypePath: Seq[String]) = {
1493+
throw new AnalysisException(s"Cannot up cast ${from.sql} from " +
1494+
s"${from.dataType.simpleString} to ${to.simpleString} as it may truncate\n" +
1495+
"The type path of the target object is:\n" + walkedTypePath.mkString("", "\n", "\n") +
1496+
"You can either add an explicit cast to the input data or choose a higher precision " +
1497+
"type of the field in the target object")
1498+
}
1499+
1500+
private def illegalNumericPrecedence(from: DataType, to: DataType): Boolean = {
1501+
val fromPrecedence = HiveTypeCoercion.numericPrecedence.indexOf(from)
1502+
val toPrecedence = HiveTypeCoercion.numericPrecedence.indexOf(to)
1503+
toPrecedence > 0 && fromPrecedence > toPrecedence
1504+
}
1505+
1506+
def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
1507+
case p if !p.childrenResolved => p
1508+
case p if p.resolved => p
1509+
1510+
case p => p transformExpressions {
1511+
case u @ UpCast(child, _, _) if !child.resolved => u
1512+
1513+
case UpCast(child, dataType, walkedTypePath) => (child.dataType, dataType) match {
1514+
case (from: NumericType, to: DecimalType) if !to.isWiderThan(from) =>
1515+
fail(child, to, walkedTypePath)
1516+
case (from: DecimalType, to: NumericType) if !from.isTighterThan(to) =>
1517+
fail(child, to, walkedTypePath)
1518+
case (from, to) if illegalNumericPrecedence(from, to) =>
1519+
fail(child, to, walkedTypePath)
1520+
case (TimestampType, DateType) =>
1521+
fail(child, DateType, walkedTypePath)
1522+
case (StringType, to: NumericType) =>
1523+
fail(child, to, walkedTypePath)
1524+
case _ => Cast(child, dataType.asNullable)
1525+
}
1526+
}
1527+
}
1528+
}
14791529
}
14801530

14811531
/**
@@ -1559,45 +1609,6 @@ object CleanupAliases extends Rule[LogicalPlan] {
15591609
}
15601610
}
15611611

1562-
/**
1563-
* Replace the `UpCast` expression by `Cast`, and throw exceptions if the cast may truncate.
1564-
*/
1565-
object ResolveUpCast extends Rule[LogicalPlan] {
1566-
private def fail(from: Expression, to: DataType, walkedTypePath: Seq[String]) = {
1567-
throw new AnalysisException(s"Cannot up cast ${from.sql} from " +
1568-
s"${from.dataType.simpleString} to ${to.simpleString} as it may truncate\n" +
1569-
"The type path of the target object is:\n" + walkedTypePath.mkString("", "\n", "\n") +
1570-
"You can either add an explicit cast to the input data or choose a higher precision " +
1571-
"type of the field in the target object")
1572-
}
1573-
1574-
private def illegalNumericPrecedence(from: DataType, to: DataType): Boolean = {
1575-
val fromPrecedence = HiveTypeCoercion.numericPrecedence.indexOf(from)
1576-
val toPrecedence = HiveTypeCoercion.numericPrecedence.indexOf(to)
1577-
toPrecedence > 0 && fromPrecedence > toPrecedence
1578-
}
1579-
1580-
def apply(plan: LogicalPlan): LogicalPlan = {
1581-
plan transformAllExpressions {
1582-
case u @ UpCast(child, _, _) if !child.resolved => u
1583-
1584-
case UpCast(child, dataType, walkedTypePath) => (child.dataType, dataType) match {
1585-
case (from: NumericType, to: DecimalType) if !to.isWiderThan(from) =>
1586-
fail(child, to, walkedTypePath)
1587-
case (from: DecimalType, to: NumericType) if !from.isTighterThan(to) =>
1588-
fail(child, to, walkedTypePath)
1589-
case (from, to) if illegalNumericPrecedence(from, to) =>
1590-
fail(child, to, walkedTypePath)
1591-
case (TimestampType, DateType) =>
1592-
fail(child, DateType, walkedTypePath)
1593-
case (StringType, to: NumericType) =>
1594-
fail(child, to, walkedTypePath)
1595-
case _ => Cast(child, dataType.asNullable)
1596-
}
1597-
}
1598-
}
1599-
}
1600-
16011612
/**
16021613
* Maps a time column to multiple time windows using the Expand operator. Since it's non-trivial to
16031614
* figure out how many windows a time column can map to, we over-estimate the number of windows and

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -307,3 +307,25 @@ case class UnresolvedAlias(child: Expression, aliasName: Option[String] = None)
307307

308308
override lazy val resolved = false
309309
}
310+
311+
/**
312+
* Holds the deserializer expression and the attributes that are available during the resolution
313+
* for it. Deserializer expression is a special kind of expression that is not always resolved by
314+
* children output, but by given attributes, e.g. the `keyDeserializer` in `MapGroups` should be
315+
* resolved by `groupingAttributes` instead of children output.
316+
*
317+
* @param deserializer The unresolved deserializer expression
318+
* @param inputAttributes The input attributes used to resolve deserializer expression, can be empty
319+
* if we want to resolve deserializer by children output.
320+
*/
321+
case class UnresolvedDeserializer(deserializer: Expression, inputAttributes: Seq[Attribute])
322+
extends UnaryExpression with Unevaluable with NonSQLExpression {
323+
// The input attributes used to resolve deserializer expression must be all resolved.
324+
require(inputAttributes.forall(_.resolved), "Input attributes must all be resolved.")
325+
326+
override def child: Expression = deserializer
327+
override def dataType: DataType = throw new UnresolvedException(this, "dataType")
328+
override def foldable: Boolean = throw new UnresolvedException(this, "foldable")
329+
override def nullable: Boolean = throw new UnresolvedException(this, "nullable")
330+
override lazy val resolved = false
331+
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ import scala.reflect.runtime.universe.{typeTag, TypeTag}
2424

2525
import org.apache.spark.sql.{AnalysisException, Encoder}
2626
import org.apache.spark.sql.catalyst.{InternalRow, JavaTypeInference, ScalaReflection}
27-
import org.apache.spark.sql.catalyst.analysis.{SimpleAnalyzer, UnresolvedAttribute, UnresolvedExtractValue}
27+
import org.apache.spark.sql.catalyst.analysis.{SimpleAnalyzer, UnresolvedAttribute, UnresolvedDeserializer, UnresolvedExtractValue}
2828
import org.apache.spark.sql.catalyst.expressions._
2929
import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateSafeProjection, GenerateUnsafeProjection}
3030
import org.apache.spark.sql.catalyst.optimizer.SimplifyCasts
@@ -317,11 +317,11 @@ case class ExpressionEncoder[T](
317317
def resolve(
318318
schema: Seq[Attribute],
319319
outerScopes: ConcurrentMap[String, AnyRef]): ExpressionEncoder[T] = {
320-
val resolved = SimpleAnalyzer.ResolveReferences.resolveDeserializer(deserializer, schema)
321-
322320
// Make a fake plan to wrap the deserializer, so that we can go though the whole analyzer, check
323321
// analysis, go through optimizer, etc.
324-
val plan = Project(Alias(resolved, "")() :: Nil, LocalRelation(schema))
322+
val plan = Project(
323+
Alias(UnresolvedDeserializer(deserializer, schema), "")() :: Nil,
324+
LocalRelation(schema))
325325
val analyzedPlan = SimpleAnalyzer.execute(plan)
326326
SimpleAnalyzer.checkAnalysis(analyzedPlan)
327327
copy(deserializer = SimplifyCasts(analyzedPlan).expressions.head.children.head)

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717

1818
package org.apache.spark.sql.catalyst.expressions
1919

20+
import java.lang.reflect.Modifier
21+
2022
import scala.annotation.tailrec
2123
import scala.language.existentials
2224
import scala.reflect.ClassTag
@@ -112,7 +114,7 @@ case class Invoke(
112114
arguments: Seq[Expression] = Nil) extends Expression with NonSQLExpression {
113115

114116
override def nullable: Boolean = true
115-
override def children: Seq[Expression] = arguments.+:(targetObject)
117+
override def children: Seq[Expression] = targetObject +: arguments
116118

117119
override def eval(input: InternalRow): Any =
118120
throw new UnsupportedOperationException("Only code-generated evaluation is supported.")
@@ -214,6 +216,16 @@ case class NewInstance(
214216

215217
override def children: Seq[Expression] = arguments
216218

219+
override lazy val resolved: Boolean = {
220+
// If the class to construct is an inner class, we need to get its outer pointer, or this
221+
// expression should be regarded as unresolved.
222+
// Note that static inner classes (e.g., inner classes within Scala objects) don't need
223+
// outer pointer registration.
224+
val needOuterPointer =
225+
outerPointer.isEmpty && cls.isMemberClass && !Modifier.isStatic(cls.getModifiers)
226+
childrenResolved && !needOuterPointer
227+
}
228+
217229
override def eval(input: InternalRow): Any =
218230
throw new UnsupportedOperationException("Only code-generated evaluation is supported.")
219231

0 commit comments

Comments
 (0)