Skip to content

Commit 81f7e85

Browse files
committed
feat: add role filtering to message history get_recent method (#349)
Add role parameter to get_recent() and get_relevant() methods in both MessageHistory and SemanticMessageHistory classes to enable filtering messages by role type. Features: - Support single role filtering: role="system" - Support multiple role filtering: role=["system", "user"] - Valid roles: "system", "user", "llm", "tool" - Backward compatible: role=None returns all messages - Works with existing parameters (top_k, session_tag, raw, etc.) - Comprehensive validation with clear error messages The implementation maintains full backward compatibility while enabling users to retrieve only specific message types like system prompts.
1 parent 7607554 commit 81f7e85

File tree

5 files changed

+496
-8
lines changed

5 files changed

+496
-8
lines changed

CLAUDE.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,12 @@ index = SearchIndex(schema, redis_url="redis://localhost:6379")
4848
token.strip().strip(",").replace(""", "").replace(""", "").lower()
4949
```
5050

51+
### Git Operations
52+
**CRITICAL**: NEVER use `git push` or attempt to push to remote repositories. The user will handle all git push operations.
53+
54+
### Code Quality
55+
**IMPORTANT**: Always run `make format` before committing code to ensure proper formatting and linting compliance.
56+
5157
### README.md Maintenance
5258
**IMPORTANT**: DO NOT modify README.md unless explicitly requested.
5359

redisvl/extensions/message_history/base_history.py

Lines changed: 46 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ def get_recent(
6060
as_text: bool = False,
6161
raw: bool = False,
6262
session_tag: Optional[str] = None,
63+
role: Optional[Union[str, List[str]]] = None,
6364
) -> Union[List[str], List[Dict[str, str]]]:
6465
"""Retrieve the recent conversation history in sequential order.
6566
@@ -72,16 +73,60 @@ def get_recent(
7273
prompt and response
7374
session_tag (str): Tag to be added to entries to link to a specific
7475
conversation session. Defaults to instance ULID.
76+
role (Optional[Union[str, List[str]]]): Filter messages by role(s).
77+
Can be a single role string ("system", "user", "llm", "tool") or
78+
a list of roles. If None, all roles are returned.
7579
7680
Returns:
7781
Union[str, List[str]]: A single string transcription of the messages
7882
or list of strings if as_text is false.
7983
8084
Raises:
81-
ValueError: If top_k is not an integer greater than or equal to 0.
85+
ValueError: If top_k is not an integer greater than or equal to 0,
86+
or if role contains invalid values.
8287
"""
8388
raise NotImplementedError
8489

90+
def _validate_roles(
91+
self, role: Optional[Union[str, List[str]]]
92+
) -> Optional[List[str]]:
93+
"""Validate and normalize role parameter.
94+
95+
Args:
96+
role: A single role string, list of roles, or None.
97+
98+
Returns:
99+
List of valid role strings if role is provided, None otherwise.
100+
101+
Raises:
102+
ValueError: If role contains invalid values.
103+
"""
104+
if role is None:
105+
return None
106+
107+
valid_roles = {"system", "user", "llm", "tool"}
108+
109+
# Handle single role string
110+
if isinstance(role, str):
111+
if role not in valid_roles:
112+
raise ValueError(
113+
f"Invalid role '{role}'. Valid roles are: {valid_roles}"
114+
)
115+
return [role]
116+
117+
# Handle list of roles
118+
if isinstance(role, list):
119+
if not role: # Empty list
120+
raise ValueError("roles cannot be empty")
121+
for r in role:
122+
if r not in valid_roles:
123+
raise ValueError(
124+
f"Invalid role '{r}'. Valid roles are: {valid_roles}"
125+
)
126+
return role
127+
128+
raise ValueError("role must be a string or list of strings")
129+
85130
def _format_context(
86131
self, messages: List[Dict[str, Any]], as_text: bool
87132
) -> Union[List[str], List[Dict[str, str]]]:

redisvl/extensions/message_history/message_history.py

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,7 @@ def get_recent(
119119
as_text: bool = False,
120120
raw: bool = False,
121121
session_tag: Optional[str] = None,
122+
role: Optional[Union[str, List[str]]] = None,
122123
) -> Union[List[str], List[Dict[str, str]]]:
123124
"""Retrieve the recent message history in sequential order.
124125
@@ -130,17 +131,24 @@ def get_recent(
130131
prompt and response.
131132
session_tag (Optional[str]): Tag of the entries linked to a specific
132133
conversation session. Defaults to instance ULID.
134+
role (Optional[Union[str, List[str]]]): Filter messages by role(s).
135+
Can be a single role string ("system", "user", "llm", "tool") or
136+
a list of roles. If None, all roles are returned.
133137
134138
Returns:
135139
Union[str, List[str]]: A single string transcription of the messages
136140
or list of strings if as_text is false.
137141
138142
Raises:
139-
ValueError: if top_k is not an integer greater than or equal to 0.
143+
ValueError: if top_k is not an integer greater than or equal to 0,
144+
or if role contains invalid values.
140145
"""
141146
if type(top_k) != int or top_k < 0:
142147
raise ValueError("top_k must be an integer greater than or equal to 0")
143148

149+
# Validate and normalize role parameter
150+
roles_to_filter = self._validate_roles(role)
151+
144152
return_fields = [
145153
ID_FIELD_NAME,
146154
SESSION_FIELD_NAME,
@@ -157,8 +165,22 @@ def get_recent(
157165
else self._default_session_filter
158166
)
159167

168+
# Combine session filter with role filter if provided
169+
filter_expression = session_filter
170+
if roles_to_filter is not None:
171+
if len(roles_to_filter) == 1:
172+
role_filter = Tag(ROLE_FIELD_NAME) == roles_to_filter[0]
173+
else:
174+
# Multiple roles - use OR logic
175+
role_filters = [Tag(ROLE_FIELD_NAME) == r for r in roles_to_filter]
176+
role_filter = role_filters[0]
177+
for rf in role_filters[1:]:
178+
role_filter = role_filter | rf
179+
180+
filter_expression = session_filter & role_filter
181+
160182
query = FilterQuery(
161-
filter_expression=session_filter,
183+
filter_expression=filter_expression,
162184
return_fields=return_fields,
163185
num_results=top_k,
164186
)

redisvl/extensions/message_history/semantic_history.py

Lines changed: 49 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,7 @@ def get_relevant(
173173
session_tag: Optional[str] = None,
174174
raw: bool = False,
175175
distance_threshold: Optional[float] = None,
176+
role: Optional[Union[str, List[str]]] = None,
176177
) -> Union[List[str], List[Dict[str, str]]]:
177178
"""Searches the message history for information semantically related to
178179
the specified prompt.
@@ -195,18 +196,25 @@ def get_relevant(
195196
if no relevant context is found.
196197
raw (bool): Whether to return the full Redis hash entry or just the
197198
message.
199+
role (Optional[Union[str, List[str]]]): Filter messages by role(s).
200+
Can be a single role string ("system", "user", "llm", "tool") or
201+
a list of roles. If None, all roles are returned.
198202
199203
Returns:
200204
Union[List[str], List[Dict[str,str]]: Either a list of strings, or a
201205
list of prompts and responses in JSON containing the most relevant.
202206
203-
Raises ValueError: if top_k is not an integer greater or equal to 0.
207+
Raises ValueError: if top_k is not an integer greater or equal to 0,
208+
or if role contains invalid values.
204209
"""
205210
if type(top_k) != int or top_k < 0:
206211
raise ValueError("top_k must be an integer greater than or equal to -1")
207212
if top_k == 0:
208213
return []
209214

215+
# Validate and normalize role parameter
216+
roles_to_filter = self._validate_roles(role)
217+
210218
# override distance threshold
211219
distance_threshold = distance_threshold or self._distance_threshold
212220

@@ -225,21 +233,35 @@ def get_relevant(
225233
else self._default_session_filter
226234
)
227235

236+
# Combine session filter with role filter if provided
237+
filter_expression = session_filter
238+
if roles_to_filter is not None:
239+
if len(roles_to_filter) == 1:
240+
role_filter = Tag(ROLE_FIELD_NAME) == roles_to_filter[0]
241+
else:
242+
# Multiple roles - use OR logic
243+
role_filters = [Tag(ROLE_FIELD_NAME) == r for r in roles_to_filter]
244+
role_filter = role_filters[0]
245+
for rf in role_filters[1:]:
246+
role_filter = role_filter | rf
247+
248+
filter_expression = session_filter & role_filter
249+
228250
query = RangeQuery(
229251
vector=self._vectorizer.embed(prompt),
230252
vector_field_name=MESSAGE_VECTOR_FIELD_NAME,
231253
return_fields=return_fields,
232254
distance_threshold=distance_threshold,
233255
num_results=top_k,
234256
return_score=True,
235-
filter_expression=session_filter,
257+
filter_expression=filter_expression,
236258
dtype=self._vectorizer.dtype,
237259
)
238260
messages = self._index.query(query)
239261

240262
# if we don't find semantic matches fallback to returning recent context
241263
if not messages and fall_back:
242-
return self.get_recent(as_text=as_text, top_k=top_k, raw=raw)
264+
return self.get_recent(as_text=as_text, top_k=top_k, raw=raw, role=role)
243265
if raw:
244266
return messages
245267
return self._format_context(messages, as_text)
@@ -250,6 +272,7 @@ def get_recent(
250272
as_text: bool = False,
251273
raw: bool = False,
252274
session_tag: Optional[str] = None,
275+
role: Optional[Union[str, List[str]]] = None,
253276
) -> Union[List[str], List[Dict[str, str]]]:
254277
"""Retrieve the recent message history in sequential order.
255278
@@ -261,17 +284,24 @@ def get_recent(
261284
prompt and response
262285
session_tag (Optional[str]): Tag of the entries linked to a specific
263286
conversation session. Defaults to instance ULID.
287+
role (Optional[Union[str, List[str]]]): Filter messages by role(s).
288+
Can be a single role string ("system", "user", "llm", "tool") or
289+
a list of roles. If None, all roles are returned.
264290
265291
Returns:
266292
Union[str, List[str]]: A single string transcription of the session
267293
or list of strings if as_text is false.
268294
269295
Raises:
270-
ValueError: if top_k is not an integer greater than or equal to 0.
296+
ValueError: if top_k is not an integer greater than or equal to 0,
297+
or if role contains invalid values.
271298
"""
272299
if type(top_k) != int or top_k < 0:
273300
raise ValueError("top_k must be an integer greater than or equal to 0")
274301

302+
# Validate and normalize role parameter
303+
roles_to_filter = self._validate_roles(role)
304+
275305
return_fields = [
276306
ID_FIELD_NAME,
277307
SESSION_FIELD_NAME,
@@ -288,8 +318,22 @@ def get_recent(
288318
else self._default_session_filter
289319
)
290320

321+
# Combine session filter with role filter if provided
322+
filter_expression = session_filter
323+
if roles_to_filter is not None:
324+
if len(roles_to_filter) == 1:
325+
role_filter = Tag(ROLE_FIELD_NAME) == roles_to_filter[0]
326+
else:
327+
# Multiple roles - use OR logic
328+
role_filters = [Tag(ROLE_FIELD_NAME) == r for r in roles_to_filter]
329+
role_filter = role_filters[0]
330+
for rf in role_filters[1:]:
331+
role_filter = role_filter | rf
332+
333+
filter_expression = session_filter & role_filter
334+
291335
query = FilterQuery(
292-
filter_expression=session_filter,
336+
filter_expression=filter_expression,
293337
return_fields=return_fields,
294338
num_results=top_k,
295339
)

0 commit comments

Comments
 (0)