Skip to content
54 changes: 50 additions & 4 deletions apps/grpo/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,13 @@ class Episode:
request_len: int
response_len: int
target: Any | None = None
request: str | None = None
response: str | None = None
Comment on lines +50 to +51
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we need those? I believe this is redundant. They are already in Completion. It would be redundant.

https://github.com/DNXie/forge/blob/main/src/forge/data_models/completion.py#L19C5-L23C14

Copy link
Member Author

@DNXie DNXie Nov 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it is more straightforward to keep record of them now because completion.prompt is of type Prompt instead of str, which is a sequence of conversations. If we don't keep record of request:str here, which is the last message, we have to deal with this Prompt type in to_dict which is really not a good way to make it general.

# Processed data
completion: Completion | None = None
ref_logprobs: torch.Tensor | None = None
reward: float | None = None
reward_breakdown: dict[str, float] | None = None
advantage: float | None = None

@property
Expand All @@ -73,6 +76,32 @@ def response_tensor(self) -> torch.Tensor:
tensor = F.pad(tensor, (0, diff), value=self.pad_id)
return tensor

def to_dict(self, exclude: list[str] | None = None) -> dict[str, Any]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am thinking that we could add some of the fields from Completion, i.e. prompt, text, stop_reason, generator_version, metadata. Wdyt?

Copy link
Member Author

@DNXie DNXie Nov 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

prompt and text are already covered. I think we either flatten the entire completion or do not introduce much overwhelming information here. In my opinion, only prompt and response are the most important thing. I think we could just keep it this way. After all we are just providing an interface and a use case of the API. The users can decide what information they need. Let's add them when we see needs

"""Convert episode to dict, optionally excluding specified fields."""
result = {
"episode_id": self.episode_id,
"policy_version": self.policy_version,
"prompt": self.request,
"response": self.response,
"target": str(self.target),
"reward": self.reward,
"advantage": self.advantage,
"request_len": self.request_len,
"response_len": self.response_len,
"pad_id": self.pad_id,
"ref_logprobs": self.ref_logprobs,
"completion": self.completion,
}

if self.reward_breakdown is not None and "reward_breakdown" not in exclude:
result.update(self.reward_breakdown)
Comment on lines +96 to +97
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is flattening the rewards, but it is not obvious that this is what we are doing. I wonder if (a) we should flatten. But assuming we should, lets add a comment here.

Do you see any strong argument for not leaving it as a dictionary? not sure if this is poorly displayed in wandb

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The record_episode_sample function needs this to be flattened. I can flatten this alternatively in record_episode_sample, wdyt?

Copy link
Contributor

@felipemello1 felipemello1 Nov 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think we should delete record_episode_sample :X

Why does it need it to be flat?

Copy link
Member Author

@DNXie DNXie Nov 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

def record_episode_sample(table_name: str, episode):
    """
    Record a structured sample-level log for a single episode.
    Args:
        table_name (str): logging prefix (e.g. "rollout/sample").
        episode (Episode): episode object with filled attributes.
    """
    sample = episode.to_dict(exclude=["ref_logprobs", "completion"])
    record_metric(table_name, sample, Reduce.SAMPLE)

Because record_episode_sample calls record_metric where sample needs to be a flattened dict.
I don't think deleting the API is a good idea as it has some ad-hoc process of the dict, which we don't want to expose to the users.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I could exclude it but just so you know that I excluded ref_logprobs is that it requires further processing to be logged. Otherwise it would lead to an empty table.


if exclude:
for key in exclude:
result.pop(key, None)

return result


# Represents the group (G) of episodes in GRPO
Group = list[Episode]
Expand Down Expand Up @@ -143,8 +172,11 @@ class RewardActor(ForgeActor):
reward_functions: list[Callable]

@endpoint
async def evaluate_response(self, prompt: str, response: str, target: str) -> float:
async def evaluate_response(
self, prompt: str, response: str, target: str
) -> (dict[str, float], float):
total_rewards = 0.0
reward_breakdown = {} # reward breakdown by function
for reward_fn in self.reward_functions:
reward = reward_fn(prompt, response, target)
total_rewards += reward
Expand All @@ -153,6 +185,7 @@ async def evaluate_response(self, prompt: str, response: str, target: str) -> fl
reward_fn_name = getattr(
reward_fn, "__name__", reward_fn.__class__.__name__
)
reward_breakdown[reward_fn_name] = reward
# per function reward
record_metric(
f"reward/evaluate_response/sum_{reward_fn_name}_reward",
Expand Down Expand Up @@ -182,8 +215,8 @@ async def evaluate_response(self, prompt: str, response: str, target: str) -> fl
Reduce.SUM,
)

avg_reward = total_rewards / len(self.reward_functions)
return avg_reward
avg_reward: float = total_rewards / len(self.reward_functions)
return reward_breakdown, avg_reward


@dataclass
Expand Down Expand Up @@ -385,9 +418,14 @@ async def continuous_rollouts():
request_len=max_req_tokens,
response_len=max_res_tokens,
target=target,
request=prompt,
response=response.text,
completion=response,
)
episode.reward = await reward_actor.evaluate_response.route(
(
episode.reward_breakdown,
episode.reward,
) = await reward_actor.evaluate_response.route(
prompt=prompt, response=response.text, target=target
)
episodes.append(episode)
Expand All @@ -412,6 +450,14 @@ async def continuous_rollouts():
episode.advantage = advantage
await replay_buffer.add.call_one(episode)

sample = episode.to_dict(exclude=["ref_logprobs", "completion"])
sample["score"] = sample["reward"]
record_metric(
"main_samples/continuous_rollouts/sample_table",
sample,
Reduce.SAMPLE,
)

rollout_count += 1
record_metric(
"main/continuous_rollouts/count_rollout_iterations", 1, Reduce.SUM
Expand Down
2 changes: 2 additions & 0 deletions src/forge/observability/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
record_metric,
Reduce,
reduce_metrics_states,
SampleAccumulator,
StdAccumulator,
SumAccumulator,
WandbBackend,
Expand Down Expand Up @@ -64,4 +65,5 @@
"MaxAccumulator",
"MinAccumulator",
"StdAccumulator",
"SampleAccumulator",
]
14 changes: 13 additions & 1 deletion src/forge/observability/metric_actors.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
LoggerBackend,
LoggingMode,
MetricCollector,
Reduce,
reduce_metrics_states,
)

Expand Down Expand Up @@ -432,9 +433,20 @@ def extract_values_from_valuemesh(results) -> list[dict[str, Any]]:
# Reduce metrics from states
reduced_metrics = reduce_metrics_states(all_local_states)

# Split into scalar metrics and sample metrics
scalar_metrics = [
m for m in reduced_metrics if m.reduction != Reduce.SAMPLE
]
sample_metrics = [
m for m in reduced_metrics if m.reduction == Reduce.SAMPLE
]

# Log to global backends
for backend_name, backend in self.global_logger_backends.items():
await backend.log_batch(reduced_metrics, global_step)
if scalar_metrics:
await backend.log_batch(scalar_metrics, global_step)
if sample_metrics:
await backend.log_samples(sample_metrics, global_step)

@endpoint
async def has_fetcher(self, proc_id: str) -> bool:
Expand Down
Loading
Loading