diff --git a/python/docs/source/reference/pyspark.rst b/python/docs/source/reference/pyspark.rst index 6d4d0b55477c..bf4e66ee3353 100644 --- a/python/docs/source/reference/pyspark.rst +++ b/python/docs/source/reference/pyspark.rst @@ -53,6 +53,7 @@ Spark Context APIs SparkContext.PACKAGE_EXTENSIONS SparkContext.accumulator + SparkContext.addArchive SparkContext.addFile SparkContext.addPyFile SparkContext.applicationId diff --git a/python/pyspark/context.py b/python/pyspark/context.py index 68f748e68faa..2f1746b0a434 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -1278,6 +1278,50 @@ def addPyFile(self, path: str) -> None: importlib.invalidate_caches() + def addArchive(self, path: str) -> None: + """ + Add an archive to be downloaded with this Spark job on every node. + The `path` passed can be either a local file, a file in HDFS + (or other Hadoop-supported filesystems), or an HTTP, HTTPS or + FTP URI. + + To access the file in Spark jobs, use :meth:`SparkFiles.get` with the + filename to find its download/unpacked location. The given path should + be one of .zip, .tar, .tar.gz, .tgz and .jar. + + .. versionadded:: 3.3.0 + + Notes + ----- + A path can be added only once. Subsequent additions of the same path are ignored. + This API is experimental. + + Examples + -------- + Creates a zipped file that contains a text file written '100'. + + >>> import zipfile + >>> from pyspark import SparkFiles + >>> path = os.path.join(tempdir, "test.txt") + >>> zip_path = os.path.join(tempdir, "test.zip") + >>> with zipfile.ZipFile(zip_path, "w", zipfile.ZIP_DEFLATED) as zipped: + ... with open(path, "w") as f: + ... _ = f.write("100") + ... zipped.write(path, os.path.basename(path)) + >>> sc.addArchive(zip_path) + + Reads the '100' as an integer in the zipped file, and processes + it with the data in the RDD. + + >>> def func(iterator): + ... with open("%s/test.txt" % SparkFiles.get("test.zip")) as f: + ... v = int(f.readline()) + ... return [x * int(v) for x in iterator] + >>> sc.parallelize([1, 2, 3, 4]).mapPartitions(func).collect() + [100, 200, 300, 400] + """ + self._jsc.sc().addArchive(path) + def setCheckpointDir(self, dirName: str) -> None: """ Set the directory under which RDDs are going to be checkpointed. The