66from fastapi import FastAPI , Body
77from fastapi .exceptions import HTTPException
88from PIL import Image
9+ from itertools import tee
910
1011import gradio as gr
1112
1213from modules .api .models import List , Dict
1314from modules .api import api
15+ from typing_extensions import Annotated
16+ from pydantic import BaseModel
1417
15- from src .core import core_generation_funnel
18+ from src .core import core_generation_funnel , CoreGenerationFunnelInp
1619from src .misc import SCRIPT_VERSION
1720from src import backbone
1821from src .common_constants import GenerationOptions as go
19-
22+ from src . api_constants import Api_Defaults , Api_Forced , Api_options
2023
2124def encode_to_base64 (image ):
2225 if type (image ) is str :
@@ -28,48 +31,77 @@ def encode_to_base64(image):
2831 else :
2932 return ""
3033
31-
3234def encode_np_to_base64 (image ):
3335 pil = Image .fromarray (image )
3436 return api .encode_pil_to_base64 (pil )
3537
36-
3738def to_base64_PIL (encoding : str ):
3839 return Image .fromarray (np .array (api .decode_base64_to_image (encoding )).astype ('uint8' ))
3940
4041
42+ def api_gen (depth_input_images , options ):
43+
44+ default_options = CoreGenerationFunnelInp ({Api_Defaults }).values
45+
46+ #TODO try-catch type errors here
47+ for key , value in options .items ():
48+ default_options [key ] = value
49+
50+ for key , value in Api_Forced .items ():
51+ default_options [key .lower ()] = value
52+
53+ if len (depth_input_images ) == 0 :
54+ raise HTTPException (status_code = 422 , detail = "No images supplied" )
55+
56+ print (f"Processing { str (len (depth_input_images ))} images through the API" )
57+
58+ pil_images = []
59+ for input_image in depth_input_images :
60+ pil_images .append (to_base64_PIL (input_image ))
61+ outpath = backbone .get_outpath ()
62+ gen_obj = core_generation_funnel (outpath , pil_images , None , None , options )
63+ return gen_obj
64+
4165def depth_api (_ : gr .Blocks , app : FastAPI ):
4266 @app .get ("/depth/version" )
4367 async def version ():
4468 return {"version" : SCRIPT_VERSION }
4569
4670 @app .get ("/depth/get_options" )
4771 async def get_options ():
48- return {"options" : sorted ([x .name .lower () for x in go ])}
72+ return {
73+ "api_options" : Api_options ,
74+ "gen_options" : [x .name .lower () for x in go ]
75+ }
4976
50- # TODO: some potential inputs not supported (like custom depthmaps)
5177 @app .post ("/depth/generate" )
5278 async def process (
5379 depth_input_images : List [str ] = Body ([], title = 'Input Images' ),
54- options : Dict [str , object ] = Body ("options" , title = 'Generation options' ),
80+ api_options : Dict [str , object ] = Body ({'outputs' : ["depth" ]}, title = 'Api options' , options = Api_options ),
81+ gen_options : Dict [str , object ] = Body ({}, title = 'Generation options' , options = [x .name .lower () for x in go ])
5582 ):
56- # TODO: restrict mesh options
57-
58- if len (depth_input_images ) == 0 :
59- raise HTTPException (status_code = 422 , detail = "No images supplied" )
60- print (f"Processing { str (len (depth_input_images ))} images trough the API" )
61-
62- pil_images = []
63- for input_image in depth_input_images :
64- pil_images .append (to_base64_PIL (input_image ))
65- outpath = backbone .get_outpath ()
66- gen_obj = core_generation_funnel (outpath , pil_images , None , None , options )
67-
68- results_based = []
69- for count , type , result in gen_obj :
70- if type not in ['simple_mesh' , 'inpainted_mesh' ]:
71- results_based += [encode_to_base64 (result )]
72- return {"images" : results_based , "info" : "Success" }
83+ gen_obj = api_gen (depth_input_images , gen_options )
84+ #NOTE Work around yield. (Might not be necessary, not sure if yield caches)
85+ _ , gen_obj = tee (gen_obj )
86+
87+ if len (api_options ["outputs" ])> 1 :
88+ results_based = {}
89+
90+ for type in api_options ["outputs" ]:
91+ result_per_type = []
92+
93+ for count , img_type , result in gen_obj :
94+ if img_type == type :
95+ result_per_type += result
96+
97+ if len (result_per_type )== 0 :
98+ results_based [type ] = "Check options. no img-type of " + str (type ) + " where generated"
99+ else :
100+ results_based [type ] = map (encode_to_base64 , result_per_type )
101+
102+ return {"images" : results_based , "info" : "Success" }
103+ else :
104+ return {"images" : {}, "info" : "api_options.output is empty" }
73105
74106
75107try :
0 commit comments