11import asyncio
2+ from contextlib import ExitStack
23from typing import List , Tuple
34
45import pytest
56
67from vllm import SamplingParams
78from vllm .engine .arg_utils import AsyncEngineArgs
89from vllm .platforms import current_platform
10+ from vllm .sampling_params import RequestOutputKind
911from vllm .v1 .engine .async_llm import AsyncLLM
1012
1113if not current_platform .is_cuda ():
1820
1921
2022async def generate (engine : AsyncLLM , request_id : str ,
23+ output_kind : RequestOutputKind ,
2124 max_tokens : int ) -> Tuple [int , str ]:
2225 count = 0
23- async for _ in engine .generate (request_id = request_id ,
24- prompt = "Hello my name is Robert and" ,
25- sampling_params = SamplingParams (
26- max_tokens = max_tokens , temperature = 0 )):
26+ sampling_params = SamplingParams (max_tokens = max_tokens ,
27+ output_kind = output_kind ,
28+ temperature = 0 )
29+ async for out in engine .generate (request_id = request_id ,
30+ prompt = "Hello my name is Robert and" ,
31+ sampling_params = sampling_params ):
32+
33+ num_tokens = len (out .outputs [0 ].token_ids )
34+ if output_kind == RequestOutputKind .DELTA :
35+ count += num_tokens
36+ else :
37+ count = num_tokens
2738
28- count += 1
2939 await asyncio .sleep (0. )
3040
3141 return count , request_id
3242
3343
44+ @pytest .mark .parametrize (
45+ "output_kind" , [RequestOutputKind .DELTA , RequestOutputKind .FINAL_ONLY ])
3446@pytest .mark .asyncio
35- async def test_load (monkeypatch ):
47+ async def test_load (monkeypatch , output_kind : RequestOutputKind ):
3648 # TODO(rickyx): Remove monkeypatch once we have a better way to test V1
3749 # so that in the future when we switch, we don't have to change all the
3850 # tests.
39- with monkeypatch .context () as m :
51+ with monkeypatch .context () as m , ExitStack () as after :
4052 m .setenv ("VLLM_USE_V1" , "1" )
4153
4254 engine = AsyncLLM .from_engine_args (ENGINE_ARGS )
55+ after .callback (engine .shutdown )
4356
4457 NUM_REQUESTS = 10000
4558 NUM_EXPECTED_TOKENS = 10
@@ -51,26 +64,33 @@ async def test_load(monkeypatch):
5164 for request_id in request_ids :
5265 tasks .append (
5366 asyncio .create_task (
54- generate (engine , request_id , NUM_EXPECTED_TOKENS )))
67+ generate (engine , request_id , output_kind ,
68+ NUM_EXPECTED_TOKENS )))
5569
5670 # Confirm that we got all the EXPECTED tokens from the requests.
57- for task in tasks :
71+ done , pending = await asyncio .wait (tasks ,
72+ return_when = asyncio .FIRST_EXCEPTION )
73+ for task in pending :
74+ task .cancel ()
75+ for task in done :
5876 num_generated_tokens , request_id = await task
5977 assert num_generated_tokens == NUM_EXPECTED_TOKENS , (
6078 f"{ request_id } generated { num_generated_tokens } but "
6179 f"expected { NUM_EXPECTED_TOKENS } " )
6280
6381 assert not engine .output_processor .has_unfinished_requests ()
64- engine .shutdown ()
6582
6683
84+ @pytest .mark .parametrize (
85+ "output_kind" , [RequestOutputKind .DELTA , RequestOutputKind .FINAL_ONLY ])
6786@pytest .mark .asyncio
68- async def test_abort (monkeypatch ):
87+ async def test_abort (monkeypatch , output_kind : RequestOutputKind ):
6988
70- with monkeypatch .context () as m :
89+ with monkeypatch .context () as m , ExitStack () as after :
7190 m .setenv ("VLLM_USE_V1" , "1" )
7291
7392 engine = AsyncLLM .from_engine_args (ENGINE_ARGS )
93+ after .callback (engine .shutdown )
7494
7595 NUM_REQUESTS = 100
7696 NUM_EXPECTED_TOKENS = 100
@@ -83,7 +103,8 @@ async def test_abort(monkeypatch):
83103 for request_id in request_ids :
84104 tasks .append (
85105 asyncio .create_task (
86- generate (engine , request_id , NUM_EXPECTED_TOKENS )))
106+ generate (engine , request_id , output_kind ,
107+ NUM_EXPECTED_TOKENS )))
87108
88109 # API server cancels requests when they disconnect.
89110 for idx in REQUEST_IDS_TO_ABORT :
@@ -108,9 +129,7 @@ async def test_abort(monkeypatch):
108129 # Confirm we can do another generation.
109130 request_id = f"request-{ REQUEST_IDS_TO_ABORT [0 ]} "
110131 task = asyncio .create_task (
111- generate (engine , request_id , NUM_EXPECTED_TOKENS ))
132+ generate (engine , request_id , output_kind , NUM_EXPECTED_TOKENS ))
112133 num_generated_tokens , request_id = await task
113134 assert num_generated_tokens == NUM_EXPECTED_TOKENS
114135 assert not engine .output_processor .has_unfinished_requests ()
115-
116- engine .shutdown ()
0 commit comments