Skip to content

Commit 052ad6b

Browse files
committed
physical object operator should define reference correctly
1 parent c36ca65 commit 052ad6b

File tree

3 files changed

+59
-42
lines changed

3 files changed

+59
-42
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -303,7 +303,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
303303
"logical except operator should have been replaced by anti-join in the optimizer")
304304

305305
case logical.DeserializeToObject(deserializer, objAttr, child) =>
306-
execution.DeserializeToObject(deserializer, objAttr, planLater(child)) :: Nil
306+
execution.DeserializeToObjectExec(deserializer, objAttr, planLater(child)) :: Nil
307307
case logical.SerializeFromObject(serializer, child) =>
308308
execution.SerializeFromObjectExec(serializer, planLater(child)) :: Nil
309309
case logical.MapPartitions(f, objAttr, child) =>

sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala

Lines changed: 53 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -28,17 +28,41 @@ import org.apache.spark.sql.catalyst.expressions.objects.Invoke
2828
import org.apache.spark.sql.catalyst.plans.physical._
2929
import org.apache.spark.sql.types.{DataType, ObjectType}
3030

31+
32+
/**
33+
* Physical version of `ObjectProducer`.
34+
*/
35+
trait ObjectProducerExec extends SparkPlan {
36+
// The attribute that reference to the single object field this operator outputs.
37+
protected def outputObjAttr: Attribute
38+
39+
override def output: Seq[Attribute] = outputObjAttr :: Nil
40+
41+
override def producedAttributes: AttributeSet = AttributeSet(outputObjAttr)
42+
43+
def outputObjectType: DataType = outputObjAttr.dataType
44+
}
45+
46+
/**
47+
* Physical version of `ObjectConsumer`.
48+
*/
49+
trait ObjectConsumerExec extends UnaryExecNode {
50+
assert(child.output.length == 1)
51+
52+
// This operator always need all columns of its child, even it doesn't reference to.
53+
override def references: AttributeSet = child.outputSet
54+
55+
def inputObjectType: DataType = child.output.head.dataType
56+
}
57+
3158
/**
3259
* Takes the input row from child and turns it into object using the given deserializer expression.
3360
* The output of this operator is a single-field safe row containing the deserialized object.
3461
*/
35-
case class DeserializeToObject(
62+
case class DeserializeToObjectExec(
3663
deserializer: Expression,
3764
outputObjAttr: Attribute,
38-
child: SparkPlan) extends UnaryExecNode with CodegenSupport {
39-
40-
override def output: Seq[Attribute] = outputObjAttr :: Nil
41-
override def producedAttributes: AttributeSet = AttributeSet(outputObjAttr)
65+
child: SparkPlan) extends UnaryExecNode with ObjectProducerExec with CodegenSupport {
4266

4367
override def inputRDDs(): Seq[RDD[InternalRow]] = {
4468
child.asInstanceOf[CodegenSupport].inputRDDs()
@@ -70,7 +94,7 @@ case class DeserializeToObject(
7094
*/
7195
case class SerializeFromObjectExec(
7296
serializer: Seq[NamedExpression],
73-
child: SparkPlan) extends UnaryExecNode with CodegenSupport {
97+
child: SparkPlan) extends UnaryExecNode with ObjectConsumerExec with CodegenSupport {
7498

7599
override def output: Seq[Attribute] = serializer.map(_.toAttribute)
76100

@@ -102,7 +126,7 @@ case class SerializeFromObjectExec(
102126
/**
103127
* Helper functions for physical operators that work with user defined objects.
104128
*/
105-
trait ObjectOperator extends SparkPlan {
129+
object ObjectOperator {
106130
def deserializeRowToObject(
107131
deserializer: Expression,
108132
inputSchema: Seq[Attribute]): InternalRow => Any = {
@@ -141,15 +165,12 @@ case class MapPartitionsExec(
141165
func: Iterator[Any] => Iterator[Any],
142166
outputObjAttr: Attribute,
143167
child: SparkPlan)
144-
extends UnaryExecNode with ObjectOperator {
145-
146-
override def output: Seq[Attribute] = outputObjAttr :: Nil
147-
override def producedAttributes: AttributeSet = AttributeSet(outputObjAttr)
168+
extends UnaryExecNode with ObjectProducerExec with ObjectConsumerExec {
148169

149170
override protected def doExecute(): RDD[InternalRow] = {
150171
child.execute().mapPartitionsInternal { iter =>
151-
val getObject = unwrapObjectFromRow(child.output.head.dataType)
152-
val outputObject = wrapObjectToRow(outputObjAttr.dataType)
172+
val getObject = ObjectOperator.unwrapObjectFromRow(child.output.head.dataType)
173+
val outputObject = ObjectOperator.wrapObjectToRow(outputObjAttr.dataType)
153174
func(iter.map(getObject)).map(outputObject)
154175
}
155176
}
@@ -166,10 +187,7 @@ case class MapElementsExec(
166187
func: AnyRef,
167188
outputObjAttr: Attribute,
168189
child: SparkPlan)
169-
extends UnaryExecNode with ObjectOperator with CodegenSupport {
170-
171-
override def output: Seq[Attribute] = outputObjAttr :: Nil
172-
override def producedAttributes: AttributeSet = AttributeSet(outputObjAttr)
190+
extends UnaryExecNode with ObjectProducerExec with ObjectConsumerExec with CodegenSupport {
173191

174192
override def inputRDDs(): Seq[RDD[InternalRow]] = {
175193
child.asInstanceOf[CodegenSupport].inputRDDs()
@@ -202,8 +220,8 @@ case class MapElementsExec(
202220
}
203221

204222
child.execute().mapPartitionsInternal { iter =>
205-
val getObject = unwrapObjectFromRow(child.output.head.dataType)
206-
val outputObject = wrapObjectToRow(outputObjAttr.dataType)
223+
val getObject = ObjectOperator.unwrapObjectFromRow(child.output.head.dataType)
224+
val outputObject = ObjectOperator.wrapObjectToRow(outputObjAttr.dataType)
207225
iter.map(row => outputObject(callFunc(getObject(row))))
208226
}
209227
}
@@ -218,17 +236,17 @@ case class AppendColumnsExec(
218236
func: Any => Any,
219237
deserializer: Expression,
220238
serializer: Seq[NamedExpression],
221-
child: SparkPlan) extends UnaryExecNode with ObjectOperator {
239+
child: SparkPlan) extends UnaryExecNode {
222240

223241
override def output: Seq[Attribute] = child.output ++ serializer.map(_.toAttribute)
224242

225243
private def newColumnSchema = serializer.map(_.toAttribute).toStructType
226244

227245
override protected def doExecute(): RDD[InternalRow] = {
228246
child.execute().mapPartitionsInternal { iter =>
229-
val getObject = deserializeRowToObject(deserializer, child.output)
247+
val getObject = ObjectOperator.deserializeRowToObject(deserializer, child.output)
230248
val combiner = GenerateUnsafeRowJoiner.create(child.schema, newColumnSchema)
231-
val outputObject = serializeObjectToRow(serializer)
249+
val outputObject = ObjectOperator.serializeObjectToRow(serializer)
232250

233251
iter.map { row =>
234252
val newColumns = outputObject(func(getObject(row)))
@@ -246,7 +264,7 @@ case class AppendColumnsWithObjectExec(
246264
func: Any => Any,
247265
inputSerializer: Seq[NamedExpression],
248266
newColumnsSerializer: Seq[NamedExpression],
249-
child: SparkPlan) extends UnaryExecNode with ObjectOperator {
267+
child: SparkPlan) extends UnaryExecNode with ObjectConsumerExec {
250268

251269
override def output: Seq[Attribute] = (inputSerializer ++ newColumnsSerializer).map(_.toAttribute)
252270

@@ -255,9 +273,9 @@ case class AppendColumnsWithObjectExec(
255273

256274
override protected def doExecute(): RDD[InternalRow] = {
257275
child.execute().mapPartitionsInternal { iter =>
258-
val getChildObject = unwrapObjectFromRow(child.output.head.dataType)
259-
val outputChildObject = serializeObjectToRow(inputSerializer)
260-
val outputNewColumnOjb = serializeObjectToRow(newColumnsSerializer)
276+
val getChildObject = ObjectOperator.unwrapObjectFromRow(child.output.head.dataType)
277+
val outputChildObject = ObjectOperator.serializeObjectToRow(inputSerializer)
278+
val outputNewColumnOjb = ObjectOperator.serializeObjectToRow(newColumnsSerializer)
261279
val combiner = GenerateUnsafeRowJoiner.create(inputSchema, newColumnSchema)
262280

263281
iter.map { row =>
@@ -280,10 +298,7 @@ case class MapGroupsExec(
280298
groupingAttributes: Seq[Attribute],
281299
dataAttributes: Seq[Attribute],
282300
outputObjAttr: Attribute,
283-
child: SparkPlan) extends UnaryExecNode with ObjectOperator {
284-
285-
override def output: Seq[Attribute] = outputObjAttr :: Nil
286-
override def producedAttributes: AttributeSet = AttributeSet(outputObjAttr)
301+
child: SparkPlan) extends UnaryExecNode with ObjectProducerExec {
287302

288303
override def requiredChildDistribution: Seq[Distribution] =
289304
ClusteredDistribution(groupingAttributes) :: Nil
@@ -295,9 +310,9 @@ case class MapGroupsExec(
295310
child.execute().mapPartitionsInternal { iter =>
296311
val grouped = GroupedIterator(iter, groupingAttributes, child.output)
297312

298-
val getKey = deserializeRowToObject(keyDeserializer, groupingAttributes)
299-
val getValue = deserializeRowToObject(valueDeserializer, dataAttributes)
300-
val outputObject = wrapObjectToRow(outputObjAttr.dataType)
313+
val getKey = ObjectOperator.deserializeRowToObject(keyDeserializer, groupingAttributes)
314+
val getValue = ObjectOperator.deserializeRowToObject(valueDeserializer, dataAttributes)
315+
val outputObject = ObjectOperator.wrapObjectToRow(outputObjAttr.dataType)
301316

302317
grouped.flatMap { case (key, rowIter) =>
303318
val result = func(
@@ -325,10 +340,7 @@ case class CoGroupExec(
325340
rightAttr: Seq[Attribute],
326341
outputObjAttr: Attribute,
327342
left: SparkPlan,
328-
right: SparkPlan) extends BinaryExecNode with ObjectOperator {
329-
330-
override def output: Seq[Attribute] = outputObjAttr :: Nil
331-
override def producedAttributes: AttributeSet = AttributeSet(outputObjAttr)
343+
right: SparkPlan) extends BinaryExecNode with ObjectProducerExec {
332344

333345
override def requiredChildDistribution: Seq[Distribution] =
334346
ClusteredDistribution(leftGroup) :: ClusteredDistribution(rightGroup) :: Nil
@@ -341,10 +353,10 @@ case class CoGroupExec(
341353
val leftGrouped = GroupedIterator(leftData, leftGroup, left.output)
342354
val rightGrouped = GroupedIterator(rightData, rightGroup, right.output)
343355

344-
val getKey = deserializeRowToObject(keyDeserializer, leftGroup)
345-
val getLeft = deserializeRowToObject(leftDeserializer, leftAttr)
346-
val getRight = deserializeRowToObject(rightDeserializer, rightAttr)
347-
val outputObject = wrapObjectToRow(outputObjAttr.dataType)
356+
val getKey = ObjectOperator.deserializeRowToObject(keyDeserializer, leftGroup)
357+
val getLeft = ObjectOperator.deserializeRowToObject(leftDeserializer, leftAttr)
358+
val getRight = ObjectOperator.deserializeRowToObject(rightDeserializer, rightAttr)
359+
val outputObject = ObjectOperator.wrapObjectToRow(outputObjAttr.dataType)
348360

349361
new CoGroupedIterator(leftGrouped, rightGrouped, leftGroup).flatMap {
350362
case (key, leftResult, rightResult) =>

sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -702,6 +702,11 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
702702
assert(e.message.contains("already exists"))
703703
dataset.sparkSession.catalog.dropTempView("tempView")
704704
}
705+
706+
test("SPARK-15381: physical object operator should define `reference` correctly") {
707+
val df = Seq(1 -> 2).toDF("a", "b")
708+
checkAnswer(df.map(row => row)(RowEncoder(df.schema)).select("b", "a"), Row(2, 1))
709+
}
705710
}
706711

707712
case class Generic[T](id: T, value: Double)

0 commit comments

Comments
 (0)