|
3 | 3 | from pathlib import Path |
4 | 4 | from typing import Any, Dict, Optional |
5 | 5 |
|
| 6 | +import torch |
6 | 7 | import uvicorn |
7 | 8 | from fastapi import FastAPI |
8 | 9 | from pydantic import BaseModel |
@@ -105,7 +106,7 @@ def predict(self, request): |
105 | 106 | self._input_type = input_type |
106 | 107 | self._output_type = output_type |
107 | 108 |
|
108 | | - def setup(self) -> None: |
| 109 | + def setup(self, *args, **kwargs) -> None: |
109 | 110 | """This method is called before the server starts. Override this if you need to download the model or |
110 | 111 | initialize the weights, setting up pipelines etc. |
111 | 112 |
|
@@ -154,7 +155,8 @@ def _attach_predict_fn(self, fastapi_app: FastAPI) -> None: |
154 | 155 | output_type: type = self.configure_output_type() |
155 | 156 |
|
156 | 157 | def predict_fn(request: input_type): # type: ignore |
157 | | - return self.predict(request) |
| 158 | + with torch.inference_mode(): |
| 159 | + return self.predict(request) |
158 | 160 |
|
159 | 161 | fastapi_app.post("/predict", response_model=output_type)(predict_fn) |
160 | 162 |
|
@@ -207,7 +209,7 @@ def run(self, *args: Any, **kwargs: Any) -> Any: |
207 | 209 |
|
208 | 210 | Normally, you don't need to override this method. |
209 | 211 | """ |
210 | | - self.setup() |
| 212 | + self.setup(*args, **kwargs) |
211 | 213 |
|
212 | 214 | fastapi_app = FastAPI() |
213 | 215 | self._attach_predict_fn(fastapi_app) |
|
0 commit comments