Skip to content

Commit 7df134a

Browse files
committed
- Add modelstore command (new version for modelset) and move modelset to depracted.c
- Replace modelrun and modelset with modelexecute and modelstroe in onnx and pytorch tests. - Create a new test file (currently still empty) for testing deprcated APIs.
1 parent d6673e8 commit 7df134a

File tree

8 files changed

+392
-217
lines changed

8 files changed

+392
-217
lines changed

src/execution/deprecated.c

Lines changed: 241 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,10 @@
33
#include "command_parser.h"
44
#include "util/string_utils.h"
55
#include "execution/utils.h"
6+
#include "rmutil/args.h"
7+
#include "backends/backends.h"
8+
#include "execution/background_workers.h"
9+
#include "redis_ai_objects/stats.h"
610

711
static int _ModelRunCommand_ParseArgs(RedisModuleCtx *ctx, int argc, RedisModuleString **argv,
812
RAI_Model **model, RAI_Error *error,
@@ -104,3 +108,240 @@ int ParseModelRunCommand(RedisAI_RunInfo *rinfo, RAI_DagOp *currentOp, RedisModu
104108
RedisModule_FreeThreadSafeContext(ctx);
105109
return res;
106110
}
111+
112+
int ModelSetCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) {
113+
if (argc < 4)
114+
return RedisModule_WrongArity(ctx);
115+
116+
ArgsCursor ac;
117+
ArgsCursor_InitRString(&ac, argv + 1, argc - 1);
118+
119+
RedisModuleString *keystr;
120+
AC_GetRString(&ac, &keystr, 0);
121+
122+
const char *bckstr;
123+
int backend;
124+
AC_GetString(&ac, &bckstr, NULL, 0);
125+
if (strcasecmp(bckstr, "TF") == 0) {
126+
backend = RAI_BACKEND_TENSORFLOW;
127+
} else if (strcasecmp(bckstr, "TFLITE") == 0) {
128+
backend = RAI_BACKEND_TFLITE;
129+
} else if (strcasecmp(bckstr, "TORCH") == 0) {
130+
backend = RAI_BACKEND_TORCH;
131+
} else if (strcasecmp(bckstr, "ONNX") == 0) {
132+
backend = RAI_BACKEND_ONNXRUNTIME;
133+
} else {
134+
return RedisModule_ReplyWithError(ctx, "ERR unsupported backend");
135+
}
136+
137+
const char *devicestr;
138+
AC_GetString(&ac, &devicestr, NULL, 0);
139+
140+
if (strlen(devicestr) > 10 || strcasecmp(devicestr, "INPUTS") == 0 ||
141+
strcasecmp(devicestr, "OUTPUTS") == 0 || strcasecmp(devicestr, "TAG") == 0 ||
142+
strcasecmp(devicestr, "BATCHSIZE") == 0 || strcasecmp(devicestr, "MINBATCHSIZE") == 0 ||
143+
strcasecmp(devicestr, "MINBATCHTIMEOUT") == 0 || strcasecmp(devicestr, "BLOB") == 0) {
144+
return RedisModule_ReplyWithError(ctx, "ERR Invalid DEVICE");
145+
}
146+
147+
RedisModuleString *tag = NULL;
148+
if (AC_AdvanceIfMatch(&ac, "TAG")) {
149+
AC_GetRString(&ac, &tag, 0);
150+
}
151+
152+
unsigned long long batchsize = 0;
153+
if (AC_AdvanceIfMatch(&ac, "BATCHSIZE")) {
154+
if (backend == RAI_BACKEND_TFLITE) {
155+
return RedisModule_ReplyWithError(
156+
ctx, "ERR Auto-batching not supported by the TFLITE backend");
157+
}
158+
if (AC_GetUnsignedLongLong(&ac, &batchsize, 0) != AC_OK) {
159+
return RedisModule_ReplyWithError(ctx, "ERR Invalid argument for BATCHSIZE");
160+
}
161+
}
162+
163+
unsigned long long minbatchsize = 0;
164+
if (AC_AdvanceIfMatch(&ac, "MINBATCHSIZE")) {
165+
if (batchsize == 0) {
166+
return RedisModule_ReplyWithError(ctx, "ERR MINBATCHSIZE specified without BATCHSIZE");
167+
}
168+
if (AC_GetUnsignedLongLong(&ac, &minbatchsize, 0) != AC_OK) {
169+
return RedisModule_ReplyWithError(ctx, "ERR Invalid argument for MINBATCHSIZE");
170+
}
171+
}
172+
173+
unsigned long long minbatchtimeout = 0;
174+
if (AC_AdvanceIfMatch(&ac, "MINBATCHTIMEOUT")) {
175+
if (batchsize == 0) {
176+
return RedisModule_ReplyWithError(ctx,
177+
"ERR MINBATCHTIMEOUT specified without BATCHSIZE");
178+
}
179+
if (minbatchsize == 0) {
180+
return RedisModule_ReplyWithError(ctx,
181+
"ERR MINBATCHTIMEOUT specified without MINBATCHSIZE");
182+
}
183+
if (AC_GetUnsignedLongLong(&ac, &minbatchtimeout, 0) != AC_OK) {
184+
return RedisModule_ReplyWithError(ctx, "ERR Invalid argument for MINBATCHTIMEOUT");
185+
}
186+
}
187+
188+
if (AC_IsAtEnd(&ac)) {
189+
return RedisModule_ReplyWithError(ctx, "ERR Insufficient arguments, missing model BLOB");
190+
}
191+
192+
ArgsCursor optionsac;
193+
const char *blob_matches[] = {"BLOB"};
194+
AC_GetSliceUntilMatches(&ac, &optionsac, 1, blob_matches);
195+
196+
if (optionsac.argc == 0 && backend == RAI_BACKEND_TENSORFLOW) {
197+
return RedisModule_ReplyWithError(
198+
ctx, "ERR Insufficient arguments, INPUTS and OUTPUTS not specified");
199+
}
200+
201+
ArgsCursor inac = {0};
202+
ArgsCursor outac = {0};
203+
if (optionsac.argc > 0 && backend == RAI_BACKEND_TENSORFLOW) {
204+
if (!AC_AdvanceIfMatch(&optionsac, "INPUTS")) {
205+
return RedisModule_ReplyWithError(ctx, "ERR INPUTS not specified");
206+
}
207+
208+
const char *matches[] = {"OUTPUTS"};
209+
AC_GetSliceUntilMatches(&optionsac, &inac, 1, matches);
210+
211+
if (!AC_IsAtEnd(&optionsac)) {
212+
if (!AC_AdvanceIfMatch(&optionsac, "OUTPUTS")) {
213+
return RedisModule_ReplyWithError(ctx, "ERR OUTPUTS not specified");
214+
}
215+
216+
AC_GetSliceToEnd(&optionsac, &outac);
217+
}
218+
}
219+
220+
size_t ninputs = inac.argc;
221+
const char *inputs[ninputs];
222+
for (size_t i = 0; i < ninputs; i++) {
223+
AC_GetString(&inac, inputs + i, NULL, 0);
224+
}
225+
226+
size_t noutputs = outac.argc;
227+
const char *outputs[noutputs];
228+
for (size_t i = 0; i < noutputs; i++) {
229+
AC_GetString(&outac, outputs + i, NULL, 0);
230+
}
231+
232+
RAI_ModelOpts opts = {
233+
.batchsize = batchsize,
234+
.minbatchsize = minbatchsize,
235+
.minbatchtimeout = minbatchtimeout,
236+
.backends_intra_op_parallelism = getBackendsIntraOpParallelism(),
237+
.backends_inter_op_parallelism = getBackendsInterOpParallelism(),
238+
};
239+
240+
RAI_Model *model = NULL;
241+
242+
AC_AdvanceUntilMatches(&ac, 1, blob_matches);
243+
244+
if (AC_IsAtEnd(&ac)) {
245+
return RedisModule_ReplyWithError(ctx, "ERR Insufficient arguments, missing model BLOB");
246+
}
247+
248+
AC_Advance(&ac);
249+
250+
ArgsCursor blobsac;
251+
AC_GetSliceToEnd(&ac, &blobsac);
252+
253+
size_t modellen;
254+
char *modeldef;
255+
256+
if (blobsac.argc == 1) {
257+
AC_GetString(&blobsac, (const char **)&modeldef, &modellen, 0);
258+
} else {
259+
const char *chunks[blobsac.argc];
260+
size_t chunklens[blobsac.argc];
261+
modellen = 0;
262+
while (!AC_IsAtEnd(&blobsac)) {
263+
AC_GetString(&blobsac, &chunks[blobsac.offset], &chunklens[blobsac.offset], 0);
264+
modellen += chunklens[blobsac.offset - 1];
265+
}
266+
267+
modeldef = RedisModule_Calloc(modellen, sizeof(char));
268+
size_t offset = 0;
269+
for (size_t i = 0; i < blobsac.argc; i++) {
270+
memcpy(modeldef + offset, chunks[i], chunklens[i]);
271+
offset += chunklens[i];
272+
}
273+
}
274+
275+
RAI_Error err = {0};
276+
277+
model = RAI_ModelCreate(backend, devicestr, tag, opts, ninputs, inputs, noutputs, outputs,
278+
modeldef, modellen, &err);
279+
280+
if (err.code == RAI_EBACKENDNOTLOADED) {
281+
RedisModule_Log(ctx, "warning", "backend %s not loaded, will try loading default backend",
282+
bckstr);
283+
int ret = RAI_LoadDefaultBackend(ctx, backend);
284+
if (ret == REDISMODULE_ERR) {
285+
RedisModule_Log(ctx, "error", "could not load %s default backend", bckstr);
286+
int ret = RedisModule_ReplyWithError(ctx, "ERR Could not load backend");
287+
RAI_ClearError(&err);
288+
return ret;
289+
}
290+
RAI_ClearError(&err);
291+
model = RAI_ModelCreate(backend, devicestr, tag, opts, ninputs, inputs, noutputs, outputs,
292+
modeldef, modellen, &err);
293+
}
294+
295+
if (blobsac.argc > 1) {
296+
RedisModule_Free(modeldef);
297+
}
298+
299+
if (err.code != RAI_OK) {
300+
RedisModule_Log(ctx, "error", "%s", err.detail);
301+
int ret = RedisModule_ReplyWithError(ctx, err.detail_oneline);
302+
RAI_ClearError(&err);
303+
return ret;
304+
}
305+
306+
// TODO: if backend loaded, make sure there's a queue
307+
RunQueueInfo *run_queue_info = NULL;
308+
if (ensureRunQueue(devicestr, &run_queue_info) != REDISMODULE_OK) {
309+
RAI_ModelFree(model, &err);
310+
if (err.code != RAI_OK) {
311+
RedisModule_Log(ctx, "error", "%s", err.detail);
312+
int ret = RedisModule_ReplyWithError(ctx, err.detail_oneline);
313+
RAI_ClearError(&err);
314+
return ret;
315+
}
316+
return RedisModule_ReplyWithError(ctx,
317+
"ERR Could not initialize queue on requested device");
318+
}
319+
320+
RedisModuleKey *key = RedisModule_OpenKey(ctx, keystr, REDISMODULE_READ | REDISMODULE_WRITE);
321+
int type = RedisModule_KeyType(key);
322+
if (type != REDISMODULE_KEYTYPE_EMPTY &&
323+
!(type == REDISMODULE_KEYTYPE_MODULE &&
324+
RedisModule_ModuleTypeGetType(key) == RedisAI_ModelType)) {
325+
RedisModule_CloseKey(key);
326+
RAI_ModelFree(model, &err);
327+
if (err.code != RAI_OK) {
328+
RedisModule_Log(ctx, "error", "%s", err.detail);
329+
int ret = RedisModule_ReplyWithError(ctx, err.detail_oneline);
330+
RAI_ClearError(&err);
331+
return ret;
332+
}
333+
return RedisModule_ReplyWithError(ctx, REDISMODULE_ERRORMSG_WRONGTYPE);
334+
}
335+
336+
RedisModule_ModuleTypeSetValue(key, RedisAI_ModelType, model);
337+
338+
model->infokey = RAI_AddStatsEntry(ctx, keystr, RAI_MODEL, backend, devicestr, tag);
339+
340+
RedisModule_CloseKey(key);
341+
342+
RedisModule_ReplyWithSimpleString(ctx, "OK");
343+
344+
RedisModule_ReplicateVerbatim(ctx);
345+
346+
return REDISMODULE_OK;
347+
}

src/execution/deprecated.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,3 +12,5 @@
1212
*/
1313
int ParseModelRunCommand(RedisAI_RunInfo *rinfo, RAI_DagOp *currentOp, RedisModuleString **argv,
1414
int argc);
15+
16+
int ModelSetCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int argc);

src/redis_ai_objects/model.c

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -272,20 +272,19 @@ int RAI_GetModelFromKeyspace(RedisModuleCtx *ctx, RedisModuleString *keyName, RA
272272
RedisModuleKey *key = RedisModule_OpenKey(ctx, keyName, mode);
273273
if (RedisModule_KeyType(key) == REDISMODULE_KEYTYPE_EMPTY) {
274274
RedisModule_CloseKey(key);
275-
// #IFDEF LITE
275+
#ifndef LITE
276+
RedisModule_Log(ctx, "warning", "could not load %s from keyspace, key doesn't exist",
277+
RedisModule_StringPtrLen(keyName, NULL));
278+
RAI_SetError(err, RAI_EKEYEMPTY, "ERR model key is empty");
279+
return REDISMODULE_ERR;
280+
#else
276281
if (VerifyKeyInThisShard(ctx, keyName)) { // Relevant for enterprise cluster.
277282
RAI_SetError(err, RAI_EKEYEMPTY, "ERR model key is empty");
278283
} else {
279284
RAI_SetError(err, RAI_EKEYEMPTY,
280285
"ERR CROSSSLOT Keys in request don't hash to the same slot");
281286
}
282-
return REDISMODULE_ERR;
283-
// #ELSE
284-
RedisModule_Log(ctx, "error", "could not load %s from keyspace, key doesn't exist",
285-
RedisModule_StringPtrLen(keyName, NULL));
286-
RAI_SetError(err, RAI_EKEYEMPTY, "ERR model key is empty");
287-
return REDISMODULE_ERR;
288-
// #ENDIF
287+
#endif
289288
}
290289
if (RedisModule_ModuleTypeGetType(key) != RedisAI_ModelType) {
291290
RedisModule_CloseKey(key);

0 commit comments

Comments
 (0)