Skip to content

Commit 951a5d9

Browse files
committed
[SPARK-1549] Add Python support to spark-submit
This PR updates spark-submit to allow submitting Python scripts (currently only with deploy-mode=client, but that's all that was supported before) and updates the PySpark code to properly find various paths, etc. One significant change is that we assume we can always find the Python files either from the Spark assembly JAR (which will happen with the Maven assembly build in make-distribution.sh) or from SPARK_HOME (which will exist in local mode even if you use sbt assembly, and should be enough for testing). This means we no longer need a weird hack to modify the environment for YARN. This patch also updates the Python worker manager to run python with -u, which means unbuffered output (send it to our logs right away instead of waiting a while after stuff was written); this should simplify debugging. In addition, it fixes https://issues.apache.org/jira/browse/SPARK-1709, setting the main class from a JAR's Main-Class attribute if not specified by the user, and fixes a few help strings and style issues in spark-submit. In the future we may want to make the `pyspark` shell use spark-submit as well, but it seems unnecessary for 1.0. Author: Matei Zaharia <[email protected]> Closes #664 from mateiz/py-submit and squashes the following commits: 15e9669 [Matei Zaharia] Fix some uses of path.separator property 051278c [Matei Zaharia] Small style fixes 0afe886 [Matei Zaharia] Add license headers 4650412 [Matei Zaharia] Add pyFiles to PYTHONPATH in executors, remove old YARN stuff, add tests 15f8e1e [Matei Zaharia] Set PYTHONPATH in PythonWorkerFactory in case it wasn't set from outside 47c0655 [Matei Zaharia] More work to make spark-submit work with Python: d4375bd [Matei Zaharia] Clean up description of spark-submit args a bit and add Python ones
1 parent ec09acd commit 951a5d9

File tree

16 files changed

+505
-194
lines changed

16 files changed

+505
-194
lines changed

assembly/pom.xml

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -40,14 +40,6 @@
4040
<deb.user>root</deb.user>
4141
</properties>
4242

43-
<repositories>
44-
<!-- A repository in the local filesystem for the Py4J JAR, which is not in Maven central -->
45-
<repository>
46-
<id>lib</id>
47-
<url>file://${project.basedir}/lib</url>
48-
</repository>
49-
</repositories>
50-
5143
<dependencies>
5244
<dependency>
5345
<groupId>org.apache.spark</groupId>
@@ -84,11 +76,6 @@
8476
<artifactId>spark-sql_${scala.binary.version}</artifactId>
8577
<version>${project.version}</version>
8678
</dependency>
87-
<dependency>
88-
<groupId>net.sf.py4j</groupId>
89-
<artifactId>py4j</artifactId>
90-
<version>0.8.1</version>
91-
</dependency>
9279
</dependencies>
9380

9481
<build>

core/pom.xml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -247,6 +247,11 @@
247247
<artifactId>pyrolite</artifactId>
248248
<version>2.0.1</version>
249249
</dependency>
250+
<dependency>
251+
<groupId>net.sf.py4j</groupId>
252+
<artifactId>py4j</artifactId>
253+
<version>0.8.1</version>
254+
</dependency>
250255
</dependencies>
251256
<build>
252257
<outputDirectory>target/scala-${scala.binary.version}/classes</outputDirectory>

core/src/main/scala/org/apache/spark/SparkEnv.scala

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717

1818
package org.apache.spark
1919

20+
import java.io.File
21+
2022
import scala.collection.JavaConversions._
2123
import scala.collection.mutable
2224
import scala.concurrent.Await
@@ -304,7 +306,7 @@ object SparkEnv extends Logging {
304306
k == "java.class.path"
305307
}.getOrElse(("", ""))
306308
val classPathEntries = classPathProperty._2
307-
.split(conf.get("path.separator", ":"))
309+
.split(File.pathSeparator)
308310
.filterNot(e => e.isEmpty)
309311
.map(e => (e, "System Classpath"))
310312
val addedJarsAndFiles = (addedJars ++ addedFiles).map((_, "Added By User"))
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.api.python
19+
20+
import java.io.File
21+
22+
import scala.collection.mutable.ArrayBuffer
23+
24+
import org.apache.spark.SparkContext
25+
26+
private[spark] object PythonUtils {
27+
/** Get the PYTHONPATH for PySpark, either from SPARK_HOME, if it is set, or from our JAR */
28+
def sparkPythonPath: String = {
29+
val pythonPath = new ArrayBuffer[String]
30+
for (sparkHome <- sys.env.get("SPARK_HOME")) {
31+
pythonPath += Seq(sparkHome, "python").mkString(File.separator)
32+
pythonPath += Seq(sparkHome, "python", "lib", "py4j-0.8.1-src.zip").mkString(File.separator)
33+
}
34+
pythonPath ++= SparkContext.jarOfObject(this)
35+
pythonPath.mkString(File.pathSeparator)
36+
}
37+
38+
/** Merge PYTHONPATHS with the appropriate separator. Ignores blank strings. */
39+
def mergePythonPaths(paths: String*): String = {
40+
paths.filter(_ != "").mkString(File.pathSeparator)
41+
}
42+
}

core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,9 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String
3737
val daemonHost = InetAddress.getByAddress(Array(127, 0, 0, 1))
3838
var daemonPort: Int = 0
3939

40+
val pythonPath = PythonUtils.mergePythonPaths(
41+
PythonUtils.sparkPythonPath, envVars.getOrElse("PYTHONPATH", ""))
42+
4043
def create(): Socket = {
4144
if (useDaemon) {
4245
createThroughDaemon()
@@ -78,9 +81,10 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String
7881
serverSocket = new ServerSocket(0, 1, InetAddress.getByAddress(Array(127, 0, 0, 1)))
7982

8083
// Create and start the worker
81-
val pb = new ProcessBuilder(Seq(pythonExec, "-m", "pyspark.worker"))
84+
val pb = new ProcessBuilder(Seq(pythonExec, "-u", "-m", "pyspark.worker"))
8285
val workerEnv = pb.environment()
8386
workerEnv.putAll(envVars)
87+
workerEnv.put("PYTHONPATH", pythonPath)
8488
val worker = pb.start()
8589

8690
// Redirect the worker's stderr to ours
@@ -151,9 +155,10 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String
151155

152156
try {
153157
// Create and start the daemon
154-
val pb = new ProcessBuilder(Seq(pythonExec, "-m", "pyspark.daemon"))
158+
val pb = new ProcessBuilder(Seq(pythonExec, "-u", "-m", "pyspark.daemon"))
155159
val workerEnv = pb.environment()
156160
workerEnv.putAll(envVars)
161+
workerEnv.put("PYTHONPATH", pythonPath)
157162
daemon = pb.start()
158163

159164
// Redirect the stderr to ours
Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.deploy
19+
20+
import java.io.{IOException, File, InputStream, OutputStream}
21+
22+
import scala.collection.mutable.ArrayBuffer
23+
import scala.collection.JavaConversions._
24+
25+
import org.apache.spark.SparkContext
26+
import org.apache.spark.api.python.PythonUtils
27+
28+
/**
29+
* A main class used by spark-submit to launch Python applications. It executes python as a
30+
* subprocess and then has it connect back to the JVM to access system properties, etc.
31+
*/
32+
object PythonRunner {
33+
def main(args: Array[String]) {
34+
val primaryResource = args(0)
35+
val pyFiles = args(1)
36+
val otherArgs = args.slice(2, args.length)
37+
38+
val pythonExec = sys.env.get("PYSPARK_PYTHON").getOrElse("python") // TODO: get this from conf
39+
40+
// Launch a Py4J gateway server for the process to connect to; this will let it see our
41+
// Java system properties and such
42+
val gatewayServer = new py4j.GatewayServer(null, 0)
43+
gatewayServer.start()
44+
45+
// Build up a PYTHONPATH that includes the Spark assembly JAR (where this class is), the
46+
// python directories in SPARK_HOME (if set), and any files in the pyFiles argument
47+
val pathElements = new ArrayBuffer[String]
48+
pathElements ++= pyFiles.split(",")
49+
pathElements += PythonUtils.sparkPythonPath
50+
pathElements += sys.env.getOrElse("PYTHONPATH", "")
51+
val pythonPath = PythonUtils.mergePythonPaths(pathElements: _*)
52+
53+
// Launch Python process
54+
val builder = new ProcessBuilder(Seq(pythonExec, "-u", primaryResource) ++ otherArgs)
55+
val env = builder.environment()
56+
env.put("PYTHONPATH", pythonPath)
57+
env.put("PYSPARK_GATEWAY_PORT", "" + gatewayServer.getListeningPort)
58+
builder.redirectErrorStream(true) // Ugly but needed for stdout and stderr to synchronize
59+
val process = builder.start()
60+
61+
new RedirectThread(process.getInputStream, System.out, "redirect output").start()
62+
63+
System.exit(process.waitFor())
64+
}
65+
66+
/**
67+
* A utility class to redirect the child process's stdout or stderr
68+
*/
69+
class RedirectThread(in: InputStream, out: OutputStream, name: String) extends Thread(name) {
70+
setDaemon(true)
71+
override def run() {
72+
scala.util.control.Exception.ignoring(classOf[IOException]) {
73+
// FIXME: We copy the stream on the level of bytes to avoid encoding problems.
74+
val buf = new Array[Byte](1024)
75+
var len = in.read(buf)
76+
while (len != -1) {
77+
out.write(buf, 0, len)
78+
out.flush()
79+
len = in.read(buf)
80+
}
81+
}
82+
}
83+
}
84+
}

0 commit comments

Comments
 (0)