Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 56 additions & 8 deletions rllm/engine/agent_execution_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,10 +203,17 @@ async def run_agent_trajectory_async(self, idx, application_id, seed=0, mode="Te
messages = agent.chat_completions
prompt_tokens, _ = convert_messages_to_tokens_and_masks(messages, tokenizer=self.tokenizer, parser=self.chat_parser, contains_first_msg=True, contains_generation_msg=True)
prompt_token_len = len(prompt_tokens)
# Note, this should never happen!

# Check if initial prompt already exceeds max length
# This can happen if:
# 1. Dataset filtering didn't catch this sample (e.g., different tokenization)
# 2. Checkpoint contains cached dataset that wasn't filtered (delete checkpoint's data.pt)
if prompt_token_len > self.max_prompt_length:
agent.reset()
raise Exception(f"Trajectory {idx}: initial prompt length {prompt_token_len} already exceeded max_prompt_length {self.max_prompt_length}, retrying")
logger.warning(f"Trajectory {idx}: Initial prompt length {prompt_token_len} exceeds max_prompt_length {self.max_prompt_length}. Skipping this sample entirely (no trajectory will be returned). First 200 chars of prompt: {self.chat_parser.parse(messages[:1], add_generation_prompt=False)[:200]}...")

# Close the environment and return None to skip this trajectory entirely
await loop.run_in_executor(self.executor, env.close)
return None

for step_idx in range(self.max_steps):
# Get action from agent
Expand Down Expand Up @@ -410,7 +417,11 @@ async def run_agent_trajectory_async(self, idx, application_id, seed=0, mode="Te
async def run_agent_trajectory_with_retry(self, idx, application_id, seed=0, mode="Text", **kwargs):
for _ in range(self.retry_limit):
try:
return await asyncio.wait_for(self.run_agent_trajectory_async(idx, application_id=application_id, seed=seed, mode=mode, **kwargs), timeout=7200)
result = await asyncio.wait_for(self.run_agent_trajectory_async(idx, application_id=application_id, seed=seed, mode=mode, **kwargs), timeout=7200)
# If result is None, it means the trajectory was skipped (e.g., overlong prompt)
if result is None:
return None
return result
except Exception:
traceback.print_exc()
continue
Expand All @@ -426,6 +437,10 @@ async def trajectory_generator(self, reset_seed=0, timing_raw=None, mode="Text",
self.executor = ThreadPoolExecutor(max_workers=self.max_env_workers)
semaphore = asyncio.Semaphore(self.n_parallel_agents)

# Initialize skipped indices and valid indices lists (will be populated as trajectories complete)
self._last_skipped_indices = []
self._last_valid_indices = None # Will be computed after all trajectories complete

if self.engine_name == "verl":
self.rollout_engine.wake_up()

Expand All @@ -445,22 +460,55 @@ async def launch_one_trajectory_task(env_idx: int):

traceback.print_exc()
raise e
return result
# Return tuple (env_idx, result) so we can track which env returned None
return (env_idx, result)

# Create all N conceptual tasks. Their execution will be throttled by the semaphore
# and the availability of agent/env indices.
tasks_to_run = [launch_one_trajectory_task(i) for i in range(len(self.envs))]

# Track results by index to maintain order and identify skipped trajectories
results_by_idx = {}
tasks_completed = 0
skipped_count = 0

for coro in asyncio.as_completed(tasks_to_run):
try:
result = await coro
env_idx, result = await coro
tasks_completed += 1
colorful_print(f"Number of Trajectories {tasks_completed}/{len(self.envs)} completed", "cyan")
yield result

# Store result with its env_idx (None if skipped)
if result is None:
skipped_count += 1
results_by_idx[env_idx] = None # Store None to mark as skipped
colorful_print(f"Number of Trajectories {tasks_completed}/{len(self.envs)} completed ({skipped_count} skipped due to overlong prompts)", "cyan")
else:
results_by_idx[env_idx] = result
colorful_print(f"Number of Trajectories {tasks_completed}/{len(self.envs)} completed", "cyan")
except Exception as e:
raise e

# Verify all tasks completed and are stored
if len(results_by_idx) != len(self.envs):
missing = sorted(set(range(len(self.envs))) - set(results_by_idx.keys()))
raise RuntimeError(f"Not all trajectories were stored! Missing indices: {missing}. Expected {len(self.envs)} but got {len(results_by_idx)}")

# Yield all trajectories in order (0 to len(self.envs)-1)
# None values indicate skipped trajectories
skipped_indices = []
for idx in range(len(self.envs)):
# All indices should be in results_by_idx after the check above
result = results_by_idx[idx]
if result is None:
skipped_indices.append(idx)
yield result # Yield result (None for skipped, trajectory dict otherwise)

# Store skipped indices and valid indices as instance variables for trainer to access
self._last_skipped_indices = skipped_indices
# Compute valid indices (complement of skipped indices) for easier batch filtering
total_count = len(self.envs)
self._last_valid_indices = [i for i in range(total_count) if i not in skipped_indices]

if self.engine_name == "verl":
self.rollout_engine.sleep()

Expand Down
49 changes: 45 additions & 4 deletions rllm/trainer/verl/agent_ppo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,7 @@ def fit_agent(self):

for epoch in range(self.config.trainer.total_epochs):
pprint(f"epoch {epoch}, step {self.global_steps} started")

for batch_dict in self.train_dataloader:
batch: DataProto = DataProto.from_single_dict(batch_dict)
batch.non_tensor_batch["uid"] = np.array([str(uuid.uuid4()) for _ in range(len(batch.batch))], dtype=object)
Expand Down Expand Up @@ -180,6 +181,28 @@ def fit_agent(self):
batch = self._pad_dataproto_to_world_size(batch=batch)
else:
final_gen_batch_output, generate_metrics = self.generate_agent_trajectory(timing_raw=timing_raw, meta_info=batch.meta_info)

# If some trajectories were skipped (overlong prompts), filter the batch to match
# Get valid indices directly from AgentExecutionEngine (handled internally)
valid_indices = getattr(self.agent_execution_engine, "_last_valid_indices", None)
skipped_indices = getattr(self.agent_execution_engine, "_last_skipped_indices", [])

# Ensure batch size matches the number of trajectories collected
num_trajectories = len(final_gen_batch_output.batch)
if valid_indices is not None and len(valid_indices) != num_trajectories:
if len(valid_indices) > num_trajectories:
valid_indices = valid_indices[:num_trajectories]
else:
raise RuntimeError(f"Fewer valid indices ({len(valid_indices)}) than trajectories ({num_trajectories}).")

if valid_indices is not None and len(valid_indices) < len(batch.batch):
# Filter batch to only include valid samples (matching the number of trajectories collected)
batch = batch.select_idxs(valid_indices)

# Final sanity check: batch sizes must match before union
if len(batch.batch) != len(final_gen_batch_output.batch):
raise RuntimeError(f"Batch size mismatch before union: batch has {len(batch.batch)} samples, final_gen_batch_output has {len(final_gen_batch_output.batch)} samples. valid_indices: {len(valid_indices) if valid_indices else 'None'}, skipped_indices: {len(skipped_indices)}")

batch = batch.union(final_gen_batch_output)
metrics.update(generate_metrics)

Expand Down Expand Up @@ -551,16 +574,29 @@ def generate_agent_trajectory(self, timing_raw=None, meta_info=None):
trajectories = []
if self.async_rollout_mode:
gen_seq_generator = self.generate_agent_trajectories_async(timing_raw=timing_raw, meta_info=meta_info, mode="Token")
for _, trajectory in enumerate(gen_seq_generator):
trajectories.append(trajectory)
for trajectory in gen_seq_generator:
# Skip None trajectories (overlong prompts) - these are handled by AgentExecutionEngine
if trajectory is not None:
trajectories.append(trajectory)
else:
raise ValueError("Only async rollout mode is supported")

# Check if all trajectories were skipped
if not trajectories:
raise RuntimeError("All trajectories were skipped (likely all prompts exceed max_prompt_length). Please check your dataset and increase max_prompt_length or enable filtering.")

# Get skipped indices from AgentExecutionEngine (handled internally)
skipped_indices = getattr(self.agent_execution_engine, "_last_skipped_indices", [])
if skipped_indices:
print(f"Skipped {len(skipped_indices)} trajectories due to overlong prompts at env indices: {skipped_indices}")

# Sort trajectories by their idx, to ensure they are in order.
trajectories.sort(key=lambda x: x["idx"])

with marked_timer("transform_trajectory", timing_raw):
# Transform the raw trajectories into DataProto format.
final_gen_batch_output, metrics = self._transform_agent_trajectories(trajectories)

return final_gen_batch_output, metrics

def generate_agent_steps(self, timing_raw=None, meta_info=None, uids=None):
Expand Down Expand Up @@ -833,18 +869,23 @@ def generate_agent_trajectories_async(self, timing_raw=None, meta_info=None, mod
timing_raw = {}
queue = Queue()

# Create a unique sentinel object to signal completion
# (Cannot use None since None is used for skipped trajectories)
_SENTINEL = object()

def runner():
async def consume():
async for item in self.agent_execution_engine.trajectory_generator(timing_raw=timing_raw, mode=mode, meta_info=meta_info):
queue.put(item)
queue.put(None) # sentinel to signal done
# Use a special sentinel object instead of None (since None is used for skipped trajectories)
queue.put(_SENTINEL) # sentinel to signal done

asyncio.run(consume())

Thread(target=runner, daemon=True).start()
while True:
item = queue.get()
if item is None:
if item is _SENTINEL:
break
yield item

Expand Down