Skip to content

Commit c277a2c

Browse files
committed
PySpark SparkContext.addFile supports adding files recursively
1 parent 3a3c9ff commit c277a2c

File tree

5 files changed

+33
-7
lines changed

5 files changed

+33
-7
lines changed

core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -669,6 +669,19 @@ class JavaSparkContext(val sc: SparkContext)
669669
sc.addFile(path)
670670
}
671671

672+
/**
673+
* Add a file to be downloaded with this Spark job on every node.
674+
* The `path` passed can be either a local file, a file in HDFS (or other Hadoop-supported
675+
* filesystems), or an HTTP, HTTPS or FTP URI. To access the file in Spark jobs,
676+
* use `SparkFiles.get(fileName)` to find its download location.
677+
*
678+
* A directory can be given if the recursive option is set to true. Currently directories are only
679+
* supported for Hadoop-supported filesystems.
680+
*/
681+
def addFile(path: String, recursive: Boolean): Unit = {
682+
sc.addFile(path, recursive)
683+
}
684+
672685
/**
673686
* Adds a JAR dependency for all tasks to be executed on this SparkContext in the future.
674687
* The `path` passed can be either a local file, a file in HDFS (or other Hadoop-supported

python/pyspark/context.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -762,7 +762,7 @@ def accumulator(self, value, accum_param=None):
762762
SparkContext._next_accum_id += 1
763763
return Accumulator(SparkContext._next_accum_id - 1, value, accum_param)
764764

765-
def addFile(self, path):
765+
def addFile(self, path, recursive=False):
766766
"""
767767
Add a file to be downloaded with this Spark job on every node.
768768
The C{path} passed can be either a local file, a file in HDFS
@@ -773,6 +773,9 @@ def addFile(self, path):
773773
L{SparkFiles.get(fileName)<pyspark.files.SparkFiles.get>} with the
774774
filename to find its download location.
775775
776+
A directory can be given if the recursive option is set to True.
777+
Currently directories are only supported for Hadoop-supported filesystems.
778+
776779
>>> from pyspark import SparkFiles
777780
>>> path = os.path.join(tempdir, "test.txt")
778781
>>> with open(path, "w") as testFile:
@@ -785,7 +788,7 @@ def addFile(self, path):
785788
>>> sc.parallelize([1, 2, 3, 4]).mapPartitions(func).collect()
786789
[100, 200, 300, 400]
787790
"""
788-
self._jsc.sc().addFile(path)
791+
self._jsc.sc().addFile(path, recursive)
789792

790793
def addPyFile(self, path):
791794
"""

python/pyspark/tests.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -409,13 +409,22 @@ def func(x):
409409
self.assertEqual("Hello World!", res)
410410

411411
def test_add_file_locally(self):
412-
path = os.path.join(SPARK_HOME, "python/test_support/hello.txt")
412+
path = os.path.join(SPARK_HOME, "python/test_support/hello/hello.txt")
413413
self.sc.addFile(path)
414414
download_path = SparkFiles.get("hello.txt")
415415
self.assertNotEqual(path, download_path)
416416
with open(download_path) as test_file:
417417
self.assertEqual("Hello World!\n", test_file.readline())
418418

419+
path = os.path.join(SPARK_HOME, "python/test_support/hello")
420+
self.sc.addFile(path, True)
421+
download_path = SparkFiles.get("hello")
422+
self.assertNotEqual(path, download_path)
423+
with open(download_path + "/hello.txt") as test_file:
424+
self.assertEqual("Hello World!\n", test_file.readline())
425+
with open(download_path + "/sub_hello/sub_hello.txt") as test_file:
426+
self.assertEqual("Sub Hello World!\n", test_file.readline())
427+
419428
def test_add_py_file_locally(self):
420429
# To ensure that we're actually testing addPyFile's effects, check that
421430
# this fails due to `userlibrary` not being on the Python path:
@@ -514,7 +523,7 @@ def test_transforming_pickle_file(self):
514523

515524
def test_cartesian_on_textfile(self):
516525
# Regression test for
517-
path = os.path.join(SPARK_HOME, "python/test_support/hello.txt")
526+
path = os.path.join(SPARK_HOME, "python/test_support/hello/hello.txt")
518527
a = self.sc.textFile(path)
519528
result = a.cartesian(a).collect()
520529
(x, y) = result[0]
@@ -751,7 +760,7 @@ def test_zip_with_different_serializers(self):
751760
b = b._reserialize(MarshalSerializer())
752761
self.assertEqual(a.zip(b).collect(), [(0, 100), (1, 101), (2, 102), (3, 103), (4, 104)])
753762
# regression test for SPARK-4841
754-
path = os.path.join(SPARK_HOME, "python/test_support/hello.txt")
763+
path = os.path.join(SPARK_HOME, "python/test_support/hello/hello.txt")
755764
t = self.sc.textFile(path)
756765
cnt = t.count()
757766
self.assertEqual(cnt, t.zip(t).count())
@@ -1214,7 +1223,7 @@ def test_oldhadoop(self):
12141223
ei = [(1, u'aa'), (1, u'aa'), (2, u'aa'), (2, u'bb'), (2, u'bb'), (3, u'cc')]
12151224
self.assertEqual(ints, ei)
12161225

1217-
hellopath = os.path.join(SPARK_HOME, "python/test_support/hello.txt")
1226+
hellopath = os.path.join(SPARK_HOME, "python/test_support/hello/hello.txt")
12181227
oldconf = {"mapred.input.dir": hellopath}
12191228
hello = self.sc.hadoopRDD("org.apache.hadoop.mapred.TextInputFormat",
12201229
"org.apache.hadoop.io.LongWritable",
@@ -1233,7 +1242,7 @@ def test_newhadoop(self):
12331242
ei = [(1, u'aa'), (1, u'aa'), (2, u'aa'), (2, u'bb'), (2, u'bb'), (3, u'cc')]
12341243
self.assertEqual(ints, ei)
12351244

1236-
hellopath = os.path.join(SPARK_HOME, "python/test_support/hello.txt")
1245+
hellopath = os.path.join(SPARK_HOME, "python/test_support/hello/hello.txt")
12371246
newconf = {"mapred.input.dir": hellopath}
12381247
hello = self.sc.newAPIHadoopRDD("org.apache.hadoop.mapreduce.lib.input.TextInputFormat",
12391248
"org.apache.hadoop.io.LongWritable",
File renamed without changes.
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Sub Hello World!

0 commit comments

Comments
 (0)