Skip to content

Commit 0db5017

Browse files
Do not add security_risk unless security analyzer is enabled (#341)
Co-authored-by: openhands <[email protected]>
1 parent 8c86203 commit 0db5017

File tree

20 files changed

+298
-138
lines changed

20 files changed

+298
-138
lines changed

examples/17_llm_security_analyzer.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,11 @@
55
"""
66

77
import os
8+
import uuid
89

910
from pydantic import SecretStr
1011

11-
from openhands.sdk import LLM, Agent, Conversation, Message, TextContent
12+
from openhands.sdk import LLM, Agent, Conversation, LocalFileStore, Message, TextContent
1213
from openhands.sdk.conversation.state import AgentExecutionStatus
1314
from openhands.sdk.event.utils import get_unmatched_actions
1415
from openhands.sdk.security.llm_analyzer import LLMSecurityAnalyzer
@@ -36,7 +37,12 @@
3637
# Create agent with security analyzer
3738
security_analyzer = LLMSecurityAnalyzer()
3839
agent = Agent(llm=llm, tools=tools, security_analyzer=security_analyzer)
39-
conversation = Conversation(agent=agent)
40+
41+
conversation_id = uuid.uuid4()
42+
file_store = LocalFileStore(f"./.conversations/{conversation_id}")
43+
conversation = Conversation(
44+
agent=agent, conversation_id=conversation_id, persist_filestore=file_store
45+
)
4046

4147
print("\n1) Safe command (LOW risk - should execute automatically)...")
4248
conversation.send_message(

openhands/sdk/agent/agent.py

Lines changed: 34 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@
2929
get_llm_metadata,
3030
)
3131
from openhands.sdk.logger import get_logger
32+
from openhands.sdk.security import risk
33+
from openhands.sdk.security.llm_analyzer import LLMSecurityAnalyzer
3234
from openhands.sdk.tool import (
3335
BUILT_IN_TOOLS,
3436
ActionBase,
@@ -217,7 +219,16 @@ def step(
217219
f"{json.dumps([m.model_dump() for m in _messages], indent=2)}"
218220
)
219221
assert isinstance(self.tools, dict)
220-
tools = [tool.to_openai_tool() for tool in self.tools.values()]
222+
223+
tools = [
224+
# add llm security risk prediction if analyzer is present
225+
tool.to_openai_tool(
226+
add_security_risk_prediction=isinstance(
227+
self.security_analyzer, LLMSecurityAnalyzer
228+
)
229+
)
230+
for tool in self.tools.values()
231+
]
221232
response = self.llm.completion(
222233
messages=_messages,
223234
tools=tools,
@@ -368,10 +379,28 @@ def _get_action_events(
368379
return
369380

370381
# Validate arguments
382+
security_risk: risk.SecurityRisk = risk.SecurityRisk.UNKNOWN
371383
try:
372-
action: ActionBase = tool.action_type.model_validate(
373-
json.loads(tool_call.function.arguments)
374-
)
384+
arguments = json.loads(tool_call.function.arguments)
385+
386+
# if the tool has a security_risk field (when security analyzer = LLM),
387+
# pop it out as it's not part of the tool's action schema
388+
if (_predicted_risk := arguments.pop("security_risk", None)) is not None:
389+
if not isinstance(self.security_analyzer, LLMSecurityAnalyzer):
390+
raise RuntimeError(
391+
"LLM provided a security_risk but no security analyzer is "
392+
"configured - THIS SHOULD NOT HAPPEN!"
393+
)
394+
try:
395+
security_risk = risk.SecurityRisk(_predicted_risk)
396+
except ValueError:
397+
logger.warning(
398+
f"Invalid security_risk value from LLM: {_predicted_risk}"
399+
)
400+
401+
# Arguments we passed in should not contains `security_risk`
402+
# as a field
403+
action: ActionBase = tool.action_type.model_validate(arguments)
375404
except (json.JSONDecodeError, ValidationError) as e:
376405
err = (
377406
f"Error validating args {tool_call.function.arguments} for tool "
@@ -394,6 +423,7 @@ def _get_action_events(
394423
tool_call=tool_call,
395424
llm_response_id=llm_response_id,
396425
metrics=metrics,
426+
security_risk=security_risk,
397427
)
398428
on_event(action_event)
399429
return action_event

openhands/sdk/event/llm_convertible.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from openhands.sdk.event.types import EventID, SourceType, ToolCallID
1111
from openhands.sdk.llm import ImageContent, Message, TextContent, content_to_str
1212
from openhands.sdk.llm.utils.metrics import MetricsSnapshot
13+
from openhands.sdk.security import risk
1314
from openhands.sdk.tool.schema import Action, Observation
1415

1516

@@ -108,6 +109,10 @@ class ActionEvent(LLMConvertibleEvent):
108109
"to the last action when multiple actions share the same LLM response."
109110
),
110111
)
112+
security_risk: risk.SecurityRisk = Field(
113+
default=risk.SecurityRisk.UNKNOWN,
114+
description="The LLM's assessment of the safety risk of this action.",
115+
)
111116

112117
@property
113118
def visualize(self) -> Text:

openhands/sdk/security/analyzer.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
from openhands.sdk.event.llm_convertible import ActionEvent
66
from openhands.sdk.logger import get_logger
77
from openhands.sdk.security.risk import SecurityRisk
8-
from openhands.sdk.tool.schema import Action
98
from openhands.sdk.utils.discriminated_union import (
109
DiscriminatedUnionMixin,
1110
DiscriminatedUnionType,
@@ -26,15 +25,15 @@ class SecurityAnalyzerBase(DiscriminatedUnionMixin, ABC):
2625
"""
2726

2827
@abstractmethod
29-
def security_risk(self, action: Action) -> SecurityRisk:
30-
"""Evaluate the security risk of an action.
28+
def security_risk(self, action: ActionEvent) -> SecurityRisk:
29+
"""Evaluate the security risk of an ActionEvent.
3130
32-
This is the core method that analyzes an action and returns its risk level.
31+
This is the core method that analyzes an ActionEvent and returns its risk level.
3332
Implementations should examine the action's content, context, and potential
3433
impact to determine the appropriate risk level.
3534
3635
Args:
37-
action: The action to analyze for security risks
36+
action: The ActionEvent to analyze for security risks
3837
3938
Returns:
4039
ActionSecurityRisk enum indicating the risk level
@@ -54,7 +53,7 @@ def analyze_event(self, event: Event) -> SecurityRisk | None:
5453
ActionSecurityRisk if event is an action, None otherwise
5554
"""
5655
if isinstance(event, ActionEvent):
57-
return self.security_risk(event.action)
56+
return self.security_risk(event)
5857
return None
5958

6059
def should_require_confirmation(
@@ -103,7 +102,7 @@ def analyze_pending_actions(
103102

104103
for action_event in pending_actions:
105104
try:
106-
risk = self.security_risk(action_event.action)
105+
risk = self.security_risk(action_event)
107106
analyzed_actions.append((action_event, risk))
108107
logger.debug(f"Action {action_event} analyzed with risk level: {risk}")
109108
except Exception as e:

openhands/sdk/security/llm_analyzer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
1+
from openhands.sdk.event import ActionEvent
12
from openhands.sdk.logger import get_logger
23
from openhands.sdk.security.analyzer import SecurityAnalyzer
34
from openhands.sdk.security.risk import SecurityRisk
4-
from openhands.sdk.tool.schema import Action
55

66

77
logger = get_logger(__name__)
@@ -17,7 +17,7 @@ class LLMSecurityAnalyzer(SecurityAnalyzer):
1717
understanding of action context and potential risks.
1818
"""
1919

20-
def security_risk(self, action: Action) -> SecurityRisk:
20+
def security_risk(self, action: ActionEvent) -> SecurityRisk:
2121
"""Evaluate security risk based on LLM-provided assessment.
2222
2323
This method checks if the action has a security_risk attribute set by the LLM

openhands/sdk/tool/schema.py

Lines changed: 0 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
from pydantic import BaseModel, ConfigDict, Field, create_model
55
from rich.text import Text
66

7-
import openhands.sdk.security.risk as risk
87
from openhands.sdk.llm import ImageContent, TextContent
98
from openhands.sdk.llm.message import content_to_str
109
from openhands.sdk.utils.discriminated_union import (
@@ -168,14 +167,6 @@ def from_mcp_schema(
168167
class ActionBase(Schema, DiscriminatedUnionMixin):
169168
"""Base schema for input action."""
170169

171-
# NOTE: We make it optional since some weaker
172-
# LLMs may not be able to fill it out correctly.
173-
# https://github.com/All-Hands-AI/OpenHands/issues/10797
174-
security_risk: risk.SecurityRisk = Field(
175-
default=risk.SecurityRisk.UNKNOWN,
176-
description="The LLM's assessment of the safety risk of this action.",
177-
)
178-
179170
@property
180171
def visualize(self) -> Text:
181172
"""Return Rich Text representation of this action.
@@ -198,23 +189,6 @@ def visualize(self) -> Text:
198189

199190
return content
200191

201-
@classmethod
202-
def to_mcp_schema(cls) -> dict[str, Any]:
203-
"""Convert to JSON schema format compatible with MCP."""
204-
schema = super().to_mcp_schema()
205-
206-
# We need to move the fields from ActionBase to the END of the properties
207-
# We use these properties to generate the llm schema for tool calling
208-
# and we want the ActionBase fields to be at the end
209-
# e.g. LLM should already outputs the argument for tools
210-
# BEFORE it predicts security_risk
211-
assert "properties" in schema, "Schema must have properties"
212-
for field_name in ActionBase.model_fields.keys():
213-
if field_name in schema["properties"]:
214-
v = schema["properties"].pop(field_name)
215-
schema["properties"][field_name] = v
216-
return schema
217-
218192

219193
class MCPActionBase(ActionBase):
220194
"""Base schema for MCP input action."""

openhands/sdk/tool/tool.py

Lines changed: 30 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
field_validator,
1111
)
1212

13+
from openhands.sdk.security import risk
1314
from openhands.sdk.tool.schema import ActionBase, ObservationBase
1415
from openhands.sdk.utils.discriminated_union import (
1516
DiscriminatedUnionMixin,
@@ -105,16 +106,6 @@ def create(cls, *args, **kwargs) -> "Tool | list[Tool]":
105106
"""
106107
raise NotImplementedError("Tool.create() must be implemented in subclasses")
107108

108-
@computed_field(return_type=dict[str, Any], alias="input_schema")
109-
@property
110-
def input_schema(self) -> dict[str, Any]:
111-
return self.action_type.to_mcp_schema()
112-
113-
@computed_field(return_type=dict[str, Any] | None, alias="output_schema")
114-
@property
115-
def output_schema(self) -> dict[str, Any] | None:
116-
return self.observation_type.to_mcp_schema() if self.observation_type else None
117-
118109
@computed_field(return_type=str, alias="title")
119110
@property
120111
def title(self) -> str:
@@ -190,24 +181,47 @@ def to_mcp_tool(self) -> dict[str, Any]:
190181
out = {
191182
"name": self.name,
192183
"description": self.description,
193-
"inputSchema": self.input_schema,
184+
"inputSchema": self.action_type.to_mcp_schema(),
194185
}
195186
if self.annotations:
196187
out["annotations"] = self.annotations
197188
if self.meta is not None:
198189
out["_meta"] = self.meta
199-
if self.output_schema:
200-
out["outputSchema"] = self.output_schema
190+
if self.observation_type:
191+
out["outputSchema"] = self.observation_type.to_mcp_schema()
201192
return out
202193

203-
def to_openai_tool(self) -> ChatCompletionToolParam:
204-
"""Convert an MCP tool to an OpenAI tool."""
194+
def to_openai_tool(
195+
self,
196+
add_security_risk_prediction: bool = False,
197+
) -> ChatCompletionToolParam:
198+
"""Convert a Tool to an OpenAI tool.
199+
200+
Args:
201+
add_security_risk_prediction: Whether to add a `security_risk` field
202+
to the action schema for LLM to predict. This is useful for
203+
tools that may have safety risks, so the LLM can reason about
204+
the risk level before calling the tool.
205+
"""
206+
207+
class ActionTypeWithRisk(self.action_type):
208+
security_risk: risk.SecurityRisk = Field(
209+
default=risk.SecurityRisk.UNKNOWN,
210+
description="The LLM's assessment of the safety risk of this action.",
211+
)
212+
213+
# We only add security_risk if the tool is not read-only
214+
add_security_risk_prediction = add_security_risk_prediction and (
215+
self.annotations is None or (not self.annotations.readOnlyHint)
216+
)
205217
return ChatCompletionToolParam(
206218
type="function",
207219
function=ChatCompletionToolParamFunctionChunk(
208220
name=self.name,
209221
description=self.description,
210-
parameters=self.input_schema,
222+
parameters=ActionTypeWithRisk.to_mcp_schema()
223+
if add_security_risk_prediction
224+
else self.action_type.to_mcp_schema(),
211225
),
212226
)
213227

0 commit comments

Comments
 (0)