Skip to content

Commit 580864e

Browse files
authored
Reduce requirements by deferring import until when it's actually needed (#8)
* Defer imports in huggingface.py until it's actually needed * format * fix * make dependencies optional * make dependencies required
1 parent 0f27b1a commit 580864e

File tree

2 files changed

+19
-11
lines changed

2 files changed

+19
-11
lines changed

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ packages = [
1212
[tool.poetry.dependencies]
1313
python = "^3.9"
1414
datasets = "^3.2"
15+
huggingface_hub = "^0.27.1"
1516

1617
[tool.poetry.group.dev.dependencies]
1718
pytest = "^8.0.0"

pyspark_huggingface/huggingface.py

Lines changed: 18 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
1-
from typing import Optional
1+
from typing import TYPE_CHECKING, Optional
22

3-
from pyspark.sql.datasource import DataSource, DataSourceArrowWriter, DataSourceReader
4-
from pyspark.sql.types import StructType
3+
from pyspark.sql.datasource import DataSource
54

6-
from pyspark_huggingface.huggingface_sink import HuggingFaceSink
7-
from pyspark_huggingface.huggingface_source import HuggingFaceSource
5+
if TYPE_CHECKING:
6+
from pyspark.sql.datasource import DataSourceWriter, DataSourceReader
7+
from pyspark.sql.types import StructType
8+
9+
from pyspark_huggingface.huggingface_sink import HuggingFaceSink
10+
from pyspark_huggingface.huggingface_source import HuggingFaceSource
811

912

1013
class HuggingFaceDatasets(DataSource):
@@ -24,15 +27,19 @@ class HuggingFaceDatasets(DataSource):
2427
def __init__(self, options: dict):
2528
super().__init__(options)
2629
self.options = options
27-
self.source: Optional[HuggingFaceSource] = None
28-
self.sink: Optional[HuggingFaceSink] = None
30+
self.source: Optional["HuggingFaceSource"] = None
31+
self.sink: Optional["HuggingFaceSink"] = None
32+
33+
def get_source(self) -> "HuggingFaceSource":
34+
from pyspark_huggingface.huggingface_source import HuggingFaceSource
2935

30-
def get_source(self) -> HuggingFaceSource:
3136
if self.source is None:
3237
self.source = HuggingFaceSource(self.options.copy())
3338
return self.source
3439

35-
def get_sink(self):
40+
def get_sink(self) -> "HuggingFaceSink":
41+
from pyspark_huggingface.huggingface_sink import HuggingFaceSink
42+
3643
if self.sink is None:
3744
self.sink = HuggingFaceSink(self.options.copy())
3845
return self.sink
@@ -44,8 +51,8 @@ def name(cls):
4451
def schema(self):
4552
return self.get_source().schema()
4653

47-
def reader(self, schema: StructType) -> "DataSourceReader":
54+
def reader(self, schema: "StructType") -> "DataSourceReader":
4855
return self.get_source().reader(schema)
4956

50-
def writer(self, schema: StructType, overwrite: bool) -> "DataSourceArrowWriter":
57+
def writer(self, schema: "StructType", overwrite: bool) -> "DataSourceWriter":
5158
return self.get_sink().writer(schema, overwrite)

0 commit comments

Comments
 (0)