Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ case class DeserializeToObject(
*/
case class SerializeFromObject(
serializer: Seq[NamedExpression],
child: LogicalPlan) extends UnaryNode with ObjectConsumer {
child: LogicalPlan) extends ObjectConsumer {

override def output: Seq[Attribute] = serializer.map(_.toAttribute)
}
Expand All @@ -118,7 +118,7 @@ object MapPartitions {
case class MapPartitions(
func: Iterator[Any] => Iterator[Any],
outputObjAttr: Attribute,
child: LogicalPlan) extends UnaryNode with ObjectConsumer with ObjectProducer
child: LogicalPlan) extends ObjectConsumer with ObjectProducer

object MapPartitionsInR {
def apply(
Expand Down Expand Up @@ -152,7 +152,7 @@ case class MapPartitionsInR(
inputSchema: StructType,
outputSchema: StructType,
outputObjAttr: Attribute,
child: LogicalPlan) extends UnaryNode with ObjectConsumer with ObjectProducer {
child: LogicalPlan) extends ObjectConsumer with ObjectProducer {
override lazy val schema = outputSchema
}

Expand All @@ -175,7 +175,7 @@ object MapElements {
case class MapElements(
func: AnyRef,
outputObjAttr: Attribute,
child: LogicalPlan) extends UnaryNode with ObjectConsumer with ObjectProducer
child: LogicalPlan) extends ObjectConsumer with ObjectProducer

/** Factory for constructing new `AppendColumn` nodes. */
object AppendColumns {
Expand Down Expand Up @@ -215,7 +215,7 @@ case class AppendColumnsWithObject(
func: Any => Any,
childSerializer: Seq[NamedExpression],
newColumnsSerializer: Seq[NamedExpression],
child: LogicalPlan) extends UnaryNode with ObjectConsumer {
child: LogicalPlan) extends ObjectConsumer {

override def output: Seq[Attribute] = (childSerializer ++ newColumnsSerializer).map(_.toAttribute)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -303,7 +303,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
"logical except operator should have been replaced by anti-join in the optimizer")

case logical.DeserializeToObject(deserializer, objAttr, child) =>
execution.DeserializeToObject(deserializer, objAttr, planLater(child)) :: Nil
execution.DeserializeToObjectExec(deserializer, objAttr, planLater(child)) :: Nil
case logical.SerializeFromObject(serializer, child) =>
execution.SerializeFromObjectExec(serializer, planLater(child)) :: Nil
case logical.MapPartitions(f, objAttr, child) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,17 +28,41 @@ import org.apache.spark.sql.catalyst.expressions.objects.Invoke
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.types.{DataType, ObjectType}


/**
* Physical version of `ObjectProducer`.
*/
trait ObjectProducerExec extends SparkPlan {
// The attribute that reference to the single object field this operator outputs.
protected def outputObjAttr: Attribute

override def output: Seq[Attribute] = outputObjAttr :: Nil

override def producedAttributes: AttributeSet = AttributeSet(outputObjAttr)

def outputObjectType: DataType = outputObjAttr.dataType
}

/**
* Physical version of `ObjectConsumer`.
*/
trait ObjectConsumerExec extends UnaryExecNode {
assert(child.output.length == 1)

// This operator always need all columns of its child, even it doesn't reference to.
override def references: AttributeSet = child.outputSet

def inputObjectType: DataType = child.output.head.dataType
}

/**
* Takes the input row from child and turns it into object using the given deserializer expression.
* The output of this operator is a single-field safe row containing the deserialized object.
*/
case class DeserializeToObject(
case class DeserializeToObjectExec(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not related, but to make the name consistent with other physical plans.

deserializer: Expression,
outputObjAttr: Attribute,
child: SparkPlan) extends UnaryExecNode with CodegenSupport {

override def output: Seq[Attribute] = outputObjAttr :: Nil
override def producedAttributes: AttributeSet = AttributeSet(outputObjAttr)
child: SparkPlan) extends UnaryExecNode with ObjectProducerExec with CodegenSupport {

override def inputRDDs(): Seq[RDD[InternalRow]] = {
child.asInstanceOf[CodegenSupport].inputRDDs()
Expand Down Expand Up @@ -70,7 +94,7 @@ case class DeserializeToObject(
*/
case class SerializeFromObjectExec(
serializer: Seq[NamedExpression],
child: SparkPlan) extends UnaryExecNode with CodegenSupport {
child: SparkPlan) extends ObjectConsumerExec with CodegenSupport {

override def output: Seq[Attribute] = serializer.map(_.toAttribute)

Expand Down Expand Up @@ -102,7 +126,7 @@ case class SerializeFromObjectExec(
/**
* Helper functions for physical operators that work with user defined objects.
*/
trait ObjectOperator extends SparkPlan {
object ObjectOperator {
def deserializeRowToObject(
deserializer: Expression,
inputSchema: Seq[Attribute]): InternalRow => Any = {
Expand Down Expand Up @@ -141,15 +165,12 @@ case class MapPartitionsExec(
func: Iterator[Any] => Iterator[Any],
outputObjAttr: Attribute,
child: SparkPlan)
extends UnaryExecNode with ObjectOperator {

override def output: Seq[Attribute] = outputObjAttr :: Nil
override def producedAttributes: AttributeSet = AttributeSet(outputObjAttr)
extends ObjectConsumerExec with ObjectProducerExec {

override protected def doExecute(): RDD[InternalRow] = {
child.execute().mapPartitionsInternal { iter =>
val getObject = unwrapObjectFromRow(child.output.head.dataType)
val outputObject = wrapObjectToRow(outputObjAttr.dataType)
val getObject = ObjectOperator.unwrapObjectFromRow(child.output.head.dataType)
val outputObject = ObjectOperator.wrapObjectToRow(outputObjAttr.dataType)
func(iter.map(getObject)).map(outputObject)
}
}
Expand All @@ -166,10 +187,7 @@ case class MapElementsExec(
func: AnyRef,
outputObjAttr: Attribute,
child: SparkPlan)
extends UnaryExecNode with ObjectOperator with CodegenSupport {

override def output: Seq[Attribute] = outputObjAttr :: Nil
override def producedAttributes: AttributeSet = AttributeSet(outputObjAttr)
extends ObjectConsumerExec with ObjectProducerExec with CodegenSupport {

override def inputRDDs(): Seq[RDD[InternalRow]] = {
child.asInstanceOf[CodegenSupport].inputRDDs()
Expand Down Expand Up @@ -202,8 +220,8 @@ case class MapElementsExec(
}

child.execute().mapPartitionsInternal { iter =>
val getObject = unwrapObjectFromRow(child.output.head.dataType)
val outputObject = wrapObjectToRow(outputObjAttr.dataType)
val getObject = ObjectOperator.unwrapObjectFromRow(child.output.head.dataType)
val outputObject = ObjectOperator.wrapObjectToRow(outputObjAttr.dataType)
iter.map(row => outputObject(callFunc(getObject(row))))
}
}
Expand All @@ -218,17 +236,17 @@ case class AppendColumnsExec(
func: Any => Any,
deserializer: Expression,
serializer: Seq[NamedExpression],
child: SparkPlan) extends UnaryExecNode with ObjectOperator {
child: SparkPlan) extends UnaryExecNode {

override def output: Seq[Attribute] = child.output ++ serializer.map(_.toAttribute)

private def newColumnSchema = serializer.map(_.toAttribute).toStructType

override protected def doExecute(): RDD[InternalRow] = {
child.execute().mapPartitionsInternal { iter =>
val getObject = deserializeRowToObject(deserializer, child.output)
val getObject = ObjectOperator.deserializeRowToObject(deserializer, child.output)
val combiner = GenerateUnsafeRowJoiner.create(child.schema, newColumnSchema)
val outputObject = serializeObjectToRow(serializer)
val outputObject = ObjectOperator.serializeObjectToRow(serializer)

iter.map { row =>
val newColumns = outputObject(func(getObject(row)))
Expand All @@ -246,7 +264,7 @@ case class AppendColumnsWithObjectExec(
func: Any => Any,
inputSerializer: Seq[NamedExpression],
newColumnsSerializer: Seq[NamedExpression],
child: SparkPlan) extends UnaryExecNode with ObjectOperator {
child: SparkPlan) extends ObjectConsumerExec {

override def output: Seq[Attribute] = (inputSerializer ++ newColumnsSerializer).map(_.toAttribute)

Expand All @@ -255,9 +273,9 @@ case class AppendColumnsWithObjectExec(

override protected def doExecute(): RDD[InternalRow] = {
child.execute().mapPartitionsInternal { iter =>
val getChildObject = unwrapObjectFromRow(child.output.head.dataType)
val outputChildObject = serializeObjectToRow(inputSerializer)
val outputNewColumnOjb = serializeObjectToRow(newColumnsSerializer)
val getChildObject = ObjectOperator.unwrapObjectFromRow(child.output.head.dataType)
val outputChildObject = ObjectOperator.serializeObjectToRow(inputSerializer)
val outputNewColumnOjb = ObjectOperator.serializeObjectToRow(newColumnsSerializer)
val combiner = GenerateUnsafeRowJoiner.create(inputSchema, newColumnSchema)

iter.map { row =>
Expand All @@ -280,10 +298,7 @@ case class MapGroupsExec(
groupingAttributes: Seq[Attribute],
dataAttributes: Seq[Attribute],
outputObjAttr: Attribute,
child: SparkPlan) extends UnaryExecNode with ObjectOperator {

override def output: Seq[Attribute] = outputObjAttr :: Nil
override def producedAttributes: AttributeSet = AttributeSet(outputObjAttr)
child: SparkPlan) extends UnaryExecNode with ObjectProducerExec {

override def requiredChildDistribution: Seq[Distribution] =
ClusteredDistribution(groupingAttributes) :: Nil
Expand All @@ -295,9 +310,9 @@ case class MapGroupsExec(
child.execute().mapPartitionsInternal { iter =>
val grouped = GroupedIterator(iter, groupingAttributes, child.output)

val getKey = deserializeRowToObject(keyDeserializer, groupingAttributes)
val getValue = deserializeRowToObject(valueDeserializer, dataAttributes)
val outputObject = wrapObjectToRow(outputObjAttr.dataType)
val getKey = ObjectOperator.deserializeRowToObject(keyDeserializer, groupingAttributes)
val getValue = ObjectOperator.deserializeRowToObject(valueDeserializer, dataAttributes)
val outputObject = ObjectOperator.wrapObjectToRow(outputObjAttr.dataType)

grouped.flatMap { case (key, rowIter) =>
val result = func(
Expand Down Expand Up @@ -325,10 +340,7 @@ case class CoGroupExec(
rightAttr: Seq[Attribute],
outputObjAttr: Attribute,
left: SparkPlan,
right: SparkPlan) extends BinaryExecNode with ObjectOperator {

override def output: Seq[Attribute] = outputObjAttr :: Nil
override def producedAttributes: AttributeSet = AttributeSet(outputObjAttr)
right: SparkPlan) extends BinaryExecNode with ObjectProducerExec {

override def requiredChildDistribution: Seq[Distribution] =
ClusteredDistribution(leftGroup) :: ClusteredDistribution(rightGroup) :: Nil
Expand All @@ -341,10 +353,10 @@ case class CoGroupExec(
val leftGrouped = GroupedIterator(leftData, leftGroup, left.output)
val rightGrouped = GroupedIterator(rightData, rightGroup, right.output)

val getKey = deserializeRowToObject(keyDeserializer, leftGroup)
val getLeft = deserializeRowToObject(leftDeserializer, leftAttr)
val getRight = deserializeRowToObject(rightDeserializer, rightAttr)
val outputObject = wrapObjectToRow(outputObjAttr.dataType)
val getKey = ObjectOperator.deserializeRowToObject(keyDeserializer, leftGroup)
val getLeft = ObjectOperator.deserializeRowToObject(leftDeserializer, leftAttr)
val getRight = ObjectOperator.deserializeRowToObject(rightDeserializer, rightAttr)
val outputObject = ObjectOperator.wrapObjectToRow(outputObjAttr.dataType)

new CoGroupedIterator(leftGrouped, rightGrouped, leftGroup).flatMap {
case (key, leftResult, rightResult) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -702,6 +702,11 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
assert(e.message.contains("already exists"))
dataset.sparkSession.catalog.dropTempView("tempView")
}

test("SPARK-15381: physical object operator should define `reference` correctly") {
val df = Seq(1 -> 2).toDF("a", "b")
checkAnswer(df.map(row => row)(RowEncoder(df.schema)).select("b", "a"), Row(2, 1))
}
}

case class Generic[T](id: T, value: Double)
Expand Down