diff --git a/agents/__init__.py b/agents/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/agents/adapters.py b/agents/adapters.py new file mode 100644 index 00000000..a1dca755 --- /dev/null +++ b/agents/adapters.py @@ -0,0 +1,52 @@ +import inspect +from textwrap import dedent +from typing import List + +from langchain.tools.base import StructuredTool + +from agents.encoder import FunctionDefinition, Parameter + + +# This is temporary until we have a better way to represent parameters +def get_parameters_from_tool(tool: StructuredTool) -> List[Parameter]: + """Convert a langchain tool to a tool user tool.""" + schema = tool.args_schema.schema() + + properties = schema["properties"] + parameters = [] + # Is this needed or is string OK? + type_adapter = { + "string": "str", # str or string? + "integer": "int", + "number": "float", + "boolean": "bool", + } + for key, value in properties.items(): + parameters.append( + { + "name": key, + "type": type_adapter.get(value["type"], value["type"]), + "description": value.get("description", ""), + } + ) + + return parameters + + +# +def convert_tool_to_function_definition(tool: StructuredTool) -> FunctionDefinition: + """Convert a langchain tool to a tool user tool.""" + # Here we re-inspect the underlying function to get the doc-string + # since StructuredTool modifies it, but we want the raw one for maximum + # flexibility. + description = inspect.getdoc(tool.func) + + parameters = get_parameters_from_tool(tool) + return { + "name": tool.name, + "description": dedent(description), + "parameters": parameters, + "return_value": { + "type": "Any", + }, + } diff --git a/agents/agent.py b/agents/agent.py new file mode 100644 index 00000000..11509e0b --- /dev/null +++ b/agents/agent.py @@ -0,0 +1,105 @@ +from typing import List, Literal, Sequence, Tuple, Union + +from langchain.agents import AgentOutputParser +from langchain.prompts.chat import ChatPromptTemplate +from langchain.schema.messages import HumanMessage +from langchain.schema.runnable import Runnable +from langchain.tools import StructuredTool +from langchain_core.agents import AgentAction, AgentFinish +from langchain_core.language_models import BaseChatModel, BaseLanguageModel +from langchain_core.messages import BaseMessage +from langchain_core.prompts import MessagesPlaceholder +from typing_extensions import NotRequired, TypedDict + +from agents.adapters import convert_tool_to_function_definition +from agents.encoder import AstPrinter, TypeScriptEncoder +from agents.prompts import AGENT_INSTRUCTIONS_BLOB_STYLE + + +def format_observation(tool_name: str, observation: str) -> BaseMessage: + """Format the observation.""" + result = ( + "\n" + f"{tool_name}\n" + f"{observation}\n" + "" + ) + + return HumanMessage(content=result) + + +def format_steps_for_chat( + intermediate_steps: List[Tuple[AgentAction, str]] +) -> List[BaseMessage]: + """Format the steps.""" + messages = [] + for action, observation in intermediate_steps: + if not isinstance(action, AgentAction): + if action.tool != "_Exception": + raise AssertionError(f"Unexpected step: {action}. type: {type(action)}") + + messages.append(HumanMessage(content=observation)) + messages.extend(action.messages) + messages.append(format_observation(action.tool, observation)) + return messages + + +# PUBLIC API + + +class AgentInput(TypedDict): + """The input to the agent.""" + + input: str + """The input to the agent.""" + intermediate_steps: List[Tuple[AgentAction, str]] + """The intermediate steps taken by the agent.""" + examples: NotRequired[List[BaseMessage]] + """A list of messages that can be used to form example traces.""" + + +def create_agent( + model: Union[BaseChatModel, BaseLanguageModel], + tools: Sequence[StructuredTool], + parser: AgentOutputParser, + *, + ast_printer: Union[AstPrinter, Literal["xml"]] = "xml", +) -> Runnable[AgentInput, Union[AgentAction, AgentFinish]]: + """Create an agent for a chat model.""" + if isinstance(ast_printer, str): + if ast_printer == "xml": + ast_printer = AstPrinter() + elif ast_printer == "typescript": + ast_printer = TypeScriptEncoder() + else: + raise ValueError(f"Unknown ast printer: {ast_printer}") + elif isinstance(ast_printer, AstPrinter): + pass + else: + raise TypeError( + f"Expected AstPrinter or str, got {type(ast_printer)} for `ast_printer`" + ) + + function_definitions = [convert_tool_to_function_definition(tool) for tool in tools] + tool_description = ast_printer.visit_function_definitions(function_definitions) + + template = ChatPromptTemplate.from_messages( + [ + ("system", AGENT_INSTRUCTIONS_BLOB_STYLE), + MessagesPlaceholder("examples"), # Can use to add example traces + ("human", "{input}"), + MessagesPlaceholder("history"), + ] + ).partial(tool_description=tool_description) + + agent = ( + { + "input": lambda x: x["input"], + "history": lambda x: format_steps_for_chat(x["intermediate_steps"]), + "examples": lambda x: x.get("examples", []), + } + | template + | model.bind(stop=[""]) + | parser + ) + return agent diff --git a/agents/encoder.py b/agents/encoder.py new file mode 100644 index 00000000..fccc6cd4 --- /dev/null +++ b/agents/encoder.py @@ -0,0 +1,226 @@ +"""Prototyping code for rendering function definitions, invocations, and results. + +Types are simplified for now to `str`. + +We should actually support something like pydantic or jsonschema for the types, so +we can expand them recursively for nested types. +""" +import abc +from typing import Any, List, Optional + +from typing_extensions import NotRequired, TypedDict + + +class Parameter(TypedDict): + """Representation for a parameter.""" + + name: str + type: str + description: str + + +class Arguments(TypedDict): + """Arguments are passed to a function during function invocation.""" + + name: Optional[str] + value: Any + + +class ReturnValue(TypedDict): + """Representation for a return value of a function call.""" + + type: str + description: NotRequired[str] + + +class FunctionDefinition(TypedDict): + """Representation for a function.""" + + name: str + description: str # Function description + parameters: List[Parameter] + return_value: ReturnValue + + +class FunctionInvocation(TypedDict): + """Representation for a function invocation.""" + + id: NotRequired[str] + name: str + arguments: List[Arguments] + + +class FunctionResult(TypedDict): + """Representation for a function result.""" + + id: NotRequired[str] + name: str + result: Optional[str] + error: Optional[str] + + +class Visitor(abc.ABC): + def visit_function_definition(self, function_definition: FunctionDefinition) -> str: + """Render a function.""" + raise NotImplementedError() + + def visit_function_definitions( + self, function_definitions: List[FunctionDefinition] + ) -> str: + """Render a function.""" + raise NotImplementedError() + + def visit_function_invocation(self, function_invocation: FunctionInvocation) -> str: + """Render a function invocation.""" + raise NotImplementedError() + + def visit_function_result(self, function_result: FunctionResult) -> str: + """Render a function result.""" + raise NotImplementedError() + + +class AstPrinter(Visitor): + """Print the AST.""" + + +class XMLEncoder(AstPrinter): + def visit_function_definition(self, function_definition: FunctionDefinition) -> str: + """Render a function.""" + parameters_as_strings = [ + "\n" + f"{parameter['name']}\n" + f"{parameter['type']}\n" + f"{parameter['description']}\n" + "\n" + for parameter in function_definition["parameters"] + ] + function = ( + "\n" + f"{function_definition['name']}\n" + "\n" + f"{function_definition['description']}\n" + "\n" + "\n" + f"{''.join(parameters_as_strings)}" # Already includes trailing newline + "\n" + "\n" + f"{function_definition['return_value']['type']}\n" + f"{function_definition['return_value']['description']}\n" + "\n" + "" + ) + return function + + def visit_function_definitions( + self, function_definitions: List[FunctionDefinition] + ) -> str: + """Render a function.""" + strs = [ + self.visit_function_definition(function_definition) + for function_definition in function_definitions + ] + return "\n" + "\n".join(strs) + "\n" + + def visit_function_invocation(self, invocation: FunctionInvocation) -> str: + """Render a function invocation.""" + arguments_as_strings = [ + "\n" + f"{argument['name']}\n" + f"{argument['value']}\n" + "\n" + for argument in invocation["arguments"] + ] + lines = [""] + + if invocation.get("id"): + lines.append(f"{invocation['id']}") + + lines.extend( + [ + f"{invocation['name']}\n" + "\n" + f"{''.join(arguments_as_strings)}" # Already includes trailing newline + "\n" + "" + ] + ) + return "\n".join(lines) + + def visit_function_result(self, function_result: FunctionResult) -> str: + """Render a function result.""" + lines = [ + "", + ] + + if function_result.get("id"): + lines.append(f"{function_result['id']}") + + lines.extend( + [ + f"{function_result['name']}", + f"{function_result['result']}", + f"{function_result['error']}", + "", + ] + ) + + return "\n".join(lines) + + +class TypeScriptEncoder(AstPrinter): + def visit_function_definition(self, function_definition: FunctionDefinition) -> str: + """Render a function.""" + parameters_as_strings = [ + f"{parameter['name']}: {parameter['type']}" + for parameter in function_definition["parameters"] + ] + # Let's use JSdoc style comments + # First the function description + lines = [ + f"// {function_definition['description']}", + # Then the parameter descriptions + *[ + f"// @param {parameter['name']} {parameter['description']}" + for parameter in function_definition["parameters"] + ], + # Then the return value description + f"// @returns {function_definition['return_value']['description']}", + # Then the function definition + f"function {function_definition['name']}(" + f"{', '.join(parameters_as_strings)}): " + f"{function_definition['return_value']['type']};", + ] + + # finally join + function = "\n".join(lines) + return function + + def visit_function_definitions( + self, function_definitions: List[FunctionDefinition] + ) -> str: + """Render a function.""" + strs = [ + self.visit_function_definition(function_definition) + for function_definition in function_definitions + ] + return "\n\n".join(strs) + + def visit_function_invocation(self, invocation: FunctionInvocation) -> str: + """Render a function invocation.""" + arguments_as_strings = [ + f"{argument['name']}: {argument['value']}" + for argument in invocation["arguments"] + ] + lines = [f"{invocation['name']}(" f"{', '.join(arguments_as_strings)});"] + return "\n".join(lines) + + def visit_function_result(self, function_result: FunctionResult) -> str: + """Render a function result.""" + lines = [] + if function_result["error"]: + lines.append(f"ERROR: {function_result['error']}") + else: + lines.append(f"> {function_result['result']}") + if function_result.get("id"): + lines.append(f"// ID: {function_result['id']}") + return "\n".join(lines) diff --git a/agents/example_traces.py b/agents/example_traces.py new file mode 100644 index 00000000..109e8724 --- /dev/null +++ b/agents/example_traces.py @@ -0,0 +1,25 @@ +# EXAMPLE_TRACE = [ +# HumanMessage(content="type the letter 'o'"), +# AIMessage( +# content=""" +# +# { +# "tool_name": "type_letter", +# "arguments": { +# "letter": "o" +# } +# } +# \ +# """ +# ), +# HumanMessage( +# content="""\ +# +# type_letter +# o +# \ +# """ +# ), +# ] +# +# diff --git a/agents/factory.py b/agents/factory.py new file mode 100644 index 00000000..6a147988 --- /dev/null +++ b/agents/factory.py @@ -0,0 +1,75 @@ +from typing import List, Optional + +from langchain.agents import AgentExecutor +from langchain.chat_models import ChatAnthropic, ChatFireworks +from langchain_core.runnables import Runnable, RunnableConfig + +from agents.agent import create_agent +from agents.parser import ParameterizedAgentParser +from langchain_benchmarks.model_registration import FIREWORK_NAME_TO_MODEL +from langchain_benchmarks.schema import ToolUsageTask +from langchain_benchmarks.tool_usage import apply_agent_executor_adapter + + +class CustomAgentFactory: + def __init__(self, task: ToolUsageTask, model: str) -> None: + """Create an OpenAI agent factory for the given task. + + Args: + task: The task to create an agent factory for. + """ + if model not in self.list_models(): + raise ValueError(f"Unknown model: {model}") + self.task = task + self.model = model + + @staticmethod + def list_models() -> List[str]: + """List all models.""" + return sorted( + [ + "claude-2.1", + "claude-2", + *FIREWORK_NAME_TO_MODEL.keys(), + ] + ) + + def __call__(self) -> Runnable: + env = self.task.create_environment() + if self.model in {"claude-2.1", "claude-2"}: + model = ChatAnthropic(model=self.model, temperature=0) + elif self.model in FIREWORK_NAME_TO_MODEL: + model = ChatFireworks( + model=FIREWORK_NAME_TO_MODEL[self.model], temperature=0 + ) + else: + raise ValueError(f"Unknown model: {self.model}") + + def _add_task_instructions( + input: dict, config: Optional[RunnableConfig] = None, **kwargs + ) -> dict: + """Add task instructions to the question.""" + input = input.copy() + input["question"] = ( + f"{self.task.instructions}\nWrite down your answer, " + f"but do not explain it. Input: `{input['question']}`" + ) + return input + + agent = create_agent( + model, + env.tools, + ParameterizedAgentParser( + wrapping_xml_tag="tool", require_closing_xml_tag=False + ), + ) + executor = AgentExecutor( + agent=agent, + tools=env.tools, + handle_parsing_errors=True, + return_intermediate_steps=True, + ) + + return _add_task_instructions | apply_agent_executor_adapter( + executor, state_reader=env.read_state + ) diff --git a/agents/parser.py b/agents/parser.py new file mode 100644 index 00000000..071f2560 --- /dev/null +++ b/agents/parser.py @@ -0,0 +1,120 @@ +import ast +import re +from typing import Any, Union + +from langchain.agents import AgentOutputParser +from langchain.pydantic_v1 import BaseModel, Field, ValidationError +from langchain_core.agents import AgentAction, AgentActionMessageLog, AgentFinish +from langchain_core.exceptions import OutputParserException +from langchain_core.messages import AIMessage + + +class _ToolInvocationRequest(BaseModel): + """Light-weight pydantic model for validating the raw tool invocation request. + + The purpose of this model, is to make sure that whatever as parsed from + the raw llm output has `tool_name` and potential `arguments` fields, and + nothing else. + """ + + tool_name: str + # OK parameterless tools which do not take arguments + named_arguments: Any = Field(default_factory=dict) + + +class ParameterizedAgentParser(AgentOutputParser): + """A generalized parser that makes it easier to parameterize different parsing.""" + + wrapping_xml_tag: str + """The tag that wraps the function invocation request. + + For example, if "tool", then the function invocation request should be wrapped + in .... + """ + require_closing_xml_tag: bool = False + """Whether we should require a closing tag for the wrapping_xml_tag. + + For example, if True, then the function invocation request should be wrapped + """ + + def parse(self, text: str) -> Union[AgentFinish, AgentAction]: + """Parse the output of the agent.""" + open_tag = f"<{self.wrapping_xml_tag}>" + close_tag = f"" + if open_tag in text: + # This is a hack to make sure that is always present + # in the output if . may be a stop sequence for the + # language model, so depending on implementation + # the stop sequence may be cut off. + # There might be a better way to do this, but this works and + # is simple. + if not self.require_closing_xml_tag: + text += close_tag + + pattern = rf"{open_tag}(?P.*?){close_tag}" + match = re.search(pattern, text, re.DOTALL) + if match: + content = match.group("invocation").strip() + return parse_invocation(content, self.wrapping_xml_tag) + + return AgentFinish( + log=text, + return_values={ + "output": text, + }, + ) + + +def parse_invocation(text: str, tag: str) -> AgentAction: + """Parse the content of the function invocation. + + Args: + text: The text to parse. + tag: The tag that wraps the function invocation request. + + Returns: + An AgentAction that corresponds to the function invocation. + + Raises: + OutputParserException: If the parsing fails. + + This exception is meant to be caught by the agent executor and + handled appropriately to provide feedback to the LLM. + """ + ai_content = f"<{tag}>{text}" + + try: + result = ast.literal_eval(text) + except Exception as e: + # Convert this to something controllable by the user. + err_msg = ( + f"ERROR: Please use the format " + f'<{tag}>{{"tool_name": $TOOL_NAME, "arguments": $ARGUMENTS}}' + ) + raise OutputParserException( + error=e, + llm_output=ai_content, + observation=err_msg, + send_to_llm=True, + ) + + try: + request = _ToolInvocationRequest(**result) + except ValidationError as e: + err_msg = ( + f"ERROR: Please use the format " + f'<{tag}>{{"tool_name": $TOOL_NAME, "arguments": $ARGUMENTS}}' + ) + raise OutputParserException( + error=e, + llm_output=ai_content, + send_to_llm=True, + observation=err_msg, + ) + + return AgentActionMessageLog( + message_log=[AIMessage(content=ai_content)], + tool=request.tool_name, + tool_input=request.named_arguments, + log=f"\nInvoking {request.tool_name}: {request.named_arguments}\n\t", + ) diff --git a/agents/prompts.py b/agents/prompts.py new file mode 100644 index 00000000..7218bf13 --- /dev/null +++ b/agents/prompts.py @@ -0,0 +1,42 @@ +AGENT_INSTRUCTIONS_XML_FORMAT = """\ +In this environment you have access to a set of tools you can use to answer the user's question. + +You may call them like this: + + +$TOOL_NAME + +<$PARAMETER_NAME>$PARAMETER_VALUE +... + + + + +Here are the tools available: + +{tool_description} +""" # noqa: E501 + +AGENT_INSTRUCTIONS_BLOB_STYLE = """\ +In this environment you have access to a set of tools you can use to answer the user's question. + +Here are the tools available: + +{tool_description} + +You may call one tool at a time using a format that includes and tag. + +Inside the tag the content is a python dictionary that uses python literals (e.g., numbers, strings, lists, dictionaries, etc.) to specify the tool invocation. + +It must match the schema of the function as described in the tool description. +"arguments" is a dictionary of the arguments to the function. + + +{{ + "tool_name": $TOOL_NAME, + "arguments": $ARGUMENTS +}} + + +If you do not know the answer use more tools. You can only take a single action at a time.\ +""" # noqa: E501 diff --git a/agents/tests/__init__.py b/agents/tests/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/agents/tests/parsing/__init__.py b/agents/tests/parsing/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/agents/tests/parsing/test_typescript_encoding.py b/agents/tests/parsing/test_typescript_encoding.py new file mode 100644 index 00000000..c62e0ad7 --- /dev/null +++ b/agents/tests/parsing/test_typescript_encoding.py @@ -0,0 +1,29 @@ +"""Test XML encoding and decoding of function definitions, invocation, and results.""" +from agents.encoder import ( + FunctionDefinition, + TypeScriptEncoder, +) + + +def test_function_definition() -> None: + """Test encoding a function definition.""" + function_definition = FunctionDefinition( + name="test_function", + description="A test function", + parameters=[ + {"name": "test_parameter", "type": "str", "description": "A test parameter"} + ], + return_value={"type": "str", "description": "A test return value"}, + ) + encoder = TypeScriptEncoder() + xml = encoder.visit_function_definition(function_definition) + assert xml == ( + "// A test function\n" + "// @param test_parameter A test parameter\n" + "// @returns A test return value\n" + "function test_function(test_parameter: str): str;" + ) + + +# Not important to test other ones right now since we can't parse / interpret +# typescript anyway. diff --git a/agents/tests/parsing/test_xml_encoding.py b/agents/tests/parsing/test_xml_encoding.py new file mode 100644 index 00000000..b22a6ebf --- /dev/null +++ b/agents/tests/parsing/test_xml_encoding.py @@ -0,0 +1,79 @@ +"""Test XML encoding and decoding of function definitions, invocation, and results.""" +from agents.encoder import ( + FunctionDefinition, + FunctionInvocation, + FunctionResult, + XMLEncoder, +) + + +def test_function_definition_encoding() -> None: + """Test encoding a function definition.""" + function_definition = FunctionDefinition( + name="test_function", + description="A test function", + parameters=[ + {"name": "test_parameter", "type": "str", "description": "A test parameter"} + ], + return_value={"type": "str", "description": "A test return value"}, + ) + encoder = XMLEncoder() + xml = encoder.visit_function_definition(function_definition) + assert xml == ( + "\n" + "test_function\n" + "\n" + "A test function\n" + "\n" + "\n" + "\n" + "test_parameter\n" + "str\n" + "A test parameter\n" + "\n" + "\n" + "\n" + "str\n" + "A test return value\n" + "\n" + "" + ) + + +def test_function_result_encoding() -> None: + """Test encoding a function result.""" + function_result = FunctionResult( + name="test_function", + result="test_result", + error="test_error", + ) + encoder = XMLEncoder() + xml = encoder.visit_function_result(function_result) + assert xml == ( + "\n" + "test_function\n" + "test_result\n" + "test_error\n" + "" + ) + + +def test_function_invocation() -> None: + """Test function invocation.""" + function_invocation = FunctionInvocation( + name="test_function", + arguments=[{"name": "test_argument", "value": "test_value"}], + ) + encoder = XMLEncoder() + xml = encoder.visit_function_invocation(function_invocation) + assert xml == ( + "\n" + "test_function\n" + "\n" + "\n" + "test_argument\n" + "test_value\n" + "\n" + "\n" + "" + ) diff --git a/agents/tests/test_lagnchain_adapter.py b/agents/tests/test_lagnchain_adapter.py new file mode 100644 index 00000000..adda0a9e --- /dev/null +++ b/agents/tests/test_lagnchain_adapter.py @@ -0,0 +1,58 @@ +import pytest +from langchain.tools import tool + +from agents.adapters import convert_tool_to_function_definition +from agents.encoder import XMLEncoder + + +@tool +def get_hello() -> str: + """Get hello.""" + return "hello" + + +@tool +def repeat(x: str) -> str: + """Repeat x. + + Args: + x: The string to repeat. + + Returns: + The repeated string. + """ + return x + + +def test_parameterless_function() -> None: + """Test foo.""" + function_definition = convert_tool_to_function_definition(get_hello) + assert function_definition == { + "name": "get_hello", + "description": "Get hello.", + "parameters": [], + "return_value": { + "type": "Any", + }, + } + + +@pytest.mark.skip("Need to fix handling of leading whitespace") +def test_function_with_parameters() -> None: + import textwrap + + doc = textwrap.dedent(repeat.func.__doc__) + assert convert_tool_to_function_definition(repeat) == { + "name": "repeat", + "description": doc, + "parameters": [ + { + "name": "x", + "type": "str", + "description": "", # Need to fix this + } + ], + "return_value": { + "type": "Any", + }, + } diff --git a/agents/tests/test_parsers.py b/agents/tests/test_parsers.py new file mode 100644 index 00000000..615b9c57 --- /dev/null +++ b/agents/tests/test_parsers.py @@ -0,0 +1,11 @@ +from langchain_adapters.alternative import AgentOutputParser +from langchain_core.agents import AgentFinish + + +def test_parser() -> None: + """Test parser.""" + parser = AgentOutputParser(require_closing_tag=False, tag="tool") + assert isinstance(parser.invoke("goodbye"), AgentFinish) + assert parser.invoke("hello") == "hello" + assert parser.invoke("hello") == "hello" + # assert isinstance(parser.invoke("hello"), AgentAction) diff --git a/agents/throttle.py b/agents/throttle.py new file mode 100644 index 00000000..17c4d6e6 --- /dev/null +++ b/agents/throttle.py @@ -0,0 +1,35 @@ +"""Throttle using a token bucket.""" +import threading +import time + + +class Throttle: + def __init__(self, rate: int) -> None: + """Initialize the throttle.""" + self.rate = rate + self.tokens = 0 + self._consume_lock = threading.Lock() + self.last = None + + def consume(self, amount: int = 0) -> int: + """Consume the given amount of tokens.""" + with self._consume_lock: + now = time.time() + + # initialize on first call to avoid a burst + if self.last is None: + self.last = now + + elapsed = now - self.last + + if elapsed * self.rate > 1: + self.tokens += elapsed * self.rate + self.last = now + + self.tokens = min(self.tokens, self.rate) + + if self.tokens >= amount: + self.tokens -= amount + return amount + + return 0 diff --git a/docs/source/notebooks/datasets.ipynb b/docs/source/notebooks/datasets.ipynb index d77e875f..3136c8ad 100644 --- a/docs/source/notebooks/datasets.ipynb +++ b/docs/source/notebooks/datasets.ipynb @@ -1,225 +1,226 @@ { - "cells": [ - { - "cell_type": "markdown", - "id": "033684fb-65b2-4586-a959-68c614741ca2", - "metadata": {}, - "source": [ - "# Datasets\n", - "[![Open In Collab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/langchain-ai/langchain-benchmarks/blob/main/docs/source/notebooks/datasets.ipynb)\n", - "\n", - "Here, we'll see how to work with LangSmith datasets." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "%pip install -U langchain-benchmarks" - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "id": "6d272fbf-710e-4a49-a0da-67e010541905", - "metadata": { - "tags": [] - }, - "outputs": [], - "source": [ - "from langchain_benchmarks import clone_public_dataset, download_public_dataset" - ] - }, - { - "cell_type": "markdown", - "id": "18ee0f96-e5c4-4ae9-aebf-7d8b88c51662", - "metadata": {}, - "source": [ - "Let's first download the dataset to the local file system" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "id": "58b94f6d-0c91-4361-9b22-f758ffaa150a", - "metadata": { - "tags": [] - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Fetching examples...\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "5a2fad8c0c3549ec96a3b38fe8a002b0", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - " 0%| | 0/21 [00:00\n", + "\n", + "Name Type Provider Description \n", + "\n", + "\n", + "gpt-3.5-turbo-1106 chat openai The latest GPT-3.5 Turbo model with improved instruction following, JSON mode, reproducible outputs, parallel function calling, and more. Returns a maximum of 4,096 output tokens. Learn more.\n", + "gpt-3.5-turbo chat openai Currently points to gpt-3.5-turbo-0613. \n", + "gpt-3.5-turbo-16k chat openai Currently points to gpt-3.5-turbo-0613. \n", + "gpt-3.5-turbo-instructllm openai Similar capabilities as text-davinci-003 but compatible with legacy Completions endpoint and not Chat Completions. \n", + "gpt-3.5-turbo-0613 chat openai Legacy Snapshot of gpt-3.5-turbo from June 13th 2023. Will be deprecated on June 13, 2024. \n", + "gpt-3.5-turbo-16k-0613chat openai Legacy Snapshot of gpt-3.5-16k-turbo from June 13th 2023. Will be deprecated on June 13, 2024. \n", + "gpt-3.5-turbo-0301 chat openai Legacy Snapshot of gpt-3.5-turbo from March 1st 2023. Will be deprecated on June 13th 2024. \n", + "text-davinci-003 llm openai Legacy Can do language tasks with better quality and consistency than the curie, babbage, or ada models. Will be deprecated on Jan 4th 2024. \n", + "text-davinci-002 llm openai Legacy Similar capabilities to text-davinci-003 but trained with supervised fine-tuning instead of reinforcement learning. Will be deprecated on Jan 4th 2024. \n", + "code-davinci-002 llm openai Legacy Optimized for code-completion tasks. Will be deprecated on Jan 4th 2024. \n", + "llama-v2-7b-chat chat fireworks 7b parameter LlamaChat model \n", + "llama-v2-13b-chat chat fireworks 13b parameter LlamaChat model \n", + "llama-v2-70b-chat chat fireworks 70b parameter LlamaChat model \n", + "\n", + "" + ], + "text/plain": [ + "ModelRegistry(registered_models=[RegisteredModel(name='gpt-3.5-turbo-1106', provider='openai', description='The latest GPT-3.5 Turbo model with improved instruction following, JSON mode, reproducible outputs, parallel function calling, and more. Returns a maximum of 4,096 output tokens. Learn more.', params={'model': 'gpt-3.5-turbo-1106'}, type='chat', path=None), RegisteredModel(name='gpt-3.5-turbo', provider='openai', description='Currently points to gpt-3.5-turbo-0613.', params={'model': 'gpt-3.5-turbo'}, type='chat', path=None), RegisteredModel(name='gpt-3.5-turbo-16k', provider='openai', description='Currently points to gpt-3.5-turbo-0613.', params={'model': 'gpt-3.5-turbo-16k'}, type='chat', path=None), RegisteredModel(name='gpt-3.5-turbo-instruct', provider='openai', description='Similar capabilities as text-davinci-003 but compatible with legacy Completions endpoint and not Chat Completions.', params={'model': 'gpt-3.5-turbo-instruct'}, type='llm', path=None), RegisteredModel(name='gpt-3.5-turbo-0613', provider='openai', description='Legacy Snapshot of gpt-3.5-turbo from June 13th 2023. Will be deprecated on June 13, 2024.', params={'model': 'gpt-3.5-turbo-0613'}, type='chat', path=None), RegisteredModel(name='gpt-3.5-turbo-16k-0613', provider='openai', description='Legacy Snapshot of gpt-3.5-16k-turbo from June 13th 2023. Will be deprecated on June 13, 2024.', params={'model': 'gpt-3.5-turbo-16k-0613'}, type='chat', path=None), RegisteredModel(name='gpt-3.5-turbo-0301', provider='openai', description='Legacy Snapshot of gpt-3.5-turbo from March 1st 2023. Will be deprecated on June 13th 2024.', params={'model': 'gpt-3.5-turbo-0301'}, type='chat', path=None), RegisteredModel(name='text-davinci-003', provider='openai', description='Legacy Can do language tasks with better quality and consistency than the curie, babbage, or ada models. Will be deprecated on Jan 4th 2024.', params={'model': 'text-davinci-003'}, type='llm', path=None), RegisteredModel(name='text-davinci-002', provider='openai', description='Legacy Similar capabilities to text-davinci-003 but trained with supervised fine-tuning instead of reinforcement learning. Will be deprecated on Jan 4th 2024.', params={'model': 'text-davinci-002'}, type='llm', path=None), RegisteredModel(name='code-davinci-002', provider='openai', description='Legacy Optimized for code-completion tasks. Will be deprecated on Jan 4th 2024.', params={'model': 'code-davinci-002'}, type='llm', path=None), RegisteredModel(name='llama-v2-7b-chat', provider='fireworks', description='7b parameter LlamaChat model', params={'model': 'accounts/fireworks/models/llama-v2-7b-chat'}, type='chat', path=None), RegisteredModel(name='llama-v2-13b-chat', provider='fireworks', description='13b parameter LlamaChat model', params={'model': 'accounts/fireworks/models/llama-v2-13b-chat'}, type='chat', path=None), RegisteredModel(name='llama-v2-70b-chat', provider='fireworks', description='70b parameter LlamaChat model', params={'model': 'accounts/fireworks/models/llama-v2-70b-chat'}, type='chat', path=None)])" + ] + }, + "execution_count": 2, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "model_registry" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "64bfc631-1f1e-4cf4-8636-b8be7b46fef8", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "text/html": [ + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "
name gpt-3.5-turbo-1106
type chat
provider openai
descriptionThe latest GPT-3.5 Turbo model with improved instruction following, JSON mode, reproducible outputs, parallel function calling, and more. Returns a maximum of 4,096 output tokens. Learn more.
model_path langchain.chat_models.openai.ChatOpenAI
" + ], + "text/plain": [ + "RegisteredModel(name='gpt-3.5-turbo-1106', provider='openai', description='The latest GPT-3.5 Turbo model with improved instruction following, JSON mode, reproducible outputs, parallel function calling, and more. Returns a maximum of 4,096 output tokens. Learn more.', params={'model': 'gpt-3.5-turbo-1106'}, type='chat', path=None)" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "registered_model = model_registry[0]\n", + "registered_model" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "3604d49e-afbe-48ad-ac10-1e538b1ad376", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "model = registered_model.get_model()" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "bdece532-9843-427a-a10b-4545ed4ec151", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "text/plain": [ + "AIMessage(content='Hello there! How can I assist you today?')" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "model.invoke('hello!')" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "db40d4da-dc70-4e6d-b7e8-61de1e15ed2e", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "text/html": [ + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "
Name Type Provider Description
gpt-3.5-turbo-1106chat openai The latest GPT-3.5 Turbo model with improved instruction following, JSON mode, reproducible outputs, parallel function calling, and more. Returns a maximum of 4,096 output tokens. Learn more.
gpt-3.5-turbo chat openai Currently points to gpt-3.5-turbo-0613.
gpt-3.5-turbo-16k chat openai Currently points to gpt-3.5-turbo-0613.
" + ], + "text/plain": [ + "ModelRegistry(registered_models=[RegisteredModel(name='gpt-3.5-turbo-1106', provider='openai', description='The latest GPT-3.5 Turbo model with improved instruction following, JSON mode, reproducible outputs, parallel function calling, and more. Returns a maximum of 4,096 output tokens. Learn more.', params={'model': 'gpt-3.5-turbo-1106'}, type='chat', path=None), RegisteredModel(name='gpt-3.5-turbo', provider='openai', description='Currently points to gpt-3.5-turbo-0613.', params={'model': 'gpt-3.5-turbo'}, type='chat', path=None), RegisteredModel(name='gpt-3.5-turbo-16k', provider='openai', description='Currently points to gpt-3.5-turbo-0613.', params={'model': 'gpt-3.5-turbo-16k'}, type='chat', path=None)])" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "model_registry[:3]" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "9874846a-52f3-4921-b1ed-0858521bb9a9", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "text/html": [ + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "
Name Type Provider Description
llama-v2-7b-chat chat fireworks 7b parameter LlamaChat model
llama-v2-13b-chatchat fireworks 13b parameter LlamaChat model
llama-v2-70b-chatchat fireworks 70b parameter LlamaChat model
" + ], + "text/plain": [ + "ModelRegistry(registered_models=[RegisteredModel(name='llama-v2-7b-chat', provider='fireworks', description='7b parameter LlamaChat model', params={'model': 'accounts/fireworks/models/llama-v2-7b-chat'}, type='chat', path=None), RegisteredModel(name='llama-v2-13b-chat', provider='fireworks', description='13b parameter LlamaChat model', params={'model': 'accounts/fireworks/models/llama-v2-13b-chat'}, type='chat', path=None), RegisteredModel(name='llama-v2-70b-chat', provider='fireworks', description='70b parameter LlamaChat model', params={'model': 'accounts/fireworks/models/llama-v2-70b-chat'}, type='chat', path=None)])" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "model_registry.filter(provider='fireworks')" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "eb531591-f46b-4745-ae67-4dfd6217ec5f", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "gpt-3.5-turbo-1106\n", + "gpt-3.5-turbo\n", + "gpt-3.5-turbo-16k\n", + "gpt-3.5-turbo-instruct\n", + "gpt-3.5-turbo-0613\n", + "gpt-3.5-turbo-16k-0613\n", + "gpt-3.5-turbo-0301\n", + "text-davinci-003\n", + "text-davinci-002\n", + "code-davinci-002\n", + "llama-v2-7b-chat\n", + "llama-v2-13b-chat\n", + "llama-v2-70b-chat\n" + ] + } + ], + "source": [ + "for registered_model in model_registry:\n", + " print(registered_model.name)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.4" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/langchain_benchmarks/__init__.py b/langchain_benchmarks/__init__.py index deb45bc9..e2a83fd7 100644 --- a/langchain_benchmarks/__init__.py +++ b/langchain_benchmarks/__init__.py @@ -1,3 +1,4 @@ +from langchain_benchmarks.model_registration import model_registry from langchain_benchmarks.registration import registry from langchain_benchmarks.utils._langsmith import ( clone_public_dataset, @@ -5,4 +6,10 @@ ) # Please keep this list sorted! -__all__ = ["clone_public_dataset", "download_public_dataset", "registry"] +__all__ = [ + "clone_public_dataset", + "download_public_dataset", + "registry", + "model_registry", + +] diff --git a/langchain_benchmarks/model_registration.py b/langchain_benchmarks/model_registration.py new file mode 100644 index 00000000..fb9d0ecf --- /dev/null +++ b/langchain_benchmarks/model_registration.py @@ -0,0 +1,152 @@ +from __future__ import annotations + +from langchain_benchmarks.schema import RegisteredModel, ModelRegistry + +_OpenAIModels = [ + RegisteredModel( + provider="openai", + name="gpt-3.5-turbo-1106", + type="chat", + description=( + "The latest GPT-3.5 Turbo model with improved instruction following, " + "JSON mode, reproducible outputs, parallel function calling, and more. " + "Returns a maximum of 4,096 output tokens. Learn more." + ), + params={ + "model": "gpt-3.5-turbo-1106", + }, + ), + RegisteredModel( + provider="openai", + name="gpt-3.5-turbo", + type="chat", + description="Currently points to gpt-3.5-turbo-0613.", + params={ + "model": "gpt-3.5-turbo", + }, + ), + RegisteredModel( + provider="openai", + name="gpt-3.5-turbo-16k", + type="chat", + description="Currently points to gpt-3.5-turbo-0613.", + params={ + "model": "gpt-3.5-turbo-16k", + }, + ), + RegisteredModel( + provider="openai", + name="gpt-3.5-turbo-instruct", + type="llm", + description=( + "Similar capabilities as text-davinci-003 but compatible with legacy " + "Completions endpoint and not Chat Completions." + ), + params={ + "model": "gpt-3.5-turbo-instruct", + }, + ), + RegisteredModel( + provider="openai", + name="gpt-3.5-turbo-0613", + type="chat", + description=( + "Legacy Snapshot of gpt-3.5-turbo from June 13th 2023. " + "Will be deprecated on June 13, 2024." + ), + params={ + "model": "gpt-3.5-turbo-0613", + }, + ), + RegisteredModel( + provider="openai", + name="gpt-3.5-turbo-16k-0613", + type="chat", + description=( + "Legacy Snapshot of gpt-3.5-16k-turbo from June 13th 2023. " + "Will be deprecated on June 13, 2024." + ), + params={ + "model": "gpt-3.5-turbo-16k-0613", + }, + ), + RegisteredModel( + provider="openai", + name="gpt-3.5-turbo-0301", + type="chat", + description=( + "Legacy Snapshot of gpt-3.5-turbo from March 1st 2023. " + "Will be deprecated on June 13th 2024." + ), + params={ + "model": "gpt-3.5-turbo-0301", + }, + ), + RegisteredModel( + provider="openai", + name="text-davinci-003", + type="llm", + description=( + "Legacy Can do language tasks with better quality and consistency than " + "the curie, babbage, or ada models. Will be deprecated on Jan 4th 2024." + ), + params={ + "model": "text-davinci-003", + }, + ), + RegisteredModel( + provider="openai", + name="text-davinci-002", + type="llm", + description=( + "Legacy Similar capabilities to text-davinci-003 but trained with " + "supervised fine-tuning instead of reinforcement learning. " + "Will be deprecated on Jan 4th 2024." + ), + params={ + "model": "text-davinci-002", + }, + ), + RegisteredModel( + provider="openai", + name="code-davinci-002", + type="llm", + description="Legacy Optimized for code-completion tasks. Will be deprecated " + "on Jan 4th 2024.", + params={ + "model": "code-davinci-002", + }, + ), +] + +_FireworksModels = [ + RegisteredModel( + provider="fireworks", + name="llama-v2-7b-chat", + type="chat", + description="7b parameter LlamaChat model", + params={ + "model": "accounts/fireworks/models/llama-v2-7b-chat", + }, + ), + RegisteredModel( + provider="fireworks", + name="llama-v2-13b-chat", + type="chat", + description="13b parameter LlamaChat model", + params={ + "model": "accounts/fireworks/models/llama-v2-13b-chat", + }, + ), + RegisteredModel( + provider="fireworks", + name="llama-v2-70b-chat", + type="chat", + description="70b parameter LlamaChat model", + params={ + "model": "accounts/fireworks/models/llama-v2-70b-chat", + }, + ), +] + +model_registry = ModelRegistry(registered_models=_OpenAIModels + _FireworksModels) diff --git a/langchain_benchmarks/schema.py b/langchain_benchmarks/schema.py index 5024957c..68227a6b 100644 --- a/langchain_benchmarks/schema.py +++ b/langchain_benchmarks/schema.py @@ -2,8 +2,12 @@ from __future__ import annotations import dataclasses +import importlib import urllib -from typing import Any, Callable, Dict, Iterable, List, Optional, Type, Union +from typing import Any, Callable, Dict, Iterable, List, Optional, Type, Union, Sequence + +from langchain_core.language_models import BaseChatModel, BaseLanguageModel +from typing_extensions import Literal from langchain.prompts import ChatPromptTemplate from langchain.schema import BaseRetriever @@ -153,6 +157,7 @@ def __post_init__(self) -> None: raise ValueError( f"Duplicate task name {task.name}. " f"Task names must be unique." ) + seen_names.add(task.name) def _repr_html_(self) -> str: """Return an HTML representation of the registry.""" @@ -210,3 +215,190 @@ def add(self, task: BaseTask) -> None: if not isinstance(task, BaseTask): raise TypeError("Only tasks can be added to the registry.") self.tasks.append(task) + + +Provider = Literal["fireworks", "openai"] +ModelType = Literal["chat", "llm"] +AUTHORIZED_NAMESPACES = {"langchain"} + + +def _get_model_class_from_path( + path: str +) -> Union[Type[BaseChatModel], Type[BaseLanguageModel]]: + """Get the class of the model.""" + module_name, attribute_name = path.rsplit(".", 1) + top_namespace = path.split(".")[0] + + if top_namespace not in AUTHORIZED_NAMESPACES: + raise ValueError( + f"Unauthorized namespace {top_namespace}. " + f"Authorized namespaces are: {AUTHORIZED_NAMESPACES}" + ) + + # Import the module dynamically + module = importlib.import_module(module_name) + model_class = getattr(module, attribute_name) + if not issubclass(model_class, (BaseLanguageModel, BaseChatModel)): + raise ValueError( + f"Model class {model_class} is not a subclass of BaseLanguageModel" + ) + return model_class + + +def _get_default_path(provider: str, type_: ModelType) -> str: + """Get the default path for a model.""" + paths = { + ("fireworks", "chat"): "langchain.chat_models.fireworks.ChatFireworks", + ("fireworks", "llm"): "langchain.language_models.fireworks.Fireworks", + ("openai", "chat"): "langchain.chat_models.openai.ChatOpenAI", + ("openai", "llm"): "langchain.language_models.openai.OpenAI", + } + + if (provider, type_) not in paths: + raise ValueError(f"Unknown provider {provider} and type {type_}") + + return paths[(provider, type_)] + + +@dataclasses.dataclass(frozen=True) +class RegisteredModel: + """Descriptive information about a model. + + This information can be used to instantiate the underlying model. + """ + + name: str + provider: Provider + description: str + params: Dict[str, Any] + type: ModelType + # Path to the model class. + # For example, "langchain.chat_models.anthropic import ChatAnthropicModel" + path: Optional[str] = None # If not provided, will use default path + + def get_model( + self, *, model_params: Optional[Dict[str, Any]] = None + ) -> Union[BaseChatModel, BaseLanguageModel]: + """Get the class of the model.""" + all_params = {**self.params, **(model_params or {})} + model_class = _get_model_class_from_path(self.model_path) + return model_class(**all_params) + + @property + def model_path(self) -> str: + """Get the path of the model.""" + return self.path or _get_default_path(self.provider, self.type) + + @property + def _table(self) -> List[List[str]]: + """Return a table representation of the environment.""" + return [ + ["name", self.name], + ["type", self.type], + ["provider", self.provider], + ["description", self.description], + ["model_path", self.model_path], + ] + + def _repr_html_(self) -> str: + """Return an HTML representation of the environment.""" + return tabulate( + self._table, + tablefmt="unsafehtml", + ) + + +StrFilter = Union[None, str, Sequence[str]] + + +def _is_in_filter(actual_value: str, filter_value: StrFilter) -> bool: + """Filter for a string attribute.""" + if filter_value is None: + return True + + if isinstance(filter_value, str): + return actual_value == filter_value + + return actual_value in filter_value + + +@dataclasses.dataclass(frozen=False) +class ModelRegistry: + registered_models: Sequence[RegisteredModel] + + def __post_init__(self) -> None: + """Validate that all the tasks have unique names and IDs.""" + seen_names = set() + for model in self.registered_models: + if model.name in seen_names: + raise ValueError( + f"Duplicate model name {model.name}. " f"Task names must be unique." + ) + seen_names.add(model.name) + + def get_model(self, name: str) -> Optional[RegisteredModel]: + """Get model info.""" + return next(model for model in self.registered_models if model.name == name) + + def filter( + self, + *, + type: StrFilter = None, + name: StrFilter = None, + provider: StrFilter = None, + ) -> ModelRegistry: + """Filter the tasks in the registry.""" + models = self.registered_models + selected_models = [] + + for model in models: + if not _is_in_filter(model.type, type): + continue + if not _is_in_filter(model.name, name): + continue + if not _is_in_filter(model.provider, provider): + continue + selected_models.append(model) + return ModelRegistry(registered_models=selected_models) + + def _repr_html_(self) -> str: + """Return an HTML representation of the registry.""" + headers = [ + "Name", + "Type", + "Provider", + "Description", + ] + table = [ + [ + model.name, + model.type, + model.provider, + model.description, + ] + for model in self.registered_models + ] + return tabulate(table, headers=headers, tablefmt="unsafehtml") + + def __len__(self) -> int: + """Return the number of tasks in the registry.""" + return len(self.registered_models) + + def __iter__(self) -> Iterable[RegisteredModel]: + """Iterate over the tasks in the registry.""" + return iter(self.registered_models) + + def __getitem__( + self, key: Union[int, str] + ) -> Union[RegisteredModel, ModelRegistry]: + """Get an environment from the registry.""" + if isinstance(key, slice): + return ModelRegistry(registered_models=self.registered_models[key]) + elif isinstance(key, (int, str)): + # If key is an integer, return the corresponding environment + if isinstance(key, str): + return self.get_model(key) + else: + return self.registered_models[key] + else: + raise TypeError("Key must be an integer or a slice.") diff --git a/tests/unit_tests/test_model_registry.py b/tests/unit_tests/test_model_registry.py new file mode 100644 index 00000000..8aa04dc8 --- /dev/null +++ b/tests/unit_tests/test_model_registry.py @@ -0,0 +1,68 @@ +import pytest + +from langchain_benchmarks.schema import RegisteredModel, ModelRegistry + +# Create some sample RegisteredModel instances for testing +SAMPLE_MODELS = [ + RegisteredModel( + "model1", "fireworks", "Description 1", {"param1": "value1"}, "chat" + ), + RegisteredModel("model2", "openai", "Description 2", {"param2": "value2"}, "llm"), +] + + +@pytest.fixture +def sample_registry() -> ModelRegistry: + return ModelRegistry(SAMPLE_MODELS) + + +def test_init() -> None: + # Test the constructor of ModelRegistry + registry = ModelRegistry(SAMPLE_MODELS) + assert len(registry.registered_models) == 2 + + +def test_get_model(sample_registry: ModelRegistry) -> None: + # Test the get_model method + model = sample_registry.get_model("model1") + assert model.name == "model1" + + +def test_filter(sample_registry: ModelRegistry) -> None: + # Test the filter method + filtered_registry = sample_registry.filter(type="chat") + assert len(filtered_registry.registered_models) == 1 + assert filtered_registry.registered_models[0].type == "chat" + + +def test_repr_html(sample_registry: ModelRegistry) -> None: + # Test the _repr_html_ method + html_representation = sample_registry._repr_html_() + assert "" in html_representation + + +def test_len(sample_registry: ModelRegistry) -> None: + # Test the __len__ method + assert len(sample_registry) == 2 + + +def test_iter(sample_registry: ModelRegistry) -> None: + # Test the __iter__ method + models = list(iter(sample_registry)) + assert len(models) == 2 + assert isinstance(models[0], RegisteredModel) + + +def test_getitem(sample_registry: ModelRegistry) -> None: + # Test the __getitem__ method for integer and string keys + model = sample_registry[0] + assert model.name == "model1" + model = sample_registry["model2"] + assert model.name == "model2" + + +def test_getitem_slice(sample_registry: ModelRegistry) -> None: + # Test the __getitem__ method for slices + sliced_registry = sample_registry[:1] + assert len(sliced_registry.registered_models) == 1 + assert sliced_registry.registered_models[0].name == "model1"