Skip to content

Commit 686d0c3

Browse files
authored
Add AI.CONFIG GET sub-command (#918)
1 parent 0ae695c commit 686d0c3

File tree

4 files changed

+99
-33
lines changed

4 files changed

+99
-33
lines changed

docs/commands.md

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1144,7 +1144,7 @@ The **AI.CONFIG** command sets the value of configuration directives at run-time
11441144

11451145
**Redis API**
11461146
```
1147-
AI.CONFIG <BACKENDSPATH <path>> | <LOADBACKEND <backend> <path>> | <MODEL_CHUNK_SIZE <chunk_size>>
1147+
AI.CONFIG <BACKENDSPATH <path>> | <LOADBACKEND <backend> <path>> | <MODEL_CHUNK_SIZE <chunk_size>> | <GET <BACKENDSPATH | MODEL_CHUNK_SIZE>>
11481148
```
11491149

11501150
_Arguments_
@@ -1156,6 +1156,7 @@ _Arguments_
11561156
* **TORCH**: The PyTorch backend
11571157
* **ONNX**: ONNXRuntime backend
11581158
* **MODEL_CHUNK_SIZE**: Sets the size of chunks (in bytes) in which model payloads are split for serialization, replication and `MODELGET`. Default is `511 * 1024 * 1024`.
1159+
* **GET**: Retrieve the current value of the `BACKENDSPATH / MODEL_CHUNK_SIZE` configurations. Note that additional information about the module's runtime configuration can be retrieved as part of Redis' info report via `INFO MODULES` command.
11591160

11601161
_Return_
11611162

@@ -1190,3 +1191,10 @@ This sets model chunk size to one megabyte (not recommended):
11901191
redis> AI.CONFIG MODEL_CHUNK_SIZE 1048576
11911192
OK
11921193
```
1194+
1195+
This returns the current model chunk size configuration:
1196+
1197+
```
1198+
redis> AI.CONFIG GET MODEL_CHUNK_SIZE
1199+
1048576
1200+
```

src/redisai.c

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -939,7 +939,20 @@ int RedisAI_Config_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, i
939939
return RedisModule_ReplyWithError(ctx, "ERR MODEL_CHUNK_SIZE: missing chunk size");
940940
}
941941
}
942-
942+
if (!strcasecmp(subcommand, "GET")) {
943+
if (argc > 2) {
944+
const char *config = RedisModule_StringPtrLen(argv[2], NULL);
945+
if (!strcasecmp(config, "BACKENDSPATH")) {
946+
return RedisModule_ReplyWithCString(ctx, Config_GetBackendsPath());
947+
} else if (!strcasecmp(config, "MODEL_CHUNK_SIZE")) {
948+
return RedisModule_ReplyWithLongLong(ctx, Config_GetModelChunkSize());
949+
} else {
950+
return RedisModule_ReplyWithNull(ctx);
951+
}
952+
} else {
953+
return RedisModule_WrongArity(ctx);
954+
}
955+
}
943956
return RedisModule_ReplyWithError(ctx, "ERR unsupported subcommand");
944957
}
945958

tests/flow/tests_commands.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -536,3 +536,79 @@ def run_model_execute_from_llapi():
536536
info = info_to_dict(con.execute_command('AI.INFO', 'm{1}'))
537537
env.assertGreaterEqual(info['calls'], 0)
538538
env.assertGreaterEqual(num_parallel_clients, info['calls'])
539+
540+
541+
def test_ai_config(env):
542+
if not TEST_PT:
543+
env.debugPrint("skipping {} since TEST_PT=0".format(sys._getframe().f_code.co_name), force=True)
544+
return
545+
546+
conns = env.getOSSMasterNodesConnectionList()
547+
if env.isCluster():
548+
env.assertEqual(len(conns), env.shardsCount)
549+
550+
model = load_file_content('pt-minimal.pt')
551+
552+
for con in conns:
553+
# Get the default configs.
554+
res = con.execute_command('AI.CONFIG', 'GET', 'BACKENDSPATH')
555+
env.assertEqual(res, None)
556+
res = con.execute_command('AI.CONFIG', 'GET', 'MODEL_CHUNK_SIZE')
557+
env.assertEqual(res, 511*1024*1024)
558+
559+
# Change the default backends path and load backend.
560+
path = f'{ROOT}/install-{DEVICE.lower()}'
561+
con.execute_command('AI.CONFIG', 'BACKENDSPATH', path)
562+
res = con.execute_command('AI.CONFIG', 'GET', 'BACKENDSPATH')
563+
env.assertEqual(res, path.encode())
564+
be_info = get_info_section(con, "backends_info")
565+
env.assertEqual(len(be_info), 0) # no backends are loaded.
566+
check_error_message(env, con, 'error loading backend', 'AI.CONFIG', 'LOADBACKEND', 'TORCH', ".")
567+
568+
res = con.execute_command('AI.CONFIG', 'LOADBACKEND', 'TORCH', "backends/redisai_torch/redisai_torch.so")
569+
env.assertEqual(res, b'OK')
570+
be_info = get_info_section(con, "backends_info")
571+
env.assertEqual(len(be_info), 1) # one backend is loaded now - torch.
572+
573+
# Set the same model twice on some shard - with and without chunks, and assert equality.
574+
con = get_connection(env, '{1}')
575+
chunk_size = len(model) // 3
576+
model_chunks = [model[i:i + chunk_size] for i in range(0, len(model), chunk_size)]
577+
con.execute_command('AI.MODELSTORE', 'm1{1}', 'TORCH', DEVICE, 'BLOB', model)
578+
con.execute_command('AI.MODELSTORE', 'm2{1}', 'TORCH', DEVICE, 'BLOB', *model_chunks)
579+
model1 = con.execute_command('AI.MODELGET', 'm1{1}', 'BLOB')
580+
model2 = con.execute_command('AI.MODELGET', 'm2{1}', 'BLOB')
581+
env.assertEqual(model1, model2)
582+
583+
for con in conns:
584+
# Change the default model_chunk_size.
585+
ret = con.execute_command('AI.CONFIG', 'MODEL_CHUNK_SIZE', chunk_size)
586+
env.assertEqual(ret, b'OK')
587+
res = con.execute_command('AI.CONFIG', 'GET', 'MODEL_CHUNK_SIZE')
588+
env.assertEqual(res, chunk_size)
589+
590+
# Verify that AI.MODELGET returns the model's blob in chunks, with or without the META arg.
591+
con = get_connection(env, '{1}')
592+
model2 = con.execute_command('AI.MODELGET', 'm1{1}', 'BLOB')
593+
env.assertEqual(len(model2), len(model_chunks))
594+
env.assertTrue(all([el1 == el2 for el1, el2 in zip(model2, model_chunks)]))
595+
596+
model3 = con.execute_command('AI.MODELGET', 'm1{1}', 'META', 'BLOB')[-1] # Extract the BLOB list from the result
597+
env.assertEqual(len(model3), len(model_chunks))
598+
env.assertTrue(all([el1 == el2 for el1, el2 in zip(model3, model_chunks)]))
599+
600+
601+
def test_ai_config_errors(env):
602+
con = get_connection(env, '{1}')
603+
604+
check_error_message(env, con, "wrong number of arguments for 'AI.CONFIG' command", 'AI.CONFIG')
605+
check_error_message(env, con, 'unsupported subcommand', 'AI.CONFIG', "bad_subcommand")
606+
check_error_message(env, con, "wrong number of arguments for 'AI.CONFIG' command", 'AI.CONFIG', 'LOADBACKEND')
607+
check_error_message(env, con, 'unsupported backend', 'AI.CONFIG', 'LOADBACKEND', 'bad_backend', "backends/redisai_torch/redisai_torch.so")
608+
check_error_message(env, con, "wrong number of arguments for 'AI.CONFIG' command", 'AI.CONFIG', 'LOADBACKEND', "TORCH")
609+
610+
check_error_message(env, con, 'BACKENDSPATH: missing path argument', 'AI.CONFIG', 'BACKENDSPATH')
611+
check_error_message(env, con, 'MODEL_CHUNK_SIZE: missing chunk size', 'AI.CONFIG', 'MODEL_CHUNK_SIZE')
612+
613+
check_error_message(env, con, "wrong number of arguments for 'AI.CONFIG' command", 'AI.CONFIG', 'GET')
614+
env.assertEqual(con.execute_command('AI.CONFIG', 'GET', 'bad_config'), None)

tests/flow/tests_pytorch.py

Lines changed: 0 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -9,37 +9,6 @@
99
'''
1010

1111

12-
def test_pytorch_chunked_modelstore(env):
13-
if not TEST_PT:
14-
env.debugPrint("skipping {} since TEST_PT=0".format(sys._getframe().f_code.co_name), force=True)
15-
return
16-
17-
con = get_connection(env, '{1}')
18-
model = load_file_content('pt-minimal.pt')
19-
20-
chunk_size = len(model) // 3
21-
22-
model_chunks = [model[i:i + chunk_size] for i in range(0, len(model), chunk_size)]
23-
24-
ret = con.execute_command('AI.MODELSTORE', 'm1{1}', 'TORCH', DEVICE, 'BLOB', model)
25-
ret = con.execute_command('AI.MODELSTORE', 'm2{1}', 'TORCH', DEVICE, 'BLOB', *model_chunks)
26-
27-
model1 = con.execute_command('AI.MODELGET', 'm1{1}', 'BLOB')
28-
model2 = con.execute_command('AI.MODELGET', 'm2{1}', 'BLOB')
29-
30-
env.assertEqual(model1, model2)
31-
32-
ret = con.execute_command('AI.CONFIG', 'MODEL_CHUNK_SIZE', chunk_size)
33-
34-
model2 = con.execute_command('AI.MODELGET', 'm2{1}', 'BLOB')
35-
env.assertEqual(len(model2), len(model_chunks))
36-
env.assertTrue(all([el1 == el2 for el1, el2 in zip(model2, model_chunks)]))
37-
38-
model3 = con.execute_command('AI.MODELGET', 'm2{1}', 'META', 'BLOB')[-1] # Extract the BLOB list from the result
39-
env.assertEqual(len(model3), len(model_chunks))
40-
env.assertTrue(all([el1 == el2 for el1, el2 in zip(model3, model_chunks)]))
41-
42-
4312
def test_pytorch_modelrun(env):
4413
if not TEST_PT:
4514
env.debugPrint("skipping {} since TEST_PT=0".format(sys._getframe().f_code.co_name), force=True)

0 commit comments

Comments
 (0)