diff --git a/rllm/engine/agent_execution_engine.py b/rllm/engine/agent_execution_engine.py index fc16b2f9..31849703 100644 --- a/rllm/engine/agent_execution_engine.py +++ b/rllm/engine/agent_execution_engine.py @@ -203,10 +203,17 @@ async def run_agent_trajectory_async(self, idx, application_id, seed=0, mode="Te messages = agent.chat_completions prompt_tokens, _ = convert_messages_to_tokens_and_masks(messages, tokenizer=self.tokenizer, parser=self.chat_parser, contains_first_msg=True, contains_generation_msg=True) prompt_token_len = len(prompt_tokens) - # Note, this should never happen! + + # Check if initial prompt already exceeds max length + # This can happen if: + # 1. Dataset filtering didn't catch this sample (e.g., different tokenization) + # 2. Checkpoint contains cached dataset that wasn't filtered (delete checkpoint's data.pt) if prompt_token_len > self.max_prompt_length: - agent.reset() - raise Exception(f"Trajectory {idx}: initial prompt length {prompt_token_len} already exceeded max_prompt_length {self.max_prompt_length}, retrying") + logger.warning(f"Trajectory {idx}: Initial prompt length {prompt_token_len} exceeds max_prompt_length {self.max_prompt_length}. Skipping this sample entirely (no trajectory will be returned). First 200 chars of prompt: {self.chat_parser.parse(messages[:1], add_generation_prompt=False)[:200]}...") + + # Close the environment and return None to skip this trajectory entirely + await loop.run_in_executor(self.executor, env.close) + return None for step_idx in range(self.max_steps): # Get action from agent @@ -410,7 +417,11 @@ async def run_agent_trajectory_async(self, idx, application_id, seed=0, mode="Te async def run_agent_trajectory_with_retry(self, idx, application_id, seed=0, mode="Text", **kwargs): for _ in range(self.retry_limit): try: - return await asyncio.wait_for(self.run_agent_trajectory_async(idx, application_id=application_id, seed=seed, mode=mode, **kwargs), timeout=7200) + result = await asyncio.wait_for(self.run_agent_trajectory_async(idx, application_id=application_id, seed=seed, mode=mode, **kwargs), timeout=7200) + # If result is None, it means the trajectory was skipped (e.g., overlong prompt) + if result is None: + return None + return result except Exception: traceback.print_exc() continue @@ -426,6 +437,10 @@ async def trajectory_generator(self, reset_seed=0, timing_raw=None, mode="Text", self.executor = ThreadPoolExecutor(max_workers=self.max_env_workers) semaphore = asyncio.Semaphore(self.n_parallel_agents) + # Initialize skipped indices and valid indices lists (will be populated as trajectories complete) + self._last_skipped_indices = [] + self._last_valid_indices = None # Will be computed after all trajectories complete + if self.engine_name == "verl": self.rollout_engine.wake_up() @@ -445,22 +460,55 @@ async def launch_one_trajectory_task(env_idx: int): traceback.print_exc() raise e - return result + # Return tuple (env_idx, result) so we can track which env returned None + return (env_idx, result) # Create all N conceptual tasks. Their execution will be throttled by the semaphore # and the availability of agent/env indices. tasks_to_run = [launch_one_trajectory_task(i) for i in range(len(self.envs))] + # Track results by index to maintain order and identify skipped trajectories + results_by_idx = {} tasks_completed = 0 + skipped_count = 0 + for coro in asyncio.as_completed(tasks_to_run): try: - result = await coro + env_idx, result = await coro tasks_completed += 1 - colorful_print(f"Number of Trajectories {tasks_completed}/{len(self.envs)} completed", "cyan") - yield result + + # Store result with its env_idx (None if skipped) + if result is None: + skipped_count += 1 + results_by_idx[env_idx] = None # Store None to mark as skipped + colorful_print(f"Number of Trajectories {tasks_completed}/{len(self.envs)} completed ({skipped_count} skipped due to overlong prompts)", "cyan") + else: + results_by_idx[env_idx] = result + colorful_print(f"Number of Trajectories {tasks_completed}/{len(self.envs)} completed", "cyan") except Exception as e: raise e + # Verify all tasks completed and are stored + if len(results_by_idx) != len(self.envs): + missing = sorted(set(range(len(self.envs))) - set(results_by_idx.keys())) + raise RuntimeError(f"Not all trajectories were stored! Missing indices: {missing}. Expected {len(self.envs)} but got {len(results_by_idx)}") + + # Yield all trajectories in order (0 to len(self.envs)-1) + # None values indicate skipped trajectories + skipped_indices = [] + for idx in range(len(self.envs)): + # All indices should be in results_by_idx after the check above + result = results_by_idx[idx] + if result is None: + skipped_indices.append(idx) + yield result # Yield result (None for skipped, trajectory dict otherwise) + + # Store skipped indices and valid indices as instance variables for trainer to access + self._last_skipped_indices = skipped_indices + # Compute valid indices (complement of skipped indices) for easier batch filtering + total_count = len(self.envs) + self._last_valid_indices = [i for i in range(total_count) if i not in skipped_indices] + if self.engine_name == "verl": self.rollout_engine.sleep() diff --git a/rllm/trainer/verl/agent_ppo_trainer.py b/rllm/trainer/verl/agent_ppo_trainer.py index 9cf908b5..7d019721 100644 --- a/rllm/trainer/verl/agent_ppo_trainer.py +++ b/rllm/trainer/verl/agent_ppo_trainer.py @@ -153,6 +153,7 @@ def fit_agent(self): for epoch in range(self.config.trainer.total_epochs): pprint(f"epoch {epoch}, step {self.global_steps} started") + for batch_dict in self.train_dataloader: batch: DataProto = DataProto.from_single_dict(batch_dict) batch.non_tensor_batch["uid"] = np.array([str(uuid.uuid4()) for _ in range(len(batch.batch))], dtype=object) @@ -180,6 +181,28 @@ def fit_agent(self): batch = self._pad_dataproto_to_world_size(batch=batch) else: final_gen_batch_output, generate_metrics = self.generate_agent_trajectory(timing_raw=timing_raw, meta_info=batch.meta_info) + + # If some trajectories were skipped (overlong prompts), filter the batch to match + # Get valid indices directly from AgentExecutionEngine (handled internally) + valid_indices = getattr(self.agent_execution_engine, "_last_valid_indices", None) + skipped_indices = getattr(self.agent_execution_engine, "_last_skipped_indices", []) + + # Ensure batch size matches the number of trajectories collected + num_trajectories = len(final_gen_batch_output.batch) + if valid_indices is not None and len(valid_indices) != num_trajectories: + if len(valid_indices) > num_trajectories: + valid_indices = valid_indices[:num_trajectories] + else: + raise RuntimeError(f"Fewer valid indices ({len(valid_indices)}) than trajectories ({num_trajectories}).") + + if valid_indices is not None and len(valid_indices) < len(batch.batch): + # Filter batch to only include valid samples (matching the number of trajectories collected) + batch = batch.select_idxs(valid_indices) + + # Final sanity check: batch sizes must match before union + if len(batch.batch) != len(final_gen_batch_output.batch): + raise RuntimeError(f"Batch size mismatch before union: batch has {len(batch.batch)} samples, final_gen_batch_output has {len(final_gen_batch_output.batch)} samples. valid_indices: {len(valid_indices) if valid_indices else 'None'}, skipped_indices: {len(skipped_indices)}") + batch = batch.union(final_gen_batch_output) metrics.update(generate_metrics) @@ -551,16 +574,29 @@ def generate_agent_trajectory(self, timing_raw=None, meta_info=None): trajectories = [] if self.async_rollout_mode: gen_seq_generator = self.generate_agent_trajectories_async(timing_raw=timing_raw, meta_info=meta_info, mode="Token") - for _, trajectory in enumerate(gen_seq_generator): - trajectories.append(trajectory) + for trajectory in gen_seq_generator: + # Skip None trajectories (overlong prompts) - these are handled by AgentExecutionEngine + if trajectory is not None: + trajectories.append(trajectory) else: raise ValueError("Only async rollout mode is supported") + + # Check if all trajectories were skipped + if not trajectories: + raise RuntimeError("All trajectories were skipped (likely all prompts exceed max_prompt_length). Please check your dataset and increase max_prompt_length or enable filtering.") + + # Get skipped indices from AgentExecutionEngine (handled internally) + skipped_indices = getattr(self.agent_execution_engine, "_last_skipped_indices", []) + if skipped_indices: + print(f"Skipped {len(skipped_indices)} trajectories due to overlong prompts at env indices: {skipped_indices}") + # Sort trajectories by their idx, to ensure they are in order. trajectories.sort(key=lambda x: x["idx"]) with marked_timer("transform_trajectory", timing_raw): # Transform the raw trajectories into DataProto format. final_gen_batch_output, metrics = self._transform_agent_trajectories(trajectories) + return final_gen_batch_output, metrics def generate_agent_steps(self, timing_raw=None, meta_info=None, uids=None): @@ -833,18 +869,23 @@ def generate_agent_trajectories_async(self, timing_raw=None, meta_info=None, mod timing_raw = {} queue = Queue() + # Create a unique sentinel object to signal completion + # (Cannot use None since None is used for skipped trajectories) + _SENTINEL = object() + def runner(): async def consume(): async for item in self.agent_execution_engine.trajectory_generator(timing_raw=timing_raw, mode=mode, meta_info=meta_info): queue.put(item) - queue.put(None) # sentinel to signal done + # Use a special sentinel object instead of None (since None is used for skipped trajectories) + queue.put(_SENTINEL) # sentinel to signal done asyncio.run(consume()) Thread(target=runner, daemon=True).start() while True: item = queue.get() - if item is None: + if item is _SENTINEL: break yield item