@@ -25,19 +25,52 @@ import scala.sys.process.BasicIO
2525import scala .sys .process .Process
2626import scala .sys .process .ProcessBuilder
2727import scala .sys .process .ProcessIO
28+ import scala .sys .process .ProcessLogger
29+
30+ import com .fasterxml .jackson .databind .ObjectMapper
31+ import com .fasterxml .jackson .dataformat .yaml .YAMLFactory
32+ import org .json4s .JsonAST .JValue
33+ import org .json4s .jackson .Json4sScalaModule
34+ import org .json4s .jackson .JsonMethods
2835
2936import org .apache .spark .SparkConf
3037import org .apache .spark .SparkException
3138import org .apache .spark .internal .Logging
3239import org .apache .spark .internal .config .CONDA_BINARY_PATH
33- import org .apache .spark .internal .config .CONDA_CHANNEL_URLS
40+ import org .apache .spark .internal .config .CONDA_GLOBAL_PACKAGE_DIRS
3441import org .apache .spark .internal .config .CONDA_VERBOSITY
3542import org .apache .spark .util .Utils
3643
37- final class CondaEnvironmentManager (condaBinaryPath : String , verbosity : Int = 0 ) extends Logging {
44+ final class CondaEnvironmentManager (condaBinaryPath : String ,
45+ verbosity : Int = 0 ,
46+ packageDirs : Seq [String ] = Nil ) extends Logging {
3847
3948 require(verbosity >= 0 && verbosity <= 3 , " Verbosity must be between 0 and 3 inclusively" )
4049
50+ lazy val defaultInfo : Map [String , JValue ] = {
51+ logInfo(" Retrieving the conda installation's info" )
52+ val command = Process (List (condaBinaryPath, " info" , " --json" ), None )
53+
54+ val buffer = new StringBuffer
55+ val io = BasicIO (withIn = false ,
56+ buffer,
57+ Some (ProcessLogger (line => logDebug(s " <conda> $line" ))))
58+
59+ val exitCode = command.run(io).exitValue()
60+ if (exitCode != 0 ) {
61+ throw new SparkException (s " Attempt to retrieve initial conda info exited with code: "
62+ + f " $exitCode%nCommand was: $command%nOutput was:%n ${buffer.toString}" )
63+ }
64+
65+ implicit val format = org.json4s.DefaultFormats
66+ JsonMethods .parse(buffer.toString).extract[Map [String , JValue ]]
67+ }
68+
69+ lazy val defaultPkgsDirs : List [String ] = {
70+ implicit val format = org.json4s.DefaultFormats
71+ defaultInfo(" pkgs_dirs" ).extract[List [String ]]
72+ }
73+
4174 def create (
4275 baseDir : String ,
4376 condaPackages : Seq [String ],
@@ -58,11 +91,11 @@ final class CondaEnvironmentManager(condaBinaryPath: String, verbosity: Int = 0)
5891 // Attempt to create environment
5992 runCondaProcess(
6093 linkedBaseDir,
61- List (" create" , " -n" , name, " -y" , " --override-channels " , " -- no-default-packages" )
94+ List (" create" , " -n" , name, " -y" , " --no-default-packages" )
6295 ::: verbosityFlags
63- ::: condaChannelUrls.flatMap(Iterator (" --channel" , _)).toList
6496 ::: " --" :: condaPackages.toList,
65- description = " create conda env"
97+ description = " create conda env" ,
98+ channels = condaChannelUrls.toList
6699 )
67100
68101 new CondaEnvironment (this , linkedBaseDir, name, condaPackages, condaChannelUrls)
@@ -77,28 +110,37 @@ final class CondaEnvironmentManager(condaBinaryPath: String, verbosity: Int = 0)
77110 *
78111 * This hack is necessary otherwise conda tries to use the homedir for pkgs cache.
79112 */
80- private [this ] def generateCondarc (baseRoot : Path ): Path = {
81- val condaPkgsPath = Paths .get(condaBinaryPath).getParent.getParent.resolve( " pkgs " )
113+ private [this ] def generateCondarc (baseRoot : Path , channelUrls : Seq [ String ] ): Path = {
114+
82115 val condarc = baseRoot.resolve(" condarc" )
83- val condarcContents =
84- s """ pkgs_dirs:
85- | - $baseRoot/pkgs
86- | - $condaPkgsPath
87- |envs_dirs:
88- | - $baseRoot/envs
89- |show_channel_urls: false
90- |channels: []
91- |default_channels: []
92- """ .stripMargin
93- Files .write(condarc, List (condarcContents).asJava)
94- logInfo(f " Using condarc at $condarc:%n $condarcContents" )
116+
117+ import org .json4s .JsonAST ._
118+ import org .json4s .JsonDSL ._
119+
120+ // building it in json4s AST since it gives us more control over how it will be mapped
121+ val condarcNode = JObject (
122+ " pkgs_dirs" -> (packageDirs ++: s " $baseRoot/pkgs " +: defaultPkgsDirs),
123+ " envs_dirs" -> List (s " $baseRoot/envs " ),
124+ " show_channel_urls" -> false ,
125+ " default_channels" -> JArray (Nil ),
126+ " channels" -> channelUrls
127+ )
128+ val mapper = new ObjectMapper (new YAMLFactory ()).registerModule(Json4sScalaModule )
129+
130+ Files .write(condarc, mapper.writeValueAsBytes(condarcNode))
131+
132+ val sanitizedCondarc = condarcNode removeField { case (name, _) => name == " channels" }
133+ logInfo(f " Using condarc at $condarc (channels have been edited out):%n "
134+ + mapper.writeValueAsString(sanitizedCondarc))
135+
95136 condarc
96137 }
97138
98139 private [conda] def runCondaProcess (baseRoot : Path ,
99140 args : List [String ],
141+ channels : List [String ],
100142 description : String ): Unit = {
101- val condarc = generateCondarc(baseRoot)
143+ val condarc = generateCondarc(baseRoot, channels )
102144 val fakeHomeDir = baseRoot.resolve(" home" )
103145 // Attempt to create fake home dir
104146 Files .createDirectories(fakeHomeDir)
@@ -142,6 +184,7 @@ object CondaEnvironmentManager {
142184 val condaBinaryPath = sparkConf.get(CONDA_BINARY_PATH ).getOrElse(
143185 sys.error(s " Expected config ${CONDA_BINARY_PATH .key} to be set " ))
144186 val verbosity = sparkConf.get(CONDA_VERBOSITY )
145- new CondaEnvironmentManager (condaBinaryPath, verbosity)
187+ val packageDirs = sparkConf.get(CONDA_GLOBAL_PACKAGE_DIRS )
188+ new CondaEnvironmentManager (condaBinaryPath, verbosity, packageDirs)
146189 }
147190}
0 commit comments