Skip to content

Commit 9819026

Browse files
api changes + notes
typo updating depth api modified: scripts/depthmap_api.py
1 parent a232eb9 commit 9819026

File tree

3 files changed

+76
-25
lines changed

3 files changed

+76
-25
lines changed

.gitignore

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1,3 @@
1-
__pycache__/
1+
__pycache__/
2+
models/
3+
ouputs/

scripts/depthmap_api.py

Lines changed: 56 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -6,17 +6,20 @@
66
from fastapi import FastAPI, Body
77
from fastapi.exceptions import HTTPException
88
from PIL import Image
9+
from itertools import tee
910

1011
import gradio as gr
1112

1213
from modules.api.models import List, Dict
1314
from modules.api import api
15+
from typing_extensions import Annotated
16+
from pydantic import BaseModel
1417

15-
from src.core import core_generation_funnel
18+
from src.core import core_generation_funnel, CoreGenerationFunnelInp
1619
from src.misc import SCRIPT_VERSION
1720
from src import backbone
1821
from src.common_constants import GenerationOptions as go
19-
22+
from src.api_constants import Api_Defaults, Api_Forced, Api_options
2023

2124
def encode_to_base64(image):
2225
if type(image) is str:
@@ -28,48 +31,77 @@ def encode_to_base64(image):
2831
else:
2932
return ""
3033

31-
3234
def encode_np_to_base64(image):
3335
pil = Image.fromarray(image)
3436
return api.encode_pil_to_base64(pil)
3537

36-
3738
def to_base64_PIL(encoding: str):
3839
return Image.fromarray(np.array(api.decode_base64_to_image(encoding)).astype('uint8'))
3940

4041

42+
def api_gen(depth_input_images, options):
43+
44+
default_options = CoreGenerationFunnelInp({Api_Defaults}).values
45+
46+
#TODO try-catch type errors here
47+
for key, value in options.items():
48+
default_options[key] = value
49+
50+
for key, value in Api_Forced.items():
51+
default_options[key.lower()] = value
52+
53+
if len(depth_input_images) == 0:
54+
raise HTTPException(status_code=422, detail="No images supplied")
55+
56+
print(f"Processing {str(len(depth_input_images))} images through the API")
57+
58+
pil_images = []
59+
for input_image in depth_input_images:
60+
pil_images.append(to_base64_PIL(input_image))
61+
outpath = backbone.get_outpath()
62+
gen_obj = core_generation_funnel(outpath, pil_images, None, None, options)
63+
return gen_obj
64+
4165
def depth_api(_: gr.Blocks, app: FastAPI):
4266
@app.get("/depth/version")
4367
async def version():
4468
return {"version": SCRIPT_VERSION}
4569

4670
@app.get("/depth/get_options")
4771
async def get_options():
48-
return {"options": sorted([x.name.lower() for x in go])}
72+
return {
73+
"api_options": Api_options,
74+
"gen_options": [x.name.lower() for x in go]
75+
}
4976

50-
# TODO: some potential inputs not supported (like custom depthmaps)
5177
@app.post("/depth/generate")
5278
async def process(
5379
depth_input_images: List[str] = Body([], title='Input Images'),
54-
options: Dict[str, object] = Body("options", title='Generation options'),
80+
api_options: Dict[str, object] = Body({'outputs': ["depth"]}, title='Api options', options= Api_options),
81+
gen_options: Dict[str, object] = Body({}, title='Generation options', options= [x.name.lower() for x in go])
5582
):
56-
# TODO: restrict mesh options
57-
58-
if len(depth_input_images) == 0:
59-
raise HTTPException(status_code=422, detail="No images supplied")
60-
print(f"Processing {str(len(depth_input_images))} images trough the API")
61-
62-
pil_images = []
63-
for input_image in depth_input_images:
64-
pil_images.append(to_base64_PIL(input_image))
65-
outpath = backbone.get_outpath()
66-
gen_obj = core_generation_funnel(outpath, pil_images, None, None, options)
67-
68-
results_based = []
69-
for count, type, result in gen_obj:
70-
if type not in ['simple_mesh', 'inpainted_mesh']:
71-
results_based += [encode_to_base64(result)]
72-
return {"images": results_based, "info": "Success"}
83+
gen_obj = api_gen(depth_input_images, gen_options)
84+
#NOTE Work around yield. (Might not be necessary, not sure if yield caches)
85+
_, gen_obj = tee (gen_obj)
86+
87+
if len(api_options["outputs"])>1:
88+
results_based = {}
89+
90+
for type in api_options["outputs"]:
91+
result_per_type = []
92+
93+
for count, img_type, result in gen_obj:
94+
if img_type == type:
95+
result_per_type += result
96+
97+
if len(result_per_type)==0:
98+
results_based[type] = "Check options. no img-type of " + str(type) + " where generated"
99+
else:
100+
results_based[type] = map(encode_to_base64, result_per_type)
101+
102+
return {"images": results_based, "info": "Success"}
103+
else:
104+
return {"images": {}, "info": "api_options.output is empty"}
73105

74106

75107
try:

src/api_constants.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
# TODO Maybe these have a better home
2+
Api_options = {
3+
'outputs': ["depth"], # list of outputs to send in response. examples ["depth", "normalmap", 'heatmap', "normal", 'background_removed'] etc
4+
#'conversions': "", #TODO implement. it's a good idea to give some options serverside for because often that's challenging in js/clientside
5+
'save':"" #TODO implement. To save on local machine. Can be very helpful for debugging.
6+
}
7+
8+
# TODO: These two are intended to be temporary
9+
Api_Defaults={
10+
"BOOST": False,
11+
"NET_SIZE_MATCH": True
12+
}
13+
#These are enforced after user inputs
14+
Api_Forced={
15+
"GEN_SIMPLE_MESH": False,
16+
"GEN_INPAINTED_MESH": False
17+
}

0 commit comments

Comments
 (0)