Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/lightning_app/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed `AutoScaler` failing due to port collision across works ([#15966](https://github.com/Lightning-AI/lightning/pull/15966))


- Fixed Registration for CloudComputes of Works in `L.app.structures` ([#15964](https://github.com/Lightning-AI/lightning/pull/15964))


## [1.8.4] - 2022-12-08

### Added
Expand Down
19 changes: 13 additions & 6 deletions src/lightning_app/core/flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,12 +173,19 @@ def __setattr__(self, name: str, value: Any) -> None:
elif isinstance(value, (Dict, List)):
self._structures.add(name)
_set_child_name(self, value, name)
if getattr(self, "_backend", None) is not None:
value._backend = self._backend
for flow in value.flows:
LightningFlow._attach_backend(flow, self._backend)
for work in value.works:
self._backend._wrap_run_method(_LightningAppRef().get_current(), work)

_backend = getattr(self, "backend", None)
if _backend is not None:
value._backend = _backend

for flow in value.flows:
if _backend is not None:
LightningFlow._attach_backend(flow, _backend)

for work in value.works:
work._register_cloud_compute()
if _backend is not None:
_backend._wrap_run_method(_LightningAppRef().get_current(), work)

elif isinstance(value, Path):
# In the init context, the full name of the Flow and Work is not known, i.e., we can't serialize
Expand Down
29 changes: 28 additions & 1 deletion tests/tests_app/core/test_lightning_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@
import pytest
from deepdiff import DeepDiff, Delta

from lightning_app import LightningApp
import lightning_app
from lightning_app import CloudCompute, LightningApp
from lightning_app.core.flow import LightningFlow
from lightning_app.core.work import LightningWork
from lightning_app.runners import MultiProcessRuntime
Expand Down Expand Up @@ -901,3 +902,29 @@ def run_patch(method):
state = app.api_publish_state_queue.put._mock_call_args[0][0]
call_hash = state["works"]["w"]["calls"]["latest_call_hash"]
assert state["works"]["w"]["calls"][call_hash]["statuses"][0]["stage"] == "succeeded"


def test_structures_register_work_cloudcompute():
class MyDummyWork(LightningWork):
def run(self):
return

class MyDummyFlow(LightningFlow):
def __init__(self):
super().__init__()
self.w_list = LList(*[MyDummyWork(cloud_compute=CloudCompute("gpu")) for i in range(5)])
self.w_dict = LDict(**{str(i): MyDummyWork(cloud_compute=CloudCompute("gpu")) for i in range(5)})

def run(self):
for w in self.w_list:
w.run()

for w in self.w_dict.values():
w.run()

MyDummyFlow()
assert len(lightning_app.utilities.packaging.cloud_compute._CLOUD_COMPUTE_STORE) == 10
for v in lightning_app.utilities.packaging.cloud_compute._CLOUD_COMPUTE_STORE.values():
assert len(v.component_names) == 1
assert v.component_names[0][:-1] in ("root.w_list.", "root.w_dict.")
assert v.component_names[0][-1].isdigit()