From 37c2ac99f76b8e890a5139ab454f66c908353b84 Mon Sep 17 00:00:00 2001 From: DNXie Date: Tue, 21 Oct 2025 15:18:03 -0700 Subject: [PATCH 01/17] add accumulator and test --- src/forge/observability/__init__.py | 5 ++ src/forge/observability/metrics.py | 99 ++++++++++++++++++++++++++++- 2 files changed, 103 insertions(+), 1 deletion(-) diff --git a/src/forge/observability/__init__.py b/src/forge/observability/__init__.py index 555aa761e..474483426 100644 --- a/src/forge/observability/__init__.py +++ b/src/forge/observability/__init__.py @@ -24,8 +24,10 @@ record_metric, Reduce, reduce_metrics_states, + SampleAccumulator, StdAccumulator, SumAccumulator, + TopBottomKFilter, WandbBackend, ) from .perf_tracker import trace, Tracer @@ -64,4 +66,7 @@ "MaxAccumulator", "MinAccumulator", "StdAccumulator", + "SampleAccumulator", + # Filter classes + "TopBottomKFilter", ] diff --git a/src/forge/observability/metrics.py b/src/forge/observability/metrics.py index 774fb1cc4..7011981e2 100644 --- a/src/forge/observability/metrics.py +++ b/src/forge/observability/metrics.py @@ -4,13 +4,15 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import heapq +import itertools import logging import os from abc import ABC, abstractmethod from dataclasses import dataclass from datetime import datetime from enum import Enum -from typing import Any +from typing import Any, Dict, List import pytz @@ -68,6 +70,7 @@ class Reduce(Enum): MAX = "max" MIN = "min" STD = "std" + SAMPLE = "sample" @property def accumulator_class(self): @@ -77,6 +80,7 @@ def accumulator_class(self): Reduce.MAX: MaxAccumulator, Reduce.MIN: MinAccumulator, Reduce.STD: StdAccumulator, + Reduce.SAMPLE: SampleAccumulator, } return mapping[self] @@ -182,6 +186,55 @@ def reduce_metrics_states(states: list[dict[str, dict[str, Any]]]) -> list[Metri return reduced_metrics +################# +# SampleFilters # +################# + + +class TopBottomKFilter: + """Keep the top-k and bottom-k samples by a given key (e.g., reward).""" + + def __init__(self, top_k=1, bottom_k=1, key="reward"): + self.top_k = top_k + self.bottom_k = bottom_k + self.key = key + self._top_heap = [] # min-heap for top-k + self._bottom_heap = [] # max-heap for bottom-k (store -value) + self._counter = itertools.count() # tie-breaker id generator + + def filter_append(self, sample: Dict) -> bool: + val = sample.get(self.key, 0.0) + idx = next(self._counter) # unique tiebreaker + + # If top_k or bottom_k <= 0, it means "disable" that side of filtering (i.e., keep none). + # maintain top-k + if self.top_k > 0: + if len(self._top_heap) < self.top_k: + heapq.heappush(self._top_heap, (val, idx, sample)) + else: + heapq.heappushpop(self._top_heap, (val, idx, sample)) + + # maintain bottom-k + if self.bottom_k > 0: + if len(self._bottom_heap) < self.bottom_k: + heapq.heappush(self._bottom_heap, (-val, idx, sample)) + else: + heapq.heappushpop(self._bottom_heap, (-val, idx, sample)) + + # always return False here because we don't store in samples list + return False + + def filter_flush(self, samples: List[Dict]) -> List[Dict]: + tops = [s for _, _, s in self._top_heap] + bottoms = [s for _, _, s in self._bottom_heap] + return bottoms + tops + + def reset(self): + self._top_heap = [] + self._bottom_heap = [] + self._counter = itertools.count() + + ################ # Accumulators # ################ @@ -392,6 +445,50 @@ def reset(self) -> None: self.count = 0 +class SampleAccumulator(MetricAccumulator): + """Accumulator for sample-level metrics (e.g., prompt/response/reward dicts). + Optionally uses a sample filter to decide what to keep at append/flush time. + """ + + def __init__(self, reduction: Reduce): + super().__init__(reduction) + self.samples: List[Dict[str, Any]] = [] + self.filter = TopBottomKFilter() + + def append(self, value: dict) -> None: + if not isinstance(value, dict): + raise ValueError(f"Expected dict, got {type(value)}") + + # Only keep the sample if filter_append returns True + if self.filter.filter_append(value): + self.samples.append(value) + + def get_value(self) -> list[dict]: + """Return locally collected (and optionally filtered) samples.""" + # Apply flush-time filter (e.g. heap selection, threshold trimming) + return self.filter.filter_flush(self.samples) + + def get_state(self) -> Dict[str, Any]: + """Serialize accumulator state for cross-rank reduction.""" + return { + "reduction_type": self.reduction_type.value, + "samples": self.get_value(), + } + + @classmethod + def get_reduced_value_from_states(cls, states: List[Dict[str, Any]]) -> list[dict]: + """Merge sample states across ranks.""" + merged = [] + for s in states: + merged.extend(s.get("samples", [])) + return merged + + def reset(self) -> None: + """Clear local samples and reset filter state.""" + self.samples.clear() + self.filter.reset() + + ############# # Collector # ############# From 92343899c0fa7f059b2506a1219f604012186826 Mon Sep 17 00:00:00 2001 From: DNXie Date: Tue, 21 Oct 2025 16:02:28 -0700 Subject: [PATCH 02/17] functions, tests --- apps/grpo/main.py | 16 ++- src/forge/observability/__init__.py | 2 + src/forge/observability/metric_actors.py | 13 +- src/forge/observability/metrics.py | 135 ++++++++++++++++-- .../unit_tests/observability/test_metrics.py | 69 ++++++--- 5 files changed, 203 insertions(+), 32 deletions(-) diff --git a/apps/grpo/main.py b/apps/grpo/main.py index 82152e1b8..06d2968dd 100644 --- a/apps/grpo/main.py +++ b/apps/grpo/main.py @@ -29,7 +29,7 @@ from forge.data.rewards import MathReward, ThinkingReward from forge.data_models.completion import Completion from forge.observability.metric_actors import get_or_create_metric_logger -from forge.observability.metrics import record_metric, Reduce +from forge.observability.metrics import record_episode_sample, record_metric, Reduce from forge.observability.perf_tracker import Tracer from forge.types import LauncherConfig, ProvisionerConfig @@ -51,6 +51,7 @@ class Episode: completion: Completion | None = None ref_logprobs: torch.Tensor | None = None reward: float | None = None + reward_breakdown: dict[str, float] | None = None advantage: float | None = None @property @@ -143,8 +144,11 @@ class RewardActor(ForgeActor): reward_functions: list[Callable] @endpoint - async def evaluate_response(self, prompt: str, response: str, target: str) -> float: + async def evaluate_response( + self, prompt: str, response: str, target: str + ) -> dict[str, float]: total_rewards = 0.0 + reward_breakdown = {} # reward breakdown by function for reward_fn in self.reward_functions: reward = reward_fn(prompt, response, target) total_rewards += reward @@ -153,6 +157,7 @@ async def evaluate_response(self, prompt: str, response: str, target: str) -> fl reward_fn_name = getattr( reward_fn, "__name__", reward_fn.__class__.__name__ ) + reward_breakdown[reward_fn_name] = reward # per function reward record_metric( f"reward/evaluate_response/sum_{reward_fn_name}_reward", @@ -183,7 +188,8 @@ async def evaluate_response(self, prompt: str, response: str, target: str) -> fl ) avg_reward = total_rewards / len(self.reward_functions) - return avg_reward + reward_breakdown["reward"] = avg_reward + return reward_breakdown @dataclass @@ -387,9 +393,10 @@ async def continuous_rollouts(): target=target, completion=response, ) - episode.reward = await reward_actor.evaluate_response.route( + episode.reward_breakdown = await reward_actor.evaluate_response.route( prompt=prompt, response=response.text, target=target ) + episode.reward = episode.reward_breakdown["reward"] episodes.append(episode) # Build input_ids for reference logprobs @@ -411,6 +418,7 @@ async def continuous_rollouts(): for episode, advantage in zip(episodes, advantages): episode.advantage = advantage await replay_buffer.add.call_one(episode) + record_episode_sample("rollout/sample", episode) rollout_count += 1 record_metric( diff --git a/src/forge/observability/__init__.py b/src/forge/observability/__init__.py index 474483426..e80bd1860 100644 --- a/src/forge/observability/__init__.py +++ b/src/forge/observability/__init__.py @@ -21,6 +21,7 @@ MetricAccumulator, MetricCollector, MinAccumulator, + record_episode_sample, record_metric, Reduce, reduce_metrics_states, @@ -37,6 +38,7 @@ # Main API functions "record_metric", "reduce_metrics_states", + "record_episode_sample", "get_logger_backend_class", "get_or_create_metric_logger", # Performance tracking diff --git a/src/forge/observability/metric_actors.py b/src/forge/observability/metric_actors.py index 83fdfb69b..438db4f4a 100644 --- a/src/forge/observability/metric_actors.py +++ b/src/forge/observability/metric_actors.py @@ -432,9 +432,20 @@ def extract_values_from_valuemesh(results) -> list[dict[str, Any]]: # Reduce metrics from states reduced_metrics = reduce_metrics_states(all_local_states) + # Split into scalar metrics and sample metrics + scalar_metrics = [ + m for m in reduced_metrics if m.reduction != Reduce.SAMPLE + ] + sample_metrics = { + m.key: m.value for m in reduced_metrics if m.reduction == Reduce.SAMPLE + } + # Log to global backends for backend_name, backend in self.global_logger_backends.items(): - await backend.log_batch(reduced_metrics, global_step) + if scalar_metrics: + await backend.log_batch(scalar_metrics, global_step) + if sample_metrics: + await backend.log_samples(sample_metrics, global_step) @endpoint async def has_fetcher(self, proc_id: str) -> bool: diff --git a/src/forge/observability/metrics.py b/src/forge/observability/metrics.py index 7011981e2..c095aba36 100644 --- a/src/forge/observability/metrics.py +++ b/src/forge/observability/metrics.py @@ -139,12 +139,32 @@ def reduce_metrics_states(states: list[dict[str, dict[str, Any]]]) -> list[Metri list[Metric]: List of reduced metrics Example: - states = [ - {"loss": {"count": 5, "sum": 14, "reduction_type": Reduce.MEAN}}, - {"loss": {"count": 10, "sum": 16, "reduction_type": Reduce.MEAN}}, - ] - reduce_metrics_states(states) - >>> [Metric(key="loss", value=2.0, reduction=Reduce.MEAN)] + >>> states = [ + ... { + ... "loss": {"count": 5, "sum": 14, "reduction_type": "mean"}, + ... "reward/sample": { + ... "reduction_type": "sample", + ... "samples": [{"episode_id": 1, "reward": 0.5}], + ... }, + ... }, + ... { + ... "loss": {"count": 10, "sum": 16, "reduction_type": "mean"}, + ... "reward/sample": { + ... "reduction_type": "sample", + ... "samples": [{"episode_id": 2, "reward": 1.0}], + ... }, + ... }, + ... ] + >>> metrics = reduce_metrics_states(states) + >>> for m in metrics: + ... print(m) + Metric(key='loss', value=2.0, reduction=Reduce.MEAN) + Metric( + key='reward/sample', + value=[{'episode_id': 1, 'reward': 0.5}, + {'episode_id': 2, 'reward': 1.0}], + reduction=Reduce.SAMPLE, + ) Raises: ValueError: on mismatched reduction types for the same metric key. @@ -186,6 +206,31 @@ def reduce_metrics_states(states: list[dict[str, dict[str, Any]]]) -> list[Metri return reduced_metrics +def record_episode_sample(table_name: str, episode): + """ + Record a structured sample-level log for a single episode. + Args: + table_name (str): logging prefix (e.g. "rollout/sample"). + episode (Episode): episode object with filled attributes. + """ + sample = { + "episode_id": episode.episode_id, + "policy_version": episode.policy_version, + "prompt": episode.request, + "response": episode.response, + "target": str(episode.target), + **( + episode.reward_breakdown or {} + ), # per-fn breakdown including the average reward + "advantage": episode.advantage, + "request_len": episode.request_len, + "response_len": episode.response_len, + "pad_id": episode.pad_id, + } + + record_metric(table_name, sample, Reduce.SAMPLE) + + ################# # SampleFilters # ################# @@ -656,7 +701,12 @@ def push(self, metric: Metric) -> None: # For PER_RANK_NO_REDUCE backends: stream without reduce for backend in self.per_rank_no_reduce_backends: - backend.log_stream(metric=metric, global_step=self.global_step) + if metric.reduction == Reduce.SAMPLE: + # Wrap singleton Metric into expected {key: [list_of_dicts]} format + sample = {metric.key: [metric.value]} + asyncio.create_task(backend.log_samples(sample, self.global_step)) + else: + backend.log_stream(metric=metric, global_step=self.global_step) # Always accumulate for reduction and state return key = metric.key @@ -711,8 +761,21 @@ async def flush( if self.per_rank_reduce_backends: metrics_for_backends = reduce_metrics_states([states]) + # Split into scalar metrics and sample metrics + scalar_metrics = [ + m for m in metrics_for_backends if m.reduction != Reduce.SAMPLE + ] + sample_metrics = { + m.key: m.value + for m in metrics_for_backends + if m.reduction == Reduce.SAMPLE + } + for backend in self.per_rank_reduce_backends: - await backend.log_batch(metrics_for_backends, global_step) + if scalar_metrics: + await backend.log_batch(scalar_metrics, global_step) + if sample_metrics: + await backend.log_samples(sample_metrics, global_step) # Update step counter for streaming backends # Note: This is incremented AFTER flush completes, so metrics recorded between @@ -846,6 +909,16 @@ def log_stream(self, metric: Metric, global_step: int, *args, **kwargs) -> None: async def finish(self) -> None: pass + async def log_samples(self, samples: Dict[str, List[dict]], step: int) -> None: + """Pretty-print sample-level logs to console.""" + import json + + logger.info(f"========== SAMPLE LOGS STEP {step} ==========") + for table_name, table_rows in samples.items(): + logger.info(f"[{table_name}] ({len(table_rows)} samples)") + logger.info(json.dumps(table_rows, indent=2, ensure_ascii=False)) + logger.info("==============================================\n") + class WandbBackend(LoggerBackend): """ @@ -882,6 +955,7 @@ def __init__( ) self.run = None self.process_name = None + self._tables: dict[str, "wandb.Table"] = {} async def init( self, @@ -992,13 +1066,58 @@ def log_stream(self, metric: Metric, global_step: int, *args, **kwargs) -> None: # note: here we dont use step since wandb keeps only the latest value for each step self.run.log(log_data) + async def log_samples(self, samples: Dict[str, List[dict]], step: int) -> None: + """Log sample-level data incrementally to persistent WandB Tables.""" + import wandb + + if not self.run: + return + + for table_name, table_rows in samples.items(): + if not table_rows: + continue + + # If table doesn't exist yet, create it in INCREMENTAL mode + if table_name not in self._tables: + columns = list(table_rows[0].keys()) + table = wandb.Table(columns=columns, log_mode="INCREMENTAL") + self._tables[table_name] = table + logger.info( + f"WandbBackend: Created new incremental table: {table_name}" + ) + else: + table = self._tables[table_name] + + # Add rows (fill missing columns with None) + for s in table_rows: + values = [s.get(c) for c in table.columns] + table.add_data(*values) + + # Log the same table object (INCREMENTAL update) + self.run.log({f"{table_name}_table": table}) + logger.info( + f"WandbBackend: Appended {len(table_rows)} rows to incremental table '{table_name}' at step {step}" + ) + def get_metadata_for_secondary_ranks(self) -> dict[str, Any]: if self.run and self.per_rank_share_run: return {"shared_run_id": self.run.id} return {} async def finish(self) -> None: + import wandb + if self.run: + # Convert each incremental table to immutable before finishing + for table_name, incr_table in self._tables.items(): + final_table = wandb.Table( + columns=incr_table.columns, + data=incr_table.data, + log_mode="IMMUTABLE", + ) + self.run.log({table_name: final_table}) + logger.info(f"WandbBackend: Finalized table {table_name}") + self.run.finish() logger.info(f"WandbBackend {self.process_name}: Finished run") diff --git a/tests/unit_tests/observability/test_metrics.py b/tests/unit_tests/observability/test_metrics.py index 948626a73..63f7046db 100644 --- a/tests/unit_tests/observability/test_metrics.py +++ b/tests/unit_tests/observability/test_metrics.py @@ -115,33 +115,64 @@ def test_empty_states(self): def test_single_state(self): """Test reduce_metrics_states with single state.""" - states = [{"loss": {"reduction_type": "mean", "sum": 10.0, "count": 2}}] - result = reduce_metrics_states(states) - assert len(result) == 1 - assert result[0].key == "loss" - assert result[0].value == 5.0 - assert result[0].reduction == Reduce.MEAN + states = [ + { + "loss": {"reduction_type": "mean", "sum": 10.0, "count": 2}, + "rollout/sample": { + "reduction_type": "sample", + "samples": [{"id": 1, "reward": 0.5}], + }, + } + ] + metrics = reduce_metrics_states(states) + assert len(metrics) == 2 + # Convert to dict for easier testing + result_dict = {m.key: (m.value, m.reduction) for m in metrics} + + assert result_dict["loss"][0] == 5.0 + assert result_dict["loss"][1] == Reduce.MEAN + + assert result_dict["rollout/sample"][0] == [{"id": 1, "reward": 0.5}] + assert result_dict["rollout/sample"][1] == Reduce.SAMPLE def test_multiple_states(self): """Test reduce_metrics_states with multiple states.""" states = [ - {"loss": {"reduction_type": "mean", "sum": 10.0, "count": 2}}, - {"loss": {"reduction_type": "mean", "sum": 20.0, "count": 3}}, + { + "loss": {"reduction_type": "mean", "sum": 10.0, "count": 2}, + "rollout/sample": { + "reduction_type": "sample", + "samples": [{"id": 1, "reward": 0.5}], + }, + }, + { + "loss": {"reduction_type": "mean", "sum": 20.0, "count": 3}, + "rollout/sample": { + "reduction_type": "sample", + "samples": [{"id": 2, "reward": 0.8}], + }, + }, {"accuracy": {"reduction_type": "sum", "total": 15.0}}, ] - result = reduce_metrics_states(states) + metrics = reduce_metrics_states(states) + + assert len(metrics) == 3 # Convert to dict for easier testing - result_dict = {metric.key: metric.value for metric in result} - assert result_dict["loss"] == 30.0 / 5.0 # 6.0 - assert result_dict["accuracy"] == 15.0 - - # Also check reduction types - for metric in result: - if metric.key == "loss": - assert metric.reduction == Reduce.MEAN - elif metric.key == "accuracy": - assert metric.reduction == Reduce.SUM + result_dict = {m.key: (m.value, m.reduction) for m in metrics} + + # Check scalar reductions + assert result_dict["loss"][0] == 30.0 / 5.0 # 6.0 + assert result_dict["loss"][1] == Reduce.MEAN + assert result_dict["accuracy"][0] == 15.0 + assert result_dict["accuracy"][1] == Reduce.SUM + + # Check sample concatenation + assert result_dict["rollout/sample"][0] == [ + {"id": 1, "reward": 0.5}, + {"id": 2, "reward": 0.8}, + ] + assert result_dict["rollout/sample"][1] == Reduce.SAMPLE def test_mismatched_reduction_types_raises_error(self): """Test reduce_metrics_states raises error for mismatched reduction types.""" From e7e42b9f11cd29b7ccbd2e2ce2bda49fd2e565ae Mon Sep 17 00:00:00 2001 From: DNXie Date: Tue, 21 Oct 2025 19:02:38 -0700 Subject: [PATCH 03/17] fix error + some debug messages --- src/forge/observability/metric_actors.py | 3 +++ src/forge/observability/metrics.py | 30 ++++++++++++++++++------ 2 files changed, 26 insertions(+), 7 deletions(-) diff --git a/src/forge/observability/metric_actors.py b/src/forge/observability/metric_actors.py index 438db4f4a..20e5c228a 100644 --- a/src/forge/observability/metric_actors.py +++ b/src/forge/observability/metric_actors.py @@ -18,6 +18,7 @@ LoggerBackend, LoggingMode, MetricCollector, + Reduce, reduce_metrics_states, ) @@ -432,6 +433,7 @@ def extract_values_from_valuemesh(results) -> list[dict[str, Any]]: # Reduce metrics from states reduced_metrics = reduce_metrics_states(all_local_states) + print(f"[DEBUG] reduced_metrics: {reduced_metrics}") # Split into scalar metrics and sample metrics scalar_metrics = [ m for m in reduced_metrics if m.reduction != Reduce.SAMPLE @@ -443,6 +445,7 @@ def extract_values_from_valuemesh(results) -> list[dict[str, Any]]: # Log to global backends for backend_name, backend in self.global_logger_backends.items(): if scalar_metrics: + print(f"[DEBUG] calling log_batch from GlobalLoggerActor") await backend.log_batch(scalar_metrics, global_step) if sample_metrics: await backend.log_samples(sample_metrics, global_step) diff --git a/src/forge/observability/metrics.py b/src/forge/observability/metrics.py index c095aba36..65bda22f5 100644 --- a/src/forge/observability/metrics.py +++ b/src/forge/observability/metrics.py @@ -228,7 +228,15 @@ def record_episode_sample(table_name: str, episode): "pad_id": episode.pad_id, } + print( + "[DEBUG] Adding sample to table via record_metric, episode_id: ", + episode.episode_id, + ) record_metric(table_name, sample, Reduce.SAMPLE) + print( + "[DEBUG] Added sample to table via record_metric, episode_id: ", + episode.episode_id, + ) ################# @@ -499,11 +507,13 @@ def __init__(self, reduction: Reduce): super().__init__(reduction) self.samples: List[Dict[str, Any]] = [] self.filter = TopBottomKFilter() + self.is_reset = True def append(self, value: dict) -> None: if not isinstance(value, dict): raise ValueError(f"Expected dict, got {type(value)}") + self.is_reset = False # Only keep the sample if filter_append returns True if self.filter.filter_append(value): self.samples.append(value) @@ -511,7 +521,8 @@ def append(self, value: dict) -> None: def get_value(self) -> list[dict]: """Return locally collected (and optionally filtered) samples.""" # Apply flush-time filter (e.g. heap selection, threshold trimming) - return self.filter.filter_flush(self.samples) + results = self.filter.filter_flush(self.samples) + return results def get_state(self) -> Dict[str, Any]: """Serialize accumulator state for cross-rank reduction.""" @@ -530,6 +541,7 @@ def get_reduced_value_from_states(cls, states: List[Dict[str, Any]]) -> list[dic def reset(self) -> None: """Clear local samples and reset filter state.""" + self.is_reset = True self.samples.clear() self.filter.reset() @@ -701,12 +713,12 @@ def push(self, metric: Metric) -> None: # For PER_RANK_NO_REDUCE backends: stream without reduce for backend in self.per_rank_no_reduce_backends: - if metric.reduction == Reduce.SAMPLE: - # Wrap singleton Metric into expected {key: [list_of_dicts]} format - sample = {metric.key: [metric.value]} - asyncio.create_task(backend.log_samples(sample, self.global_step)) - else: - backend.log_stream(metric=metric, global_step=self.global_step) + # if metric.reduction == Reduce.SAMPLE: + # # Wrap singleton Metric into expected {key: [list_of_dicts]} format + # sample = {metric.key: [metric.value]} + # asyncio.create_task(backend.log_samples(sample, self.global_step)) + # else: + backend.log_stream(metric=metric, global_step=self.global_step) # Always accumulate for reduction and state return key = metric.key @@ -773,6 +785,7 @@ async def flush( for backend in self.per_rank_reduce_backends: if scalar_metrics: + print(f"[DEBUG] calling log_batch from MetricCollector") await backend.log_batch(scalar_metrics, global_step) if sample_metrics: await backend.log_samples(sample_metrics, global_step) @@ -895,6 +908,7 @@ async def init( async def log_batch( self, metrics: list[Metric], global_step: int, *args, **kwargs ) -> None: + print(f"[DEBUG] calling log_batch with {len(metrics)} metrics") metrics_str = "\n".join( f" {metric.key}: {metric.value}" for metric in sorted(metrics, key=lambda m: m.key) @@ -913,6 +927,8 @@ async def log_samples(self, samples: Dict[str, List[dict]], step: int) -> None: """Pretty-print sample-level logs to console.""" import json + print(f"[DEBUG] calling log_samples with {len(samples)} samples") + logger.info(f"========== SAMPLE LOGS STEP {step} ==========") for table_name, table_rows in samples.items(): logger.info(f"[{table_name}] ({len(table_rows)} samples)") From b2c8d88440be065f9b7ec2eda63e9cb33dc9c950 Mon Sep 17 00:00:00 2001 From: DNXie Date: Tue, 21 Oct 2025 19:29:51 -0700 Subject: [PATCH 04/17] fix hanging issue: missing entries --- apps/grpo/main.py | 4 ++++ src/forge/observability/metrics.py | 2 ++ 2 files changed, 6 insertions(+) diff --git a/apps/grpo/main.py b/apps/grpo/main.py index 06d2968dd..3f6171625 100644 --- a/apps/grpo/main.py +++ b/apps/grpo/main.py @@ -47,6 +47,8 @@ class Episode: request_len: int response_len: int target: Any | None = None + request: str | None = None + response: str | None = None # Processed data completion: Completion | None = None ref_logprobs: torch.Tensor | None = None @@ -391,6 +393,8 @@ async def continuous_rollouts(): request_len=max_req_tokens, response_len=max_res_tokens, target=target, + request=prompt, + response=response.text, completion=response, ) episode.reward_breakdown = await reward_actor.evaluate_response.route( diff --git a/src/forge/observability/metrics.py b/src/forge/observability/metrics.py index 65bda22f5..f16be3bd9 100644 --- a/src/forge/observability/metrics.py +++ b/src/forge/observability/metrics.py @@ -231,6 +231,8 @@ def record_episode_sample(table_name: str, episode): print( "[DEBUG] Adding sample to table via record_metric, episode_id: ", episode.episode_id, + # "episode: ", + # episode, ) record_metric(table_name, sample, Reduce.SAMPLE) print( From 98054fc10f75931c6443685d08e781ab34464395 Mon Sep 17 00:00:00 2001 From: DNXie Date: Fri, 24 Oct 2025 10:11:23 -0700 Subject: [PATCH 05/17] per_rank_no_reduce mode --- src/forge/observability/metrics.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/src/forge/observability/metrics.py b/src/forge/observability/metrics.py index f16be3bd9..7aea6cd1c 100644 --- a/src/forge/observability/metrics.py +++ b/src/forge/observability/metrics.py @@ -4,6 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import asyncio import heapq import itertools import logging @@ -715,12 +716,13 @@ def push(self, metric: Metric) -> None: # For PER_RANK_NO_REDUCE backends: stream without reduce for backend in self.per_rank_no_reduce_backends: - # if metric.reduction == Reduce.SAMPLE: - # # Wrap singleton Metric into expected {key: [list_of_dicts]} format - # sample = {metric.key: [metric.value]} - # asyncio.create_task(backend.log_samples(sample, self.global_step)) - # else: - backend.log_stream(metric=metric, global_step=self.global_step) + + if metric.reduction == Reduce.SAMPLE: + # Wrap singleton Metric into expected {key: [list_of_dicts]} format + sample = {metric.key: [metric.value]} + asyncio.create_task(backend.log_samples(sample, self.global_step)) + else: + backend.log_stream(metric=metric, global_step=self.global_step) # Always accumulate for reduction and state return key = metric.key From 1c8e0ce3668e7ffa25a86d7aee0cabf1f9760f7b Mon Sep 17 00:00:00 2001 From: DNXie Date: Fri, 24 Oct 2025 10:13:14 -0700 Subject: [PATCH 06/17] clean up debug prints --- src/forge/observability/metric_actors.py | 2 -- src/forge/observability/metrics.py | 17 +---------------- 2 files changed, 1 insertion(+), 18 deletions(-) diff --git a/src/forge/observability/metric_actors.py b/src/forge/observability/metric_actors.py index 20e5c228a..2fac2b1e9 100644 --- a/src/forge/observability/metric_actors.py +++ b/src/forge/observability/metric_actors.py @@ -433,7 +433,6 @@ def extract_values_from_valuemesh(results) -> list[dict[str, Any]]: # Reduce metrics from states reduced_metrics = reduce_metrics_states(all_local_states) - print(f"[DEBUG] reduced_metrics: {reduced_metrics}") # Split into scalar metrics and sample metrics scalar_metrics = [ m for m in reduced_metrics if m.reduction != Reduce.SAMPLE @@ -445,7 +444,6 @@ def extract_values_from_valuemesh(results) -> list[dict[str, Any]]: # Log to global backends for backend_name, backend in self.global_logger_backends.items(): if scalar_metrics: - print(f"[DEBUG] calling log_batch from GlobalLoggerActor") await backend.log_batch(scalar_metrics, global_step) if sample_metrics: await backend.log_samples(sample_metrics, global_step) diff --git a/src/forge/observability/metrics.py b/src/forge/observability/metrics.py index 7aea6cd1c..a42667c1e 100644 --- a/src/forge/observability/metrics.py +++ b/src/forge/observability/metrics.py @@ -228,18 +228,7 @@ def record_episode_sample(table_name: str, episode): "response_len": episode.response_len, "pad_id": episode.pad_id, } - - print( - "[DEBUG] Adding sample to table via record_metric, episode_id: ", - episode.episode_id, - # "episode: ", - # episode, - ) record_metric(table_name, sample, Reduce.SAMPLE) - print( - "[DEBUG] Added sample to table via record_metric, episode_id: ", - episode.episode_id, - ) ################# @@ -789,7 +778,6 @@ async def flush( for backend in self.per_rank_reduce_backends: if scalar_metrics: - print(f"[DEBUG] calling log_batch from MetricCollector") await backend.log_batch(scalar_metrics, global_step) if sample_metrics: await backend.log_samples(sample_metrics, global_step) @@ -912,7 +900,6 @@ async def init( async def log_batch( self, metrics: list[Metric], global_step: int, *args, **kwargs ) -> None: - print(f"[DEBUG] calling log_batch with {len(metrics)} metrics") metrics_str = "\n".join( f" {metric.key}: {metric.value}" for metric in sorted(metrics, key=lambda m: m.key) @@ -931,12 +918,10 @@ async def log_samples(self, samples: Dict[str, List[dict]], step: int) -> None: """Pretty-print sample-level logs to console.""" import json - print(f"[DEBUG] calling log_samples with {len(samples)} samples") - logger.info(f"========== SAMPLE LOGS STEP {step} ==========") for table_name, table_rows in samples.items(): logger.info(f"[{table_name}] ({len(table_rows)} samples)") - logger.info(json.dumps(table_rows, indent=2, ensure_ascii=False)) + logger.info(json.dumps(table_rows)) logger.info("==============================================\n") From 53e9c3a6ab233ec07d51a4c31763a7e90b00cde7 Mon Sep 17 00:00:00 2001 From: DNXie Date: Fri, 24 Oct 2025 10:17:01 -0700 Subject: [PATCH 07/17] json pprint --- src/forge/observability/metrics.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/forge/observability/metrics.py b/src/forge/observability/metrics.py index a42667c1e..59a0eef04 100644 --- a/src/forge/observability/metrics.py +++ b/src/forge/observability/metrics.py @@ -921,7 +921,7 @@ async def log_samples(self, samples: Dict[str, List[dict]], step: int) -> None: logger.info(f"========== SAMPLE LOGS STEP {step} ==========") for table_name, table_rows in samples.items(): logger.info(f"[{table_name}] ({len(table_rows)} samples)") - logger.info(json.dumps(table_rows)) + logger.info(json.dumps(table_rows, indent=2, ensure_ascii=False)) logger.info("==============================================\n") From 291cd1feb944fc91c595ced54e1eb162487c040d Mon Sep 17 00:00:00 2001 From: DNXie Date: Tue, 4 Nov 2025 17:38:16 -0800 Subject: [PATCH 08/17] move avg reward out --- apps/grpo/main.py | 13 +++++++------ src/forge/observability/metrics.py | 1 + 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/apps/grpo/main.py b/apps/grpo/main.py index 3f6171625..df57c3f2d 100644 --- a/apps/grpo/main.py +++ b/apps/grpo/main.py @@ -148,7 +148,7 @@ class RewardActor(ForgeActor): @endpoint async def evaluate_response( self, prompt: str, response: str, target: str - ) -> dict[str, float]: + ) -> (dict[str, float], float): total_rewards = 0.0 reward_breakdown = {} # reward breakdown by function for reward_fn in self.reward_functions: @@ -189,9 +189,8 @@ async def evaluate_response( Reduce.SUM, ) - avg_reward = total_rewards / len(self.reward_functions) - reward_breakdown["reward"] = avg_reward - return reward_breakdown + avg_reward: float = total_rewards / len(self.reward_functions) + return reward_breakdown, avg_reward @dataclass @@ -397,10 +396,12 @@ async def continuous_rollouts(): response=response.text, completion=response, ) - episode.reward_breakdown = await reward_actor.evaluate_response.route( + ( + episode.reward_breakdown, + episode.reward, + ) = await reward_actor.evaluate_response.route( prompt=prompt, response=response.text, target=target ) - episode.reward = episode.reward_breakdown["reward"] episodes.append(episode) # Build input_ids for reference logprobs diff --git a/src/forge/observability/metrics.py b/src/forge/observability/metrics.py index 59a0eef04..515392304 100644 --- a/src/forge/observability/metrics.py +++ b/src/forge/observability/metrics.py @@ -223,6 +223,7 @@ def record_episode_sample(table_name: str, episode): **( episode.reward_breakdown or {} ), # per-fn breakdown including the average reward + "reward": episode.reward, "advantage": episode.advantage, "request_len": episode.request_len, "response_len": episode.response_len, From 1cc4a2c829029326e7cc31269fd4fb39de395d6e Mon Sep 17 00:00:00 2001 From: DNXie Date: Tue, 4 Nov 2025 17:39:52 -0800 Subject: [PATCH 09/17] rename table name --- apps/grpo/main.py | 4 +++- src/forge/observability/metrics.py | 2 +- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/apps/grpo/main.py b/apps/grpo/main.py index df57c3f2d..4d18b1f18 100644 --- a/apps/grpo/main.py +++ b/apps/grpo/main.py @@ -423,7 +423,9 @@ async def continuous_rollouts(): for episode, advantage in zip(episodes, advantages): episode.advantage = advantage await replay_buffer.add.call_one(episode) - record_episode_sample("rollout/sample", episode) + record_episode_sample( + "main_samples/continuous_rollouts/sample_table", episode + ) rollout_count += 1 record_metric( diff --git a/src/forge/observability/metrics.py b/src/forge/observability/metrics.py index 515392304..3d928cf65 100644 --- a/src/forge/observability/metrics.py +++ b/src/forge/observability/metrics.py @@ -1100,7 +1100,7 @@ async def log_samples(self, samples: Dict[str, List[dict]], step: int) -> None: table.add_data(*values) # Log the same table object (INCREMENTAL update) - self.run.log({f"{table_name}_table": table}) + self.run.log({f"{table_name}": table}) logger.info( f"WandbBackend: Appended {len(table_rows)} rows to incremental table '{table_name}' at step {step}" ) From b413d7ba20e993034be06ea3aabbcbaa856a0039 Mon Sep 17 00:00:00 2001 From: DNXie Date: Tue, 4 Nov 2025 17:40:23 -0800 Subject: [PATCH 10/17] simplify docstring --- src/forge/observability/metrics.py | 32 ++++++------------------------ 1 file changed, 6 insertions(+), 26 deletions(-) diff --git a/src/forge/observability/metrics.py b/src/forge/observability/metrics.py index 3d928cf65..06021bd0e 100644 --- a/src/forge/observability/metrics.py +++ b/src/forge/observability/metrics.py @@ -140,32 +140,12 @@ def reduce_metrics_states(states: list[dict[str, dict[str, Any]]]) -> list[Metri list[Metric]: List of reduced metrics Example: - >>> states = [ - ... { - ... "loss": {"count": 5, "sum": 14, "reduction_type": "mean"}, - ... "reward/sample": { - ... "reduction_type": "sample", - ... "samples": [{"episode_id": 1, "reward": 0.5}], - ... }, - ... }, - ... { - ... "loss": {"count": 10, "sum": 16, "reduction_type": "mean"}, - ... "reward/sample": { - ... "reduction_type": "sample", - ... "samples": [{"episode_id": 2, "reward": 1.0}], - ... }, - ... }, - ... ] - >>> metrics = reduce_metrics_states(states) - >>> for m in metrics: - ... print(m) - Metric(key='loss', value=2.0, reduction=Reduce.MEAN) - Metric( - key='reward/sample', - value=[{'episode_id': 1, 'reward': 0.5}, - {'episode_id': 2, 'reward': 1.0}], - reduction=Reduce.SAMPLE, - ) + states = [ + {"loss": {"count": 5, "sum": 14, "reduction_type": Reduce.MEAN}}, + {"loss": {"count": 10, "sum": 16, "reduction_type": Reduce.MEAN}}, + ] + reduce_metrics_states(states) + >>> [Metric(key="loss", value=2.0, reduction=Reduce.MEAN)] Raises: ValueError: on mismatched reduction types for the same metric key. From 0635c9a7dc2297faf5e8c7111a9f45b07ef4cd4b Mon Sep 17 00:00:00 2001 From: DNXie Date: Tue, 4 Nov 2025 17:45:47 -0800 Subject: [PATCH 11/17] add to_dict --- apps/grpo/main.py | 24 ++++++++++++++++++++++++ src/forge/observability/metrics.py | 16 +--------------- 2 files changed, 25 insertions(+), 15 deletions(-) diff --git a/apps/grpo/main.py b/apps/grpo/main.py index 4d18b1f18..5bec30e37 100644 --- a/apps/grpo/main.py +++ b/apps/grpo/main.py @@ -76,6 +76,30 @@ def response_tensor(self) -> torch.Tensor: tensor = F.pad(tensor, (0, diff), value=self.pad_id) return tensor + def to_dict(self, exclude: list[str] | None = None) -> dict[str, Any]: + """Convert episode to dict, optionally excluding specified fields.""" + result = { + "episode_id": self.episode_id, + "policy_version": self.policy_version, + "prompt": self.request, + "response": self.response, + "target": str(self.target), + "reward": self.reward, + "advantage": self.advantage, + "request_len": self.request_len, + "response_len": self.response_len, + "pad_id": self.pad_id, + } + + if self.reward_breakdown is not None: + result.update(self.reward_breakdown) + + if exclude: + for key in exclude: + result.pop(key, None) + + return result + # Represents the group (G) of episodes in GRPO Group = list[Episode] diff --git a/src/forge/observability/metrics.py b/src/forge/observability/metrics.py index 06021bd0e..a4afedfeb 100644 --- a/src/forge/observability/metrics.py +++ b/src/forge/observability/metrics.py @@ -194,21 +194,7 @@ def record_episode_sample(table_name: str, episode): table_name (str): logging prefix (e.g. "rollout/sample"). episode (Episode): episode object with filled attributes. """ - sample = { - "episode_id": episode.episode_id, - "policy_version": episode.policy_version, - "prompt": episode.request, - "response": episode.response, - "target": str(episode.target), - **( - episode.reward_breakdown or {} - ), # per-fn breakdown including the average reward - "reward": episode.reward, - "advantage": episode.advantage, - "request_len": episode.request_len, - "response_len": episode.response_len, - "pad_id": episode.pad_id, - } + sample = episode.to_dict() record_metric(table_name, sample, Reduce.SAMPLE) From 755d63202adfa54b32981f7a89350b017dd35670 Mon Sep 17 00:00:00 2001 From: DNXie Date: Tue, 4 Nov 2025 17:51:01 -0800 Subject: [PATCH 12/17] some misc --- src/forge/observability/metrics.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/src/forge/observability/metrics.py b/src/forge/observability/metrics.py index a4afedfeb..f88e67f7b 100644 --- a/src/forge/observability/metrics.py +++ b/src/forge/observability/metrics.py @@ -1054,7 +1054,7 @@ async def log_samples(self, samples: Dict[str, List[dict]], step: int) -> None: columns = list(table_rows[0].keys()) table = wandb.Table(columns=columns, log_mode="INCREMENTAL") self._tables[table_name] = table - logger.info( + logger.debug( f"WandbBackend: Created new incremental table: {table_name}" ) else: @@ -1066,10 +1066,10 @@ async def log_samples(self, samples: Dict[str, List[dict]], step: int) -> None: table.add_data(*values) # Log the same table object (INCREMENTAL update) + # table_name has to end with _table to be recognized by wandb + if not table_name.endswith("_table"): + table_name += "_table" self.run.log({f"{table_name}": table}) - logger.info( - f"WandbBackend: Appended {len(table_rows)} rows to incremental table '{table_name}' at step {step}" - ) def get_metadata_for_secondary_ranks(self) -> dict[str, Any]: if self.run and self.per_rank_share_run: @@ -1080,7 +1080,11 @@ async def finish(self) -> None: import wandb if self.run: - # Convert each incremental table to immutable before finishing + """ + Convert each incremental table to immutable before finishing + as recommended by wandb: + https://docs.wandb.ai/models/tables/log_tables#incremental-mode + """ for table_name, incr_table in self._tables.items(): final_table = wandb.Table( columns=incr_table.columns, @@ -1088,7 +1092,7 @@ async def finish(self) -> None: log_mode="IMMUTABLE", ) self.run.log({table_name: final_table}) - logger.info(f"WandbBackend: Finalized table {table_name}") + logger.debug(f"WandbBackend: Finalized table {table_name}") self.run.finish() logger.info(f"WandbBackend {self.process_name}: Finished run") From 9d2a0cb4389a7f33df95b4fab5d3e78a65b74956 Mon Sep 17 00:00:00 2001 From: DNXie Date: Tue, 4 Nov 2025 17:52:28 -0800 Subject: [PATCH 13/17] iterate over rows to get columns --- src/forge/observability/metrics.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/src/forge/observability/metrics.py b/src/forge/observability/metrics.py index f88e67f7b..8c61c0c74 100644 --- a/src/forge/observability/metrics.py +++ b/src/forge/observability/metrics.py @@ -7,6 +7,7 @@ import asyncio import heapq import itertools +import json import logging import os from abc import ABC, abstractmethod @@ -883,7 +884,6 @@ async def finish(self) -> None: async def log_samples(self, samples: Dict[str, List[dict]], step: int) -> None: """Pretty-print sample-level logs to console.""" - import json logger.info(f"========== SAMPLE LOGS STEP {step} ==========") for table_name, table_rows in samples.items(): @@ -1051,11 +1051,15 @@ async def log_samples(self, samples: Dict[str, List[dict]], step: int) -> None: # If table doesn't exist yet, create it in INCREMENTAL mode if table_name not in self._tables: - columns = list(table_rows[0].keys()) + # Collect all unique columns from all rows + columns = set() + for row in table_rows: + columns.update(row.keys()) + columns = sorted(columns) # Sort for consistent column ordering table = wandb.Table(columns=columns, log_mode="INCREMENTAL") self._tables[table_name] = table logger.debug( - f"WandbBackend: Created new incremental table: {table_name}" + f"WandbBackend: Created new incremental table: {table_name} with columns: {columns}" ) else: table = self._tables[table_name] From 2b0496e2dd85ec4c271e21e41fdfd4c93b7d9ea0 Mon Sep 17 00:00:00 2001 From: DNXie Date: Tue, 4 Nov 2025 20:01:14 -0800 Subject: [PATCH 14/17] log_samples take list of metirc --- apps/grpo/main.py | 4 +++- src/forge/observability/metric_actors.py | 2 +- src/forge/observability/metrics.py | 16 ++++++++-------- 3 files changed, 12 insertions(+), 10 deletions(-) diff --git a/apps/grpo/main.py b/apps/grpo/main.py index 5bec30e37..a9788fb87 100644 --- a/apps/grpo/main.py +++ b/apps/grpo/main.py @@ -89,9 +89,11 @@ def to_dict(self, exclude: list[str] | None = None) -> dict[str, Any]: "request_len": self.request_len, "response_len": self.response_len, "pad_id": self.pad_id, + "ref_logprobs": self.ref_logprobs, + "completion": self.completion, } - if self.reward_breakdown is not None: + if self.reward_breakdown is not None and "reward_breakdown" not in exclude: result.update(self.reward_breakdown) if exclude: diff --git a/src/forge/observability/metric_actors.py b/src/forge/observability/metric_actors.py index 2fac2b1e9..346cfd78a 100644 --- a/src/forge/observability/metric_actors.py +++ b/src/forge/observability/metric_actors.py @@ -438,7 +438,7 @@ def extract_values_from_valuemesh(results) -> list[dict[str, Any]]: m for m in reduced_metrics if m.reduction != Reduce.SAMPLE ] sample_metrics = { - m.key: m.value for m in reduced_metrics if m.reduction == Reduce.SAMPLE + m for m in reduced_metrics if m.reduction == Reduce.SAMPLE } # Log to global backends diff --git a/src/forge/observability/metrics.py b/src/forge/observability/metrics.py index 8c61c0c74..f68bbf1ef 100644 --- a/src/forge/observability/metrics.py +++ b/src/forge/observability/metrics.py @@ -195,7 +195,7 @@ def record_episode_sample(table_name: str, episode): table_name (str): logging prefix (e.g. "rollout/sample"). episode (Episode): episode object with filled attributes. """ - sample = episode.to_dict() + sample = episode.to_dict(exclude=["ref_logprobs", "completion"]) record_metric(table_name, sample, Reduce.SAMPLE) @@ -675,9 +675,7 @@ def push(self, metric: Metric) -> None: for backend in self.per_rank_no_reduce_backends: if metric.reduction == Reduce.SAMPLE: - # Wrap singleton Metric into expected {key: [list_of_dicts]} format - sample = {metric.key: [metric.value]} - asyncio.create_task(backend.log_samples(sample, self.global_step)) + asyncio.create_task(backend.log_samples([metric], self.global_step)) else: backend.log_stream(metric=metric, global_step=self.global_step) @@ -882,11 +880,12 @@ def log_stream(self, metric: Metric, global_step: int, *args, **kwargs) -> None: async def finish(self) -> None: pass - async def log_samples(self, samples: Dict[str, List[dict]], step: int) -> None: + async def log_samples(self, samples: List[Metric], step: int) -> None: """Pretty-print sample-level logs to console.""" logger.info(f"========== SAMPLE LOGS STEP {step} ==========") - for table_name, table_rows in samples.items(): + for sample in samples: + table_name, table_rows = sample.key, sample.value logger.info(f"[{table_name}] ({len(table_rows)} samples)") logger.info(json.dumps(table_rows, indent=2, ensure_ascii=False)) logger.info("==============================================\n") @@ -1038,14 +1037,15 @@ def log_stream(self, metric: Metric, global_step: int, *args, **kwargs) -> None: # note: here we dont use step since wandb keeps only the latest value for each step self.run.log(log_data) - async def log_samples(self, samples: Dict[str, List[dict]], step: int) -> None: + async def log_samples(self, samples: List[Metric], step: int) -> None: """Log sample-level data incrementally to persistent WandB Tables.""" import wandb if not self.run: return - for table_name, table_rows in samples.items(): + for sample in samples: + table_name, table_rows = sample.key, sample.value if not table_rows: continue From 62fd0cc66b4b8556f7adbcc767e64ccb9b18a7b0 Mon Sep 17 00:00:00 2001 From: DNXie Date: Tue, 4 Nov 2025 20:24:49 -0800 Subject: [PATCH 15/17] merge filter into sampler --- src/forge/observability/__init__.py | 3 - src/forge/observability/metric_actors.py | 4 +- src/forge/observability/metrics.py | 98 +++++++++--------------- 3 files changed, 39 insertions(+), 66 deletions(-) diff --git a/src/forge/observability/__init__.py b/src/forge/observability/__init__.py index e80bd1860..17a666aa2 100644 --- a/src/forge/observability/__init__.py +++ b/src/forge/observability/__init__.py @@ -28,7 +28,6 @@ SampleAccumulator, StdAccumulator, SumAccumulator, - TopBottomKFilter, WandbBackend, ) from .perf_tracker import trace, Tracer @@ -69,6 +68,4 @@ "MinAccumulator", "StdAccumulator", "SampleAccumulator", - # Filter classes - "TopBottomKFilter", ] diff --git a/src/forge/observability/metric_actors.py b/src/forge/observability/metric_actors.py index 346cfd78a..c4ba5f142 100644 --- a/src/forge/observability/metric_actors.py +++ b/src/forge/observability/metric_actors.py @@ -437,9 +437,9 @@ def extract_values_from_valuemesh(results) -> list[dict[str, Any]]: scalar_metrics = [ m for m in reduced_metrics if m.reduction != Reduce.SAMPLE ] - sample_metrics = { + sample_metrics = [ m for m in reduced_metrics if m.reduction == Reduce.SAMPLE - } + ] # Log to global backends for backend_name, backend in self.global_logger_backends.items(): diff --git a/src/forge/observability/metrics.py b/src/forge/observability/metrics.py index f68bbf1ef..4bda26e5d 100644 --- a/src/forge/observability/metrics.py +++ b/src/forge/observability/metrics.py @@ -199,55 +199,6 @@ def record_episode_sample(table_name: str, episode): record_metric(table_name, sample, Reduce.SAMPLE) -################# -# SampleFilters # -################# - - -class TopBottomKFilter: - """Keep the top-k and bottom-k samples by a given key (e.g., reward).""" - - def __init__(self, top_k=1, bottom_k=1, key="reward"): - self.top_k = top_k - self.bottom_k = bottom_k - self.key = key - self._top_heap = [] # min-heap for top-k - self._bottom_heap = [] # max-heap for bottom-k (store -value) - self._counter = itertools.count() # tie-breaker id generator - - def filter_append(self, sample: Dict) -> bool: - val = sample.get(self.key, 0.0) - idx = next(self._counter) # unique tiebreaker - - # If top_k or bottom_k <= 0, it means "disable" that side of filtering (i.e., keep none). - # maintain top-k - if self.top_k > 0: - if len(self._top_heap) < self.top_k: - heapq.heappush(self._top_heap, (val, idx, sample)) - else: - heapq.heappushpop(self._top_heap, (val, idx, sample)) - - # maintain bottom-k - if self.bottom_k > 0: - if len(self._bottom_heap) < self.bottom_k: - heapq.heappush(self._bottom_heap, (-val, idx, sample)) - else: - heapq.heappushpop(self._bottom_heap, (-val, idx, sample)) - - # always return False here because we don't store in samples list - return False - - def filter_flush(self, samples: List[Dict]) -> List[Dict]: - tops = [s for _, _, s in self._top_heap] - bottoms = [s for _, _, s in self._bottom_heap] - return bottoms + tops - - def reset(self): - self._top_heap = [] - self._bottom_heap = [] - self._counter = itertools.count() - - ################ # Accumulators # ################ @@ -459,14 +410,23 @@ def reset(self) -> None: class SampleAccumulator(MetricAccumulator): - """Accumulator for sample-level metrics (e.g., prompt/response/reward dicts). - Optionally uses a sample filter to decide what to keep at append/flush time. + """Accumulator for sample-level metrics with top-k and bottom-k filtering. + + Keeps the top-k and bottom-k samples by a given key (e.g., reward). + Useful for logging only the best and worst samples from a batch. """ - def __init__(self, reduction: Reduce): + def __init__( + self, reduction: Reduce, top_k: int = 1, bottom_k: int = 1, key: str = "reward" + ): super().__init__(reduction) self.samples: List[Dict[str, Any]] = [] - self.filter = TopBottomKFilter() + self.top_k = top_k + self.bottom_k = bottom_k + self.key = key + self._top_heap = [] # min-heap for top-k + self._bottom_heap = [] # max-heap for bottom-k (store -value) + self._counter = itertools.count() # tie-breaker id generator self.is_reset = True def append(self, value: dict) -> None: @@ -474,15 +434,29 @@ def append(self, value: dict) -> None: raise ValueError(f"Expected dict, got {type(value)}") self.is_reset = False - # Only keep the sample if filter_append returns True - if self.filter.filter_append(value): - self.samples.append(value) + val = value.get(self.key, 0.0) + idx = next(self._counter) # unique tiebreaker + + # If top_k or bottom_k <= 0, it means "disable" that side of filtering (i.e., keep none). + # maintain top-k + if self.top_k > 0: + if len(self._top_heap) < self.top_k: + heapq.heappush(self._top_heap, (val, idx, value)) + else: + heapq.heappushpop(self._top_heap, (val, idx, value)) + + # maintain bottom-k + if self.bottom_k > 0: + if len(self._bottom_heap) < self.bottom_k: + heapq.heappush(self._bottom_heap, (-val, idx, value)) + else: + heapq.heappushpop(self._bottom_heap, (-val, idx, value)) def get_value(self) -> list[dict]: - """Return locally collected (and optionally filtered) samples.""" - # Apply flush-time filter (e.g. heap selection, threshold trimming) - results = self.filter.filter_flush(self.samples) - return results + """Return top-k and bottom-k filtered samples.""" + tops = [s for _, _, s in self._top_heap] + bottoms = [s for _, _, s in self._bottom_heap] + return bottoms + tops def get_state(self) -> Dict[str, Any]: """Serialize accumulator state for cross-rank reduction.""" @@ -503,7 +477,9 @@ def reset(self) -> None: """Clear local samples and reset filter state.""" self.is_reset = True self.samples.clear() - self.filter.reset() + self._top_heap = [] + self._bottom_heap = [] + self._counter = itertools.count() ############# From 38326d72e52ef00e2a2b6de8cc634095a58a1305 Mon Sep 17 00:00:00 2001 From: DNXie Date: Wed, 5 Nov 2025 10:01:43 -0800 Subject: [PATCH 16/17] debug --- src/forge/observability/metrics.py | 36 +++++++++++++++++++++--------- 1 file changed, 25 insertions(+), 11 deletions(-) diff --git a/src/forge/observability/metrics.py b/src/forge/observability/metrics.py index 4bda26e5d..ee16a5419 100644 --- a/src/forge/observability/metrics.py +++ b/src/forge/observability/metrics.py @@ -141,12 +141,25 @@ def reduce_metrics_states(states: list[dict[str, dict[str, Any]]]) -> list[Metri list[Metric]: List of reduced metrics Example: - states = [ - {"loss": {"count": 5, "sum": 14, "reduction_type": Reduce.MEAN}}, - {"loss": {"count": 10, "sum": 16, "reduction_type": Reduce.MEAN}}, + >>> states = [ + ... { + ... "loss": {"count": 5, "sum": 14, "reduction_type": "mean"}, + ... "reward/sample": { + ... "reduction_type": "sample", + ... "samples": [{"episode_id": 1, "reward": 0.5}], + ... }, + ... }, + ... {"loss": {"count": 10, "sum": 16, "reduction_type": Reduce.MEAN}}, + ... ] + >>> reduce_metrics_states(states) + [ + Metric(key='loss', value=2.0, reduction=Reduce.MEAN), + Metric( + key='reward/sample', + value=[{'episode_id': 1, 'reward': 0.5}], + reduction=Reduce.SAMPLE, + ) ] - reduce_metrics_states(states) - >>> [Metric(key="loss", value=2.0, reduction=Reduce.MEAN)] Raises: ValueError: on mismatched reduction types for the same metric key. @@ -649,7 +662,6 @@ def push(self, metric: Metric) -> None: # For PER_RANK_NO_REDUCE backends: stream without reduce for backend in self.per_rank_no_reduce_backends: - if metric.reduction == Reduce.SAMPLE: asyncio.create_task(backend.log_samples([metric], self.global_step)) else: @@ -712,11 +724,9 @@ async def flush( scalar_metrics = [ m for m in metrics_for_backends if m.reduction != Reduce.SAMPLE ] - sample_metrics = { - m.key: m.value - for m in metrics_for_backends - if m.reduction == Reduce.SAMPLE - } + sample_metrics = [ + m for m in metrics_for_backends if m.reduction == Reduce.SAMPLE + ] for backend in self.per_rank_reduce_backends: if scalar_metrics: @@ -1025,6 +1035,10 @@ async def log_samples(self, samples: List[Metric], step: int) -> None: if not table_rows: continue + # Convert to list if single sample. This happens when logging stream + if isinstance(table_rows, dict): + table_rows = [table_rows] + # If table doesn't exist yet, create it in INCREMENTAL mode if table_name not in self._tables: # Collect all unique columns from all rows From 14f0a0f8fd5d3a80f3ddc4130ec7600c18f6f719 Mon Sep 17 00:00:00 2001 From: DNXie Date: Mon, 17 Nov 2025 10:23:36 -0800 Subject: [PATCH 17/17] resolve comment --- apps/grpo/main.py | 11 ++++++++--- src/forge/observability/__init__.py | 2 -- src/forge/observability/metrics.py | 22 ++++++++-------------- 3 files changed, 16 insertions(+), 19 deletions(-) diff --git a/apps/grpo/main.py b/apps/grpo/main.py index a9788fb87..accce1a83 100644 --- a/apps/grpo/main.py +++ b/apps/grpo/main.py @@ -29,7 +29,7 @@ from forge.data.rewards import MathReward, ThinkingReward from forge.data_models.completion import Completion from forge.observability.metric_actors import get_or_create_metric_logger -from forge.observability.metrics import record_episode_sample, record_metric, Reduce +from forge.observability.metrics import record_metric, Reduce from forge.observability.perf_tracker import Tracer from forge.types import LauncherConfig, ProvisionerConfig @@ -449,8 +449,13 @@ async def continuous_rollouts(): for episode, advantage in zip(episodes, advantages): episode.advantage = advantage await replay_buffer.add.call_one(episode) - record_episode_sample( - "main_samples/continuous_rollouts/sample_table", episode + + sample = episode.to_dict(exclude=["ref_logprobs", "completion"]) + sample["score"] = sample["reward"] + record_metric( + "main_samples/continuous_rollouts/sample_table", + sample, + Reduce.SAMPLE, ) rollout_count += 1 diff --git a/src/forge/observability/__init__.py b/src/forge/observability/__init__.py index 17a666aa2..988673e3c 100644 --- a/src/forge/observability/__init__.py +++ b/src/forge/observability/__init__.py @@ -21,7 +21,6 @@ MetricAccumulator, MetricCollector, MinAccumulator, - record_episode_sample, record_metric, Reduce, reduce_metrics_states, @@ -37,7 +36,6 @@ # Main API functions "record_metric", "reduce_metrics_states", - "record_episode_sample", "get_logger_backend_class", "get_or_create_metric_logger", # Performance tracking diff --git a/src/forge/observability/metrics.py b/src/forge/observability/metrics.py index ee16a5419..5d14cf34a 100644 --- a/src/forge/observability/metrics.py +++ b/src/forge/observability/metrics.py @@ -201,17 +201,6 @@ def reduce_metrics_states(states: list[dict[str, dict[str, Any]]]) -> list[Metri return reduced_metrics -def record_episode_sample(table_name: str, episode): - """ - Record a structured sample-level log for a single episode. - Args: - table_name (str): logging prefix (e.g. "rollout/sample"). - episode (Episode): episode object with filled attributes. - """ - sample = episode.to_dict(exclude=["ref_logprobs", "completion"]) - record_metric(table_name, sample, Reduce.SAMPLE) - - ################ # Accumulators # ################ @@ -430,7 +419,7 @@ class SampleAccumulator(MetricAccumulator): """ def __init__( - self, reduction: Reduce, top_k: int = 1, bottom_k: int = 1, key: str = "reward" + self, reduction: Reduce, top_k: int = 1, bottom_k: int = 1, key: str = "score" ): super().__init__(reduction) self.samples: List[Dict[str, Any]] = [] @@ -869,12 +858,10 @@ async def finish(self) -> None: async def log_samples(self, samples: List[Metric], step: int) -> None: """Pretty-print sample-level logs to console.""" - logger.info(f"========== SAMPLE LOGS STEP {step} ==========") for sample in samples: table_name, table_rows = sample.key, sample.value logger.info(f"[{table_name}] ({len(table_rows)} samples)") logger.info(json.dumps(table_rows, indent=2, ensure_ascii=False)) - logger.info("==============================================\n") class WandbBackend(LoggerBackend): @@ -1056,6 +1043,13 @@ async def log_samples(self, samples: List[Metric], step: int) -> None: # Add rows (fill missing columns with None) for s in table_rows: + # Check for extra columns not in the table schema + extra_columns = set(s.keys()) - set(table.columns) + if extra_columns: + logger.warning( + f"WandbBackend: Row has extra columns not in table '{table_name}': {sorted(extra_columns)}. " + f"These will be ignored." + ) values = [s.get(c) for c in table.columns] table.add_data(*values)