@@ -173,6 +173,7 @@ def get_relevant(
173173 session_tag : Optional [str ] = None ,
174174 raw : bool = False ,
175175 distance_threshold : Optional [float ] = None ,
176+ role : Optional [Union [str , List [str ]]] = None ,
176177 ) -> Union [List [str ], List [Dict [str , str ]]]:
177178 """Searches the message history for information semantically related to
178179 the specified prompt.
@@ -195,18 +196,25 @@ def get_relevant(
195196 if no relevant context is found.
196197 raw (bool): Whether to return the full Redis hash entry or just the
197198 message.
199+ role (Optional[Union[str, List[str]]]): Filter messages by role(s).
200+ Can be a single role string ("system", "user", "llm", "tool") or
201+ a list of roles. If None, all roles are returned.
198202
199203 Returns:
200204 Union[List[str], List[Dict[str,str]]: Either a list of strings, or a
201205 list of prompts and responses in JSON containing the most relevant.
202206
203- Raises ValueError: if top_k is not an integer greater or equal to 0.
207+ Raises ValueError: if top_k is not an integer greater or equal to 0,
208+ or if role contains invalid values.
204209 """
205210 if type (top_k ) != int or top_k < 0 :
206211 raise ValueError ("top_k must be an integer greater than or equal to -1" )
207212 if top_k == 0 :
208213 return []
209214
215+ # Validate and normalize role parameter
216+ roles_to_filter = self ._validate_roles (role )
217+
210218 # override distance threshold
211219 distance_threshold = distance_threshold or self ._distance_threshold
212220
@@ -225,21 +233,35 @@ def get_relevant(
225233 else self ._default_session_filter
226234 )
227235
236+ # Combine session filter with role filter if provided
237+ filter_expression = session_filter
238+ if roles_to_filter is not None :
239+ if len (roles_to_filter ) == 1 :
240+ role_filter = Tag (ROLE_FIELD_NAME ) == roles_to_filter [0 ]
241+ else :
242+ # Multiple roles - use OR logic
243+ role_filters = [Tag (ROLE_FIELD_NAME ) == r for r in roles_to_filter ]
244+ role_filter = role_filters [0 ]
245+ for rf in role_filters [1 :]:
246+ role_filter = role_filter | rf
247+
248+ filter_expression = session_filter & role_filter
249+
228250 query = RangeQuery (
229251 vector = self ._vectorizer .embed (prompt ),
230252 vector_field_name = MESSAGE_VECTOR_FIELD_NAME ,
231253 return_fields = return_fields ,
232254 distance_threshold = distance_threshold ,
233255 num_results = top_k ,
234256 return_score = True ,
235- filter_expression = session_filter ,
257+ filter_expression = filter_expression ,
236258 dtype = self ._vectorizer .dtype ,
237259 )
238260 messages = self ._index .query (query )
239261
240262 # if we don't find semantic matches fallback to returning recent context
241263 if not messages and fall_back :
242- return self .get_recent (as_text = as_text , top_k = top_k , raw = raw )
264+ return self .get_recent (as_text = as_text , top_k = top_k , raw = raw , role = role )
243265 if raw :
244266 return messages
245267 return self ._format_context (messages , as_text )
@@ -250,6 +272,7 @@ def get_recent(
250272 as_text : bool = False ,
251273 raw : bool = False ,
252274 session_tag : Optional [str ] = None ,
275+ role : Optional [Union [str , List [str ]]] = None ,
253276 ) -> Union [List [str ], List [Dict [str , str ]]]:
254277 """Retrieve the recent message history in sequential order.
255278
@@ -261,17 +284,24 @@ def get_recent(
261284 prompt and response
262285 session_tag (Optional[str]): Tag of the entries linked to a specific
263286 conversation session. Defaults to instance ULID.
287+ role (Optional[Union[str, List[str]]]): Filter messages by role(s).
288+ Can be a single role string ("system", "user", "llm", "tool") or
289+ a list of roles. If None, all roles are returned.
264290
265291 Returns:
266292 Union[str, List[str]]: A single string transcription of the session
267293 or list of strings if as_text is false.
268294
269295 Raises:
270- ValueError: if top_k is not an integer greater than or equal to 0.
296+ ValueError: if top_k is not an integer greater than or equal to 0,
297+ or if role contains invalid values.
271298 """
272299 if type (top_k ) != int or top_k < 0 :
273300 raise ValueError ("top_k must be an integer greater than or equal to 0" )
274301
302+ # Validate and normalize role parameter
303+ roles_to_filter = self ._validate_roles (role )
304+
275305 return_fields = [
276306 ID_FIELD_NAME ,
277307 SESSION_FIELD_NAME ,
@@ -288,8 +318,22 @@ def get_recent(
288318 else self ._default_session_filter
289319 )
290320
321+ # Combine session filter with role filter if provided
322+ filter_expression = session_filter
323+ if roles_to_filter is not None :
324+ if len (roles_to_filter ) == 1 :
325+ role_filter = Tag (ROLE_FIELD_NAME ) == roles_to_filter [0 ]
326+ else :
327+ # Multiple roles - use OR logic
328+ role_filters = [Tag (ROLE_FIELD_NAME ) == r for r in roles_to_filter ]
329+ role_filter = role_filters [0 ]
330+ for rf in role_filters [1 :]:
331+ role_filter = role_filter | rf
332+
333+ filter_expression = session_filter & role_filter
334+
291335 query = FilterQuery (
292- filter_expression = session_filter ,
336+ filter_expression = filter_expression ,
293337 return_fields = return_fields ,
294338 num_results = top_k ,
295339 )
0 commit comments