Skip to content

Commit 34c743c

Browse files
cloud-fandavies
authored andcommitted
[SPARK-15381] [SQL] physical object operator should define reference correctly
## What changes were proposed in this pull request? Whole Stage Codegen depends on `SparkPlan.reference` to do some optimization. For physical object operators, they should be consistent with their logical version and set the `reference` correctly. ## How was this patch tested? new test in DatasetSuite Author: Wenchen Fan <[email protected]> Closes #13167 from cloud-fan/bug. (cherry picked from commit 661c210) Signed-off-by: Davies Liu <[email protected]>
1 parent a1948a0 commit 34c743c

File tree

4 files changed

+64
-47
lines changed

4 files changed

+64
-47
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ case class DeserializeToObject(
9494
*/
9595
case class SerializeFromObject(
9696
serializer: Seq[NamedExpression],
97-
child: LogicalPlan) extends UnaryNode with ObjectConsumer {
97+
child: LogicalPlan) extends ObjectConsumer {
9898

9999
override def output: Seq[Attribute] = serializer.map(_.toAttribute)
100100
}
@@ -118,7 +118,7 @@ object MapPartitions {
118118
case class MapPartitions(
119119
func: Iterator[Any] => Iterator[Any],
120120
outputObjAttr: Attribute,
121-
child: LogicalPlan) extends UnaryNode with ObjectConsumer with ObjectProducer
121+
child: LogicalPlan) extends ObjectConsumer with ObjectProducer
122122

123123
object MapPartitionsInR {
124124
def apply(
@@ -152,7 +152,7 @@ case class MapPartitionsInR(
152152
inputSchema: StructType,
153153
outputSchema: StructType,
154154
outputObjAttr: Attribute,
155-
child: LogicalPlan) extends UnaryNode with ObjectConsumer with ObjectProducer {
155+
child: LogicalPlan) extends ObjectConsumer with ObjectProducer {
156156
override lazy val schema = outputSchema
157157
}
158158

@@ -175,7 +175,7 @@ object MapElements {
175175
case class MapElements(
176176
func: AnyRef,
177177
outputObjAttr: Attribute,
178-
child: LogicalPlan) extends UnaryNode with ObjectConsumer with ObjectProducer
178+
child: LogicalPlan) extends ObjectConsumer with ObjectProducer
179179

180180
/** Factory for constructing new `AppendColumn` nodes. */
181181
object AppendColumns {
@@ -215,7 +215,7 @@ case class AppendColumnsWithObject(
215215
func: Any => Any,
216216
childSerializer: Seq[NamedExpression],
217217
newColumnsSerializer: Seq[NamedExpression],
218-
child: LogicalPlan) extends UnaryNode with ObjectConsumer {
218+
child: LogicalPlan) extends ObjectConsumer {
219219

220220
override def output: Seq[Attribute] = (childSerializer ++ newColumnsSerializer).map(_.toAttribute)
221221
}

sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -303,7 +303,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
303303
"logical except operator should have been replaced by anti-join in the optimizer")
304304

305305
case logical.DeserializeToObject(deserializer, objAttr, child) =>
306-
execution.DeserializeToObject(deserializer, objAttr, planLater(child)) :: Nil
306+
execution.DeserializeToObjectExec(deserializer, objAttr, planLater(child)) :: Nil
307307
case logical.SerializeFromObject(serializer, child) =>
308308
execution.SerializeFromObjectExec(serializer, planLater(child)) :: Nil
309309
case logical.MapPartitions(f, objAttr, child) =>

sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala

Lines changed: 53 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -28,17 +28,41 @@ import org.apache.spark.sql.catalyst.expressions.objects.Invoke
2828
import org.apache.spark.sql.catalyst.plans.physical._
2929
import org.apache.spark.sql.types.{DataType, ObjectType}
3030

31+
32+
/**
33+
* Physical version of `ObjectProducer`.
34+
*/
35+
trait ObjectProducerExec extends SparkPlan {
36+
// The attribute that reference to the single object field this operator outputs.
37+
protected def outputObjAttr: Attribute
38+
39+
override def output: Seq[Attribute] = outputObjAttr :: Nil
40+
41+
override def producedAttributes: AttributeSet = AttributeSet(outputObjAttr)
42+
43+
def outputObjectType: DataType = outputObjAttr.dataType
44+
}
45+
46+
/**
47+
* Physical version of `ObjectConsumer`.
48+
*/
49+
trait ObjectConsumerExec extends UnaryExecNode {
50+
assert(child.output.length == 1)
51+
52+
// This operator always need all columns of its child, even it doesn't reference to.
53+
override def references: AttributeSet = child.outputSet
54+
55+
def inputObjectType: DataType = child.output.head.dataType
56+
}
57+
3158
/**
3259
* Takes the input row from child and turns it into object using the given deserializer expression.
3360
* The output of this operator is a single-field safe row containing the deserialized object.
3461
*/
35-
case class DeserializeToObject(
62+
case class DeserializeToObjectExec(
3663
deserializer: Expression,
3764
outputObjAttr: Attribute,
38-
child: SparkPlan) extends UnaryExecNode with CodegenSupport {
39-
40-
override def output: Seq[Attribute] = outputObjAttr :: Nil
41-
override def producedAttributes: AttributeSet = AttributeSet(outputObjAttr)
65+
child: SparkPlan) extends UnaryExecNode with ObjectProducerExec with CodegenSupport {
4266

4367
override def inputRDDs(): Seq[RDD[InternalRow]] = {
4468
child.asInstanceOf[CodegenSupport].inputRDDs()
@@ -70,7 +94,7 @@ case class DeserializeToObject(
7094
*/
7195
case class SerializeFromObjectExec(
7296
serializer: Seq[NamedExpression],
73-
child: SparkPlan) extends UnaryExecNode with CodegenSupport {
97+
child: SparkPlan) extends ObjectConsumerExec with CodegenSupport {
7498

7599
override def output: Seq[Attribute] = serializer.map(_.toAttribute)
76100

@@ -102,7 +126,7 @@ case class SerializeFromObjectExec(
102126
/**
103127
* Helper functions for physical operators that work with user defined objects.
104128
*/
105-
trait ObjectOperator extends SparkPlan {
129+
object ObjectOperator {
106130
def deserializeRowToObject(
107131
deserializer: Expression,
108132
inputSchema: Seq[Attribute]): InternalRow => Any = {
@@ -141,15 +165,12 @@ case class MapPartitionsExec(
141165
func: Iterator[Any] => Iterator[Any],
142166
outputObjAttr: Attribute,
143167
child: SparkPlan)
144-
extends UnaryExecNode with ObjectOperator {
145-
146-
override def output: Seq[Attribute] = outputObjAttr :: Nil
147-
override def producedAttributes: AttributeSet = AttributeSet(outputObjAttr)
168+
extends ObjectConsumerExec with ObjectProducerExec {
148169

149170
override protected def doExecute(): RDD[InternalRow] = {
150171
child.execute().mapPartitionsInternal { iter =>
151-
val getObject = unwrapObjectFromRow(child.output.head.dataType)
152-
val outputObject = wrapObjectToRow(outputObjAttr.dataType)
172+
val getObject = ObjectOperator.unwrapObjectFromRow(child.output.head.dataType)
173+
val outputObject = ObjectOperator.wrapObjectToRow(outputObjAttr.dataType)
153174
func(iter.map(getObject)).map(outputObject)
154175
}
155176
}
@@ -166,10 +187,7 @@ case class MapElementsExec(
166187
func: AnyRef,
167188
outputObjAttr: Attribute,
168189
child: SparkPlan)
169-
extends UnaryExecNode with ObjectOperator with CodegenSupport {
170-
171-
override def output: Seq[Attribute] = outputObjAttr :: Nil
172-
override def producedAttributes: AttributeSet = AttributeSet(outputObjAttr)
190+
extends ObjectConsumerExec with ObjectProducerExec with CodegenSupport {
173191

174192
override def inputRDDs(): Seq[RDD[InternalRow]] = {
175193
child.asInstanceOf[CodegenSupport].inputRDDs()
@@ -202,8 +220,8 @@ case class MapElementsExec(
202220
}
203221

204222
child.execute().mapPartitionsInternal { iter =>
205-
val getObject = unwrapObjectFromRow(child.output.head.dataType)
206-
val outputObject = wrapObjectToRow(outputObjAttr.dataType)
223+
val getObject = ObjectOperator.unwrapObjectFromRow(child.output.head.dataType)
224+
val outputObject = ObjectOperator.wrapObjectToRow(outputObjAttr.dataType)
207225
iter.map(row => outputObject(callFunc(getObject(row))))
208226
}
209227
}
@@ -218,17 +236,17 @@ case class AppendColumnsExec(
218236
func: Any => Any,
219237
deserializer: Expression,
220238
serializer: Seq[NamedExpression],
221-
child: SparkPlan) extends UnaryExecNode with ObjectOperator {
239+
child: SparkPlan) extends UnaryExecNode {
222240

223241
override def output: Seq[Attribute] = child.output ++ serializer.map(_.toAttribute)
224242

225243
private def newColumnSchema = serializer.map(_.toAttribute).toStructType
226244

227245
override protected def doExecute(): RDD[InternalRow] = {
228246
child.execute().mapPartitionsInternal { iter =>
229-
val getObject = deserializeRowToObject(deserializer, child.output)
247+
val getObject = ObjectOperator.deserializeRowToObject(deserializer, child.output)
230248
val combiner = GenerateUnsafeRowJoiner.create(child.schema, newColumnSchema)
231-
val outputObject = serializeObjectToRow(serializer)
249+
val outputObject = ObjectOperator.serializeObjectToRow(serializer)
232250

233251
iter.map { row =>
234252
val newColumns = outputObject(func(getObject(row)))
@@ -246,7 +264,7 @@ case class AppendColumnsWithObjectExec(
246264
func: Any => Any,
247265
inputSerializer: Seq[NamedExpression],
248266
newColumnsSerializer: Seq[NamedExpression],
249-
child: SparkPlan) extends UnaryExecNode with ObjectOperator {
267+
child: SparkPlan) extends ObjectConsumerExec {
250268

251269
override def output: Seq[Attribute] = (inputSerializer ++ newColumnsSerializer).map(_.toAttribute)
252270

@@ -255,9 +273,9 @@ case class AppendColumnsWithObjectExec(
255273

256274
override protected def doExecute(): RDD[InternalRow] = {
257275
child.execute().mapPartitionsInternal { iter =>
258-
val getChildObject = unwrapObjectFromRow(child.output.head.dataType)
259-
val outputChildObject = serializeObjectToRow(inputSerializer)
260-
val outputNewColumnOjb = serializeObjectToRow(newColumnsSerializer)
276+
val getChildObject = ObjectOperator.unwrapObjectFromRow(child.output.head.dataType)
277+
val outputChildObject = ObjectOperator.serializeObjectToRow(inputSerializer)
278+
val outputNewColumnOjb = ObjectOperator.serializeObjectToRow(newColumnsSerializer)
261279
val combiner = GenerateUnsafeRowJoiner.create(inputSchema, newColumnSchema)
262280

263281
iter.map { row =>
@@ -280,10 +298,7 @@ case class MapGroupsExec(
280298
groupingAttributes: Seq[Attribute],
281299
dataAttributes: Seq[Attribute],
282300
outputObjAttr: Attribute,
283-
child: SparkPlan) extends UnaryExecNode with ObjectOperator {
284-
285-
override def output: Seq[Attribute] = outputObjAttr :: Nil
286-
override def producedAttributes: AttributeSet = AttributeSet(outputObjAttr)
301+
child: SparkPlan) extends UnaryExecNode with ObjectProducerExec {
287302

288303
override def requiredChildDistribution: Seq[Distribution] =
289304
ClusteredDistribution(groupingAttributes) :: Nil
@@ -295,9 +310,9 @@ case class MapGroupsExec(
295310
child.execute().mapPartitionsInternal { iter =>
296311
val grouped = GroupedIterator(iter, groupingAttributes, child.output)
297312

298-
val getKey = deserializeRowToObject(keyDeserializer, groupingAttributes)
299-
val getValue = deserializeRowToObject(valueDeserializer, dataAttributes)
300-
val outputObject = wrapObjectToRow(outputObjAttr.dataType)
313+
val getKey = ObjectOperator.deserializeRowToObject(keyDeserializer, groupingAttributes)
314+
val getValue = ObjectOperator.deserializeRowToObject(valueDeserializer, dataAttributes)
315+
val outputObject = ObjectOperator.wrapObjectToRow(outputObjAttr.dataType)
301316

302317
grouped.flatMap { case (key, rowIter) =>
303318
val result = func(
@@ -325,10 +340,7 @@ case class CoGroupExec(
325340
rightAttr: Seq[Attribute],
326341
outputObjAttr: Attribute,
327342
left: SparkPlan,
328-
right: SparkPlan) extends BinaryExecNode with ObjectOperator {
329-
330-
override def output: Seq[Attribute] = outputObjAttr :: Nil
331-
override def producedAttributes: AttributeSet = AttributeSet(outputObjAttr)
343+
right: SparkPlan) extends BinaryExecNode with ObjectProducerExec {
332344

333345
override def requiredChildDistribution: Seq[Distribution] =
334346
ClusteredDistribution(leftGroup) :: ClusteredDistribution(rightGroup) :: Nil
@@ -341,10 +353,10 @@ case class CoGroupExec(
341353
val leftGrouped = GroupedIterator(leftData, leftGroup, left.output)
342354
val rightGrouped = GroupedIterator(rightData, rightGroup, right.output)
343355

344-
val getKey = deserializeRowToObject(keyDeserializer, leftGroup)
345-
val getLeft = deserializeRowToObject(leftDeserializer, leftAttr)
346-
val getRight = deserializeRowToObject(rightDeserializer, rightAttr)
347-
val outputObject = wrapObjectToRow(outputObjAttr.dataType)
356+
val getKey = ObjectOperator.deserializeRowToObject(keyDeserializer, leftGroup)
357+
val getLeft = ObjectOperator.deserializeRowToObject(leftDeserializer, leftAttr)
358+
val getRight = ObjectOperator.deserializeRowToObject(rightDeserializer, rightAttr)
359+
val outputObject = ObjectOperator.wrapObjectToRow(outputObjAttr.dataType)
348360

349361
new CoGroupedIterator(leftGrouped, rightGrouped, leftGroup).flatMap {
350362
case (key, leftResult, rightResult) =>

sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -711,6 +711,11 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
711711
assert(e.message.contains("already exists"))
712712
dataset.sparkSession.catalog.dropTempView("tempView")
713713
}
714+
715+
test("SPARK-15381: physical object operator should define `reference` correctly") {
716+
val df = Seq(1 -> 2).toDF("a", "b")
717+
checkAnswer(df.map(row => row)(RowEncoder(df.schema)).select("b", "a"), Row(2, 1))
718+
}
714719
}
715720

716721
case class Generic[T](id: T, value: Double)

0 commit comments

Comments
 (0)