From ff2c94f7f0a87346e916ff6ec7cbcbe4fd44a64b Mon Sep 17 00:00:00 2001 From: Nick Clegg Date: Tue, 1 Jul 2025 15:28:35 +0000 Subject: [PATCH] refactor: Remove unused code --- src/strands/tools/loader.py | 85 +----------------- src/strands/tools/registry.py | 31 ++++++- tests/strands/tools/test_loader.py | 124 --------------------------- tests/strands/tools/test_registry.py | 32 +++++++ 4 files changed, 60 insertions(+), 212 deletions(-) diff --git a/src/strands/tools/loader.py b/src/strands/tools/loader.py index 1b3cfddbc..7bf5c5e75 100644 --- a/src/strands/tools/loader.py +++ b/src/strands/tools/loader.py @@ -1,12 +1,11 @@ """Tool loading utilities.""" import importlib -import inspect import logging import os import sys from pathlib import Path -from typing import Any, Dict, List, Optional, cast +from typing import cast from ..types.tools import AgentTool from .decorator import DecoratedFunctionTool @@ -15,88 +14,6 @@ logger = logging.getLogger(__name__) -def load_function_tool(func: Any) -> Optional[DecoratedFunctionTool]: - """Load a function as a tool if it's decorated with @tool. - - Args: - func: The function to load. - - Returns: - FunctionTool if successful, None otherwise. - """ - logger.warning( - "issue=<%s> | load_function_tool will be removed in a future version", - "https://github.com/strands-agents/sdk-python/pull/258", - ) - - if isinstance(func, DecoratedFunctionTool): - return func - else: - return None - - -def scan_module_for_tools(module: Any) -> List[DecoratedFunctionTool]: - """Scan a module for function-based tools. - - Args: - module: The module to scan. - - Returns: - List of FunctionTool instances found in the module. - """ - tools = [] - - for name, obj in inspect.getmembers(module): - if isinstance(obj, DecoratedFunctionTool): - # Create a function tool with correct name - try: - tools.append(obj) - except Exception as e: - logger.warning("tool_name=<%s> | failed to create function tool | %s", name, e) - - return tools - - -def scan_directory_for_tools(directory: Path) -> Dict[str, DecoratedFunctionTool]: - """Scan a directory for Python modules containing function-based tools. - - Args: - directory: The directory to scan. - - Returns: - Dictionary mapping tool names to FunctionTool instances. - """ - tools: Dict[str, DecoratedFunctionTool] = {} - - if not directory.exists() or not directory.is_dir(): - return tools - - for file_path in directory.glob("*.py"): - if file_path.name.startswith("_"): - continue - - try: - # Dynamically import the module - module_name = file_path.stem - spec = importlib.util.spec_from_file_location(module_name, file_path) - if not spec or not spec.loader: - continue - - module = importlib.util.module_from_spec(spec) - spec.loader.exec_module(module) - - # Find tools in the module - for attr_name in dir(module): - attr = getattr(module, attr_name) - if isinstance(attr, DecoratedFunctionTool): - tools[attr.tool_name] = attr - - except Exception as e: - logger.warning("tool_path=<%s> | failed to load tools under path | %s", file_path, e) - - return tools - - class ToolLoader: """Handles loading of tools from different sources.""" diff --git a/src/strands/tools/registry.py b/src/strands/tools/registry.py index 5e335ff2b..5ab611e0c 100644 --- a/src/strands/tools/registry.py +++ b/src/strands/tools/registry.py @@ -15,8 +15,9 @@ from typing_extensions import TypedDict, cast +from strands.tools.decorator import DecoratedFunctionTool + from ..types.tools import AgentTool, Tool, ToolChoice, ToolChoiceAuto, ToolConfig, ToolSpec -from .loader import scan_module_for_tools from .tools import PythonAgentTool, normalize_schema, normalize_tool_spec logger = logging.getLogger(__name__) @@ -84,7 +85,7 @@ def process_tools(self, tools: List[Any]) -> List[str]: self.load_tool_from_filepath(tool_name=tool_name, tool_path=module_path) tool_names.append(tool_name) else: - function_tools = scan_module_for_tools(tool) + function_tools = self._scan_module_for_tools(tool) for function_tool in function_tools: self.register_tool(function_tool) tool_names.append(function_tool.tool_name) @@ -313,7 +314,7 @@ def reload_tool(self, tool_name: str) -> None: # Look for function-based tools first try: - function_tools = scan_module_for_tools(module) + function_tools = self._scan_module_for_tools(module) if function_tools: for function_tool in function_tools: @@ -400,7 +401,7 @@ def initialize_tools(self, load_tools_from_directory: bool = True) -> None: if tool_path.suffix == ".py": # Check for decorated function tools first try: - function_tools = scan_module_for_tools(module) + function_tools = self._scan_module_for_tools(module) if function_tools: for function_tool in function_tools: @@ -592,3 +593,25 @@ def _update_tool_config(self, tool_config: Dict[str, Any], new_tool: NewToolDict else: tool_config["tools"].append(new_tool_entry) logger.debug("tool_name=<%s> | added new tool", new_tool_name) + + def _scan_module_for_tools(self, module: Any) -> List[AgentTool]: + """Scan a module for function-based tools. + + Args: + module: The module to scan. + + Returns: + List of FunctionTool instances found in the module. + """ + tools: List[AgentTool] = [] + + for name, obj in inspect.getmembers(module): + if isinstance(obj, DecoratedFunctionTool): + # Create a function tool with correct name + try: + # Cast as AgentTool for mypy + tools.append(cast(AgentTool, obj)) + except Exception as e: + logger.warning("tool_name=<%s> | failed to create function tool | %s", name, e) + + return tools diff --git a/tests/strands/tools/test_loader.py b/tests/strands/tools/test_loader.py index 4f600e430..c1b4d7040 100644 --- a/tests/strands/tools/test_loader.py +++ b/tests/strands/tools/test_loader.py @@ -1,138 +1,14 @@ import os -import pathlib import re import textwrap -import unittest.mock import pytest -import strands from strands.tools.decorator import DecoratedFunctionTool from strands.tools.loader import ToolLoader from strands.tools.tools import PythonAgentTool -def test_load_function_tool(): - @strands.tools.tool - def tool_function(a): - return a - - tool = strands.tools.loader.load_function_tool(tool_function) - - assert isinstance(tool, DecoratedFunctionTool) - - -def test_load_function_tool_no_function(): - tool = strands.tools.loader.load_function_tool("no_function") - - assert tool is None - - -def test_load_function_tool_no_spec(): - def tool_function(a): - return a - - tool = strands.tools.loader.load_function_tool(tool_function) - - assert tool is None - - -def test_load_function_tool_invalid(): - def tool_function(a): - return a - - tool_function.TOOL_SPEC = "invalid" - - tool = strands.tools.loader.load_function_tool(tool_function) - - assert tool is None - - -def test_scan_module_for_tools(): - @strands.tools.tool - def tool_function_1(a): - return a - - @strands.tools.tool - def tool_function_2(b): - return b - - def tool_function_3(c): - return c - - def tool_function_4(d): - return d - - tool_function_4.tool_spec = "invalid" - - mock_module = unittest.mock.MagicMock() - mock_module.tool_function_1 = tool_function_1 - mock_module.tool_function_2 = tool_function_2 - mock_module.tool_function_3 = tool_function_3 - mock_module.tool_function_4 = tool_function_4 - - tools = strands.tools.loader.scan_module_for_tools(mock_module) - - assert len(tools) == 2 - assert all(isinstance(tool, DecoratedFunctionTool) for tool in tools) - - -def test_scan_directory_for_tools(tmp_path): - tool_definition_1 = textwrap.dedent(""" - import strands - - @strands.tools.tool - def tool_function_1(a): - return a - """) - tool_definition_2 = textwrap.dedent(""" - import strands - - @strands.tools.tool - def tool_function_2(b): - return b - """) - tool_definition_3 = textwrap.dedent(""" - def tool_function_3(c): - return c - """) - tool_definition_4 = textwrap.dedent(""" - def tool_function_4(d): - return d - """) - tool_definition_5 = "" - tool_definition_6 = "**invalid**" - - tool_path_1 = tmp_path / "tool_1.py" - tool_path_2 = tmp_path / "tool_2.py" - tool_path_3 = tmp_path / "tool_3.py" - tool_path_4 = tmp_path / "tool_4.py" - tool_path_5 = tmp_path / "_tool_5.py" - tool_path_6 = tmp_path / "tool_6.py" - - tool_path_1.write_text(tool_definition_1) - tool_path_2.write_text(tool_definition_2) - tool_path_3.write_text(tool_definition_3) - tool_path_4.write_text(tool_definition_4) - tool_path_5.write_text(tool_definition_5) - tool_path_6.write_text(tool_definition_6) - - tools = strands.tools.loader.scan_directory_for_tools(tmp_path) - - tru_tool_names = sorted(tools.keys()) - exp_tool_names = ["tool_function_1", "tool_function_2"] - - assert tru_tool_names == exp_tool_names - assert all(isinstance(tool, DecoratedFunctionTool) for tool in tools.values()) - - -def test_scan_directory_for_tools_does_not_exist(): - tru_tools = strands.tools.loader.scan_directory_for_tools(pathlib.Path("does_not_exist")) - exp_tools = {} - - assert tru_tools == exp_tools - - @pytest.fixture def tool_path(request, tmp_path, monkeypatch): definition = request.param diff --git a/tests/strands/tools/test_registry.py b/tests/strands/tools/test_registry.py index 1b274f46b..bfdc2a47d 100644 --- a/tests/strands/tools/test_registry.py +++ b/tests/strands/tools/test_registry.py @@ -7,6 +7,7 @@ import pytest from strands.tools import PythonAgentTool +from strands.tools.decorator import DecoratedFunctionTool, tool from strands.tools.registry import ToolRegistry @@ -43,3 +44,34 @@ def test_register_tool_with_similar_name_raises(): str(err.value) == "Tool name 'tool_like_this' already exists as 'tool-like-this'. " "Cannot add a duplicate tool which differs by a '-' or '_'" ) + + +def test_scan_module_for_tools(): + @tool + def tool_function_1(a): + return a + + @tool + def tool_function_2(b): + return b + + def tool_function_3(c): + return c + + def tool_function_4(d): + return d + + tool_function_4.tool_spec = "invalid" + + mock_module = MagicMock() + mock_module.tool_function_1 = tool_function_1 + mock_module.tool_function_2 = tool_function_2 + mock_module.tool_function_3 = tool_function_3 + mock_module.tool_function_4 = tool_function_4 + + tool_registry = ToolRegistry() + + tools = tool_registry._scan_module_for_tools(mock_module) + + assert len(tools) == 2 + assert all(isinstance(tool, DecoratedFunctionTool) for tool in tools)