diff --git a/core/src/main/scala/org/apache/spark/rdd/PipedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/PipedRDD.scala index 5dd8cb8440be6..df03a2174dd9d 100644 --- a/core/src/main/scala/org/apache/spark/rdd/PipedRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/PipedRDD.scala @@ -230,7 +230,7 @@ private[spark] class PipedRDD[T: ClassTag]( } } -private object PipedRDD { +private[spark] object PipedRDD { // Split a string into words using a standard StringTokenizer def tokenize(command: String): Seq[String] = { val buf = new ArrayBuffer[String] diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala index d383532cbd3d3..56f503624bfb6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.api.java.function.FilterFunction import org.apache.spark.broadcast.Broadcast -import org.apache.spark.sql.{Encoder, Row} +import org.apache.spark.sql.{Encoder, Encoders, Row} import org.apache.spark.sql.catalyst.analysis.UnresolvedDeserializer import org.apache.spark.sql.catalyst.encoders._ import org.apache.spark.sql.catalyst.expressions._ @@ -585,3 +585,32 @@ case class CoGroup( outputObjAttr: Attribute, left: LogicalPlan, right: LogicalPlan) extends BinaryNode with ObjectProducer + +object PipeElements { + def apply[T : Encoder]( + command: String, + printElement: (Any, String => Unit) => Unit, + child: LogicalPlan): LogicalPlan = { + val deserialized = CatalystSerde.deserialize[T](child) + implicit val encoder = Encoders.STRING + val piped = PipeElements( + implicitly[Encoder[T]].clsTag.runtimeClass, + implicitly[Encoder[T]].schema, + CatalystSerde.generateObjAttr[String], + command, + printElement, + deserialized) + CatalystSerde.serialize[String](piped) + } +} + +/** + * A relation produced by piping elements to a forked external process. + */ +case class PipeElements[T]( + argumentClass: Class[_], + argumentSchema: StructType, + outputObjAttr: Attribute, + command: String, + printElement: (Any, String => Unit) => Unit, + child: LogicalPlan) extends ObjectConsumer with ObjectProducer diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index f9590797434a1..06b23bcf64e7a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -2889,6 +2889,35 @@ class Dataset[T] private[sql]( flatMap(func)(encoder) } + /** + * Return a new Dataset of string created by piping elements to a forked external process. + * The resulting Dataset is computed by executing the given process once per partition. + * All elements of each input partition are written to a process's stdin as lines of input + * separated by a newline. The resulting partition consists of the process's stdout output, with + * each line of stdout resulting in one element of the output partition. A process is invoked + * even for empty partitions. + * + * Note that for micro-batch streaming Dataset, the effect of pipe is only per micro-batch, not + * cross entire stream. If your external process does aggregation-like on inputs, e.g. `wc -l`, + * the aggregation is applied per a partition in micro-batch. You may want to aggregate these + * outputs after calling pipe to get global aggregation across partitions and also across + * micro-batches. + * + * @param command command to run in forked process. + * @param printElement Use this function to customize how to pipe elements. This function + * will be called with each Dataset element as the 1st parameter, and the + * print line function (like out.println()) as the 2nd parameter. + * @group typedrel + * @since 3.2.0 + */ + def pipe(command: String, printElement: (T, String => Unit) => Unit): Dataset[String] = { + implicit val stringEncoder = Encoders.STRING + withTypedPlan[String](PipeElements[T]( + command, + printElement.asInstanceOf[(Any, String => Unit) => Unit], + logicalPlan)) + } + /** * Applies a function `f` to all rows. * diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index a8d788f59d271..d71b8ad715a81 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -666,6 +666,8 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { execution.python.MapInPandasExec(func, output, planLater(child)) :: Nil case logical.MapElements(f, _, _, objAttr, child) => execution.MapElementsExec(f, objAttr, planLater(child)) :: Nil + case logical.PipeElements(_, _, objAttr, command, printElement, child) => + execution.PipeElementsExec(objAttr, command, printElement, planLater(child)) :: Nil case logical.AppendColumns(f, _, _, in, out, child) => execution.AppendColumnsExec(f, in, out, planLater(child)) :: Nil case logical.AppendColumnsWithObject(f, childSer, newSer, child) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala index c08db132c946f..fa208ac958b90 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala @@ -25,7 +25,7 @@ import scala.language.existentials import org.apache.spark.api.java.function.MapFunction import org.apache.spark.api.r._ import org.apache.spark.broadcast.Broadcast -import org.apache.spark.rdd.RDD +import org.apache.spark.rdd.{PipedRDD, RDD} import org.apache.spark.sql.Row import org.apache.spark.sql.api.r.SQLUtils._ import org.apache.spark.sql.catalyst.InternalRow @@ -624,3 +624,34 @@ case class CoGroupExec( } } } + +/** + * Piping elements to a forked external process. + * The output of its child must be a single-field row containing the input object. + */ +case class PipeElementsExec( + outputObjAttr: Attribute, + command: String, + printElement: (Any, String => Unit) => Unit, + child: SparkPlan) + extends ObjectConsumerExec with ObjectProducerExec { + + override protected def doExecute(): RDD[InternalRow] = { + val getObject = ObjectOperator.unwrapObjectFromRow(child.output.head.dataType) + val printRDDElement: (InternalRow, String => Unit) => Unit = (row, printFunc) => { + val obj = getObject(row) + printElement(obj, printFunc) + } + + child.execute() + .pipe(command = PipedRDD.tokenize(command), printRDDElement = printRDDElement) + .mapPartitionsInternal { iter => + val outputObject = ObjectOperator.wrapObjectToRow(outputObjectType) + iter.map(e => outputObject(e)) + } + } + + override def outputOrdering: Seq[SortOrder] = child.outputOrdering + + override def outputPartitioning: Partitioning = child.outputPartitioning +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index 2ec4c6918a248..ce021b5e15470 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -24,7 +24,7 @@ import org.scalatest.Assertions._ import org.scalatest.exceptions.TestFailedException import org.scalatest.prop.TableDrivenPropertyChecks._ -import org.apache.spark.{SparkException, TaskContext} +import org.apache.spark.{SparkException, TaskContext, TestUtils} import org.apache.spark.sql.catalyst.{FooClassWithEnum, FooEnum, ScroogeLikeExample} import org.apache.spark.sql.catalyst.encoders.{OuterScopes, RowEncoder} import org.apache.spark.sql.catalyst.plans.{LeftAnti, LeftSemi} @@ -2007,6 +2007,54 @@ class DatasetSuite extends QueryTest checkAnswer(withUDF, Row(Row(1), null, null) :: Row(Row(1), null, null) :: Nil) } + + test("SPARK-34205: Pipe Dataset") { + assume(TestUtils.testCommandAvailable("cat")) + + val nums = spark.range(4) + val piped = nums.pipe("cat", (l, printFunc) => printFunc(l.toString)).toDF + + checkAnswer(piped, Row("0") :: Row("1") :: Row("2") :: Row("3") :: Nil) + + val piped2 = nums.pipe("wc -l", (l, printFunc) => printFunc(l.toString)).toDF.collect() + assert(piped2.size == 2) + assert(piped2(0).getString(0).trim == "2") + assert(piped2(1).getString(0).trim == "2") + } + + test("SPARK-34205: Pipe DataFrame") { + assume(TestUtils.testCommandAvailable("cat")) + + val data = Seq((123, "first"), (4567, "second")).toDF("num", "word") + + def printElement(row: Row, printFunc: (String) => Unit): Unit = { + val line = s"num: ${row.getInt(0)}, word: ${row.getString(1)}" + printFunc.apply(line) + } + val piped = data.pipe("cat", printElement).toDF + checkAnswer(piped, Row("num: 123, word: first") :: Row("num: 4567, word: second") :: Nil) + } + + test("SPARK-34205: Pipe complex type Dataset") { + assume(TestUtils.testCommandAvailable("cat")) + + val data = Seq(DoubleData(123, "first"), DoubleData(4567, "second")).toDS + + def printElement(data: DoubleData, printFunc: (String) => Unit): Unit = { + val line = s"num: ${data.id}, word: ${data.val1}" + printFunc.apply(line) + } + val piped = data.pipe("cat", printElement).toDF + checkAnswer(piped, Row("num: 123, word: first") :: Row("num: 4567, word: second") :: Nil) + } + + test("SPARK-34205: pipe Dataset with empty partition") { + val data = Seq(123, 4567).toDF("num").repartition(8, $"num") + val piped = data.pipe("wc -l", (row, printFunc) => printFunc(row.getInt(0).toString)) + assert(piped.count == 8) + val lineCounts = piped.map(_.trim.toInt).collect().toSet + assert(Set(0, 1, 1) == lineCounts) + } } case class Bar(a: Int) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala index 440fe997ae133..4e6b652dcc58d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala @@ -1264,6 +1264,20 @@ class StreamSuite extends StreamTest { } } } + + test("SPARK-34205: Pipe Streaming Dataset") { + assume(TestUtils.testCommandAvailable("cat")) + + val inputData = MemoryStream[Int] + val piped = inputData.toDS() + .pipe("cat", (n, printFunc) => printFunc(n.toString)).toDF + + testStream(piped)( + AddData(inputData, 1, 2, 3), + CheckAnswer(Row("1"), Row("2"), Row("3")), + AddData(inputData, 4), + CheckNewAnswer(Row("4"))) + } } abstract class FakeSource extends StreamSourceProvider {