-
Notifications
You must be signed in to change notification settings - Fork 66
Add Sample-level Logging API #486
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
37c2ac9
9234389
e7e42b9
b2c8d88
98054fc
1c8e0ce
53e9c3a
291cd1f
1cc4a2c
b413d7b
0635c9a
755d632
9d2a0cb
2b0496e
62fd0cc
38326d7
14f0a0f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -47,10 +47,13 @@ class Episode: | |
| request_len: int | ||
| response_len: int | ||
| target: Any | None = None | ||
| request: str | None = None | ||
| response: str | None = None | ||
| # Processed data | ||
| completion: Completion | None = None | ||
| ref_logprobs: torch.Tensor | None = None | ||
| reward: float | None = None | ||
| reward_breakdown: dict[str, float] | None = None | ||
| advantage: float | None = None | ||
|
|
||
| @property | ||
|
|
@@ -73,6 +76,32 @@ def response_tensor(self) -> torch.Tensor: | |
| tensor = F.pad(tensor, (0, diff), value=self.pad_id) | ||
| return tensor | ||
|
|
||
| def to_dict(self, exclude: list[str] | None = None) -> dict[str, Any]: | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I am thinking that we could add some of the fields from Completion, i.e. prompt, text, stop_reason, generator_version, metadata. Wdyt?
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. prompt and text are already covered. I think we either flatten the entire completion or do not introduce much overwhelming information here. In my opinion, only prompt and response are the most important thing. I think we could just keep it this way. After all we are just providing an interface and a use case of the API. The users can decide what information they need. Let's add them when we see needs |
||
| """Convert episode to dict, optionally excluding specified fields.""" | ||
| result = { | ||
| "episode_id": self.episode_id, | ||
| "policy_version": self.policy_version, | ||
| "prompt": self.request, | ||
| "response": self.response, | ||
| "target": str(self.target), | ||
| "reward": self.reward, | ||
| "advantage": self.advantage, | ||
| "request_len": self.request_len, | ||
| "response_len": self.response_len, | ||
| "pad_id": self.pad_id, | ||
| "ref_logprobs": self.ref_logprobs, | ||
| "completion": self.completion, | ||
| } | ||
|
|
||
| if self.reward_breakdown is not None and "reward_breakdown" not in exclude: | ||
| result.update(self.reward_breakdown) | ||
|
Comment on lines
+96
to
+97
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this is flattening the rewards, but it is not obvious that this is what we are doing. I wonder if (a) we should flatten. But assuming we should, lets add a comment here. Do you see any strong argument for not leaving it as a dictionary? not sure if this is poorly displayed in wandb
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. i think we should delete record_episode_sample :X Why does it need it to be flat?
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. def record_episode_sample(table_name: str, episode):
"""
Record a structured sample-level log for a single episode.
Args:
table_name (str): logging prefix (e.g. "rollout/sample").
episode (Episode): episode object with filled attributes.
"""
sample = episode.to_dict(exclude=["ref_logprobs", "completion"])
record_metric(table_name, sample, Reduce.SAMPLE)Because
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I could exclude it but just so you know that I excluded |
||
|
|
||
| if exclude: | ||
| for key in exclude: | ||
| result.pop(key, None) | ||
|
|
||
| return result | ||
|
|
||
|
|
||
| # Represents the group (G) of episodes in GRPO | ||
| Group = list[Episode] | ||
|
|
@@ -143,8 +172,11 @@ class RewardActor(ForgeActor): | |
| reward_functions: list[Callable] | ||
|
|
||
| @endpoint | ||
| async def evaluate_response(self, prompt: str, response: str, target: str) -> float: | ||
| async def evaluate_response( | ||
| self, prompt: str, response: str, target: str | ||
| ) -> (dict[str, float], float): | ||
| total_rewards = 0.0 | ||
| reward_breakdown = {} # reward breakdown by function | ||
| for reward_fn in self.reward_functions: | ||
| reward = reward_fn(prompt, response, target) | ||
| total_rewards += reward | ||
|
|
@@ -153,6 +185,7 @@ async def evaluate_response(self, prompt: str, response: str, target: str) -> fl | |
| reward_fn_name = getattr( | ||
| reward_fn, "__name__", reward_fn.__class__.__name__ | ||
| ) | ||
| reward_breakdown[reward_fn_name] = reward | ||
| # per function reward | ||
| record_metric( | ||
| f"reward/evaluate_response/sum_{reward_fn_name}_reward", | ||
|
|
@@ -182,8 +215,8 @@ async def evaluate_response(self, prompt: str, response: str, target: str) -> fl | |
| Reduce.SUM, | ||
| ) | ||
|
|
||
| avg_reward = total_rewards / len(self.reward_functions) | ||
| return avg_reward | ||
| avg_reward: float = total_rewards / len(self.reward_functions) | ||
| return reward_breakdown, avg_reward | ||
|
|
||
|
|
||
| @dataclass | ||
|
|
@@ -385,9 +418,14 @@ async def continuous_rollouts(): | |
| request_len=max_req_tokens, | ||
| response_len=max_res_tokens, | ||
| target=target, | ||
| request=prompt, | ||
| response=response.text, | ||
| completion=response, | ||
| ) | ||
| episode.reward = await reward_actor.evaluate_response.route( | ||
| ( | ||
| episode.reward_breakdown, | ||
| episode.reward, | ||
| ) = await reward_actor.evaluate_response.route( | ||
| prompt=prompt, response=response.text, target=target | ||
| ) | ||
| episodes.append(episode) | ||
|
|
@@ -412,6 +450,14 @@ async def continuous_rollouts(): | |
| episode.advantage = advantage | ||
| await replay_buffer.add.call_one(episode) | ||
|
|
||
| sample = episode.to_dict(exclude=["ref_logprobs", "completion"]) | ||
| sample["score"] = sample["reward"] | ||
| record_metric( | ||
| "main_samples/continuous_rollouts/sample_table", | ||
| sample, | ||
| Reduce.SAMPLE, | ||
| ) | ||
|
|
||
| rollout_count += 1 | ||
| record_metric( | ||
| "main/continuous_rollouts/count_rollout_iterations", 1, Reduce.SUM | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do we need those? I believe this is redundant. They are already in
Completion. It would be redundant.https://github.com/DNXie/forge/blob/main/src/forge/data_models/completion.py#L19C5-L23C14
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it is more straightforward to keep record of them now because completion.prompt is of type
Promptinstead of str, which is a sequence of conversations. If we don't keep record ofrequest:strhere, which is the last message, we have to deal with thisPrompttype into_dictwhich is really not a good way to make it general.