diff --git a/src/NRedisStack/Search/Query.cs b/src/NRedisStack/Search/Query.cs index 475ef7f4..cf7d5ae1 100644 --- a/src/NRedisStack/Search/Query.cs +++ b/src/NRedisStack/Search/Query.cs @@ -180,7 +180,7 @@ public HighlightTags(string open, string close) /// /// Set the query parameter to sort by ASC by default /// - public bool SortAscending { get; set; } = true; + public bool? SortAscending { get; set; } = null; // highlight and summarize internal bool _wantsHighlight = false, _wantsSummarize = false; @@ -260,7 +260,8 @@ internal void SerializeRedisArgs(List args) { args.Add("SORTBY"); args.Add(SortBy); - args.Add((SortAscending ? "ASC" : "DESC")); + if (SortAscending != null) + args.Add(((bool)SortAscending ? "ASC" : "DESC")); } if (Payload != null) { @@ -605,7 +606,7 @@ public Query SummarizeFields(int contextLen, int fragmentCount, string separator /// the sorting field's name /// if set to true, the sorting order is ascending, else descending /// the query object itself - public Query SetSortBy(string field, bool ascending = true) + public Query SetSortBy(string field, bool? ascending = null) { SortBy = field; SortAscending = ascending; diff --git a/src/NRedisStack/Search/Schema.cs b/src/NRedisStack/Search/Schema.cs index 5f17501b..2b4f98b8 100644 --- a/src/NRedisStack/Search/Schema.cs +++ b/src/NRedisStack/Search/Schema.cs @@ -209,13 +209,16 @@ public enum VectorAlgo public VectorAlgo Algorithm { get; } public Dictionary? Attributes { get; } - public VectorField(string name, VectorAlgo algorithm, Dictionary? attributes = null) + public VectorField(FieldName name, VectorAlgo algorithm, Dictionary? attributes = null) : base(name, FieldType.Vector) { Algorithm = algorithm; Attributes = attributes; } + public VectorField(string name, VectorAlgo algorithm, Dictionary? attributes = null) + : this(FieldName.Of(name), algorithm, attributes) { } + internal override void AddFieldTypeArgs(List args) { args.Add(Algorithm.ToString()); @@ -376,6 +379,19 @@ public Schema AddTagField(string name, bool sortable = false, bool unf = false, return this; } + /// + /// Add a Vector field to the schema. + /// + /// The field's name. + /// The vector similarity algorithm to use. + /// The algorithm attributes for the creation of the vector index. + /// The object. + public Schema AddVectorField(FieldName name, VectorAlgo algorithm, Dictionary? attributes = null) + { + Fields.Add(new VectorField(name, algorithm, attributes)); + return this; + } + /// /// Add a Vector field to the schema. /// diff --git a/src/NRedisStack/Search/SearchResult.cs b/src/NRedisStack/Search/SearchResult.cs index 76a76cfb..7f4c5c77 100644 --- a/src/NRedisStack/Search/SearchResult.cs +++ b/src/NRedisStack/Search/SearchResult.cs @@ -15,8 +15,8 @@ public class SearchResult /// /// Converts the documents to a list of json strings. only works on a json documents index. /// - public IEnumerable? ToJson() => Documents.Select(x => x["json"].ToString()) - .Where(x => !string.IsNullOrEmpty(x)); + public List? ToJson() => Documents.Select(x => x["json"].ToString()) + .Where(x => !string.IsNullOrEmpty(x)).ToList(); internal SearchResult(RedisResult[] resp, bool hasContent, bool hasScores, bool hasPayloads/*, bool shouldExplainScore*/) { diff --git a/tests/NRedisStack.Tests/Search/SearchTests.cs b/tests/NRedisStack.Tests/Search/SearchTests.cs index 593f7642..838d60a1 100644 --- a/tests/NRedisStack.Tests/Search/SearchTests.cs +++ b/tests/NRedisStack.Tests/Search/SearchTests.cs @@ -6,6 +6,7 @@ using static NRedisStack.Search.Schema; using NRedisStack.Search.Aggregation; using NRedisStack.Search.Literals.Enums; +using System.Runtime.InteropServices; namespace NRedisStack.Tests.Search; @@ -1912,6 +1913,98 @@ public async Task TestVectorCount_Issue70() Assert.Equal(expected.Count(), actual.Args.Length); } + [Fact] + public void VectorSimilaritySearch() + { + IDatabase db = redisFixture.Redis.GetDatabase(); + db.Execute("FLUSHALL"); + var ft = db.FT(); + var json = db.JSON(); + + json.Set("vec:1", "$", "{\"vector\":[1,1,1,1]}"); + json.Set("vec:2", "$", "{\"vector\":[2,2,2,2]}"); + json.Set("vec:3", "$", "{\"vector\":[3,3,3,3]}"); + json.Set("vec:4", "$", "{\"vector\":[4,4,4,4]}"); + + var schema = new Schema().AddVectorField(FieldName.Of("$.vector").As("vector"), Schema.VectorField.VectorAlgo.FLAT, new Dictionary() + { + ["TYPE"] = "FLOAT32", + ["DIM"] = "4", + ["DISTANCE_METRIC"] = "L2", + }); + + var idxDef = new FTCreateParams().On(IndexDataType.JSON).Prefix("vec:"); + Assert.True(ft.Create("vss_idx", idxDef, schema)); + + float[] vec = new float[] { 2, 2, 2, 2 }; + byte[] queryVec = MemoryMarshal.Cast(vec).ToArray(); + + + var query = new Query("*=>[KNN 3 @vector $query_vec]") + .AddParam("query_vec", queryVec) + .SetSortBy("__vector_score") + .Dialect(2); + var res = ft.Search("vss_idx", query); + + Assert.Equal(3, res.TotalResults); + + Assert.Equal("vec:2", res.Documents[0].Id.ToString()); + + Assert.Equal(0, res.Documents[0]["__vector_score"]); + + var jsonRes = res.ToJson(); + Assert.Equal("{\"vector\":[2,2,2,2]}", jsonRes![0]); + } + + [Fact] + public void QueryingVectorFields() + { + IDatabase db = redisFixture.Redis.GetDatabase(); + db.Execute("FLUSHALL"); + var ft = db.FT(); + var json = db.JSON(); + + var schema = new Schema().AddVectorField("v", Schema.VectorField.VectorAlgo.HNSW, new Dictionary() + { + ["TYPE"] = "FLOAT32", + ["DIM"] = "2", + ["DISTANCE_METRIC"] = "L2", + }); + + ft.Create("idx", new FTCreateParams(), schema); + + db.HashSet("a", "v", "aaaaaaaa"); + db.HashSet("b", "v", "aaaabaaa"); + db.HashSet("c", "v", "aaaaabaa"); + + var q = new Query("*=>[KNN 2 @v $vec]").ReturnFields("__v_score").Dialect(2); + var res = ft.Search("idx", q.AddParam("vec", "aaaaaaaa")); + Assert.Equal(2, res.TotalResults); + } + + [Fact] + public async Task TestVectorFieldJson_Issue102Async() + { + IDatabase db = redisFixture.Redis.GetDatabase(); + db.Execute("FLUSHALL"); + var ft = db.FT(); + var json = db.JSON(); + + // JSON.SET 1 $ '{"vec":[1,2,3,4]}' + await json.SetAsync("1", "$", "{\"vec\":[1,2,3,4]}"); + + // FT.CREATE my_index ON JSON SCHEMA $.vec as vector VECTOR FLAT 6 TYPE FLOAT32 DIM 4 DISTANCE_METRIC L2 + var schema = new Schema().AddVectorField(FieldName.Of("$.vec").As("vector"), Schema.VectorField.VectorAlgo.FLAT, new Dictionary() + { + ["TYPE"] = "FLOAT32", + ["DIM"] = "4", + ["DISTANCE_METRIC"] = "L2", + }); + + Assert.True(await ft.CreateAsync("my_index", new FTCreateParams().On(IndexDataType.JSON), schema)); + + } + [Fact] public void TestModulePrefixs1() {