Skip to content

Commit 90a4c02

Browse files
authored
Fix cloudcomputes registration for structures (#15964)
* fix cloudcomputes * updates cloudcompute registration * changelog
1 parent e56e7f1 commit 90a4c02

File tree

3 files changed

+44
-7
lines changed

3 files changed

+44
-7
lines changed

src/lightning_app/CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
3737
- Fixed `AutoScaler` failing due to port collision across works ([#15966](https://github.com/Lightning-AI/lightning/pull/15966))
3838

3939

40+
- Fixed Registration for CloudComputes of Works in `L.app.structures` ([#15964](https://github.com/Lightning-AI/lightning/pull/15964))
41+
42+
4043
## [1.8.4] - 2022-12-08
4144

4245
### Added

src/lightning_app/core/flow.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -173,12 +173,19 @@ def __setattr__(self, name: str, value: Any) -> None:
173173
elif isinstance(value, (Dict, List)):
174174
self._structures.add(name)
175175
_set_child_name(self, value, name)
176-
if getattr(self, "_backend", None) is not None:
177-
value._backend = self._backend
178-
for flow in value.flows:
179-
LightningFlow._attach_backend(flow, self._backend)
180-
for work in value.works:
181-
self._backend._wrap_run_method(_LightningAppRef().get_current(), work)
176+
177+
_backend = getattr(self, "backend", None)
178+
if _backend is not None:
179+
value._backend = _backend
180+
181+
for flow in value.flows:
182+
if _backend is not None:
183+
LightningFlow._attach_backend(flow, _backend)
184+
185+
for work in value.works:
186+
work._register_cloud_compute()
187+
if _backend is not None:
188+
_backend._wrap_run_method(_LightningAppRef().get_current(), work)
182189

183190
elif isinstance(value, Path):
184191
# In the init context, the full name of the Flow and Work is not known, i.e., we can't serialize

tests/tests_app/core/test_lightning_flow.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,8 @@
1010
import pytest
1111
from deepdiff import DeepDiff, Delta
1212

13-
from lightning_app import LightningApp
13+
import lightning_app
14+
from lightning_app import CloudCompute, LightningApp
1415
from lightning_app.core.flow import LightningFlow
1516
from lightning_app.core.work import LightningWork
1617
from lightning_app.runners import MultiProcessRuntime
@@ -901,3 +902,29 @@ def run_patch(method):
901902
state = app.api_publish_state_queue.put._mock_call_args[0][0]
902903
call_hash = state["works"]["w"]["calls"]["latest_call_hash"]
903904
assert state["works"]["w"]["calls"][call_hash]["statuses"][0]["stage"] == "succeeded"
905+
906+
907+
def test_structures_register_work_cloudcompute():
908+
class MyDummyWork(LightningWork):
909+
def run(self):
910+
return
911+
912+
class MyDummyFlow(LightningFlow):
913+
def __init__(self):
914+
super().__init__()
915+
self.w_list = LList(*[MyDummyWork(cloud_compute=CloudCompute("gpu")) for i in range(5)])
916+
self.w_dict = LDict(**{str(i): MyDummyWork(cloud_compute=CloudCompute("gpu")) for i in range(5)})
917+
918+
def run(self):
919+
for w in self.w_list:
920+
w.run()
921+
922+
for w in self.w_dict.values():
923+
w.run()
924+
925+
MyDummyFlow()
926+
assert len(lightning_app.utilities.packaging.cloud_compute._CLOUD_COMPUTE_STORE) == 10
927+
for v in lightning_app.utilities.packaging.cloud_compute._CLOUD_COMPUTE_STORE.values():
928+
assert len(v.component_names) == 1
929+
assert v.component_names[0][:-1] in ("root.w_list.", "root.w_dict.")
930+
assert v.component_names[0][-1].isdigit()

0 commit comments

Comments
 (0)