1111from lightning_app .utilities .component import _set_flow_context
1212
1313
14- class SyncWorkA (LightningWork ):
14+ class SyncWorkLITDriveA (LightningWork ):
1515 def __init__ (self , tmpdir ):
1616 super ().__init__ ()
1717 self .tmpdir = tmpdir
@@ -25,35 +25,35 @@ def run(self, drive: Drive):
2525 os .remove (f"{ self .tmpdir } /a.txt" )
2626
2727
28- class SyncWorkB (LightningWork ):
28+ class SyncWorkLITDriveB (LightningWork ):
2929 def run (self , drive : Drive ):
3030 assert not os .path .exists ("a.txt" )
3131 drive .get ("a.txt" )
3232 assert os .path .exists ("a.txt" )
3333
3434
35- class SyncFlow (LightningFlow ):
35+ class SyncFlowLITDrives (LightningFlow ):
3636 def __init__ (self , tmpdir ):
3737 super ().__init__ ()
3838 self .log_dir = Drive ("lit://log_dir" )
39- self .work_a = SyncWorkA (str (tmpdir ))
40- self .work_b = SyncWorkB ()
39+ self .work_a = SyncWorkLITDriveA (str (tmpdir ))
40+ self .work_b = SyncWorkLITDriveB ()
4141
4242 def run (self ):
4343 self .work_a .run (self .log_dir )
4444 self .work_b .run (self .log_dir )
4545 self ._exit ()
4646
4747
48- def test_synchronization_drive (tmpdir ):
48+ def test_synchronization_lit_drive (tmpdir ):
4949 if os .path .exists ("a.txt" ):
5050 os .remove ("a.txt" )
51- app = LightningApp (SyncFlow (tmpdir ))
51+ app = LightningApp (SyncFlowLITDrives (tmpdir ))
5252 MultiProcessRuntime (app , start_server = False ).dispatch ()
5353 os .remove ("a.txt" )
5454
5555
56- class Work (LightningWork ):
56+ class LITDriveWork (LightningWork ):
5757 def __init__ (self ):
5858 super ().__init__ (parallel = True )
5959 self .drive = None
@@ -75,7 +75,7 @@ def run(self, *args, **kwargs):
7575 self .counter += 1
7676
7777
78- class Work2 (LightningWork ):
78+ class LITDriveWork2 (LightningWork ):
7979 def __init__ (self ):
8080 super ().__init__ (parallel = True )
8181
@@ -86,11 +86,11 @@ def run(self, drive: Drive, **kwargs):
8686 assert drive .list ("." , component_name = self .name ) == []
8787
8888
89- class Flow (LightningFlow ):
89+ class LITDriveFlow (LightningFlow ):
9090 def __init__ (self ):
9191 super ().__init__ ()
92- self .work = Work ()
93- self .work2 = Work2 ()
92+ self .work = LITDriveWork ()
93+ self .work2 = LITDriveWork2 ()
9494
9595 def run (self ):
9696 self .work .run ("0" )
@@ -102,15 +102,15 @@ def run(self):
102102 self ._exit ()
103103
104104
105- def test_drive_transferring_files ():
106- app = LightningApp (Flow ())
105+ def test_lit_drive_transferring_files ():
106+ app = LightningApp (LITDriveFlow ())
107107 MultiProcessRuntime (app , start_server = False ).dispatch ()
108108 os .remove ("a.txt" )
109109
110110
111- def test_drive ():
112- with pytest .raises (Exception , match = "The Drive id needs to start with one of the following protocols " ):
113- Drive ("this_drive_id " )
111+ def test_lit_drive ():
112+ with pytest .raises (Exception , match = "Unknown protocol for the drive 'id' argument " ):
113+ Drive ("invalid_drive_id " )
114114
115115 with pytest .raises (
116116 Exception , match = "The id should be unique to identify your drive. Found `this_drive_id/something_else`."
@@ -213,19 +213,56 @@ def test_drive():
213213 os .remove ("a.txt" )
214214
215215
216- def test_maybe_create_drive ():
216+ def test_s3_drives ():
217+ drive = Drive ("s3://foo/" , allow_duplicates = True )
218+ drive .component_name = "root.work"
217219
218- drive = Drive ("lit://drive_3" , allow_duplicates = False )
220+ with pytest .raises (
221+ Exception , match = "S3 based drives cannot currently add files via this API. Did you mean to use `lit://` drives?"
222+ ):
223+ drive .put ("a.txt" )
224+ with pytest .raises (
225+ Exception ,
226+ match = "S3 based drives cannot currently list files via this API. Did you mean to use `lit://` drives?" ,
227+ ):
228+ drive .list ("a.txt" )
229+ with pytest .raises (
230+ Exception , match = "S3 based drives cannot currently get files via this API. Did you mean to use `lit://` drives?"
231+ ):
232+ drive .get ("a.txt" )
233+ with pytest .raises (
234+ Exception ,
235+ match = "S3 based drives cannot currently delete files via this API. Did you mean to use `lit://` drives?" ,
236+ ):
237+ drive .delete ("a.txt" )
238+
239+ _set_flow_context ()
240+ with pytest .raises (Exception , match = "The flow isn't allowed to put files into a Drive." ):
241+ drive .put ("a.txt" )
242+ with pytest .raises (Exception , match = "The flow isn't allowed to list files from a Drive." ):
243+ drive .list ("a.txt" )
244+ with pytest .raises (Exception , match = "The flow isn't allowed to get files from a Drive." ):
245+ drive .get ("a.txt" )
246+
247+
248+ def test_create_s3_drive_without_trailing_slash_fails ():
249+ with pytest .raises (ValueError , match = "S3 drives must end in a trailing slash" ):
250+ Drive ("s3://foo" )
251+
252+
253+ @pytest .mark .parametrize ("drive_id" , ["lit://drive" , "s3://drive/" ])
254+ def test_maybe_create_drive (drive_id ):
255+ drive = Drive (drive_id , allow_duplicates = False )
219256 drive .component_name = "root.work1"
220257 new_drive = _maybe_create_drive (drive .component_name , drive .to_dict ())
221258 assert new_drive .protocol == drive .protocol
222259 assert new_drive .id == drive .id
223260 assert new_drive .component_name == drive .component_name
224261
225262
226- def test_drive_deepcopy ():
227-
228- drive = Drive ("lit://drive" , allow_duplicates = True )
263+ @ pytest . mark . parametrize ( "drive_id" , [ "lit://drive" , "s3://drive/" ])
264+ def test_drive_deepcopy ( drive_id ):
265+ drive = Drive (drive_id , allow_duplicates = True )
229266 drive .component_name = "root.work1"
230267 new_drive = deepcopy (drive )
231268 assert new_drive .id == drive .id
0 commit comments