Skip to content

Added 16 Validations & Some Critical Tests #3

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Jul 16, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ on:
push:
branches:
- main
- add_validations
pull_request:
# All PRs, including stacked PRs

Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ dependencies = [
"typing-extensions>=4.12.2, <5",
"requests>=2.0, <3",
"types-requests>=2.0, <3",
"mcp>=1.9.4, <2; python_version >= '3.10'",
"mcp>=1.11.0, <2; python_version >= '3.10'",
]
classifiers = [
"Typing :: Typed",
Expand Down
109 changes: 109 additions & 0 deletions src/agents/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,9 +222,118 @@ class Agent(AgentBase, Generic[TContext]):
to True. This ensures that the agent doesn't enter an infinite loop of tool usage."""

def __post_init__(self):
from typing import get_origin

if not isinstance(self.name, str):
raise TypeError(f"Agent name must be a string, got {type(self.name).__name__}")

if self.handoff_description is not None and not isinstance(self.handoff_description, str):
raise TypeError(
f"Agent handoff_description must be a string or None, "
f"got {type(self.handoff_description).__name__}"
)

if not isinstance(self.tools, list):
raise TypeError(f"Agent tools must be a list, got {type(self.tools).__name__}")

if not isinstance(self.mcp_servers, list):
raise TypeError(
f"Agent mcp_servers must be a list, got {type(self.mcp_servers).__name__}"
)

if not isinstance(self.mcp_config, dict):
raise TypeError(
f"Agent mcp_config must be a dict, got {type(self.mcp_config).__name__}"
)

if (
self.instructions is not None
and not isinstance(self.instructions, str)
and not callable(self.instructions)
):
raise TypeError(
f"Agent instructions must be a string, callable, or None, "
f"got {type(self.instructions).__name__}"
)

if (
self.prompt is not None
and not callable(self.prompt)
and not hasattr(self.prompt, "get")
):
raise TypeError(
f"Agent prompt must be a Prompt, DynamicPromptFunction, or None, "
f"got {type(self.prompt).__name__}"
)

if not isinstance(self.handoffs, list):
raise TypeError(f"Agent handoffs must be a list, got {type(self.handoffs).__name__}")

if self.model is not None and not isinstance(self.model, str):
from .models.interface import Model

if not isinstance(self.model, Model):
raise TypeError(
f"Agent model must be a string, Model, or None, got {type(self.model).__name__}"
)

if not isinstance(self.model_settings, ModelSettings):
raise TypeError(
f"Agent model_settings must be a ModelSettings instance, "
f"got {type(self.model_settings).__name__}"
)

if not isinstance(self.input_guardrails, list):
raise TypeError(
f"Agent input_guardrails must be a list, got {type(self.input_guardrails).__name__}"
)

if not isinstance(self.output_guardrails, list):
raise TypeError(
f"Agent output_guardrails must be a list, "
f"got {type(self.output_guardrails).__name__}"
)

if self.output_type is not None:
from .agent_output import AgentOutputSchemaBase

if not (
isinstance(self.output_type, (type, AgentOutputSchemaBase))
or get_origin(self.output_type) is not None
):
raise TypeError(
f"Agent output_type must be a type, AgentOutputSchemaBase, or None, "
f"got {type(self.output_type).__name__}"
)

if self.hooks is not None:
from .lifecycle import AgentHooksBase

if not isinstance(self.hooks, AgentHooksBase):
raise TypeError(
f"Agent hooks must be an AgentHooks instance or None, "
f"got {type(self.hooks).__name__}"
)

if (
not (
isinstance(self.tool_use_behavior, str)
and self.tool_use_behavior in ["run_llm_again", "stop_on_first_tool"]
)
and not isinstance(self.tool_use_behavior, dict)
and not callable(self.tool_use_behavior)
):
raise TypeError(
f"Agent tool_use_behavior must be 'run_llm_again', 'stop_on_first_tool', "
f"StopAtTools dict, or callable, got {type(self.tool_use_behavior).__name__}"
)

if not isinstance(self.reset_tool_choice, bool):
raise TypeError(
f"Agent reset_tool_choice must be a boolean, "
f"got {type(self.reset_tool_choice).__name__}"
)

def clone(self, **kwargs: Any) -> Agent[TContext]:
"""Make a copy of the agent, with the given arguments changed. For example, you could do:
```
Expand Down
39 changes: 39 additions & 0 deletions src/agents/mcp/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,17 @@
class MCPServer(abc.ABC):
"""Base class for Model Context Protocol servers."""

def __init__(self, use_structured_content: bool = False):
"""
Args:
use_structured_content: Whether to use `tool_result.structured_content` when calling an
MCP tool.Defaults to False for backwards compatibility - most MCP servers still
include the structured content in the `tool_result.content`, and using it by
default will cause duplicate content. You can set this to True if you know the
server will not duplicate the structured content in the `tool_result.content`.
"""
self.use_structured_content = use_structured_content

@abc.abstractmethod
async def connect(self):
"""Connect to the server. For example, this might mean spawning a subprocess or
Expand Down Expand Up @@ -86,6 +97,7 @@ def __init__(
cache_tools_list: bool,
client_session_timeout_seconds: float | None,
tool_filter: ToolFilter = None,
use_structured_content: bool = False,
):
"""
Args:
Expand All @@ -98,7 +110,13 @@ def __init__(

client_session_timeout_seconds: the read timeout passed to the MCP ClientSession.
tool_filter: The tool filter to use for filtering tools.
use_structured_content: Whether to use `tool_result.structured_content` when calling an
MCP tool. Defaults to False for backwards compatibility - most MCP servers still
include the structured content in the `tool_result.content`, and using it by
default will cause duplicate content. You can set this to True if you know the
server will not duplicate the structured content in the `tool_result.content`.
"""
super().__init__(use_structured_content=use_structured_content)
self.session: ClientSession | None = None
self.exit_stack: AsyncExitStack = AsyncExitStack()
self._cleanup_lock: asyncio.Lock = asyncio.Lock()
Expand Down Expand Up @@ -346,6 +364,7 @@ def __init__(
name: str | None = None,
client_session_timeout_seconds: float | None = 5,
tool_filter: ToolFilter = None,
use_structured_content: bool = False,
):
"""Create a new MCP server based on the stdio transport.

Expand All @@ -364,11 +383,17 @@ def __init__(
command.
client_session_timeout_seconds: the read timeout passed to the MCP ClientSession.
tool_filter: The tool filter to use for filtering tools.
use_structured_content: Whether to use `tool_result.structured_content` when calling an
MCP tool. Defaults to False for backwards compatibility - most MCP servers still
include the structured content in the `tool_result.content`, and using it by
default will cause duplicate content. You can set this to True if you know the
server will not duplicate the structured content in the `tool_result.content`.
"""
super().__init__(
cache_tools_list,
client_session_timeout_seconds,
tool_filter,
use_structured_content,
)

self.params = StdioServerParameters(
Expand Down Expand Up @@ -429,6 +454,7 @@ def __init__(
name: str | None = None,
client_session_timeout_seconds: float | None = 5,
tool_filter: ToolFilter = None,
use_structured_content: bool = False,
):
"""Create a new MCP server based on the HTTP with SSE transport.

Expand All @@ -449,11 +475,17 @@ def __init__(

client_session_timeout_seconds: the read timeout passed to the MCP ClientSession.
tool_filter: The tool filter to use for filtering tools.
use_structured_content: Whether to use `tool_result.structured_content` when calling an
MCP tool. Defaults to False for backwards compatibility - most MCP servers still
include the structured content in the `tool_result.content`, and using it by
default will cause duplicate content. You can set this to True if you know the
server will not duplicate the structured content in the `tool_result.content`.
"""
super().__init__(
cache_tools_list,
client_session_timeout_seconds,
tool_filter,
use_structured_content,
)

self.params = params
Expand Down Expand Up @@ -514,6 +546,7 @@ def __init__(
name: str | None = None,
client_session_timeout_seconds: float | None = 5,
tool_filter: ToolFilter = None,
use_structured_content: bool = False,
):
"""Create a new MCP server based on the Streamable HTTP transport.

Expand All @@ -535,11 +568,17 @@ def __init__(

client_session_timeout_seconds: the read timeout passed to the MCP ClientSession.
tool_filter: The tool filter to use for filtering tools.
use_structured_content: Whether to use `tool_result.structured_content` when calling an
MCP tool. Defaults to False for backwards compatibility - most MCP servers still
include the structured content in the `tool_result.content`, and using it by
default will cause duplicate content. You can set this to True if you know the
server will not duplicate the structured content in the `tool_result.content`.
"""
super().__init__(
cache_tools_list,
client_session_timeout_seconds,
tool_filter,
use_structured_content,
)

self.params = params
Expand Down
10 changes: 9 additions & 1 deletion src/agents/mcp/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,8 +198,16 @@ async def invoke_mcp_tool(
# string. We'll try to convert.
if len(result.content) == 1:
tool_output = result.content[0].model_dump_json()
# Append structured content if it exists and we're using it.
if server.use_structured_content and result.structuredContent:
tool_output = f"{tool_output}\n{json.dumps(result.structuredContent)}"
elif len(result.content) > 1:
tool_output = json.dumps([item.model_dump(mode="json") for item in result.content])
tool_results = [item.model_dump(mode="json") for item in result.content]
if server.use_structured_content and result.structuredContent:
tool_results.append(result.structuredContent)
tool_output = json.dumps(tool_results)
elif server.use_structured_content and result.structuredContent:
tool_output = json.dumps(result.structuredContent)
else:
logger.error(f"Errored MCP tool result: {result}")
tool_output = "Error running tool."
Expand Down
1 change: 1 addition & 0 deletions tests/mcp/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ def __init__(
tool_filter: ToolFilter = None,
server_name: str = "fake_mcp_server",
):
super().__init__(use_structured_content=False)
self.tools: list[MCPTool] = tools or []
self.tool_calls: list[str] = []
self.tool_results: list[str] = []
Expand Down
6 changes: 3 additions & 3 deletions tests/mcp/test_mcp_tracing.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ async def test_mcp_tracing():
"data": {
"name": "test_tool_1",
"input": "",
"output": '{"type":"text","text":"result_test_tool_1_{}","annotations":null}', # noqa: E501
"output": '{"type":"text","text":"result_test_tool_1_{}","annotations":null,"meta":null}', # noqa: E501
"mcp_data": {"server": "fake_mcp_server"},
},
},
Expand Down Expand Up @@ -133,7 +133,7 @@ async def test_mcp_tracing():
"data": {
"name": "test_tool_2",
"input": "",
"output": '{"type":"text","text":"result_test_tool_2_{}","annotations":null}', # noqa: E501
"output": '{"type":"text","text":"result_test_tool_2_{}","annotations":null,"meta":null}', # noqa: E501
"mcp_data": {"server": "fake_mcp_server"},
},
},
Expand Down Expand Up @@ -197,7 +197,7 @@ async def test_mcp_tracing():
"data": {
"name": "test_tool_3",
"input": "",
"output": '{"type":"text","text":"result_test_tool_3_{}","annotations":null}', # noqa: E501
"output": '{"type":"text","text":"result_test_tool_3_{}","annotations":null,"meta":null}', # noqa: E501
"mcp_data": {"server": "fake_mcp_server"},
},
},
Expand Down
57 changes: 57 additions & 0 deletions tests/test_agent_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
from pydantic import BaseModel

from agents import Agent, AgentOutputSchema, Handoff, RunContextWrapper, handoff
from agents.lifecycle import AgentHooksBase
from agents.model_settings import ModelSettings
from agents.run import AgentRunner


Expand Down Expand Up @@ -167,3 +169,58 @@ async def test_agent_final_output():
assert schema.is_strict_json_schema() is True
assert schema.json_schema() is not None
assert not schema.is_plain_text()


class TestAgentValidation:
"""Essential validation tests for Agent __post_init__"""

def test_name_validation_critical_cases(self):
"""Test name validation - the original issue that started this PR"""
# This was the original failing case that caused JSON serialization errors
with pytest.raises(TypeError, match="Agent name must be a string, got int"):
Agent(name=1) # type: ignore

with pytest.raises(TypeError, match="Agent name must be a string, got NoneType"):
Agent(name=None) # type: ignore

def test_tool_use_behavior_dict_validation(self):
"""Test tool_use_behavior accepts StopAtTools dict - fixes existing test failures"""
# This test ensures the existing failing tests now pass
Agent(name="test", tool_use_behavior={"stop_at_tool_names": ["tool1"]})

# Invalid cases that should fail
with pytest.raises(TypeError, match="Agent tool_use_behavior must be"):
Agent(name="test", tool_use_behavior=123) # type: ignore

def test_hooks_validation_python39_compatibility(self):
"""Test hooks validation works with Python 3.9 - fixes generic type issues"""

class MockHooks(AgentHooksBase):
pass

# Valid case
Agent(name="test", hooks=MockHooks()) # type: ignore

# Invalid case
with pytest.raises(TypeError, match="Agent hooks must be an AgentHooks instance"):
Agent(name="test", hooks="invalid") # type: ignore

def test_list_field_validation(self):
"""Test critical list fields that commonly get wrong types"""
# These are the most common mistakes users make
with pytest.raises(TypeError, match="Agent tools must be a list"):
Agent(name="test", tools="not_a_list") # type: ignore

with pytest.raises(TypeError, match="Agent handoffs must be a list"):
Agent(name="test", handoffs="not_a_list") # type: ignore

def test_model_settings_validation(self):
"""Test model_settings validation - prevents runtime errors"""
# Valid case
Agent(name="test", model_settings=ModelSettings())

# Invalid case that could cause runtime issues
with pytest.raises(
TypeError, match="Agent model_settings must be a ModelSettings instance"
):
Agent(name="test", model_settings={}) # type: ignore
Loading