Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
033fa60
Add LanguageReward for training models to think in target language
casteryh Oct 31, 2025
b12ed15
Update system prompt to instruct model to think in Japanese
casteryh Oct 31, 2025
b15f171
Add fallback reward for correct language without thinking blocks
casteryh Oct 31, 2025
a2c0237
Add debug logging and troubleshooting guide for LanguageReward
casteryh Oct 31, 2025
afca75c
Add debug printing to LanguageReward and strengthen system prompt
casteryh Oct 31, 2025
2625b28
Refactor to use configurable Japanese tags <思考> instead of English <t…
casteryh Oct 31, 2025
1a4d5fb
Remove old debug code from main.py
casteryh Oct 31, 2025
4e87a4d
Weaken system prompt to rely more on RL rewards
casteryh Oct 31, 2025
abb653e
Remove sandbox config and reference apps/grpo configs instead
casteryh Oct 31, 2025
7b4829c
Simplify LanguageReward logic to focus on language detection only
casteryh Oct 31, 2025
0ed798c
Add langid to dev dependencies for CI
casteryh Oct 31, 2025
5a3193e
Remove debug script
casteryh Oct 31, 2025
93a65b2
Clarify why English training won't work in TROUBLESHOOTING
casteryh Nov 1, 2025
f72be7f
Add unit test for ThinkingReward custom tag
casteryh Nov 1, 2025
6186f9f
Bump LanguageReward match_reward to 2.0
casteryh Nov 1, 2025
c640d37
Set KL divergence coefficient to zero in loss function
casteryh Nov 1, 2025
7fde86d
Change KL divergence coefficient to 1e-3
casteryh Nov 1, 2025
ffb6c43
Change KL divergence coefficient to 1e-4
casteryh Nov 1, 2025
7ffa20e
Enable multi-epoch training in sandbox/grpo_language app
casteryh Nov 2, 2025
1bf3cca
Fix recursive endpoint call - use while loop instead
casteryh Nov 2, 2025
7758b48
Simplify multi-epoch fix - use return next() instead of while loop
casteryh Nov 2, 2025
f71dbb6
fix
casteryh Nov 3, 2025
735af9a
change logging
casteryh Nov 3, 2025
ef39e46
git mv
casteryh Nov 3, 2025
2e2981a
Update src/forge/data/rewards.py
casteryh Nov 19, 2025
869c76d
remove README.md & TROUBLESHOOTING.md
casteryh Nov 19, 2025
5f6a788
truncate at 1000 instead
casteryh Nov 19, 2025
04182d4
merge into main app & fix test
casteryh Nov 19, 2025
dc67ea2
merge into main app
casteryh Nov 19, 2025
d02b866
Merge branch 'main' into language-reward-feature
casteryh Nov 20, 2025
866dff7
fix endpoint
casteryh Nov 20, 2025
20d8644
Merge branch 'main' into language-reward-feature
casteryh Nov 20, 2025
ddefe22
update config for 1.7b
casteryh Nov 21, 2025
22891bf
remove unnecessary tests
casteryh Nov 21, 2025
bf3dbf0
simplify sampling, add todo
casteryh Nov 21, 2025
4318d83
increase context len, decrease bsz for 8b (working)
casteryh Nov 21, 2025
2cc6c6b
kl coeff = 1e-5
casteryh Nov 21, 2025
e7d1bfc
fix rewards
casteryh Nov 21, 2025
5d051bd
fix debug samples
casteryh Nov 21, 2025
71cf182
fix test
casteryh Nov 21, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 22 additions & 7 deletions apps/grpo/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from forge.actors.trainer import TitanTrainer
from forge.controller.actor import ForgeActor
from forge.controller.provisioner import init_provisioner, shutdown
from forge.data.rewards import MathReward, ThinkingReward
from forge.data.rewards import LanguageReward, MathReward, ThinkingReward
from forge.data_models.completion import Completion
from forge.observability.metric_actors import get_or_create_metric_logger
from forge.observability.metrics import record_metric, Reduce
Expand Down Expand Up @@ -129,7 +129,7 @@ def simple_grpo_loss(
ref_logprobs: torch.Tensor,
advantages: torch.Tensor,
padding_mask: torch.Tensor,
beta: float = 0.1,
beta: float = 1e-5,
) -> torch.Tensor:
logprobs: torch.Tensor = compute_logprobs(logits, response)
kl = torch.exp(ref_logprobs - logprobs) - (ref_logprobs - logprobs) - 1
Expand Down Expand Up @@ -237,10 +237,15 @@ async def setup(self):
self._epoch = 0

def gsm8k_transform(sample):
system_prompt = """
Put all your scratchpad work between <think> and </think> tags.
Your final answer should be between <answer> and </answer> tags otherwise it will not be scored.
"""
system_prompt = """You are a helpful AI assistant that solves math problems.

Please show your reasoning inside <思考></思考> tags, then provide your final numerical answer inside <answer></answer> tags.

Example:
Question: What is 12 + 5?
<思考>12と5を足します。12 + 5 = 17です。</思考>
<answer>17</answer>
"""
request: str = sample["question"]
as_chat = [
{"role": "system", "content": system_prompt},
Expand Down Expand Up @@ -359,7 +364,17 @@ async def main(cfg: DictConfig):
ComputeAdvantages.options(**cfg.actors.compute_advantages).as_actor(),
ReferenceModel.options(**cfg.services.ref_model).as_service(**cfg.ref_model),
RewardActor.options(**cfg.services.reward_actor).as_service(
reward_functions=[MathReward(), ThinkingReward()]
reward_functions=[
MathReward(),
ThinkingReward(tag="思考"), # Use Japanese tag
LanguageReward(
target_language="ja",
tag="思考",
match_reward=2.0,
debug=True,
debug_sample_rate=0.1,
), # Japanese language reward with debug
]
),
)

Expand Down
2 changes: 1 addition & 1 deletion apps/grpo/qwen3_1_7b.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
group_size: 8
local_batch_size: 16 # per-device batch size
max_req_tokens: 1024
max_res_tokens: 1024
max_res_tokens: 2048
model: "Qwen/Qwen3-1.7B"
off_by_n: 1 # Off by one by default

Expand Down
4 changes: 2 additions & 2 deletions apps/grpo/qwen3_8b.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@

# Global configuration
group_size: 8
local_batch_size: 12 # per-device batch size
local_batch_size: 8 # per-device batch size
max_req_tokens: 1024
max_res_tokens: 1024
max_res_tokens: 2048
model: "Qwen/Qwen3-8B"
off_by_n: 1 # Off by one by default

Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ dev = [
"anyio",
"pytest-asyncio",
"multiprocess",
"langid",
]
docs = [
"sphinx==7.2.6",
Expand Down
147 changes: 143 additions & 4 deletions src/forge/data/rewards.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import random
import re


Expand Down Expand Up @@ -57,15 +58,28 @@ def _to_float(self, text: str) -> float | None:


class ThinkingReward:
"""Reward class for evaluating use of <think> tags in reasoning."""
"""Reward class for evaluating use of thinking tags in reasoning.

def __init__(self, partial_reward: float = 0.2, full_reward: float = 1.0):
Args:
partial_reward: Reward for partial tag usage (incomplete/malformed)
full_reward: Reward for well-formed thinking blocks with content
tag: Tag name to use (default "think", can use "思考" for Japanese, etc.)
"""

def __init__(
self, partial_reward: float = 0.2, full_reward: float = 1.0, tag: str = "think"
):
self.partial_reward = partial_reward
self.full_reward = full_reward
self.tag = tag
# Build regex patterns for the specified tag
self._THINK_BLOCK_RE = re.compile(
r"<\s*think\s*>(.*?)<\s*/\s*think\s*>", re.IGNORECASE | re.DOTALL
rf"<\s*{re.escape(tag)}\s*>(.*?)<\s*/\s*{re.escape(tag)}\s*>",
re.IGNORECASE | re.DOTALL,
)
self._THINK_TAG_ATTEMPT_RE = re.compile(
rf"<\s*/?\s*{re.escape(tag)}\s*>", re.IGNORECASE
)
self._THINK_TAG_ATTEMPT_RE = re.compile(r"<\s*/?\s*think\s*>", re.IGNORECASE)

def __call__(self, prompt: str, response: str, target: str | None = None) -> float:
"""Compute thinking reward."""
Expand All @@ -80,3 +94,128 @@ def __call__(self, prompt: str, response: str, target: str | None = None) -> flo
elif has_attempt:
return self.partial_reward
return 0.0


class LanguageReward:
"""Reward class for evaluating the language used in responses.

This reward uses langid to detect the language and rewards responses that use
the target language. The detection strategy depends on the format:
- If exactly one thinking block: detect language of the block content
- Otherwise (no blocks or multiple blocks): detect language of whole response

Note: Format enforcement (single vs multiple blocks) is handled by ThinkingReward.
This reward focuses purely on language detection.

Args:
target_language: ISO 639-1 language code (e.g., 'en', 'ja', 'zh', 'es')
match_reward: Reward when detected language matches target (default: 1.0)
no_match_reward: Reward when language doesn't match (default: 0.0)
tag: Tag name to use (default "思考" for multilingual, can use "think", etc.)
debug: If True, print debug samples showing model outputs and detected language
debug_sample_rate: Fraction of calls to debug (e.g., 0.1 = 10% of calls)

Note: Requires langid to be installed. Install with: pip install langid
"""

def __init__(
self,
target_language: str = "ja",
match_reward: float = 1.0,
no_match_reward: float = 0.0,
tag: str = "思考",
debug: bool = False,
debug_sample_rate: float = 0.1,
):
self.target_language = target_language
self.match_reward = match_reward
self.no_match_reward = no_match_reward
self.tag = tag
self.debug = debug
self.debug_sample_rate = debug_sample_rate
self._debug_counter = 0
# Build regex pattern for the specified tag
self._THINK_BLOCK_RE = re.compile(
rf"<\s*{re.escape(tag)}\s*>(.*?)<\s*/\s*{re.escape(tag)}\s*>", re.DOTALL
)

# Lazy import langid with helpful error message
try:
import langid

self._langid = langid
except ImportError:
raise ImportError(
"langid is required for LanguageReward but is not installed. "
"Please install it with: pip install langid"
) from None
Comment on lines +147 to +151
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do you want to surface import error that late into the program? If this is not a "default" app we want everyone to install, is there better option to manage the module dependencies? I think we'll get to this point soon to support different training backend.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do you want to surface import error that late into the program?

I don't know a better place to check import error other than the initializer.

If this is not a "default" app we want everyone to install, is there better option to manage the module dependencies?

One option is to add this as an optional dependency in pyproject.toml. But I am not sure if we should do this in this particular case.

I think we'll get to this point soon to support different training backend.

In this case, I think optional dependency is the way to go. So the use can run
pip install forge[huggingface] to enable the hugging face trainer for example.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @joecummings what do you think?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I understand the rationale -- is it possible to do something like pip install forge[grpo] for this dependency? We'll need to turn the app into integration tests that run on CI and we need a way to specify the dependency statically.


def __call__(self, prompt: str, response: str, target: str | None = None) -> float:
"""Compute language reward based on detected language.

Detection strategy:
- If exactly one thinking block: detect language of block content
- Otherwise: detect language of whole response

Args:
prompt: The input prompt (unused but kept for signature consistency)
response: The model response
target: Optional target string (unused but kept for signature consistency)

Returns:
match_reward if detected language matches target, no_match_reward otherwise
"""

# TODO: refactor pending https://github.com/meta-pytorch/torchforge/issues/187
should_debug = self.debug and (random.random() < self.debug_sample_rate)

if not response:
if should_debug:
print(
f"\n[LanguageReward] Empty response | Reward: {self.no_match_reward}"
)
return self.no_match_reward

# Extract all thinking blocks
matches = self._THINK_BLOCK_RE.findall(response)

# Determine what text to analyze
if len(matches) == 1:
# Single block: detect language of block content only
text_to_analyze = matches[0].strip()
detection_mode = "single block"
else:
# No blocks or multiple blocks: detect language of whole response
text_to_analyze = response.strip()
detection_mode = f"{len(matches)} blocks, using whole response"

# Remove extra whitespace
text_to_analyze = re.sub(r"\s+", " ", text_to_analyze).strip()

if not text_to_analyze:
if should_debug:
print(f"\n[LanguageReward] Empty text | Reward: {self.no_match_reward}")
return self.no_match_reward

# Detect language using langid
detected_lang, confidence = self._langid.classify(text_to_analyze)

# Check if language matches target
reward = (
self.match_reward
if detected_lang == self.target_language
else self.no_match_reward
)

if should_debug:
sample = text_to_analyze[:1000].replace("\n", " ")
match_symbol = "✓" if detected_lang == self.target_language else "✗"
print(
f"\n[LanguageReward] Detection mode: {detection_mode}"
f"\n Target: {self.target_language} | Detected: {detected_lang} | "
f"Confidence: {confidence:.2f}"
f"\n Sample: {sample}..."
f"\n → Reward: {reward} {match_symbol}"
)

return reward
Loading
Loading