Skip to content

Commit da2b489

Browse files
committed
ft.profile query_params
1 parent 42b937f commit da2b489

File tree

2 files changed

+27
-1
lines changed

2 files changed

+27
-1
lines changed

redis/commands/search/commands.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -506,7 +506,7 @@ def _get_aggregate_result(self, raw, query, has_cursor):
506506

507507
return AggregateResult(rows, cursor, schema)
508508

509-
def profile(self, query, limited=False):
509+
def profile(self, query, query_params=None, limited=False):
510510
"""
511511
Performs a search or aggregate command and collects performance
512512
information.
@@ -530,6 +530,8 @@ def profile(self, query, limited=False):
530530
elif isinstance(query, Query):
531531
cmd[2] = "SEARCH"
532532
cmd += query.get_args()
533+
if query_params is not None:
534+
cmd += self.get_params_args(query_params)
533535
else:
534536
raise ValueError("Must provide AggregateRequest object or " "Query object.")
535537

tests/test_search.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1557,6 +1557,30 @@ def test_profile_limited(client):
15571557
assert len(res.docs) == 3 # check also the search result
15581558

15591559

1560+
@pytest.mark.redismod
1561+
@skip_ifmodversion_lt("2.4.3", "search")
1562+
def test_profile_query_params(modclient: redis.Redis):
1563+
modclient.flushdb()
1564+
modclient.ft().create_index(
1565+
(
1566+
VectorField(
1567+
"v", "HNSW", {"TYPE": "FLOAT32", "DIM": 2, "DISTANCE_METRIC": "L2"}
1568+
),
1569+
)
1570+
)
1571+
modclient.hset("a", "v", "aaaaaaaa")
1572+
modclient.hset("b", "v", "aaaabaaa")
1573+
modclient.hset("c", "v", "aaaaabaa")
1574+
query = "*=>[KNN 2 @v $vec]"
1575+
q = Query(query).return_field("__v_score").sort_by("__v_score", True).dialect(2)
1576+
res, det = modclient.ft().profile(q, query_params={"vec": "aaaaaaaa"})
1577+
assert det["Iterators profile"]["Counter"] == 2.0
1578+
assert det["Iterators profile"]["Type"] == "VECTOR"
1579+
assert res.total == 2
1580+
assert "a" == res.docs[0].id
1581+
assert "0" == res.docs[0].__getattribute__("__v_score")
1582+
1583+
15601584
@pytest.mark.redismod
15611585
@skip_ifmodversion_lt("2.4.3", "search")
15621586
def test_vector_field(modclient):

0 commit comments

Comments
 (0)