Skip to content

Commit dda2e8c

Browse files
committed
feat: Address pr feedback
1 parent bad6f78 commit dda2e8c

File tree

7 files changed

+111
-21
lines changed

7 files changed

+111
-21
lines changed

src/strands/tools/loader.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ def load_tools_from_file_path(tool_path: str) -> List[AgentTool]:
5555
# ./path/to/my_cool_tool.py -> my_cool_tool
5656
module_name = os.path.basename(tool_path).split(".")[0]
5757

58-
# This function import a module based on its path, and gives it the provided name
58+
# This function imports a module based on its path, and gives it the provided name
5959

6060
spec: ModuleSpec = cast(ModuleSpec, importlib.util.spec_from_file_location(module_name, abs_path))
6161
if not spec:
@@ -72,31 +72,31 @@ def load_tools_from_file_path(tool_path: str) -> List[AgentTool]:
7272
return load_tools_from_module(module, module_name)
7373

7474

75-
def load_tools_from_module_path(module_path: str) -> list[AgentTool]:
75+
def load_tools_from_module_path(module_tool_path: str) -> list[AgentTool]:
7676
"""Load strands tool from a module path.
7777
7878
Example module paths:
7979
my.module.path
8080
my.module.path:tool_name
8181
"""
82-
if ":" in module_path:
83-
module_name, tool_func_name = module_path.split(":")
82+
if ":" in module_tool_path:
83+
module_path, tool_func_name = module_tool_path.split(":")
8484
else:
85-
module_name, tool_func_name = (module_path, None)
85+
module_path, tool_func_name = (module_tool_path, None)
8686

8787
try:
88-
module = importlib.import_module(module_name)
88+
module = importlib.import_module(module_path)
8989
except ModuleNotFoundError as e:
90-
raise AttributeError(f'Tool string: "{module_path}" is not a valid tool string.') from e
90+
raise AttributeError(f'Tool string: "{module_tool_path}" is not a valid tool string.') from e
9191

9292
# If a ':' is present in the string, then its a targeted function in a module
9393
if tool_func_name:
94-
if tool_func_name in dir(module):
94+
if hasattr(module, tool_func_name):
9595
target_tool = getattr(module, tool_func_name)
9696
if isinstance(target_tool, DecoratedFunctionTool):
9797
return [target_tool]
9898

99-
raise AttributeError(f"Tool {tool_func_name} not found in module {module_name}")
99+
raise AttributeError(f"Tool {tool_func_name} not found in module {module_path}")
100100

101101
# Else, try to import all of the @tool decorated tools, or the module based tool
102102
module_name = module_path.split(".")[-1]

src/strands/tools/registry.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -75,15 +75,6 @@ def add_tool(tool: Any) -> None:
7575
self.register_tool(a_tool)
7676
tool_names.append(a_tool.tool_name)
7777

78-
# Dictionary with path only
79-
elif isinstance(tool, dict) and "path" in tool:
80-
tools = load_tool_from_string(tool["path"])
81-
82-
for a_tool in tools:
83-
a_tool.mark_dynamic()
84-
self.register_tool(a_tool)
85-
tool_names.append(a_tool.tool_name)
86-
8778
# Dictionary with name and path
8879
elif isinstance(tool, dict) and "name" in tool and "path" in tool:
8980
tools = load_tool_from_string(tool["path"])
@@ -97,7 +88,16 @@ def add_tool(tool: Any) -> None:
9788
tool_found = True
9889

9990
if not tool_found:
100-
raise ValueError(f"Failed to load tool {tool}")
91+
raise ValueError(f'Tool "{tool["name"]}" not found in "{tool["path"]}"')
92+
93+
# Dictionary with path only
94+
elif isinstance(tool, dict) and "path" in tool:
95+
tools = load_tool_from_string(tool["path"])
96+
97+
for a_tool in tools:
98+
a_tool.mark_dynamic()
99+
self.register_tool(a_tool)
100+
tool_names.append(a_tool.tool_name)
101101

102102
# Imported Python module
103103
elif hasattr(tool, "__file__") and inspect.ismodule(tool):
@@ -131,7 +131,7 @@ def add_tool(tool: Any) -> None:
131131
return tool_names
132132

133133
def load_tool_from_filepath(self, tool_name: str, tool_path: str) -> None:
134-
"""Load a tool from a file path.
134+
"""DEPRECATED: Load a tool from a file path.
135135
136136
Args:
137137
tool_name: Name of the tool.

tests/fixtures/say_tool.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,3 +11,7 @@ def say(input: str) -> str:
1111
def dont_say(input: str) -> str:
1212
"""Dont say something."""
1313
return "Didnt say anything!"
14+
15+
16+
def not_a_tool() -> str:
17+
return "Not a tool!"
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
TOOL_SPEC = {"hello": "world!"}
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
TOOL_SPEC = {"hello": "world"}
2+
3+
tool_with_spec_but_non_callable_function = "not a function!"

tests/strands/tools/test_loader.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
import os
22
import re
3+
import tempfile
34
import textwrap
45

56
import pytest
67

78
from strands.tools.decorator import DecoratedFunctionTool
8-
from strands.tools.loader import ToolLoader
9+
from strands.tools.loader import ToolLoader, load_tools_from_file_path
910
from strands.tools.tools import PythonAgentTool
1011

1112

@@ -310,3 +311,9 @@ def test_load_tool_path_returns_single_tool(tool_path):
310311

311312
assert loaded_python_tool.tool_name == "alpha"
312313
assert loaded_tool.tool_name == "alpha"
314+
315+
316+
def test_load_tools_from_file_path_module_spec_missing():
317+
with tempfile.NamedTemporaryFile() as f:
318+
with pytest.raises(ImportError, match=f"Could not create spec for {os.path.basename(f.name)}"):
319+
load_tools_from_file_path(f.name)

tests/strands/tools/test_registry.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -185,3 +185,78 @@ def test_register_strands_tools_specific_tool_from_module():
185185
assert len(tool_registry.registry) == 1
186186
assert "say" in tool_registry.registry
187187
assert "dont_say" not in tool_registry.registry
188+
189+
190+
def test_register_strands_tools_specific_tool_from_module_tool_missing():
191+
tool_registry = ToolRegistry()
192+
193+
with pytest.raises(ValueError, match="Failed to load tool tests.fixtures.say_tool:nay: "):
194+
tool_registry.process_tools(["tests.fixtures.say_tool:nay"])
195+
196+
197+
def test_register_strands_tools_specific_tool_from_module_not_a_tool():
198+
tool_registry = ToolRegistry()
199+
200+
with pytest.raises(ValueError, match="Failed to load tool tests.fixtures.say_tool:not_a_tool: "):
201+
tool_registry.process_tools(["tests.fixtures.say_tool:not_a_tool"])
202+
203+
204+
def test_register_strands_tools_with_dict():
205+
tool_registry = ToolRegistry()
206+
tool_registry.process_tools([{"path": "tests.fixtures.say_tool"}])
207+
208+
assert len(tool_registry.registry) == 2
209+
assert "say" in tool_registry.registry
210+
assert "dont_say" in tool_registry.registry
211+
212+
213+
def test_register_strands_tools_specific_tool_with_dict():
214+
tool_registry = ToolRegistry()
215+
tool_registry.process_tools([{"path": "tests.fixtures.say_tool", "name": "say"}])
216+
217+
assert len(tool_registry.registry) == 1
218+
assert "say" in tool_registry.registry
219+
220+
221+
def test_register_strands_tools_specific_tool_with_dict_not_found():
222+
tool_registry = ToolRegistry()
223+
224+
with pytest.raises(
225+
ValueError,
226+
match="Failed to load tool {'path': 'tests.fixtures.say_tool'"
227+
", 'name': 'nay'}: Tool \"nay\" not found in \"tests.fixtures.say_tool\"",
228+
):
229+
tool_registry.process_tools([{"path": "tests.fixtures.say_tool", "name": "nay"}])
230+
231+
232+
def test_register_strands_tools_module_no_spec():
233+
tool_registry = ToolRegistry()
234+
235+
with pytest.raises(
236+
ValueError,
237+
match="Failed to load tool tests.fixtures.mocked_model_provider: "
238+
"The module mocked_model_provider is not a valid module",
239+
):
240+
tool_registry.process_tools(["tests.fixtures.mocked_model_provider"])
241+
242+
243+
def test_register_strands_tools_module_no_function():
244+
tool_registry = ToolRegistry()
245+
246+
with pytest.raises(
247+
ValueError,
248+
match="Failed to load tool tests.fixtures.tool_with_spec_but_no_function: "
249+
"Module-based tool tool_with_spec_but_no_function missing function tool_with_spec_but_no_function",
250+
):
251+
tool_registry.process_tools(["tests.fixtures.tool_with_spec_but_no_function"])
252+
253+
254+
def test_register_strands_tools_module_non_callable_function():
255+
tool_registry = ToolRegistry()
256+
257+
with pytest.raises(
258+
ValueError,
259+
match="Failed to load tool tests.fixtures.tool_with_spec_but_non_callable_function:"
260+
" Tool tool_with_spec_but_non_callable_function function is not callable",
261+
):
262+
tool_registry.process_tools(["tests.fixtures.tool_with_spec_but_non_callable_function"])

0 commit comments

Comments
 (0)