Skip to content

Commit 1df6fc6

Browse files
committed
[SPARK-51316][PYTHON][FOLLOW-UP] Revert unrelated changes and mark mapInPandas/mapInArrow batched in byte size
### What changes were proposed in this pull request? This PR is a followup of #50096 that reverts unrelated changes and mark mapInPandas/mapInArrow batched in byte size ### Why are the changes needed? To make the original change self-contained, and mark mapInPandas/mapInArrow batched in byte size to be consistent. ### Does this PR introduce _any_ user-facing change? No, the main change has not been released out yet. ### How was this patch tested? Manually. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #50111 from HyukjinKwon/SPARK-51316-followup. Authored-by: Hyukjin Kwon <[email protected]> Signed-off-by: Hyukjin Kwon <[email protected]> (cherry picked from commit 5b45671) Signed-off-by: Hyukjin Kwon <[email protected]>
1 parent dcc2f3c commit 1df6fc6

File tree

3 files changed

+11
-3
lines changed

3 files changed

+11
-3
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ import org.apache.spark.sql.catalyst.plans.ReferenceAllColumns
3636
import org.apache.spark.sql.catalyst.plans.logical.{EventTimeWatermark, FunctionUtils, LogicalGroupState}
3737
import org.apache.spark.sql.catalyst.plans.physical._
3838
import org.apache.spark.sql.catalyst.types.DataTypeUtils
39+
import org.apache.spark.sql.execution.python.BatchIterator
3940
import org.apache.spark.sql.execution.r.ArrowRRunner
4041
import org.apache.spark.sql.execution.streaming.GroupStateImpl
4142
import org.apache.spark.sql.internal.SQLConf
@@ -218,13 +219,17 @@ case class MapPartitionsInRWithArrowExec(
218219
child: SparkPlan) extends UnaryExecNode {
219220
override def producedAttributes: AttributeSet = AttributeSet(output)
220221

222+
private val batchSize = conf.arrowMaxRecordsPerBatch
223+
221224
override def outputPartitioning: Partitioning = child.outputPartitioning
222225

223226
override protected def doExecute(): RDD[InternalRow] = {
224227
child.execute().mapPartitionsInternal { inputIter =>
225228
val outputTypes = schema.map(_.dataType)
226229

227-
val batchIter = Iterator(inputIter)
230+
// DO NOT use iter.grouped(). See BatchIterator.
231+
val batchIter =
232+
if (batchSize > 0) new BatchIterator(inputIter, batchSize) else Iterator(inputIter)
228233

229234
val runner = new ArrowRRunner(func, packageNames, broadcastVars, inputSchema,
230235
SQLConf.get.sessionLocalTimeZone, RRunnerModes.DATAFRAME_DAPPLY)

sql/core/src/main/scala/org/apache/spark/sql/execution/python/MapInBatchEvaluatorFactory.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ class MapInBatchEvaluatorFactory(
6969
pythonRunnerConf,
7070
pythonMetrics,
7171
jobArtifactUUID,
72-
None)
72+
None) with BatchedPythonArrowInput
7373
val columnarBatchIter = pyRunner.compute(batchIter, context.partitionId(), context)
7474

7575
val unsafeProj = UnsafeProjection.create(output, output)

sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowInput.scala

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,10 @@ private[python] trait BasicPythonArrowInput extends PythonArrowInput[Iterator[In
145145

146146
private[python] trait BatchedPythonArrowInput extends BasicPythonArrowInput {
147147
self: BasePythonRunner[Iterator[InternalRow], _] =>
148-
private val arrowMaxRecordsPerBatch = SQLConf.get.arrowMaxRecordsPerBatch
148+
private val arrowMaxRecordsPerBatch = {
149+
val v = SQLConf.get.arrowMaxRecordsPerBatch
150+
if (v > 0) v else Int.MaxValue
151+
}
149152

150153
private val maxBytesPerBatch = SQLConf.get.arrowMaxBytesPerBatch
151154

0 commit comments

Comments
 (0)