Skip to content

Commit 0fccc84

Browse files
zekebfirsh
andcommitted
always define components.schemas.Output
..even when the user-defined predict function returns a simple type like a string. Co-Authored-By: Ben Firshman <[email protected]> Signed-off-by: Zeke Sikelianos <[email protected]>
1 parent 5996f20 commit 0fccc84

File tree

3 files changed

+168
-12
lines changed

3 files changed

+168
-12
lines changed

python/cog/predictor.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import os.path
77
from pathlib import Path
88
import typing
9-
from pydantic import create_model
9+
from pydantic import create_model, BaseModel
1010
from pydantic.fields import FieldInfo
1111

1212
# Added in Python 3.8. Can be from typing if we drop support for <3.8.
@@ -117,4 +117,8 @@ def get_output_type(predictor: Predictor):
117117
if get_origin(OutputType) is Generator:
118118
OutputType = get_args(OutputType)[0]
119119

120-
return OutputType
120+
# Wrap the type in a model so Pydantic can document it in component schema
121+
class Output(BaseModel):
122+
__root__: OutputType
123+
124+
return Output

python/cog/server/http.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,18 @@ class Request(BaseModel):
3737
input: InputType = None
3838
output_file_prefix: str = None
3939

40+
# response_model is purely for generating schema.
41+
# We generate Response again in the request so we can set file output paths correctly, etc.
42+
OutputType = get_output_type(predictor)
43+
44+
@app.post(
45+
"/predictions",
46+
response_model=get_response_type(OutputType),
47+
response_model_exclude_unset=True,
48+
)
49+
50+
# The signature of this function is used by FastAPI to generate the schema.
51+
# The function body is not used to generate the schema.
4052
def predict(request: Request = Body(default=None)):
4153
if request is None or request.input is None:
4254
output = predictor.predict()
@@ -77,15 +89,6 @@ def predict(...) -> output_type:
7789
)
7890
return JSONResponse(content=encoded_response)
7991

80-
# response_model is purely for generating schema.
81-
# We generate Response again in the request so we can set file output paths correctly, etc.
82-
OutputType = get_output_type(predictor)
83-
app.post(
84-
"/predictions",
85-
response_model=get_response_type(OutputType),
86-
response_model_exclude_unset=True,
87-
)(predict)
88-
8992
return app
9093

9194

python/tests/server/test_http.py

Lines changed: 150 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from unittest import mock
77

88
from fastapi.testclient import TestClient
9+
from pydantic import BaseModel
910
from PIL import Image
1011
import pytest
1112

@@ -152,6 +153,10 @@ def predict(
152153
"choices": {"$ref": "#/components/schemas/choices"},
153154
},
154155
},
156+
"Output": {
157+
"title": "Output",
158+
"type": "string",
159+
},
155160
"Request": {
156161
"title": "Request",
157162
"type": "object",
@@ -169,7 +174,7 @@ def predict(
169174
"type": "object",
170175
"properties": {
171176
"status": {"$ref": "#/components/schemas/Status"},
172-
"output": {"title": "Output", "type": "string"},
177+
"output": {"$ref": "#/components/schemas/Output"},
173178
"error": {"title": "Error", "type": "string"},
174179
},
175180
"description": "The status of a prediction.",
@@ -203,6 +208,150 @@ def predict(
203208
}
204209

205210

211+
def test_openapi_specification_with_custom_user_defined_output_type():
212+
# Calling this `MyOutput` to test if cog renames it to `Output` in the schema
213+
class MyOutput(BaseModel):
214+
foo_number: int = "42"
215+
foo_string: str = "meaning of life"
216+
217+
class Predictor(cog.Predictor):
218+
def predict(
219+
self,
220+
) -> MyOutput:
221+
pass
222+
223+
client = make_client(Predictor())
224+
resp = client.get("/openapi.json")
225+
assert resp.status_code == 200
226+
print(resp.json())
227+
228+
assert resp.json() == {
229+
"openapi": "3.0.2",
230+
"info": {"title": "Cog", "version": "0.1.0"},
231+
"paths": {
232+
"/": {
233+
"get": {
234+
"summary": "Root",
235+
"operationId": "root__get",
236+
"responses": {
237+
"200": {
238+
"description": "Successful Response",
239+
"content": {"application/json": {"schema": {}}},
240+
}
241+
},
242+
}
243+
},
244+
"/predictions": {
245+
"post": {
246+
"summary": "Predict",
247+
"operationId": "predict_predictions_post",
248+
"requestBody": {
249+
"content": {
250+
"application/json": {
251+
"schema": {"$ref": "#/components/schemas/Request"}
252+
}
253+
}
254+
},
255+
"responses": {
256+
"200": {
257+
"description": "Successful Response",
258+
"content": {
259+
"application/json": {
260+
"schema": {"$ref": "#/components/schemas/Response"}
261+
}
262+
},
263+
},
264+
"422": {
265+
"description": "Validation Error",
266+
"content": {
267+
"application/json": {
268+
"schema": {
269+
"$ref": "#/components/schemas/HTTPValidationError"
270+
}
271+
}
272+
},
273+
},
274+
},
275+
}
276+
},
277+
},
278+
"components": {
279+
"schemas": {
280+
"HTTPValidationError": {
281+
"title": "HTTPValidationError",
282+
"type": "object",
283+
"properties": {
284+
"detail": {
285+
"title": "Detail",
286+
"type": "array",
287+
"items": {"$ref": "#/components/schemas/ValidationError"},
288+
}
289+
},
290+
},
291+
"Input": {"title": "Input", "type": "object", "properties": {}},
292+
"MyOutput": {
293+
"title": "MyOutput",
294+
"type": "object",
295+
"properties": {
296+
"foo_number": {
297+
"title": "Foo Number",
298+
"type": "integer",
299+
"default": "42",
300+
},
301+
"foo_string": {
302+
"title": "Foo String",
303+
"type": "string",
304+
"default": "meaning of life",
305+
},
306+
},
307+
},
308+
"Output": {"$ref": "#/components/schemas/MyOutput", "title": "Output"},
309+
"Request": {
310+
"title": "Request",
311+
"type": "object",
312+
"properties": {
313+
"input": {"$ref": "#/components/schemas/Input"},
314+
"output_file_prefix": {
315+
"title": "Output File Prefix",
316+
"type": "string",
317+
},
318+
},
319+
},
320+
"Response": {
321+
"title": "Response",
322+
"required": ["status"],
323+
"type": "object",
324+
"properties": {
325+
"status": {"$ref": "#/components/schemas/Status"},
326+
"output": {"$ref": "#/components/schemas/Output"},
327+
"error": {"title": "Error", "type": "string"},
328+
},
329+
"description": "The status of a prediction.",
330+
},
331+
"Status": {
332+
"title": "Status",
333+
"enum": ["processing", "success", "failed"],
334+
"description": "An enumeration.",
335+
},
336+
"ValidationError": {
337+
"title": "ValidationError",
338+
"required": ["loc", "msg", "type"],
339+
"type": "object",
340+
"properties": {
341+
"loc": {
342+
"title": "Location",
343+
"type": "array",
344+
"items": {"type": "string"},
345+
},
346+
"msg": {"title": "Message", "type": "string"},
347+
"type": {"title": "Error Type", "type": "string"},
348+
},
349+
},
350+
}
351+
},
352+
}
353+
354+
206355
def test_yielding_strings_from_generator_predictors():
207356
class Predictor(cog.Predictor):
208357
def predict(self) -> Generator[str, None, None]:

0 commit comments

Comments
 (0)