1- from typing import Optional
1+ from typing import TYPE_CHECKING , Optional
22
3- from pyspark .sql .datasource import DataSource , DataSourceArrowWriter , DataSourceReader
4- from pyspark .sql .types import StructType
3+ from pyspark .sql .datasource import DataSource
54
6- from pyspark_huggingface .huggingface_sink import HuggingFaceSink
7- from pyspark_huggingface .huggingface_source import HuggingFaceSource
5+ if TYPE_CHECKING :
6+ from pyspark .sql .datasource import DataSourceWriter , DataSourceReader
7+ from pyspark .sql .types import StructType
8+
9+ from pyspark_huggingface .huggingface_sink import HuggingFaceSink
10+ from pyspark_huggingface .huggingface_source import HuggingFaceSource
811
912
1013class HuggingFaceDatasets (DataSource ):
@@ -24,15 +27,19 @@ class HuggingFaceDatasets(DataSource):
2427 def __init__ (self , options : dict ):
2528 super ().__init__ (options )
2629 self .options = options
27- self .source : Optional [HuggingFaceSource ] = None
28- self .sink : Optional [HuggingFaceSink ] = None
30+ self .source : Optional ["HuggingFaceSource" ] = None
31+ self .sink : Optional ["HuggingFaceSink" ] = None
32+
33+ def get_source (self ) -> "HuggingFaceSource" :
34+ from pyspark_huggingface .huggingface_source import HuggingFaceSource
2935
30- def get_source (self ) -> HuggingFaceSource :
3136 if self .source is None :
3237 self .source = HuggingFaceSource (self .options .copy ())
3338 return self .source
3439
35- def get_sink (self ):
40+ def get_sink (self ) -> "HuggingFaceSink" :
41+ from pyspark_huggingface .huggingface_sink import HuggingFaceSink
42+
3643 if self .sink is None :
3744 self .sink = HuggingFaceSink (self .options .copy ())
3845 return self .sink
@@ -44,8 +51,8 @@ def name(cls):
4451 def schema (self ):
4552 return self .get_source ().schema ()
4653
47- def reader (self , schema : StructType ) -> "DataSourceReader" :
54+ def reader (self , schema : " StructType" ) -> "DataSourceReader" :
4855 return self .get_source ().reader (schema )
4956
50- def writer (self , schema : StructType , overwrite : bool ) -> "DataSourceArrowWriter " :
57+ def writer (self , schema : " StructType" , overwrite : bool ) -> "DataSourceWriter " :
5158 return self .get_sink ().writer (schema , overwrite )
0 commit comments