diff --git a/src/strands/tools/loader.py b/src/strands/tools/loader.py index 56433324e..5935077db 100644 --- a/src/strands/tools/loader.py +++ b/src/strands/tools/loader.py @@ -4,8 +4,9 @@ import logging import os import sys +import warnings from pathlib import Path -from typing import cast +from typing import List, cast from ..types.tools import AgentTool from .decorator import DecoratedFunctionTool @@ -18,60 +19,42 @@ class ToolLoader: """Handles loading of tools from different sources.""" @staticmethod - def load_python_tool(tool_path: str, tool_name: str) -> AgentTool: - """Load a Python tool module. - - Args: - tool_path: Path to the Python tool file. - tool_name: Name of the tool. + def load_python_tools(tool_path: str, tool_name: str) -> List[AgentTool]: + """Load a Python tool module and return all discovered function-based tools as a list. - Returns: - Tool instance. - - Raises: - AttributeError: If required attributes are missing from the tool module. - ImportError: If there are issues importing the tool module. - TypeError: If the tool function is not callable. - ValueError: If function in module is not a valid tool. - Exception: For other errors during tool loading. + This method always returns a list of AgentTool (possibly length 1). It is the + canonical API for retrieving multiple tools from a single Python file. """ try: - # Check if tool_path is in the format "package.module:function"; but keep in mind windows whose file path - # could have a colon so also ensure that it's not a file + # Support module:function style (e.g. package.module:function) if not os.path.exists(tool_path) and ":" in tool_path: module_path, function_name = tool_path.rsplit(":", 1) logger.debug("tool_name=<%s>, module_path=<%s> | importing tool from path", function_name, module_path) try: - # Import the module module = __import__(module_path, fromlist=["*"]) - - # Get the function - if not hasattr(module, function_name): - raise AttributeError(f"Module {module_path} has no function named {function_name}") - - func = getattr(module, function_name) - - if isinstance(func, DecoratedFunctionTool): - logger.debug( - "tool_name=<%s>, module_path=<%s> | found function-based tool", function_name, module_path - ) - # mypy has problems converting between DecoratedFunctionTool <-> AgentTool - return cast(AgentTool, func) - else: - raise ValueError( - f"Function {function_name} in {module_path} is not a valid tool (missing @tool decorator)" - ) - except ImportError as e: raise ImportError(f"Failed to import module {module_path}: {str(e)}") from e + if not hasattr(module, function_name): + raise AttributeError(f"Module {module_path} has no function named {function_name}") + + func = getattr(module, function_name) + if isinstance(func, DecoratedFunctionTool): + logger.debug( + "tool_name=<%s>, module_path=<%s> | found function-based tool", function_name, module_path + ) + return [cast(AgentTool, func)] + else: + raise ValueError( + f"Function {function_name} in {module_path} is not a valid tool (missing @tool decorator)" + ) + # Normal file-based tool loading abs_path = str(Path(tool_path).resolve()) - logger.debug("tool_path=<%s> | loading python tool from path", abs_path) - # First load the module to get TOOL_SPEC and check for Lambda deployment + # Load the module by spec spec = importlib.util.spec_from_file_location(tool_name, abs_path) if not spec: raise ImportError(f"Could not create spec for {tool_name}") @@ -82,24 +65,26 @@ def load_python_tool(tool_path: str, tool_name: str) -> AgentTool: sys.modules[tool_name] = module spec.loader.exec_module(module) - # First, check for function-based tools with @tool decorator + # Collect function-based tools decorated with @tool + function_tools: List[AgentTool] = [] for attr_name in dir(module): attr = getattr(module, attr_name) if isinstance(attr, DecoratedFunctionTool): logger.debug( "tool_name=<%s>, tool_path=<%s> | found function-based tool in path", attr_name, tool_path ) - # mypy has problems converting between DecoratedFunctionTool <-> AgentTool - return cast(AgentTool, attr) + function_tools.append(cast(AgentTool, attr)) + + if function_tools: + return function_tools - # If no function-based tools found, fall back to traditional module-level tool + # Fall back to module-level TOOL_SPEC + function tool_spec = getattr(module, "TOOL_SPEC", None) if not tool_spec: raise AttributeError( f"Tool {tool_name} missing TOOL_SPEC (neither at module level nor as a decorated function)" ) - # Standard local tool loading tool_func_name = tool_name if not hasattr(module, tool_func_name): raise AttributeError(f"Tool {tool_name} missing function {tool_func_name}") @@ -108,22 +93,61 @@ def load_python_tool(tool_path: str, tool_name: str) -> AgentTool: if not callable(tool_func): raise TypeError(f"Tool {tool_name} function is not callable") - return PythonAgentTool(tool_name, tool_spec, tool_func) + return [PythonAgentTool(tool_name, tool_spec, tool_func)] except Exception: - logger.exception("tool_name=<%s>, sys_path=<%s> | failed to load python tool", tool_name, sys.path) + logger.exception("tool_name=<%s>, sys_path=<%s> | failed to load python tool(s)", tool_name, sys.path) raise + @staticmethod + def load_python_tool(tool_path: str, tool_name: str) -> AgentTool: + """DEPRECATED: Load a Python tool module and return a single AgentTool for backwards compatibility. + + Use `load_python_tools` to retrieve all tools defined in a .py file (returns a list). + This function will emit a `DeprecationWarning` and return the first discovered tool. + """ + warnings.warn( + "ToolLoader.load_python_tool is deprecated and will be removed in Strands SDK 2.0. " + "Use ToolLoader.load_python_tools(...) which always returns a list of AgentTool.", + DeprecationWarning, + stacklevel=2, + ) + + tools = ToolLoader.load_python_tools(tool_path, tool_name) + if not tools: + raise RuntimeError(f"No tools found in {tool_path} for {tool_name}") + return tools[0] + @classmethod def load_tool(cls, tool_path: str, tool_name: str) -> AgentTool: - """Load a tool based on its file extension. + """DEPRECATED: Load a single tool based on its file extension for backwards compatibility. + + Use `load_tools` to retrieve all tools defined in a file (returns a list). + This function will emit a `DeprecationWarning` and return the first discovered tool. + """ + warnings.warn( + "ToolLoader.load_tool is deprecated and will be removed in Strands SDK 2.0. " + "Use ToolLoader.load_tools(...) which always returns a list of AgentTool.", + DeprecationWarning, + stacklevel=2, + ) + + tools = ToolLoader.load_tools(tool_path, tool_name) + if not tools: + raise RuntimeError(f"No tools found in {tool_path} for {tool_name}") + + return tools[0] + + @classmethod + def load_tools(cls, tool_path: str, tool_name: str) -> list[AgentTool]: + """Load tools from a file based on its file extension. Args: tool_path: Path to the tool file. tool_name: Name of the tool. Returns: - Tool instance. + A single Tool instance. Raises: FileNotFoundError: If the tool file does not exist. @@ -138,7 +162,7 @@ def load_tool(cls, tool_path: str, tool_name: str) -> AgentTool: try: if ext == ".py": - return cls.load_python_tool(abs_path, tool_name) + return cls.load_python_tools(abs_path, tool_name) else: raise ValueError(f"Unsupported tool file type: {ext}") except Exception: diff --git a/src/strands/tools/mcp/mcp_client.py b/src/strands/tools/mcp/mcp_client.py index 7cb03e46f..5d9dd0b0f 100644 --- a/src/strands/tools/mcp/mcp_client.py +++ b/src/strands/tools/mcp/mcp_client.py @@ -318,10 +318,12 @@ def _handle_tool_result(self, tool_use_id: str, call_tool_result: MCPCallToolRes """ self._log_debug_with_thread("received tool result with %d content items", len(call_tool_result.content)) - mapped_content = [ - mapped_content + # Build a typed list of ToolResultContent. Use a clearer local name to avoid shadowing + # and annotate the result for mypy so it knows the intended element type. + mapped_contents: list[ToolResultContent] = [ + mc for content in call_tool_result.content - if (mapped_content := self._map_mcp_content_to_tool_result_content(content)) is not None + if (mc := self._map_mcp_content_to_tool_result_content(content)) is not None ] status: ToolResultStatus = "error" if call_tool_result.isError else "success" @@ -329,8 +331,9 @@ def _handle_tool_result(self, tool_use_id: str, call_tool_result: MCPCallToolRes result = MCPToolResult( status=status, toolUseId=tool_use_id, - content=mapped_content, + content=mapped_contents, ) + if call_tool_result.structuredContent: result["structuredContent"] = call_tool_result.structuredContent diff --git a/src/strands/tools/registry.py b/src/strands/tools/registry.py index 471472a64..0660337a2 100644 --- a/src/strands/tools/registry.py +++ b/src/strands/tools/registry.py @@ -127,11 +127,11 @@ def load_tool_from_filepath(self, tool_name: str, tool_path: str) -> None: if not os.path.exists(tool_path): raise FileNotFoundError(f"Tool file not found: {tool_path}") - loaded_tool = ToolLoader.load_tool(tool_path, tool_name) - loaded_tool.mark_dynamic() - - # Because we're explicitly registering the tool we don't need an allowlist - self.register_tool(loaded_tool) + loaded_tools = ToolLoader.load_tools(tool_path, tool_name) + for t in loaded_tools: + t.mark_dynamic() + # Because we're explicitly registering the tool we don't need an allowlist + self.register_tool(t) except Exception as e: exception_str = str(e) logger.exception("tool_name=<%s> | failed to load tool", tool_name) diff --git a/tests/strands/tools/test_loader.py b/tests/strands/tools/test_loader.py index c1b4d7040..6b86d00ee 100644 --- a/tests/strands/tools/test_loader.py +++ b/tests/strands/tools/test_loader.py @@ -235,3 +235,78 @@ def no_spec(): def test_load_tool_no_spec(tool_path): with pytest.raises(AttributeError, match="Tool no_spec missing TOOL_SPEC"): ToolLoader.load_tool(tool_path, "no_spec") + + with pytest.raises(AttributeError, match="Tool no_spec missing TOOL_SPEC"): + ToolLoader.load_tools(tool_path, "no_spec") + + with pytest.raises(AttributeError, match="Tool no_spec missing TOOL_SPEC"): + ToolLoader.load_python_tool(tool_path, "no_spec") + + with pytest.raises(AttributeError, match="Tool no_spec missing TOOL_SPEC"): + ToolLoader.load_python_tools(tool_path, "no_spec") + + +@pytest.mark.parametrize( + "tool_path", + [ + textwrap.dedent( + """ + import strands + + @strands.tools.tool + def alpha(): + return "alpha" + + @strands.tools.tool + def bravo(): + return "bravo" + """ + ) + ], + indirect=True, +) +def test_load_python_tool_path_multiple_function_based(tool_path): + # load_python_tools, load_tools returns a list when multiple decorated tools are present + loaded_python_tools = ToolLoader.load_python_tools(tool_path, "alpha") + + assert isinstance(loaded_python_tools, list) + assert len(loaded_python_tools) == 2 + assert all(isinstance(t, DecoratedFunctionTool) for t in loaded_python_tools) + names = {t.tool_name for t in loaded_python_tools} + assert names == {"alpha", "bravo"} + + loaded_tools = ToolLoader.load_tools(tool_path, "alpha") + + assert isinstance(loaded_tools, list) + assert len(loaded_tools) == 2 + assert all(isinstance(t, DecoratedFunctionTool) for t in loaded_tools) + names = {t.tool_name for t in loaded_tools} + assert names == {"alpha", "bravo"} + + +@pytest.mark.parametrize( + "tool_path", + [ + textwrap.dedent( + """ + import strands + + @strands.tools.tool + def alpha(): + return "alpha" + + @strands.tools.tool + def bravo(): + return "bravo" + """ + ) + ], + indirect=True, +) +def test_load_tool_path_returns_single_tool(tool_path): + # loaded_python_tool and loaded_tool returns single item + loaded_python_tool = ToolLoader.load_python_tool(tool_path, "alpha") + loaded_tool = ToolLoader.load_tool(tool_path, "alpha") + + assert loaded_python_tool.tool_name == "alpha" + assert loaded_tool.tool_name == "alpha"