|
7 | 7 |
|
8 | 8 | import uvicorn |
9 | 9 | from fastapi import FastAPI |
10 | | -from lightning_utilities.core.imports import compare_version, module_available |
| 10 | +from lightning_utilities.core.imports import compare_version |
11 | 11 | from pydantic import BaseModel |
12 | 12 |
|
13 | 13 | from lightning_app.core.work import LightningWork |
|
16 | 16 |
|
17 | 17 | logger = Logger(__name__) |
18 | 18 |
|
19 | | -__doctest_skip__ = [] |
20 | | -# Skip doctests if requirements aren't available |
21 | | -if not module_available("lightning_api_access"): |
22 | | - __doctest_skip__ += ["PythonServer", "PythonServer.*"] |
| 19 | +__doctest_skip__ = ["PythonServer", "PythonServer.*"] |
| 20 | + |
23 | 21 |
|
24 | 22 | # Skip doctests if requirements aren't available |
25 | 23 | if not _is_torch_available(): |
@@ -72,7 +70,7 @@ class PythonServer(LightningWork, abc.ABC): |
72 | 70 |
|
73 | 71 | _start_method = "spawn" |
74 | 72 |
|
75 | | - @requires(["torch", "lightning_api_access"]) |
| 73 | + @requires(["torch"]) |
76 | 74 | def __init__( # type: ignore |
77 | 75 | self, |
78 | 76 | input_type: type = _DefaultInputData, |
@@ -193,29 +191,32 @@ def predict_fn(request: input_type): # type: ignore |
193 | 191 | fastapi_app.post("/predict", response_model=output_type)(predict_fn) |
194 | 192 |
|
195 | 193 | def configure_layout(self) -> None: |
196 | | - if module_available("lightning_api_access"): |
| 194 | + try: |
197 | 195 | from lightning_api_access import APIAccessFrontend |
198 | | - |
199 | | - class_name = self.__class__.__name__ |
200 | | - url = f"{self.url}/predict" |
201 | | - |
202 | | - try: |
203 | | - request = self._get_sample_dict_from_datatype(self.configure_input_type()) |
204 | | - response = self._get_sample_dict_from_datatype(self.configure_output_type()) |
205 | | - except TypeError: |
206 | | - return None |
207 | | - |
208 | | - return APIAccessFrontend( |
209 | | - apis=[ |
210 | | - { |
211 | | - "name": class_name, |
212 | | - "url": url, |
213 | | - "method": "POST", |
214 | | - "request": request, |
215 | | - "response": response, |
216 | | - } |
217 | | - ] |
218 | | - ) |
| 196 | + except ModuleNotFoundError: |
| 197 | + logger.warn("APIAccessFrontend not found. Please install lightning-api-access to enable the UI") |
| 198 | + return |
| 199 | + |
| 200 | + class_name = self.__class__.__name__ |
| 201 | + url = f"{self.url}/predict" |
| 202 | + |
| 203 | + try: |
| 204 | + request = self._get_sample_dict_from_datatype(self.configure_input_type()) |
| 205 | + response = self._get_sample_dict_from_datatype(self.configure_output_type()) |
| 206 | + except TypeError: |
| 207 | + return None |
| 208 | + |
| 209 | + return APIAccessFrontend( |
| 210 | + apis=[ |
| 211 | + { |
| 212 | + "name": class_name, |
| 213 | + "url": url, |
| 214 | + "method": "POST", |
| 215 | + "request": request, |
| 216 | + "response": response, |
| 217 | + } |
| 218 | + ] |
| 219 | + ) |
219 | 220 |
|
220 | 221 | def run(self, *args: Any, **kwargs: Any) -> Any: |
221 | 222 | """Run method takes care of configuring and setting up a FastAPI server behind the scenes. |
|
0 commit comments