diff --git a/README.md b/README.md index 0ae637d..61fc94c 100644 --- a/README.md +++ b/README.md @@ -190,7 +190,7 @@ if __name__ == "__main__": import os from stagehand.sync import Stagehand from stagehand import StagehandConfig -from stagehand.schemas import AgentConfig, AgentExecuteOptions, AgentProvider +from stagehand.types import AgentConfigAPI as AgentConfig, AgentExecuteOptionsAPI as AgentExecuteOptions, AgentProvider from dotenv import load_dotenv load_dotenv() @@ -246,7 +246,7 @@ if __name__ == "__main__": The `ActOptions` model takes an `action` field that tells the AI what to do on the page, plus optional fields such as `useVision` and `variables`: ```python - from stagehand.schemas import ActOptions + from stagehand.types import ActOptions # Example: await page.act(ActOptions(action="click on the 'Quickstart' button")) @@ -256,7 +256,7 @@ if __name__ == "__main__": The `ObserveOptions` model lets you find elements on the page using natural language. The `onlyVisible` option helps limit the results: ```python - from stagehand.schemas import ObserveOptions + from stagehand.types import ObserveOptions # Example: await page.observe(ObserveOptions(instruction="find the button labeled 'News'", onlyVisible=True)) @@ -266,7 +266,7 @@ if __name__ == "__main__": The `ExtractOptions` model extracts structured data from the page. Pass your instructions and a schema defining your expected data format. **Note:** If you are using a Pydantic model for the schema, call its `.model_json_schema()` method to ensure JSON serializability. ```python - from stagehand.schemas import ExtractOptions + from stagehand.types import ExtractOptions from pydantic import BaseModel class DescriptionSchema(BaseModel): diff --git a/examples/agent_example.py b/examples/agent_example.py index c8cd56f..f52c7d7 100644 --- a/examples/agent_example.py +++ b/examples/agent_example.py @@ -8,7 +8,7 @@ from rich.theme import Theme from stagehand import Stagehand, StagehandConfig, AgentConfig, configure_logging -from stagehand.schemas import AgentExecuteOptions, AgentProvider +from stagehand.types import AgentExecuteOptionsAPI as AgentExecuteOptions, AgentProvider # Create a custom theme for consistent styling custom_theme = Theme( diff --git a/stagehand/__init__.py b/stagehand/__init__.py index 8f3a5f0..c6edf68 100644 --- a/stagehand/__init__.py +++ b/stagehand/__init__.py @@ -8,11 +8,11 @@ from .main import Stagehand from .metrics import StagehandFunctionName, StagehandMetrics from .page import StagehandPage -from .schemas import ( +from .types import ( ActOptions, ActResult, - AgentConfig, - AgentExecuteOptions, + AgentConfigAPI as AgentConfig, + AgentExecuteOptionsAPI as AgentExecuteOptions, AgentExecuteResult, AgentProvider, ExtractOptions, diff --git a/stagehand/agent.py b/stagehand/agent.py index 32e689d..093ac71 100644 --- a/stagehand/agent.py +++ b/stagehand/agent.py @@ -1,6 +1,6 @@ -from .schemas import ( - AgentConfig, - AgentExecuteOptions, +from .types import ( + AgentConfigAPI as AgentConfig, + AgentExecuteOptionsAPI as AgentExecuteOptions, AgentExecuteResult, AgentProvider, ) diff --git a/stagehand/config.py b/stagehand/config.py index 1cb6b25..90a7977 100644 --- a/stagehand/config.py +++ b/stagehand/config.py @@ -3,7 +3,7 @@ from browserbase.types import SessionCreateParams as BrowserbaseSessionCreateParams from pydantic import BaseModel, ConfigDict, Field -from stagehand.schemas import AvailableModel +from stagehand.types import AvailableModel class StagehandConfig(BaseModel): diff --git a/stagehand/handlers/observe_handler.py b/stagehand/handlers/observe_handler.py index 251f857..c747ef5 100644 --- a/stagehand/handlers/observe_handler.py +++ b/stagehand/handlers/observe_handler.py @@ -5,7 +5,7 @@ from stagehand.a11y.utils import get_accessibility_tree, get_xpath_by_resolved_object_id from stagehand.llm.inference import observe as observe_inference from stagehand.metrics import StagehandFunctionName # Changed import location -from stagehand.schemas import ObserveOptions, ObserveResult +from stagehand.types import ObserveOptions, ObserveResult from stagehand.utils import draw_observe_overlay diff --git a/stagehand/main.py b/stagehand/main.py index e8662c2..2e216a5 100644 --- a/stagehand/main.py +++ b/stagehand/main.py @@ -28,7 +28,7 @@ from .logging import StagehandLogger, default_log_handler from .metrics import StagehandFunctionName, StagehandMetrics from .page import StagehandPage -from .schemas import AgentConfig +from .types import AgentConfigAPI as AgentConfig from .utils import make_serializable load_dotenv() diff --git a/stagehand/page.py b/stagehand/page.py index d01c83a..decf849 100644 --- a/stagehand/page.py +++ b/stagehand/page.py @@ -7,7 +7,7 @@ from stagehand.handlers.extract_handler import ExtractHandler from stagehand.handlers.observe_handler import ObserveHandler -from .schemas import ( +from .types import ( DEFAULT_EXTRACT_SCHEMA, ActOptions, ActResult, diff --git a/stagehand/schemas.py b/stagehand/schemas.py deleted file mode 100644 index 472ec05..0000000 --- a/stagehand/schemas.py +++ /dev/null @@ -1,288 +0,0 @@ -from enum import Enum -from typing import Any, Optional, Union - -from pydantic import BaseModel, ConfigDict, Field, field_serializer - -# Default extraction schema that matches the TypeScript version -DEFAULT_EXTRACT_SCHEMA = { - "type": "object", - "properties": {"extraction": {"type": "string"}}, - "required": ["extraction"], -} - - -# TODO: Remove this -class AvailableModel(str, Enum): - GPT_4O = "gpt-4o" - GPT_4O_MINI = "gpt-4o-mini" - CLAUDE_3_5_SONNET_LATEST = "claude-3-5-sonnet-latest" - CLAUDE_3_7_SONNET_LATEST = "claude-3-7-sonnet-latest" - COMPUTER_USE_PREVIEW = "computer-use-preview" - GEMINI_2_0_FLASH = "gemini-2.0-flash" - - -class StagehandBaseModel(BaseModel): - """Base model for all Stagehand models with camelCase conversion support""" - - model_config = ConfigDict( - populate_by_name=True, # Allow accessing fields by their Python name - alias_generator=lambda field_name: "".join( - [field_name.split("_")[0]] - + [word.capitalize() for word in field_name.split("_")[1:]] - ), # snake_case to camelCase - ) - - -class ActOptions(StagehandBaseModel): - """ - Options for the 'act' command. - - Attributes: - action (str): The action command to be executed by the AI. - variables (Optional[dict[str, str]]): Key-value pairs for variable substitution. - model_name (Optional[str]): The model to use for processing. - slow_dom_based_act (Optional[bool]): Whether to use DOM-based action execution. - dom_settle_timeout_ms (Optional[int]): Additional time for DOM to settle after an action. - timeout_ms (Optional[int]): Timeout for the action in milliseconds. - """ - - action: str = Field(..., description="The action command to be executed by the AI.") - variables: Optional[dict[str, str]] = None - model_name: Optional[str] = None - slow_dom_based_act: Optional[bool] = None - dom_settle_timeout_ms: Optional[int] = None - timeout_ms: Optional[int] = None - model_client_options: Optional[dict[str, Any]] = None - - -class ActResult(StagehandBaseModel): - """ - Result of the 'act' command. - - Attributes: - success (bool): Whether the action was successful. - message (str): Message from the AI about the action. - action (str): The action command that was executed. - """ - - success: bool = Field(..., description="Whether the action was successful.") - message: str = Field(..., description="Message from the AI about the action.") - action: str = Field(..., description="The action command that was executed.") - - -class ExtractOptions(StagehandBaseModel): - """ - Options for the 'extract' command. - - Attributes: - instruction (str): Instruction specifying what data to extract using AI. - model_name (Optional[str]): The model to use for processing. - selector (Optional[str]): CSS selector to limit extraction to. - schema_definition (Union[dict[str, Any], type[BaseModel]]): A JSON schema or Pydantic model that defines the structure of the expected data. - Note: If passing a Pydantic model, invoke its .model_json_schema() method to ensure the schema is JSON serializable. - use_text_extract (Optional[bool]): Whether to use text-based extraction. - dom_settle_timeout_ms (Optional[int]): Additional time for DOM to settle before extraction. - """ - - instruction: str = Field( - ..., description="Instruction specifying what data to extract using AI." - ) - model_name: Optional[str] = None - selector: Optional[str] = None - # IMPORTANT: If using a Pydantic model for schema_definition, please call its .model_json_schema() method - # to convert it to a JSON serializable dictionary before sending it with the extract command. - schema_definition: Union[dict[str, Any], type[BaseModel]] = Field( - default=DEFAULT_EXTRACT_SCHEMA, - description="A JSON schema or Pydantic model that defines the structure of the expected data.", - ) - use_text_extract: Optional[bool] = None - dom_settle_timeout_ms: Optional[int] = None - model_client_options: Optional[dict[Any, Any]] = None - - @field_serializer("schema_definition") - def serialize_schema_definition( - self, schema_definition: Union[dict[str, Any], type[BaseModel]] - ) -> dict[str, Any]: - """Serialize schema_definition to a JSON schema if it's a Pydantic model""" - if isinstance(schema_definition, type) and issubclass( - schema_definition, BaseModel - ): - # Get the JSON schema using default ref_template ('#/$defs/{model}') - schema = schema_definition.model_json_schema() - - defs_key = "$defs" - if defs_key not in schema: - defs_key = "definitions" - if defs_key not in schema: - return schema - - definitions = schema.get(defs_key, {}) - if definitions: - self._resolve_references(schema, definitions, f"#/{defs_key}/") - schema.pop(defs_key, None) - - return schema - - elif isinstance(schema_definition, dict): - return schema_definition - - raise TypeError("schema_definition must be a Pydantic model or a dict") - - def _resolve_references(self, obj: Any, definitions: dict, ref_prefix: str) -> None: - """Recursively resolve $ref references in a schema using definitions.""" - if isinstance(obj, dict): - if "$ref" in obj and obj["$ref"].startswith(ref_prefix): - ref_name = obj["$ref"][len(ref_prefix) :] # Get name after prefix - if ref_name in definitions: - original_keys = {k: v for k, v in obj.items() if k != "$ref"} - resolved_definition = definitions[ref_name].copy() # Use a copy - self._resolve_references( - resolved_definition, definitions, ref_prefix - ) - - obj.clear() - obj.update(resolved_definition) - obj.update(original_keys) - else: - # Recursively process all values in the dictionary - for _, value in obj.items(): - self._resolve_references(value, definitions, ref_prefix) - - elif isinstance(obj, list): - # Process all items in the list - for item in obj: - self._resolve_references(item, definitions, ref_prefix) - - model_config = ConfigDict(arbitrary_types_allowed=True) - - -class ExtractResult(StagehandBaseModel): - """ - Result of the 'extract' command. - - This is a generic model to hold extraction results of different types. - The actual fields will depend on the schema provided in ExtractOptions. - """ - - # This class is intentionally left without fields so it can accept - # any fields from the extraction result based on the schema - - model_config = ConfigDict(extra="allow") # Allow any extra fields - - def __getitem__(self, key): - """ - Enable dictionary-style access to attributes. - This allows usage like result["selector"] in addition to result.selector - """ - return getattr(self, key) - - -class ObserveOptions(StagehandBaseModel): - """ - Options for the 'observe' command. - - Attributes: - instruction (str): Instruction detailing what the AI should observe. - model_name (Optional[str]): The model to use for processing. - return_action (Optional[bool]): Whether to include action information in the result. - draw_overlay (Optional[bool]): Whether to draw an overlay on observed elements. - dom_settle_timeout_ms (Optional[int]): Additional time for DOM to settle before observation. - """ - - instruction: str = Field( - ..., description="Instruction detailing what the AI should observe." - ) - model_name: Optional[str] = None - draw_overlay: Optional[bool] = None - dom_settle_timeout_ms: Optional[int] = None - model_client_options: Optional[dict[str, Any]] = None - - -class ObserveResult(StagehandBaseModel): - """ - Result of the 'observe' command. - - Attributes: - selector (str): The selector of the observed element. - description (str): The description of the observed element. - backend_node_id (Optional[int]): The backend node ID. - method (Optional[str]): The method to execute. - arguments (Optional[list[str]]): The arguments for the method. - """ - - selector: str = Field(..., description="The selector of the observed element.") - description: str = Field( - ..., description="The description of the observed element." - ) - backend_node_id: Optional[int] = None - method: Optional[str] = None - arguments: Optional[list[str]] = None - - def __getitem__(self, key): - """ - Enable dictionary-style access to attributes. - This allows usage like result["selector"] in addition to result.selector - """ - return getattr(self, key) - - -class AgentProvider(str, Enum): - """Supported agent providers""" - - OPENAI = "openai" - ANTHROPIC = "anthropic" - - -class AgentConfig(StagehandBaseModel): - """ - Configuration for agent execution. - - Attributes: - provider (Optional[AgentProvider]): The provider to use (openai or anthropic). - model (Optional[str]): The model name to use. - instructions (Optional[str]): Custom instructions for the agent. - options (Optional[dict[str, Any]]): Additional provider-specific options. - """ - - provider: Optional[AgentProvider] = None - model: Optional[str] = None - instructions: Optional[str] = None - options: Optional[dict[str, Any]] = None - - -class AgentExecuteOptions(StagehandBaseModel): - """ - Options for agent execution. - - Attributes: - instruction (str): The task instruction for the agent. - max_steps (Optional[int]): Maximum number of steps the agent can take. - auto_screenshot (Optional[bool]): Whether to automatically take screenshots between steps. - wait_between_actions (Optional[int]): Milliseconds to wait between actions. - context (Optional[str]): Additional context for the agent. - """ - - instruction: str = Field(..., description="The task instruction for the agent.") - max_steps: Optional[int] = None - auto_screenshot: Optional[bool] = None - wait_between_actions: Optional[int] = None - context: Optional[str] = None - - -class AgentExecuteResult(StagehandBaseModel): - """ - Result of agent execution. - - Attributes: - success (bool): Whether the execution was successful. - actions (Optional[list[dict[str, Any]]]): Actions taken by the agent. - message (Optional[str]): Final result message from the agent. - completed (bool): Whether the agent has completed its task. - """ - - success: bool = Field(..., description="Whether the execution was successful.") - actions: Optional[list[dict[str, Any]]] = None - message: Optional[str] = None - completed: bool = Field( - False, description="Whether the agent has completed its task." - ) diff --git a/stagehand/types/__init__.py b/stagehand/types/__init__.py index ac1af17..3904348 100644 --- a/stagehand/types/__init__.py +++ b/stagehand/types/__init__.py @@ -1,5 +1,5 @@ """ -Exports for accessibility types. +Exports for Stagehand types. """ from .a11y import ( @@ -15,6 +15,16 @@ ) from .agent import ( AgentConfig, + AgentConfigAPI, + AgentExecuteOptions, + AgentExecuteOptionsAPI, + AgentExecuteResult, +) +from .base import ( + AgentProvider, + AvailableModel, + DEFAULT_EXTRACT_SCHEMA, + StagehandBaseModel, ) from .llm import ( ChatMessage, @@ -33,6 +43,12 @@ ) __all__ = [ + # Base types + "StagehandBaseModel", + "AgentProvider", + "AvailableModel", + "DEFAULT_EXTRACT_SCHEMA", + # A11y types "AXProperty", "AXValue", "AXNode", @@ -42,7 +58,9 @@ "Locator", "PlaywrightCommandError", "PlaywrightMethodNotSupportedError", + # LLM types "ChatMessage", + # Page types "ObserveElementSchema", "ObserveInferenceSchema", "ActOptions", @@ -53,7 +71,10 @@ "DefaultExtractSchema", "ExtractOptions", "ExtractResult", + # Agent types "AgentConfig", + "AgentConfigAPI", "AgentExecuteOptions", - "AgentResult", + "AgentExecuteOptionsAPI", + "AgentExecuteResult", ] diff --git a/stagehand/types/agent.py b/stagehand/types/agent.py index b533538..d37a0a1 100644 --- a/stagehand/types/agent.py +++ b/stagehand/types/agent.py @@ -1,6 +1,8 @@ from typing import Any, Literal, Optional, Union -from pydantic import BaseModel, RootModel +from pydantic import BaseModel, RootModel, Field + +from .base import StagehandBaseModel, AgentProvider class AgentConfig(BaseModel): @@ -175,3 +177,59 @@ class EnvState(BaseModel): # The screenshot in PNG format. screenshot: bytes url: str + + +# Schemas from the API (with camelCase serialization) +class AgentConfigAPI(StagehandBaseModel): + """ + Configuration for agent execution. + + Attributes: + provider (Optional[AgentProvider]): The provider to use (openai or anthropic). + model (Optional[str]): The model name to use. + instructions (Optional[str]): Custom instructions for the agent. + options (Optional[dict[str, Any]]): Additional provider-specific options. + """ + + provider: Optional[AgentProvider] = None + model: Optional[str] = None + instructions: Optional[str] = None + options: Optional[dict[str, Any]] = None + + +class AgentExecuteOptionsAPI(StagehandBaseModel): + """ + Options for agent execution. + + Attributes: + instruction (str): The task instruction for the agent. + max_steps (Optional[int]): Maximum number of steps the agent can take. + auto_screenshot (Optional[bool]): Whether to automatically take screenshots between steps. + wait_between_actions (Optional[int]): Milliseconds to wait between actions. + context (Optional[str]): Additional context for the agent. + """ + + instruction: str = Field(..., description="The task instruction for the agent.") + max_steps: Optional[int] = None + auto_screenshot: Optional[bool] = None + wait_between_actions: Optional[int] = None + context: Optional[str] = None + + +class AgentExecuteResult(StagehandBaseModel): + """ + Result of agent execution. + + Attributes: + success (bool): Whether the execution was successful. + actions (Optional[list[dict[str, Any]]]): Actions taken by the agent. + message (Optional[str]): Final result message from the agent. + completed (bool): Whether the agent has completed its task. + """ + + success: bool = Field(..., description="Whether the execution was successful.") + actions: Optional[list[dict[str, Any]]] = None + message: Optional[str] = None + completed: bool = Field( + False, description="Whether the agent has completed its task." + ) diff --git a/stagehand/types/base.py b/stagehand/types/base.py new file mode 100644 index 0000000..cee0df7 --- /dev/null +++ b/stagehand/types/base.py @@ -0,0 +1,40 @@ +from enum import Enum +from typing import Any + +from pydantic import BaseModel, ConfigDict + +# Default extraction schema that matches the TypeScript version +DEFAULT_EXTRACT_SCHEMA = { + "type": "object", + "properties": {"extraction": {"type": "string"}}, + "required": ["extraction"], +} + + +# TODO: Remove this +class AvailableModel(str, Enum): + GPT_4O = "gpt-4o" + GPT_4O_MINI = "gpt-4o-mini" + CLAUDE_3_5_SONNET_LATEST = "claude-3-5-sonnet-latest" + CLAUDE_3_7_SONNET_LATEST = "claude-3-7-sonnet-latest" + COMPUTER_USE_PREVIEW = "computer-use-preview" + GEMINI_2_0_FLASH = "gemini-2.0-flash" + + +class StagehandBaseModel(BaseModel): + """Base model for all Stagehand models with camelCase conversion support""" + + model_config = ConfigDict( + populate_by_name=True, # Allow accessing fields by their Python name + alias_generator=lambda field_name: "".join( + [field_name.split("_")[0]] + + [word.capitalize() for word in field_name.split("_")[1:]] + ), # snake_case to camelCase + ) + + +class AgentProvider(str, Enum): + """Supported agent providers""" + + OPENAI = "openai" + ANTHROPIC = "anthropic" \ No newline at end of file diff --git a/stagehand/types/page.py b/stagehand/types/page.py index ecfee16..80616cd 100644 --- a/stagehand/types/page.py +++ b/stagehand/types/page.py @@ -1,6 +1,8 @@ from typing import Any, Optional, Union -from pydantic import BaseModel, Field +from pydantic import BaseModel, ConfigDict, Field, field_serializer + +from .base import DEFAULT_EXTRACT_SCHEMA, StagehandBaseModel # Ignore linting error for this class name since it's used as a constant @@ -27,7 +29,7 @@ class MetadataSchema(BaseModel): progress: str -class ActOptions(BaseModel): +class ActOptions(StagehandBaseModel): """ Options for the 'act' command. @@ -35,6 +37,7 @@ class ActOptions(BaseModel): action (str): The action command to be executed by the AI. variables (Optional[dict[str, str]]): Key-value pairs for variable substitution. model_name (Optional[str]): The model to use for processing. + slow_dom_based_act (Optional[bool]): Whether to use DOM-based action execution. dom_settle_timeout_ms (Optional[int]): Additional time for DOM to settle after an action. timeout_ms (Optional[int]): Timeout for the action in milliseconds. """ @@ -42,12 +45,13 @@ class ActOptions(BaseModel): action: str = Field(..., description="The action command to be executed by the AI.") variables: Optional[dict[str, str]] = None model_name: Optional[str] = None + slow_dom_based_act: Optional[bool] = None dom_settle_timeout_ms: Optional[int] = None timeout_ms: Optional[int] = None model_client_options: Optional[dict[str, Any]] = None -class ActResult(BaseModel): +class ActResult(StagehandBaseModel): """ Result of the 'act' command. @@ -59,48 +63,107 @@ class ActResult(BaseModel): success: bool = Field(..., description="Whether the action was successful.") message: str = Field(..., description="Message from the AI about the action.") - action: str = Field(description="The action command that was executed.") + action: str = Field(..., description="The action command that was executed.") -class ObserveOptions(BaseModel): +class ExtractOptions(StagehandBaseModel): """ - Options for the 'observe' command. + Options for the 'extract' command. Attributes: - instruction (str): Instruction detailing what the AI should observe. - model_name (Optional[AvailableModel]): The model to use for processing. - draw_overlay (Optional[bool]): Whether to draw an overlay on observed elements. - dom_settle_timeout_ms (Optional[int]): Additional time for DOM to settle before observation. + instruction (str): Instruction specifying what data to extract using AI. + model_name (Optional[str]): The model to use for processing. + selector (Optional[str]): CSS selector to limit extraction to. + schema_definition (Union[dict[str, Any], type[BaseModel]]): A JSON schema or Pydantic model that defines the structure of the expected data. + Note: If passing a Pydantic model, invoke its .model_json_schema() method to ensure the schema is JSON serializable. + use_text_extract (Optional[bool]): Whether to use text-based extraction. + dom_settle_timeout_ms (Optional[int]): Additional time for DOM to settle before extraction. """ instruction: str = Field( - ..., description="Instruction detailing what the AI should observe." + ..., description="Instruction specifying what data to extract using AI." ) model_name: Optional[str] = None - draw_overlay: Optional[bool] = None + selector: Optional[str] = None + # IMPORTANT: If using a Pydantic model for schema_definition, please call its .model_json_schema() method + # to convert it to a JSON serializable dictionary before sending it with the extract command. + schema_definition: Union[dict[str, Any], type[BaseModel]] = Field( + default=DEFAULT_EXTRACT_SCHEMA, + description="A JSON schema or Pydantic model that defines the structure of the expected data.", + ) + use_text_extract: Optional[bool] = None dom_settle_timeout_ms: Optional[int] = None - model_client_options: Optional[dict[str, Any]] = None - + model_client_options: Optional[dict[Any, Any]] = None -class ObserveResult(BaseModel): + @field_serializer("schema_definition") + def serialize_schema_definition( + self, schema_definition: Union[dict[str, Any], type[BaseModel]] + ) -> dict[str, Any]: + """Serialize schema_definition to a JSON schema if it's a Pydantic model""" + if isinstance(schema_definition, type) and issubclass( + schema_definition, BaseModel + ): + # Get the JSON schema using default ref_template ('#/$defs/{model}') + schema = schema_definition.model_json_schema() + + defs_key = "$defs" + if defs_key not in schema: + defs_key = "definitions" + if defs_key not in schema: + return schema + + definitions = schema.get(defs_key, {}) + if definitions: + self._resolve_references(schema, definitions, f"#/{defs_key}/") + schema.pop(defs_key, None) + + return schema + + elif isinstance(schema_definition, dict): + return schema_definition + + raise TypeError("schema_definition must be a Pydantic model or a dict") + + def _resolve_references(self, obj: Any, definitions: dict, ref_prefix: str) -> None: + """Recursively resolve $ref references in a schema using definitions.""" + if isinstance(obj, dict): + if "$ref" in obj and obj["$ref"].startswith(ref_prefix): + ref_name = obj["$ref"][len(ref_prefix) :] # Get name after prefix + if ref_name in definitions: + original_keys = {k: v for k, v in obj.items() if k != "$ref"} + resolved_definition = definitions[ref_name].copy() # Use a copy + self._resolve_references( + resolved_definition, definitions, ref_prefix + ) + + obj.clear() + obj.update(resolved_definition) + obj.update(original_keys) + else: + # Recursively process all values in the dictionary + for _, value in obj.items(): + self._resolve_references(value, definitions, ref_prefix) + + elif isinstance(obj, list): + # Process all items in the list + for item in obj: + self._resolve_references(item, definitions, ref_prefix) + + model_config = ConfigDict(arbitrary_types_allowed=True) + + +class ExtractResult(StagehandBaseModel): """ - Result of the 'observe' command. + Result of the 'extract' command. - Attributes: - selector (str): The selector of the observed element. - description (str): The description of the observed element. - backend_node_id (Optional[int]): The backend node ID. - method (Optional[str]): The method to execute. - arguments (Optional[list[str]]): The arguments for the method. + This is a generic model to hold extraction results of different types. + The actual fields will depend on the schema provided in ExtractOptions. """ - selector: str = Field(..., description="The selector of the observed element.") - description: str = Field( - ..., description="The description of the observed element." - ) - backend_node_id: Optional[int] = None - method: Optional[str] = None - arguments: Optional[list[str]] = None + # This class is intentionally left without fields so it can accept + # any fields from the extraction result based on the schema + + model_config = ConfigDict(extra="allow") # Allow any extra fields def __getitem__(self, key): """ @@ -110,45 +173,46 @@ def __getitem__(self, key): return getattr(self, key) -class ExtractOptions(BaseModel): +class ObserveOptions(StagehandBaseModel): """ - Options for the 'extract' command. + Options for the 'observe' command. Attributes: - instruction (str): Instruction specifying what data to extract using AI. - model_name (Optional[AvailableModel]): The model to use for processing. - selector (Optional[str]): CSS selector to limit extraction to. - schema_definition (Union[dict[str, Any], type[BaseModel]]): A JSON schema or Pydantic model that defines the structure of the expected data. - Note: If passing a Pydantic model, invoke its .model_json_schema() method to ensure the schema is JSON serializable. - use_text_extract (Optional[bool]): Whether to use text-based extraction. - dom_settle_timeout_ms (Optional[int]): Additional time for DOM to settle before extraction. + instruction (str): Instruction detailing what the AI should observe. + model_name (Optional[str]): The model to use for processing. + return_action (Optional[bool]): Whether to include action information in the result. + draw_overlay (Optional[bool]): Whether to draw an overlay on observed elements. + dom_settle_timeout_ms (Optional[int]): Additional time for DOM to settle before observation. """ instruction: str = Field( - ..., description="Instruction specifying what data to extract using AI." + ..., description="Instruction detailing what the AI should observe." ) model_name: Optional[str] = None - selector: Optional[str] = None - # IMPORTANT: If using a Pydantic model for schema_definition, please call its .model_json_schema() method - # to convert it to a JSON serializable dictionary before sending it with the extract command. - schema_definition: Union[dict[str, Any], type[BaseModel]] = Field( - default=DefaultExtractSchema, - description="A JSON schema or Pydantic model that defines the structure of the expected data.", - ) - use_text_extract: Optional[bool] = None + draw_overlay: Optional[bool] = None dom_settle_timeout_ms: Optional[int] = None - model_client_options: Optional[dict[Any, Any]] = None + model_client_options: Optional[dict[str, Any]] = None -class ExtractResult(BaseModel): +class ObserveResult(StagehandBaseModel): """ - Result of the 'extract' command. + Result of the 'observe' command. - The 'data' field will contain the Pydantic model instance if a schema was provided - and validation was successful, otherwise it may contain the raw extracted dictionary. + Attributes: + selector (str): The selector of the observed element. + description (str): The description of the observed element. + backend_node_id (Optional[int]): The backend node ID. + method (Optional[str]): The method to execute. + arguments (Optional[list[str]]): The arguments for the method. """ - data: Optional[Any] = None + selector: str = Field(..., description="The selector of the observed element.") + description: str = Field( + ..., description="The description of the observed element." + ) + backend_node_id: Optional[int] = None + method: Optional[str] = None + arguments: Optional[list[str]] = None def __getitem__(self, key): """ diff --git a/tests/conftest.py b/tests/conftest.py index 2ba2809..24c648b 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -5,7 +5,7 @@ from typing import Dict, Any from stagehand import Stagehand, StagehandConfig -from stagehand.schemas import ActResult, ExtractResult, ObserveResult +from stagehand.types import ActResult, ExtractResult, ObserveResult # Set up pytest-asyncio as the default diff --git a/tests/e2e/test_extract_integration.py b/tests/e2e/test_extract_integration.py index d88b51a..2d01802 100644 --- a/tests/e2e/test_extract_integration.py +++ b/tests/e2e/test_extract_integration.py @@ -13,7 +13,7 @@ from pydantic import BaseModel, Field, HttpUrl from stagehand import Stagehand, StagehandConfig -from stagehand.schemas import ExtractOptions +from stagehand.types import ExtractOptions class Article(BaseModel): diff --git a/tests/e2e/test_stagehand_integration.py b/tests/e2e/test_stagehand_integration.py index 0150cfa..a548042 100644 --- a/tests/e2e/test_stagehand_integration.py +++ b/tests/e2e/test_stagehand_integration.py @@ -13,7 +13,7 @@ from pydantic import BaseModel, Field, HttpUrl from stagehand import Stagehand, StagehandConfig -from stagehand.schemas import ExtractOptions +from stagehand.types import ExtractOptions class Company(BaseModel): diff --git a/tests/e2e/test_workflows.py b/tests/e2e/test_workflows.py index a03a06f..4a57501 100644 --- a/tests/e2e/test_workflows.py +++ b/tests/e2e/test_workflows.py @@ -5,7 +5,7 @@ from pydantic import BaseModel from stagehand import Stagehand, StagehandConfig -from stagehand.schemas import ActResult, ObserveResult, ExtractResult +from stagehand.types import ActResult, ObserveResult, ExtractResult from tests.mocks.mock_llm import MockLLMClient from tests.mocks.mock_browser import create_mock_browser_stack, setup_page_with_content from tests.mocks.mock_server import create_mock_server_with_client, setup_successful_session_flow @@ -648,7 +648,7 @@ class ProductList(BaseModel): await stagehand.page.goto("https://electronics-store.com") # Extract with Pydantic schema - from stagehand.schemas import ExtractOptions + from stagehand.types import ExtractOptions extract_options = ExtractOptions( instruction="extract all products with detailed information", diff --git a/tests/integration/api/test_core_api.py b/tests/integration/api/test_core_api.py index f5410e1..2d98f27 100644 --- a/tests/integration/api/test_core_api.py +++ b/tests/integration/api/test_core_api.py @@ -5,7 +5,7 @@ from pydantic import BaseModel, Field from stagehand import Stagehand, StagehandConfig -from stagehand.schemas import ExtractOptions +from stagehand.types import ExtractOptions class Article(BaseModel): diff --git a/tests/unit/core/test_page.py b/tests/unit/core/test_page.py index 777a880..8e9c1f4 100644 --- a/tests/unit/core/test_page.py +++ b/tests/unit/core/test_page.py @@ -5,7 +5,7 @@ from pydantic import BaseModel from stagehand.page import StagehandPage -from stagehand.schemas import ( +from stagehand.types import ( ActOptions, ActResult, ExtractOptions, diff --git a/tests/unit/handlers/test_observe_handler.py b/tests/unit/handlers/test_observe_handler.py index f934e08..4af7e5d 100644 --- a/tests/unit/handlers/test_observe_handler.py +++ b/tests/unit/handlers/test_observe_handler.py @@ -4,7 +4,7 @@ from unittest.mock import AsyncMock, MagicMock, patch from stagehand.handlers.observe_handler import ObserveHandler -from stagehand.schemas import ObserveOptions, ObserveResult +from stagehand.types import ObserveOptions, ObserveResult from tests.mocks.mock_llm import MockLLMClient