|
10 | 10 | field_validator, |
11 | 11 | ) |
12 | 12 |
|
| 13 | +from openhands.sdk.security import risk |
13 | 14 | from openhands.sdk.tool.schema import ActionBase, ObservationBase |
14 | 15 | from openhands.sdk.utils.discriminated_union import ( |
15 | 16 | DiscriminatedUnionMixin, |
@@ -105,16 +106,6 @@ def create(cls, *args, **kwargs) -> "Tool | list[Tool]": |
105 | 106 | """ |
106 | 107 | raise NotImplementedError("Tool.create() must be implemented in subclasses") |
107 | 108 |
|
108 | | - @computed_field(return_type=dict[str, Any], alias="input_schema") |
109 | | - @property |
110 | | - def input_schema(self) -> dict[str, Any]: |
111 | | - return self.action_type.to_mcp_schema() |
112 | | - |
113 | | - @computed_field(return_type=dict[str, Any] | None, alias="output_schema") |
114 | | - @property |
115 | | - def output_schema(self) -> dict[str, Any] | None: |
116 | | - return self.observation_type.to_mcp_schema() if self.observation_type else None |
117 | | - |
118 | 109 | @computed_field(return_type=str, alias="title") |
119 | 110 | @property |
120 | 111 | def title(self) -> str: |
@@ -190,24 +181,47 @@ def to_mcp_tool(self) -> dict[str, Any]: |
190 | 181 | out = { |
191 | 182 | "name": self.name, |
192 | 183 | "description": self.description, |
193 | | - "inputSchema": self.input_schema, |
| 184 | + "inputSchema": self.action_type.to_mcp_schema(), |
194 | 185 | } |
195 | 186 | if self.annotations: |
196 | 187 | out["annotations"] = self.annotations |
197 | 188 | if self.meta is not None: |
198 | 189 | out["_meta"] = self.meta |
199 | | - if self.output_schema: |
200 | | - out["outputSchema"] = self.output_schema |
| 190 | + if self.observation_type: |
| 191 | + out["outputSchema"] = self.observation_type.to_mcp_schema() |
201 | 192 | return out |
202 | 193 |
|
203 | | - def to_openai_tool(self) -> ChatCompletionToolParam: |
204 | | - """Convert an MCP tool to an OpenAI tool.""" |
| 194 | + def to_openai_tool( |
| 195 | + self, |
| 196 | + add_security_risk_prediction: bool = False, |
| 197 | + ) -> ChatCompletionToolParam: |
| 198 | + """Convert a Tool to an OpenAI tool. |
| 199 | +
|
| 200 | + Args: |
| 201 | + add_security_risk_prediction: Whether to add a `security_risk` field |
| 202 | + to the action schema for LLM to predict. This is useful for |
| 203 | + tools that may have safety risks, so the LLM can reason about |
| 204 | + the risk level before calling the tool. |
| 205 | + """ |
| 206 | + |
| 207 | + class ActionTypeWithRisk(self.action_type): |
| 208 | + security_risk: risk.SecurityRisk = Field( |
| 209 | + default=risk.SecurityRisk.UNKNOWN, |
| 210 | + description="The LLM's assessment of the safety risk of this action.", |
| 211 | + ) |
| 212 | + |
| 213 | + # We only add security_risk if the tool is not read-only |
| 214 | + add_security_risk_prediction = add_security_risk_prediction and ( |
| 215 | + self.annotations is None or (not self.annotations.readOnlyHint) |
| 216 | + ) |
205 | 217 | return ChatCompletionToolParam( |
206 | 218 | type="function", |
207 | 219 | function=ChatCompletionToolParamFunctionChunk( |
208 | 220 | name=self.name, |
209 | 221 | description=self.description, |
210 | | - parameters=self.input_schema, |
| 222 | + parameters=ActionTypeWithRisk.to_mcp_schema() |
| 223 | + if add_security_risk_prediction |
| 224 | + else self.action_type.to_mcp_schema(), |
211 | 225 | ), |
212 | 226 | ) |
213 | 227 |
|
|
0 commit comments