Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
17 commits
Select commit Hold shift + click to select a range
3c1ea65
one attempt
holdenk Aug 1, 2016
1d538fe
Revert "one attempt"
holdenk Aug 1, 2016
4bc43c0
Start switching the Python Accumulator backing to the V2 API
holdenk Aug 1, 2016
46fa97d
Start switching the Python side code to match the JVM side code
holdenk Aug 1, 2016
736f6ce
Merge branch 'master' into SPARK-16775-reduce-internal-warnings-from-…
holdenk Aug 2, 2016
4756853
If we start merging on the worker go through the same add path so its…
holdenk Aug 3, 2016
4b1b872
Do a deep copy on copy and implemented specialized copyAndReset to av…
holdenk Aug 3, 2016
a4d87e8
Use Collections.synchronizedList for safety
holdenk Aug 3, 2016
cc5f435
Merge branch 'master' into SPARK-16861-refactor-pyspark-accumulator-api
holdenk Aug 4, 2016
5fcaa5a
Merge branch 'master' into SPARK-16861-refactor-pyspark-accumulator-api
holdenk Aug 18, 2016
04a1d37
synchronized on otherPythonAccumulator during merge step
holdenk Aug 19, 2016
2f0af6a
Merge branch 'master' into SPARK-16861-refactor-pyspark-accumulator-api
holdenk Sep 8, 2016
6169c3c
Merge branch 'master' into SPARK-16861-refactor-pyspark-accumulator-api
holdenk Sep 12, 2016
b29d8cd
Merge branch 'master' into SPARK-16861-refactor-pyspark-accumulator-api
holdenk Sep 14, 2016
fca20c0
Merge branch 'master' into SPARK-16861-refactor-pyspark-accumulator-api
holdenk Sep 20, 2016
45ec1ef
Simplify to AccumulatorV2[JList[Array[Byte]], Unit]
holdenk Sep 20, 2016
76f1fac
Use the CollectionAccumulator base trait for Python Accumulator
holdenk Sep 21, 2016
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 23 additions & 19 deletions core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ package org.apache.spark.api.python
import java.io._
import java.net._
import java.nio.charset.StandardCharsets
import java.util.{ArrayList => JArrayList, Collections, List => JList, Map => JMap}
import java.util.{ArrayList => JArrayList, List => JList, Map => JMap}

import scala.collection.JavaConverters._
import scala.collection.mutable
Expand All @@ -38,7 +38,7 @@ import org.apache.spark.broadcast.Broadcast
import org.apache.spark.input.PortableDataStream
import org.apache.spark.internal.Logging
import org.apache.spark.rdd.RDD
import org.apache.spark.util.{SerializableConfiguration, Utils}
import org.apache.spark.util._


private[spark] class PythonRDD(
Expand Down Expand Up @@ -75,7 +75,7 @@ private[spark] case class PythonFunction(
pythonExec: String,
pythonVer: String,
broadcastVars: JList[Broadcast[PythonBroadcast]],
accumulator: Accumulator[JList[Array[Byte]]])
accumulator: PythonAccumulatorV2)

/**
* A wrapper for chained Python functions (from bottom to top).
Expand Down Expand Up @@ -200,7 +200,7 @@ private[spark] class PythonRunner(
val updateLen = stream.readInt()
val update = new Array[Byte](updateLen)
stream.readFully(update)
accumulator += Collections.singletonList(update)
accumulator.add(update)
}
// Check whether the worker is ready to be re-used.
if (stream.readInt() == SpecialLengths.END_OF_STREAM) {
Expand Down Expand Up @@ -461,7 +461,7 @@ private[spark] object PythonRDD extends Logging {
JavaRDD[Array[Byte]] = {
val file = new DataInputStream(new FileInputStream(filename))
try {
val objs = new collection.mutable.ArrayBuffer[Array[Byte]]
val objs = new mutable.ArrayBuffer[Array[Byte]]
try {
while (true) {
val length = file.readInt()
Expand Down Expand Up @@ -866,11 +866,13 @@ class BytesToString extends org.apache.spark.api.java.function.Function[Array[By
}

/**
* Internal class that acts as an `AccumulatorParam` for Python accumulators. Inside, it
* Internal class that acts as an `AccumulatorV2` for Python accumulators. Inside, it
* collects a list of pickled strings that we pass to Python through a socket.
*/
private class PythonAccumulatorParam(@transient private val serverHost: String, serverPort: Int)
extends AccumulatorParam[JList[Array[Byte]]] {
private[spark] class PythonAccumulatorV2(
@transient private val serverHost: String,
private val serverPort: Int)
extends CollectionAccumulator[Array[Byte]] {

Utils.checkHost(serverHost, "Expected hostname")

Expand All @@ -880,30 +882,33 @@ private class PythonAccumulatorParam(@transient private val serverHost: String,
* We try to reuse a single Socket to transfer accumulator updates, as they are all added
* by the DAGScheduler's single-threaded RpcEndpoint anyway.
*/
@transient var socket: Socket = _
@transient private var socket: Socket = _

def openSocket(): Socket = synchronized {
private def openSocket(): Socket = synchronized {
if (socket == null || socket.isClosed) {
socket = new Socket(serverHost, serverPort)
}
socket
}

override def zero(value: JList[Array[Byte]]): JList[Array[Byte]] = new JArrayList
// Need to override so the types match with PythonFunction
override def copyAndReset(): PythonAccumulatorV2 = new PythonAccumulatorV2(serverHost, serverPort)

override def addInPlace(val1: JList[Array[Byte]], val2: JList[Array[Byte]])
: JList[Array[Byte]] = synchronized {
override def merge(other: AccumulatorV2[Array[Byte], JList[Array[Byte]]]): Unit = synchronized {
val otherPythonAccumulator = other.asInstanceOf[PythonAccumulatorV2]
// This conditional isn't strictly speaking needed - merging only currently happens on the
// driver program - but that isn't gauranteed so incase this changes.
if (serverHost == null) {
// This happens on the worker node, where we just want to remember all the updates
val1.addAll(val2)
val1
// We are on the worker
super.merge(otherPythonAccumulator)
} else {
// This happens on the master, where we pass the updates to Python through a socket
val socket = openSocket()
val in = socket.getInputStream
val out = new DataOutputStream(new BufferedOutputStream(socket.getOutputStream, bufferSize))
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On a bit of a side note - we could consider using the callback server here if we wanted to enable it in general rather than just for streaming once Py4J has its performance improvements in.

out.writeInt(val2.size)
for (array <- val2.asScala) {
val values = other.value
out.writeInt(values.size)
for (array <- values.asScala) {
out.writeInt(array.length)
out.write(array)
}
Expand All @@ -913,7 +918,6 @@ private class PythonAccumulatorParam(@transient private val serverHost: String,
if (byteRead == -1) {
throw new SparkException("EOF reached before Python server acknowledged")
}
null
}
}
}
Expand Down
5 changes: 2 additions & 3 deletions python/pyspark/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,9 +173,8 @@ def _do_init(self, master, appName, sparkHome, pyFiles, environment, batchSize,
# they will be passed back to us through a TCP server
self._accumulatorServer = accumulators._start_update_server()
(host, port) = self._accumulatorServer.server_address
self._javaAccumulator = self._jsc.accumulator(
self._jvm.java.util.ArrayList(),
self._jvm.PythonAccumulatorParam(host, port))
self._javaAccumulator = self._jvm.PythonAccumulatorV2(host, port)
self._jsc.sc().register(self._javaAccumulator)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I cannot fully understand why an accumulator is created for every instance of SparkContext . I see it is used when the attribute _jrdd is called but that still does not clear things :(

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So in general you would have one SparkContext and many RDDs. The accumulator here doesn't represent a specific accumulator rather a general mechanism for all of the Python accumulators are built on top of. The design is certainly a bit confusing if you try and think of it as a regular accumulator - I found it helped to look at how the scala side "merge" is implemented.


self.pythonExec = os.environ.get("PYSPARK_PYTHON", 'python')
self.pythonVer = "%d.%d" % sys.version_info[:2]
Expand Down