diff --git a/src/forge/actors/generator.py b/src/forge/actors/generator.py index 32ae69906..0d808fcfb 100644 --- a/src/forge/actors/generator.py +++ b/src/forge/actors/generator.py @@ -287,12 +287,20 @@ def split_keys(keys): return state_dict @endpoint - async def generate(self, prompt: str, *, priority: int = 0) -> list[Completion]: + async def generate( + self, + prompt: str, + *, + priority: int = 0, + sampling_params: SamplingParams | None = None, + ) -> list[Completion]: """Generate a response for the given prompt Args: prompt (str): The prompt to generate a response for. priority (int, optional): The priority of the request. Defaults to 0. + sampling_params (SamplingParams, optional): Sampling parameters to use for this request. + If not provided, uses self.sampling_params. Returns: list[Completion]: n completions from vLLM based on your prompt. @@ -301,12 +309,18 @@ async def generate(self, prompt: str, *, priority: int = 0) -> list[Completion]: t.start() record_metric("generator/generate/count_requests", 1, Reduce.SUM) + if sampling_params is not None: + # as in `post_init` + sampling_params.output_kind = RequestOutputKind.FINAL_ONLY + + params = sampling_params or self.sampling_params + self.request_id += 1 % sys.maxsize request_id = str(self.request_id) tokenization_kwargs = {} # TODO: add truncation support https://github.com/vllm-project/vllm/issues/4507 - truncate_prompt_tokens = self.sampling_params.truncate_prompt_tokens + truncate_prompt_tokens = params.truncate_prompt_tokens _validate_truncation_size( self.vllm_config.model_config.max_model_len, truncate_prompt_tokens, @@ -315,7 +329,7 @@ async def generate(self, prompt: str, *, priority: int = 0) -> list[Completion]: prompt_str, request = self.processor.process_inputs( request_id=request_id, prompt={"prompt": prompt}, - params=self.sampling_params, + params=params, arrival_time=None, tokenization_kwargs=tokenization_kwargs, trace_headers=None, @@ -331,21 +345,21 @@ async def generate(self, prompt: str, *, priority: int = 0) -> list[Completion]: await self.request_lock.wait_for(lambda: self.accepting_requests) # Explicitly keeping the redundant logic to make it easier to pick up vLLM changes - if (num_samples := self.sampling_params.n) == 1: + if (num_samples := params.n) == 1: self.output_processor.add_request(request, prompt_str, None, 0) request, _ = self._preprocess_add_request(request) request_fut = asyncio.Future() self.requests[request_id] = (None, request_fut) self.scheduler.add_request(request) else: - parent_req = ParentRequest(request_id, self.sampling_params) + parent_req = ParentRequest(request_id, params) for idx in range(num_samples): # Note: `get_child_info` mutates ParentRequest to track the # generated child request - child_request_id, params = parent_req.get_child_info(idx) + child_request_id, params_child = parent_req.get_child_info(idx) child_request = request if idx == num_samples - 1 else copy(request) child_request.request_id = child_request_id - child_request.sampling_params = params + child_request.sampling_params = params_child self.output_processor.add_request( child_request, prompt_str, parent_req, idx ) diff --git a/src/forge/controller/launcher.py b/src/forge/controller/launcher.py index a11ab50be..dd74591f1 100644 --- a/src/forge/controller/launcher.py +++ b/src/forge/controller/launcher.py @@ -17,8 +17,6 @@ import monarch import torchx.specs as specs - -from forge.types import Launcher, LauncherConfig from monarch._rust_bindings.monarch_hyperactor.alloc import AllocConstraints from monarch._rust_bindings.monarch_hyperactor.channel import ChannelTransport @@ -29,6 +27,8 @@ from monarch.tools.commands import create, info from monarch.tools.config import Config, Workspace +from forge.types import Launcher, LauncherConfig + _MAST_AVAILABLE = False try: