diff --git a/apps/grpo/main.py b/apps/grpo/main.py
index 1c1c2bd4a..18be2a439 100644
--- a/apps/grpo/main.py
+++ b/apps/grpo/main.py
@@ -26,7 +26,7 @@
from forge.actors.trainer import TitanTrainer
from forge.controller.actor import ForgeActor
from forge.controller.provisioner import init_provisioner, shutdown
-from forge.data.rewards import MathReward, ThinkingReward
+from forge.data.rewards import LanguageReward, MathReward, ThinkingReward
from forge.data_models.completion import Completion
from forge.observability.metric_actors import get_or_create_metric_logger
from forge.observability.metrics import record_metric, Reduce
@@ -129,7 +129,7 @@ def simple_grpo_loss(
ref_logprobs: torch.Tensor,
advantages: torch.Tensor,
padding_mask: torch.Tensor,
- beta: float = 0.1,
+ beta: float = 1e-5,
) -> torch.Tensor:
logprobs: torch.Tensor = compute_logprobs(logits, response)
kl = torch.exp(ref_logprobs - logprobs) - (ref_logprobs - logprobs) - 1
@@ -237,10 +237,15 @@ async def setup(self):
self._epoch = 0
def gsm8k_transform(sample):
- system_prompt = """
- Put all your scratchpad work between and tags.
- Your final answer should be between and tags otherwise it will not be scored.
- """
+ system_prompt = """You are a helpful AI assistant that solves math problems.
+
+Please show your reasoning inside <思考>思考> tags, then provide your final numerical answer inside tags.
+
+Example:
+Question: What is 12 + 5?
+<思考>12と5を足します。12 + 5 = 17です。思考>
+17
+"""
request: str = sample["question"]
as_chat = [
{"role": "system", "content": system_prompt},
@@ -359,7 +364,17 @@ async def main(cfg: DictConfig):
ComputeAdvantages.options(**cfg.actors.compute_advantages).as_actor(),
ReferenceModel.options(**cfg.services.ref_model).as_service(**cfg.ref_model),
RewardActor.options(**cfg.services.reward_actor).as_service(
- reward_functions=[MathReward(), ThinkingReward()]
+ reward_functions=[
+ MathReward(),
+ ThinkingReward(tag="思考"), # Use Japanese tag
+ LanguageReward(
+ target_language="ja",
+ tag="思考",
+ match_reward=2.0,
+ debug=True,
+ debug_sample_rate=0.1,
+ ), # Japanese language reward with debug
+ ]
),
)
diff --git a/apps/grpo/qwen3_1_7b.yaml b/apps/grpo/qwen3_1_7b.yaml
index c6fc1613b..ebdd27787 100644
--- a/apps/grpo/qwen3_1_7b.yaml
+++ b/apps/grpo/qwen3_1_7b.yaml
@@ -5,7 +5,7 @@
group_size: 8
local_batch_size: 16 # per-device batch size
max_req_tokens: 1024
-max_res_tokens: 1024
+max_res_tokens: 2048
model: "Qwen/Qwen3-1.7B"
off_by_n: 1 # Off by one by default
diff --git a/apps/grpo/qwen3_8b.yaml b/apps/grpo/qwen3_8b.yaml
index a2815c5c0..c59408ef2 100644
--- a/apps/grpo/qwen3_8b.yaml
+++ b/apps/grpo/qwen3_8b.yaml
@@ -3,9 +3,9 @@
# Global configuration
group_size: 8
-local_batch_size: 12 # per-device batch size
+local_batch_size: 8 # per-device batch size
max_req_tokens: 1024
-max_res_tokens: 1024
+max_res_tokens: 2048
model: "Qwen/Qwen3-8B"
off_by_n: 1 # Off by one by default
diff --git a/pyproject.toml b/pyproject.toml
index 7a553f172..48bd8c238 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -47,6 +47,7 @@ dev = [
"anyio",
"pytest-asyncio",
"multiprocess",
+ "langid",
]
docs = [
"sphinx==7.2.6",
diff --git a/src/forge/data/rewards.py b/src/forge/data/rewards.py
index 23a0002df..b30eb0408 100644
--- a/src/forge/data/rewards.py
+++ b/src/forge/data/rewards.py
@@ -4,6 +4,7 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
+import random
import re
@@ -57,15 +58,28 @@ def _to_float(self, text: str) -> float | None:
class ThinkingReward:
- """Reward class for evaluating use of tags in reasoning."""
+ """Reward class for evaluating use of thinking tags in reasoning.
- def __init__(self, partial_reward: float = 0.2, full_reward: float = 1.0):
+ Args:
+ partial_reward: Reward for partial tag usage (incomplete/malformed)
+ full_reward: Reward for well-formed thinking blocks with content
+ tag: Tag name to use (default "think", can use "思考" for Japanese, etc.)
+ """
+
+ def __init__(
+ self, partial_reward: float = 0.2, full_reward: float = 1.0, tag: str = "think"
+ ):
self.partial_reward = partial_reward
self.full_reward = full_reward
+ self.tag = tag
+ # Build regex patterns for the specified tag
self._THINK_BLOCK_RE = re.compile(
- r"<\s*think\s*>(.*?)<\s*/\s*think\s*>", re.IGNORECASE | re.DOTALL
+ rf"<\s*{re.escape(tag)}\s*>(.*?)<\s*/\s*{re.escape(tag)}\s*>",
+ re.IGNORECASE | re.DOTALL,
+ )
+ self._THINK_TAG_ATTEMPT_RE = re.compile(
+ rf"<\s*/?\s*{re.escape(tag)}\s*>", re.IGNORECASE
)
- self._THINK_TAG_ATTEMPT_RE = re.compile(r"<\s*/?\s*think\s*>", re.IGNORECASE)
def __call__(self, prompt: str, response: str, target: str | None = None) -> float:
"""Compute thinking reward."""
@@ -80,3 +94,128 @@ def __call__(self, prompt: str, response: str, target: str | None = None) -> flo
elif has_attempt:
return self.partial_reward
return 0.0
+
+
+class LanguageReward:
+ """Reward class for evaluating the language used in responses.
+
+ This reward uses langid to detect the language and rewards responses that use
+ the target language. The detection strategy depends on the format:
+ - If exactly one thinking block: detect language of the block content
+ - Otherwise (no blocks or multiple blocks): detect language of whole response
+
+ Note: Format enforcement (single vs multiple blocks) is handled by ThinkingReward.
+ This reward focuses purely on language detection.
+
+ Args:
+ target_language: ISO 639-1 language code (e.g., 'en', 'ja', 'zh', 'es')
+ match_reward: Reward when detected language matches target (default: 1.0)
+ no_match_reward: Reward when language doesn't match (default: 0.0)
+ tag: Tag name to use (default "思考" for multilingual, can use "think", etc.)
+ debug: If True, print debug samples showing model outputs and detected language
+ debug_sample_rate: Fraction of calls to debug (e.g., 0.1 = 10% of calls)
+
+ Note: Requires langid to be installed. Install with: pip install langid
+ """
+
+ def __init__(
+ self,
+ target_language: str = "ja",
+ match_reward: float = 1.0,
+ no_match_reward: float = 0.0,
+ tag: str = "思考",
+ debug: bool = False,
+ debug_sample_rate: float = 0.1,
+ ):
+ self.target_language = target_language
+ self.match_reward = match_reward
+ self.no_match_reward = no_match_reward
+ self.tag = tag
+ self.debug = debug
+ self.debug_sample_rate = debug_sample_rate
+ self._debug_counter = 0
+ # Build regex pattern for the specified tag
+ self._THINK_BLOCK_RE = re.compile(
+ rf"<\s*{re.escape(tag)}\s*>(.*?)<\s*/\s*{re.escape(tag)}\s*>", re.DOTALL
+ )
+
+ # Lazy import langid with helpful error message
+ try:
+ import langid
+
+ self._langid = langid
+ except ImportError:
+ raise ImportError(
+ "langid is required for LanguageReward but is not installed. "
+ "Please install it with: pip install langid"
+ ) from None
+
+ def __call__(self, prompt: str, response: str, target: str | None = None) -> float:
+ """Compute language reward based on detected language.
+
+ Detection strategy:
+ - If exactly one thinking block: detect language of block content
+ - Otherwise: detect language of whole response
+
+ Args:
+ prompt: The input prompt (unused but kept for signature consistency)
+ response: The model response
+ target: Optional target string (unused but kept for signature consistency)
+
+ Returns:
+ match_reward if detected language matches target, no_match_reward otherwise
+ """
+
+ # TODO: refactor pending https://github.com/meta-pytorch/torchforge/issues/187
+ should_debug = self.debug and (random.random() < self.debug_sample_rate)
+
+ if not response:
+ if should_debug:
+ print(
+ f"\n[LanguageReward] Empty response | Reward: {self.no_match_reward}"
+ )
+ return self.no_match_reward
+
+ # Extract all thinking blocks
+ matches = self._THINK_BLOCK_RE.findall(response)
+
+ # Determine what text to analyze
+ if len(matches) == 1:
+ # Single block: detect language of block content only
+ text_to_analyze = matches[0].strip()
+ detection_mode = "single block"
+ else:
+ # No blocks or multiple blocks: detect language of whole response
+ text_to_analyze = response.strip()
+ detection_mode = f"{len(matches)} blocks, using whole response"
+
+ # Remove extra whitespace
+ text_to_analyze = re.sub(r"\s+", " ", text_to_analyze).strip()
+
+ if not text_to_analyze:
+ if should_debug:
+ print(f"\n[LanguageReward] Empty text | Reward: {self.no_match_reward}")
+ return self.no_match_reward
+
+ # Detect language using langid
+ detected_lang, confidence = self._langid.classify(text_to_analyze)
+
+ # Check if language matches target
+ reward = (
+ self.match_reward
+ if detected_lang == self.target_language
+ else self.no_match_reward
+ )
+
+ if should_debug:
+ sample = text_to_analyze[:1000].replace("\n", " ")
+ match_symbol = "✓" if detected_lang == self.target_language else "✗"
+ print(
+ f"\n[LanguageReward] Detection mode: {detection_mode}"
+ f"\n Target: {self.target_language} | Detected: {detected_lang} | "
+ f"Confidence: {confidence:.2f}"
+ f"\n Sample: {sample}..."
+ f"\n → Reward: {reward} {match_symbol}"
+ )
+
+ return reward
diff --git a/tests/unit_tests/rl/test_language_reward.py b/tests/unit_tests/rl/test_language_reward.py
new file mode 100644
index 000000000..00bb28d29
--- /dev/null
+++ b/tests/unit_tests/rl/test_language_reward.py
@@ -0,0 +1,266 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+import sys
+import unittest
+from unittest.mock import patch
+
+
+class TestLanguageReward(unittest.TestCase):
+ def setUp(self):
+ """Set up test fixtures before each test method."""
+ # Import after patching to avoid ImportError
+ from forge.data.rewards import LanguageReward
+
+ self.LanguageReward = LanguageReward
+ self.reward_en = LanguageReward(target_language="en")
+ self.reward_ja = LanguageReward(target_language="ja")
+
+ def test_init_default_values(self):
+ """Test LanguageReward initialization with default values."""
+ reward = self.LanguageReward()
+ self.assertEqual(reward.target_language, "ja")
+ self.assertEqual(reward.match_reward, 1.0)
+ self.assertEqual(reward.no_match_reward, 0.0)
+
+ def test_init_custom_values(self):
+ """Test LanguageReward initialization with custom values."""
+ reward = self.LanguageReward(
+ target_language="ja",
+ match_reward=0.9,
+ no_match_reward=0.1,
+ )
+ self.assertEqual(reward.target_language, "ja")
+ self.assertEqual(reward.match_reward, 0.9)
+ self.assertEqual(reward.no_match_reward, 0.1)
+
+ def test_init_missing_langid(self):
+ """Test LanguageReward initialization without langid installed."""
+ # Remove langid from modules if it exists
+ langid_module = sys.modules.get("langid")
+ if "langid" in sys.modules:
+ del sys.modules["langid"]
+
+ with patch.dict("sys.modules", {"langid": None}):
+ with self.assertRaises(ImportError) as context:
+ # Re-import to trigger the ImportError
+ import importlib
+
+ import forge.data.rewards
+
+ importlib.reload(forge.data.rewards)
+ forge.data.rewards.LanguageReward()
+
+ self.assertIn("langid is required", str(context.exception))
+ self.assertIn("pip install langid", str(context.exception))
+
+ # Restore langid module if it existed
+ if langid_module is not None:
+ sys.modules["langid"] = langid_module
+
+ def test_regex_pattern(self):
+ """Test that regex pattern is compiled correctly."""
+ reward = self.LanguageReward()
+ self.assertIsNotNone(reward._THINK_BLOCK_RE)
+
+ def test_call_with_english_thinking(self):
+ """Test __call__ with English text in thinking blocks."""
+ response = "<思考>This is English reasoning about math problems.思考>"
+ result = self.reward_en("prompt", response)
+ self.assertEqual(result, 1.0)
+
+ def test_call_with_japanese_thinking(self):
+ """Test __call__ with Japanese text in thinking blocks."""
+ response = "<思考>これは日本語で考えています。数学の問題を解きます。思考>"
+ result = self.reward_ja("prompt", response)
+ self.assertEqual(result, 1.0)
+
+ # English reward should give no_match_reward for Japanese text
+ result = self.reward_en("prompt", response)
+ self.assertEqual(result, 0.0)
+
+ def test_call_with_chinese_thinking(self):
+ """Test __call__ with Chinese text in thinking blocks."""
+ response = "<思考>这是中文思考。我们需要解决这个数学问题。思考>"
+ reward_zh = self.LanguageReward(target_language="zh")
+ result = reward_zh("prompt", response)
+ # langid should detect this as Chinese (zh)
+ self.assertEqual(result, 1.0)
+
+ def test_call_with_spanish_thinking(self):
+ """Test __call__ with Spanish text in thinking blocks."""
+ response = "<思考>Este es un razonamiento en español sobre problemas matemáticos.思考>"
+ reward_es = self.LanguageReward(target_language="es")
+ result = reward_es("prompt", response)
+ # langid should detect this as Spanish (es)
+ self.assertEqual(result, 1.0)
+
+ def test_call_language_mismatch(self):
+ """Test __call__ when detected language doesn't match target."""
+ # Japanese reward with English text
+ response = "<思考>This is English reasoning.思考>"
+ result = self.reward_ja("prompt", response)
+ self.assertEqual(result, 0.0)
+
+ # English reward with Japanese text
+ response = "<思考>これは日本語です。思考>"
+ result = self.reward_en("prompt", response)
+ self.assertEqual(result, 0.0)
+
+ def test_call_with_no_thinking_tags(self):
+ """Test __call__ with response containing no thinking tags but correct language."""
+ result = self.reward_en(
+ "prompt", "This is just a regular response without any thinking tags."
+ )
+ # No thinking blocks -> detect whole response, English detected -> match_reward
+ self.assertEqual(result, 1.0)
+
+ def test_call_with_no_thinking_tags_wrong_language(self):
+ """Test __call__ with response containing no thinking tags and wrong language."""
+ result = self.reward_en("prompt", "これは日本語の応答です。タグはありません。")
+ # No thinking blocks -> detect whole response, Japanese detected -> no_match_reward
+ self.assertEqual(result, 0.0)
+
+ def test_call_with_empty_thinking_block(self):
+ """Test __call__ with empty thinking block."""
+ result = self.reward_en("prompt", "<思考>思考>")
+ self.assertEqual(result, 0.0)
+
+ def test_call_with_whitespace_only_thinking_block(self):
+ """Test __call__ with whitespace-only thinking block."""
+ result = self.reward_en("prompt", "<思考> \n \t 思考>")
+ self.assertEqual(result, 0.0)
+
+ def test_call_with_proper_tags(self):
+ """Test __call__ with properly formatted thinking tags."""
+ response = "<思考>This is English reasoning.思考>"
+ result = self.reward_en("prompt", response)
+ self.assertEqual(result, 1.0)
+
+ # Japanese content should also work
+ response = "<思考>これは日本語です。思考>"
+ result = self.reward_ja("prompt", response)
+ self.assertEqual(result, 1.0)
+
+ def test_call_multiple_thinking_blocks(self):
+ """Test __call__ with multiple thinking blocks - detects whole response language."""
+ response = """
+ <思考>First thought in English.思考>
+ Some text in between.
+ <思考>Second thought also in English.思考>
+ """
+ result = self.reward_en("prompt", response)
+ # Multiple blocks -> detect whole response, English detected -> match_reward
+ self.assertEqual(result, 1.0)
+
+ def test_call_multiline_thinking_block(self):
+ """Test __call__ with multiline thinking blocks."""
+ response = """<思考>
+ This is a multiline
+ thinking block with
+ lots of English content
+ about solving problems
+ 思考>"""
+ result = self.reward_en("prompt", response)
+ self.assertEqual(result, 1.0)
+
+ def test_call_empty_response(self):
+ """Test __call__ with empty response."""
+ result = self.reward_en("prompt", "")
+ self.assertEqual(result, 0.0)
+
+ def test_call_none_response(self):
+ """Test __call__ with None response."""
+ result = self.reward_en("prompt", None)
+ self.assertEqual(result, 0.0)
+
+ def test_call_custom_reward_values(self):
+ """Test __call__ with custom reward values."""
+ response_ja_single = "<思考>これは日本語です。思考>"
+ response_ja_multiple = "<思考>最初の考え。思考><思考>次の考え。思考>"
+ response_ja_no_tags = "これはタグなしの日本語です。"
+ response_en = "<思考>This is English.思考>"
+ response_none = ""
+
+ custom_reward = self.LanguageReward(
+ target_language="ja",
+ match_reward=0.9,
+ no_match_reward=0.1,
+ )
+ # Test custom match reward (single block, correct language)
+ self.assertEqual(custom_reward("prompt", response_ja_single), 0.9)
+ # Test custom match reward (multiple blocks -> whole response, correct language)
+ self.assertEqual(custom_reward("prompt", response_ja_multiple), 0.9)
+ # Test custom match reward (no blocks -> whole response, correct language)
+ self.assertEqual(custom_reward("prompt", response_ja_no_tags), 0.9)
+ # Test custom no_match reward (wrong language)
+ self.assertEqual(custom_reward("prompt", response_en), 0.1)
+ # Test empty response
+ self.assertEqual(custom_reward("prompt", response_none), 0.1)
+
+ def test_call_with_special_characters(self):
+ """Test __call__ with special characters in thinking blocks."""
+ response = (
+ "<思考>English with special chars: @#$%^&*()_+-=[]{}|;':\",./<>?`~思考>"
+ )
+ result = self.reward_en("prompt", response)
+ self.assertEqual(result, 1.0)
+
+ def test_call_with_mixed_content_outside_tags(self):
+ """Test __call__ with mixed language content outside thinking tags."""
+ # Content outside think tags should be ignored
+ response = """
+ これは日本語のテキストです。
+ <思考>But this is English reasoning inside the tags.思考>
+ もっと日本語のテキスト。
+ """
+ result = self.reward_en("prompt", response)
+ # Should detect English from thinking block only
+ self.assertEqual(result, 1.0)
+
+ def test_call_with_numbers_and_symbols(self):
+ """Test __call__ with thinking blocks containing mostly numbers."""
+ response = "<思考>Calculate: 2 + 2 = 4, then 4 * 3 = 12思考>"
+ result = self.reward_en("prompt", response)
+ # Should still detect as English due to words like "Calculate" and "then"
+ self.assertEqual(result, 1.0)
+
+ def test_call_with_code_in_thinking(self):
+ """Test __call__ with code snippets in thinking blocks."""
+ response = """<思考>
+ Let me write some Python code to solve this:
+ def calculate(x):
+ return x * 2
+ The function doubles the input value.
+ 思考>"""
+ result = self.reward_en("prompt", response)
+ # Should detect as English due to surrounding text
+ self.assertEqual(result, 1.0)
+
+ def test_different_language_codes(self):
+ """Test __call__ with various ISO 639-1 language codes."""
+ # Test a few common languages
+ languages = {
+ "fr": "Ceci est un texte en français avec beaucoup de contenu.",
+ "de": "Dies ist ein deutscher Text mit viel Inhalt.",
+ "it": "Questo è un testo italiano con molto contenuto.",
+ "pt": "Este é um texto em português com muito conteúdo.",
+ }
+
+ for lang_code, text in languages.items():
+ reward = self.LanguageReward(target_language=lang_code)
+ response = f"<思考>{text}思考>"
+ result = reward("prompt", response)
+ # langid should detect these correctly
+ self.assertEqual(
+ result,
+ 1.0,
+ f"Failed to detect {lang_code} language: '{text[:50]}...'",
+ )
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/tests/unit_tests/rl/test_thinking_reward.py b/tests/unit_tests/rl/test_thinking_reward.py
index b95823e9a..10b7bf38e 100644
--- a/tests/unit_tests/rl/test_thinking_reward.py
+++ b/tests/unit_tests/rl/test_thinking_reward.py
@@ -203,6 +203,19 @@ def test_call_very_long_thinking_block(self):
result = self.reward("prompt", f"{long_content}")
self.assertEqual(result, 1.0)
+ def test_custom_tag(self):
+ """Test that ThinkingReward uses the custom tag passed in."""
+ # Create reward with custom Japanese tag
+ custom_tag_reward = ThinkingReward(tag="思考")
+
+ # Response with custom tag should get full reward
+ result = custom_tag_reward("prompt", "<思考>This is my reasoning思考>")
+ self.assertEqual(result, 1.0)
+
+ # Response with default "think" tag should get no reward
+ result = custom_tag_reward("prompt", "This is my reasoning")
+ self.assertEqual(result, 0.0)
+
if __name__ == "__main__":
unittest.main()