Skip to content

Commit 9e93f68

Browse files
committed
feat: add skip_decode parameter to return_fields method (#252)
Implements skip_decode parameter for return_fields() method to improve field deserialization UX. This allows users to skip decoding of binary fields like embeddings while still returning them in query results. - Added optional skip_decode parameter to BaseQuery.return_fields() - Parameter accepts string or list of field names to skip decoding - Maintains backward compatibility when skip_decode is not provided - Comprehensive unit test coverage for all query types - Enhanced skip_decode to use parent's return_field with decode_field=False - Added comprehensive integration tests with real Redis - Maintained full backward compatibility with return_field(decode_field=False) - Tests confirm proper binary field handling (embeddings, image data)
1 parent 27dabc0 commit 9e93f68

File tree

3 files changed

+551
-0
lines changed

3 files changed

+551
-0
lines changed

redisvl/query/query.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,9 @@ def __init__(self, query_string: str = "*"):
4747
# has not been built yet.
4848
self._built_query_string = None
4949

50+
# Initialize skip_decode_fields set
51+
self._skip_decode_fields: Set[str] = set()
52+
5053
def __str__(self) -> str:
5154
"""Return the string representation of the query."""
5255
return " ".join([str(x) for x in self.get_args()])
@@ -107,6 +110,58 @@ def _query_string(self, value: Optional[str]):
107110
"""Setter for _query_string to maintain compatibility with parent class."""
108111
self._built_query_string = value
109112

113+
def return_fields(
114+
self, *fields, skip_decode: Optional[Union[str, List[str]]] = None
115+
):
116+
"""
117+
Set the fields to return with search results.
118+
119+
Args:
120+
*fields: Variable number of field names to return.
121+
skip_decode: Optional field name or list of field names that should not be
122+
decoded. Useful for binary data like embeddings.
123+
124+
Returns:
125+
self: Returns the query object for method chaining.
126+
127+
Raises:
128+
TypeError: If skip_decode is not a string, list, or None.
129+
"""
130+
# Only clear fields when skip_decode is provided (indicating user is explicitly setting fields)
131+
# This preserves backward compatibility when return_fields is called multiple times
132+
if skip_decode is not None:
133+
# Clear existing fields to provide replacement behavior
134+
self._return_fields = []
135+
self._return_fields_decode_as = {}
136+
137+
# Process skip_decode parameter to prepare decode settings
138+
if isinstance(skip_decode, str):
139+
skip_decode_set = {skip_decode}
140+
self._skip_decode_fields = {skip_decode}
141+
elif isinstance(skip_decode, list):
142+
skip_decode_set = set(skip_decode)
143+
self._skip_decode_fields = set(skip_decode)
144+
else:
145+
raise TypeError("skip_decode must be a string or list of strings")
146+
147+
# Add fields using parent's return_field method with proper decode settings
148+
for field in fields:
149+
if field in skip_decode_set:
150+
# Use return_field with decode_field=False for skip_decode fields
151+
super().return_field(field, decode_field=False)
152+
else:
153+
# Use normal return_field for other fields
154+
super().return_field(field)
155+
else:
156+
# Standard additive behavior (backward compatible)
157+
super().return_fields(*fields)
158+
159+
# Initialize skip_decode_fields if not already set
160+
if not hasattr(self, "_skip_decode_fields"):
161+
self._skip_decode_fields = set()
162+
163+
return self
164+
110165

111166
class FilterQuery(BaseQuery):
112167
def __init__(
Lines changed: 325 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,325 @@
1+
"""Integration tests for skip_decode parameter in query return_fields (issue #252)."""
2+
3+
import numpy as np
4+
import pytest
5+
from redis import Redis
6+
7+
from redisvl.exceptions import RedisSearchError
8+
from redisvl.index import SearchIndex
9+
from redisvl.query import FilterQuery, RangeQuery, VectorQuery
10+
from redisvl.schema import IndexSchema
11+
12+
13+
@pytest.fixture
14+
def sample_schema():
15+
"""Create a sample schema with various field types."""
16+
return IndexSchema.from_dict(
17+
{
18+
"index": {
19+
"name": "test_skip_decode",
20+
"prefix": "doc",
21+
"storage_type": "hash",
22+
},
23+
"fields": [
24+
{"name": "title", "type": "text"},
25+
{"name": "year", "type": "numeric"},
26+
{"name": "description", "type": "text"},
27+
{
28+
"name": "embedding",
29+
"type": "vector",
30+
"attrs": {
31+
"dims": 128,
32+
"algorithm": "flat",
33+
"distance_metric": "cosine",
34+
},
35+
},
36+
{
37+
"name": "image_data",
38+
"type": "tag",
39+
}, # Will store binary data as tag
40+
],
41+
}
42+
)
43+
44+
45+
@pytest.fixture
46+
def search_index(redis_url, sample_schema):
47+
"""Create and populate a test index."""
48+
index = SearchIndex(sample_schema, redis_url=redis_url)
49+
50+
# Clear any existing data
51+
try:
52+
index.delete(drop=True)
53+
except RedisSearchError:
54+
pass # Index may not exist, which is fine
55+
56+
# Create the index
57+
index.create(overwrite=True)
58+
59+
# Populate with test data
60+
data = []
61+
for i in range(5):
62+
embedding_vector = np.random.rand(128).astype(np.float32)
63+
doc = {
64+
"title": f"Document {i}",
65+
"year": 2020 + i,
66+
"description": f"This is document number {i}",
67+
"embedding": embedding_vector.tobytes(), # Store as binary
68+
"image_data": f"binary_image_{i}".encode("utf-8"), # Store as binary
69+
}
70+
data.append(doc)
71+
72+
# Load data into Redis
73+
index.load(data, id_field="title")
74+
75+
yield index
76+
77+
# Cleanup
78+
try:
79+
index.delete(drop=True)
80+
except RedisSearchError:
81+
pass # Ignore cleanup errors
82+
83+
84+
class TestSkipDecodeIntegration:
85+
"""Integration tests for skip_decode functionality with real Redis."""
86+
87+
def test_filter_query_skip_decode_single_field(self, search_index):
88+
"""Test FilterQuery with skip_decode for embedding field."""
89+
query = FilterQuery(num_results=10)
90+
query.return_fields("title", "year", "embedding", skip_decode=["embedding"])
91+
92+
results = search_index.query(query)
93+
94+
# Verify we got results
95+
assert len(results) > 0
96+
97+
# Check first result
98+
first_result = results[0]
99+
assert "title" in first_result
100+
assert "year" in first_result
101+
assert "embedding" in first_result
102+
103+
# Title and year should be decoded strings
104+
assert isinstance(first_result["title"], str)
105+
assert isinstance(first_result["year"], str) # Redis returns as string
106+
107+
# Embedding should remain as bytes (not decoded)
108+
assert isinstance(first_result["embedding"], bytes)
109+
110+
def test_filter_query_skip_decode_multiple_fields(self, search_index):
111+
"""Test FilterQuery with skip_decode for multiple binary fields."""
112+
query = FilterQuery(num_results=10)
113+
query.return_fields(
114+
"title",
115+
"year",
116+
"embedding",
117+
"image_data",
118+
skip_decode=["embedding", "image_data"],
119+
)
120+
121+
results = search_index.query(query)
122+
123+
assert len(results) > 0
124+
125+
first_result = results[0]
126+
# Decoded fields
127+
assert isinstance(first_result["title"], str)
128+
assert isinstance(first_result["year"], str)
129+
130+
# Non-decoded fields (should be bytes)
131+
assert isinstance(first_result["embedding"], bytes)
132+
assert isinstance(first_result["image_data"], bytes)
133+
134+
def test_filter_query_no_skip_decode_default(self, search_index):
135+
"""Test FilterQuery without skip_decode (default behavior)."""
136+
query = FilterQuery(num_results=10)
137+
query.return_fields("title", "year", "description")
138+
139+
results = search_index.query(query)
140+
141+
assert len(results) > 0
142+
143+
first_result = results[0]
144+
# All fields should be decoded to strings
145+
assert isinstance(first_result["title"], str)
146+
assert isinstance(first_result["year"], str)
147+
assert isinstance(first_result["description"], str)
148+
149+
def test_vector_query_skip_decode(self, search_index):
150+
"""Test VectorQuery with skip_decode for embedding field."""
151+
# Create a random query vector
152+
query_vector = np.random.rand(128).astype(np.float32)
153+
154+
query = VectorQuery(
155+
vector=query_vector.tolist(),
156+
vector_field_name="embedding",
157+
return_fields=None, # Will set with method
158+
num_results=3,
159+
return_score=True, # Explicitly request distance score
160+
)
161+
162+
# Use skip_decode for embedding
163+
query.return_fields("title", "embedding", skip_decode=["embedding"])
164+
165+
results = search_index.query(query)
166+
167+
assert len(results) > 0
168+
169+
for result in results:
170+
assert isinstance(result["title"], str)
171+
# Embedding should be bytes (not decoded)
172+
assert isinstance(result["embedding"], bytes)
173+
# Distance score is added automatically by VectorQuery when return_score=True
174+
# but may not be in the result dict, just check the fields we requested
175+
176+
def test_range_query_skip_decode(self, search_index):
177+
"""Test RangeQuery with skip_decode for binary fields."""
178+
# Create a random query vector
179+
query_vector = np.random.rand(128).astype(np.float32)
180+
181+
query = RangeQuery(
182+
vector=query_vector.tolist(),
183+
vector_field_name="embedding",
184+
distance_threshold=1.0,
185+
return_fields=None,
186+
num_results=10,
187+
)
188+
189+
query.return_fields("title", "year", "embedding", skip_decode=["embedding"])
190+
191+
results = search_index.query(query)
192+
193+
if len(results) > 0: # Range query might not return results
194+
first_result = results[0]
195+
assert isinstance(first_result["title"], str)
196+
assert isinstance(first_result["year"], str)
197+
assert isinstance(first_result["embedding"], bytes)
198+
199+
def test_backward_compat_return_field_decode_false(self, search_index):
200+
"""Test backward compatibility with return_field(decode_field=False)."""
201+
query = FilterQuery(num_results=10)
202+
203+
# Use old API - return_field with decode_field=False
204+
query.return_field("embedding", decode_field=False)
205+
query.return_field("image_data", decode_field=False)
206+
query.return_fields("title", "year") # These should be decoded
207+
208+
results = search_index.query(query)
209+
210+
assert len(results) > 0
211+
212+
first_result = results[0]
213+
# Decoded fields
214+
assert isinstance(first_result["title"], str)
215+
assert isinstance(first_result["year"], str)
216+
217+
# Non-decoded fields (using old API)
218+
assert isinstance(first_result["embedding"], bytes)
219+
assert isinstance(first_result["image_data"], bytes)
220+
221+
def test_mixed_api_usage(self, search_index):
222+
"""Test mixing old and new API calls."""
223+
query = FilterQuery(num_results=10)
224+
225+
# First use old API
226+
query.return_field("image_data", decode_field=False)
227+
228+
# Then use new API with skip_decode
229+
query.return_fields("title", "year", "embedding", skip_decode=["embedding"])
230+
231+
results = search_index.query(query)
232+
233+
assert len(results) > 0
234+
235+
first_result = results[0]
236+
# The new API call should have replaced everything
237+
# (when skip_decode is provided, it clears previous fields)
238+
assert "title" in first_result
239+
assert "year" in first_result
240+
assert "embedding" in first_result
241+
242+
# image_data should not be in results since return_fields
243+
# with skip_decode clears previous fields
244+
assert "image_data" not in first_result
245+
246+
def test_skip_decode_with_empty_list(self, search_index):
247+
"""Test skip_decode with empty list (all fields decoded)."""
248+
query = FilterQuery(num_results=10)
249+
query.return_fields("title", "year", "description", skip_decode=[])
250+
251+
results = search_index.query(query)
252+
253+
assert len(results) > 0
254+
255+
first_result = results[0]
256+
# All fields should be decoded
257+
assert isinstance(first_result["title"], str)
258+
assert isinstance(first_result["year"], str)
259+
assert isinstance(first_result["description"], str)
260+
261+
def test_skip_decode_with_string_parameter(self, search_index):
262+
"""Test skip_decode accepts a single string instead of list."""
263+
query = FilterQuery(num_results=10)
264+
265+
# Pass a single string instead of list
266+
query.return_fields("title", "embedding", skip_decode="embedding")
267+
268+
results = search_index.query(query)
269+
270+
assert len(results) > 0
271+
272+
first_result = results[0]
273+
assert isinstance(first_result["title"], str)
274+
# Embedding should be bytes (not decoded)
275+
assert isinstance(first_result["embedding"], bytes)
276+
277+
def test_multiple_calls_without_skip_decode(self, search_index):
278+
"""Test multiple return_fields calls without skip_decode (additive behavior)."""
279+
query = FilterQuery(num_results=10)
280+
281+
# Multiple calls without skip_decode should be additive
282+
query.return_fields("title")
283+
query.return_fields("year")
284+
query.return_field("embedding", decode_field=False)
285+
286+
results = search_index.query(query)
287+
288+
assert len(results) > 0
289+
290+
first_result = results[0]
291+
# All fields should be present (additive behavior)
292+
assert "title" in first_result
293+
assert "year" in first_result
294+
assert "embedding" in first_result
295+
296+
# Check types
297+
assert isinstance(first_result["title"], str)
298+
assert isinstance(first_result["year"], str)
299+
assert isinstance(first_result["embedding"], bytes)
300+
301+
def test_replacement_behavior_with_skip_decode(self, search_index):
302+
"""Test that skip_decode parameter triggers replacement behavior."""
303+
query = FilterQuery(num_results=10)
304+
305+
# First set some fields
306+
query.return_fields("title", "description")
307+
308+
# Then call with skip_decode - should replace, not add
309+
query.return_fields("year", "embedding", skip_decode=["embedding"])
310+
311+
results = search_index.query(query)
312+
313+
assert len(results) > 0
314+
315+
first_result = results[0]
316+
# Only fields from second call should be present
317+
assert "year" in first_result
318+
assert "embedding" in first_result
319+
320+
# Fields from first call should NOT be present (replaced)
321+
assert "title" not in first_result
322+
assert "description" not in first_result
323+
324+
# Check embedding is not decoded
325+
assert isinstance(first_result["embedding"], bytes)

0 commit comments

Comments
 (0)