Skip to content

Commit 90d5754

Browse files
holdenksrowen
authored andcommitted
[SPARK-16861][PYSPARK][CORE] Refactor PySpark accumulator API on top of Accumulator V2
## What changes were proposed in this pull request? Move the internals of the PySpark accumulator API from the old deprecated API on top of the new accumulator API. ## How was this patch tested? The existing PySpark accumulator tests (both unit tests and doc tests at the start of accumulator.py). Author: Holden Karau <[email protected]> Closes #14467 from holdenk/SPARK-16861-refactor-pyspark-accumulator-api.
1 parent 5c5396c commit 90d5754

File tree

2 files changed

+25
-22
lines changed

2 files changed

+25
-22
lines changed

core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala

Lines changed: 23 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ package org.apache.spark.api.python
2020
import java.io._
2121
import java.net._
2222
import java.nio.charset.StandardCharsets
23-
import java.util.{ArrayList => JArrayList, Collections, List => JList, Map => JMap}
23+
import java.util.{ArrayList => JArrayList, List => JList, Map => JMap}
2424

2525
import scala.collection.JavaConverters._
2626
import scala.collection.mutable
@@ -38,7 +38,7 @@ import org.apache.spark.broadcast.Broadcast
3838
import org.apache.spark.input.PortableDataStream
3939
import org.apache.spark.internal.Logging
4040
import org.apache.spark.rdd.RDD
41-
import org.apache.spark.util.{SerializableConfiguration, Utils}
41+
import org.apache.spark.util._
4242

4343

4444
private[spark] class PythonRDD(
@@ -75,7 +75,7 @@ private[spark] case class PythonFunction(
7575
pythonExec: String,
7676
pythonVer: String,
7777
broadcastVars: JList[Broadcast[PythonBroadcast]],
78-
accumulator: Accumulator[JList[Array[Byte]]])
78+
accumulator: PythonAccumulatorV2)
7979

8080
/**
8181
* A wrapper for chained Python functions (from bottom to top).
@@ -200,7 +200,7 @@ private[spark] class PythonRunner(
200200
val updateLen = stream.readInt()
201201
val update = new Array[Byte](updateLen)
202202
stream.readFully(update)
203-
accumulator += Collections.singletonList(update)
203+
accumulator.add(update)
204204
}
205205
// Check whether the worker is ready to be re-used.
206206
if (stream.readInt() == SpecialLengths.END_OF_STREAM) {
@@ -461,7 +461,7 @@ private[spark] object PythonRDD extends Logging {
461461
JavaRDD[Array[Byte]] = {
462462
val file = new DataInputStream(new FileInputStream(filename))
463463
try {
464-
val objs = new collection.mutable.ArrayBuffer[Array[Byte]]
464+
val objs = new mutable.ArrayBuffer[Array[Byte]]
465465
try {
466466
while (true) {
467467
val length = file.readInt()
@@ -866,11 +866,13 @@ class BytesToString extends org.apache.spark.api.java.function.Function[Array[By
866866
}
867867

868868
/**
869-
* Internal class that acts as an `AccumulatorParam` for Python accumulators. Inside, it
869+
* Internal class that acts as an `AccumulatorV2` for Python accumulators. Inside, it
870870
* collects a list of pickled strings that we pass to Python through a socket.
871871
*/
872-
private class PythonAccumulatorParam(@transient private val serverHost: String, serverPort: Int)
873-
extends AccumulatorParam[JList[Array[Byte]]] {
872+
private[spark] class PythonAccumulatorV2(
873+
@transient private val serverHost: String,
874+
private val serverPort: Int)
875+
extends CollectionAccumulator[Array[Byte]] {
874876

875877
Utils.checkHost(serverHost, "Expected hostname")
876878

@@ -880,30 +882,33 @@ private class PythonAccumulatorParam(@transient private val serverHost: String,
880882
* We try to reuse a single Socket to transfer accumulator updates, as they are all added
881883
* by the DAGScheduler's single-threaded RpcEndpoint anyway.
882884
*/
883-
@transient var socket: Socket = _
885+
@transient private var socket: Socket = _
884886

885-
def openSocket(): Socket = synchronized {
887+
private def openSocket(): Socket = synchronized {
886888
if (socket == null || socket.isClosed) {
887889
socket = new Socket(serverHost, serverPort)
888890
}
889891
socket
890892
}
891893

892-
override def zero(value: JList[Array[Byte]]): JList[Array[Byte]] = new JArrayList
894+
// Need to override so the types match with PythonFunction
895+
override def copyAndReset(): PythonAccumulatorV2 = new PythonAccumulatorV2(serverHost, serverPort)
893896

894-
override def addInPlace(val1: JList[Array[Byte]], val2: JList[Array[Byte]])
895-
: JList[Array[Byte]] = synchronized {
897+
override def merge(other: AccumulatorV2[Array[Byte], JList[Array[Byte]]]): Unit = synchronized {
898+
val otherPythonAccumulator = other.asInstanceOf[PythonAccumulatorV2]
899+
// This conditional isn't strictly speaking needed - merging only currently happens on the
900+
// driver program - but that isn't gauranteed so incase this changes.
896901
if (serverHost == null) {
897-
// This happens on the worker node, where we just want to remember all the updates
898-
val1.addAll(val2)
899-
val1
902+
// We are on the worker
903+
super.merge(otherPythonAccumulator)
900904
} else {
901905
// This happens on the master, where we pass the updates to Python through a socket
902906
val socket = openSocket()
903907
val in = socket.getInputStream
904908
val out = new DataOutputStream(new BufferedOutputStream(socket.getOutputStream, bufferSize))
905-
out.writeInt(val2.size)
906-
for (array <- val2.asScala) {
909+
val values = other.value
910+
out.writeInt(values.size)
911+
for (array <- values.asScala) {
907912
out.writeInt(array.length)
908913
out.write(array)
909914
}
@@ -913,7 +918,6 @@ private class PythonAccumulatorParam(@transient private val serverHost: String,
913918
if (byteRead == -1) {
914919
throw new SparkException("EOF reached before Python server acknowledged")
915920
}
916-
null
917921
}
918922
}
919923
}

python/pyspark/context.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -173,9 +173,8 @@ def _do_init(self, master, appName, sparkHome, pyFiles, environment, batchSize,
173173
# they will be passed back to us through a TCP server
174174
self._accumulatorServer = accumulators._start_update_server()
175175
(host, port) = self._accumulatorServer.server_address
176-
self._javaAccumulator = self._jsc.accumulator(
177-
self._jvm.java.util.ArrayList(),
178-
self._jvm.PythonAccumulatorParam(host, port))
176+
self._javaAccumulator = self._jvm.PythonAccumulatorV2(host, port)
177+
self._jsc.sc().register(self._javaAccumulator)
179178

180179
self.pythonExec = os.environ.get("PYSPARK_PYTHON", 'python')
181180
self.pythonVer = "%d.%d" % sys.version_info[:2]

0 commit comments

Comments
 (0)