Skip to content

Commit d3b8869

Browse files
committed
[SPARK-17585][PYSPARK][CORE] PySpark SparkContext.addFile supports adding files recursively
## What changes were proposed in this pull request? Users would like to add a directory as dependency in some cases, they can use ```SparkContext.addFile``` with argument ```recursive=true``` to recursively add all files under the directory by using Scala. But Python users can only add file not directory, we should also make it supported. ## How was this patch tested? Unit test. Author: Yanbo Liang <[email protected]> Closes #15140 from yanboliang/spark-17585.
1 parent 61876a4 commit d3b8869

File tree

5 files changed

+34
-7
lines changed

5 files changed

+34
-7
lines changed

core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -669,6 +669,19 @@ class JavaSparkContext(val sc: SparkContext)
669669
sc.addFile(path)
670670
}
671671

672+
/**
673+
* Add a file to be downloaded with this Spark job on every node.
674+
* The `path` passed can be either a local file, a file in HDFS (or other Hadoop-supported
675+
* filesystems), or an HTTP, HTTPS or FTP URI. To access the file in Spark jobs,
676+
* use `SparkFiles.get(fileName)` to find its download location.
677+
*
678+
* A directory can be given if the recursive option is set to true. Currently directories are only
679+
* supported for Hadoop-supported filesystems.
680+
*/
681+
def addFile(path: String, recursive: Boolean): Unit = {
682+
sc.addFile(path, recursive)
683+
}
684+
672685
/**
673686
* Adds a JAR dependency for all tasks to be executed on this SparkContext in the future.
674687
* The `path` passed can be either a local file, a file in HDFS (or other Hadoop-supported

python/pyspark/context.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -767,7 +767,7 @@ def accumulator(self, value, accum_param=None):
767767
SparkContext._next_accum_id += 1
768768
return Accumulator(SparkContext._next_accum_id - 1, value, accum_param)
769769

770-
def addFile(self, path):
770+
def addFile(self, path, recursive=False):
771771
"""
772772
Add a file to be downloaded with this Spark job on every node.
773773
The C{path} passed can be either a local file, a file in HDFS
@@ -778,6 +778,9 @@ def addFile(self, path):
778778
L{SparkFiles.get(fileName)<pyspark.files.SparkFiles.get>} with the
779779
filename to find its download location.
780780
781+
A directory can be given if the recursive option is set to True.
782+
Currently directories are only supported for Hadoop-supported filesystems.
783+
781784
>>> from pyspark import SparkFiles
782785
>>> path = os.path.join(tempdir, "test.txt")
783786
>>> with open(path, "w") as testFile:
@@ -790,7 +793,7 @@ def addFile(self, path):
790793
>>> sc.parallelize([1, 2, 3, 4]).mapPartitions(func).collect()
791794
[100, 200, 300, 400]
792795
"""
793-
self._jsc.sc().addFile(path)
796+
self._jsc.sc().addFile(path, recursive)
794797

795798
def addPyFile(self, path):
796799
"""

python/pyspark/tests.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -409,13 +409,23 @@ def func(x):
409409
self.assertEqual("Hello World!", res)
410410

411411
def test_add_file_locally(self):
412-
path = os.path.join(SPARK_HOME, "python/test_support/hello.txt")
412+
path = os.path.join(SPARK_HOME, "python/test_support/hello/hello.txt")
413413
self.sc.addFile(path)
414414
download_path = SparkFiles.get("hello.txt")
415415
self.assertNotEqual(path, download_path)
416416
with open(download_path) as test_file:
417417
self.assertEqual("Hello World!\n", test_file.readline())
418418

419+
def test_add_file_recursively_locally(self):
420+
path = os.path.join(SPARK_HOME, "python/test_support/hello")
421+
self.sc.addFile(path, True)
422+
download_path = SparkFiles.get("hello")
423+
self.assertNotEqual(path, download_path)
424+
with open(download_path + "/hello.txt") as test_file:
425+
self.assertEqual("Hello World!\n", test_file.readline())
426+
with open(download_path + "/sub_hello/sub_hello.txt") as test_file:
427+
self.assertEqual("Sub Hello World!\n", test_file.readline())
428+
419429
def test_add_py_file_locally(self):
420430
# To ensure that we're actually testing addPyFile's effects, check that
421431
# this fails due to `userlibrary` not being on the Python path:
@@ -514,7 +524,7 @@ def test_transforming_pickle_file(self):
514524

515525
def test_cartesian_on_textfile(self):
516526
# Regression test for
517-
path = os.path.join(SPARK_HOME, "python/test_support/hello.txt")
527+
path = os.path.join(SPARK_HOME, "python/test_support/hello/hello.txt")
518528
a = self.sc.textFile(path)
519529
result = a.cartesian(a).collect()
520530
(x, y) = result[0]
@@ -751,7 +761,7 @@ def test_zip_with_different_serializers(self):
751761
b = b._reserialize(MarshalSerializer())
752762
self.assertEqual(a.zip(b).collect(), [(0, 100), (1, 101), (2, 102), (3, 103), (4, 104)])
753763
# regression test for SPARK-4841
754-
path = os.path.join(SPARK_HOME, "python/test_support/hello.txt")
764+
path = os.path.join(SPARK_HOME, "python/test_support/hello/hello.txt")
755765
t = self.sc.textFile(path)
756766
cnt = t.count()
757767
self.assertEqual(cnt, t.zip(t).count())
@@ -1214,7 +1224,7 @@ def test_oldhadoop(self):
12141224
ei = [(1, u'aa'), (1, u'aa'), (2, u'aa'), (2, u'bb'), (2, u'bb'), (3, u'cc')]
12151225
self.assertEqual(ints, ei)
12161226

1217-
hellopath = os.path.join(SPARK_HOME, "python/test_support/hello.txt")
1227+
hellopath = os.path.join(SPARK_HOME, "python/test_support/hello/hello.txt")
12181228
oldconf = {"mapred.input.dir": hellopath}
12191229
hello = self.sc.hadoopRDD("org.apache.hadoop.mapred.TextInputFormat",
12201230
"org.apache.hadoop.io.LongWritable",
@@ -1233,7 +1243,7 @@ def test_newhadoop(self):
12331243
ei = [(1, u'aa'), (1, u'aa'), (2, u'aa'), (2, u'bb'), (2, u'bb'), (3, u'cc')]
12341244
self.assertEqual(ints, ei)
12351245

1236-
hellopath = os.path.join(SPARK_HOME, "python/test_support/hello.txt")
1246+
hellopath = os.path.join(SPARK_HOME, "python/test_support/hello/hello.txt")
12371247
newconf = {"mapred.input.dir": hellopath}
12381248
hello = self.sc.newAPIHadoopRDD("org.apache.hadoop.mapreduce.lib.input.TextInputFormat",
12391249
"org.apache.hadoop.io.LongWritable",
File renamed without changes.
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Sub Hello World!

0 commit comments

Comments
 (0)