Skip to content

Commit 784b604

Browse files
panos-isPanos Lantavos-Stratigakispre-commit-ci[bot]rlizzoBorda
authored
(app) Add s3 drive type (1/2) (#14002)
* Add S3 protocol and optimization field to the drive object * Add a list of drives to the work specification * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add only protocol for s3 drives, no optimization arguments, and add tests * added trailing slash criteria * allow slash in s3 drives * fix * fixed test issues Co-authored-by: Panos Lantavos-Stratigakis <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Rick Izzo <[email protected]> Co-authored-by: Jirka Borovec <[email protected]> Co-authored-by: Rick Izzo <[email protected]>
1 parent 7e77367 commit 784b604

File tree

2 files changed

+90
-24
lines changed

2 files changed

+90
-24
lines changed

src/lightning_app/storage/drive.py

Lines changed: 31 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
class Drive:
1414

1515
__IDENTIFIER__ = "__drive__"
16-
__PROTOCOLS__ = ["lit://"]
16+
__PROTOCOLS__ = ["lit://", "s3://"]
1717

1818
def __init__(
1919
self,
@@ -35,15 +35,28 @@ def __init__(
3535
root_folder: This is the folder from where the Drive perceives the data (e.g this acts as a mount dir).
3636
"""
3737
self.id = None
38+
self.protocol = None
3839
for protocol in self.__PROTOCOLS__:
3940
if id.startswith(protocol):
4041
self.protocol = protocol
4142
self.id = id.replace(protocol, "")
43+
break
44+
else: # N.B. for-else loop
45+
raise ValueError(
46+
f"Unknown protocol for the drive 'id' argument '{id}`. The 'id' string "
47+
f"must start with one of the following prefixes {self.__PROTOCOLS__}"
48+
)
49+
50+
if self.protocol == "s3://" and not self.id.endswith("/"):
51+
raise ValueError(
52+
"S3 drives must end in a trailing slash (`/`) to indicate a folder is being mounted. "
53+
f"Recieved: '{id}'. Mounting a single file is not currently supported."
54+
)
4255

4356
if not self.id:
4457
raise Exception(f"The Drive id needs to start with one of the following protocols: {self.__PROTOCOLS__}")
4558

46-
if "/" in self.id:
59+
if self.protocol != "s3://" and "/" in self.id:
4760
raise Exception(f"The id should be unique to identify your drive. Found `{self.id}`.")
4861

4962
self.root_folder = pathlib.Path(root_folder).resolve() if root_folder else os.getcwd()
@@ -75,6 +88,10 @@ def put(self, path: str) -> None:
7588
raise Exception("The component name needs to be known to put a path to the Drive.")
7689
if _is_flow_context():
7790
raise Exception("The flow isn't allowed to put files into a Drive.")
91+
if self.protocol == "s3://":
92+
raise PermissionError(
93+
"S3 based drives cannot currently add files via this API. Did you mean to use `lit://` drives?"
94+
)
7895

7996
self._validate_path(path)
8097

@@ -98,6 +115,10 @@ def list(self, path: Optional[str] = ".", component_name: Optional[str] = None)
98115
"""
99116
if _is_flow_context():
100117
raise Exception("The flow isn't allowed to list files from a Drive.")
118+
if self.protocol == "s3://":
119+
raise PermissionError(
120+
"S3 based drives cannot currently list files via this API. Did you mean to use `lit://` drives?"
121+
)
101122

102123
if component_name:
103124
paths = [
@@ -142,6 +163,10 @@ def get(
142163
"""
143164
if _is_flow_context():
144165
raise Exception("The flow isn't allowed to get files from a Drive.")
166+
if self.protocol == "s3://":
167+
raise PermissionError(
168+
"S3 based drives cannot currently get files via this API. Did you mean to use `lit://` drives?"
169+
)
145170

146171
if component_name:
147172
shared_path = self._to_shared_path(
@@ -189,6 +214,10 @@ def delete(self, path: str) -> None:
189214
"""
190215
if not self.component_name:
191216
raise Exception("The component name needs to be known to delete a path to the Drive.")
217+
if self.protocol == "s3://":
218+
raise PermissionError(
219+
"S3 based drives cannot currently delete files via this API. Did you mean to use `lit://` drives?"
220+
)
192221

193222
shared_path = self._to_shared_path(
194223
path,

tests/tests_app/storage/test_drive.py

Lines changed: 59 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from lightning_app.utilities.component import _set_flow_context
1212

1313

14-
class SyncWorkA(LightningWork):
14+
class SyncWorkLITDriveA(LightningWork):
1515
def __init__(self, tmpdir):
1616
super().__init__()
1717
self.tmpdir = tmpdir
@@ -25,35 +25,35 @@ def run(self, drive: Drive):
2525
os.remove(f"{self.tmpdir}/a.txt")
2626

2727

28-
class SyncWorkB(LightningWork):
28+
class SyncWorkLITDriveB(LightningWork):
2929
def run(self, drive: Drive):
3030
assert not os.path.exists("a.txt")
3131
drive.get("a.txt")
3232
assert os.path.exists("a.txt")
3333

3434

35-
class SyncFlow(LightningFlow):
35+
class SyncFlowLITDrives(LightningFlow):
3636
def __init__(self, tmpdir):
3737
super().__init__()
3838
self.log_dir = Drive("lit://log_dir")
39-
self.work_a = SyncWorkA(str(tmpdir))
40-
self.work_b = SyncWorkB()
39+
self.work_a = SyncWorkLITDriveA(str(tmpdir))
40+
self.work_b = SyncWorkLITDriveB()
4141

4242
def run(self):
4343
self.work_a.run(self.log_dir)
4444
self.work_b.run(self.log_dir)
4545
self._exit()
4646

4747

48-
def test_synchronization_drive(tmpdir):
48+
def test_synchronization_lit_drive(tmpdir):
4949
if os.path.exists("a.txt"):
5050
os.remove("a.txt")
51-
app = LightningApp(SyncFlow(tmpdir))
51+
app = LightningApp(SyncFlowLITDrives(tmpdir))
5252
MultiProcessRuntime(app, start_server=False).dispatch()
5353
os.remove("a.txt")
5454

5555

56-
class Work(LightningWork):
56+
class LITDriveWork(LightningWork):
5757
def __init__(self):
5858
super().__init__(parallel=True)
5959
self.drive = None
@@ -75,7 +75,7 @@ def run(self, *args, **kwargs):
7575
self.counter += 1
7676

7777

78-
class Work2(LightningWork):
78+
class LITDriveWork2(LightningWork):
7979
def __init__(self):
8080
super().__init__(parallel=True)
8181

@@ -86,11 +86,11 @@ def run(self, drive: Drive, **kwargs):
8686
assert drive.list(".", component_name=self.name) == []
8787

8888

89-
class Flow(LightningFlow):
89+
class LITDriveFlow(LightningFlow):
9090
def __init__(self):
9191
super().__init__()
92-
self.work = Work()
93-
self.work2 = Work2()
92+
self.work = LITDriveWork()
93+
self.work2 = LITDriveWork2()
9494

9595
def run(self):
9696
self.work.run("0")
@@ -102,15 +102,15 @@ def run(self):
102102
self._exit()
103103

104104

105-
def test_drive_transferring_files():
106-
app = LightningApp(Flow())
105+
def test_lit_drive_transferring_files():
106+
app = LightningApp(LITDriveFlow())
107107
MultiProcessRuntime(app, start_server=False).dispatch()
108108
os.remove("a.txt")
109109

110110

111-
def test_drive():
112-
with pytest.raises(Exception, match="The Drive id needs to start with one of the following protocols"):
113-
Drive("this_drive_id")
111+
def test_lit_drive():
112+
with pytest.raises(Exception, match="Unknown protocol for the drive 'id' argument"):
113+
Drive("invalid_drive_id")
114114

115115
with pytest.raises(
116116
Exception, match="The id should be unique to identify your drive. Found `this_drive_id/something_else`."
@@ -213,19 +213,56 @@ def test_drive():
213213
os.remove("a.txt")
214214

215215

216-
def test_maybe_create_drive():
216+
def test_s3_drives():
217+
drive = Drive("s3://foo/", allow_duplicates=True)
218+
drive.component_name = "root.work"
217219

218-
drive = Drive("lit://drive_3", allow_duplicates=False)
220+
with pytest.raises(
221+
Exception, match="S3 based drives cannot currently add files via this API. Did you mean to use `lit://` drives?"
222+
):
223+
drive.put("a.txt")
224+
with pytest.raises(
225+
Exception,
226+
match="S3 based drives cannot currently list files via this API. Did you mean to use `lit://` drives?",
227+
):
228+
drive.list("a.txt")
229+
with pytest.raises(
230+
Exception, match="S3 based drives cannot currently get files via this API. Did you mean to use `lit://` drives?"
231+
):
232+
drive.get("a.txt")
233+
with pytest.raises(
234+
Exception,
235+
match="S3 based drives cannot currently delete files via this API. Did you mean to use `lit://` drives?",
236+
):
237+
drive.delete("a.txt")
238+
239+
_set_flow_context()
240+
with pytest.raises(Exception, match="The flow isn't allowed to put files into a Drive."):
241+
drive.put("a.txt")
242+
with pytest.raises(Exception, match="The flow isn't allowed to list files from a Drive."):
243+
drive.list("a.txt")
244+
with pytest.raises(Exception, match="The flow isn't allowed to get files from a Drive."):
245+
drive.get("a.txt")
246+
247+
248+
def test_create_s3_drive_without_trailing_slash_fails():
249+
with pytest.raises(ValueError, match="S3 drives must end in a trailing slash"):
250+
Drive("s3://foo")
251+
252+
253+
@pytest.mark.parametrize("drive_id", ["lit://drive", "s3://drive/"])
254+
def test_maybe_create_drive(drive_id):
255+
drive = Drive(drive_id, allow_duplicates=False)
219256
drive.component_name = "root.work1"
220257
new_drive = _maybe_create_drive(drive.component_name, drive.to_dict())
221258
assert new_drive.protocol == drive.protocol
222259
assert new_drive.id == drive.id
223260
assert new_drive.component_name == drive.component_name
224261

225262

226-
def test_drive_deepcopy():
227-
228-
drive = Drive("lit://drive", allow_duplicates=True)
263+
@pytest.mark.parametrize("drive_id", ["lit://drive", "s3://drive/"])
264+
def test_drive_deepcopy(drive_id):
265+
drive = Drive(drive_id, allow_duplicates=True)
229266
drive.component_name = "root.work1"
230267
new_drive = deepcopy(drive)
231268
assert new_drive.id == drive.id

0 commit comments

Comments
 (0)