From 638f95a153eb1437c8b2cffd7de58cbc8b37743f Mon Sep 17 00:00:00 2001 From: Brian Sam-Bodden Date: Fri, 26 Sep 2025 21:15:40 -0700 Subject: [PATCH] feat: add skip_decode parameter to return_fields method (#252) Implements skip_decode parameter for return_fields() method to improve field deserialization UX. This allows users to skip decoding of binary fields like embeddings while still returning them in query results. - Added optional skip_decode parameter to BaseQuery.return_fields() - Parameter accepts string or list of field names to skip decoding - Maintains backward compatibility when skip_decode is not provided - Comprehensive unit test coverage for all query types - Enhanced skip_decode to use parent's return_field with decode_field=False - Added comprehensive integration tests with real Redis - Maintained full backward compatibility with return_field(decode_field=False) - Tests confirm proper binary field handling (embeddings, image data) --- redisvl/query/query.py | 55 +++ .../test_skip_decode_fields_integration.py | 325 ++++++++++++++++++ tests/unit/test_skip_decode_fields.py | 171 +++++++++ 3 files changed, 551 insertions(+) create mode 100644 tests/integration/test_skip_decode_fields_integration.py create mode 100644 tests/unit/test_skip_decode_fields.py diff --git a/redisvl/query/query.py b/redisvl/query/query.py index 28dea9e0..d2eb5ee6 100644 --- a/redisvl/query/query.py +++ b/redisvl/query/query.py @@ -47,6 +47,9 @@ def __init__(self, query_string: str = "*"): # has not been built yet. self._built_query_string = None + # Initialize skip_decode_fields set + self._skip_decode_fields: Set[str] = set() + def __str__(self) -> str: """Return the string representation of the query.""" return " ".join([str(x) for x in self.get_args()]) @@ -107,6 +110,58 @@ def _query_string(self, value: Optional[str]): """Setter for _query_string to maintain compatibility with parent class.""" self._built_query_string = value + def return_fields( + self, *fields, skip_decode: Optional[Union[str, List[str]]] = None + ): + """ + Set the fields to return with search results. + + Args: + *fields: Variable number of field names to return. + skip_decode: Optional field name or list of field names that should not be + decoded. Useful for binary data like embeddings. + + Returns: + self: Returns the query object for method chaining. + + Raises: + TypeError: If skip_decode is not a string, list, or None. + """ + # Only clear fields when skip_decode is provided (indicating user is explicitly setting fields) + # This preserves backward compatibility when return_fields is called multiple times + if skip_decode is not None: + # Clear existing fields to provide replacement behavior + self._return_fields = [] + self._return_fields_decode_as = {} + + # Process skip_decode parameter to prepare decode settings + if isinstance(skip_decode, str): + skip_decode_set = {skip_decode} + self._skip_decode_fields = {skip_decode} + elif isinstance(skip_decode, list): + skip_decode_set = set(skip_decode) + self._skip_decode_fields = set(skip_decode) + else: + raise TypeError("skip_decode must be a string or list of strings") + + # Add fields using parent's return_field method with proper decode settings + for field in fields: + if field in skip_decode_set: + # Use return_field with decode_field=False for skip_decode fields + super().return_field(field, decode_field=False) + else: + # Use normal return_field for other fields + super().return_field(field) + else: + # Standard additive behavior (backward compatible) + super().return_fields(*fields) + + # Initialize skip_decode_fields if not already set + if not hasattr(self, "_skip_decode_fields"): + self._skip_decode_fields = set() + + return self + class FilterQuery(BaseQuery): def __init__( diff --git a/tests/integration/test_skip_decode_fields_integration.py b/tests/integration/test_skip_decode_fields_integration.py new file mode 100644 index 00000000..6f868ce4 --- /dev/null +++ b/tests/integration/test_skip_decode_fields_integration.py @@ -0,0 +1,325 @@ +"""Integration tests for skip_decode parameter in query return_fields (issue #252).""" + +import numpy as np +import pytest +from redis import Redis + +from redisvl.exceptions import RedisSearchError +from redisvl.index import SearchIndex +from redisvl.query import FilterQuery, RangeQuery, VectorQuery +from redisvl.schema import IndexSchema + + +@pytest.fixture +def sample_schema(): + """Create a sample schema with various field types.""" + return IndexSchema.from_dict( + { + "index": { + "name": "test_skip_decode", + "prefix": "doc", + "storage_type": "hash", + }, + "fields": [ + {"name": "title", "type": "text"}, + {"name": "year", "type": "numeric"}, + {"name": "description", "type": "text"}, + { + "name": "embedding", + "type": "vector", + "attrs": { + "dims": 128, + "algorithm": "flat", + "distance_metric": "cosine", + }, + }, + { + "name": "image_data", + "type": "tag", + }, # Will store binary data as tag + ], + } + ) + + +@pytest.fixture +def search_index(redis_url, sample_schema): + """Create and populate a test index.""" + index = SearchIndex(sample_schema, redis_url=redis_url) + + # Clear any existing data + try: + index.delete(drop=True) + except RedisSearchError: + pass # Index may not exist, which is fine + + # Create the index + index.create(overwrite=True) + + # Populate with test data + data = [] + for i in range(5): + embedding_vector = np.random.rand(128).astype(np.float32) + doc = { + "title": f"Document {i}", + "year": 2020 + i, + "description": f"This is document number {i}", + "embedding": embedding_vector.tobytes(), # Store as binary + "image_data": f"binary_image_{i}".encode("utf-8"), # Store as binary + } + data.append(doc) + + # Load data into Redis + index.load(data, id_field="title") + + yield index + + # Cleanup + try: + index.delete(drop=True) + except RedisSearchError: + pass # Ignore cleanup errors + + +class TestSkipDecodeIntegration: + """Integration tests for skip_decode functionality with real Redis.""" + + def test_filter_query_skip_decode_single_field(self, search_index): + """Test FilterQuery with skip_decode for embedding field.""" + query = FilterQuery(num_results=10) + query.return_fields("title", "year", "embedding", skip_decode=["embedding"]) + + results = search_index.query(query) + + # Verify we got results + assert len(results) > 0 + + # Check first result + first_result = results[0] + assert "title" in first_result + assert "year" in first_result + assert "embedding" in first_result + + # Title and year should be decoded strings + assert isinstance(first_result["title"], str) + assert isinstance(first_result["year"], str) # Redis returns as string + + # Embedding should remain as bytes (not decoded) + assert isinstance(first_result["embedding"], bytes) + + def test_filter_query_skip_decode_multiple_fields(self, search_index): + """Test FilterQuery with skip_decode for multiple binary fields.""" + query = FilterQuery(num_results=10) + query.return_fields( + "title", + "year", + "embedding", + "image_data", + skip_decode=["embedding", "image_data"], + ) + + results = search_index.query(query) + + assert len(results) > 0 + + first_result = results[0] + # Decoded fields + assert isinstance(first_result["title"], str) + assert isinstance(first_result["year"], str) + + # Non-decoded fields (should be bytes) + assert isinstance(first_result["embedding"], bytes) + assert isinstance(first_result["image_data"], bytes) + + def test_filter_query_no_skip_decode_default(self, search_index): + """Test FilterQuery without skip_decode (default behavior).""" + query = FilterQuery(num_results=10) + query.return_fields("title", "year", "description") + + results = search_index.query(query) + + assert len(results) > 0 + + first_result = results[0] + # All fields should be decoded to strings + assert isinstance(first_result["title"], str) + assert isinstance(first_result["year"], str) + assert isinstance(first_result["description"], str) + + def test_vector_query_skip_decode(self, search_index): + """Test VectorQuery with skip_decode for embedding field.""" + # Create a random query vector + query_vector = np.random.rand(128).astype(np.float32) + + query = VectorQuery( + vector=query_vector.tolist(), + vector_field_name="embedding", + return_fields=None, # Will set with method + num_results=3, + return_score=True, # Explicitly request distance score + ) + + # Use skip_decode for embedding + query.return_fields("title", "embedding", skip_decode=["embedding"]) + + results = search_index.query(query) + + assert len(results) > 0 + + for result in results: + assert isinstance(result["title"], str) + # Embedding should be bytes (not decoded) + assert isinstance(result["embedding"], bytes) + # Distance score is added automatically by VectorQuery when return_score=True + # but may not be in the result dict, just check the fields we requested + + def test_range_query_skip_decode(self, search_index): + """Test RangeQuery with skip_decode for binary fields.""" + # Create a random query vector + query_vector = np.random.rand(128).astype(np.float32) + + query = RangeQuery( + vector=query_vector.tolist(), + vector_field_name="embedding", + distance_threshold=1.0, + return_fields=None, + num_results=10, + ) + + query.return_fields("title", "year", "embedding", skip_decode=["embedding"]) + + results = search_index.query(query) + + if len(results) > 0: # Range query might not return results + first_result = results[0] + assert isinstance(first_result["title"], str) + assert isinstance(first_result["year"], str) + assert isinstance(first_result["embedding"], bytes) + + def test_backward_compat_return_field_decode_false(self, search_index): + """Test backward compatibility with return_field(decode_field=False).""" + query = FilterQuery(num_results=10) + + # Use old API - return_field with decode_field=False + query.return_field("embedding", decode_field=False) + query.return_field("image_data", decode_field=False) + query.return_fields("title", "year") # These should be decoded + + results = search_index.query(query) + + assert len(results) > 0 + + first_result = results[0] + # Decoded fields + assert isinstance(first_result["title"], str) + assert isinstance(first_result["year"], str) + + # Non-decoded fields (using old API) + assert isinstance(first_result["embedding"], bytes) + assert isinstance(first_result["image_data"], bytes) + + def test_mixed_api_usage(self, search_index): + """Test mixing old and new API calls.""" + query = FilterQuery(num_results=10) + + # First use old API + query.return_field("image_data", decode_field=False) + + # Then use new API with skip_decode + query.return_fields("title", "year", "embedding", skip_decode=["embedding"]) + + results = search_index.query(query) + + assert len(results) > 0 + + first_result = results[0] + # The new API call should have replaced everything + # (when skip_decode is provided, it clears previous fields) + assert "title" in first_result + assert "year" in first_result + assert "embedding" in first_result + + # image_data should not be in results since return_fields + # with skip_decode clears previous fields + assert "image_data" not in first_result + + def test_skip_decode_with_empty_list(self, search_index): + """Test skip_decode with empty list (all fields decoded).""" + query = FilterQuery(num_results=10) + query.return_fields("title", "year", "description", skip_decode=[]) + + results = search_index.query(query) + + assert len(results) > 0 + + first_result = results[0] + # All fields should be decoded + assert isinstance(first_result["title"], str) + assert isinstance(first_result["year"], str) + assert isinstance(first_result["description"], str) + + def test_skip_decode_with_string_parameter(self, search_index): + """Test skip_decode accepts a single string instead of list.""" + query = FilterQuery(num_results=10) + + # Pass a single string instead of list + query.return_fields("title", "embedding", skip_decode="embedding") + + results = search_index.query(query) + + assert len(results) > 0 + + first_result = results[0] + assert isinstance(first_result["title"], str) + # Embedding should be bytes (not decoded) + assert isinstance(first_result["embedding"], bytes) + + def test_multiple_calls_without_skip_decode(self, search_index): + """Test multiple return_fields calls without skip_decode (additive behavior).""" + query = FilterQuery(num_results=10) + + # Multiple calls without skip_decode should be additive + query.return_fields("title") + query.return_fields("year") + query.return_field("embedding", decode_field=False) + + results = search_index.query(query) + + assert len(results) > 0 + + first_result = results[0] + # All fields should be present (additive behavior) + assert "title" in first_result + assert "year" in first_result + assert "embedding" in first_result + + # Check types + assert isinstance(first_result["title"], str) + assert isinstance(first_result["year"], str) + assert isinstance(first_result["embedding"], bytes) + + def test_replacement_behavior_with_skip_decode(self, search_index): + """Test that skip_decode parameter triggers replacement behavior.""" + query = FilterQuery(num_results=10) + + # First set some fields + query.return_fields("title", "description") + + # Then call with skip_decode - should replace, not add + query.return_fields("year", "embedding", skip_decode=["embedding"]) + + results = search_index.query(query) + + assert len(results) > 0 + + first_result = results[0] + # Only fields from second call should be present + assert "year" in first_result + assert "embedding" in first_result + + # Fields from first call should NOT be present (replaced) + assert "title" not in first_result + assert "description" not in first_result + + # Check embedding is not decoded + assert isinstance(first_result["embedding"], bytes) diff --git a/tests/unit/test_skip_decode_fields.py b/tests/unit/test_skip_decode_fields.py new file mode 100644 index 00000000..258191d6 --- /dev/null +++ b/tests/unit/test_skip_decode_fields.py @@ -0,0 +1,171 @@ +"""Unit tests for skip_decode parameter in query return_fields (issue #252).""" + +import pytest + +from redisvl.query import FilterQuery, RangeQuery, VectorQuery + + +class TestSkipDecodeFields: + """Test the skip_decode parameter for return_fields method.""" + + def test_filter_query_skip_decode_single_field(self): + """Test FilterQuery with skip_decode for a single field.""" + query = FilterQuery(num_results=10) + + # Use the new skip_decode parameter + query.return_fields("title", "year", "embedding", skip_decode=["embedding"]) + + # Check that fields are added correctly + assert hasattr(query, "_return_fields") + assert "title" in query._return_fields + assert "year" in query._return_fields + assert "embedding" in query._return_fields + + # Check that decode settings are tracked + assert hasattr(query, "_skip_decode_fields") + assert "embedding" in query._skip_decode_fields + + def test_filter_query_skip_decode_multiple_fields(self): + """Test FilterQuery with skip_decode for multiple fields.""" + query = FilterQuery(num_results=10) + + # Use skip_decode with multiple fields + query.return_fields( + "title", + "year", + "embedding", + "image_data", + skip_decode=["embedding", "image_data"], + ) + + # Check that all fields are added + assert len(query._return_fields) == 4 + + # Check that both fields are in skip_decode + assert "embedding" in query._skip_decode_fields + assert "image_data" in query._skip_decode_fields + + def test_vector_query_skip_decode_single_field(self): + """Test VectorQuery with skip_decode parameter.""" + query = VectorQuery( + vector=[0.1, 0.2, 0.3], + vector_field_name="vector_field", + return_fields=None, # Will set with method + num_results=5, + ) + + # Use the new API + query.return_fields( + "id", "vector_field", "metadata", skip_decode=["vector_field"] + ) + + # Check fields + assert "id" in query._return_fields + assert "vector_field" in query._return_fields + assert "metadata" in query._return_fields + + # Check skip_decode + assert hasattr(query, "_skip_decode_fields") + assert "vector_field" in query._skip_decode_fields + + def test_range_query_skip_decode(self): + """Test RangeQuery with skip_decode parameter.""" + query = RangeQuery( + vector=[0.1, 0.2, 0.3], + vector_field_name="embedding", + distance_threshold=0.5, + return_fields=None, + num_results=10, + ) + + # Use skip_decode + query.return_fields("doc_id", "text", "embedding", skip_decode=["embedding"]) + + # Verify + assert "doc_id" in query._return_fields + assert "text" in query._return_fields + assert "embedding" in query._return_fields + assert "embedding" in query._skip_decode_fields + + def test_skip_decode_empty_list(self): + """Test skip_decode with empty list (all fields should be decoded).""" + query = FilterQuery(num_results=10) + + # Empty skip_decode list + query.return_fields("field1", "field2", "field3", skip_decode=[]) + + # All fields present but none skipped + assert len(query._return_fields) == 3 + assert len(query._skip_decode_fields) == 0 + + def test_skip_decode_none_default(self): + """Test that skip_decode defaults to None (backwards compatible).""" + query = FilterQuery(num_results=10) + + # No skip_decode parameter (backwards compatibility) + query.return_fields("field1", "field2", "field3") + + # Fields should be added normally + assert len(query._return_fields) == 3 + + # No skip_decode_fields should be set + if hasattr(query, "_skip_decode_fields"): + assert len(query._skip_decode_fields) == 0 + + def test_skip_decode_field_not_in_return_fields(self): + """Test skip_decode with field not in return_fields (should be ignored).""" + query = FilterQuery(num_results=10) + + # Skip decode for field not being returned + query.return_fields("field1", "field2", skip_decode=["field3"]) + + # Only requested fields should be present + assert len(query._return_fields) == 2 + assert "field1" in query._return_fields + assert "field2" in query._return_fields + + # Skip decode should be tracked even if field not returned + # (implementation may choose to ignore or track it) + assert hasattr(query, "_skip_decode_fields") + + def test_multiple_return_fields_calls_with_skip_decode(self): + """Test calling return_fields multiple times with skip_decode.""" + query = FilterQuery(num_results=10) + + # First call + query.return_fields("field1", skip_decode=["field1"]) + + # Second call should replace, not append + query.return_fields("field2", "field3", skip_decode=["field3"]) + + # Should only have fields from second call + assert "field1" not in query._return_fields + assert "field2" in query._return_fields + assert "field3" in query._return_fields + + # Skip decode should also be replaced + assert "field1" not in query._skip_decode_fields + assert "field3" in query._skip_decode_fields + + def test_skip_decode_with_string_input(self): + """Test skip_decode accepts single string as well as list.""" + query = FilterQuery(num_results=10) + + # Single string for skip_decode + query.return_fields("field1", "field2", skip_decode="field1") + + # Should work same as list with single element + assert "field1" in query._skip_decode_fields + assert "field2" not in query._skip_decode_fields + + def test_skip_decode_type_validation(self): + """Test that skip_decode validates input types.""" + query = FilterQuery(num_results=10) + + # Invalid type should raise error + with pytest.raises(TypeError, match="skip_decode must be"): + query.return_fields("field1", skip_decode=123) + + # Dict should also fail + with pytest.raises(TypeError, match="skip_decode must be"): + query.return_fields("field1", skip_decode={"field": True})