Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
164 changes: 163 additions & 1 deletion rllm/trainer/verl/agent_workflow_trainer.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import asyncio
import json
import math
import os
import threading
import uuid
from collections import Counter, defaultdict
Expand All @@ -8,7 +10,9 @@

import numpy as np
import torch
from omegaconf import OmegaConf
from omegaconf import OmegaConf, open_dict
from torch.utils.data import Sampler
from torchdata.stateful_dataloader import StatefulDataLoader

from rllm.engine.agent_workflow_engine import AgentWorkflowEngine
from rllm.engine.rollout.verl_engine import VerlEngine
Expand Down Expand Up @@ -481,6 +485,8 @@ def _validate_agent(self):
uid_lst = []
workflow_metrics_by_source = defaultdict(lambda: defaultdict(list))

# import pdb; pdb.set_trace()

for test_data in self.val_dataloader:
test_batch = DataProto.from_single_dict(test_data)
test_batch.non_tensor_batch["task_ids"] = np.array([str(uuid.uuid4()) for _ in range(len(test_batch.batch))], dtype=object)
Expand Down Expand Up @@ -785,3 +791,159 @@ def visualize_trajectory_last_step(self, tensor_batch, sample_idx=0, max_samples
response_parts.append(f"{bg}{fg}{tok}\x1b[0m")

print("".join(response_parts))

def _create_dataloader(self, train_dataset, val_dataset, collate_fn, train_sampler: Sampler | None):
"""
Creates the train and validation dataloaders.
"""
# TODO: we have to make sure the batch size is divisible by the dp size
from verl.trainer.main_ppo import create_rl_dataset, create_rl_sampler

if train_dataset is None:
train_dataset = create_rl_dataset(self.config.data.train_files, self.config.data, self.tokenizer, self.processor)
if val_dataset is None:
val_dataset = create_rl_dataset(self.config.data.val_files, self.config.data, self.tokenizer, self.processor)

num_workers = min(self.config.data["dataloader_num_workers"], os.cpu_count())

def convert_extra_info_to_dict(example):
if "extra_info" in example and isinstance(example["extra_info"], str):
example["extra_info"] = json.loads(example["extra_info"])
return example

if train_dataset is not None:
train_dataset.dataframe = train_dataset.dataframe.map(convert_extra_info_to_dict, num_proc=num_workers)
if val_dataset is not None:
val_dataset.dataframe = val_dataset.dataframe.map(convert_extra_info_to_dict, num_proc=num_workers)

self.train_dataset, self.val_dataset = train_dataset, val_dataset

if train_sampler is None:
train_sampler = create_rl_sampler(self.config.data, self.train_dataset)
if collate_fn is None:
from verl.utils.dataset.rl_dataset import collate_fn as default_collate_fn

collate_fn = default_collate_fn

self.train_dataloader = StatefulDataLoader(
dataset=self.train_dataset,
batch_size=self.config.data.get("gen_batch_size", self.config.data.train_batch_size),
num_workers=num_workers,
drop_last=True,
collate_fn=collate_fn,
sampler=train_sampler,
)

val_batch_size = self.config.data.val_batch_size # Prefer config value if set
if val_batch_size is None:
val_batch_size = len(self.val_dataset)

self.val_dataloader = StatefulDataLoader(
dataset=self.val_dataset,
batch_size=val_batch_size,
num_workers=num_workers,
shuffle=self.config.data.get("validation_shuffle", True),
drop_last=False,
collate_fn=collate_fn,
)

assert len(self.train_dataloader) >= 1, "Train dataloader is empty!"
assert len(self.val_dataloader) >= 1, "Validation dataloader is empty!"

print(f"Size of train dataloader: {len(self.train_dataloader)}, Size of val dataloader: {len(self.val_dataloader)}")

total_training_steps = len(self.train_dataloader) * self.config.trainer.total_epochs

if self.config.trainer.total_training_steps is not None:
total_training_steps = self.config.trainer.total_training_steps

self.total_training_steps = total_training_steps
print(f"Total training steps: {self.total_training_steps}")

try:
OmegaConf.set_struct(self.config, True)
with open_dict(self.config):
if OmegaConf.select(self.config, "actor_rollout_ref.actor.optim"):
self.config.actor_rollout_ref.actor.optim.total_training_steps = total_training_steps
if OmegaConf.select(self.config, "critic.optim"):
self.config.critic.optim.total_training_steps = total_training_steps
except Exception as e:
print(f"Warning: Could not set total_training_steps in config. Structure missing? Error: {e}")

def _create_dataloader(self, train_dataset, val_dataset, collate_fn, train_sampler: Sampler | None):
"""
Creates the train and validation dataloaders.
"""
# TODO: we have to make sure the batch size is divisible by the dp size
from verl.trainer.main_ppo import create_rl_dataset, create_rl_sampler

if train_dataset is None:
train_dataset = create_rl_dataset(self.config.data.train_files, self.config.data, self.tokenizer, self.processor)
if val_dataset is None:
val_dataset = create_rl_dataset(self.config.data.val_files, self.config.data, self.tokenizer, self.processor)

num_workers = min(self.config.data["dataloader_num_workers"], os.cpu_count())

def convert_extra_info_to_dict(example):
if "extra_info" in example and isinstance(example["extra_info"], str):
example["extra_info"] = json.loads(example["extra_info"])
return example

if train_dataset is not None:
train_dataset.dataframe = train_dataset.dataframe.map(convert_extra_info_to_dict, num_proc=num_workers)
if val_dataset is not None:
val_dataset.dataframe = val_dataset.dataframe.map(convert_extra_info_to_dict, num_proc=num_workers)

self.train_dataset, self.val_dataset = train_dataset, val_dataset

if train_sampler is None:
train_sampler = create_rl_sampler(self.config.data, self.train_dataset)
if collate_fn is None:
from verl.utils.dataset.rl_dataset import collate_fn as default_collate_fn

collate_fn = default_collate_fn

self.train_dataloader = StatefulDataLoader(
dataset=self.train_dataset,
batch_size=self.config.data.get("gen_batch_size", self.config.data.train_batch_size),
num_workers=num_workers,
drop_last=True,
collate_fn=collate_fn,
sampler=train_sampler,
)

val_batch_size = self.config.data.val_batch_size # Prefer config value if set
if val_batch_size is None:
val_batch_size = len(self.val_dataset)

self.val_dataloader = StatefulDataLoader(
dataset=self.val_dataset,
batch_size=val_batch_size,
num_workers=num_workers,
shuffle=self.config.data.get("validation_shuffle", True),
drop_last=False,
collate_fn=collate_fn,
)

assert len(self.train_dataloader) >= 1, "Train dataloader is empty!"
assert len(self.val_dataloader) >= 1, "Validation dataloader is empty!"

print(f"Size of train dataloader: {len(self.train_dataloader)}, Size of val dataloader: {len(self.val_dataloader)}")

total_training_steps = len(self.train_dataloader) * self.config.trainer.total_epochs

if self.config.trainer.total_training_steps is not None:
total_training_steps = self.config.trainer.total_training_steps

self.total_training_steps = total_training_steps
print(f"Total training steps: {self.total_training_steps}")

try:
OmegaConf.set_struct(self.config, True)
with open_dict(self.config):
if OmegaConf.select(self.config, "actor_rollout_ref.actor.optim"):
self.config.actor_rollout_ref.actor.optim.total_training_steps = total_training_steps
if OmegaConf.select(self.config, "critic.optim"):
self.config.critic.optim.total_training_steps = total_training_steps
except Exception as e:
print(f"Warning: Could not set total_training_steps in config. Structure missing? Error: {e}")