Skip to content

Commit b49352e

Browse files
committed
adding tests for bfloat16
1 parent 4b10eac commit b49352e

File tree

1 file changed

+52
-0
lines changed

1 file changed

+52
-0
lines changed
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
using StackExchange.Redis;
2+
using NRedisStack.Search;
3+
using NRedisStack.RedisStackCommands;
4+
using Xunit;
5+
using System.Runtime.InteropServices;
6+
7+
namespace NRedisStack.Tests.Search;
8+
9+
public class IndexCreationTests : AbstractNRedisStackTest, IDisposable
10+
{
11+
public IndexCreationTests(RedisFixture redisFixture) : base(redisFixture) { }
12+
13+
[SkipIfRedis(Is.OSSCluster, Comparison.LessThan, "7.3.240")]
14+
public void TestCreateFloat16VectorField()
15+
{
16+
IDatabase db = redisFixture.Redis.GetDatabase();
17+
db.Execute("FLUSHALL");
18+
var ft = db.FT(2);
19+
20+
var schema = new Schema().AddVectorField("v", Schema.VectorField.VectorAlgo.FLAT, new Dictionary<string, object>()
21+
{
22+
["TYPE"] = "FLOAT16",
23+
["DIM"] = "5",
24+
["DISTANCE_METRIC"] = "L2",
25+
}).AddVectorField("v2", Schema.VectorField.VectorAlgo.FLAT, new Dictionary<string, object>()
26+
{
27+
["TYPE"] = "BFLOAT16",
28+
["DIM"] = "4",
29+
["DISTANCE_METRIC"] = "L2",
30+
});
31+
Assert.True(ft.Create("idx", new FTCreateParams(), schema));
32+
33+
short[] vec1 = new short[] { 2, 1, 2, 2, 2 };
34+
byte[] vec1ToBytes = MemoryMarshal.Cast<short, byte>(vec1).ToArray();
35+
36+
short[] vec2 = new short[] { 1, 2, 2, 2 };
37+
byte[] vec2ToBytes = MemoryMarshal.Cast<short, byte>(vec2).ToArray();
38+
39+
var entries = new HashEntry[] { new HashEntry("v", vec1ToBytes), new HashEntry("v2", vec2ToBytes) };
40+
db.HashSet("a", entries);
41+
db.HashSet("b", entries);
42+
db.HashSet("c", entries);
43+
44+
var q = new Query("*=>[KNN 2 @v $vec]").ReturnFields("__v_score");
45+
var res = ft.Search("idx", q.AddParam("vec", vec1ToBytes));
46+
Assert.Equal(2, res.TotalResults);
47+
48+
q = new Query("*=>[KNN 2 @v2 $vec]").ReturnFields("__v_score");
49+
res = ft.Search("idx", q.AddParam("vec", vec2ToBytes));
50+
Assert.Equal(2, res.TotalResults);
51+
}
52+
}

0 commit comments

Comments
 (0)