Skip to content
28 changes: 21 additions & 7 deletions src/forge/actors/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,12 +287,20 @@ def split_keys(keys):
return state_dict

@endpoint
async def generate(self, prompt: str, *, priority: int = 0) -> list[Completion]:
async def generate(
self,
prompt: str,
*,
priority: int = 0,
sampling_params: SamplingParams | None = None,
) -> list[Completion]:
"""Generate a response for the given prompt

Args:
prompt (str): The prompt to generate a response for.
priority (int, optional): The priority of the request. Defaults to 0.
sampling_params (SamplingParams, optional): Sampling parameters to use for this request.
If not provided, uses self.sampling_params.

Returns:
list[Completion]: n completions from vLLM based on your prompt.
Expand All @@ -301,12 +309,18 @@ async def generate(self, prompt: str, *, priority: int = 0) -> list[Completion]:
t.start()
record_metric("generator/generate/count_requests", 1, Reduce.SUM)

if sampling_params is not None:
# as in `post_init`
sampling_params.output_kind = RequestOutputKind.FINAL_ONLY

params = sampling_params or self.sampling_params

self.request_id += 1 % sys.maxsize
request_id = str(self.request_id)

tokenization_kwargs = {}
# TODO: add truncation support https://github.com/vllm-project/vllm/issues/4507
truncate_prompt_tokens = self.sampling_params.truncate_prompt_tokens
truncate_prompt_tokens = params.truncate_prompt_tokens
_validate_truncation_size(
self.vllm_config.model_config.max_model_len,
truncate_prompt_tokens,
Expand All @@ -315,7 +329,7 @@ async def generate(self, prompt: str, *, priority: int = 0) -> list[Completion]:
prompt_str, request = self.processor.process_inputs(
request_id=request_id,
prompt={"prompt": prompt},
params=self.sampling_params,
params=params,
arrival_time=None,
tokenization_kwargs=tokenization_kwargs,
trace_headers=None,
Expand All @@ -331,21 +345,21 @@ async def generate(self, prompt: str, *, priority: int = 0) -> list[Completion]:
await self.request_lock.wait_for(lambda: self.accepting_requests)

# Explicitly keeping the redundant logic to make it easier to pick up vLLM changes
if (num_samples := self.sampling_params.n) == 1:
if (num_samples := params.n) == 1:
self.output_processor.add_request(request, prompt_str, None, 0)
request, _ = self._preprocess_add_request(request)
request_fut = asyncio.Future()
self.requests[request_id] = (None, request_fut)
self.scheduler.add_request(request)
else:
parent_req = ParentRequest(request_id, self.sampling_params)
parent_req = ParentRequest(request_id, params)
for idx in range(num_samples):
# Note: `get_child_info` mutates ParentRequest to track the
# generated child request
child_request_id, params = parent_req.get_child_info(idx)
child_request_id, params_child = parent_req.get_child_info(idx)
child_request = request if idx == num_samples - 1 else copy(request)
child_request.request_id = child_request_id
child_request.sampling_params = params
child_request.sampling_params = params_child
self.output_processor.add_request(
child_request, prompt_str, parent_req, idx
)
Expand Down
4 changes: 2 additions & 2 deletions src/forge/controller/launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,6 @@
import monarch

import torchx.specs as specs

from forge.types import Launcher, LauncherConfig
from monarch._rust_bindings.monarch_hyperactor.alloc import AllocConstraints
from monarch._rust_bindings.monarch_hyperactor.channel import ChannelTransport

Expand All @@ -29,6 +27,8 @@
from monarch.tools.commands import create, info
from monarch.tools.config import Config, Workspace

from forge.types import Launcher, LauncherConfig

_MAST_AVAILABLE = False

try:
Expand Down
Loading