66from unittest import mock
77
88from fastapi .testclient import TestClient
9+ from pydantic import BaseModel
910from PIL import Image
1011import pytest
1112
@@ -152,6 +153,10 @@ def predict(
152153 "choices" : {"$ref" : "#/components/schemas/choices" },
153154 },
154155 },
156+ "Output" : {
157+ "title" : "Output" ,
158+ "type" : "string" ,
159+ },
155160 "Request" : {
156161 "title" : "Request" ,
157162 "type" : "object" ,
@@ -169,7 +174,7 @@ def predict(
169174 "type" : "object" ,
170175 "properties" : {
171176 "status" : {"$ref" : "#/components/schemas/Status" },
172- "output" : {"title " : "Output" , "type" : "string " },
177+ "output" : {"$ref " : "#/components/schemas/ Output" },
173178 "error" : {"title" : "Error" , "type" : "string" },
174179 },
175180 "description" : "The status of a prediction." ,
@@ -203,6 +208,150 @@ def predict(
203208 }
204209
205210
211+ def test_openapi_specification_with_custom_user_defined_output_type ():
212+ # Calling this `MyOutput` to test if cog renames it to `Output` in the schema
213+ class MyOutput (BaseModel ):
214+ foo_number : int = "42"
215+ foo_string : str = "meaning of life"
216+
217+ class Predictor (cog .Predictor ):
218+ def predict (
219+ self ,
220+ ) -> MyOutput :
221+ pass
222+
223+ client = make_client (Predictor ())
224+ resp = client .get ("/openapi.json" )
225+ assert resp .status_code == 200
226+ print (resp .json ())
227+
228+ assert resp .json () == {
229+ "openapi" : "3.0.2" ,
230+ "info" : {"title" : "Cog" , "version" : "0.1.0" },
231+ "paths" : {
232+ "/" : {
233+ "get" : {
234+ "summary" : "Root" ,
235+ "operationId" : "root__get" ,
236+ "responses" : {
237+ "200" : {
238+ "description" : "Successful Response" ,
239+ "content" : {"application/json" : {"schema" : {}}},
240+ }
241+ },
242+ }
243+ },
244+ "/predictions" : {
245+ "post" : {
246+ "summary" : "Predict" ,
247+ "operationId" : "predict_predictions_post" ,
248+ "requestBody" : {
249+ "content" : {
250+ "application/json" : {
251+ "schema" : {"$ref" : "#/components/schemas/Request" }
252+ }
253+ }
254+ },
255+ "responses" : {
256+ "200" : {
257+ "description" : "Successful Response" ,
258+ "content" : {
259+ "application/json" : {
260+ "schema" : {"$ref" : "#/components/schemas/Response" }
261+ }
262+ },
263+ },
264+ "422" : {
265+ "description" : "Validation Error" ,
266+ "content" : {
267+ "application/json" : {
268+ "schema" : {
269+ "$ref" : "#/components/schemas/HTTPValidationError"
270+ }
271+ }
272+ },
273+ },
274+ },
275+ }
276+ },
277+ },
278+ "components" : {
279+ "schemas" : {
280+ "HTTPValidationError" : {
281+ "title" : "HTTPValidationError" ,
282+ "type" : "object" ,
283+ "properties" : {
284+ "detail" : {
285+ "title" : "Detail" ,
286+ "type" : "array" ,
287+ "items" : {"$ref" : "#/components/schemas/ValidationError" },
288+ }
289+ },
290+ },
291+ "Input" : {"title" : "Input" , "type" : "object" , "properties" : {}},
292+ "MyOutput" : {
293+ "title" : "MyOutput" ,
294+ "type" : "object" ,
295+ "properties" : {
296+ "foo_number" : {
297+ "title" : "Foo Number" ,
298+ "type" : "integer" ,
299+ "default" : "42" ,
300+ },
301+ "foo_string" : {
302+ "title" : "Foo String" ,
303+ "type" : "string" ,
304+ "default" : "meaning of life" ,
305+ },
306+ },
307+ },
308+ "Output" : {"$ref" : "#/components/schemas/MyOutput" , "title" : "Output" },
309+ "Request" : {
310+ "title" : "Request" ,
311+ "type" : "object" ,
312+ "properties" : {
313+ "input" : {"$ref" : "#/components/schemas/Input" },
314+ "output_file_prefix" : {
315+ "title" : "Output File Prefix" ,
316+ "type" : "string" ,
317+ },
318+ },
319+ },
320+ "Response" : {
321+ "title" : "Response" ,
322+ "required" : ["status" ],
323+ "type" : "object" ,
324+ "properties" : {
325+ "status" : {"$ref" : "#/components/schemas/Status" },
326+ "output" : {"$ref" : "#/components/schemas/Output" },
327+ "error" : {"title" : "Error" , "type" : "string" },
328+ },
329+ "description" : "The status of a prediction." ,
330+ },
331+ "Status" : {
332+ "title" : "Status" ,
333+ "enum" : ["processing" , "success" , "failed" ],
334+ "description" : "An enumeration." ,
335+ },
336+ "ValidationError" : {
337+ "title" : "ValidationError" ,
338+ "required" : ["loc" , "msg" , "type" ],
339+ "type" : "object" ,
340+ "properties" : {
341+ "loc" : {
342+ "title" : "Location" ,
343+ "type" : "array" ,
344+ "items" : {"type" : "string" },
345+ },
346+ "msg" : {"title" : "Message" , "type" : "string" },
347+ "type" : {"title" : "Error Type" , "type" : "string" },
348+ },
349+ },
350+ }
351+ },
352+ }
353+
354+
206355def test_yielding_strings_from_generator_predictors ():
207356 class Predictor (cog .Predictor ):
208357 def predict (self ) -> Generator [str , None , None ]:
0 commit comments