From d01af13c98dbf2664a08f479ca846a0507a89540 Mon Sep 17 00:00:00 2001 From: erranlli Date: Fri, 31 Oct 2025 17:35:00 +0800 Subject: [PATCH 1/2] fix format --- rllm/engine/agent_execution_engine.py | 27 ++++++++++++++++--- rllm/trainer/verl/agent_ppo_trainer.py | 36 ++++++++++++++++++++++++-- 2 files changed, 57 insertions(+), 6 deletions(-) diff --git a/rllm/engine/agent_execution_engine.py b/rllm/engine/agent_execution_engine.py index fc16b2f90..d821820e9 100644 --- a/rllm/engine/agent_execution_engine.py +++ b/rllm/engine/agent_execution_engine.py @@ -203,10 +203,17 @@ async def run_agent_trajectory_async(self, idx, application_id, seed=0, mode="Te messages = agent.chat_completions prompt_tokens, _ = convert_messages_to_tokens_and_masks(messages, tokenizer=self.tokenizer, parser=self.chat_parser, contains_first_msg=True, contains_generation_msg=True) prompt_token_len = len(prompt_tokens) - # Note, this should never happen! + + # Check if initial prompt already exceeds max length + # This can happen if: + # 1. Dataset filtering didn't catch this sample (e.g., different tokenization) + # 2. Checkpoint contains cached dataset that wasn't filtered (delete checkpoint's data.pt) if prompt_token_len > self.max_prompt_length: - agent.reset() - raise Exception(f"Trajectory {idx}: initial prompt length {prompt_token_len} already exceeded max_prompt_length {self.max_prompt_length}, retrying") + logger.warning(f"Trajectory {idx}: Initial prompt length {prompt_token_len} exceeds max_prompt_length {self.max_prompt_length}. Skipping this sample entirely (no trajectory will be returned). First 200 chars of prompt: {self.chat_parser.parse(messages[:1], add_generation_prompt=False)[:200]}...") + + # Close the environment and return None to skip this trajectory entirely + await loop.run_in_executor(self.executor, env.close) + return None for step_idx in range(self.max_steps): # Get action from agent @@ -410,7 +417,11 @@ async def run_agent_trajectory_async(self, idx, application_id, seed=0, mode="Te async def run_agent_trajectory_with_retry(self, idx, application_id, seed=0, mode="Text", **kwargs): for _ in range(self.retry_limit): try: - return await asyncio.wait_for(self.run_agent_trajectory_async(idx, application_id=application_id, seed=seed, mode=mode, **kwargs), timeout=7200) + result = await asyncio.wait_for(self.run_agent_trajectory_async(idx, application_id=application_id, seed=seed, mode=mode, **kwargs), timeout=7200) + # If result is None, it means the trajectory was skipped (e.g., overlong prompt) + if result is None: + return None + return result except Exception: traceback.print_exc() continue @@ -452,10 +463,18 @@ async def launch_one_trajectory_task(env_idx: int): tasks_to_run = [launch_one_trajectory_task(i) for i in range(len(self.envs))] tasks_completed = 0 + skipped_count = 0 for coro in asyncio.as_completed(tasks_to_run): try: result = await coro tasks_completed += 1 + + # Skip None results (trajectories that were skipped due to overlong prompts) + if result is None: + skipped_count += 1 + colorful_print(f"Number of Trajectories {tasks_completed}/{len(self.envs)} completed ({skipped_count} skipped due to overlong prompts)", "cyan") + continue + colorful_print(f"Number of Trajectories {tasks_completed}/{len(self.envs)} completed", "cyan") yield result except Exception as e: diff --git a/rllm/trainer/verl/agent_ppo_trainer.py b/rllm/trainer/verl/agent_ppo_trainer.py index 9cf908b57..3ecc280af 100644 --- a/rllm/trainer/verl/agent_ppo_trainer.py +++ b/rllm/trainer/verl/agent_ppo_trainer.py @@ -180,6 +180,18 @@ def fit_agent(self): batch = self._pad_dataproto_to_world_size(batch=batch) else: final_gen_batch_output, generate_metrics = self.generate_agent_trajectory(timing_raw=timing_raw, meta_info=batch.meta_info) + + # If some trajectories were skipped (overlong prompts), filter the batch to match + if "skipped_indices" in final_gen_batch_output.meta_info: + skipped_indices = final_gen_batch_output.meta_info.pop("skipped_indices") + # Create mask for valid (non-skipped) indices + valid_mask = np.ones(len(batch.batch), dtype=bool) + valid_mask[skipped_indices] = False + # Filter batch to only include valid samples + valid_indices = np.where(valid_mask)[0] + batch = batch.select_idxs(valid_indices) + print(f"Filtered batch from {len(valid_mask)} to {len(valid_indices)} samples after skipping {len(skipped_indices)} overlong prompts") + batch = batch.union(final_gen_batch_output) metrics.update(generate_metrics) @@ -551,16 +563,36 @@ def generate_agent_trajectory(self, timing_raw=None, meta_info=None): trajectories = [] if self.async_rollout_mode: gen_seq_generator = self.generate_agent_trajectories_async(timing_raw=timing_raw, meta_info=meta_info, mode="Token") - for _, trajectory in enumerate(gen_seq_generator): - trajectories.append(trajectory) + for trajectory in gen_seq_generator: + # Skip None trajectories (overlong prompts) + if trajectory is not None: + trajectories.append(trajectory) else: raise ValueError("Only async rollout mode is supported") + + # Check if all trajectories were skipped + if not trajectories: + raise RuntimeError("All trajectories were skipped (likely all prompts exceed max_prompt_length). Please check your dataset and increase max_prompt_length or enable filtering.") + # Sort trajectories by their idx, to ensure they are in order. trajectories.sort(key=lambda x: x["idx"]) + # Determine which indices were skipped by checking missing idx values + # Expected indices are 0 to (batch_size * rollout.n - 1) + expected_count = len(self.agent_execution_engine.envs) + actual_indices = set(t["idx"] for t in trajectories) + expected_indices = set(range(expected_count)) + skipped_indices = sorted(expected_indices - actual_indices) + + if skipped_indices: + print(f"Skipped {len(skipped_indices)} trajectories due to overlong prompts at env indices: {skipped_indices}") + with marked_timer("transform_trajectory", timing_raw): # Transform the raw trajectories into DataProto format. final_gen_batch_output, metrics = self._transform_agent_trajectories(trajectories) + # Store skipped indices in meta_info for potential filtering of original batch + if skipped_indices: + final_gen_batch_output.meta_info["skipped_indices"] = skipped_indices return final_gen_batch_output, metrics def generate_agent_steps(self, timing_raw=None, meta_info=None, uids=None): From 25577bbd1404a7f065e4974cb8b1e76508d2cc67 Mon Sep 17 00:00:00 2001 From: erranli Date: Fri, 7 Nov 2025 09:37:59 +0000 Subject: [PATCH 2/2] minor fix --- rllm/engine/agent_execution_engine.py | 43 +++++++++++++++---- rllm/trainer/verl/agent_ppo_trainer.py | 57 +++++++++++++++----------- 2 files changed, 69 insertions(+), 31 deletions(-) diff --git a/rllm/engine/agent_execution_engine.py b/rllm/engine/agent_execution_engine.py index d821820e9..318497034 100644 --- a/rllm/engine/agent_execution_engine.py +++ b/rllm/engine/agent_execution_engine.py @@ -437,6 +437,10 @@ async def trajectory_generator(self, reset_seed=0, timing_raw=None, mode="Text", self.executor = ThreadPoolExecutor(max_workers=self.max_env_workers) semaphore = asyncio.Semaphore(self.n_parallel_agents) + # Initialize skipped indices and valid indices lists (will be populated as trajectories complete) + self._last_skipped_indices = [] + self._last_valid_indices = None # Will be computed after all trajectories complete + if self.engine_name == "verl": self.rollout_engine.wake_up() @@ -456,30 +460,55 @@ async def launch_one_trajectory_task(env_idx: int): traceback.print_exc() raise e - return result + # Return tuple (env_idx, result) so we can track which env returned None + return (env_idx, result) # Create all N conceptual tasks. Their execution will be throttled by the semaphore # and the availability of agent/env indices. tasks_to_run = [launch_one_trajectory_task(i) for i in range(len(self.envs))] + # Track results by index to maintain order and identify skipped trajectories + results_by_idx = {} tasks_completed = 0 skipped_count = 0 + for coro in asyncio.as_completed(tasks_to_run): try: - result = await coro + env_idx, result = await coro tasks_completed += 1 - # Skip None results (trajectories that were skipped due to overlong prompts) + # Store result with its env_idx (None if skipped) if result is None: skipped_count += 1 + results_by_idx[env_idx] = None # Store None to mark as skipped colorful_print(f"Number of Trajectories {tasks_completed}/{len(self.envs)} completed ({skipped_count} skipped due to overlong prompts)", "cyan") - continue - - colorful_print(f"Number of Trajectories {tasks_completed}/{len(self.envs)} completed", "cyan") - yield result + else: + results_by_idx[env_idx] = result + colorful_print(f"Number of Trajectories {tasks_completed}/{len(self.envs)} completed", "cyan") except Exception as e: raise e + # Verify all tasks completed and are stored + if len(results_by_idx) != len(self.envs): + missing = sorted(set(range(len(self.envs))) - set(results_by_idx.keys())) + raise RuntimeError(f"Not all trajectories were stored! Missing indices: {missing}. Expected {len(self.envs)} but got {len(results_by_idx)}") + + # Yield all trajectories in order (0 to len(self.envs)-1) + # None values indicate skipped trajectories + skipped_indices = [] + for idx in range(len(self.envs)): + # All indices should be in results_by_idx after the check above + result = results_by_idx[idx] + if result is None: + skipped_indices.append(idx) + yield result # Yield result (None for skipped, trajectory dict otherwise) + + # Store skipped indices and valid indices as instance variables for trainer to access + self._last_skipped_indices = skipped_indices + # Compute valid indices (complement of skipped indices) for easier batch filtering + total_count = len(self.envs) + self._last_valid_indices = [i for i in range(total_count) if i not in skipped_indices] + if self.engine_name == "verl": self.rollout_engine.sleep() diff --git a/rllm/trainer/verl/agent_ppo_trainer.py b/rllm/trainer/verl/agent_ppo_trainer.py index 3ecc280af..7d0197219 100644 --- a/rllm/trainer/verl/agent_ppo_trainer.py +++ b/rllm/trainer/verl/agent_ppo_trainer.py @@ -153,6 +153,7 @@ def fit_agent(self): for epoch in range(self.config.trainer.total_epochs): pprint(f"epoch {epoch}, step {self.global_steps} started") + for batch_dict in self.train_dataloader: batch: DataProto = DataProto.from_single_dict(batch_dict) batch.non_tensor_batch["uid"] = np.array([str(uuid.uuid4()) for _ in range(len(batch.batch))], dtype=object) @@ -182,15 +183,25 @@ def fit_agent(self): final_gen_batch_output, generate_metrics = self.generate_agent_trajectory(timing_raw=timing_raw, meta_info=batch.meta_info) # If some trajectories were skipped (overlong prompts), filter the batch to match - if "skipped_indices" in final_gen_batch_output.meta_info: - skipped_indices = final_gen_batch_output.meta_info.pop("skipped_indices") - # Create mask for valid (non-skipped) indices - valid_mask = np.ones(len(batch.batch), dtype=bool) - valid_mask[skipped_indices] = False - # Filter batch to only include valid samples - valid_indices = np.where(valid_mask)[0] + # Get valid indices directly from AgentExecutionEngine (handled internally) + valid_indices = getattr(self.agent_execution_engine, "_last_valid_indices", None) + skipped_indices = getattr(self.agent_execution_engine, "_last_skipped_indices", []) + + # Ensure batch size matches the number of trajectories collected + num_trajectories = len(final_gen_batch_output.batch) + if valid_indices is not None and len(valid_indices) != num_trajectories: + if len(valid_indices) > num_trajectories: + valid_indices = valid_indices[:num_trajectories] + else: + raise RuntimeError(f"Fewer valid indices ({len(valid_indices)}) than trajectories ({num_trajectories}).") + + if valid_indices is not None and len(valid_indices) < len(batch.batch): + # Filter batch to only include valid samples (matching the number of trajectories collected) batch = batch.select_idxs(valid_indices) - print(f"Filtered batch from {len(valid_mask)} to {len(valid_indices)} samples after skipping {len(skipped_indices)} overlong prompts") + + # Final sanity check: batch sizes must match before union + if len(batch.batch) != len(final_gen_batch_output.batch): + raise RuntimeError(f"Batch size mismatch before union: batch has {len(batch.batch)} samples, final_gen_batch_output has {len(final_gen_batch_output.batch)} samples. valid_indices: {len(valid_indices) if valid_indices else 'None'}, skipped_indices: {len(skipped_indices)}") batch = batch.union(final_gen_batch_output) metrics.update(generate_metrics) @@ -564,7 +575,7 @@ def generate_agent_trajectory(self, timing_raw=None, meta_info=None): if self.async_rollout_mode: gen_seq_generator = self.generate_agent_trajectories_async(timing_raw=timing_raw, meta_info=meta_info, mode="Token") for trajectory in gen_seq_generator: - # Skip None trajectories (overlong prompts) + # Skip None trajectories (overlong prompts) - these are handled by AgentExecutionEngine if trajectory is not None: trajectories.append(trajectory) else: @@ -574,25 +585,18 @@ def generate_agent_trajectory(self, timing_raw=None, meta_info=None): if not trajectories: raise RuntimeError("All trajectories were skipped (likely all prompts exceed max_prompt_length). Please check your dataset and increase max_prompt_length or enable filtering.") - # Sort trajectories by their idx, to ensure they are in order. - trajectories.sort(key=lambda x: x["idx"]) - - # Determine which indices were skipped by checking missing idx values - # Expected indices are 0 to (batch_size * rollout.n - 1) - expected_count = len(self.agent_execution_engine.envs) - actual_indices = set(t["idx"] for t in trajectories) - expected_indices = set(range(expected_count)) - skipped_indices = sorted(expected_indices - actual_indices) - + # Get skipped indices from AgentExecutionEngine (handled internally) + skipped_indices = getattr(self.agent_execution_engine, "_last_skipped_indices", []) if skipped_indices: print(f"Skipped {len(skipped_indices)} trajectories due to overlong prompts at env indices: {skipped_indices}") + # Sort trajectories by their idx, to ensure they are in order. + trajectories.sort(key=lambda x: x["idx"]) + with marked_timer("transform_trajectory", timing_raw): # Transform the raw trajectories into DataProto format. final_gen_batch_output, metrics = self._transform_agent_trajectories(trajectories) - # Store skipped indices in meta_info for potential filtering of original batch - if skipped_indices: - final_gen_batch_output.meta_info["skipped_indices"] = skipped_indices + return final_gen_batch_output, metrics def generate_agent_steps(self, timing_raw=None, meta_info=None, uids=None): @@ -865,18 +869,23 @@ def generate_agent_trajectories_async(self, timing_raw=None, meta_info=None, mod timing_raw = {} queue = Queue() + # Create a unique sentinel object to signal completion + # (Cannot use None since None is used for skipped trajectories) + _SENTINEL = object() + def runner(): async def consume(): async for item in self.agent_execution_engine.trajectory_generator(timing_raw=timing_raw, mode=mode, meta_info=meta_info): queue.put(item) - queue.put(None) # sentinel to signal done + # Use a special sentinel object instead of None (since None is used for skipped trajectories) + queue.put(_SENTINEL) # sentinel to signal done asyncio.run(consume()) Thread(target=runner, daemon=True).start() while True: item = queue.get() - if item is None: + if item is _SENTINEL: break yield item