Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 8 additions & 5 deletions stable_diffusion_api/api/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,7 +327,7 @@ async def wait_task_finished(
request: Request,
status_service: StatusService,
) -> FinishedEvent:
async def get_finished_event_or_raise():
async def get_finished_event_or_raise() -> FinishedEvent:
async for ev in subscribe_to_task(task.task_id):
if isinstance(ev, AbortedEvent):
raise HTTPException(status_code=500, detail=ev.reason)
Expand All @@ -338,18 +338,21 @@ async def get_finished_event_or_raise():
async def disconnect_listener() -> None:
while not await request.is_disconnected():
await asyncio.sleep(0.1)
status_service.cancel_task(task.task_id)
raise HTTPException(status_code=499, detail="Client disconnected")

done, pending = await asyncio.wait([get_finished_event_or_raise(), disconnect_listener()],
return_when=asyncio.FIRST_COMPLETED)

for aio_task in pending:
aio_task.cancel()

if await request.is_disconnected():
status_service.cancel_task(task.task_id)
raise HTTPException(status_code=499, detail="Client disconnected")
done_task = done.pop()
exc = done_task.exception()
if exc is not None:
raise exc

event = done.pop().result()
event = done_task.result()
if event is None:
raise RuntimeError("Event stream ended unexpectedly")
return event
Expand Down
Empty file.
Loading