From 4938e1270bdc72bee5465e473b1427711ee038cf Mon Sep 17 00:00:00 2001 From: Brian Sam-Bodden Date: Thu, 23 Oct 2025 07:53:39 -0700 Subject: [PATCH 1/2] test: add integration tests for multi-vector query, UNF/NOINDEX, and multi-field sorting - Add MultiVectorQueryIntegrationTest with 7 tests for multi-vector query feature (#402) - Add UnfNoindexIntegrationTest with 10 tests for UNF/NOINDEX attributes (#374) - Enhance QuerySortingIntegrationTest with 5 multi-field sorting tests Note: skip_decode and text field weights already have comprehensive unit test coverage --- .../MultiVectorQueryIntegrationTest.java | 276 ++++++++++++ .../vl/query/QuerySortingIntegrationTest.java | 132 ++++++ .../vl/schema/UnfNoindexIntegrationTest.java | 400 ++++++++++++++++++ 3 files changed, 808 insertions(+) create mode 100644 core/src/test/java/com/redis/vl/query/MultiVectorQueryIntegrationTest.java create mode 100644 core/src/test/java/com/redis/vl/schema/UnfNoindexIntegrationTest.java diff --git a/core/src/test/java/com/redis/vl/query/MultiVectorQueryIntegrationTest.java b/core/src/test/java/com/redis/vl/query/MultiVectorQueryIntegrationTest.java new file mode 100644 index 0000000..083da3f --- /dev/null +++ b/core/src/test/java/com/redis/vl/query/MultiVectorQueryIntegrationTest.java @@ -0,0 +1,276 @@ +package com.redis.vl.query; + +import static org.assertj.core.api.Assertions.*; + +import com.redis.vl.BaseIntegrationTest; +import com.redis.vl.index.SearchIndex; +import com.redis.vl.schema.*; +import java.util.*; +import org.junit.jupiter.api.*; + +/** + * Integration tests for Multi-Vector Query support (#402). + * + *

Tests simultaneous search across multiple vector fields with weighted score combination. + * + *

Python reference: PR #402 - Multi-vector query support + */ +@Tag("integration") +@DisplayName("Multi-Vector Query Integration Tests") +@TestMethodOrder(MethodOrderer.OrderAnnotation.class) +class MultiVectorQueryIntegrationTest extends BaseIntegrationTest { + + private static final String INDEX_NAME = "multi_vector_test_idx"; + private static SearchIndex searchIndex; + + @BeforeAll + static void setupIndex() { + // Clean up any existing index + try { + unifiedJedis.ftDropIndex(INDEX_NAME); + } catch (Exception e) { + // Ignore if index doesn't exist + } + + // Create schema with multiple vector fields + IndexSchema schema = + IndexSchema.builder() + .name(INDEX_NAME) + .prefix("product:") + .field(TextField.builder().name("title").build()) + .field(TextField.builder().name("description").build()) + .field(TagField.builder().name("category").build()) + .field(NumericField.builder().name("price").sortable(true).build()) + // Text embeddings (3 dimensions) + .field( + VectorField.builder() + .name("text_embedding") + .dimensions(3) + .distanceMetric(VectorField.DistanceMetric.COSINE) + .build()) + // Image embeddings (2 dimensions) + .field( + VectorField.builder() + .name("image_embedding") + .dimensions(2) + .distanceMetric(VectorField.DistanceMetric.COSINE) + .build()) + .build(); + + searchIndex = new SearchIndex(schema, unifiedJedis); + searchIndex.create(); + + // Insert test documents with multiple vector embeddings + Map doc1 = new HashMap<>(); + doc1.put("id", "1"); + doc1.put("title", "Red Laptop"); + doc1.put("description", "Premium laptop"); + doc1.put("category", "electronics"); + doc1.put("price", 1200); + doc1.put("text_embedding", new float[] {0.1f, 0.2f, 0.3f}); + doc1.put("image_embedding", new float[] {0.5f, 0.5f}); + + Map doc2 = new HashMap<>(); + doc2.put("id", "2"); + doc2.put("title", "Blue Phone"); + doc2.put("description", "Budget smartphone"); + doc2.put("category", "electronics"); + doc2.put("price", 300); + doc2.put("text_embedding", new float[] {0.4f, 0.5f, 0.6f}); + doc2.put("image_embedding", new float[] {0.3f, 0.4f}); + + Map doc3 = new HashMap<>(); + doc3.put("id", "3"); + doc3.put("title", "Green Tablet"); + doc3.put("description", "Mid-range tablet"); + doc3.put("category", "electronics"); + doc3.put("price", 500); + doc3.put("text_embedding", new float[] {0.7f, 0.8f, 0.9f}); + doc3.put("image_embedding", new float[] {0.1f, 0.2f}); + + // Load all documents + searchIndex.load(Arrays.asList(doc1, doc2, doc3), "id"); + + // Wait for indexing + try { + Thread.sleep(100); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + } + } + + @AfterAll + static void cleanupIndex() { + if (searchIndex != null) { + try { + searchIndex.drop(); + } catch (Exception e) { + // Ignore + } + } + } + + @Test + @Order(1) + @DisplayName("Should create multi-vector query with single vector") + void testSingleVectorQuery() { + Vector textVec = + Vector.builder() + .vector(new float[] {0.1f, 0.2f, 0.3f}) + .fieldName("text_embedding") + .dtype("float32") + .weight(1.0) + .build(); + + MultiVectorQuery query = MultiVectorQuery.builder().vector(textVec).numResults(10).build(); + + assertThat(query.getVectors()).hasSize(1); + assertThat(query.getNumResults()).isEqualTo(10); + + Map params = query.toParams(); + assertThat(params).containsKey("vector_0"); + assertThat(params.get("vector_0")).isInstanceOf(byte[].class); + } + + @Test + @Order(2) + @DisplayName("Should create multi-vector query with multiple vectors") + void testMultipleVectorsQuery() { + Vector textVec = + Vector.builder() + .vector(new float[] {0.1f, 0.2f, 0.3f}) + .fieldName("text_embedding") + .weight(0.7) + .build(); + + Vector imageVec = + Vector.builder() + .vector(new float[] {0.5f, 0.5f}) + .fieldName("image_embedding") + .weight(0.3) + .build(); + + MultiVectorQuery query = + MultiVectorQuery.builder().vectors(textVec, imageVec).numResults(10).build(); + + assertThat(query.getVectors()).hasSize(2); + + // Verify params + Map params = query.toParams(); + assertThat(params).containsKeys("vector_0", "vector_1"); + + // Verify query string format + String queryString = query.toQueryString(); + assertThat(queryString) + .contains("@text_embedding:[VECTOR_RANGE 2.0 $vector_0]") + .contains("@image_embedding:[VECTOR_RANGE 2.0 $vector_1]") + .contains(" | "); + + // Verify scoring + String formula = query.getScoringFormula(); + assertThat(formula).contains("0.70 * score_0").contains("0.30 * score_1"); + } + + @Test + @Order(3) + @DisplayName("Should combine multi-vector query with filter expression") + void testMultiVectorQueryWithFilter() { + Vector textVec = + Vector.builder().vector(new float[] {0.1f, 0.2f, 0.3f}).fieldName("text_embedding").build(); + + Filter filter = Filter.tag("category", "electronics"); + + MultiVectorQuery query = + MultiVectorQuery.builder().vector(textVec).filterExpression(filter).numResults(5).build(); + + String queryString = query.toQueryString(); + assertThat(queryString).contains(" AND ").contains("@category:{electronics}"); + } + + @Test + @Order(4) + @DisplayName("Should calculate score from multiple vectors with different weights") + void testWeightedScoringCalculation() { + Vector v1 = + Vector.builder() + .vector(new float[] {0.1f, 0.2f, 0.3f}) + .fieldName("text_embedding") + .weight(0.6) + .build(); + + Vector v2 = + Vector.builder() + .vector(new float[] {0.5f, 0.5f}) + .fieldName("image_embedding") + .weight(0.4) + .build(); + + MultiVectorQuery query = MultiVectorQuery.builder().vectors(v1, v2).build(); + + // Verify individual score calculations + Map calculations = query.getScoreCalculations(); + assertThat(calculations).hasSize(2); + assertThat(calculations.get("score_0")).isEqualTo("(2 - distance_0)/2"); + assertThat(calculations.get("score_1")).isEqualTo("(2 - distance_1)/2"); + + // Verify combined scoring formula + String formula = query.getScoringFormula(); + assertThat(formula).isEqualTo("0.60 * score_0 + 0.40 * score_1"); + } + + @Test + @Order(5) + @DisplayName("Should support different vector dimensions and dtypes") + void testDifferentDimensionsAndDtypes() { + Vector v1 = + Vector.builder() + .vector(new float[] {0.1f, 0.2f, 0.3f}) // 3 dimensions + .fieldName("text_embedding") + .dtype("float32") + .weight(0.5) + .build(); + + Vector v2 = + Vector.builder() + .vector(new float[] {0.5f, 0.5f}) // 2 dimensions + .fieldName("image_embedding") + .dtype("float32") + .weight(0.5) + .build(); + + MultiVectorQuery query = MultiVectorQuery.builder().vectors(v1, v2).build(); + + assertThat(query.getVectors().get(0).getVector()).hasSize(3); + assertThat(query.getVectors().get(1).getVector()).hasSize(2); + } + + @Test + @Order(6) + @DisplayName("Should specify return fields") + void testReturnFields() { + Vector textVec = + Vector.builder().vector(new float[] {0.1f, 0.2f, 0.3f}).fieldName("text_embedding").build(); + + MultiVectorQuery query = + MultiVectorQuery.builder() + .vector(textVec) + .returnFields("title", "price", "category") + .build(); + + assertThat(query.getReturnFields()).containsExactly("title", "price", "category"); + } + + @Test + @Order(7) + @DisplayName("Should use VECTOR_RANGE with threshold 2.0") + void testVectorRangeThreshold() { + Vector textVec = + Vector.builder().vector(new float[] {0.1f, 0.2f, 0.3f}).fieldName("text_embedding").build(); + + MultiVectorQuery query = MultiVectorQuery.builder().vector(textVec).build(); + + String queryString = query.toQueryString(); + // Distance threshold hardcoded at 2.0 to include all eligible documents + assertThat(queryString).contains("VECTOR_RANGE 2.0"); + } +} diff --git a/core/src/test/java/com/redis/vl/query/QuerySortingIntegrationTest.java b/core/src/test/java/com/redis/vl/query/QuerySortingIntegrationTest.java index 5ee7064..2567c49 100644 --- a/core/src/test/java/com/redis/vl/query/QuerySortingIntegrationTest.java +++ b/core/src/test/java/com/redis/vl/query/QuerySortingIntegrationTest.java @@ -10,6 +10,7 @@ import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.Tag; import org.junit.jupiter.api.Test; /** @@ -24,6 +25,7 @@ * *

Python reference: /redis-vl-python/tests/integration/test_query.py */ +@Tag("integration") @DisplayName("Query Sorting Integration Tests") class QuerySortingIntegrationTest extends BaseIntegrationTest { @@ -284,6 +286,136 @@ void testSortFilterQueryAlreadyWorks() { } } + /** Test multi-field sorting with FilterQuery (only first field used - Redis limitation) */ + @Test + void testMultiFieldSortingFilterQuery() { + // Specify multiple sort fields - only first should be used (Redis limitation) + List sortFields = List.of(SortField.desc("age"), SortField.asc("credit_score")); + + FilterQuery query = + FilterQuery.builder() + .filterExpression(Filter.tag("credit_score", "high")) + .returnFields(List.of("user", "age", "credit_score")) + .sortBy(sortFields) + .build(); + + // Should use only the first field (age DESC) + assertThat(query.getSortBy()).isEqualTo("age"); + + List> results = index.query(query); + + // Verify results are sorted by age in descending order (first field) + for (int i = 0; i < results.size() - 1; i++) { + int currentAge = getIntValue(results.get(i), "age"); + int nextAge = getIntValue(results.get(i + 1), "age"); + assertThat(currentAge) + .as("Age at position %d should be >= age at position %d (DESC)", i, i + 1) + .isGreaterThanOrEqualTo(nextAge); + } + + // First result should be oldest + assertThat(results.get(0).get("user")).isEqualTo("tyler"); + } + + /** Test multi-field sorting with VectorQuery (only first field used) */ + @Test + void testMultiFieldSortingVectorQuery() { + // Multiple sort fields - only first is used + List sortFields = List.of(SortField.asc("age"), SortField.desc("credit_score")); + + VectorQuery query = + VectorQuery.builder() + .field("user_embedding") + .vector(new float[] {0.1f, 0.1f, 0.5f}) + .returnFields(List.of("user", "age", "credit_score")) + .sortBy(sortFields) + .build(); + + // Only first field should be used + assertThat(query.getSortBy()).isEqualTo("age"); + assertThat(query.isSortDescending()).isFalse(); + + List> results = index.query(query); + + // Verify sorted by first field (age ASC) + for (int i = 0; i < results.size() - 1; i++) { + int currentAge = getIntValue(results.get(i), "age"); + int nextAge = getIntValue(results.get(i + 1), "age"); + assertThat(currentAge).isLessThanOrEqualTo(nextAge); + } + + // First result should be youngest + assertThat(results.get(0).get("user")).isEqualTo("tim"); + } + + /** Test multi-field sorting with VectorRangeQuery (only first field used) */ + @Test + void testMultiFieldSortingVectorRangeQuery() { + List sortFields = List.of(SortField.desc("age"), SortField.asc("user")); + + VectorRangeQuery query = + VectorRangeQuery.builder() + .field("user_embedding") + .vector(new float[] {0.1f, 0.1f, 0.5f}) + .distanceThreshold(1.0f) + .returnFields(List.of("user", "age")) + .sortBy(sortFields) + .build(); + + // Only first field used + assertThat(query.getSortBy()).isEqualTo("age"); + assertThat(query.isSortDescending()).isTrue(); + + List> results = index.query(query); + assertThat(results).hasSizeGreaterThan(0); + + // Sorted by age DESC (first field) + for (int i = 0; i < results.size() - 1; i++) { + int currentAge = getIntValue(results.get(i), "age"); + int nextAge = getIntValue(results.get(i + 1), "age"); + assertThat(currentAge).isGreaterThanOrEqualTo(nextAge); + } + } + + /** Test multi-field sorting with TextQuery (only first field used) */ + @Test + void testMultiFieldSortingTextQuery() { + List sortFields = List.of(SortField.asc("age"), SortField.desc("credit_score")); + + TextQuery query = + TextQuery.builder().text("engineer").textField("job").sortBy(sortFields).build(); + + // Only first field used + assertThat(query.getSortBy()).isEqualTo("age"); + assertThat(query.isSortDescending()).isFalse(); + + List> results = index.query(query); + assertThat(results).hasSizeGreaterThan(0); + + // Sorted by age ASC (first field) + for (int i = 0; i < results.size() - 1; i++) { + int currentAge = getIntValue(results.get(i), "age"); + int nextAge = getIntValue(results.get(i + 1), "age"); + assertThat(currentAge).isLessThanOrEqualTo(nextAge); + } + } + + /** Test that empty sort list is handled gracefully */ + @Test + void testEmptyMultiFieldSort() { + FilterQuery query = + FilterQuery.builder() + .returnFields(List.of("user", "age")) + .sortBy(List.of()) // Empty list + .build(); + + assertThat(query.getSortBy()).isNull(); + + // Should still execute without sorting + List> results = index.query(query); + assertThat(results).hasSizeGreaterThan(0); + } + // Helper method for type conversion (Hash storage returns strings) private int getIntValue(Map map, String key) { Object value = map.get(key); diff --git a/core/src/test/java/com/redis/vl/schema/UnfNoindexIntegrationTest.java b/core/src/test/java/com/redis/vl/schema/UnfNoindexIntegrationTest.java new file mode 100644 index 0000000..263e06e --- /dev/null +++ b/core/src/test/java/com/redis/vl/schema/UnfNoindexIntegrationTest.java @@ -0,0 +1,400 @@ +package com.redis.vl.schema; + +import static org.assertj.core.api.Assertions.*; + +import com.redis.vl.BaseIntegrationTest; +import com.redis.vl.index.SearchIndex; +import java.util.*; +import org.junit.jupiter.api.*; +import redis.clients.jedis.args.SortingOrder; +import redis.clients.jedis.search.SearchResult; + +/** + * Integration tests for UNF (un-normalized form) and NOINDEX field attributes (#374). + * + *

Tests field attributes for controlling sorting normalization and indexing behavior. + * + *

Python reference: PR #386 - UNF/NOINDEX support + */ +@Tag("integration") +@DisplayName("UNF/NOINDEX Field Attributes Integration Tests") +@TestMethodOrder(MethodOrderer.OrderAnnotation.class) +class UnfNoindexIntegrationTest extends BaseIntegrationTest { + + private static final String INDEX_NAME = "unf_noindex_test_idx"; + private static SearchIndex searchIndex; + + @BeforeAll + static void setupIndex() { + // Clean up any existing index + try { + unifiedJedis.ftDropIndex(INDEX_NAME); + } catch (Exception e) { + // Ignore if index doesn't exist + } + + // Create schema with UNF and NOINDEX fields + IndexSchema schema = + IndexSchema.builder() + .name(INDEX_NAME) + .prefix("product:") + // Regular sortable text field (normalized) + .field(TextField.builder().name("name").sortable(true).build()) + // UNF sortable text field (un-normalized, preserves case) + .field(TextField.builder().name("brand").sortable(true).unf(true).build()) + // NOINDEX sortable text field (not searchable, but sortable) + .field(TextField.builder().name("sku").sortable(true).indexed(false).build()) + // Regular sortable numeric field + .field(NumericField.builder().name("price").sortable(true).build()) + // UNF numeric field (flag stored, but Jedis limitation) + .field(NumericField.builder().name("stock").sortable(true).unf(true).build()) + // Regular indexed text field for search + .field(TextField.builder().name("description").build()) + .build(); + + searchIndex = new SearchIndex(schema, unifiedJedis); + searchIndex.create(); + + // Insert test documents with mixed case data + Map doc1 = new HashMap<>(); + doc1.put("id", "1"); + doc1.put("name", "apple laptop"); + doc1.put("brand", "Apple"); + doc1.put("sku", "SKU-001"); + doc1.put("price", 1200); + doc1.put("stock", 50); + doc1.put("description", "Premium laptop"); + + Map doc2 = new HashMap<>(); + doc2.put("id", "2"); + doc2.put("name", "banana phone"); + doc2.put("brand", "banana"); + doc2.put("sku", "SKU-002"); + doc2.put("price", 800); + doc2.put("stock", 30); + doc2.put("description", "Budget phone"); + + Map doc3 = new HashMap<>(); + doc3.put("id", "3"); + doc3.put("name", "cherry tablet"); + doc3.put("brand", "CHERRY"); + doc3.put("sku", "SKU-003"); + doc3.put("price", 500); + doc3.put("stock", 20); + doc3.put("description", "Mid-range tablet"); + + Map doc4 = new HashMap<>(); + doc4.put("id", "4"); + doc4.put("name", "date watch"); + doc4.put("brand", "Date"); + doc4.put("sku", "SKU-004"); + doc4.put("price", 300); + doc4.put("stock", 10); + doc4.put("description", "Smart watch"); + + // Load all documents + searchIndex.load(Arrays.asList(doc1, doc2, doc3, doc4), "id"); + + // Wait for indexing + try { + Thread.sleep(100); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + } + } + + @AfterAll + static void cleanupIndex() { + if (searchIndex != null) { + try { + searchIndex.drop(); + } catch (Exception e) { + // Ignore + } + } + } + + @Test + @Order(1) + @DisplayName("Should create fields with UNF attribute") + void testUnfFieldCreation() { + IndexSchema schema = searchIndex.getSchema(); + + // TextField with UNF + TextField brandField = + (TextField) + schema.getFields().stream().filter(f -> f.getName().equals("brand")).findFirst().get(); + assertThat(brandField.isUnf()).isTrue(); + assertThat(brandField.isSortable()).isTrue(); + + // NumericField with UNF (flag stored despite Jedis limitation) + NumericField stockField = + (NumericField) + schema.getFields().stream().filter(f -> f.getName().equals("stock")).findFirst().get(); + assertThat(stockField.isUnf()).isTrue(); + assertThat(stockField.isSortable()).isTrue(); + } + + @Test + @Order(2) + @DisplayName("Should sort by regular field with case normalization") + void testRegularSortableCaseNormalization() { + // Query sorted by 'name' (regular sortable, normalized) + SearchResult result = + unifiedJedis.ftSearch( + INDEX_NAME, + "*", + redis.clients.jedis.search.FTSearchParams.searchParams() + .sortBy("name", SortingOrder.ASC) + .limit(0, 10)); + + assertThat(result.getTotalResults()).isEqualTo(4); + + // With normalization, all names are treated lowercase + // Expected order: "apple laptop" < "banana phone" < "cherry tablet" < "date watch" + List names = new ArrayList<>(); + result + .getDocuments() + .forEach( + doc -> { + names.add(doc.getString("name")); + }); + + assertThat(names) + .containsExactly("apple laptop", "banana phone", "cherry tablet", "date watch"); + } + + @Test + @Order(3) + @DisplayName("Should sort by UNF field preserving original case") + void testUnfSortablePreservesCase() { + // Query sorted by 'brand' (UNF sortable, case-preserved) + SearchResult result = + unifiedJedis.ftSearch( + INDEX_NAME, + "*", + redis.clients.jedis.search.FTSearchParams.searchParams() + .sortBy("brand", SortingOrder.ASC) + .limit(0, 10)); + + assertThat(result.getTotalResults()).isEqualTo(4); + + // With UNF, case is preserved in sorting + // Expected order: "Apple" < "CHERRY" < "Date" < "banana" + // (uppercase letters sort before lowercase in ASCII) + List brands = new ArrayList<>(); + result + .getDocuments() + .forEach( + doc -> { + brands.add(doc.getString("brand")); + }); + + // ASCII order: 'A' (65) < 'C' (67) < 'D' (68) < 'b' (98) + assertThat(brands).containsExactly("Apple", "CHERRY", "Date", "banana"); + } + + @Test + @Order(4) + @DisplayName("Should allow sorting by numeric field") + void testNumericSortable() { + // Query sorted by 'price' (regular numeric sortable) + SearchResult result = + unifiedJedis.ftSearch( + INDEX_NAME, + "*", + redis.clients.jedis.search.FTSearchParams.searchParams() + .sortBy("price", SortingOrder.ASC) + .limit(0, 10)); + + assertThat(result.getTotalResults()).isEqualTo(4); + + List prices = new ArrayList<>(); + result + .getDocuments() + .forEach( + doc -> { + prices.add(Double.parseDouble(doc.getString("price"))); + }); + + assertThat(prices).containsExactly(300.0, 500.0, 800.0, 1200.0); + } + + @Test + @Order(5) + @DisplayName("Should allow sorting by UNF numeric field (Jedis limitation noted)") + void testUnfNumericSortable() { + // Query sorted by 'stock' (UNF numeric sortable, but Jedis doesn't support sortableUNF for + // numeric) + // This test documents current behavior with Jedis limitation + SearchResult result = + unifiedJedis.ftSearch( + INDEX_NAME, + "*", + redis.clients.jedis.search.FTSearchParams.searchParams() + .sortBy("stock", SortingOrder.ASC) + .limit(0, 10)); + + assertThat(result.getTotalResults()).isEqualTo(4); + + List stock = new ArrayList<>(); + result + .getDocuments() + .forEach( + doc -> { + stock.add(Double.parseDouble(doc.getString("stock"))); + }); + + // Even though UNF is set, numeric fields sort normally + // (Jedis doesn't have sortableUNF() for NumericField yet) + assertThat(stock).containsExactly(10.0, 20.0, 30.0, 50.0); + } + + @Test + @Order(6) + @DisplayName("NOINDEX field should not be searchable but should be retrievable") + void testNoindexFieldNotSearchable() { + // Try to search by 'sku' field (NOINDEX, should not match) + SearchResult result = + unifiedJedis.ftSearch( + INDEX_NAME, "@sku:SKU-001", redis.clients.jedis.search.FTSearchParams.searchParams()); + + // NOINDEX field should not be searchable + assertThat(result.getTotalResults()).isEqualTo(0); + } + + @Test + @Order(7) + @DisplayName("NOINDEX field should be retrievable in results") + void testNoindexFieldRetrievable() { + // Search by indexed field, but retrieve NOINDEX field + SearchResult result = + unifiedJedis.ftSearch( + INDEX_NAME, + "@description:laptop", + redis.clients.jedis.search.FTSearchParams.searchParams()); + + assertThat(result.getTotalResults()).isEqualTo(1); + + // NOINDEX field should still be retrievable + redis.clients.jedis.search.Document doc = result.getDocuments().get(0); + assertThat(doc.getString("sku")).isEqualTo("SKU-001"); + } + + @Test + @Order(8) + @DisplayName("NOINDEX field should be sortable") + void testNoindexFieldSortable() { + // Query sorted by 'sku' (NOINDEX but sortable) + SearchResult result = + unifiedJedis.ftSearch( + INDEX_NAME, + "*", + redis.clients.jedis.search.FTSearchParams.searchParams() + .sortBy("sku", SortingOrder.ASC) + .limit(0, 10)); + + assertThat(result.getTotalResults()).isEqualTo(4); + + List skus = new ArrayList<>(); + result + .getDocuments() + .forEach( + doc -> { + skus.add(doc.getString("sku")); + }); + + // Should sort correctly even though not indexed + assertThat(skus).containsExactly("SKU-001", "SKU-002", "SKU-003", "SKU-004"); + } + + @Test + @Order(9) + @DisplayName("UNF should only apply when sortable is true") + void testUnfRequiresSortable() { + // Create a field with UNF but not sortable + TextField field = TextField.builder().name("test").unf(true).build(); + + // UNF flag is set but field is not sortable + assertThat(field.isUnf()).isTrue(); + assertThat(field.isSortable()).isFalse(); + + // When converted to Jedis field, UNF should be ignored (no sortable) + redis.clients.jedis.search.schemafields.SchemaField jedisField = field.toJedisSchemaField(); + assertThat(jedisField).isNotNull(); + // Field should not be sortable, so UNF has no effect + } + + @Test + @Order(10) + @DisplayName("Should support combined UNF and NOINDEX attributes") + void testUnfAndNoindexCombined() { + // Clean up existing index + try { + unifiedJedis.ftDropIndex("combined_test_idx"); + } catch (Exception e) { + // Ignore + } + + // Create field with both UNF and NOINDEX + IndexSchema schema = + IndexSchema.builder() + .name("combined_test_idx") + .prefix("test:") + .field(TextField.builder().name("code").sortable(true).unf(true).indexed(false).build()) + .field(TextField.builder().name("description").build()) + .build(); + + SearchIndex testIndex = new SearchIndex(schema, unifiedJedis); + testIndex.create(); + + try { + // Insert test data + Map doc1 = new HashMap<>(); + doc1.put("id", "1"); + doc1.put("code", "Alpha"); + doc1.put("description", "First"); + + Map doc2 = new HashMap<>(); + doc2.put("id", "2"); + doc2.put("code", "beta"); + doc2.put("description", "Second"); + + // Load documents + testIndex.load(Arrays.asList(doc1, doc2), "id"); + + try { + Thread.sleep(100); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + } + + // Field should be sortable with UNF (case preserved) + SearchResult result = + unifiedJedis.ftSearch( + "combined_test_idx", + "*", + redis.clients.jedis.search.FTSearchParams.searchParams() + .sortBy("code", SortingOrder.ASC)); + + assertThat(result.getTotalResults()).isEqualTo(2); + + List codes = new ArrayList<>(); + result.getDocuments().forEach(doc -> codes.add(doc.getString("code"))); + + // UNF preserves case: 'A' (65) < 'b' (98) + assertThat(codes).containsExactly("Alpha", "beta"); + + // Field should not be searchable (NOINDEX) + SearchResult searchResult = + unifiedJedis.ftSearch( + "combined_test_idx", + "@code:Alpha", + redis.clients.jedis.search.FTSearchParams.searchParams()); + assertThat(searchResult.getTotalResults()).isEqualTo(0); + + } finally { + testIndex.drop(); + } + } +} From 11a8514f1a79221ae42ec6bf373dec516b67d1cf Mon Sep 17 00:00:00 2001 From: Brian Sam-Bodden Date: Thu, 23 Oct 2025 08:32:15 -0700 Subject: [PATCH 2/2] feat(messagehistory): add role filtering to getRecent() method (#349) Implements role-based filtering for message retrieval, allowing users to filter messages by role type (system, user, llm, tool) when querying conversation history. Changes: - Add validateRoles() method to BaseMessageHistory for role validation - Update getRecent() signature to accept optional role parameter - Support single role string or List of multiple roles - Implement role filtering in MessageHistory using Filter.tag() combinations - Maintain backward compatibility with existing getRecent() calls Tests: - Add RoleFilteringTest with 15 unit tests for validation logic - Add RoleFilteringIntegrationTest with 15 integration tests - Test single role, multiple roles, null role, and error cases - Test role filtering with top_k, session_tag, raw, and asText parameters Python reference: PR #387 - Role filtering for message history Ported from: tests/integration/test_role_filter_get_recent.py --- .../messagehistory/BaseMessageHistory.java | 66 ++- .../messagehistory/MessageHistory.java | 40 +- .../RoleFilteringIntegrationTest.java | 398 ++++++++++++++++++ .../messagehistory/RoleFilteringTest.java | 227 ++++++++++ 4 files changed, 726 insertions(+), 5 deletions(-) create mode 100644 core/src/test/java/com/redis/vl/extensions/messagehistory/RoleFilteringIntegrationTest.java create mode 100644 core/src/test/java/com/redis/vl/extensions/messagehistory/RoleFilteringTest.java diff --git a/core/src/main/java/com/redis/vl/extensions/messagehistory/BaseMessageHistory.java b/core/src/main/java/com/redis/vl/extensions/messagehistory/BaseMessageHistory.java index 66064e0..194674e 100644 --- a/core/src/main/java/com/redis/vl/extensions/messagehistory/BaseMessageHistory.java +++ b/core/src/main/java/com/redis/vl/extensions/messagehistory/BaseMessageHistory.java @@ -7,6 +7,7 @@ import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.Set; /** * Base class for message history implementations. @@ -15,6 +16,9 @@ */ public abstract class BaseMessageHistory { + /** Valid role values for message filtering. */ + private static final Set VALID_ROLES = Set.of("system", "user", "llm", "tool"); + protected final String name; protected final String sessionTag; @@ -55,10 +59,14 @@ protected BaseMessageHistory(String name, String sessionTag) { * @param raw Whether to return the full Redis hash entry or just the role/content/tool_call_id. * @param sessionTag Tag of the entries linked to a specific conversation session. Defaults to * instance ULID. + * @param role Filter messages by role(s). Can be a single role string ("system", "user", "llm", + * "tool"), a List of role strings, or null for no filtering. * @return List of messages (either as text strings or maps depending on asText parameter) - * @throws IllegalArgumentException if topK is not an integer greater than or equal to 0 + * @throws IllegalArgumentException if topK is not an integer greater than or equal to 0, or if + * role contains invalid values */ - public abstract List getRecent(int topK, boolean asText, boolean raw, String sessionTag); + public abstract List getRecent( + int topK, boolean asText, boolean raw, String sessionTag, Object role); /** * Insert a prompt:response pair into the message history. @@ -117,6 +125,60 @@ protected List formatContext(List> messages, boolean return context; } + /** + * Validate and normalize role parameter for filtering messages. + * + *

Matches Python _validate_roles from base_history.py (lines 90-128) + * + * @param role A single role string, List of roles, or null + * @return List of valid role strings if role is provided, null otherwise + * @throws IllegalArgumentException if role contains invalid values or is the wrong type + */ + @SuppressWarnings("unchecked") + protected List validateRoles(Object role) { + if (role == null) { + return null; + } + + // Handle single role string + if (role instanceof String) { + String roleStr = (String) role; + if (!VALID_ROLES.contains(roleStr)) { + throw new IllegalArgumentException( + String.format("Invalid role '%s'. Valid roles are: %s", roleStr, VALID_ROLES)); + } + return List.of(roleStr); + } + + // Handle list of roles + if (role instanceof List) { + List roleList = (List) role; + + if (roleList.isEmpty()) { + throw new IllegalArgumentException("roles cannot be empty"); + } + + // Validate all roles in the list + List validatedRoles = new ArrayList<>(); + for (Object r : roleList) { + if (!(r instanceof String)) { + throw new IllegalArgumentException( + "role list must contain only strings, found: " + r.getClass().getSimpleName()); + } + String roleStr = (String) r; + if (!VALID_ROLES.contains(roleStr)) { + throw new IllegalArgumentException( + String.format("Invalid role '%s'. Valid roles are: %s", roleStr, VALID_ROLES)); + } + validatedRoles.add(roleStr); + } + + return validatedRoles; + } + + throw new IllegalArgumentException("role must be a String, List, or null"); + } + public String getName() { return name; } diff --git a/core/src/main/java/com/redis/vl/extensions/messagehistory/MessageHistory.java b/core/src/main/java/com/redis/vl/extensions/messagehistory/MessageHistory.java index 33b4607..22a2e36 100644 --- a/core/src/main/java/com/redis/vl/extensions/messagehistory/MessageHistory.java +++ b/core/src/main/java/com/redis/vl/extensions/messagehistory/MessageHistory.java @@ -77,7 +77,7 @@ public void delete() { public void drop(String id) { if (id == null) { // Get the most recent message - List> recent = getRecent(1, false, true, null); + List> recent = getRecent(1, false, true, null, null); if (!recent.isEmpty()) { id = (String) recent.get(0).get(ID_FIELD_NAME); } else { @@ -111,14 +111,31 @@ public List> getMessages() { return formatContext(messages, false); } + /** + * Retrieve the recent conversation history (backward-compatible overload without role filter). + * + * @param topK The number of previous messages to return + * @param asText Whether to return as text strings or maps + * @param raw Whether to return full Redis hash entries + * @param sessionTag Session tag to filter by + * @return List of messages + */ + public List getRecent(int topK, boolean asText, boolean raw, String sessionTag) { + return getRecent(topK, asText, raw, sessionTag, null); + } + @Override @SuppressWarnings("unchecked") - public List getRecent(int topK, boolean asText, boolean raw, String sessionTag) { + public List getRecent( + int topK, boolean asText, boolean raw, String sessionTag, Object role) { // Validate topK if (topK < 0) { throw new IllegalArgumentException("topK must be an integer greater than or equal to 0"); } + // Validate and normalize role parameter + List rolesToFilter = validateRoles(role); + List returnFields = List.of( ID_FIELD_NAME, @@ -131,9 +148,26 @@ public List getRecent(int topK, boolean asText, boolean raw, String sessi Filter sessionFilter = (sessionTag != null) ? Filter.tag(SESSION_FIELD_NAME, sessionTag) : defaultSessionFilter; + // Combine session filter with role filter if provided + Filter filterExpression = sessionFilter; + if (rolesToFilter != null) { + if (rolesToFilter.size() == 1) { + // Single role filter + Filter roleFilter = Filter.tag(ROLE_FIELD_NAME, rolesToFilter.get(0)); + filterExpression = Filter.and(sessionFilter, roleFilter); + } else { + // Multiple roles - use OR logic + Filter roleFilter = Filter.tag(ROLE_FIELD_NAME, rolesToFilter.get(0)); + for (int i = 1; i < rolesToFilter.size(); i++) { + roleFilter = Filter.or(roleFilter, Filter.tag(ROLE_FIELD_NAME, rolesToFilter.get(i))); + } + filterExpression = Filter.and(sessionFilter, roleFilter); + } + } + FilterQuery query = FilterQuery.builder() - .filterExpression(sessionFilter) + .filterExpression(filterExpression) .returnFields(returnFields) .numResults(topK) .sortBy(TIMESTAMP_FIELD_NAME) diff --git a/core/src/test/java/com/redis/vl/extensions/messagehistory/RoleFilteringIntegrationTest.java b/core/src/test/java/com/redis/vl/extensions/messagehistory/RoleFilteringIntegrationTest.java new file mode 100644 index 0000000..e4cb6a4 --- /dev/null +++ b/core/src/test/java/com/redis/vl/extensions/messagehistory/RoleFilteringIntegrationTest.java @@ -0,0 +1,398 @@ +package com.redis.vl.extensions.messagehistory; + +import static com.redis.vl.extensions.Constants.*; +import static org.assertj.core.api.Assertions.*; + +import com.redis.vl.BaseIntegrationTest; +import java.util.*; +import org.junit.jupiter.api.*; + +/** + * Integration tests for role filtering in getRecent() method (#349). + * + *

Ported from Python: tests/integration/test_role_filter_get_recent.py + * + *

Tests role filtering functionality with real Redis operations. + * + *

Python reference: PR #387 - Role filtering for message history + */ +@Tag("integration") +@DisplayName("Role Filtering Integration Tests") +@TestMethodOrder(MethodOrderer.OrderAnnotation.class) +class RoleFilteringIntegrationTest extends BaseIntegrationTest { + + private MessageHistory history; + + @BeforeEach + void setUp() { + // Each test gets a unique MessageHistory instance + } + + @AfterEach + void tearDown() { + if (history != null) { + try { + history.delete(); + } catch (Exception e) { + // Ignore cleanup errors + } + } + } + + @Test + @Order(1) + @DisplayName("Should filter by single role: system") + void testGetRecentSingleRoleSystem() { + history = new MessageHistory("test_role_system", unifiedJedis); + history.clear(); + + // Add various messages with different roles + List> messages = new ArrayList<>(); + messages.add(Map.of(ROLE_FIELD_NAME, "system", CONTENT_FIELD_NAME, "System initialization")); + messages.add(Map.of(ROLE_FIELD_NAME, "user", CONTENT_FIELD_NAME, "Hello")); + messages.add(Map.of(ROLE_FIELD_NAME, "llm", CONTENT_FIELD_NAME, "Hi there")); + messages.add( + Map.of(ROLE_FIELD_NAME, "system", CONTENT_FIELD_NAME, "System configuration updated")); + messages.add( + Map.of( + ROLE_FIELD_NAME, + "tool", + CONTENT_FIELD_NAME, + "Function executed", + TOOL_FIELD_NAME, + "call1")); + + history.addMessages(messages); + + // Get only system messages + List> result = history.getRecent(10, false, false, null, "system"); + + assertThat(result).hasSize(2); + assertThat(result).allMatch(msg -> "system".equals(msg.get(ROLE_FIELD_NAME))); + assertThat(result.get(0).get(CONTENT_FIELD_NAME)).isEqualTo("System initialization"); + assertThat(result.get(1).get(CONTENT_FIELD_NAME)).isEqualTo("System configuration updated"); + } + + @Test + @Order(2) + @DisplayName("Should filter by single role: user") + void testGetRecentSingleRoleUser() { + history = new MessageHistory("test_role_user", unifiedJedis); + history.clear(); + + List> messages = new ArrayList<>(); + messages.add(Map.of(ROLE_FIELD_NAME, "system", CONTENT_FIELD_NAME, "Welcome")); + messages.add(Map.of(ROLE_FIELD_NAME, "user", CONTENT_FIELD_NAME, "First question")); + messages.add(Map.of(ROLE_FIELD_NAME, "llm", CONTENT_FIELD_NAME, "First answer")); + messages.add(Map.of(ROLE_FIELD_NAME, "user", CONTENT_FIELD_NAME, "Second question")); + messages.add(Map.of(ROLE_FIELD_NAME, "user", CONTENT_FIELD_NAME, "Third question")); + + history.addMessages(messages); + + List> result = history.getRecent(10, false, false, null, "user"); + + assertThat(result).hasSize(3); + assertThat(result).allMatch(msg -> "user".equals(msg.get(ROLE_FIELD_NAME))); + assertThat(result.get(0).get(CONTENT_FIELD_NAME)).isEqualTo("First question"); + assertThat(result.get(2).get(CONTENT_FIELD_NAME)).isEqualTo("Third question"); + } + + @Test + @Order(3) + @DisplayName("Should filter by single role: llm") + void testGetRecentSingleRoleLlm() { + history = new MessageHistory("test_role_llm", unifiedJedis); + history.clear(); + + List> messages = new ArrayList<>(); + messages.add(Map.of(ROLE_FIELD_NAME, "user", CONTENT_FIELD_NAME, "Question 1")); + messages.add(Map.of(ROLE_FIELD_NAME, "llm", CONTENT_FIELD_NAME, "Answer 1")); + messages.add(Map.of(ROLE_FIELD_NAME, "user", CONTENT_FIELD_NAME, "Question 2")); + messages.add(Map.of(ROLE_FIELD_NAME, "llm", CONTENT_FIELD_NAME, "Answer 2")); + messages.add(Map.of(ROLE_FIELD_NAME, "system", CONTENT_FIELD_NAME, "System note")); + + history.addMessages(messages); + + List> result = history.getRecent(10, false, false, null, "llm"); + + assertThat(result).hasSize(2); + assertThat(result).allMatch(msg -> "llm".equals(msg.get(ROLE_FIELD_NAME))); + assertThat(result.get(0).get(CONTENT_FIELD_NAME)).isEqualTo("Answer 1"); + assertThat(result.get(1).get(CONTENT_FIELD_NAME)).isEqualTo("Answer 2"); + } + + @Test + @Order(4) + @DisplayName("Should filter by single role: tool") + void testGetRecentSingleRoleTool() { + history = new MessageHistory("test_role_tool", unifiedJedis); + history.clear(); + + List> messages = new ArrayList<>(); + messages.add(Map.of(ROLE_FIELD_NAME, "user", CONTENT_FIELD_NAME, "Run function")); + messages.add( + Map.of( + ROLE_FIELD_NAME, + "tool", + CONTENT_FIELD_NAME, + "Function result 1", + TOOL_FIELD_NAME, + "call1")); + messages.add(Map.of(ROLE_FIELD_NAME, "llm", CONTENT_FIELD_NAME, "Processing")); + messages.add( + Map.of( + ROLE_FIELD_NAME, + "tool", + CONTENT_FIELD_NAME, + "Function result 2", + TOOL_FIELD_NAME, + "call2")); + + history.addMessages(messages); + + List> result = history.getRecent(10, false, false, null, "tool"); + + assertThat(result).hasSize(2); + assertThat(result).allMatch(msg -> "tool".equals(msg.get(ROLE_FIELD_NAME))); + assertThat(result).allMatch(msg -> msg.containsKey(TOOL_FIELD_NAME)); + } + + @Test + @Order(5) + @DisplayName("Should filter by multiple roles") + void testGetRecentMultipleRoles() { + history = new MessageHistory("test_multi_roles", unifiedJedis); + history.clear(); + + List> messages = new ArrayList<>(); + messages.add(Map.of(ROLE_FIELD_NAME, "system", CONTENT_FIELD_NAME, "System message")); + messages.add(Map.of(ROLE_FIELD_NAME, "user", CONTENT_FIELD_NAME, "User message")); + messages.add(Map.of(ROLE_FIELD_NAME, "llm", CONTENT_FIELD_NAME, "LLM message")); + messages.add( + Map.of( + ROLE_FIELD_NAME, "tool", CONTENT_FIELD_NAME, "Tool message", TOOL_FIELD_NAME, "call1")); + + history.addMessages(messages); + + // Get system and user messages only + List> result = + history.getRecent(10, false, false, null, Arrays.asList("system", "user")); + + assertThat(result).hasSize(2); + assertThat(result) + .allMatch(msg -> List.of("system", "user").contains(msg.get(ROLE_FIELD_NAME))); + assertThat(result.get(0).get(CONTENT_FIELD_NAME)).isEqualTo("System message"); + assertThat(result.get(1).get(CONTENT_FIELD_NAME)).isEqualTo("User message"); + } + + @Test + @Order(6) + @DisplayName("Should return all messages when role=null (backward compatibility)") + void testGetRecentNoRoleFilterBackwardCompatibility() { + history = new MessageHistory("test_no_filter", unifiedJedis); + history.clear(); + + List> messages = new ArrayList<>(); + messages.add(Map.of(ROLE_FIELD_NAME, "system", CONTENT_FIELD_NAME, "System")); + messages.add(Map.of(ROLE_FIELD_NAME, "user", CONTENT_FIELD_NAME, "User")); + messages.add(Map.of(ROLE_FIELD_NAME, "llm", CONTENT_FIELD_NAME, "LLM")); + messages.add( + Map.of(ROLE_FIELD_NAME, "tool", CONTENT_FIELD_NAME, "Tool", TOOL_FIELD_NAME, "call1")); + + history.addMessages(messages); + + // No role filter - should return all messages + List> result = history.getRecent(10, false, false, null, null); + + assertThat(result).hasSize(4); + Set roles = + result.stream() + .map(msg -> (String) msg.get(ROLE_FIELD_NAME)) + .collect(java.util.stream.Collectors.toSet()); + assertThat(roles).containsExactlyInAnyOrder("system", "user", "llm", "tool"); + } + + @Test + @Order(7) + @DisplayName("Should throw exception for invalid role") + void testGetRecentInvalidRoleRaisesError() { + history = new MessageHistory("test_invalid", unifiedJedis); + + assertThatThrownBy(() -> history.getRecent(10, false, false, null, "invalid_role")) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("Invalid role"); + } + + @Test + @Order(8) + @DisplayName("Should throw exception for invalid role in list") + void testGetRecentInvalidRoleInListRaisesError() { + history = new MessageHistory("test_invalid_list", unifiedJedis); + + assertThatThrownBy( + () -> + history.getRecent(10, false, false, null, Arrays.asList("system", "invalid_role"))) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("Invalid role"); + } + + @Test + @Order(9) + @DisplayName("Should throw exception for empty role list") + void testGetRecentEmptyRoleListRaisesError() { + history = new MessageHistory("test_empty_list", unifiedJedis); + + assertThatThrownBy(() -> history.getRecent(10, false, false, null, Collections.emptyList())) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("roles cannot be empty"); + } + + @Test + @Order(10) + @DisplayName("Should work with top_k parameter") + void testGetRecentRoleWithTopKParameter() { + history = new MessageHistory("test_with_params", unifiedJedis); + history.clear(); + + // Add many system messages + for (int i = 0; i < 5; i++) { + history.addMessage( + Map.of(ROLE_FIELD_NAME, "system", CONTENT_FIELD_NAME, "System message " + i)); + } + + // Add other messages + history.addMessage(Map.of(ROLE_FIELD_NAME, "user", CONTENT_FIELD_NAME, "User message")); + history.addMessage(Map.of(ROLE_FIELD_NAME, "llm", CONTENT_FIELD_NAME, "LLM message")); + + // Get only 2 most recent system messages (should get last 2: index 3 and 4) + List> result = history.getRecent(2, false, false, null, "system"); + + assertThat(result).hasSize(2); + assertThat(result).allMatch(msg -> "system".equals(msg.get(ROLE_FIELD_NAME))); + // Should get 2 most recent ones in chronological order + // Since we added 0,1,2,3,4 and query gets the last 2, we expect 3 and 4 + String first = (String) result.get(0).get(CONTENT_FIELD_NAME); + String second = (String) result.get(1).get(CONTENT_FIELD_NAME); + + // Verify we got 2 of the system messages + assertThat(first).startsWith("System message"); + assertThat(second).startsWith("System message"); + + // The second should be later than the first (chronological order after reversal) + int firstNum = Integer.parseInt(first.replace("System message ", "")); + int secondNum = Integer.parseInt(second.replace("System message ", "")); + assertThat(secondNum).isGreaterThan(firstNum); + } + + @Test + @Order(11) + @DisplayName("Should work with session_tag parameter") + void testGetRecentRoleWithSessionTag() { + history = new MessageHistory("test_session", unifiedJedis); + history.clear(); + + // Add messages with different session tags + history.addMessages( + Arrays.asList( + Map.of(ROLE_FIELD_NAME, "system", CONTENT_FIELD_NAME, "System for session1"), + Map.of(ROLE_FIELD_NAME, "user", CONTENT_FIELD_NAME, "User for session1")), + "session1"); + + history.addMessages( + Arrays.asList( + Map.of(ROLE_FIELD_NAME, "system", CONTENT_FIELD_NAME, "System for session2"), + Map.of(ROLE_FIELD_NAME, "llm", CONTENT_FIELD_NAME, "LLM for session2")), + "session2"); + + // Get system messages from session2 only + List> result = history.getRecent(10, false, false, "session2", "system"); + + assertThat(result).hasSize(1); + assertThat(result.get(0).get(ROLE_FIELD_NAME)).isEqualTo("system"); + assertThat(result.get(0).get(CONTENT_FIELD_NAME)).isEqualTo("System for session2"); + } + + @Test + @Order(12) + @DisplayName("Should work with raw=true parameter") + void testGetRecentRoleWithRawOutput() { + history = new MessageHistory("test_raw", unifiedJedis); + history.clear(); + + history.addMessage(Map.of(ROLE_FIELD_NAME, "system", CONTENT_FIELD_NAME, "System message")); + + List> result = history.getRecent(10, false, true, null, "system"); + + assertThat(result).hasSize(1); + assertThat(result.get(0).get(ROLE_FIELD_NAME)).isEqualTo("system"); + // Raw should include additional metadata + assertThat(result.get(0)).containsKeys(ID_FIELD_NAME, TIMESTAMP_FIELD_NAME, SESSION_FIELD_NAME); + } + + @Test + @Order(13) + @DisplayName("Should accept all valid roles") + void testValidRolesAccepted() { + String[] validRoles = {"system", "user", "llm", "tool"}; + history = new MessageHistory("test_valid_roles", unifiedJedis); + history.clear(); + + // Add messages with all valid roles + for (String role : validRoles) { + if ("tool".equals(role)) { + history.addMessage( + Map.of( + ROLE_FIELD_NAME, + role, + CONTENT_FIELD_NAME, + role + " message", + TOOL_FIELD_NAME, + "call1")); + } else { + history.addMessage(Map.of(ROLE_FIELD_NAME, role, CONTENT_FIELD_NAME, role + " message")); + } + } + + // Test each valid role works + for (String role : validRoles) { + List> result = history.getRecent(10, false, false, null, role); + assertThat(result).hasSizeGreaterThanOrEqualTo(1); + assertThat(result).allMatch(msg -> role.equals(msg.get(ROLE_FIELD_NAME))); + } + } + + @Test + @Order(14) + @DisplayName("Should be case-sensitive for role validation") + void testCaseSensitiveRoles() { + history = new MessageHistory("test_case", unifiedJedis); + + // Uppercase should fail + assertThatThrownBy(() -> history.getRecent(10, false, false, null, "SYSTEM")) + .isInstanceOf(IllegalArgumentException.class); + + // Mixed case should fail + assertThatThrownBy(() -> history.getRecent(10, false, false, null, "User")) + .isInstanceOf(IllegalArgumentException.class); + } + + @Test + @Order(15) + @DisplayName("Should work with asText=true parameter") + void testGetRecentRoleWithAsText() { + history = new MessageHistory("test_as_text", unifiedJedis); + history.clear(); + + history.addMessages( + Arrays.asList( + Map.of(ROLE_FIELD_NAME, "system", CONTENT_FIELD_NAME, "System message"), + Map.of(ROLE_FIELD_NAME, "user", CONTENT_FIELD_NAME, "User message"), + Map.of(ROLE_FIELD_NAME, "llm", CONTENT_FIELD_NAME, "LLM message"))); + + // Get only user messages as text + List result = history.getRecent(10, true, false, null, "user"); + + assertThat(result).hasSize(1); + assertThat(result.get(0)).isEqualTo("User message"); + } +} diff --git a/core/src/test/java/com/redis/vl/extensions/messagehistory/RoleFilteringTest.java b/core/src/test/java/com/redis/vl/extensions/messagehistory/RoleFilteringTest.java new file mode 100644 index 0000000..6026b7f --- /dev/null +++ b/core/src/test/java/com/redis/vl/extensions/messagehistory/RoleFilteringTest.java @@ -0,0 +1,227 @@ +package com.redis.vl.extensions.messagehistory; + +import static org.assertj.core.api.Assertions.*; + +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.Test; + +/** + * Unit tests for role filtering functionality (#349). + * + *

Ported from Python: tests/integration/test_role_filter_get_recent.py (TestRoleValidation) + * + *

Tests the validateRoles() method that validates role parameter for filtering messages. + * + *

Python reference: PR #387 - Role filtering for message history + */ +@DisplayName("Role Filtering Validation Tests") +class RoleFilteringTest { + + /** + * Concrete implementation of BaseMessageHistory for testing validation logic. This allows us to + * test the protected validateRoles method without requiring Redis. + */ + private static class TestableMessageHistory extends BaseMessageHistory { + public TestableMessageHistory() { + super("test", null); + } + + // Expose validateRoles for testing + public List testValidateRoles(Object role) { + return validateRoles(role); + } + + @Override + public void clear() {} + + @Override + public void delete() {} + + @Override + public void drop(String id) {} + + @Override + public List> getMessages() { + return null; + } + + @Override + public List getRecent( + int topK, boolean asText, boolean raw, String sessionTag, Object role) { + return null; + } + + @Override + public void store(String prompt, String response, String sessionTag) {} + + @Override + public void addMessages(List> messages, String sessionTag) {} + + @Override + public void addMessage(java.util.Map message, String sessionTag) {} + } + + @Test + @DisplayName("Should accept null role (no filtering)") + void testNullRoleReturnsNull() { + TestableMessageHistory history = new TestableMessageHistory(); + + List result = history.testValidateRoles(null); + + assertThat(result).isNull(); + } + + @Test + @DisplayName("Should accept valid single role: system") + void testValidSingleRoleSystem() { + TestableMessageHistory history = new TestableMessageHistory(); + + List result = history.testValidateRoles("system"); + + assertThat(result).isNotNull().containsExactly("system"); + } + + @Test + @DisplayName("Should accept valid single role: user") + void testValidSingleRoleUser() { + TestableMessageHistory history = new TestableMessageHistory(); + + List result = history.testValidateRoles("user"); + + assertThat(result).isNotNull().containsExactly("user"); + } + + @Test + @DisplayName("Should accept valid single role: llm") + void testValidSingleRoleLlm() { + TestableMessageHistory history = new TestableMessageHistory(); + + List result = history.testValidateRoles("llm"); + + assertThat(result).isNotNull().containsExactly("llm"); + } + + @Test + @DisplayName("Should accept valid single role: tool") + void testValidSingleRoleTool() { + TestableMessageHistory history = new TestableMessageHistory(); + + List result = history.testValidateRoles("tool"); + + assertThat(result).isNotNull().containsExactly("tool"); + } + + @Test + @DisplayName("Should reject invalid single role") + void testInvalidSingleRole() { + TestableMessageHistory history = new TestableMessageHistory(); + + assertThatThrownBy(() -> history.testValidateRoles("invalid_role")) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("Invalid role 'invalid_role'") + .hasMessageContaining("system") + .hasMessageContaining("user") + .hasMessageContaining("llm") + .hasMessageContaining("tool"); + } + + @Test + @DisplayName("Should reject uppercase role (case-sensitive)") + void testCaseSensitiveUppercase() { + TestableMessageHistory history = new TestableMessageHistory(); + + assertThatThrownBy(() -> history.testValidateRoles("SYSTEM")) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("Invalid role 'SYSTEM'"); + } + + @Test + @DisplayName("Should reject mixed-case role (case-sensitive)") + void testCaseSensitiveMixedCase() { + TestableMessageHistory history = new TestableMessageHistory(); + + assertThatThrownBy(() -> history.testValidateRoles("User")) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("Invalid role 'User'"); + } + + @Test + @DisplayName("Should accept list of valid roles") + void testValidListOfRoles() { + TestableMessageHistory history = new TestableMessageHistory(); + + List input = Arrays.asList("system", "user"); + List result = history.testValidateRoles(input); + + assertThat(result).isNotNull().containsExactly("system", "user"); + } + + @Test + @DisplayName("Should accept list with all valid roles") + void testValidListAllRoles() { + TestableMessageHistory history = new TestableMessageHistory(); + + List input = Arrays.asList("system", "user", "llm", "tool"); + List result = history.testValidateRoles(input); + + assertThat(result).isNotNull().containsExactly("system", "user", "llm", "tool"); + } + + @Test + @DisplayName("Should reject empty role list") + void testEmptyRoleList() { + TestableMessageHistory history = new TestableMessageHistory(); + + assertThatThrownBy(() -> history.testValidateRoles(Collections.emptyList())) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("roles cannot be empty"); + } + + @Test + @DisplayName("Should reject list containing invalid role") + void testListWithInvalidRole() { + TestableMessageHistory history = new TestableMessageHistory(); + + List input = Arrays.asList("system", "invalid_role"); + + assertThatThrownBy(() -> history.testValidateRoles(input)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("Invalid role 'invalid_role'"); + } + + @Test + @DisplayName("Should reject list containing uppercase role") + void testListWithUppercaseRole() { + TestableMessageHistory history = new TestableMessageHistory(); + + List input = Arrays.asList("system", "USER"); + + assertThatThrownBy(() -> history.testValidateRoles(input)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("Invalid role 'USER'"); + } + + @Test + @DisplayName("Should reject non-string, non-list input") + void testInvalidInputType() { + TestableMessageHistory history = new TestableMessageHistory(); + + assertThatThrownBy(() -> history.testValidateRoles(Integer.valueOf(42))) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("role must be a String, List, or null"); + } + + @Test + @DisplayName("Should handle list with single role") + void testListWithSingleRole() { + TestableMessageHistory history = new TestableMessageHistory(); + + List input = Arrays.asList("system"); + List result = history.testValidateRoles(input); + + assertThat(result).isNotNull().containsExactly("system"); + } +}