diff --git a/src/react_agent/context.py b/src/react_agent/context.py index 8ccfa75..a78bc16 100644 --- a/src/react_agent/context.py +++ b/src/react_agent/context.py @@ -4,7 +4,7 @@ import os from dataclasses import dataclass, field, fields -from typing import Annotated +from typing import Annotated, Any, get_type_hints from . import prompts @@ -36,11 +36,45 @@ class Context: }, ) + temperature: float = field( + default=0.1, + metadata={ + "description": "The temperature setting for the language model (0.0 to 1.0)." + }, + ) + + enable_debug: bool = field( + default=False, + metadata={"description": "Enable debug mode for verbose logging."}, + ) + def __post_init__(self) -> None: - """Fetch env vars for attributes that were not passed as args.""" + """Fetch env vars for attributes that were not passed as args, with type conversion.""" + type_hints = get_type_hints(self.__class__) for f in fields(self): if not f.init: continue - if getattr(self, f.name) == f.default: - setattr(self, f.name, os.environ.get(f.name.upper(), f.default)) + current_value = getattr(self, f.name) + env_value = os.environ.get(f.name.upper(), None) + if current_value == f.default and env_value is not None: + # Convert env_value to the correct type + target_type = type_hints.get(f.name, str) + converted_value: Any = env_value # Default to string value + try: + if target_type is int: + converted_value = int(env_value) + elif target_type is float: + converted_value = float(env_value) + elif target_type is bool: + converted_value = env_value.lower() in ( + "true", + "1", + "yes", + "on", + ) + # str type requires no conversion + except (ValueError, AttributeError): + # If conversion fails, keep the original default value + converted_value = current_value + setattr(self, f.name, converted_value) diff --git a/src/react_agent/tools.py b/src/react_agent/tools.py index 4ce1eb6..7afd11b 100644 --- a/src/react_agent/tools.py +++ b/src/react_agent/tools.py @@ -8,7 +8,7 @@ from typing import Any, Callable, List, Optional, cast -from langchain_tavily import TavilySearch # type: ignore[import-not-found] +from langchain_tavily import TavilySearch from langgraph.runtime import get_runtime from react_agent.context import Context diff --git a/tests/unit_tests/test_configuration.py b/tests/unit_tests/test_configuration.py index f180b62..4ee05e0 100644 --- a/tests/unit_tests/test_configuration.py +++ b/tests/unit_tests/test_configuration.py @@ -18,3 +18,164 @@ def test_context_init_with_env_vars_and_passed_values() -> None: os.environ["MODEL"] = "openai/gpt-4o-mini" context = Context(model="openai/gpt-5o-mini") assert context.model == "openai/gpt-5o-mini" + + +def test_context_int_type_conversion() -> None: + """Test that integer environment variables are properly converted.""" + # Clean up environment + os.environ.pop("MAX_SEARCH_RESULTS", None) + + # Test int conversion + os.environ["MAX_SEARCH_RESULTS"] = "20" + context = Context() + assert context.max_search_results == 20 + assert isinstance(context.max_search_results, int) + + # Clean up + os.environ.pop("MAX_SEARCH_RESULTS", None) + + +def test_context_int_type_conversion_invalid() -> None: + """Test that invalid integer environment variables keep default value.""" + # Clean up environment + os.environ.pop("MAX_SEARCH_RESULTS", None) + + # Test invalid int conversion - should keep default value + os.environ["MAX_SEARCH_RESULTS"] = "not_a_number" + context = Context() + # Should keep default value when int conversion fails + assert context.max_search_results == 10 # default value + assert isinstance(context.max_search_results, int) + + # Clean up + os.environ.pop("MAX_SEARCH_RESULTS", None) + + +def test_context_string_type_conversion() -> None: + """Test that string environment variables work correctly.""" + # Clean up environment + os.environ.pop("MODEL", None) + + # Test string conversion (no conversion needed) + os.environ["MODEL"] = "test/model-name" + context = Context() + assert context.model == "test/model-name" + assert isinstance(context.model, str) + + # Clean up + os.environ.pop("MODEL", None) + + +def test_context_env_vars_only_used_for_defaults() -> None: + """Test that environment variables are only used when field has default value.""" + # Clean up environment + os.environ.pop("MAX_SEARCH_RESULTS", None) + os.environ.pop("MODEL", None) + + # Set environment variables + os.environ["MAX_SEARCH_RESULTS"] = "99" + os.environ["MODEL"] = "env/model" + + # Pass explicit values - should override env vars + context = Context(max_search_results=5, model="explicit/model") + assert context.max_search_results == 5 + assert context.model == "explicit/model" + + # Clean up + os.environ.pop("MAX_SEARCH_RESULTS", None) + os.environ.pop("MODEL", None) + + +def test_context_float_type_conversion() -> None: + """Test that float environment variables are properly converted.""" + # Clean up environment + os.environ.pop("TEMPERATURE", None) + + # Test float conversion + os.environ["TEMPERATURE"] = "0.5" + context = Context() + assert context.temperature == 0.5 + assert isinstance(context.temperature, float) + + # Clean up + os.environ.pop("TEMPERATURE", None) + + +def test_context_float_type_conversion_invalid() -> None: + """Test that invalid float environment variables keep default value.""" + # Clean up environment + os.environ.pop("TEMPERATURE", None) + + # Test invalid float conversion - should keep default value + os.environ["TEMPERATURE"] = "not_a_float" + context = Context() + # Should keep default value when float conversion fails + assert context.temperature == 0.1 # default value + assert isinstance(context.temperature, float) + + # Clean up + os.environ.pop("TEMPERATURE", None) + + +def test_context_bool_type_conversion() -> None: + """Test that boolean environment variables are properly converted.""" + # Clean up environment + os.environ.pop("ENABLE_DEBUG", None) + + # Test various true values + for true_value in ["true", "True", "TRUE", "1", "yes", "YES", "on", "ON"]: + os.environ["ENABLE_DEBUG"] = true_value + context = Context() + assert context.enable_debug is True + assert isinstance(context.enable_debug, bool) + + # Test various false values + for false_value in [ + "false", + "False", + "FALSE", + "0", + "no", + "NO", + "off", + "OFF", + "anything_else", + ]: + os.environ["ENABLE_DEBUG"] = false_value + context = Context() + assert context.enable_debug is False + assert isinstance(context.enable_debug, bool) + + # Clean up + os.environ.pop("ENABLE_DEBUG", None) + + +def test_context_multiple_type_conversions() -> None: + """Test multiple type conversions at once.""" + # Clean up environment + os.environ.pop("MAX_SEARCH_RESULTS", None) + os.environ.pop("TEMPERATURE", None) + os.environ.pop("ENABLE_DEBUG", None) + os.environ.pop("MODEL", None) + + # Set multiple environment variables + os.environ["MAX_SEARCH_RESULTS"] = "25" + os.environ["TEMPERATURE"] = "0.8" + os.environ["ENABLE_DEBUG"] = "true" + os.environ["MODEL"] = "test/model" + + context = Context() + assert context.max_search_results == 25 + assert isinstance(context.max_search_results, int) + assert context.temperature == 0.8 + assert isinstance(context.temperature, float) + assert context.enable_debug is True + assert isinstance(context.enable_debug, bool) + assert context.model == "test/model" + assert isinstance(context.model, str) + + # Clean up + os.environ.pop("MAX_SEARCH_RESULTS", None) + os.environ.pop("TEMPERATURE", None) + os.environ.pop("ENABLE_DEBUG", None) + os.environ.pop("MODEL", None)