Skip to content

Commit ab03652

Browse files
authored
Merge pull request apache-spark-on-k8s#115 from palantir/ds/conda-runner
[SPARK-20001] Conda Runner & full Python conda support
2 parents 55fb50d + 2430478 commit ab03652

File tree

20 files changed

+618
-36
lines changed

20 files changed

+618
-36
lines changed

core/src/main/scala/org/apache/spark/SparkContext.scala

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,10 @@ import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat, Job => NewHad
4242
import org.apache.hadoop.mapreduce.lib.input.{FileInputFormat => NewFileInputFormat}
4343

4444
import org.apache.spark.annotation.DeveloperApi
45+
import org.apache.spark.api.conda.CondaEnvironment
46+
import org.apache.spark.api.conda.CondaEnvironment.CondaSetupInstructions
4547
import org.apache.spark.broadcast.Broadcast
46-
import org.apache.spark.deploy.{LocalSparkCluster, SparkHadoopUtil}
48+
import org.apache.spark.deploy.{CondaRunner, LocalSparkCluster, SparkHadoopUtil}
4749
import org.apache.spark.input.{FixedLengthBinaryInputFormat, PortableDataStream, StreamInputFormat, WholeTextFileInputFormat}
4850
import org.apache.spark.internal.Logging
4951
import org.apache.spark.internal.config._
@@ -336,6 +338,9 @@ class SparkContext(config: SparkConf) extends Logging {
336338
override protected def initialValue(): Properties = new Properties()
337339
}
338340

341+
// Retrieve the Conda Environment from CondaRunner if it has set one up for us
342+
val condaEnvironment: Option[CondaEnvironment] = CondaRunner.condaEnvironment
343+
339344
/* ------------------------------------------------------------------------------------- *
340345
| Initialization. This code initializes the context in a manner that is exception-safe. |
341346
| All internal fields holding state are initialized here, and any error prompts the |
@@ -1851,6 +1856,28 @@ class SparkContext(config: SparkConf) extends Logging {
18511856
*/
18521857
def listJars(): Seq[String] = addedJars.keySet.toSeq
18531858

1859+
private[this] def condaEnvironmentOrFail(): CondaEnvironment = {
1860+
condaEnvironment.getOrElse(sys.error("A conda environment was not set up."))
1861+
}
1862+
1863+
/**
1864+
* Add a set of conda packages (identified by <a href="
1865+
* https://conda.io/docs/spec.html#build-version-spec">package match specification</a>
1866+
* for all tasks to be executed on this SparkContext in the future.
1867+
*/
1868+
def addCondaPackages(packages: Seq[String]): Unit = {
1869+
condaEnvironmentOrFail().installPackages(packages)
1870+
}
1871+
1872+
def addCondaChannel(url: String): Unit = {
1873+
condaEnvironmentOrFail().addChannel(url)
1874+
}
1875+
1876+
private[spark] def buildCondaInstructions(): Option[CondaSetupInstructions] = {
1877+
condaEnvironment.map(_.buildSetupInstructions)
1878+
}
1879+
1880+
18541881
/**
18551882
* When stopping SparkContext inside Spark components, it's easy to cause dead-lock since Spark
18561883
* may wait for some internal threads to finish. It's better to use this method to stop

core/src/main/scala/org/apache/spark/SparkEnv.scala

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ import scala.util.Properties
2626
import com.google.common.collect.MapMaker
2727

2828
import org.apache.spark.annotation.DeveloperApi
29+
import org.apache.spark.api.conda.CondaEnvironment.CondaSetupInstructions
2930
import org.apache.spark.api.python.PythonWorkerFactory
3031
import org.apache.spark.broadcast.BroadcastManager
3132
import org.apache.spark.internal.Logging
@@ -70,7 +71,10 @@ class SparkEnv (
7071
val conf: SparkConf) extends Logging {
7172

7273
private[spark] var isStopped = false
73-
private val pythonWorkers = mutable.HashMap[(String, Map[String, String]), PythonWorkerFactory]()
74+
75+
case class PythonWorkerKey(pythonExec: Option[String], envVars: Map[String, String],
76+
condaInstructions: Option[CondaSetupInstructions])
77+
private val pythonWorkers = mutable.HashMap[PythonWorkerKey, PythonWorkerFactory]()
7478

7579
// A general, soft-reference map for metadata needed during HadoopRDD split computation
7680
// (e.g., HadoopFileRDD uses this to cache JobConfs and InputFormats).
@@ -110,25 +114,29 @@ class SparkEnv (
110114
}
111115

112116
private[spark]
113-
def createPythonWorker(pythonExec: String, envVars: Map[String, String]): java.net.Socket = {
117+
def createPythonWorker(pythonExec: Option[String], envVars: Map[String, String],
118+
condaInstructions: Option[CondaSetupInstructions]): java.net.Socket = {
114119
synchronized {
115-
val key = (pythonExec, envVars)
116-
pythonWorkers.getOrElseUpdate(key, new PythonWorkerFactory(pythonExec, envVars)).create()
120+
val key = PythonWorkerKey(pythonExec, envVars, condaInstructions)
121+
pythonWorkers.getOrElseUpdate(key,
122+
new PythonWorkerFactory(pythonExec, envVars, condaInstructions)).create()
117123
}
118124
}
119125

120126
private[spark]
121-
def destroyPythonWorker(pythonExec: String, envVars: Map[String, String], worker: Socket) {
127+
def destroyPythonWorker(pythonExec: Option[String], envVars: Map[String, String],
128+
condaInstructions: Option[CondaSetupInstructions], worker: Socket) {
122129
synchronized {
123-
val key = (pythonExec, envVars)
130+
val key = PythonWorkerKey(pythonExec, envVars, condaInstructions)
124131
pythonWorkers.get(key).foreach(_.stopWorker(worker))
125132
}
126133
}
127134

128135
private[spark]
129-
def releasePythonWorker(pythonExec: String, envVars: Map[String, String], worker: Socket) {
136+
def releasePythonWorker(pythonExec: Option[String], envVars: Map[String, String],
137+
condaInstructions: Option[CondaSetupInstructions], worker: Socket) {
130138
synchronized {
131-
val key = (pythonExec, envVars)
139+
val key = PythonWorkerKey(pythonExec, envVars, condaInstructions)
132140
pythonWorkers.get(key).foreach(_.releaseWorker(worker))
133141
}
134142
}
Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
package org.apache.spark.api.conda
18+
19+
import java.io.File
20+
import java.nio.file.Path
21+
import java.util.{Map => JMap}
22+
23+
import scala.collection.mutable
24+
25+
import org.apache.spark.internal.Logging
26+
27+
/**
28+
* A stateful class that describes a Conda environment and also keeps track of packages that have
29+
* been added, as well as additional channels.
30+
*
31+
* @param rootPath The root path under which envs/ and pkgs/ are located.
32+
* @param envName The name of the environment.
33+
*/
34+
final class CondaEnvironment(val manager: CondaEnvironmentManager,
35+
val rootPath: Path,
36+
val envName: String,
37+
bootstrapPackages: Seq[String],
38+
bootstrapChannels: Seq[String]) extends Logging {
39+
40+
import CondaEnvironment._
41+
42+
private[this] val packages = mutable.Buffer(bootstrapPackages: _*)
43+
private[this] val channels = bootstrapChannels.toBuffer
44+
45+
val condaEnvDir: Path = rootPath.resolve("envs").resolve(envName)
46+
47+
def activatedEnvironment(startEnv: Map[String, String] = Map.empty): Map[String, String] = {
48+
require(!startEnv.contains("PATH"), "Defining PATH in a CondaEnvironment's startEnv is " +
49+
s"prohibited; found PATH=${startEnv("PATH")}")
50+
import collection.JavaConverters._
51+
val newVars = System.getenv().asScala.toIterator ++ startEnv ++ List(
52+
"CONDA_PREFIX" -> condaEnvDir.toString,
53+
"CONDA_DEFAULT_ENV" -> condaEnvDir.toString,
54+
"PATH" -> (condaEnvDir.resolve("bin").toString +
55+
sys.env.get("PATH").map(File.pathSeparator + _).getOrElse(""))
56+
)
57+
newVars.toMap
58+
}
59+
60+
def addChannel(url: String): Unit = {
61+
channels += url
62+
}
63+
64+
def installPackages(packages: Seq[String]): Unit = {
65+
manager.runCondaProcess(rootPath,
66+
List("install", "-n", envName, "-y", "--override-channels")
67+
::: channels.iterator.flatMap(Iterator("--channel", _)).toList
68+
::: "--" :: packages.toList,
69+
description = s"install dependencies in conda env $condaEnvDir"
70+
)
71+
72+
this.packages ++= packages
73+
}
74+
75+
/**
76+
* Clears the given java environment and replaces all variables with the environment
77+
* produced after calling `activate` inside this conda environment.
78+
*/
79+
def initializeJavaEnvironment(env: JMap[String, String]): Unit = {
80+
env.clear()
81+
val activatedEnv = activatedEnvironment()
82+
activatedEnv.foreach { case (k, v) => env.put(k, v) }
83+
logDebug(s"Initialised environment from conda: $activatedEnv")
84+
}
85+
86+
/**
87+
* This is for sending the instructions to the executors so they can replicate the same steps.
88+
*/
89+
def buildSetupInstructions: CondaSetupInstructions = {
90+
CondaSetupInstructions(packages.toList, channels.toList)
91+
}
92+
}
93+
94+
object CondaEnvironment {
95+
case class CondaSetupInstructions(packages: Seq[String], channels: Seq[String]) {
96+
require(channels.nonEmpty)
97+
require(packages.nonEmpty)
98+
}
99+
}
Lines changed: 149 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,149 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
package org.apache.spark.api.conda
18+
19+
import java.nio.file.Files
20+
import java.nio.file.Path
21+
import java.nio.file.Paths
22+
23+
import scala.collection.JavaConverters._
24+
import scala.sys.process.BasicIO
25+
import scala.sys.process.Process
26+
import scala.sys.process.ProcessBuilder
27+
import scala.sys.process.ProcessIO
28+
29+
import org.apache.spark.SparkConf
30+
import org.apache.spark.SparkException
31+
import org.apache.spark.internal.Logging
32+
import org.apache.spark.internal.config.CONDA_BINARY_PATH
33+
import org.apache.spark.internal.config.CONDA_CHANNEL_URLS
34+
import org.apache.spark.internal.config.CONDA_VERBOSITY
35+
import org.apache.spark.util.Utils
36+
37+
final class CondaEnvironmentManager(condaBinaryPath: String, condaChannelUrls: Seq[String],
38+
verbosity: Int = 0)
39+
extends Logging {
40+
41+
require(condaChannelUrls.nonEmpty, "Can't have an empty list of conda channel URLs")
42+
require(verbosity >= 0 && verbosity <= 3, "Verbosity must be between 0 and 3 inclusively")
43+
44+
def create(
45+
baseDir: String,
46+
bootstrapPackages: Seq[String]): CondaEnvironment = {
47+
require(bootstrapPackages.nonEmpty, "Expected at least one bootstrap package.")
48+
val name = "conda-env"
49+
50+
// must link in /tmp to reduce path length in case baseDir is very long...
51+
// If baseDir path is too long, this breaks conda's 220 character limit for binary replacement.
52+
// Don't even try to use java.io.tmpdir - yarn sets this to a very long path
53+
val linkedBaseDir = Utils.createTempDir("/tmp", "conda").toPath.resolve("real")
54+
logInfo(s"Creating symlink $linkedBaseDir -> $baseDir")
55+
Files.createSymbolicLink(linkedBaseDir, Paths.get(baseDir))
56+
57+
val verbosityFlags = 0.until(verbosity).map(_ => "-v").toList
58+
59+
// Attempt to create environment
60+
runCondaProcess(
61+
linkedBaseDir,
62+
List("create", "-n", name, "-y", "--override-channels", "--no-default-packages")
63+
::: verbosityFlags
64+
::: condaChannelUrls.flatMap(Iterator("--channel", _)).toList
65+
::: "--" :: bootstrapPackages.toList,
66+
description = "create conda env"
67+
)
68+
69+
new CondaEnvironment(this, linkedBaseDir, name, bootstrapPackages, condaChannelUrls)
70+
}
71+
72+
/**
73+
* Create a condarc that only exposes package and env directories under the given baseRoot,
74+
* on top of the from the default pkgs directory inferred from condaBinaryPath.
75+
*
76+
* The file will be placed directly inside the given `baseRoot` dir, and link to `baseRoot/pkgs`
77+
* as the first package cache.
78+
*
79+
* This hack is necessary otherwise conda tries to use the homedir for pkgs cache.
80+
*/
81+
private[this] def generateCondarc(baseRoot: Path): Path = {
82+
val condaPkgsPath = Paths.get(condaBinaryPath).getParent.getParent.resolve("pkgs")
83+
val condarc = baseRoot.resolve("condarc")
84+
val condarcContents =
85+
s"""pkgs_dirs:
86+
| - $baseRoot/pkgs
87+
| - $condaPkgsPath
88+
|envs_dirs:
89+
| - $baseRoot/envs
90+
|show_channel_urls: false
91+
""".stripMargin
92+
Files.write(condarc, List(condarcContents).asJava)
93+
logInfo(f"Using condarc at $condarc:%n$condarcContents")
94+
condarc
95+
}
96+
97+
private[conda] def runCondaProcess(baseRoot: Path,
98+
args: List[String],
99+
description: String): Unit = {
100+
val condarc = generateCondarc(baseRoot)
101+
val fakeHomeDir = baseRoot.resolve("home")
102+
// Attempt to create fake home dir
103+
Files.createDirectories(fakeHomeDir)
104+
105+
val extraEnv = List(
106+
"CONDARC" -> condarc.toString,
107+
"HOME" -> fakeHomeDir.toString
108+
)
109+
110+
val command = Process(
111+
condaBinaryPath :: args,
112+
None,
113+
extraEnv: _*
114+
)
115+
116+
logInfo(s"About to execute $command with environment $extraEnv")
117+
runOrFail(command, description)
118+
logInfo(s"Successfully executed $command with environment $extraEnv")
119+
}
120+
121+
private[this] def runOrFail(command: ProcessBuilder, description: String): Unit = {
122+
val buffer = new StringBuffer
123+
val collectErrOutToBuffer = new ProcessIO(
124+
BasicIO.input(false),
125+
BasicIO.processFully(buffer),
126+
BasicIO.processFully(buffer))
127+
val exitCode = command.run(collectErrOutToBuffer).exitValue()
128+
if (exitCode != 0) {
129+
throw new SparkException(s"Attempt to $description exited with code: "
130+
+ f"$exitCode%nCommand was: $command%nOutput was:%n${buffer.toString}")
131+
}
132+
}
133+
}
134+
135+
object CondaEnvironmentManager {
136+
def isConfigured(sparkConf: SparkConf): Boolean = {
137+
sparkConf.contains(CONDA_BINARY_PATH)
138+
}
139+
140+
def fromConf(sparkConf: SparkConf): CondaEnvironmentManager = {
141+
val condaBinaryPath = sparkConf.get(CONDA_BINARY_PATH).getOrElse(
142+
sys.error(s"Expected config ${CONDA_BINARY_PATH.key} to be set"))
143+
val condaChannelUrls = sparkConf.get(CONDA_CHANNEL_URLS)
144+
require(condaChannelUrls.nonEmpty,
145+
s"Must define at least one conda channel in config ${CONDA_CHANNEL_URLS.key}")
146+
val verbosity = sparkConf.get(CONDA_VERBOSITY)
147+
new CondaEnvironmentManager(condaBinaryPath, condaChannelUrls, verbosity)
148+
}
149+
}

core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -695,6 +695,15 @@ class JavaSparkContext(val sc: SparkContext)
695695
sc.addJar(path)
696696
}
697697

698+
/**
699+
* Add a set of conda packages (identified by <a href="
700+
* https://conda.io/docs/spec.html#build-version-spec">package match specification</a>
701+
* for all tasks to be executed on this SparkContext in the future.
702+
*/
703+
def addCondaPackages(packages: java.util.List[String]): Unit = {
704+
sc.addCondaPackages(packages.asScala)
705+
}
706+
698707
/**
699708
* Returns the Hadoop configuration used for the Hadoop code (e.g. file systems) we reuse.
700709
*

0 commit comments

Comments
 (0)