diff --git a/langchain_benchmarks/tool_usage/agents/__init__.py b/langchain_benchmarks/tool_usage/agents/__init__.py
index 4e9f2896..4ad7c7ee 100644
--- a/langchain_benchmarks/tool_usage/agents/__init__.py
+++ b/langchain_benchmarks/tool_usage/agents/__init__.py
@@ -1,7 +1,15 @@
from langchain_benchmarks.tool_usage.agents.adapters import apply_agent_executor_adapter
+from langchain_benchmarks.tool_usage.agents.anthropic_tool_user import (
+ AnthropicToolUserFactory,
+)
from langchain_benchmarks.tool_usage.agents.experimental.factory import (
CustomAgentFactory,
)
from langchain_benchmarks.tool_usage.agents.openai_functions import OpenAIAgentFactory
-__all__ = ["OpenAIAgentFactory", "apply_agent_executor_adapter", "CustomAgentFactory"]
+__all__ = [
+ "OpenAIAgentFactory",
+ "apply_agent_executor_adapter",
+ "CustomAgentFactory",
+ "AnthropicToolUserFactory",
+]
diff --git a/langchain_benchmarks/tool_usage/agents/anthropic_tool_agent.py b/langchain_benchmarks/tool_usage/agents/anthropic_tool_agent.py
new file mode 100644
index 00000000..df698c86
--- /dev/null
+++ b/langchain_benchmarks/tool_usage/agents/anthropic_tool_agent.py
@@ -0,0 +1,228 @@
+"""
+Module contains re-implementation of the anthropic tool agent SDK using
+langchain primitives.
+"""
+import re
+from typing import Dict, Optional, Union
+from typing import List, Sequence, Tuple
+
+import xmltodict
+from langchain.agents import AgentOutputParser
+from langchain.prompts.chat import ChatPromptTemplate
+from langchain.pydantic_v1 import BaseModel, Field
+from langchain.schema.runnable import Runnable
+from langchain.tools import StructuredTool
+from langchain_core.agents import AgentAction, AgentActionMessageLog, AgentFinish
+from langchain_core.exceptions import OutputParserException
+from langchain_core.language_models import BaseChatModel, BaseLanguageModel
+from langchain_core.messages import AIMessage
+from langchain_core.messages import BaseMessage, HumanMessage
+from langchain_core.prompts import MessagesPlaceholder
+from typing_extensions import NotRequired, TypedDict
+
+from langchain_benchmarks import RateLimiter
+from langchain_benchmarks.rate_limiting import with_rate_limit
+from langchain_benchmarks.tool_usage.agents.experimental.encoder import (
+ AstPrinter,
+ FunctionResult,
+ AnthropicXMLEncoder,
+)
+from langchain_benchmarks.tool_usage.agents.experimental.prompts import (
+ _ANTHROPIC_TOOL_USER_PROMPT,
+)
+from langchain_benchmarks.tool_usage.agents.experimental.tool_utils import (
+ convert_tool_to_function_definition,
+)
+
+
+class _ToolInvocationRequest(BaseModel):
+ """Light-weight pydantic model for validating the raw tool invocation request.
+
+ The purpose of this model, is to make sure that whatever as parsed from
+ the raw llm output has `tool_name` and potential `arguments` fields, and
+ nothing else.
+ """
+
+ tool_name: str
+ # OK parameterless tools which do not take arguments
+ arguments: Optional[Dict] = Field(default_factory=dict)
+
+
+class AnthropicToolParser(AgentOutputParser):
+ """A generalized parser that makes it easier to parameterize different parsing."""
+
+ def parse(self, text: str) -> Union[AgentFinish, AgentAction]:
+ """Parse the output of the agent."""
+ wrapping_xml_tag = "function_calls"
+ open_tag = f"<{wrapping_xml_tag}>"
+ close_tag = f"{wrapping_xml_tag}>"
+ if open_tag in text:
+ # This is a hack to make sure that is always present
+ # in the output if . may be a stop sequence for the
+ # language model, so depending on implementation
+ # the stop sequence may be cut off.
+ # There might be a better way to do this, but this works and
+ # is simple.
+ if not self.require_closing_xml_tag:
+ text += close_tag
+
+ pattern = rf"{open_tag}(?P.*?){close_tag}"
+ match = re.search(pattern, text, re.DOTALL)
+ if match:
+ content = match.group("invocation").strip()
+ return parse_invocation(content, self.wrapping_xml_tag)
+
+ return AgentFinish(
+ log=text,
+ return_values={
+ "output": text,
+ },
+ )
+
+
+def parse_invocation(text: str, tag: str) -> AgentAction:
+ """Parse the content of the function invocation.
+
+ Args:
+ text: The text to parse.
+ tag: The tag that wraps the function invocation request.
+
+ Returns:
+ An AgentAction that corresponds to the function invocation.
+
+ Raises:
+ OutputParserException: If the parsing fails.
+
+ This exception is meant to be caught by the agent executor and
+ handled appropriately to provide feedback to the LLM.
+ """
+
+ ai_content = f"<{tag}>{text}{tag}>\n"
+ try:
+ function_calls = xmltodict.parse(ai_content, force_list=("function_calls",))
+ except Exception as e:
+ # Convert this to something controllable by the user.
+ err_msg = ()
+
+ raise OutputParserException(
+ error=e,
+ llm_output=ai_content,
+ observation=err_msg,
+ send_to_llm=True,
+ )
+
+ try:
+ request = _ToolInvocationRequest.validate(result)
+ except Exception as e: # Using broad exception since it's not just ValidationError
+ # Can also raise DictError if result is not a dict.
+ err_msg = (
+ f"ERROR: Please use the format "
+ f'<{tag}>{{"tool_name": $TOOL_NAME, "arguments": $ARGUMENTS}}{tag}>\n'
+ )
+ raise OutputParserException(
+ error=e,
+ llm_output=ai_content,
+ send_to_llm=True,
+ observation=err_msg,
+ )
+
+ return AgentActionMessageLog(
+ message_log=[AIMessage(content=ai_content)],
+ tool=request.tool_name,
+ tool_input=request.arguments,
+ log=f"\nInvoking {request.tool_name}: {request.arguments}\n\t",
+ )
+
+
+def format_steps_for_chat(
+ intermediate_steps: List[Tuple[AgentAction, str]],
+ ast_printer: AstPrinter,
+) -> List[BaseMessage]:
+ """Format the steps."""
+ messages = []
+ for action, observation in intermediate_steps:
+ # Action messages contains the tool invocation request from the LLM
+ # Now add the result of the tool invocation.
+
+ if action.tool == "_Exception":
+ messages.append(
+ AIMessage(
+ content=action.log,
+ )
+ )
+ messages.append(
+ # Tool input is the error message for the exception
+ HumanMessage(content=action.tool_input)
+ )
+ else:
+ messages.extend(action.messages)
+ function_result: FunctionResult = {
+ "name": action.tool,
+ "error": None,
+ "result": observation,
+ }
+ messages.append(
+ HumanMessage(
+ content=ast_printer.visit_function_result(function_result),
+ )
+ )
+
+ return messages
+
+
+# PUBLIC API
+
+
+class AgentInput(TypedDict):
+ """The input to the agent."""
+
+ input: str
+ """The input to the agent."""
+ intermediate_steps: List[Tuple[AgentAction, str]]
+ """The intermediate steps taken by the agent."""
+ examples: NotRequired[List[BaseMessage]]
+ """A list of messages that can be used to form example traces."""
+
+
+def create_agent(
+ model: Union[BaseChatModel, BaseLanguageModel],
+ tools: Sequence[StructuredTool],
+ parser: AgentOutputParser,
+ *,
+ rate_limiter: Optional[RateLimiter] = None,
+) -> Runnable[AgentInput, Union[AgentAction, AgentFinish]]:
+ """Create an agent for a chat model."""
+
+ function_definitions = [convert_tool_to_function_definition(tool) for tool in tools]
+ ast_printer_ = AnthropicXMLEncoder()
+ tool_description = ast_printer_.visit_function_definitions(function_definitions)
+
+ template = ChatPromptTemplate.from_messages(
+ [
+ ("system", _ANTHROPIC_TOOL_USER_PROMPT),
+ MessagesPlaceholder("examples"), # Can use to add example traces
+ ("human", "{input}"),
+ MessagesPlaceholder("history"),
+ ]
+ ).partial(tool_description=tool_description)
+
+ # For the time being, hard-coding the fact that we're using a tag.
+ model = model.bind(stop=[""])
+
+ if rate_limiter:
+ # Apply a rate limiter if it was provided
+ model = with_rate_limit(model, rate_limiter)
+
+ agent = (
+ {
+ "input": lambda x: x["input"],
+ "history": lambda x: format_steps_for_chat(
+ x["intermediate_steps"], ast_printer_
+ ),
+ "examples": lambda x: x.get("examples", []),
+ }
+ | template
+ | model
+ | parser
+ )
+ return agent
diff --git a/langchain_benchmarks/tool_usage/agents/experimental/encoder.py b/langchain_benchmarks/tool_usage/agents/experimental/encoder.py
index c6799609..ba82df96 100644
--- a/langchain_benchmarks/tool_usage/agents/experimental/encoder.py
+++ b/langchain_benchmarks/tool_usage/agents/experimental/encoder.py
@@ -74,10 +74,20 @@ def visit_function_definitions(
def visit_function_invocation(self, function_invocation: FunctionInvocation) -> str:
"""Render a function invocation."""
+ @abc.abstractmethod
+ def visit_function_invocations(
+ self, function_invocations: List[FunctionInvocation]
+ ) -> str:
+ """Render a function invocation."""
+
@abc.abstractmethod
def visit_function_result(self, function_result: FunctionResult) -> str:
"""Render a function result."""
+ @abc.abstractmethod
+ def visit_function_results(self, function_results: List[FunctionResult]) -> str:
+ """Render a function result."""
+
class AstPrinter(Visitor):
"""Print the AST."""
@@ -154,6 +164,18 @@ def visit_function_invocation(self, invocation: FunctionInvocation) -> str:
)
return "\n".join(lines)
+ def visit_function_invocations(
+ self, function_invocations: List[FunctionInvocation]
+ ) -> str:
+ """Render a function invocation."""
+ strs = [
+ self.visit_function_invocation(function_invocation)
+ for function_invocation in function_invocations
+ ]
+ return (
+ "\n" + "\n".join(strs) + "\n"
+ )
+
def visit_function_result(self, function_result: FunctionResult) -> str:
"""Render a function result."""
lines = [
@@ -180,6 +202,158 @@ def visit_function_result(self, function_result: FunctionResult) -> str:
return "\n".join(lines)
+ def visit_function_results(self, function_results: List[FunctionResult]) -> str:
+ """Render a function result."""
+ strs = [
+ self.visit_function_result(function_result)
+ for function_result in function_results
+ ]
+ return "\n" + "\n".join(strs) + "\n"
+
+
+class AnthropicXMLEncoder(AstPrinter):
+ """Adapter for Anthropic tool usage api.
+
+ As described here: https://github.com/anthropics/anthropic-tools/tree/main
+ """
+
+ def visit_function_definition(self, function_definition: FunctionDefinition) -> str:
+ """Render a function.
+
+ Function definition example:
+
+
+ get_time_of_day
+
+ get_time_of_day(time_zone: str) -> str - Retrieve the current time of day
+
+ Args:
+ time_zone: The time zone to get the current time for,
+
+ Returns:
+ time format
+
+
+
+ time_zone
+ str
+
+
+
+
+ """
+ parameters_lines = []
+
+ for parameter in function_definition["parameters"]:
+ parameters_lines.extend(
+ [
+ "",
+ f"{parameter['name']}",
+ f"{parameter['type']}",
+ f"{parameter['description']}",
+ "",
+ ]
+ )
+ lines = [
+ "",
+ f"{function_definition['name']}",
+ "",
+ f"{function_definition['description']}",
+ "",
+ "",
+ *parameters_lines,
+ "",
+ "",
+ ]
+ return "\n".join(lines)
+
+ def visit_function_definitions(
+ self, function_definitions: List[FunctionDefinition]
+ ) -> str:
+ """Render a function."""
+ strs = [
+ self.visit_function_definition(function_definition)
+ for function_definition in function_definitions
+ ]
+
+ lines = [
+ "",
+ *strs,
+ "",
+ ]
+ return "\n".join(lines)
+
+ def visit_function_invocation(self, invocation: FunctionInvocation) -> str:
+ """Render a function invocation.
+
+
+ get_time_of_day
+
+ UTC
+
+
+ """
+ arguments_as_strings = [
+ f"<{argument['name']}>{argument['value']}{argument['name']}>"
+ for argument in invocation["arguments"]
+ ]
+ lines = [
+ "",
+ f"{invocation['name']}",
+ "",
+ *arguments_as_strings,
+ "",
+ "",
+ ]
+ return "\n".join(lines)
+
+ def visit_function_invocations(self, invocations: List[FunctionInvocation]) -> str:
+ """Render a function invocation."""
+ strs = [
+ self.visit_function_invocation(invocation) for invocation in invocations
+ ]
+
+ lines = [
+ "",
+ *strs,
+ "",
+ ]
+ return "\n".join(lines)
+
+ def visit_function_result(self, function_result: FunctionResult) -> str:
+ """Render a function result.
+
+
+
+ get_time_of_day
+
+ 02:57:27
+
+
+
+ """
+ lines = [
+ "",
+ f"{function_result['name']}",
+ f"{function_result['result']}",
+ "",
+ ]
+ return "\n".join(lines)
+
+ def visit_function_results(self, function_results: List[FunctionResult]) -> str:
+ """Render a function result."""
+ strs = [
+ self.visit_function_result(function_result)
+ for function_result in function_results
+ ]
+
+ lines = [
+ "",
+ *strs,
+ "",
+ ]
+ return "\n".join(lines)
+
class TypeScriptEncoder(AstPrinter):
def visit_function_definition(self, function_definition: FunctionDefinition) -> str:
@@ -238,3 +412,18 @@ def visit_function_result(self, function_result: FunctionResult) -> str:
if function_result.get("id"):
lines.append(f"// ID: {function_result['id']}")
return "\n".join(lines)
+
+ def visit_function_results(self, function_results: List[FunctionResult]) -> str:
+ """Render a function result."""
+ strs = [
+ self.visit_function_result(function_result)
+ for function_result in function_results
+ ]
+ return "\n".join(strs)
+
+ def visit_function_invocations(self, invocations: List[FunctionInvocation]) -> str:
+ """Render a function invocation."""
+ strs = [
+ self.visit_function_invocation(invocation) for invocation in invocations
+ ]
+ return "\n".join(strs)
diff --git a/langchain_benchmarks/tool_usage/agents/experimental/prompts.py b/langchain_benchmarks/tool_usage/agents/experimental/prompts.py
index 9abc051e..6d7b67c9 100644
--- a/langchain_benchmarks/tool_usage/agents/experimental/prompts.py
+++ b/langchain_benchmarks/tool_usage/agents/experimental/prompts.py
@@ -1,4 +1,4 @@
-AGENT_INSTRUCTIONS_XML_FORMAT = """\
+_ANTHROPIC_TOOL_USER_PROMPT = """\
In this environment you have access to a set of tools you can use to answer the user's question.
You may call them like this:
diff --git a/poetry.lock b/poetry.lock
index 2ec957bd..e4db049b 100644
--- a/poetry.lock
+++ b/poetry.lock
@@ -3860,6 +3860,17 @@ files = [
{file = "widgetsnbextension-4.0.9.tar.gz", hash = "sha256:3c1f5e46dc1166dfd40a42d685e6a51396fd34ff878742a3e47c6f0cc4a2a385"},
]
+[[package]]
+name = "xmltodict"
+version = "0.13.0"
+description = "Makes working with XML feel like you are working with JSON"
+optional = false
+python-versions = ">=3.4"
+files = [
+ {file = "xmltodict-0.13.0-py2.py3-none-any.whl", hash = "sha256:aa89e8fd76320154a40d19a0df04a4695fb9dc5ba977cbb68ab3e4eb225e7852"},
+ {file = "xmltodict-0.13.0.tar.gz", hash = "sha256:341595a488e3e01a85a9d8911d8912fd922ede5fecc4dce437eb4b6c8d037e56"},
+]
+
[[package]]
name = "y-py"
version = "0.6.2"
@@ -4083,4 +4094,4 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p
[metadata]
lock-version = "2.0"
python-versions = "^3.8.1"
-content-hash = "91171e1e590780b3d7df5efcf5eaddddabbe2715294add5ccf14f52cd3fa3b6d"
+content-hash = "f01a0553fe50c69a84eb318ce208dcbc61edd208286c22ab294aaac242b508dc"
diff --git a/pyproject.toml b/pyproject.toml
index f62f52d3..ff1211c0 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -13,6 +13,7 @@ langsmith = ">=0.0.70"
tqdm = "^4"
ipywidgets = "^8"
tabulate = ">=0.8.0"
+xmltodict = "^0.13.0"
[tool.poetry.group.dev.dependencies]
jupyterlab = "^3.6.1"
diff --git a/tests/unit_tests/agents/test_anthropic_tool_parsing.py b/tests/unit_tests/agents/test_anthropic_tool_parsing.py
new file mode 100644
index 00000000..f53e66bd
--- /dev/null
+++ b/tests/unit_tests/agents/test_anthropic_tool_parsing.py
@@ -0,0 +1,17 @@
+from langchain_benchmarks.tool_usage.agents.anthropic_tool_agent import parse_invocation
+from xmltodict import parse
+
+
+def test_parse_invocation() -> None:
+ """Test parsing a tool invocation."""
+ invocation = parse_invocation(
+ """
+
+ get_time_of_day
+
+ UTC
+
+
+ """,
+ "function_calls"
+ )