66from fastapi import FastAPI , Body
77from fastapi .exceptions import HTTPException
88from PIL import Image
9+ from itertools import tee
910
1011import gradio as gr
1112
1213from modules .api .models import List , Dict
1314from modules .api import api
1415
15- from src .core import core_generation_funnel
16+ from src .core import core_generation_funnel , CoreGenerationFunnelInp
1617from src .misc import SCRIPT_VERSION
1718from src import backbone
1819from src .common_constants import GenerationOptions as go
19-
20+ from src . api_constants import Api_Defaults , Api_Forced , Api_options
2021
2122def encode_to_base64 (image ):
2223 if type (image ) is str :
@@ -28,48 +29,77 @@ def encode_to_base64(image):
2829 else :
2930 return ""
3031
31-
3232def encode_np_to_base64 (image ):
3333 pil = Image .fromarray (image )
3434 return api .encode_pil_to_base64 (pil )
3535
36-
3736def to_base64_PIL (encoding : str ):
3837 return Image .fromarray (np .array (api .decode_base64_to_image (encoding )).astype ('uint8' ))
3938
4039
40+ def api_gen (depth_input_images , options ):
41+
42+ default_options = CoreGenerationFunnelInp ({Api_Defaults }).values
43+
44+ #TODO try-catch type errors here
45+ for key , value in options .items ():
46+ default_options [key ] = value
47+
48+ for key , value in Api_Forced .items ():
49+ default_options [key .lower ()] = value
50+
51+ if len (depth_input_images ) == 0 :
52+ raise HTTPException (status_code = 422 , detail = "No images supplied" )
53+
54+ print (f"Processing { str (len (depth_input_images ))} images through the API" )
55+
56+ pil_images = []
57+ for input_image in depth_input_images :
58+ pil_images .append (to_base64_PIL (input_image ))
59+ outpath = backbone .get_outpath ()
60+ gen_obj = core_generation_funnel (outpath , pil_images , None , None , options )
61+ return gen_obj
62+
4163def depth_api (_ : gr .Blocks , app : FastAPI ):
4264 @app .get ("/depth/version" )
4365 async def version ():
4466 return {"version" : SCRIPT_VERSION }
4567
4668 @app .get ("/depth/get_options" )
4769 async def get_options ():
48- return {"options" : sorted ([x .name .lower () for x in go ])}
70+ return {
71+ "api_options" : Api_options ,
72+ "gen_options" : [x .name .lower () for x in go ]
73+ }
4974
50- # TODO: some potential inputs not supported (like custom depthmaps)
5175 @app .post ("/depth/generate" )
5276 async def process (
5377 depth_input_images : List [str ] = Body ([], title = 'Input Images' ),
54- options : Dict [str , object ] = Body ("options" , title = 'Generation options' ),
78+ api_options : Dict [str , object ] = Body ({'outputs' : ["depth" ]}, title = 'Api options' , options = Api_options ),
79+ gen_options : Dict [str , object ] = Body ({}, title = 'Generation options' , options = [x .name .lower () for x in go ])
5580 ):
56- # TODO: restrict mesh options
57-
58- if len (depth_input_images ) == 0 :
59- raise HTTPException (status_code = 422 , detail = "No images supplied" )
60- print (f"Processing { str (len (depth_input_images ))} images trough the API" )
61-
62- pil_images = []
63- for input_image in depth_input_images :
64- pil_images .append (to_base64_PIL (input_image ))
65- outpath = backbone .get_outpath ()
66- gen_obj = core_generation_funnel (outpath , pil_images , None , None , options )
67-
68- results_based = []
69- for count , type , result in gen_obj :
70- if type not in ['simple_mesh' , 'inpainted_mesh' ]:
71- results_based += [encode_to_base64 (result )]
72- return {"images" : results_based , "info" : "Success" }
81+ gen_obj = api_gen (depth_input_images , gen_options )
82+ #NOTE Work around yield. (Might not be necessary, not sure if yield caches)
83+ _ , gen_obj = tee (gen_obj )
84+
85+ if len (api_options ["outputs" ])> 1 :
86+ results_based = {}
87+
88+ for type in api_options ["outputs" ]:
89+ result_per_type = []
90+
91+ for count , img_type , result in gen_obj :
92+ if img_type == type :
93+ result_per_type += result
94+
95+ if len (result_per_type )== 0 :
96+ results_based [type ] = "Check options. no img-type of " + str (type ) + " where generated"
97+ else :
98+ results_based [type ] = map (encode_to_base64 , result_per_type )
99+
100+ return {"images" : results_based , "info" : "Success" }
101+ else :
102+ return {"images" : {}, "info" : "api_options.output is empty" }
73103
74104
75105try :
0 commit comments