@@ -28,17 +28,41 @@ import org.apache.spark.sql.catalyst.expressions.objects.Invoke
2828import org .apache .spark .sql .catalyst .plans .physical ._
2929import org .apache .spark .sql .types .{DataType , ObjectType }
3030
31+
32+ /**
33+ * Physical version of `ObjectProducer`.
34+ */
35+ trait ObjectProducerExec extends SparkPlan {
36+ // The attribute that reference to the single object field this operator outputs.
37+ protected def outputObjAttr : Attribute
38+
39+ override def output : Seq [Attribute ] = outputObjAttr :: Nil
40+
41+ override def producedAttributes : AttributeSet = AttributeSet (outputObjAttr)
42+
43+ def outputObjectType : DataType = outputObjAttr.dataType
44+ }
45+
46+ /**
47+ * Physical version of `ObjectConsumer`.
48+ */
49+ trait ObjectConsumerExec extends UnaryExecNode {
50+ assert(child.output.length == 1 )
51+
52+ // This operator always need all columns of its child, even it doesn't reference to.
53+ override def references : AttributeSet = child.outputSet
54+
55+ def inputObjectType : DataType = child.output.head.dataType
56+ }
57+
3158/**
3259 * Takes the input row from child and turns it into object using the given deserializer expression.
3360 * The output of this operator is a single-field safe row containing the deserialized object.
3461 */
35- case class DeserializeToObject (
62+ case class DeserializeToObjectExec (
3663 deserializer : Expression ,
3764 outputObjAttr : Attribute ,
38- child : SparkPlan ) extends UnaryExecNode with CodegenSupport {
39-
40- override def output : Seq [Attribute ] = outputObjAttr :: Nil
41- override def producedAttributes : AttributeSet = AttributeSet (outputObjAttr)
65+ child : SparkPlan ) extends UnaryExecNode with ObjectProducerExec with CodegenSupport {
4266
4367 override def inputRDDs (): Seq [RDD [InternalRow ]] = {
4468 child.asInstanceOf [CodegenSupport ].inputRDDs()
@@ -70,7 +94,7 @@ case class DeserializeToObject(
7094 */
7195case class SerializeFromObjectExec (
7296 serializer : Seq [NamedExpression ],
73- child : SparkPlan ) extends UnaryExecNode with CodegenSupport {
97+ child : SparkPlan ) extends ObjectConsumerExec with CodegenSupport {
7498
7599 override def output : Seq [Attribute ] = serializer.map(_.toAttribute)
76100
@@ -102,7 +126,7 @@ case class SerializeFromObjectExec(
102126/**
103127 * Helper functions for physical operators that work with user defined objects.
104128 */
105- trait ObjectOperator extends SparkPlan {
129+ object ObjectOperator {
106130 def deserializeRowToObject (
107131 deserializer : Expression ,
108132 inputSchema : Seq [Attribute ]): InternalRow => Any = {
@@ -141,15 +165,12 @@ case class MapPartitionsExec(
141165 func : Iterator [Any ] => Iterator [Any ],
142166 outputObjAttr : Attribute ,
143167 child : SparkPlan )
144- extends UnaryExecNode with ObjectOperator {
145-
146- override def output : Seq [Attribute ] = outputObjAttr :: Nil
147- override def producedAttributes : AttributeSet = AttributeSet (outputObjAttr)
168+ extends ObjectConsumerExec with ObjectProducerExec {
148169
149170 override protected def doExecute (): RDD [InternalRow ] = {
150171 child.execute().mapPartitionsInternal { iter =>
151- val getObject = unwrapObjectFromRow(child.output.head.dataType)
152- val outputObject = wrapObjectToRow(outputObjAttr.dataType)
172+ val getObject = ObjectOperator . unwrapObjectFromRow(child.output.head.dataType)
173+ val outputObject = ObjectOperator . wrapObjectToRow(outputObjAttr.dataType)
153174 func(iter.map(getObject)).map(outputObject)
154175 }
155176 }
@@ -166,10 +187,7 @@ case class MapElementsExec(
166187 func : AnyRef ,
167188 outputObjAttr : Attribute ,
168189 child : SparkPlan )
169- extends UnaryExecNode with ObjectOperator with CodegenSupport {
170-
171- override def output : Seq [Attribute ] = outputObjAttr :: Nil
172- override def producedAttributes : AttributeSet = AttributeSet (outputObjAttr)
190+ extends ObjectConsumerExec with ObjectProducerExec with CodegenSupport {
173191
174192 override def inputRDDs (): Seq [RDD [InternalRow ]] = {
175193 child.asInstanceOf [CodegenSupport ].inputRDDs()
@@ -202,8 +220,8 @@ case class MapElementsExec(
202220 }
203221
204222 child.execute().mapPartitionsInternal { iter =>
205- val getObject = unwrapObjectFromRow(child.output.head.dataType)
206- val outputObject = wrapObjectToRow(outputObjAttr.dataType)
223+ val getObject = ObjectOperator . unwrapObjectFromRow(child.output.head.dataType)
224+ val outputObject = ObjectOperator . wrapObjectToRow(outputObjAttr.dataType)
207225 iter.map(row => outputObject(callFunc(getObject(row))))
208226 }
209227 }
@@ -218,17 +236,17 @@ case class AppendColumnsExec(
218236 func : Any => Any ,
219237 deserializer : Expression ,
220238 serializer : Seq [NamedExpression ],
221- child : SparkPlan ) extends UnaryExecNode with ObjectOperator {
239+ child : SparkPlan ) extends UnaryExecNode {
222240
223241 override def output : Seq [Attribute ] = child.output ++ serializer.map(_.toAttribute)
224242
225243 private def newColumnSchema = serializer.map(_.toAttribute).toStructType
226244
227245 override protected def doExecute (): RDD [InternalRow ] = {
228246 child.execute().mapPartitionsInternal { iter =>
229- val getObject = deserializeRowToObject(deserializer, child.output)
247+ val getObject = ObjectOperator . deserializeRowToObject(deserializer, child.output)
230248 val combiner = GenerateUnsafeRowJoiner .create(child.schema, newColumnSchema)
231- val outputObject = serializeObjectToRow(serializer)
249+ val outputObject = ObjectOperator . serializeObjectToRow(serializer)
232250
233251 iter.map { row =>
234252 val newColumns = outputObject(func(getObject(row)))
@@ -246,7 +264,7 @@ case class AppendColumnsWithObjectExec(
246264 func : Any => Any ,
247265 inputSerializer : Seq [NamedExpression ],
248266 newColumnsSerializer : Seq [NamedExpression ],
249- child : SparkPlan ) extends UnaryExecNode with ObjectOperator {
267+ child : SparkPlan ) extends ObjectConsumerExec {
250268
251269 override def output : Seq [Attribute ] = (inputSerializer ++ newColumnsSerializer).map(_.toAttribute)
252270
@@ -255,9 +273,9 @@ case class AppendColumnsWithObjectExec(
255273
256274 override protected def doExecute (): RDD [InternalRow ] = {
257275 child.execute().mapPartitionsInternal { iter =>
258- val getChildObject = unwrapObjectFromRow(child.output.head.dataType)
259- val outputChildObject = serializeObjectToRow(inputSerializer)
260- val outputNewColumnOjb = serializeObjectToRow(newColumnsSerializer)
276+ val getChildObject = ObjectOperator . unwrapObjectFromRow(child.output.head.dataType)
277+ val outputChildObject = ObjectOperator . serializeObjectToRow(inputSerializer)
278+ val outputNewColumnOjb = ObjectOperator . serializeObjectToRow(newColumnsSerializer)
261279 val combiner = GenerateUnsafeRowJoiner .create(inputSchema, newColumnSchema)
262280
263281 iter.map { row =>
@@ -280,10 +298,7 @@ case class MapGroupsExec(
280298 groupingAttributes : Seq [Attribute ],
281299 dataAttributes : Seq [Attribute ],
282300 outputObjAttr : Attribute ,
283- child : SparkPlan ) extends UnaryExecNode with ObjectOperator {
284-
285- override def output : Seq [Attribute ] = outputObjAttr :: Nil
286- override def producedAttributes : AttributeSet = AttributeSet (outputObjAttr)
301+ child : SparkPlan ) extends UnaryExecNode with ObjectProducerExec {
287302
288303 override def requiredChildDistribution : Seq [Distribution ] =
289304 ClusteredDistribution (groupingAttributes) :: Nil
@@ -295,9 +310,9 @@ case class MapGroupsExec(
295310 child.execute().mapPartitionsInternal { iter =>
296311 val grouped = GroupedIterator (iter, groupingAttributes, child.output)
297312
298- val getKey = deserializeRowToObject(keyDeserializer, groupingAttributes)
299- val getValue = deserializeRowToObject(valueDeserializer, dataAttributes)
300- val outputObject = wrapObjectToRow(outputObjAttr.dataType)
313+ val getKey = ObjectOperator . deserializeRowToObject(keyDeserializer, groupingAttributes)
314+ val getValue = ObjectOperator . deserializeRowToObject(valueDeserializer, dataAttributes)
315+ val outputObject = ObjectOperator . wrapObjectToRow(outputObjAttr.dataType)
301316
302317 grouped.flatMap { case (key, rowIter) =>
303318 val result = func(
@@ -325,10 +340,7 @@ case class CoGroupExec(
325340 rightAttr : Seq [Attribute ],
326341 outputObjAttr : Attribute ,
327342 left : SparkPlan ,
328- right : SparkPlan ) extends BinaryExecNode with ObjectOperator {
329-
330- override def output : Seq [Attribute ] = outputObjAttr :: Nil
331- override def producedAttributes : AttributeSet = AttributeSet (outputObjAttr)
343+ right : SparkPlan ) extends BinaryExecNode with ObjectProducerExec {
332344
333345 override def requiredChildDistribution : Seq [Distribution ] =
334346 ClusteredDistribution (leftGroup) :: ClusteredDistribution (rightGroup) :: Nil
@@ -341,10 +353,10 @@ case class CoGroupExec(
341353 val leftGrouped = GroupedIterator (leftData, leftGroup, left.output)
342354 val rightGrouped = GroupedIterator (rightData, rightGroup, right.output)
343355
344- val getKey = deserializeRowToObject(keyDeserializer, leftGroup)
345- val getLeft = deserializeRowToObject(leftDeserializer, leftAttr)
346- val getRight = deserializeRowToObject(rightDeserializer, rightAttr)
347- val outputObject = wrapObjectToRow(outputObjAttr.dataType)
356+ val getKey = ObjectOperator . deserializeRowToObject(keyDeserializer, leftGroup)
357+ val getLeft = ObjectOperator . deserializeRowToObject(leftDeserializer, leftAttr)
358+ val getRight = ObjectOperator . deserializeRowToObject(rightDeserializer, rightAttr)
359+ val outputObject = ObjectOperator . wrapObjectToRow(outputObjAttr.dataType)
348360
349361 new CoGroupedIterator (leftGrouped, rightGrouped, leftGroup).flatMap {
350362 case (key, leftResult, rightResult) =>
0 commit comments