Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -406,7 +406,7 @@ case class WrapOption(child: Expression, optType: DataType)
}

/**
* A place holder for the loop variable used in [[MapObjects]]. This should never be constructed
* A placeholder for the loop variable used in [[MapObjects]]. This should never be constructed
* manually, but will instead be passed into the provided lambda function.
*/
case class LambdaVariable(
Expand All @@ -421,6 +421,27 @@ case class LambdaVariable(
}
}

/**
* When constructing [[MapObjects]], the element type must be given, which may not be available
* before analysis. This class acts like a placeholder for [[MapObjects]], and will be replaced by
* [[MapObjects]] during analysis after the input data is resolved.
* Note that, ideally we should not serialize and send unresolved expressions to executors, but
* users may accidentally do this(e.g. mistakenly reference an encoder instance when implementing
* Aggregator). Here we mark `function` as transient because it may reference scala Type, which is
* not serializable. Then even users mistakenly reference unresolved expression and serialize it,
* it's just a performance issue(more network traffic), and will not fail.
*/
case class UnresolvedMapObjects(
@transient function: Expression => Expression,
child: Expression,
customCollectionCls: Option[Class[_]] = None) extends UnaryExpression with Unevaluable {
override lazy val resolved = false

override def dataType: DataType = customCollectionCls.map(ObjectType.apply).getOrElse {
throw new UnsupportedOperationException("not resolved")
}
}

object MapObjects {
private val curId = new java.util.concurrent.atomic.AtomicInteger()

Expand All @@ -442,20 +463,8 @@ object MapObjects {
val loopValue = s"MapObjects_loopValue$id"
val loopIsNull = s"MapObjects_loopIsNull$id"
val loopVar = LambdaVariable(loopValue, loopIsNull, elementType)
val builderValue = s"MapObjects_builderValue$id"
MapObjects(loopValue, loopIsNull, elementType, function(loopVar), inputData,
customCollectionCls, builderValue)
}
}

case class UnresolvedMapObjects(
function: Expression => Expression,
child: Expression,
customCollectionCls: Option[Class[_]] = None) extends UnaryExpression with Unevaluable {
override lazy val resolved = false

override def dataType: DataType = customCollectionCls.map(ObjectType.apply).getOrElse {
throw new UnsupportedOperationException("not resolved")
MapObjects(
loopValue, loopIsNull, elementType, function(loopVar), inputData, customCollectionCls)
}
}

Expand All @@ -482,17 +491,14 @@ case class UnresolvedMapObjects(
* @param inputData An expression that when evaluated returns a collection object.
* @param customCollectionCls Class of the resulting collection (returning ObjectType)
* or None (returning ArrayType)
* @param builderValue The name of the builder variable used to construct the resulting collection
* (used only when returning ObjectType)
*/
case class MapObjects private(
loopValue: String,
loopIsNull: String,
loopVarDataType: DataType,
lambdaFunction: Expression,
inputData: Expression,
customCollectionCls: Option[Class[_]],
builderValue: String) extends Expression with NonSQLExpression {
customCollectionCls: Option[Class[_]]) extends Expression with NonSQLExpression {

override def nullable: Boolean = inputData.nullable

Expand Down Expand Up @@ -590,15 +596,15 @@ case class MapObjects private(
customCollectionCls match {
case Some(cls) =>
// collection
val collObjectName = s"${cls.getName}$$.MODULE$$"
val getBuilderVar = s"$collObjectName.newBuilder()"
val getBuilder = s"${cls.getName}$$.MODULE$$.newBuilder()"
val builder = ctx.freshName("collectionBuilder")
(
s"""
${classOf[Builder[_, _]].getName} $builderValue = $getBuilderVar;
$builderValue.sizeHint($dataLength);
${classOf[Builder[_, _]].getName} $builder = $getBuilder;
$builder.sizeHint($dataLength);
""",
genValue => s"$builderValue.$$plus$$eq($genValue);",
s"(${cls.getName}) $builderValue.result();"
genValue => s"$builder.$$plus$$eq($genValue);",
s"(${cls.getName}) $builder.result();"
)
case None =>
// array
Expand Down