From adbfb8309cba1f3330713ff4c1debb21505f1b83 Mon Sep 17 00:00:00 2001 From: QIngyang Wu Date: Mon, 3 Nov 2025 11:38:33 -0800 Subject: [PATCH] fix json_str/dict conversion --- rllm/trainer/verl/agent_workflow_trainer.py | 164 +++++++++++++++++++- 1 file changed, 163 insertions(+), 1 deletion(-) diff --git a/rllm/trainer/verl/agent_workflow_trainer.py b/rllm/trainer/verl/agent_workflow_trainer.py index cb711818..0cead14b 100644 --- a/rllm/trainer/verl/agent_workflow_trainer.py +++ b/rllm/trainer/verl/agent_workflow_trainer.py @@ -1,5 +1,7 @@ import asyncio +import json import math +import os import threading import uuid from collections import Counter, defaultdict @@ -8,7 +10,9 @@ import numpy as np import torch -from omegaconf import OmegaConf +from omegaconf import OmegaConf, open_dict +from torch.utils.data import Sampler +from torchdata.stateful_dataloader import StatefulDataLoader from rllm.engine.agent_workflow_engine import AgentWorkflowEngine from rllm.engine.rollout.verl_engine import VerlEngine @@ -481,6 +485,8 @@ def _validate_agent(self): uid_lst = [] workflow_metrics_by_source = defaultdict(lambda: defaultdict(list)) + # import pdb; pdb.set_trace() + for test_data in self.val_dataloader: test_batch = DataProto.from_single_dict(test_data) test_batch.non_tensor_batch["task_ids"] = np.array([str(uuid.uuid4()) for _ in range(len(test_batch.batch))], dtype=object) @@ -785,3 +791,159 @@ def visualize_trajectory_last_step(self, tensor_batch, sample_idx=0, max_samples response_parts.append(f"{bg}{fg}{tok}\x1b[0m") print("".join(response_parts)) + + def _create_dataloader(self, train_dataset, val_dataset, collate_fn, train_sampler: Sampler | None): + """ + Creates the train and validation dataloaders. + """ + # TODO: we have to make sure the batch size is divisible by the dp size + from verl.trainer.main_ppo import create_rl_dataset, create_rl_sampler + + if train_dataset is None: + train_dataset = create_rl_dataset(self.config.data.train_files, self.config.data, self.tokenizer, self.processor) + if val_dataset is None: + val_dataset = create_rl_dataset(self.config.data.val_files, self.config.data, self.tokenizer, self.processor) + + num_workers = min(self.config.data["dataloader_num_workers"], os.cpu_count()) + + def convert_extra_info_to_dict(example): + if "extra_info" in example and isinstance(example["extra_info"], str): + example["extra_info"] = json.loads(example["extra_info"]) + return example + + if train_dataset is not None: + train_dataset.dataframe = train_dataset.dataframe.map(convert_extra_info_to_dict, num_proc=num_workers) + if val_dataset is not None: + val_dataset.dataframe = val_dataset.dataframe.map(convert_extra_info_to_dict, num_proc=num_workers) + + self.train_dataset, self.val_dataset = train_dataset, val_dataset + + if train_sampler is None: + train_sampler = create_rl_sampler(self.config.data, self.train_dataset) + if collate_fn is None: + from verl.utils.dataset.rl_dataset import collate_fn as default_collate_fn + + collate_fn = default_collate_fn + + self.train_dataloader = StatefulDataLoader( + dataset=self.train_dataset, + batch_size=self.config.data.get("gen_batch_size", self.config.data.train_batch_size), + num_workers=num_workers, + drop_last=True, + collate_fn=collate_fn, + sampler=train_sampler, + ) + + val_batch_size = self.config.data.val_batch_size # Prefer config value if set + if val_batch_size is None: + val_batch_size = len(self.val_dataset) + + self.val_dataloader = StatefulDataLoader( + dataset=self.val_dataset, + batch_size=val_batch_size, + num_workers=num_workers, + shuffle=self.config.data.get("validation_shuffle", True), + drop_last=False, + collate_fn=collate_fn, + ) + + assert len(self.train_dataloader) >= 1, "Train dataloader is empty!" + assert len(self.val_dataloader) >= 1, "Validation dataloader is empty!" + + print(f"Size of train dataloader: {len(self.train_dataloader)}, Size of val dataloader: {len(self.val_dataloader)}") + + total_training_steps = len(self.train_dataloader) * self.config.trainer.total_epochs + + if self.config.trainer.total_training_steps is not None: + total_training_steps = self.config.trainer.total_training_steps + + self.total_training_steps = total_training_steps + print(f"Total training steps: {self.total_training_steps}") + + try: + OmegaConf.set_struct(self.config, True) + with open_dict(self.config): + if OmegaConf.select(self.config, "actor_rollout_ref.actor.optim"): + self.config.actor_rollout_ref.actor.optim.total_training_steps = total_training_steps + if OmegaConf.select(self.config, "critic.optim"): + self.config.critic.optim.total_training_steps = total_training_steps + except Exception as e: + print(f"Warning: Could not set total_training_steps in config. Structure missing? Error: {e}") + + def _create_dataloader(self, train_dataset, val_dataset, collate_fn, train_sampler: Sampler | None): + """ + Creates the train and validation dataloaders. + """ + # TODO: we have to make sure the batch size is divisible by the dp size + from verl.trainer.main_ppo import create_rl_dataset, create_rl_sampler + + if train_dataset is None: + train_dataset = create_rl_dataset(self.config.data.train_files, self.config.data, self.tokenizer, self.processor) + if val_dataset is None: + val_dataset = create_rl_dataset(self.config.data.val_files, self.config.data, self.tokenizer, self.processor) + + num_workers = min(self.config.data["dataloader_num_workers"], os.cpu_count()) + + def convert_extra_info_to_dict(example): + if "extra_info" in example and isinstance(example["extra_info"], str): + example["extra_info"] = json.loads(example["extra_info"]) + return example + + if train_dataset is not None: + train_dataset.dataframe = train_dataset.dataframe.map(convert_extra_info_to_dict, num_proc=num_workers) + if val_dataset is not None: + val_dataset.dataframe = val_dataset.dataframe.map(convert_extra_info_to_dict, num_proc=num_workers) + + self.train_dataset, self.val_dataset = train_dataset, val_dataset + + if train_sampler is None: + train_sampler = create_rl_sampler(self.config.data, self.train_dataset) + if collate_fn is None: + from verl.utils.dataset.rl_dataset import collate_fn as default_collate_fn + + collate_fn = default_collate_fn + + self.train_dataloader = StatefulDataLoader( + dataset=self.train_dataset, + batch_size=self.config.data.get("gen_batch_size", self.config.data.train_batch_size), + num_workers=num_workers, + drop_last=True, + collate_fn=collate_fn, + sampler=train_sampler, + ) + + val_batch_size = self.config.data.val_batch_size # Prefer config value if set + if val_batch_size is None: + val_batch_size = len(self.val_dataset) + + self.val_dataloader = StatefulDataLoader( + dataset=self.val_dataset, + batch_size=val_batch_size, + num_workers=num_workers, + shuffle=self.config.data.get("validation_shuffle", True), + drop_last=False, + collate_fn=collate_fn, + ) + + assert len(self.train_dataloader) >= 1, "Train dataloader is empty!" + assert len(self.val_dataloader) >= 1, "Validation dataloader is empty!" + + print(f"Size of train dataloader: {len(self.train_dataloader)}, Size of val dataloader: {len(self.val_dataloader)}") + + total_training_steps = len(self.train_dataloader) * self.config.trainer.total_epochs + + if self.config.trainer.total_training_steps is not None: + total_training_steps = self.config.trainer.total_training_steps + + self.total_training_steps = total_training_steps + print(f"Total training steps: {self.total_training_steps}") + + try: + OmegaConf.set_struct(self.config, True) + with open_dict(self.config): + if OmegaConf.select(self.config, "actor_rollout_ref.actor.optim"): + self.config.actor_rollout_ref.actor.optim.total_training_steps = total_training_steps + if OmegaConf.select(self.config, "critic.optim"): + self.config.critic.optim.total_training_steps = total_training_steps + except Exception as e: + print(f"Warning: Could not set total_training_steps in config. Structure missing? Error: {e}")