@@ -20,7 +20,7 @@ package org.apache.spark.api.python
2020import java .io ._
2121import java .net ._
2222import java .nio .charset .StandardCharsets
23- import java .util .{ArrayList => JArrayList , Collections , List => JList , Map => JMap }
23+ import java .util .{ArrayList => JArrayList , List => JList , Map => JMap }
2424
2525import scala .collection .JavaConverters ._
2626import scala .collection .mutable
@@ -38,7 +38,7 @@ import org.apache.spark.broadcast.Broadcast
3838import org .apache .spark .input .PortableDataStream
3939import org .apache .spark .internal .Logging
4040import org .apache .spark .rdd .RDD
41- import org .apache .spark .util .{ SerializableConfiguration , Utils }
41+ import org .apache .spark .util ._
4242
4343
4444private [spark] class PythonRDD (
@@ -75,7 +75,7 @@ private[spark] case class PythonFunction(
7575 pythonExec : String ,
7676 pythonVer : String ,
7777 broadcastVars : JList [Broadcast [PythonBroadcast ]],
78- accumulator : Accumulator [ JList [ Array [ Byte ]]] )
78+ accumulator : PythonAccumulatorV2 )
7979
8080/**
8181 * A wrapper for chained Python functions (from bottom to top).
@@ -200,7 +200,7 @@ private[spark] class PythonRunner(
200200 val updateLen = stream.readInt()
201201 val update = new Array [Byte ](updateLen)
202202 stream.readFully(update)
203- accumulator += Collections .singletonList (update)
203+ accumulator.add (update)
204204 }
205205 // Check whether the worker is ready to be re-used.
206206 if (stream.readInt() == SpecialLengths .END_OF_STREAM ) {
@@ -461,7 +461,7 @@ private[spark] object PythonRDD extends Logging {
461461 JavaRDD [Array [Byte ]] = {
462462 val file = new DataInputStream (new FileInputStream (filename))
463463 try {
464- val objs = new collection. mutable.ArrayBuffer [Array [Byte ]]
464+ val objs = new mutable.ArrayBuffer [Array [Byte ]]
465465 try {
466466 while (true ) {
467467 val length = file.readInt()
@@ -866,11 +866,13 @@ class BytesToString extends org.apache.spark.api.java.function.Function[Array[By
866866}
867867
868868/**
869- * Internal class that acts as an `AccumulatorParam ` for Python accumulators. Inside, it
869+ * Internal class that acts as an `AccumulatorV2 ` for Python accumulators. Inside, it
870870 * collects a list of pickled strings that we pass to Python through a socket.
871871 */
872- private class PythonAccumulatorParam (@ transient private val serverHost : String , serverPort : Int )
873- extends AccumulatorParam [JList [Array [Byte ]]] {
872+ private [spark] class PythonAccumulatorV2 (
873+ @ transient private val serverHost : String ,
874+ private val serverPort : Int )
875+ extends CollectionAccumulator [Array [Byte ]] {
874876
875877 Utils .checkHost(serverHost, " Expected hostname" )
876878
@@ -880,30 +882,33 @@ private class PythonAccumulatorParam(@transient private val serverHost: String,
880882 * We try to reuse a single Socket to transfer accumulator updates, as they are all added
881883 * by the DAGScheduler's single-threaded RpcEndpoint anyway.
882884 */
883- @ transient var socket : Socket = _
885+ @ transient private var socket : Socket = _
884886
885- def openSocket (): Socket = synchronized {
887+ private def openSocket (): Socket = synchronized {
886888 if (socket == null || socket.isClosed) {
887889 socket = new Socket (serverHost, serverPort)
888890 }
889891 socket
890892 }
891893
892- override def zero (value : JList [Array [Byte ]]): JList [Array [Byte ]] = new JArrayList
894+ // Need to override so the types match with PythonFunction
895+ override def copyAndReset (): PythonAccumulatorV2 = new PythonAccumulatorV2 (serverHost, serverPort)
893896
894- override def addInPlace (val1 : JList [Array [Byte ]], val2 : JList [Array [Byte ]])
895- : JList [Array [Byte ]] = synchronized {
897+ override def merge (other : AccumulatorV2 [Array [Byte ], JList [Array [Byte ]]]): Unit = synchronized {
898+ val otherPythonAccumulator = other.asInstanceOf [PythonAccumulatorV2 ]
899+ // This conditional isn't strictly speaking needed - merging only currently happens on the
900+ // driver program - but that isn't gauranteed so incase this changes.
896901 if (serverHost == null ) {
897- // This happens on the worker node, where we just want to remember all the updates
898- val1.addAll(val2)
899- val1
902+ // We are on the worker
903+ super .merge(otherPythonAccumulator)
900904 } else {
901905 // This happens on the master, where we pass the updates to Python through a socket
902906 val socket = openSocket()
903907 val in = socket.getInputStream
904908 val out = new DataOutputStream (new BufferedOutputStream (socket.getOutputStream, bufferSize))
905- out.writeInt(val2.size)
906- for (array <- val2.asScala) {
909+ val values = other.value
910+ out.writeInt(values.size)
911+ for (array <- values.asScala) {
907912 out.writeInt(array.length)
908913 out.write(array)
909914 }
@@ -913,7 +918,6 @@ private class PythonAccumulatorParam(@transient private val serverHost: String,
913918 if (byteRead == - 1 ) {
914919 throw new SparkException (" EOF reached before Python server acknowledged" )
915920 }
916- null
917921 }
918922 }
919923}
0 commit comments