Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,6 @@

package org.apache.spark.sql.catalyst.analysis

import java.lang.reflect.Modifier

import scala.annotation.tailrec
import scala.collection.mutable.ArrayBuffer

Expand Down Expand Up @@ -87,9 +85,11 @@ class Analyzer(
Batch("Resolution", fixedPoint,
ResolveRelations ::
ResolveReferences ::
ResolveDeserializer ::
ResolveNewInstance ::
ResolveUpCast ::
ResolveGroupingAnalytics ::
ResolvePivot ::
ResolveUpCast ::
ResolveOrdinalInOrderByAndGroupBy ::
ResolveSortReferences ::
ResolveGenerate ::
Expand Down Expand Up @@ -499,18 +499,9 @@ class Analyzer(
Generate(newG.asInstanceOf[Generator], join, outer, qualifier, output, child)
}

// A special case for ObjectOperator, because the deserializer expressions in ObjectOperator
// should be resolved by their corresponding attributes instead of children's output.
case o: ObjectOperator if containsUnresolvedDeserializer(o.deserializers.map(_._1)) =>
val deserializerToAttributes = o.deserializers.map {
case (deserializer, attributes) => new TreeNodeRef(deserializer) -> attributes
}.toMap

o.transformExpressions {
case expr => deserializerToAttributes.get(new TreeNodeRef(expr)).map { attributes =>
resolveDeserializer(expr, attributes)
}.getOrElse(expr)
}
// Skips plan which contains deserializer expressions, as they should be resolved by another
// rule: ResolveDeserializer.
case plan if containsDeserializer(plan.expressions) => plan

case q: LogicalPlan =>
logTrace(s"Attempting to resolve ${q.simpleString}")
Expand All @@ -526,38 +517,6 @@ class Analyzer(
}
}

private def containsUnresolvedDeserializer(exprs: Seq[Expression]): Boolean = {
exprs.exists { expr =>
!expr.resolved || expr.find(_.isInstanceOf[BoundReference]).isDefined
}
}

def resolveDeserializer(
deserializer: Expression,
attributes: Seq[Attribute]): Expression = {
val unbound = deserializer transform {
case b: BoundReference => attributes(b.ordinal)
}

resolveExpression(unbound, LocalRelation(attributes), throws = true) transform {
case n: NewInstance
// If this is an inner class of another class, register the outer object in `OuterScopes`.
// Note that static inner classes (e.g., inner classes within Scala objects) don't need
// outer pointer registration.
if n.outerPointer.isEmpty &&
n.cls.isMemberClass &&
!Modifier.isStatic(n.cls.getModifiers) =>
val outer = OuterScopes.getOuterScope(n.cls)
if (outer == null) {
throw new AnalysisException(
s"Unable to generate an encoder for inner class `${n.cls.getName}` without " +
"access to the scope that this class was defined in.\n" +
"Try moving this class out of its parent class.")
}
n.copy(outerPointer = Some(outer))
}
}

def newAliases(expressions: Seq[NamedExpression]): Seq[NamedExpression] = {
expressions.map {
case a: Alias => Alias(a.child, a.name)(isGenerated = a.isGenerated)
Expand Down Expand Up @@ -623,6 +582,10 @@ class Analyzer(
}
}

private def containsDeserializer(exprs: Seq[Expression]): Boolean = {
exprs.exists(_.find(_.isInstanceOf[UnresolvedDeserializer]).isDefined)
}

protected[sql] def resolveExpression(
expr: Expression,
plan: LogicalPlan,
Expand Down Expand Up @@ -1475,7 +1438,94 @@ class Analyzer(
Project(projectList, Join(left, right, joinType, newCondition))
}

/**
* Replaces [[UnresolvedDeserializer]] with the deserialization expression that has been resolved
* to the given input attributes.
*/
object ResolveDeserializer extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
case p if !p.childrenResolved => p
case p if p.resolved => p

case p => p transformExpressions {
case UnresolvedDeserializer(deserializer, inputAttributes) =>
val inputs = if (inputAttributes.isEmpty) {
p.children.flatMap(_.output)
} else {
inputAttributes
}
val unbound = deserializer transform {
case b: BoundReference => inputs(b.ordinal)
}
resolveExpression(unbound, LocalRelation(inputs), throws = true)
}
}
}

/**
* Resolves [[NewInstance]] by finding and adding the outer scope to it if the object being
* constructed is an inner class.
*/
object ResolveNewInstance extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
case p if !p.childrenResolved => p
case p if p.resolved => p

case p => p transformExpressions {
case n: NewInstance if n.childrenResolved && !n.resolved =>
val outer = OuterScopes.getOuterScope(n.cls)
if (outer == null) {
throw new AnalysisException(
s"Unable to generate an encoder for inner class `${n.cls.getName}` without " +
"access to the scope that this class was defined in.\n" +
"Try moving this class out of its parent class.")
}
n.copy(outerPointer = Some(outer))
}
}
}

/**
* Replace the [[UpCast]] expression by [[Cast]], and throw exceptions if the cast may truncate.
*/
object ResolveUpCast extends Rule[LogicalPlan] {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Move rules related to deserializer expressions together.

private def fail(from: Expression, to: DataType, walkedTypePath: Seq[String]) = {
throw new AnalysisException(s"Cannot up cast ${from.sql} from " +
s"${from.dataType.simpleString} to ${to.simpleString} as it may truncate\n" +
"The type path of the target object is:\n" + walkedTypePath.mkString("", "\n", "\n") +
"You can either add an explicit cast to the input data or choose a higher precision " +
"type of the field in the target object")
}

private def illegalNumericPrecedence(from: DataType, to: DataType): Boolean = {
val fromPrecedence = HiveTypeCoercion.numericPrecedence.indexOf(from)
val toPrecedence = HiveTypeCoercion.numericPrecedence.indexOf(to)
toPrecedence > 0 && fromPrecedence > toPrecedence
}

def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
case p if !p.childrenResolved => p
case p if p.resolved => p

case p => p transformExpressions {
case u @ UpCast(child, _, _) if !child.resolved => u

case UpCast(child, dataType, walkedTypePath) => (child.dataType, dataType) match {
case (from: NumericType, to: DecimalType) if !to.isWiderThan(from) =>
fail(child, to, walkedTypePath)
case (from: DecimalType, to: NumericType) if !from.isTighterThan(to) =>
fail(child, to, walkedTypePath)
case (from, to) if illegalNumericPrecedence(from, to) =>
fail(child, to, walkedTypePath)
case (TimestampType, DateType) =>
fail(child, DateType, walkedTypePath)
case (StringType, to: NumericType) =>
fail(child, to, walkedTypePath)
case _ => Cast(child, dataType.asNullable)
}
}
}
}
}

/**
Expand Down Expand Up @@ -1559,45 +1609,6 @@ object CleanupAliases extends Rule[LogicalPlan] {
}
}

/**
* Replace the `UpCast` expression by `Cast`, and throw exceptions if the cast may truncate.
*/
object ResolveUpCast extends Rule[LogicalPlan] {
private def fail(from: Expression, to: DataType, walkedTypePath: Seq[String]) = {
throw new AnalysisException(s"Cannot up cast ${from.sql} from " +
s"${from.dataType.simpleString} to ${to.simpleString} as it may truncate\n" +
"The type path of the target object is:\n" + walkedTypePath.mkString("", "\n", "\n") +
"You can either add an explicit cast to the input data or choose a higher precision " +
"type of the field in the target object")
}

private def illegalNumericPrecedence(from: DataType, to: DataType): Boolean = {
val fromPrecedence = HiveTypeCoercion.numericPrecedence.indexOf(from)
val toPrecedence = HiveTypeCoercion.numericPrecedence.indexOf(to)
toPrecedence > 0 && fromPrecedence > toPrecedence
}

def apply(plan: LogicalPlan): LogicalPlan = {
plan transformAllExpressions {
case u @ UpCast(child, _, _) if !child.resolved => u

case UpCast(child, dataType, walkedTypePath) => (child.dataType, dataType) match {
case (from: NumericType, to: DecimalType) if !to.isWiderThan(from) =>
fail(child, to, walkedTypePath)
case (from: DecimalType, to: NumericType) if !from.isTighterThan(to) =>
fail(child, to, walkedTypePath)
case (from, to) if illegalNumericPrecedence(from, to) =>
fail(child, to, walkedTypePath)
case (TimestampType, DateType) =>
fail(child, DateType, walkedTypePath)
case (StringType, to: NumericType) =>
fail(child, to, walkedTypePath)
case _ => Cast(child, dataType.asNullable)
}
}
}
}

/**
* Maps a time column to multiple time windows using the Expand operator. Since it's non-trivial to
* figure out how many windows a time column can map to, we over-estimate the number of windows and
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -307,3 +307,25 @@ case class UnresolvedAlias(child: Expression, aliasName: Option[String] = None)

override lazy val resolved = false
}

/**
* Holds the deserializer expression and the attributes that are available during the resolution
* for it. Deserializer expression is a special kind of expression that is not always resolved by
* children output, but by given attributes, e.g. the `keyDeserializer` in `MapGroups` should be
* resolved by `groupingAttributes` instead of children output.
*
* @param deserializer The unresolved deserializer expression
* @param inputAttributes The input attributes used to resolve deserializer expression, can be empty
* if we want to resolve deserializer by children output.
*/
case class UnresolvedDeserializer(deserializer: Expression, inputAttributes: Seq[Attribute])
extends UnaryExpression with Unevaluable with NonSQLExpression {
// The input attributes used to resolve deserializer expression must be all resolved.
require(inputAttributes.forall(_.resolved), "Input attributes must all be resolved.")

override def child: Expression = deserializer
override def dataType: DataType = throw new UnresolvedException(this, "dataType")
override def foldable: Boolean = throw new UnresolvedException(this, "foldable")
override def nullable: Boolean = throw new UnresolvedException(this, "nullable")
override lazy val resolved = false
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ import scala.reflect.runtime.universe.{typeTag, TypeTag}

import org.apache.spark.sql.{AnalysisException, Encoder}
import org.apache.spark.sql.catalyst.{InternalRow, JavaTypeInference, ScalaReflection}
import org.apache.spark.sql.catalyst.analysis.{SimpleAnalyzer, UnresolvedAttribute, UnresolvedExtractValue}
import org.apache.spark.sql.catalyst.analysis.{SimpleAnalyzer, UnresolvedAttribute, UnresolvedDeserializer, UnresolvedExtractValue}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateSafeProjection, GenerateUnsafeProjection}
import org.apache.spark.sql.catalyst.optimizer.SimplifyCasts
Expand Down Expand Up @@ -317,11 +317,11 @@ case class ExpressionEncoder[T](
def resolve(
schema: Seq[Attribute],
outerScopes: ConcurrentMap[String, AnyRef]): ExpressionEncoder[T] = {
val resolved = SimpleAnalyzer.ResolveReferences.resolveDeserializer(deserializer, schema)

// Make a fake plan to wrap the deserializer, so that we can go though the whole analyzer, check
// analysis, go through optimizer, etc.
val plan = Project(Alias(resolved, "")() :: Nil, LocalRelation(schema))
val plan = Project(
Alias(UnresolvedDeserializer(deserializer, schema), "")() :: Nil,
LocalRelation(schema))
val analyzedPlan = SimpleAnalyzer.execute(plan)
SimpleAnalyzer.checkAnalysis(analyzedPlan)
copy(deserializer = SimplifyCasts(analyzedPlan).expressions.head.children.head)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

package org.apache.spark.sql.catalyst.expressions

import java.lang.reflect.Modifier

import scala.annotation.tailrec
import scala.language.existentials
import scala.reflect.ClassTag
Expand Down Expand Up @@ -112,7 +114,7 @@ case class Invoke(
arguments: Seq[Expression] = Nil) extends Expression with NonSQLExpression {

override def nullable: Boolean = true
override def children: Seq[Expression] = arguments.+:(targetObject)
override def children: Seq[Expression] = targetObject +: arguments

override def eval(input: InternalRow): Any =
throw new UnsupportedOperationException("Only code-generated evaluation is supported.")
Expand Down Expand Up @@ -214,6 +216,16 @@ case class NewInstance(

override def children: Seq[Expression] = arguments

override lazy val resolved: Boolean = {
// If the class to construct is an inner class, we need to get its outer pointer, or this
// expression should be regarded as unresolved.
// Note that static inner classes (e.g., inner classes within Scala objects) don't need
// outer pointer registration.
val needOuterPointer =
outerPointer.isEmpty && cls.isMemberClass && !Modifier.isStatic(cls.getModifiers)
childrenResolved && !needOuterPointer
}

override def eval(input: InternalRow): Any =
throw new UnsupportedOperationException("Only code-generated evaluation is supported.")

Expand Down
Loading