|
1 | 1 | import abc |
2 | 2 | import base64 |
3 | 3 | import os |
| 4 | +import platform |
4 | 5 | from pathlib import Path |
5 | 6 | from typing import Any, Dict, Optional |
6 | 7 |
|
7 | 8 | import uvicorn |
8 | 9 | from fastapi import FastAPI |
9 | | -from lightning_utilities.core.imports import module_available |
| 10 | +from lightning_utilities.core.imports import compare_version, module_available |
10 | 11 | from pydantic import BaseModel |
11 | 12 |
|
12 | | -from lightning_app.core.queues import MultiProcessQueue |
13 | 13 | from lightning_app.core.work import LightningWork |
14 | 14 | from lightning_app.utilities.app_helpers import Logger |
15 | 15 | from lightning_app.utilities.imports import _is_torch_available, requires |
16 | | -from lightning_app.utilities.proxies import _proxy_setattr, unwrap, WorkRunExecutor, WorkStateObserver |
17 | 16 |
|
18 | 17 | logger = Logger(__name__) |
19 | 18 |
|
|
27 | 26 | __doctest_skip__ += ["PythonServer", "PythonServer.*"] |
28 | 27 |
|
29 | 28 |
|
30 | | -class _PyTorchSpawnRunExecutor(WorkRunExecutor): |
| 29 | +def _get_device(): |
| 30 | + import operator |
31 | 31 |
|
32 | | - """This Executor enables to move PyTorch tensors on GPU. |
| 32 | + import torch |
33 | 33 |
|
34 | | - Without this executor, it would raise the following exception: |
35 | | - RuntimeError: Cannot re-initialize CUDA in forked subprocess. |
36 | | - To use CUDA with multiprocessing, you must use the 'spawn' start method |
37 | | - """ |
| 34 | + _TORCH_GREATER_EQUAL_1_12 = compare_version("torch", operator.ge, "1.12.0") |
38 | 35 |
|
39 | | - enable_start_observer: bool = False |
| 36 | + local_rank = int(os.getenv("LOCAL_RANK", "0")) |
40 | 37 |
|
41 | | - def __call__(self, *args: Any, **kwargs: Any): |
42 | | - import torch |
43 | | - |
44 | | - with self.enable_spawn(): |
45 | | - queue = self.delta_queue if isinstance(self.delta_queue, MultiProcessQueue) else self.delta_queue.to_dict() |
46 | | - torch.multiprocessing.spawn( |
47 | | - self.dispatch_run, |
48 | | - args=(self.__class__, self.work, queue, args, kwargs), |
49 | | - nprocs=1, |
50 | | - ) |
51 | | - |
52 | | - @staticmethod |
53 | | - def dispatch_run(local_rank, cls, work, delta_queue, args, kwargs): |
54 | | - if local_rank == 0: |
55 | | - if isinstance(delta_queue, dict): |
56 | | - delta_queue = cls.process_queue(delta_queue) |
57 | | - work._request_queue = cls.process_queue(work._request_queue) |
58 | | - work._response_queue = cls.process_queue(work._response_queue) |
59 | | - |
60 | | - state_observer = WorkStateObserver(work, delta_queue=delta_queue) |
61 | | - state_observer.start() |
62 | | - _proxy_setattr(work, delta_queue, state_observer) |
63 | | - |
64 | | - unwrap(work.run)(*args, **kwargs) |
65 | | - |
66 | | - if local_rank == 0: |
67 | | - state_observer.join(0) |
| 38 | + if _TORCH_GREATER_EQUAL_1_12 and torch.backends.mps.is_available() and platform.processor() in ("arm", "arm64"): |
| 39 | + return torch.device("mps", local_rank) |
| 40 | + else: |
| 41 | + return torch.device(f"cuda:{local_rank}" if torch.cuda.is_available() else "cpu") |
68 | 42 |
|
69 | 43 |
|
70 | 44 | class _DefaultInputData(BaseModel): |
@@ -95,6 +69,9 @@ def _get_sample_data() -> Dict[Any, Any]: |
95 | 69 |
|
96 | 70 |
|
97 | 71 | class PythonServer(LightningWork, abc.ABC): |
| 72 | + |
| 73 | + _start_method = "spawn" |
| 74 | + |
98 | 75 | @requires(["torch", "lightning_api_access"]) |
99 | 76 | def __init__( # type: ignore |
100 | 77 | self, |
@@ -160,11 +137,6 @@ def predict(self, request): |
160 | 137 | self._input_type = input_type |
161 | 138 | self._output_type = output_type |
162 | 139 |
|
163 | | - # Note: Enable to run inference on GPUs. |
164 | | - self._run_executor_cls = ( |
165 | | - WorkRunExecutor if os.getenv("LIGHTNING_CLOUD_APP_ID", None) else _PyTorchSpawnRunExecutor |
166 | | - ) |
167 | | - |
168 | 140 | def setup(self, *args, **kwargs) -> None: |
169 | 141 | """This method is called before the server starts. Override this if you need to download the model or |
170 | 142 | initialize the weights, setting up pipelines etc. |
@@ -210,13 +182,16 @@ def _get_sample_dict_from_datatype(datatype: Any) -> dict: |
210 | 182 | return out |
211 | 183 |
|
212 | 184 | def _attach_predict_fn(self, fastapi_app: FastAPI) -> None: |
213 | | - from torch import inference_mode |
| 185 | + from torch import inference_mode, no_grad |
214 | 186 |
|
215 | 187 | input_type: type = self.configure_input_type() |
216 | 188 | output_type: type = self.configure_output_type() |
217 | 189 |
|
| 190 | + device = _get_device() |
| 191 | + context = no_grad if device.type == "mps" else inference_mode |
| 192 | + |
218 | 193 | def predict_fn(request: input_type): # type: ignore |
219 | | - with inference_mode(): |
| 194 | + with context(): |
220 | 195 | return self.predict(request) |
221 | 196 |
|
222 | 197 | fastapi_app.post("/predict", response_model=output_type)(predict_fn) |
|
0 commit comments