33from typing_extensions import Protocol , runtime_checkable
44
55from lightning_app .components .multi_node .base import MultiNode
6+ from lightning_app .core .queues import MultiProcessQueue
67from lightning_app .core .work import LightningWork
7- from lightning_app .utilities .app_helpers import is_static_method
88from lightning_app .utilities .packaging .cloud_compute import CloudCompute
9- from lightning_app .utilities .proxies import WorkRunExecutor
9+ from lightning_app .utilities .proxies import _proxy_setattr , unwrap , WorkRunExecutor , WorkStateObserver
1010
1111
1212@runtime_checkable
@@ -22,6 +22,9 @@ def run(
2222
2323
2424class _PyTorchSpawnRunExecutor (WorkRunExecutor ):
25+
26+ enable_start_observer : bool = False
27+
2528 def __call__ (
2629 self ,
2730 main_address : str ,
@@ -31,10 +34,31 @@ def __call__(
3134 ):
3235 import torch
3336
34- nprocs = torch .cuda .device_count () if torch .cuda .is_available () else 1
35- torch .multiprocessing .spawn (
36- self .run , args = (self .work_run , main_address , main_port , num_nodes , node_rank , nprocs ), nprocs = nprocs
37- )
37+ with self .enable_spawn ():
38+ nprocs = torch .cuda .device_count () if torch .cuda .is_available () else 1
39+ queue = self .delta_queue if isinstance (self .delta_queue , MultiProcessQueue ) else self .delta_queue .to_dict ()
40+ torch .multiprocessing .spawn (
41+ self .dispatch_run ,
42+ args = (self .__class__ , self .work , queue , main_address , main_port , num_nodes , node_rank , nprocs ),
43+ nprocs = nprocs ,
44+ )
45+
46+ @staticmethod
47+ def dispatch_run (local_rank , cls , work , delta_queue , * args , ** kwargs ):
48+ if local_rank == 0 :
49+ if isinstance (delta_queue , dict ):
50+ delta_queue = cls .process_queue (delta_queue )
51+ work ._request_queue = cls .process_queue (work ._request_queue )
52+ work ._response_queue = cls .process_queue (work ._response_queue )
53+
54+ state_observer = WorkStateObserver (work , delta_queue = delta_queue )
55+ state_observer .start ()
56+ _proxy_setattr (work , delta_queue , state_observer )
57+
58+ cls .run (local_rank , unwrap (work .run ), * args , ** kwargs )
59+
60+ if local_rank == 0 :
61+ state_observer .join (0 )
3862
3963 @staticmethod
4064 def run (
@@ -46,6 +70,7 @@ def run(
4670 node_rank : int ,
4771 nprocs : int ,
4872 ):
73+
4974 import torch
5075
5176 # 1. Setting distributed environment
@@ -76,11 +101,6 @@ def __init__(
76101 ** work_kwargs : Any ,
77102 ) -> None :
78103 assert issubclass (work_cls , _PyTorchSpawnWorkProtocol )
79- if not is_static_method (work_cls , "run" ):
80- raise TypeError (
81- f"The provided { work_cls } run method needs to be static for now."
82- "HINT: Remove `self` and add staticmethod decorator."
83- )
84104
85105 # Note: Private way to modify the work run executor
86106 # Probably exposed to the users in the future if needed.
0 commit comments