Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 3 additions & 5 deletions core/src/main/scala/spark/HttpFileServer.scala
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
package spark

import java.io.{File, PrintWriter}
import java.net.URL
import scala.collection.mutable.HashMap
import org.apache.hadoop.fs.FileUtil
import java.io.{File}
import com.google.common.io.Files

private[spark] class HttpFileServer extends Logging {

Expand Down Expand Up @@ -40,7 +38,7 @@ private[spark] class HttpFileServer extends Logging {
}

def addFileToDir(file: File, dir: File) : String = {
Utils.copyFile(file, new File(dir, file.getName))
Files.copy(file, new File(dir, file.getName))
return dir + "/" + file.getName
}

Expand Down
7 changes: 4 additions & 3 deletions core/src/main/scala/spark/SparkContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -439,9 +439,10 @@ class SparkContext(
def broadcast[T](value: T) = env.broadcastManager.newBroadcast[T](value, isLocal)

/**
* Add a file to be downloaded into the working directory of this Spark job on every node.
* Add a file to be downloaded with this Spark job on every node.
* The `path` passed can be either a local file, a file in HDFS (or other Hadoop-supported
* filesystems), or an HTTP, HTTPS or FTP URI.
* filesystems), or an HTTP, HTTPS or FTP URI. To access the file in Spark jobs,
* use `SparkFiles.get(path)` to find its download location.
*/
def addFile(path: String) {
val uri = new URI(path)
Expand All @@ -454,7 +455,7 @@ class SparkContext(
// Fetch the file locally in case a job is executed locally.
// Jobs that run through LocalScheduler will already fetch the required dependencies,
// but jobs run in DAGScheduler.runLocally() will not so we must fetch the files here.
Utils.fetchFile(path, new File("."))
Utils.fetchFile(path, new File(SparkFiles.getRootDirectory))

logInfo("Added file " + path + " at " + key + " with timestamp " + addedFiles(key))
}
Expand Down
20 changes: 13 additions & 7 deletions core/src/main/scala/spark/SparkEnv.scala
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,10 @@ class SparkEnv (
val broadcastManager: BroadcastManager,
val blockManager: BlockManager,
val connectionManager: ConnectionManager,
val httpFileServer: HttpFileServer
val httpFileServer: HttpFileServer,
val sparkFilesDir: String
) {

/** No-parameter constructor for unit tests. */
def this() = {
this(null, new JavaSerializer, new JavaSerializer, null, null, null, null, null, null, null)
}

def stop() {
httpFileServer.stop()
mapOutputTracker.stop()
Expand Down Expand Up @@ -112,6 +108,15 @@ object SparkEnv extends Logging {
httpFileServer.initialize()
System.setProperty("spark.fileserver.uri", httpFileServer.serverUri)

// Set the sparkFiles directory, used when downloading dependencies. In local mode,
// this is a temporary directory; in distributed mode, this is the executor's current working
// directory.
val sparkFilesDir: String = if (isMaster) {
Utils.createTempDir().getAbsolutePath
} else {
"."
}

// Warn about deprecated spark.cache.class property
if (System.getProperty("spark.cache.class") != null) {
logWarning("The spark.cache.class property is no longer being used! Specify storage " +
Expand All @@ -128,6 +133,7 @@ object SparkEnv extends Logging {
broadcastManager,
blockManager,
connectionManager,
httpFileServer)
httpFileServer,
sparkFilesDir)
}
}
25 changes: 25 additions & 0 deletions core/src/main/scala/spark/SparkFiles.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
package spark;

import java.io.File;

/**
* Resolves paths to files added through `SparkContext.addFile()`.
*/
public class SparkFiles {

private SparkFiles() {}

/**
* Get the absolute path of a file added through `SparkContext.addFile()`.
*/
public static String get(String filename) {
return new File(getRootDirectory(), filename).getAbsolutePath();
}

/**
* Get the root directory that contains files added through `SparkContext.addFile()`.
*/
public static String getRootDirectory() {
return SparkEnv.get().sparkFilesDir();
}
}
16 changes: 1 addition & 15 deletions core/src/main/scala/spark/Utils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -111,20 +111,6 @@ private object Utils extends Logging {
}
}

/** Copy a file on the local file system */
def copyFile(source: File, dest: File) {
val in = new FileInputStream(source)
val out = new FileOutputStream(dest)
copyStream(in, out, true)
}

/** Download a file from a given URL to the local filesystem */
def downloadFile(url: URL, localPath: String) {
val in = url.openStream()
val out = new FileOutputStream(localPath)
Utils.copyStream(in, out, true)
}

/**
* Download a file requested by the executor. Supports fetching the file in a variety of ways,
* including HTTP, HDFS and files on a standard filesystem, based on the URL parameter.
Expand Down Expand Up @@ -201,7 +187,7 @@ private object Utils extends Logging {
Utils.execute(Seq("tar", "-xf", filename), targetDir)
}
// Make the file executable - That's necessary for scripts
FileUtil.chmod(filename, "a+x")
FileUtil.chmod(targetFile.getAbsolutePath, "a+x")
}

/**
Expand Down
5 changes: 3 additions & 2 deletions core/src/main/scala/spark/api/java/JavaSparkContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -323,9 +323,10 @@ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWork
def getSparkHome(): Option[String] = sc.getSparkHome()

/**
* Add a file to be downloaded into the working directory of this Spark job on every node.
* Add a file to be downloaded with this Spark job on every node.
* The `path` passed can be either a local file, a file in HDFS (or other Hadoop-supported
* filesystems), or an HTTP, HTTPS or FTP URI.
* filesystems), or an HTTP, HTTPS or FTP URI. To access the file in Spark jobs,
* use `SparkFiles.get(path)` to find its download location.
*/
def addFile(path: String) {
sc.addFile(path)
Expand Down
2 changes: 2 additions & 0 deletions core/src/main/scala/spark/api/python/PythonRDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,8 @@ private[spark] class PythonRDD[T: ClassManifest](
val dOut = new DataOutputStream(proc.getOutputStream)
// Split index
dOut.writeInt(split.index)
// sparkFilesDir
PythonRDD.writeAsPickle(SparkFiles.getRootDirectory, dOut)
// Broadcast variables
dOut.writeInt(broadcastVars.length)
for (broadcast <- broadcastVars) {
Expand Down
5 changes: 0 additions & 5 deletions core/src/main/scala/spark/deploy/worker/ExecutorRunner.scala
Original file line number Diff line number Diff line change
Expand Up @@ -106,11 +106,6 @@ private[spark] class ExecutorRunner(
throw new IOException("Failed to create directory " + executorDir)
}

// Download the files it depends on into it (disabled for now)
//for (url <- jobDesc.fileUrls) {
// fetchFile(url, executorDir)
//}

// Launch the process
val command = buildCommandSeq()
val builder = new ProcessBuilder(command: _*).directory(executorDir)
Expand Down
34 changes: 18 additions & 16 deletions core/src/main/scala/spark/executor/Executor.scala
Original file line number Diff line number Diff line change
Expand Up @@ -159,22 +159,24 @@ private[spark] class Executor extends Logging {
* SparkContext. Also adds any new JARs we fetched to the class loader.
*/
private def updateDependencies(newFiles: HashMap[String, Long], newJars: HashMap[String, Long]) {
// Fetch missing dependencies
for ((name, timestamp) <- newFiles if currentFiles.getOrElse(name, -1L) < timestamp) {
logInfo("Fetching " + name + " with timestamp " + timestamp)
Utils.fetchFile(name, new File("."))
currentFiles(name) = timestamp
}
for ((name, timestamp) <- newJars if currentJars.getOrElse(name, -1L) < timestamp) {
logInfo("Fetching " + name + " with timestamp " + timestamp)
Utils.fetchFile(name, new File("."))
currentJars(name) = timestamp
// Add it to our class loader
val localName = name.split("/").last
val url = new File(".", localName).toURI.toURL
if (!urlClassLoader.getURLs.contains(url)) {
logInfo("Adding " + url + " to class loader")
urlClassLoader.addURL(url)
synchronized {
// Fetch missing dependencies
for ((name, timestamp) <- newFiles if currentFiles.getOrElse(name, -1L) < timestamp) {
logInfo("Fetching " + name + " with timestamp " + timestamp)
Utils.fetchFile(name, new File(SparkFiles.getRootDirectory))
currentFiles(name) = timestamp
}
for ((name, timestamp) <- newJars if currentJars.getOrElse(name, -1L) < timestamp) {
logInfo("Fetching " + name + " with timestamp " + timestamp)
Utils.fetchFile(name, new File(SparkFiles.getRootDirectory))
currentJars(name) = timestamp
// Add it to our class loader
val localName = name.split("/").last
val url = new File(SparkFiles.getRootDirectory, localName).toURI.toURL
if (!urlClassLoader.getURLs.contains(url)) {
logInfo("Adding " + url + " to class loader")
urlClassLoader.addURL(url)
}
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -116,16 +116,16 @@ private[spark] class LocalScheduler(threads: Int, maxFailures: Int, sc: SparkCon
// Fetch missing dependencies
for ((name, timestamp) <- newFiles if currentFiles.getOrElse(name, -1L) < timestamp) {
logInfo("Fetching " + name + " with timestamp " + timestamp)
Utils.fetchFile(name, new File("."))
Utils.fetchFile(name, new File(SparkFiles.getRootDirectory))
currentFiles(name) = timestamp
}
for ((name, timestamp) <- newJars if currentJars.getOrElse(name, -1L) < timestamp) {
logInfo("Fetching " + name + " with timestamp " + timestamp)
Utils.fetchFile(name, new File("."))
Utils.fetchFile(name, new File(SparkFiles.getRootDirectory))
currentJars(name) = timestamp
// Add it to our class loader
val localName = name.split("/").last
val url = new File(".", localName).toURI.toURL
val url = new File(SparkFiles.getRootDirectory, localName).toURI.toURL
if (!classLoader.getURLs.contains(url)) {
logInfo("Adding " + url + " to class loader")
classLoader.addURL(url)
Expand Down
9 changes: 6 additions & 3 deletions core/src/test/scala/spark/FileServerSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,8 @@ class FileServerSuite extends FunSuite with BeforeAndAfter {
sc.addFile(tmpFile.toString)
val testData = Array((1,1), (1,1), (2,1), (3,5), (2,2), (3,0))
val result = sc.parallelize(testData).reduceByKey {
val in = new BufferedReader(new FileReader("FileServerSuite.txt"))
val path = SparkFiles.get("FileServerSuite.txt")
val in = new BufferedReader(new FileReader(path))
val fileVal = in.readLine().toInt
in.close()
_ * fileVal + _ * fileVal
Expand All @@ -54,7 +55,8 @@ class FileServerSuite extends FunSuite with BeforeAndAfter {
sc.addFile((new File(tmpFile.toString)).toURL.toString)
val testData = Array((1,1), (1,1), (2,1), (3,5), (2,2), (3,0))
val result = sc.parallelize(testData).reduceByKey {
val in = new BufferedReader(new FileReader("FileServerSuite.txt"))
val path = SparkFiles.get("FileServerSuite.txt")
val in = new BufferedReader(new FileReader(path))
val fileVal = in.readLine().toInt
in.close()
_ * fileVal + _ * fileVal
Expand Down Expand Up @@ -83,7 +85,8 @@ class FileServerSuite extends FunSuite with BeforeAndAfter {
sc.addFile(tmpFile.toString)
val testData = Array((1,1), (1,1), (2,1), (3,5), (2,2), (3,0))
val result = sc.parallelize(testData).reduceByKey {
val in = new BufferedReader(new FileReader("FileServerSuite.txt"))
val path = SparkFiles.get("FileServerSuite.txt")
val in = new BufferedReader(new FileReader(path))
val fileVal = in.readLine().toInt
in.close()
_ * fileVal + _ * fileVal
Expand Down
5 changes: 4 additions & 1 deletion python/pyspark/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
A broadcast variable that gets reused across tasks.
- L{Accumulator<pyspark.accumulators.Accumulator>}
An "add-only" shared variable that tasks can only add values to.
- L{SparkFiles<pyspark.files.SparkFiles>}
Access files shipped with jobs.
"""
import sys
import os
Expand All @@ -19,6 +21,7 @@

from pyspark.context import SparkContext
from pyspark.rdd import RDD
from pyspark.files import SparkFiles


__all__ = ["SparkContext", "RDD"]
__all__ = ["SparkContext", "RDD", "SparkFiles"]
Loading