From 6cccd9d56948312016708c49275401df27fc77fd Mon Sep 17 00:00:00 2001 From: Luca Antiga Date: Sun, 3 Nov 2019 13:20:58 +0100 Subject: [PATCH 1/7] Add support for automated batching Add support for inspection and eviction to queue Mock run info batching Mock run info batching Make TF tests work Add batching for ONNX and ONNX-ML Fix torch API, still WIP Fix torch backend Fixes after rebasing Add auto-batching to TFLite backend Fix from rebase Add batching args to command and change API accordingly Add batching heuristics [WIP] Fix TFLite test by accessing first tensor in first batch safely Temporarily comment out wrong_bg test check Implement batching heuristics Introduce autobatch tests, tflite still fails Fix segfault when error was generated from the backend Fix tflite autobatch test Updated documentation with auto batching Remove stale comments Avoid making extra copies of inputs and outputs when batch count is 1 Address review comments re const-correctness Add tests to detect failures Fix slicing and concatenation Fix tensor slicing and concatenating Temporarily disable tflite autobatch test due to tflite limitation Disable support for autobatching for TFLITE --- docs/commands.md | 20 +- src/backends.c | 8 +- src/backends.h | 4 +- src/backends/onnxruntime.c | 214 ++++++++++++------- src/backends/onnxruntime.h | 2 +- src/backends/tensorflow.c | 118 +++++++++-- src/backends/tensorflow.h | 2 +- src/backends/tflite.c | 81 ++++++-- src/backends/tflite.h | 2 +- src/backends/torch.c | 75 +++++-- src/backends/torch.h | 2 +- src/err.c | 4 +- src/model.c | 128 +++++++++--- src/model.h | 14 +- src/model_struct.h | 14 +- src/redisai.c | 334 +++++++++++++++++++++++++++--- src/redisai.h | 10 +- src/tensor.c | 92 +++++++- src/tensor.h | 10 +- test/test_data/mnist_batched.onnx | 3 + test/test_data/onnx_batch.py | 38 ++++ test/tests_onnx.py | 45 ++++ test/tests_pytorch.py | 40 ++++ test/tests_tensorflow.py | 41 ++++ test/tests_tflite.py | 51 +++++ 25 files changed, 1154 insertions(+), 198 deletions(-) create mode 100644 test/test_data/mnist_batched.onnx create mode 100644 test/test_data/onnx_batch.py diff --git a/docs/commands.md b/docs/commands.md index 6b421eb51..9d8547698 100644 --- a/docs/commands.md +++ b/docs/commands.md @@ -67,12 +67,22 @@ AI.TENSORGET foo BLOB Set a model. ```sql -AI.MODELSET model_key backend device [INPUTS name1 name2 ... OUTPUTS name1 name2 ...] model_blob +AI.MODELSET model_key backend device [BATCHSIZE n [MINBATCHSIZE m]] [INPUTS name1 name2 ... OUTPUTS name1 name2 ...] model_blob ``` * model_key - Key for storing the model * backend - The backend corresponding to the model being set. Allowed values: `TF`, `TORCH`, `ONNX`. * device - Device where the model is loaded and where the computation will run. Allowed values: `CPU`, `GPU`. +* BATCHSIZE n - Batch incoming requests from multiple clients if they hit the same model and if input tensors have the same + shape. Upon MODELRUN, the request queue is visited, input tensors from compatible requests are concatenated + along the 0-th (batch) dimension, up until BATCHSIZE is exceeded. The model is then run for the entire batch, + results are unpacked back among the individual requests and the respective clients are unblocked. + If the batch size of the inputs to the first request in the queue exceeds BATCHSIZE, the request is served + in any case. Default is 0 (no batching). +* MINBATCHSIZE m - Do not execute a MODELRUN until the batch size has reached MINBATCHSIZE. This is primarily used to force + batching during testing, but it can also be used under normal operation. In this case, note that requests + for which MINBATCHSIZE is not reached will hang indefinitely. + Default is 0 (no minimum batch size). * INPUTS name1 name2 ... - Name of the nodes in the provided graph corresponding to inputs [`TF` backend only] * OUTPUTS name1 name2 ... - Name of the nodes in the provided graph corresponding to outputs [`TF` backend only] * model_blob - Binary buffer containing the model protobuf saved from a supported backend @@ -91,6 +101,14 @@ AI.MODELSET resnet18 TF CPU INPUTS in1 OUTPUTS linear4 < foo.pb AI.MODELSET mnist_net ONNX CPU < mnist.onnx ``` +```sql +AI.MODELSET mnist_net ONNX CPU BATCHSIZE 10 < mnist.onnx +``` + +```sql +AI.MODELSET resnet18 TF CPU BATCHSIZE 10 MINBATCHSIZE 6 INPUTS in1 OUTPUTS linear4 < foo.pb +``` + ## AI.MODELGET Get a model. diff --git a/src/backends.c b/src/backends.c index 78051dd8f..a45044cc1 100644 --- a/src/backends.c +++ b/src/backends.c @@ -74,7 +74,7 @@ int RAI_LoadBackend_TensorFlow(RedisModuleCtx *ctx, const char *path) { } init_backend(RedisModule_GetApi); - backend.model_create_with_nodes = (RAI_Model* (*)(RAI_Backend, const char*, + backend.model_create_with_nodes = (RAI_Model* (*)(RAI_Backend, const char*, RAI_ModelOpts, size_t, const char**, size_t, const char**, const char*, size_t, RAI_Error*)) (unsigned long) dlsym(handle, "RAI_ModelCreateTF"); @@ -140,7 +140,7 @@ int RAI_LoadBackend_TFLite(RedisModuleCtx *ctx, const char *path) { } init_backend(RedisModule_GetApi); - backend.model_create = (RAI_Model* (*)(RAI_Backend, const char*, + backend.model_create = (RAI_Model* (*)(RAI_Backend, const char*, RAI_ModelOpts, const char*, size_t, RAI_Error*)) (unsigned long) dlsym(handle, "RAI_ModelCreateTFLite"); if (backend.model_create == NULL) { @@ -205,7 +205,7 @@ int RAI_LoadBackend_Torch(RedisModuleCtx *ctx, const char *path) { } init_backend(RedisModule_GetApi); - backend.model_create = (RAI_Model* (*)(RAI_Backend, const char*, + backend.model_create = (RAI_Model* (*)(RAI_Backend, const char*, RAI_ModelOpts, const char*, size_t, RAI_Error*)) (unsigned long) dlsym(handle, "RAI_ModelCreateTorch"); if (backend.model_create == NULL) { @@ -294,7 +294,7 @@ int RAI_LoadBackend_ONNXRuntime(RedisModuleCtx *ctx, const char *path) { } init_backend(RedisModule_GetApi); - backend.model_create = (RAI_Model* (*)(RAI_Backend, const char*, + backend.model_create = (RAI_Model* (*)(RAI_Backend, const char*, RAI_ModelOpts, const char*, size_t, RAI_Error*)) (unsigned long) dlsym(handle, "RAI_ModelCreateORT"); if (backend.model_create == NULL) { diff --git a/src/backends.h b/src/backends.h index 77b4cf1a1..6ae6871a0 100644 --- a/src/backends.h +++ b/src/backends.h @@ -8,10 +8,10 @@ #include "err.h" typedef struct RAI_LoadedBackend { - RAI_Model* (*model_create_with_nodes)(RAI_Backend, const char*, + RAI_Model* (*model_create_with_nodes)(RAI_Backend, const char*, RAI_ModelOpts, size_t, const char**, size_t, const char**, const char*, size_t, RAI_Error*); - RAI_Model* (*model_create)(RAI_Backend, const char*, + RAI_Model* (*model_create)(RAI_Backend, const char*, RAI_ModelOpts, const char*, size_t, RAI_Error*); void (*model_free)(RAI_Model*, RAI_Error*); int (*model_run)(RAI_ModelRunCtx*, RAI_Error*); diff --git a/src/backends/onnxruntime.c b/src/backends/onnxruntime.c index 17ad615f1..180267886 100644 --- a/src/backends/onnxruntime.c +++ b/src/backends/onnxruntime.c @@ -78,32 +78,82 @@ DLDataType RAI_GetDLDataTypeFromORT(ONNXTensorElementDataType dtype) { return (DLDataType){ .bits = 0 }; } -OrtValue* RAI_OrtValueFromTensor(RAI_Tensor* t, RAI_Error *error) { - // TODO: create outside and pass? +OrtValue* RAI_OrtValueFromTensors(RAI_Tensor** ts, size_t count, RAI_Error *error) { + OrtStatus* status = NULL; const OrtApi* ort = OrtGetApiBase()->GetApi(1); - OrtMemoryInfo* memory_info; - OrtStatus* status; - status = ort->CreateCpuMemoryInfo(OrtArenaAllocator, OrtMemTypeDefault, &memory_info); + + if (count == 0) { + return NULL; + } + + OrtAllocator *allocator; + status = ort->GetAllocatorWithDefaultOptions(&allocator); if (status != NULL) { - goto error; + return NULL; } + if (count == 0) { + return NULL; + } + + size_t batch_size = 0; + size_t batch_byte_size = 0; + + for (size_t i=0; itensor.dl_tensor.shape[0]; + batch_byte_size += RAI_TensorByteSize(ts[i]); + } + + RAI_Tensor* t0 = ts[0]; + + const int ndim = t0->tensor.dl_tensor.ndim; + int64_t batched_shape[ndim]; + + for (size_t i=1; itensor.dl_tensor.shape[i]; + } + + batched_shape[0] = batch_size; + OrtValue* out; - status = ort->CreateTensorWithDataAsOrtValue( - memory_info, - t->tensor.dl_tensor.data, - RAI_TensorByteSize(t), - t->tensor.dl_tensor.shape, - t->tensor.dl_tensor.ndim, - RAI_GetOrtDataTypeFromDL(t->tensor.dl_tensor.dtype), - &out); - if (status != NULL) { - ort->ReleaseMemoryInfo(memory_info); - goto error; + if (count > 1) { + status = ort->CreateTensorAsOrtValue( + allocator, + batched_shape, + t0->tensor.dl_tensor.ndim, + RAI_GetOrtDataTypeFromDL(t0->tensor.dl_tensor.dtype), + &out); + if (status != NULL) { + goto error; + } + + char *ort_data; + status = ort->GetTensorMutableData(out, (void **)&ort_data); + if (status != NULL) { + goto error; + } + + size_t offset = 0; + for (size_t i=0; iCreateTensorWithDataAsOrtValue( + allocator->Info(allocator), + t0->tensor.dl_tensor.data, + RAI_TensorByteSize(t0), + t0->tensor.dl_tensor.shape, + t0->tensor.dl_tensor.ndim, + RAI_GetOrtDataTypeFromDL(t0->tensor.dl_tensor.dtype), + &out); - ort->ReleaseMemoryInfo(memory_info); + if (status != NULL) { + goto error; + } + } return out; @@ -113,7 +163,7 @@ OrtValue* RAI_OrtValueFromTensor(RAI_Tensor* t, RAI_Error *error) { return NULL; } -RAI_Tensor* RAI_TensorCreateFromOrtValue(OrtValue* v, RAI_Error *error) { +RAI_Tensor* RAI_TensorCreateFromOrtValue(OrtValue* v, size_t batch_offset, size_t batch_size, RAI_Error *error) { OrtStatus* status = NULL; const OrtApi* ort = OrtGetApiBase()->GetApi(1); @@ -155,18 +205,23 @@ RAI_Tensor* RAI_TensorCreateFromOrtValue(OrtValue* v, RAI_Error *error) { status = ort->GetTensorElementType(info, &ort_dtype); if (status != NULL) goto error; + int64_t total_batch_size = dims[0]; + shape = RedisModule_Calloc(ndims, sizeof(*shape)); strides = RedisModule_Calloc(ndims, sizeof(*strides)); - for (int64_t i = 0; i < ndims; ++i) + for (int64_t i=0; i= 0; --i) { strides[i] *= strides[i + 1] * shape[i + 1]; } + // size_t sample_bytesize = TF_TensorByteSize(tensor) / total_batch_size; + DLDataType dtype = RAI_GetDLDataTypeFromORT(ort_dtype); #ifdef RAI_COPY_RUN_OUTPUT char *ort_data; @@ -180,9 +235,14 @@ RAI_Tensor* RAI_TensorCreateFromOrtValue(OrtValue* v, RAI_Error *error) { goto error; } - size_t len = dtype.bits * elem_count; - char *data = RedisModule_Calloc(len, sizeof(*data)); - memcpy(data, ort_data, len); + const size_t len = dtype.bits * elem_count / 8; + + const size_t total_bytesize = len * sizeof(char); + const size_t sample_bytesize = total_bytesize / total_batch_size; + const size_t batch_bytesize = sample_bytesize * batch_size; + + char *data = RedisModule_Calloc(batch_bytesize, sizeof(*data)); + memcpy(data, ort_data + batch_offset * sample_bytesize, batch_bytesize); #endif ort->ReleaseTensorTypeAndShapeInfo(info); @@ -232,13 +292,9 @@ typedef struct RAI_ONNXBuffer { OrtEnv* env = NULL; -RAI_Model *RAI_ModelCreateORT(RAI_Backend backend, const char* devicestr, +RAI_Model *RAI_ModelCreateORT(RAI_Backend backend, const char* devicestr, RAI_ModelOpts opts, const char *modeldef, size_t modellen, RAI_Error *error) { - - // TODO: take from - // https://github.com/microsoft/onnxruntime/blob/master/csharp/test/Microsoft.ML.OnnxRuntime.EndToEndTests.Capi/C_Api_Sample.cpp - const OrtApi* ort = OrtGetApiBase()->GetApi(1); RAI_Device device; @@ -264,7 +320,7 @@ RAI_Model *RAI_ModelCreateORT(RAI_Backend backend, const char* devicestr, goto error; } - // TODO: probably these options could be configured at the AI.CONFIG level + // TODO: these options could be configured at the AI.CONFIG level OrtSessionOptions* session_options; status = ort->CreateSessionOptions(&session_options); if (status != NULL) { @@ -289,7 +345,7 @@ RAI_Model *RAI_ModelCreateORT(RAI_Backend backend, const char* devicestr, OrtSessionOptionsAppendExecutionProvider_CUDA(session_options, deviceid); } #else - // TODO: Do dynamic device/provider check with GetExecutionProviderType or something else + // TODO: Do dynamic device/provider check with GetExecutionProviderType or something on these lines if (device == RAI_DEVICE_GPU) { RAI_SetError(error, RAI_EMODELCREATE, "GPU requested but ONNX couldn't find CUDA"); return NULL; @@ -322,6 +378,7 @@ RAI_Model *RAI_ModelCreateORT(RAI_Backend backend, const char* devicestr, ret->backend = backend; ret->devicestr = RedisModule_Strdup(devicestr); ret->refCount = 1; + ret->opts = opts; ret->data = onnxbuffer; return ret; @@ -355,26 +412,41 @@ int RAI_ModelRunORT(RAI_ModelRunCtx *mctx, RAI_Error *error) return 1; } + const size_t nbatches = array_len(mctx->batches); + if (nbatches == 0) { + RAI_SetError(error, RAI_EMODELRUN, "No batches to run\n"); + return 1; + } + + size_t batch_sizes[nbatches]; + size_t batch_offsets[nbatches]; + if (array_len(mctx->batches[0].inputs) > 0) { + for (size_t b=0; bbatches[b].inputs[0].tensor, 0); + } + batch_offsets[0] = 0; + for (size_t b=1; bGetAllocatorWithDefaultOptions(&allocator); - if (status != NULL) - { + if (status != NULL) { goto error; } size_t n_input_nodes; status = ort->SessionGetInputCount(session, &n_input_nodes); - if (status != NULL) - { + if (status != NULL) { goto error; } size_t n_output_nodes; status = ort->SessionGetOutputCount(session, &n_output_nodes); - if (status != NULL) - { + if (status != NULL) { goto error; } @@ -385,40 +457,39 @@ int RAI_ModelRunORT(RAI_ModelRunCtx *mctx, RAI_Error *error) OrtValue *inputs[n_input_nodes]; OrtValue *outputs[n_output_nodes]; - size_t ninputs = array_len(mctx->inputs); - size_t noutputs = array_len(mctx->outputs); - - if (ninputs != n_input_nodes) - { + const size_t ninputs = array_len(mctx->batches[0].inputs); + const size_t noutputs = array_len(mctx->batches[0].outputs); + if (ninputs != n_input_nodes) { char msg[70]; sprintf(msg, "Expected %li inputs but got %li", n_input_nodes, ninputs); RAI_SetError(error, RAI_EMODELRUN, msg); return 1; } - if (noutputs != n_output_nodes) - { + if (noutputs != n_output_nodes) { char msg[70]; sprintf(msg, "Expected %li outputs but got %li", n_output_nodes, noutputs); RAI_SetError(error, RAI_EMODELRUN, msg); return 1; } - for (size_t i = 0; i < n_input_nodes; i++) - { + for (size_t i = 0; i < n_input_nodes; i++) { char *input_name; status = ort->SessionGetInputName(session, i, allocator, &input_name); - if (status != NULL) - { + if (status != NULL) { goto error; } input_names[i] = input_name; - inputs[i] = RAI_OrtValueFromTensor(mctx->inputs[i].tensor, error); - if (error->code != RAI_OK) - { + RAI_Tensor* batched_input_tensors[nbatches]; + for (size_t b=0; bbatches[b].inputs[i].tensor; + } + + inputs[i] = RAI_OrtValueFromTensors(batched_input_tensors, nbatches, error); + if (error->code != RAI_OK) { ort->ReleaseStatus(status); return 1; } @@ -443,12 +514,10 @@ int RAI_ModelRunORT(RAI_ModelRunCtx *mctx, RAI_Error *error) #endif } - for (size_t i = 0; i < n_output_nodes; i++) - { + for (size_t i = 0; i < n_output_nodes; i++) { char *output_name; status = ort->SessionGetOutputName(session, i, allocator, &output_name); - if (status != NULL) - { + if (status != NULL) { goto error; } @@ -464,33 +533,30 @@ int RAI_ModelRunORT(RAI_ModelRunCtx *mctx, RAI_Error *error) status = ort->Run(session, run_options, input_names, (const OrtValue *const *)inputs, n_input_nodes, output_names, n_output_nodes, outputs); - if (status) - { + if (status) { goto error; } - for (size_t i = 0; i < n_output_nodes; i++) - { - RAI_Tensor *output_tensor = RAI_TensorCreateFromOrtValue(outputs[i], error); - if (error->code != RAI_OK) - { - ort->ReleaseStatus(status); - return 1; - } - if (output_tensor) - { - mctx->outputs[i].tensor = RAI_TensorGetShallowCopy(output_tensor); - RAI_TensorFree(output_tensor); - } - else - { - printf("ERR: non-tensor output from ONNX models, ignoring (currently unsupported).\n"); + for (size_t i = 0; i < n_output_nodes; i++) { + for (size_t b=0; bcode != RAI_OK) { + ort->ReleaseStatus(status); + return 1; + } + if (output_tensor) { + mctx->batches[b].outputs[i].tensor = RAI_TensorGetShallowCopy(output_tensor); + RAI_TensorFree(output_tensor); + } + else { + printf("ERR: non-tensor output from ONNX models, ignoring (currently unsupported).\n"); + } } + ort->ReleaseValue(outputs[i]); } - for (size_t i = 0; i < n_input_nodes; i++) - { + for (size_t i = 0; i < n_input_nodes; i++) { ort->ReleaseValue(inputs[i]); } diff --git a/src/backends/onnxruntime.h b/src/backends/onnxruntime.h index ba027feb4..86b83633d 100644 --- a/src/backends/onnxruntime.h +++ b/src/backends/onnxruntime.h @@ -8,7 +8,7 @@ int RAI_InitBackendORT(int (*get_api_fn)(const char *, void *)); -RAI_Model *RAI_ModelCreateORT(RAI_Backend backend, const char* devicestr, +RAI_Model *RAI_ModelCreateORT(RAI_Backend backend, const char* devicestr, RAI_ModelOpts opts, const char *modeldef, size_t modellen, RAI_Error *err); diff --git a/src/backends/tensorflow.c b/src/backends/tensorflow.c index 5d447845c..1e2f7837f 100644 --- a/src/backends/tensorflow.c +++ b/src/backends/tensorflow.c @@ -78,7 +78,7 @@ DLDataType RAI_GetDLDataTypeFromTF(TF_DataType dtype) { return (DLDataType){ .bits = 0 }; } -RAI_Tensor* RAI_TensorCreateFromTFTensor(TF_Tensor *tensor) { +RAI_Tensor* RAI_TensorCreateFromTFTensor(TF_Tensor *tensor, size_t batch_offset, size_t batch_size) { RAI_Tensor* ret = RedisModule_Calloc(1, sizeof(*ret)); DLContext ctx = (DLContext){ @@ -86,7 +86,9 @@ RAI_Tensor* RAI_TensorCreateFromTFTensor(TF_Tensor *tensor) { .device_id = 0 }; - size_t ndims = TF_NumDims(tensor); + const size_t ndims = TF_NumDims(tensor); + + const int64_t total_batch_size = TF_Dim(tensor, 0); int64_t* shape = RedisModule_Calloc(ndims, sizeof(*shape)); int64_t* strides = RedisModule_Calloc(ndims, sizeof(*strides)); @@ -94,19 +96,22 @@ RAI_Tensor* RAI_TensorCreateFromTFTensor(TF_Tensor *tensor) { shape[i] = TF_Dim(tensor, i); strides[i] = 1; } + shape[0] = batch_size; for (int64_t i = ndims-2 ; i >= 0 ; --i) { strides[i] *= strides[i+1] * shape[i+1]; } + const size_t sample_bytesize = TF_TensorByteSize(tensor) / total_batch_size; + // FIXME: In TF, RunSession allocates memory for output tensors // This means that we either memcpy the tensor data and let // Redis be responsible for the memory, or we reuse the TF // allocated memory, which might not be optimal down the road // Note: on YOLO this has no impact on perf #ifdef RAI_COPY_RUN_OUTPUT - size_t len = TF_TensorByteSize(tensor); + const size_t len = sample_bytesize * batch_size; char* data = RedisModule_Calloc(len, sizeof(*data)); - memcpy(data, TF_TensorData(tensor), len); + memcpy(data, TF_TensorData(tensor) + sample_bytesize * batch_offset, len); #endif // TODO: use manager_ctx to ensure TF tensor doesn't get deallocated @@ -160,8 +165,63 @@ TF_Tensor* RAI_TFTensorFromTensor(RAI_Tensor* t){ #endif /* RAI_COPY_RUN_INPUT */ } +TF_Tensor* RAI_TFTensorFromTensors(RAI_Tensor** ts, size_t count){ + + if (count == 0) { + return NULL; + } + + size_t batch_size = 0; + size_t batch_byte_size = 0; + + for (size_t i=0; itensor.dl_tensor.shape[0]; + batch_byte_size += RAI_TensorByteSize(ts[i]); + } + + RAI_Tensor* t0 = ts[0]; -RAI_Model *RAI_ModelCreateTF(RAI_Backend backend, const char* devicestr, + const int ndim = t0->tensor.dl_tensor.ndim; + int64_t batched_shape[ndim]; + + for (size_t i=0; itensor.dl_tensor.shape[i]; + } + + batched_shape[0] = batch_size; + + TF_Tensor* out = NULL; + + if (count > 1) { + out = TF_AllocateTensor( + RAI_GetTFDataTypeFromDL(t0->tensor.dl_tensor.dtype), + batched_shape, + t0->tensor.dl_tensor.ndim, + batch_byte_size); + + size_t offset = 0; + for (size_t i=0; itensor.dl_tensor.data, tbytesize); + offset += tbytesize; + } + } + else { + out = TF_NewTensor( + RAI_GetTFDataTypeFromDL(t0->tensor.dl_tensor.dtype), + t0->tensor.dl_tensor.shape, + t0->tensor.dl_tensor.ndim, + t0->tensor.dl_tensor.data, + RAI_TensorByteSize(t0), + &RAI_TFDeallocator, + NULL); + } + + return out; +} + + +RAI_Model *RAI_ModelCreateTF(RAI_Backend backend, const char* devicestr, RAI_ModelOpts opts, size_t ninputs, const char **inputs, size_t noutputs, const char **outputs, const char *modeldef, size_t modellen, @@ -312,6 +372,7 @@ RAI_Model *RAI_ModelCreateTF(RAI_Backend backend, const char* devicestr, ret->devicestr = RedisModule_Strdup(devicestr); ret->inputs = inputs_; ret->outputs = outputs_; + ret->opts = opts; ret->refCount = 1; return ret; @@ -360,17 +421,42 @@ void RAI_ModelFreeTF(RAI_Model* model, RAI_Error* error) { int RAI_ModelRunTF(RAI_ModelRunCtx* mctx, RAI_Error *error) { TF_Status *status = TF_NewStatus(); - const size_t ninputs = array_len(mctx->inputs); - const size_t noutputs = array_len(mctx->outputs); + + const size_t nbatches = array_len(mctx->batches); + if (nbatches == 0) { + RAI_SetError(error, RAI_EMODELRUN, "No batches to run\n"); + return 1; + } + + const size_t ninputs = array_len(mctx->batches[0].inputs); + const size_t noutputs = array_len(mctx->batches[0].outputs); TF_Tensor* inputTensorsValues[ninputs]; TF_Output inputs[ninputs]; TF_Tensor* outputTensorsValues[noutputs]; TF_Output outputs[noutputs]; - for (size_t i=0 ; iinputs[i].tensor); + size_t batch_sizes[nbatches]; + size_t batch_offsets[nbatches]; + if (array_len(mctx->batches[0].inputs) > 0) { + for (size_t b=0; bbatches[b].inputs[0].tensor, 0); + } + batch_offsets[0] = 0; + for (size_t b=1; bbatches[b].inputs[i].tensor; + } + // inputTensorsValues[i] = RAI_TFTensorFromTensor(mctx->inputs[i].tensor); + inputTensorsValues[i] = RAI_TFTensorFromTensors(batched_input_tensors, nbatches); TF_Output port; - port.oper = TF_GraphOperationByName(mctx->model->model, mctx->inputs[i].name); + port.oper = TF_GraphOperationByName(mctx->model->model, mctx->batches[0].inputs[i].name); port.index = 0; if(port.oper == NULL){ return 1; @@ -380,7 +466,7 @@ int RAI_ModelRunTF(RAI_ModelRunCtx* mctx, RAI_Error *error) { for (size_t i=0 ; imodel->model, mctx->outputs[i].name); + port.oper = TF_GraphOperationByName(mctx->model->model, mctx->batches[0].outputs[i].name); port.index = 0; if(port.oper == NULL){ return 1; @@ -407,11 +493,13 @@ int RAI_ModelRunTF(RAI_ModelRunCtx* mctx, RAI_Error *error) { return 1; } - for(size_t i = 0 ; i < noutputs ; ++i) { - RAI_Tensor* output_tensor = RAI_TensorCreateFromTFTensor(outputTensorsValues[i]); + for(size_t i=0; ibatches[b].outputs[i].tensor = RAI_TensorGetShallowCopy(output_tensor); + RAI_TensorFree(output_tensor); + } TF_DeleteTensor(outputTensorsValues[i]); - mctx->outputs[i].tensor = RAI_TensorGetShallowCopy(output_tensor); - RAI_TensorFree(output_tensor); } // TODO: add (make sure we deallocate once) diff --git a/src/backends/tensorflow.h b/src/backends/tensorflow.h index 11ecc0ffc..b59cfdde7 100644 --- a/src/backends/tensorflow.h +++ b/src/backends/tensorflow.h @@ -8,7 +8,7 @@ int RAI_InitBackendTF(int (*get_api_fn)(const char *, void *)); -RAI_Model *RAI_ModelCreateTF(RAI_Backend backend, const char* devicestr, +RAI_Model *RAI_ModelCreateTF(RAI_Backend backend, const char* devicestr, RAI_ModelOpts opts, size_t ninputs, const char **inputs, size_t noutputs, const char **outputs, const char *modeldef, size_t modellen, diff --git a/src/backends/tflite.c b/src/backends/tflite.c index adb6bba8b..96d46d8e1 100644 --- a/src/backends/tflite.c +++ b/src/backends/tflite.c @@ -19,7 +19,7 @@ typedef struct RAI_TfLiteBuffer { size_t len; } RAI_TfLiteBuffer; -RAI_Model *RAI_ModelCreateTFLite(RAI_Backend backend, const char* devicestr, +RAI_Model *RAI_ModelCreateTFLite(RAI_Backend backend, const char* devicestr, RAI_ModelOpts opts, const char *modeldef, size_t modellen, RAI_Error *error) { DLDeviceType dl_device; @@ -67,6 +67,7 @@ RAI_Model *RAI_ModelCreateTFLite(RAI_Backend backend, const char* devicestr, ret->inputs = NULL; ret->outputs = NULL; ret->refCount = 1; + ret->opts = opts; ret->data = tflitebuffer; return ret; @@ -83,23 +84,61 @@ void RAI_ModelFreeTFLite(RAI_Model* model, RAI_Error *error) { int RAI_ModelRunTFLite(RAI_ModelRunCtx* mctx, RAI_Error *error) { - size_t ninputs = array_len(mctx->inputs); - size_t noutputs = array_len(mctx->outputs); + const size_t nbatches = array_len(mctx->batches); + if (nbatches == 0) { + RAI_SetError(error, RAI_EMODELRUN, "No batches to run\n"); + return 1; + } - DLManagedTensor* inputs[ninputs]; - DLManagedTensor* outputs[noutputs]; + const size_t ninputs = array_len(mctx->batches[0].inputs); + const size_t noutputs = array_len(mctx->batches[0].outputs); - for (size_t i=0 ; iinputs[i].tensor->tensor; + RAI_Tensor* inputs[ninputs]; + + DLManagedTensor* inputs_dl[ninputs]; + DLManagedTensor* outputs_dl[noutputs]; + + size_t batch_sizes[nbatches]; + size_t batch_offsets[nbatches]; + size_t total_batch_size = 0; + + if (nbatches > 1) { + if (array_len(mctx->batches[0].inputs) > 0) { + for (size_t b=0; bbatches[b].inputs[0].tensor, 0); + total_batch_size += batch_sizes[b]; + } + batch_offsets[0] = 0; + for (size_t b=1; bbatches[b].inputs[i].tensor; + } + + inputs[i] = RAI_TensorCreateByConcatenatingTensors(batch, nbatches); + inputs_dl[i] = &inputs[i]->tensor; + } + } + else { + for (size_t i=0 ; ibatches[0].inputs[i].tensor); + inputs_dl[i] = &inputs[i]->tensor; + } } for (size_t i=0 ; ioutputs[i].tensor ? &mctx->outputs[i].tensor->tensor : NULL; + outputs_dl[i] = NULL; } char* error_descr = NULL; tfliteRunModel(mctx->model->model, - ninputs, inputs, noutputs, outputs, + ninputs, inputs_dl, noutputs, outputs_dl, &error_descr, RedisModule_Alloc); if (error_descr != NULL) { @@ -108,17 +147,33 @@ int RAI_ModelRunTFLite(RAI_ModelRunCtx* mctx, RAI_Error *error) { return 1; } - for(size_t i=0 ; ioutputs) ; ++i) { - if (outputs[i] == NULL) { + for(size_t i=0 ; ioutputs[i].tensor = RAI_TensorGetShallowCopy(output_tensor); + RAI_Tensor* output_tensor = RAI_TensorCreateFromDLTensor(outputs_dl[i]); + if (nbatches > 1 && RAI_TensorDim(output_tensor, 0) != total_batch_size) { + RAI_TensorFree(output_tensor); + RAI_SetError(error, RAI_EMODELRUN, "Model did not generate the expected batch size."); + return 1; + } + if (nbatches > 1) { + for (size_t b=0; bbatches[b].outputs[i].tensor = RAI_TensorCreateBySlicingTensor(output_tensor, batch_offsets[b], batch_sizes[b]); + } + } + else { + mctx->batches[0].outputs[i].tensor = RAI_TensorGetShallowCopy(output_tensor); + } RAI_TensorFree(output_tensor); RedisModule_Free(outputs[i]); } + for (size_t i=0 ; idevicestr = RedisModule_Strdup(devicestr); ret->inputs = NULL; ret->outputs = NULL; + ret->opts = opts; ret->refCount = 1; return ret; @@ -68,24 +69,61 @@ void RAI_ModelFreeTorch(RAI_Model* model, RAI_Error *error) { } int RAI_ModelRunTorch(RAI_ModelRunCtx* mctx, RAI_Error *error) { + const size_t nbatches = array_len(mctx->batches); + if (nbatches == 0) { + RAI_SetError(error, RAI_EMODELRUN, "No batches to run\n"); + return 1; + } - size_t ninputs = array_len(mctx->inputs); - size_t noutputs = array_len(mctx->outputs); + const size_t ninputs = array_len(mctx->batches[0].inputs); + const size_t noutputs = array_len(mctx->batches[0].outputs); - DLManagedTensor* inputs[ninputs]; - DLManagedTensor* outputs[noutputs]; + RAI_Tensor* inputs[ninputs]; - for (size_t i=0 ; iinputs[i].tensor->tensor; + DLManagedTensor* inputs_dl[ninputs]; + DLManagedTensor* outputs_dl[noutputs]; + + size_t batch_sizes[nbatches]; + size_t batch_offsets[nbatches]; + + if (nbatches > 1) { + size_t total_batch_size = 0; + if (array_len(mctx->batches[0].inputs) > 0) { + for (size_t b=0; bbatches[b].inputs[0].tensor, 0); + total_batch_size += batch_sizes[b]; + } + batch_offsets[0] = 0; + for (size_t b=1; bbatches[b].inputs[i].tensor; + } + + inputs[i] = RAI_TensorCreateByConcatenatingTensors(batch, nbatches); + inputs_dl[i] = &inputs[i]->tensor; + } + } + else { + for (size_t i=0 ; ibatches[0].inputs[i].tensor); + inputs_dl[i] = &inputs[i]->tensor; + } } for (size_t i=0 ; ioutputs[i].tensor ? &mctx->outputs[i].tensor->tensor : NULL; + outputs_dl[i] = NULL; } char* error_descr = NULL; torchRunModel(mctx->model->model, - ninputs, inputs, noutputs, outputs, + ninputs, inputs_dl, noutputs, outputs_dl, &error_descr, RedisModule_Alloc); if (error_descr != NULL) { @@ -94,16 +132,27 @@ int RAI_ModelRunTorch(RAI_ModelRunCtx* mctx, RAI_Error *error) { return 1; } - for(size_t i=0 ; ioutputs) ; ++i) { - if (outputs[i] == NULL) { + for(size_t i=0 ; ioutputs[i].tensor = RAI_TensorGetShallowCopy(output_tensor); + RAI_Tensor* output_tensor = RAI_TensorCreateFromDLTensor(outputs_dl[i]); + if (nbatches > 1) { + for (size_t b=0; bbatches[b].outputs[i].tensor = RAI_TensorCreateBySlicingTensor(output_tensor, batch_offsets[b], batch_sizes[b]); + } + } + else { + mctx->batches[0].outputs[i].tensor = RAI_TensorGetShallowCopy(output_tensor); + } RAI_TensorFree(output_tensor); } + for (size_t i=0 ; idetail) { RedisModule_Free(err->detail); - RedisModule_Free(err->detail_oneline); err->detail = NULL; + } + if (err->detail_oneline) { + RedisModule_Free(err->detail_oneline); err->detail_oneline = NULL; } err->code = RAI_OK; diff --git a/src/model.c b/src/model.c index cd37cb339..225f07b16 100644 --- a/src/model.c +++ b/src/model.c @@ -16,14 +16,18 @@ static void* RAI_Model_RdbLoad(struct RedisModuleIO *io, int encver) { RAI_Backend backend = RedisModule_LoadUnsigned(io); const char *devicestr = RedisModule_LoadStringBuffer(io, NULL); - size_t ninputs = RedisModule_LoadUnsigned(io); + + const size_t batchsize = RedisModule_LoadUnsigned(io); + const size_t minbatchsize = RedisModule_LoadUnsigned(io); + + const size_t ninputs = RedisModule_LoadUnsigned(io); const char **inputs = RedisModule_Alloc(ninputs * sizeof(char*)); for (size_t i=0; ibackend); RedisModule_SaveStringBuffer(io, model->devicestr, strlen(model->devicestr) + 1); + RedisModule_SaveUnsigned(io, model->opts.batchsize); + RedisModule_SaveUnsigned(io, model->opts.minbatchsize); RedisModule_SaveUnsigned(io, model->ninputs); for (size_t i=0; ininputs; i++) { RedisModule_SaveStringBuffer(io, model->inputs[i], strlen(model->inputs[i]) + 1); @@ -137,9 +148,11 @@ static void RAI_Model_AofRewrite(RedisModuleIO *aof, RedisModuleString *key, voi const char* backendstr = RAI_BackendName(model->backend); - RedisModule_EmitAOF(aof, "AI.MODELSET", "slccvcvb", + RedisModule_EmitAOF(aof, "AI.MODELSET", "slcclclcvcvb", key, backendstr, model->devicestr, + "BATCHSIZE", model->opts.batchsize, + "MINBATCHSIZE", model->opts.minbatchsize, "INPUTS", inputs_, model->ninputs, "OUTPUTS", outputs_, model->noutputs, buffer, len); @@ -186,7 +199,7 @@ int RAI_ModelInit(RedisModuleCtx* ctx) { return RedisAI_ModelType != NULL; } -RAI_Model *RAI_ModelCreate(RAI_Backend backend, const char* devicestr, +RAI_Model *RAI_ModelCreate(RAI_Backend backend, const char* devicestr, RAI_ModelOpts opts, size_t ninputs, const char **inputs, size_t noutputs, const char **outputs, const char *modeldef, size_t modellen, RAI_Error* err) { @@ -196,28 +209,28 @@ RAI_Model *RAI_ModelCreate(RAI_Backend backend, const char* devicestr, RAI_SetError(err, RAI_EBACKENDNOTLOADED, "Backend not loaded: TF.\n"); return NULL; } - model = RAI_backends.tf.model_create_with_nodes(backend, devicestr, ninputs, inputs, noutputs, outputs, modeldef, modellen, err); + model = RAI_backends.tf.model_create_with_nodes(backend, devicestr, opts, ninputs, inputs, noutputs, outputs, modeldef, modellen, err); } else if (backend == RAI_BACKEND_TFLITE) { if (!RAI_backends.tflite.model_create) { RAI_SetError(err, RAI_EBACKENDNOTLOADED, "Backend not loaded: TFLITE.\n"); return NULL; } - model = RAI_backends.tflite.model_create(backend, devicestr, modeldef, modellen, err); + model = RAI_backends.tflite.model_create(backend, devicestr, opts, modeldef, modellen, err); } else if (backend == RAI_BACKEND_TORCH) { if (!RAI_backends.torch.model_create) { RAI_SetError(err, RAI_EBACKENDNOTLOADED, "Backend not loaded: TORCH.\n"); return NULL; } - model = RAI_backends.torch.model_create(backend, devicestr, modeldef, modellen, err); + model = RAI_backends.torch.model_create(backend, devicestr, opts, modeldef, modellen, err); } else if (backend == RAI_BACKEND_ONNXRUNTIME) { if (!RAI_backends.onnx.model_create) { RAI_SetError(err, RAI_EBACKENDNOTLOADED, "Backend not loaded: ONNX.\n"); return NULL; } - model = RAI_backends.onnx.model_create(backend, devicestr, modeldef, modellen, err); + model = RAI_backends.onnx.model_create(backend, devicestr, opts, modeldef, modellen, err); } else { RAI_SetError(err, RAI_EUNSUPPORTEDBACKEND, "Unsupported backend.\n"); @@ -268,11 +281,11 @@ void RAI_ModelFree(RAI_Model* model, RAI_Error* err) { } RAI_ModelRunCtx* RAI_ModelRunCtxCreate(RAI_Model* model) { -#define PARAM_INITIAL_SIZE 10 +#define BATCH_INITIAL_SIZE 10 RAI_ModelRunCtx* mctx = RedisModule_Calloc(1, sizeof(*mctx)); mctx->model = RAI_ModelGetShallowCopy(model); - mctx->inputs = array_new(RAI_ModelCtxParam, PARAM_INITIAL_SIZE); - mctx->outputs = array_new(RAI_ModelCtxParam, PARAM_INITIAL_SIZE); + mctx->batches = array_new(RAI_ModelCtxBatch, BATCH_INITIAL_SIZE); +#undef BATCH_INITIAL_SIZE return mctx; } @@ -287,42 +300,99 @@ static int Model_RunCtxAddParam(RAI_ModelRunCtx* mctx, RAI_ModelCtxParam** param return 1; } -int RAI_ModelRunCtxAddInput(RAI_ModelRunCtx* mctx, const char* inputName, RAI_Tensor* inputTensor) { - return Model_RunCtxAddParam(mctx, &mctx->inputs, inputName, inputTensor); +int RAI_ModelRunCtxAddInput(RAI_ModelRunCtx* mctx, size_t id, const char* inputName, RAI_Tensor* inputTensor) { + if (id >= RAI_ModelRunCtxNumBatches(mctx)) { + // TODO error + return 0; + } + return Model_RunCtxAddParam(mctx, &mctx->batches[id].inputs, inputName, inputTensor); } -int RAI_ModelRunCtxAddOutput(RAI_ModelRunCtx* mctx, const char* outputName) { - return Model_RunCtxAddParam(mctx, &mctx->outputs, outputName, NULL); +int RAI_ModelRunCtxAddOutput(RAI_ModelRunCtx* mctx, size_t id, const char* outputName) { + if (id >= RAI_ModelRunCtxNumBatches(mctx)) { + // TODO error + return 0; + } + return Model_RunCtxAddParam(mctx, &mctx->batches[id].outputs, outputName, NULL); +} + +size_t RAI_ModelRunCtxNumInputs(RAI_ModelRunCtx* mctx) { + if (RAI_ModelRunCtxNumBatches(mctx) == 0) { + return 0; + } + // Here we assume batch is well-formed (i.e. number of outputs is equal in all batches) + return array_len(mctx->batches[0].inputs); } size_t RAI_ModelRunCtxNumOutputs(RAI_ModelRunCtx* mctx) { - return array_len(mctx->outputs); + if (RAI_ModelRunCtxNumBatches(mctx) == 0) { + return 0; + } + // Here we assume batch is well-formed (i.e. number of outputs is equal in all batches) + return array_len(mctx->batches[0].outputs); +} + +int RAI_ModelRunCtxAddBatch(RAI_ModelRunCtx* mctx) { +#define PARAM_INITIAL_SIZE 10 + RAI_ModelCtxBatch batch = { + .inputs = array_new(RAI_ModelCtxParam, PARAM_INITIAL_SIZE), + .outputs = array_new(RAI_ModelCtxParam, PARAM_INITIAL_SIZE) + }; +#undef PARAM_INITIAL_SIZE + array_append(mctx->batches, batch); + return array_len(mctx->batches)-1; } -RAI_Tensor* RAI_ModelRunCtxOutputTensor(RAI_ModelRunCtx* mctx, size_t index) { +size_t RAI_ModelRunCtxNumBatches(RAI_ModelRunCtx* mctx) { + return array_len(mctx->batches); +} + +void RAI_ModelRunCtxCopyBatch(RAI_ModelRunCtx* dest, size_t id_dest, RAI_ModelRunCtx* src, size_t id_src) { + size_t ninputs = array_len(src->batches[id_src].inputs); + for (size_t i=0; ibatches[id_src].inputs[i]; + RAI_ModelRunCtxAddInput(dest, id_dest, param.name, param.tensor); + } + + size_t noutputs = array_len(src->batches[id_src].outputs); + for (size_t i=0; ibatches[id_src].outputs[i]; + RAI_ModelRunCtxAddOutput(dest, id_dest, param.name); + } +} + +RAI_Tensor* RAI_ModelRunCtxInputTensor(RAI_ModelRunCtx* mctx, size_t id, size_t index) { + // TODO: add method to collect from batches? + assert(RAI_ModelRunCtxNumInputs(mctx) > index && index >= 0); + return mctx->batches[id].inputs[index].tensor; +} + +RAI_Tensor* RAI_ModelRunCtxOutputTensor(RAI_ModelRunCtx* mctx, size_t id, size_t index) { + // TODO: add method to collect from batches? assert(RAI_ModelRunCtxNumOutputs(mctx) > index && index >= 0); - return mctx->outputs[index].tensor; + return mctx->batches[id].outputs[index].tensor; } void RAI_ModelRunCtxFree(RAI_ModelRunCtx* mctx) { - for (size_t i = 0 ; i < array_len(mctx->inputs) ; ++i) { - RAI_TensorFree(mctx->inputs[i].tensor); - } - array_free(mctx->inputs); + for (size_t b=0; bbatches); ++b) { + for (size_t i=0; ibatches[b].inputs); ++i) { + RAI_TensorFree(mctx->batches[b].inputs[i].tensor); + } + array_free(mctx->batches[b].inputs); - for (size_t i = 0 ; i < array_len(mctx->outputs) ; ++i) { - if (mctx->outputs[i].tensor) { - RAI_TensorFree(mctx->outputs[i].tensor); + for (size_t i = 0 ; i < array_len(mctx->batches[b].outputs) ; ++i) { + if (mctx->batches[b].outputs[i].tensor) { + RAI_TensorFree(mctx->batches[b].outputs[i].tensor); + } } + array_free(mctx->batches[b].outputs); } - array_free(mctx->outputs); RAI_Error err = {0}; RAI_ModelFree(mctx->model, &err); if (err.code != RAI_OK) { // TODO: take it to client somehow - printf("ERR: %s\n", err.detail); RAI_ClearError(&err); } diff --git a/src/model.h b/src/model.h index e15eb9413..ce71cf1e2 100644 --- a/src/model.h +++ b/src/model.h @@ -10,17 +10,23 @@ extern RedisModuleType *RedisAI_ModelType; int RAI_ModelInit(RedisModuleCtx* ctx); -RAI_Model *RAI_ModelCreate(RAI_Backend backend, const char* devicestr, +RAI_Model *RAI_ModelCreate(RAI_Backend backend, const char* devicestr, RAI_ModelOpts opts, size_t ninputs, const char **inputs, size_t noutputs, const char **outputs, const char *modeldef, size_t modellen, RAI_Error* err); void RAI_ModelFree(RAI_Model* model, RAI_Error* err); RAI_ModelRunCtx* RAI_ModelRunCtxCreate(RAI_Model* model); -int RAI_ModelRunCtxAddInput(RAI_ModelRunCtx* mctx, const char* inputName, RAI_Tensor* inputTensor); -int RAI_ModelRunCtxAddOutput(RAI_ModelRunCtx* mctx, const char* outputName); + +int RAI_ModelRunCtxAddBatch(RAI_ModelRunCtx* mctx); +size_t RAI_ModelRunCtxNumBatches(RAI_ModelRunCtx* mctx); +void RAI_ModelRunCtxCopyBatch(RAI_ModelRunCtx* dest, size_t id_dest, RAI_ModelRunCtx* src, size_t id_src); +int RAI_ModelRunCtxAddInput(RAI_ModelRunCtx* mctx, size_t id, const char* inputName, RAI_Tensor* inputTensor); +int RAI_ModelRunCtxAddOutput(RAI_ModelRunCtx* mctx, size_t id, const char* outputName); +size_t RAI_ModelRunCtxNumInputs(RAI_ModelRunCtx* mctx); size_t RAI_ModelRunCtxNumOutputs(RAI_ModelRunCtx* mctx); -RAI_Tensor* RAI_ModelRunCtxOutputTensor(RAI_ModelRunCtx* mctx, size_t index); +RAI_Tensor* RAI_ModelRunCtxInputTensor(RAI_ModelRunCtx* mctx, size_t id, size_t index); +RAI_Tensor* RAI_ModelRunCtxOutputTensor(RAI_ModelRunCtx* mctx, size_t id, size_t index); void RAI_ModelRunCtxFree(RAI_ModelRunCtx* mctx); int RAI_ModelRun(RAI_ModelRunCtx* mctx, RAI_Error* err); diff --git a/src/model_struct.h b/src/model_struct.h index e2be0b143..64542f0e2 100644 --- a/src/model_struct.h +++ b/src/model_struct.h @@ -4,6 +4,11 @@ #include "config.h" #include "tensor_struct.h" +typedef struct RAI_ModelOpts { + size_t batchsize; + size_t minbatchsize; +} RAI_ModelOpts; + typedef struct RAI_Model { void* model; // TODO: use session pool? The ideal would be to use one session per client. @@ -12,6 +17,7 @@ typedef struct RAI_Model { void *session; RAI_Backend backend; char* devicestr; + RAI_ModelOpts opts; char **inputs; size_t ninputs; char **outputs; @@ -25,11 +31,15 @@ typedef struct RAI_ModelCtxParam { RAI_Tensor* tensor; } RAI_ModelCtxParam; +typedef struct RAI_ModelCtxBatch { + RAI_ModelCtxParam* inputs; + RAI_ModelCtxParam* outputs; +} RAI_ModelCtxBatch; + typedef struct RAI_ModelRunCtx { size_t ctxtype; RAI_Model* model; - RAI_ModelCtxParam* inputs; - RAI_ModelCtxParam* outputs; + RAI_ModelCtxBatch* batches; } RAI_ModelRunCtx; #endif /* SRC_MODEL_STRUCT_H_ */ diff --git a/src/redisai.c b/src/redisai.c index 64e958f6f..e5ec1e797 100644 --- a/src/redisai.c +++ b/src/redisai.c @@ -21,6 +21,7 @@ typedef struct queueItem { struct queueItem *next; + struct queueItem *prev; void *value; } queueItem; @@ -50,11 +51,13 @@ void queuePush(queue *queue, void *value) { return; item->value = value; item->next = NULL; + item->prev = NULL; if (queue->len == 0) { queue->front = queue->back = item; } else { queue->back->next = item; + item->prev = queue->back; queue->back = item; } queue->len++; @@ -66,14 +69,48 @@ queueItem *queuePop(queue *queue) { return NULL; } queue->front = item->next; + if (queue->front != NULL) { + queue->front->prev = NULL; + } if (item == queue->back) { queue->back = NULL; } item->next = NULL; + item->prev = NULL; + queue->len--; + return item; +} + +queueItem *queueFront(queue *queue) { + return queue->front; +} + +queueItem *queueNext(queueItem *item) { + return item->next; +} + +queueItem *queueEvict(queue *queue, queueItem *item) { + if (item == queue->front) { + return queuePop(queue); + } + else if (item == queue->back) { + queue->back = item->prev; + queue->back->next = NULL; + } + else { + item->prev->next = item->next->prev; + } + + item->next = NULL; + item->prev = NULL; queue->len--; return item; } +long long queueLength(queue *queue) { + return queue->len; +} + void queueRelease(queue *queue) { unsigned long len; queueItem *current; @@ -199,7 +236,7 @@ int RedisAI_TensorSet_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv const char* typestr; AC_GetString(&ac, &typestr, NULL, 0); - size_t datasize = RAI_TensorGetDataSize(typestr); + size_t datasize = RAI_TensorDataSizeFromString(typestr); if (!datasize){ return RedisModule_ReplyWithError(ctx, "ERR invalid data type"); } @@ -527,25 +564,82 @@ void RedisAI_FreeRunStats(RedisModuleCtx *ctx, struct RedisAI_RunStats *rstats) RedisModule_Free(rstats->devicestr); } -void *RedisAI_RunSession(void *arg) { - struct RedisAI_RunInfo *rinfo = (struct RedisAI_RunInfo*)arg; - rinfo->err = RedisModule_Calloc(1, sizeof(RAI_Error)); - const long long start = ustime(); - if (rinfo->mctx) { - rinfo->status = RAI_ModelRun(rinfo->mctx, rinfo->err); +void *RedisAI_RunSession(struct RedisAI_RunInfo **batch_rinfo) { + if (array_len(batch_rinfo) == 0) { + return NULL; } - else if (rinfo->sctx) { - rinfo->status = RAI_ScriptRun(rinfo->sctx, rinfo->err); + + RAI_Error* err = RedisModule_Calloc(1, sizeof(RAI_Error)); + long long rtime; + int status; + RAI_ModelRunCtx* mctx = NULL; + RAI_ScriptRunCtx* sctx = NULL; + if (batch_rinfo[0]->mctx) { + mctx = RAI_ModelRunCtxCreate(batch_rinfo[0]->mctx->model); + for (long long i=0; imctx, 0); + } } - rinfo->duration_us = ustime()-start; + else if (batch_rinfo[0]->sctx) { + // No batching for scripts for now + sctx = batch_rinfo[0]->sctx; + } + + const long long start = ustime(); + if (mctx) { + status = RAI_ModelRun(mctx, err); + } + else if (sctx) { + status = RAI_ScriptRun(sctx, err); + } + rtime = ustime() - start; + + for (long long i=0; ibatches[i].outputs[o].tensor; + if (tensor) { + rinfo->mctx->batches[0].outputs[o].tensor = RAI_TensorGetShallowCopy(tensor); + } + else { + rinfo->mctx->batches[0].outputs[o].tensor = NULL; + } + } + } + else if (sctx) { + // No batching for scripts for now + } - if (rinfo->client != NULL) { - RedisModule_UnblockClient(rinfo->client, rinfo); + rinfo->status = status; + rinfo->err = RedisModule_Calloc(1, sizeof(RAI_Error)); + // TODO: add information on whether the call was batched + // and how large the batch was + rinfo->duration_us = rtime; + + rinfo->err->code = err->code; + if (err->code != RAI_OK) { + rinfo->err->detail = RedisModule_Strdup(err->detail); + rinfo->err->detail_oneline = RedisModule_Strdup(err->detail_oneline); + } + if (rinfo->client != NULL) { + RedisModule_UnblockClient(rinfo->client, rinfo); + } + } + + if (mctx) { + RAI_ModelRunCtxFree(mctx); + } + else if (sctx) { + // No batching for scripts for now } + return NULL; } -// key backend device [INPUTS name1 name2] [OUTPUTS name1 name2] modelbuf +// key backend device [BATCHSIZE n] [MINBATCHSIZE m] [INPUTS name1 name2] [OUTPUTS name1 name2] modelbuf int RedisAI_ModelSet_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) { RedisModule_AutoMemory(ctx); @@ -579,6 +673,26 @@ int RedisAI_ModelSet_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, const char* devicestr; AC_GetString(&ac, &devicestr, NULL, 0); + unsigned long long batchsize = 0; + if (AC_AdvanceIfMatch(&ac, "BATCHSIZE")) { + if (backend == RAI_BACKEND_TFLITE) { + return RedisModule_ReplyWithError(ctx, "Auto-batching not supported by the TFLITE backend."); + } + if (AC_GetUnsignedLongLong(&ac, &batchsize, 0) != AC_OK) { + return RedisModule_ReplyWithError(ctx, "Invalid argument for BATCHSIZE."); + } + } + + unsigned long long minbatchsize = 0; + if (AC_AdvanceIfMatch(&ac, "MINBATCHSIZE")) { + if (batchsize == 0) { + return RedisModule_ReplyWithError(ctx, "MINBATCHSIZE specified without BATCHSIZE."); + } + if (AC_GetUnsignedLongLong(&ac, &minbatchsize, 0) != AC_OK) { + return RedisModule_ReplyWithError(ctx, "Invalid argument for MINBATCHSIZE"); + } + } + ArgsCursor optionsac; AC_GetSliceToOffset(&ac, &optionsac, argc-2); @@ -617,6 +731,11 @@ int RedisAI_ModelSet_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, AC_GetString(&outac, outputs+i, NULL, 0); } + RAI_ModelOpts opts = { + .batchsize = batchsize, + .minbatchsize = minbatchsize + }; + RAI_Model *model = NULL; size_t modellen; @@ -625,7 +744,7 @@ int RedisAI_ModelSet_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, RAI_Error err = {0}; - model = RAI_ModelCreate(backend, devicestr, ninputs, inputs, noutputs, outputs, modeldef, modellen, &err); + model = RAI_ModelCreate(backend, devicestr, opts, ninputs, inputs, noutputs, outputs, modeldef, modellen, &err); if (err.code == RAI_EBACKENDNOTLOADED) { RedisModule_Log(ctx, "warning", "Backend %s not loaded, will try loading default backend\n", bckstr); @@ -637,7 +756,7 @@ int RedisAI_ModelSet_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, return ret; } RAI_ClearError(&err); - model = RAI_ModelCreate(backend, devicestr, ninputs, inputs, noutputs, outputs, modeldef, modellen, &err); + model = RAI_ModelCreate(backend, devicestr, opts, ninputs, inputs, noutputs, outputs, modeldef, modellen, &err); } if (err.code != RAI_OK) { @@ -651,7 +770,7 @@ int RedisAI_ModelSet_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, // TODO: if backend loaded, make sure there's a queue - if (ensureRunQueue(devicestr)==REDISMODULE_ERR) { + if (ensureRunQueue(devicestr) == REDISMODULE_ERR) { RAI_ModelFree(model, &err); if (err.code != RAI_OK) { #ifdef RAI_PRINT_BACKEND_ERRORS @@ -869,8 +988,8 @@ int RedisAI_Run_Reply(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) { if (rinfo->status) { RedisModule_Log(ctx, "warning", "ERR %s", rinfo->err->detail); - rstats->calls += 1; rstats->nerrors += 1; + rstats->calls += 1; int ret = RedisModule_ReplyWithError(ctx, rinfo->err->detail_oneline); RedisAI_FreeRunInfo(ctx, rinfo); return ret; @@ -901,7 +1020,7 @@ int RedisAI_Run_Reply(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) { } RAI_Tensor *t = NULL; if (rinfo->mctx) { - t = RAI_ModelRunCtxOutputTensor(rinfo->mctx, i); + t = RAI_ModelRunCtxOutputTensor(rinfo->mctx, 0, i); if (t && batch_size == 0) { batch_size = RAI_TensorDim(t, 0); } @@ -1025,6 +1144,8 @@ int RedisAI_ModelRun_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, rinfo->outkeys = NULL; rinfo->err = NULL; + RAI_ModelRunCtxAddBatch(rinfo->mctx); + for (size_t i=0; iinputs) { opname = mto->inputs[i]; } - if (!RAI_ModelRunCtxAddInput(rinfo->mctx, opname, t)) { + if (!RAI_ModelRunCtxAddInput(rinfo->mctx, 0, opname, t)) { // todo free rinfo return RedisModule_ReplyWithError(ctx, "Input key not found."); } @@ -1059,7 +1180,7 @@ int RedisAI_ModelRun_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, if (mto->outputs) { opname = mto->outputs[i]; } - if (!RAI_ModelRunCtxAddOutput(rinfo->mctx, opname)) { + if (!RAI_ModelRunCtxAddOutput(rinfo->mctx, 0, opname)) { // todo free rinfo return RedisModule_ReplyWithError(ctx, "Output key not found."); } @@ -1096,6 +1217,79 @@ int RedisAI_ModelRun_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, return REDISMODULE_OK; } +size_t RAI_RunInfoBatchSize(struct RedisAI_RunInfo* rinfo) { + if (rinfo->mctx == NULL) { + return -1; + } + + size_t ninputs = RAI_ModelRunCtxNumInputs(rinfo->mctx); + + int batchsize = 0; + + if (ninputs == 0) { + return batchsize; + } + + for (size_t i=0; imctx, 0, i); + + if (i == 0) { + batchsize = RAI_TensorDim(input, 0); + continue; + } + + if (batchsize != RAI_TensorDim(input, 0)) { + batchsize = 0; + break; + } + } + + return batchsize; +} + +int RAI_RunInfoBatchable(struct RedisAI_RunInfo* rinfo1, struct RedisAI_RunInfo* rinfo2) { + if (rinfo1->mctx == NULL || rinfo2->mctx == NULL) { + return 0; + } + + if (rinfo1->mctx->model != rinfo2->mctx->model) { + return 0; + } + + int ninputs1 = RAI_ModelRunCtxNumInputs(rinfo1->mctx); + int ninputs2 = RAI_ModelRunCtxNumInputs(rinfo2->mctx); + + if (ninputs1 != ninputs2) { + return 0; + } + + for (int i=0; imctx, 0, i); + RAI_Tensor* input2 = RAI_ModelRunCtxInputTensor(rinfo2->mctx, 0, i); + + int ndims1 = RAI_TensorNumDims(input1); + int ndims2 = RAI_TensorNumDims(input2); + + if (ndims1 != ndims2) { + return 0; + } + + if (ndims1 == 0) { + continue; + } + + for (int j=1; jrun_queue_mutex); while (true){ int rc = pthread_cond_wait(&run_queue_info->queue_condition_var, &run_queue_info->run_queue_mutex); - queueItem *item = NULL; - while ( (item = queuePop(run_queue_info->run_queue)) != NULL){ + + long long run_queue_len = queueLength(run_queue_info->run_queue); + + while (run_queue_len > 0) { + queueItem **evicted_items = NULL; + struct RedisAI_RunInfo **batch_rinfo = NULL; + + queueItem *item = queueFront(run_queue_info->run_queue); + + while (item) { + struct RedisAI_RunInfo *rinfo = (struct RedisAI_RunInfo *)item->value; + + if (evicted_items) { + array_free(evicted_items); + array_free(batch_rinfo); + } + evicted_items = array_new(queueItem *, run_queue_len); + batch_rinfo = array_new(struct RedisAI_RunInfo *, run_queue_len); + + array_append(evicted_items, item); + array_append(batch_rinfo, rinfo); + + if (rinfo->sctx) { + break; + } + + size_t batchsize = rinfo->mctx->model->opts.batchsize; + + if (batchsize == 0) { + break; + } + + size_t current_batchsize = RAI_RunInfoBatchSize(rinfo); + + if (current_batchsize == 0 || + current_batchsize >= batchsize) { + break; + } + + queueItem *next_item = item->next; + + while (next_item != NULL) { + struct RedisAI_RunInfo *next_rinfo = (struct RedisAI_RunInfo *)next_item->value; + + if (RAI_RunInfoBatchable(rinfo, next_rinfo) == 0) { + next_item = queueNext(next_item); + continue; + } + + int next_batchsize = RAI_RunInfoBatchSize(next_rinfo); + + if (current_batchsize + next_batchsize > batchsize) { + break; + } + + array_append(evicted_items, next_item); + array_append(batch_rinfo, next_rinfo); + + current_batchsize += next_batchsize; + next_item = queueNext(next_item); + } + + size_t minbatchsize = rinfo->mctx->model->opts.minbatchsize; + + if (minbatchsize == 0 || current_batchsize >= minbatchsize) { + break; + } + + item = item->next; + } + + if (item == NULL) { + array_free(evicted_items); + array_free(batch_rinfo); + pthread_mutex_unlock(&run_queue_info->run_queue_mutex); + break; + } + + for (long long i=0; irun_queue, evicted_items[i]); + } + pthread_mutex_unlock(&run_queue_info->run_queue_mutex); - RedisAI_RunSession(item->value); - RedisModule_Free(item); + + RedisAI_RunSession(batch_rinfo); + + for (long long i=0; irun_queue_mutex); + + run_queue_len = queueLength(run_queue_info->run_queue); } - } } @@ -1606,7 +1888,7 @@ static int RedisAI_RegisterApi(RedisModuleCtx* ctx) { REGISTER_API(GetLLAPIVersion, ctx); REGISTER_API(TensorCreate, ctx); - REGISTER_API(TensorGetDataSize, ctx); + REGISTER_API(TensorDataSize, ctx); REGISTER_API(TensorFree, ctx); REGISTER_API(TensorSetData, ctx); REGISTER_API(TensorSetValueFromLongLong, ctx); diff --git a/src/redisai.h b/src/redisai.h index 543619ecb..29b7f86e6 100644 --- a/src/redisai.h +++ b/src/redisai.h @@ -32,8 +32,10 @@ typedef struct RAI_Error RAI_Error; #define REDISAI_INFOMSG_THREADS_PER_QUEUE "Setting THREADS_PER_QUEUE parameter to" RAI_Tensor* MODULE_API_FUNC(RedisAI_TensorCreate)(const char* dataTypeStr, long long* dims, int ndims); +RAI_Tensor* MODULE_API_FUNC(RedisAI_TensorCreateByConcatenatingTensors)(RAI_Tensor** ts, long long n); +RAI_Tensor* MODULE_API_FUNC(RedisAI_TensorCreateBySlicingTensor)(RAI_Tensor* t, long long offset, long long len); size_t MODULE_API_FUNC(RedisAI_TensorLength)(RAI_Tensor* t); -size_t MODULE_API_FUNC(RedisAI_TensorGetDataSize)(const char* dataTypeStr); +size_t MODULE_API_FUNC(RedisAI_TensorDataSize)(RAI_Tensor* t); size_t MODULE_API_FUNC(RedisAI_TensorDataType)(RAI_Tensor* t); void MODULE_API_FUNC(RedisAI_TensorFree)(RAI_Tensor* t); int MODULE_API_FUNC(RedisAI_TensorSetData)(RAI_Tensor* tensor, const char* data, size_t len); @@ -47,7 +49,7 @@ long long MODULE_API_FUNC(RedisAI_TensorDim)(RAI_Tensor* t, int dim); size_t MODULE_API_FUNC(RedisAI_TensorByteSize)(RAI_Tensor* t); char* MODULE_API_FUNC(RedisAI_TensorData)(RAI_Tensor* t); -RAI_Model* MODULE_API_FUNC(RedisAI_ModelCreate)(int backend, char* devicestr, +RAI_Model* MODULE_API_FUNC(RedisAI_ModelCreate)(int backend, char* devicestr, RAI_ModelOpts opts, size_t ninputs, const char **inputs, size_t noutputs, const char **outputs, const char *modeldef, size_t modellen, RAI_Error* err); @@ -92,7 +94,9 @@ static int RedisAI_Initialize(RedisModuleCtx* ctx){ REDISAI_MODULE_INIT_FUNCTION(ctx, GetLLAPIVersion); REDISAI_MODULE_INIT_FUNCTION(ctx, TensorCreate); - REDISAI_MODULE_INIT_FUNCTION(ctx, TensorGetDataSize); + REDISAI_MODULE_INIT_FUNCTION(ctx, TensorCreateByConcatenatingTensors); + REDISAI_MODULE_INIT_FUNCTION(ctx, TensorCreateBySlicingTensor); + REDISAI_MODULE_INIT_FUNCTION(ctx, TensorDataSize); REDISAI_MODULE_INIT_FUNCTION(ctx, TensorFree); REDISAI_MODULE_INIT_FUNCTION(ctx, TensorSetData); REDISAI_MODULE_INIT_FUNCTION(ctx, TensorSetValueFromLongLong); diff --git a/src/tensor.c b/src/tensor.c index 2ca2ce1dc..e189d4a0e 100644 --- a/src/tensor.c +++ b/src/tensor.c @@ -8,7 +8,7 @@ RedisModuleType *RedisAI_TensorType = NULL; -static DLDataType Tensor_GetDataType(const char* typestr){ +DLDataType RAI_TensorDataTypeFromString(const char* typestr){ if (strcasecmp(typestr, "FLOAT") == 0){ return (DLDataType){ .code = kDLFloat, .bits = 32, .lanes = 1}; } @@ -223,8 +223,7 @@ int RAI_TensorInit(RedisModuleCtx* ctx){ return RedisAI_TensorType != NULL; } -RAI_Tensor* RAI_TensorCreate(const char* dataTypeStr, long long* dims, int ndims, int hasdata) { - DLDataType dtype = Tensor_GetDataType(dataTypeStr); +RAI_Tensor* RAI_TensorCreateWithDLDataType(DLDataType dtype, long long* dims, int ndims, int hasdata) { const size_t dtypeSize = Tensor_DataTypeSize(dtype); if ( dtypeSize == 0){ return NULL; @@ -279,6 +278,11 @@ RAI_Tensor* RAI_TensorCreate(const char* dataTypeStr, long long* dims, int ndims return ret; } +RAI_Tensor* RAI_TensorCreate(const char* dataType, long long* dims, int ndims, int hasdata) { + DLDataType dtype = RAI_TensorDataTypeFromString(dataType); + return RAI_TensorCreateWithDLDataType(dtype, dims, ndims, hasdata); +} + #if 0 void RAI_TensorMoveFrom(RAI_Tensor* dst, RAI_Tensor* src) { if (--dst->refCount <= 0){ @@ -296,6 +300,76 @@ void RAI_TensorMoveFrom(RAI_Tensor* dst, RAI_Tensor* src) { } #endif +RAI_Tensor* RAI_TensorCreateByConcatenatingTensors(RAI_Tensor** ts, long long n) { + + if (n == 0) { + return NULL; + } + + long long total_batch_size = 0; + long long batch_sizes[n]; + long long batch_offsets[n]; + + const long long ndims = RAI_TensorNumDims(ts[0]); + long long dims[ndims]; + + // TODO check that all tensors have compatible dims + + for (long long i=0; i Date: Sun, 8 Mar 2020 23:09:53 +0100 Subject: [PATCH 2/7] Fix TFLite and tests after rebase --- src/backends/tflite.c | 2 +- test/includes.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/src/backends/tflite.c b/src/backends/tflite.c index 96d46d8e1..8e5d1e8d9 100644 --- a/src/backends/tflite.c +++ b/src/backends/tflite.c @@ -167,7 +167,7 @@ int RAI_ModelRunTFLite(RAI_ModelRunCtx* mctx, RAI_Error *error) { mctx->batches[0].outputs[i].tensor = RAI_TensorGetShallowCopy(output_tensor); } RAI_TensorFree(output_tensor); - RedisModule_Free(outputs[i]); + RedisModule_Free(outputs_dl[i]); } for (size_t i=0 ; i Date: Sat, 14 Mar 2020 23:34:35 +0100 Subject: [PATCH 3/7] Temporarily disable macos CI build --- .circleci/config.yml | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/.circleci/config.yml b/.circleci/config.yml index bcc6afdbb..99f71cf33 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -195,12 +195,12 @@ workflows: only: /.*/ tags: only: /.*/ - - build-macos: - filters: - branches: - ignore: /.*/ - tags: - only: /^v[0-9].*/ + # - build-macos: + # filters: + # branches: + # ignore: /.*/ + # tags: + # only: /^v[0-9].*/ #- build-multiarch-docker: # filters: # tags: From c5f56c807e4107c353529c3e570d9183838f2a95 Mon Sep 17 00:00:00 2001 From: Luca Antiga Date: Fri, 20 Mar 2020 10:51:44 +0100 Subject: [PATCH 4/7] Add synchronization to autobatch tests --- test/tests_onnx.py | 4 ++++ test/tests_pytorch.py | 4 ++++ test/tests_tensorflow.py | 4 ++++ 3 files changed, 12 insertions(+) diff --git a/test/tests_onnx.py b/test/tests_onnx.py index 3e0327c92..e95388f2d 100644 --- a/test/tests_onnx.py +++ b/test/tests_onnx.py @@ -147,6 +147,8 @@ def test_onnx_modelrun_mnist_autobatch(env): con.execute_command('AI.TENSORSET', 'a', 'FLOAT', 1, 1, 28, 28, 'BLOB', sample_raw) con.execute_command('AI.TENSORSET', 'c', 'FLOAT', 1, 1, 28, 28, 'BLOB', sample_raw) + ensureSlaveSynced(con, env) + def run(): con = env.getConnection() con.execute_command('AI.MODELRUN', 'm', 'INPUTS', 'c', 'OUTPUTS', 'd') @@ -156,6 +158,8 @@ def run(): con.execute_command('AI.MODELRUN', 'm', 'INPUTS', 'a', 'OUTPUTS', 'b') + ensureSlaveSynced(con, env) + tensor = con.execute_command('AI.TENSORGET', 'b', 'VALUES') values = tensor[-1] argmax = max(range(len(values)), key=lambda i: values[i]) diff --git a/test/tests_pytorch.py b/test/tests_pytorch.py index bb8278e09..9855dc98e 100644 --- a/test/tests_pytorch.py +++ b/test/tests_pytorch.py @@ -139,6 +139,8 @@ def test_pytorch_modelrun_autobatch(env): con.execute_command('AI.TENSORSET', 'd', 'FLOAT', 2, 2, 'VALUES', 2, 3, 2, 3) con.execute_command('AI.TENSORSET', 'e', 'FLOAT', 2, 2, 'VALUES', 2, 3, 2, 3) + ensureSlaveSynced(con, env) + def run(): con = env.getConnection() con.execute_command('AI.MODELRUN', 'm', 'INPUTS', 'd', 'e', 'OUTPUTS', 'f') @@ -148,6 +150,8 @@ def run(): con.execute_command('AI.MODELRUN', 'm', 'INPUTS', 'a', 'b', 'OUTPUTS', 'c') + ensureSlaveSynced(con, env) + tensor = con.execute_command('AI.TENSORGET', 'c', 'VALUES') values = tensor[-1] env.assertEqual(values, [b'4', b'6', b'4', b'6']) diff --git a/test/tests_tensorflow.py b/test/tests_tensorflow.py index 62129dbb0..4b344e02a 100644 --- a/test/tests_tensorflow.py +++ b/test/tests_tensorflow.py @@ -388,6 +388,8 @@ def test_run_tf_model_autobatch(env): con.execute_command('AI.TENSORSET', 'd', 'FLOAT', 2, 2, 'VALUES', 2, 3, 2, 3) con.execute_command('AI.TENSORSET', 'e', 'FLOAT', 2, 2, 'VALUES', 2, 3, 2, 3) + ensureSlaveSynced(con, env) + def run(): con = env.getConnection() con.execute_command('AI.MODELRUN', 'm', 'INPUTS', 'd', 'e', 'OUTPUTS', 'f') @@ -397,6 +399,8 @@ def run(): con.execute_command('AI.MODELRUN', 'm', 'INPUTS', 'a', 'b', 'OUTPUTS', 'c') + ensureSlaveSynced(con, env) + tensor = con.execute_command('AI.TENSORGET', 'c', 'VALUES') values = tensor[-1] env.assertEqual(values, [b'4', b'9', b'4', b'9']) From 52b7e4a07a35755c9313795bc682d7bae54e436e Mon Sep 17 00:00:00 2001 From: Luca Antiga Date: Fri, 20 Mar 2020 12:41:36 +0100 Subject: [PATCH 5/7] Add synchronization to autobatch thread --- test/tests_onnx.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/test/tests_onnx.py b/test/tests_onnx.py index e95388f2d..604df687b 100644 --- a/test/tests_onnx.py +++ b/test/tests_onnx.py @@ -160,6 +160,9 @@ def run(): ensureSlaveSynced(con, env) + import time + time.sleep(1) + tensor = con.execute_command('AI.TENSORGET', 'b', 'VALUES') values = tensor[-1] argmax = max(range(len(values)), key=lambda i: values[i]) From 9ab3cec3bebcd9989f95993ba8701a15f16d63b3 Mon Sep 17 00:00:00 2001 From: Luca Antiga Date: Fri, 20 Mar 2020 12:55:12 +0100 Subject: [PATCH 6/7] Add synchronization to autobatch thread --- test/tests_pytorch.py | 1 + test/tests_tensorflow.py | 1 + 2 files changed, 2 insertions(+) diff --git a/test/tests_pytorch.py b/test/tests_pytorch.py index 9855dc98e..8b6aab549 100644 --- a/test/tests_pytorch.py +++ b/test/tests_pytorch.py @@ -144,6 +144,7 @@ def test_pytorch_modelrun_autobatch(env): def run(): con = env.getConnection() con.execute_command('AI.MODELRUN', 'm', 'INPUTS', 'd', 'e', 'OUTPUTS', 'f') + ensureSlaveSynced(con, env) t = threading.Thread(target=run) t.start() diff --git a/test/tests_tensorflow.py b/test/tests_tensorflow.py index 4b344e02a..b27c13caf 100644 --- a/test/tests_tensorflow.py +++ b/test/tests_tensorflow.py @@ -393,6 +393,7 @@ def test_run_tf_model_autobatch(env): def run(): con = env.getConnection() con.execute_command('AI.MODELRUN', 'm', 'INPUTS', 'd', 'e', 'OUTPUTS', 'f') + ensureSlaveSynced(con, env) t = threading.Thread(target=run) t.start() From d468198c0b3066dee09e072dbbb93a32ce3eb764 Mon Sep 17 00:00:00 2001 From: Sherin Thomas Date: Sat, 28 Mar 2020 07:15:35 +0530 Subject: [PATCH 7/7] Batching crashtest (#310) * test cases for crash test * Fix issue with evict. Port test to multiprocessing to allow killing pending command. * Use terminate instead of kill Co-authored-by: Luca Antiga --- src/redisai.c | 3 +- test/tests_tensorflow.py | 182 ++++++++++++++++++++++++++++----------- 2 files changed, 132 insertions(+), 53 deletions(-) diff --git a/src/redisai.c b/src/redisai.c index e5ec1e797..89ae9064d 100644 --- a/src/redisai.c +++ b/src/redisai.c @@ -98,7 +98,8 @@ queueItem *queueEvict(queue *queue, queueItem *item) { queue->back->next = NULL; } else { - item->prev->next = item->next->prev; + item->prev->next = item->next; + item->next->prev = item->prev; } item->next = NULL; diff --git a/test/tests_tensorflow.py b/test/tests_tensorflow.py index b27c13caf..ff1f7e077 100644 --- a/test/tests_tensorflow.py +++ b/test/tests_tensorflow.py @@ -1,4 +1,6 @@ import redis +from functools import wraps +import multiprocessing as mp from includes import * @@ -7,11 +9,19 @@ ''' -def test_run_mobilenet(env): - if not TEST_TF: - env.debugPrint("skipping {} since TEST_TF=0".format(sys._getframe().f_code.co_name), force=True) - return +def skip_if_no_TF(f): + @wraps(f) + def wrapper(env, *args, **kwargs): + if not TEST_TF: + env.debugPrint("skipping {} since TEST_TF=0".format( + sys._getframe().f_code.co_name), force=True) + return + return f(env, *args, **kwargs) + return wrapper + +@skip_if_no_TF +def test_run_mobilenet(env): con = env.getConnection() input_var = 'input' @@ -24,13 +34,16 @@ def test_run_mobilenet(env): ensureSlaveSynced(con, env) - mobilenet_model_serialized = con.execute_command('AI.MODELGET', 'mobilenet') + mobilenet_model_serialized = con.execute_command( + 'AI.MODELGET', 'mobilenet') ensureSlaveSynced(con, env) if env.useSlaves: con2 = env.getSlaveConnection() - slave_mobilenet_model_serialized = con2.execute_command('AI.MODELGET', 'mobilenet') - env.assertEqual(len(mobilenet_model_serialized), len(slave_mobilenet_model_serialized)) + slave_mobilenet_model_serialized = con2.execute_command( + 'AI.MODELGET', 'mobilenet') + env.assertEqual(len(mobilenet_model_serialized), + len(slave_mobilenet_model_serialized)) con.execute_command('AI.TENSORSET', 'input', 'FLOAT', 1, img.shape[1], img.shape[0], img.shape[2], @@ -38,12 +51,14 @@ def test_run_mobilenet(env): ensureSlaveSynced(con, env) input_tensor_meta = con.execute_command('AI.TENSORGET', 'input', 'META') - env.assertEqual([b'FLOAT', [1, img.shape[1], img.shape[0], img.shape[2]]], input_tensor_meta) + env.assertEqual( + [b'FLOAT', [1, img.shape[1], img.shape[0], img.shape[2]]], input_tensor_meta) ensureSlaveSynced(con, env) if env.useSlaves: con2 = env.getSlaveConnection() - slave_tensor_meta = con2.execute_command('AI.TENSORGET', 'input', 'META') + slave_tensor_meta = con2.execute_command( + 'AI.TENSORGET', 'input', 'META') env.assertEqual(input_tensor_meta, slave_tensor_meta) con.execute_command('AI.MODELRUN', 'mobilenet', @@ -63,19 +78,18 @@ def test_run_mobilenet(env): if env.useSlaves: con2 = env.getSlaveConnection() - slave_dtype, slave_shape, slave_data = con2.execute_command('AI.TENSORGET', 'output', 'BLOB') + slave_dtype, slave_shape, slave_data = con2.execute_command( + 'AI.TENSORGET', 'output', 'BLOB') env.assertEqual(dtype, slave_dtype) env.assertEqual(shape, slave_shape) env.assertEqual(data, slave_data) +@skip_if_no_TF def test_run_mobilenet_multiproc(env): - if not TEST_TF: - env.debugPrint("skipping {} since TEST_TF=0".format(sys._getframe().f_code.co_name), force=True) - return - if VALGRIND: - env.debugPrint("skipping {} since VALGRIND=1".format(sys._getframe().f_code.co_name), force=True) + env.debugPrint("skipping {} since VALGRIND=1".format( + sys._getframe().f_code.co_name), force=True) return con = env.getConnection() @@ -106,17 +120,15 @@ def test_run_mobilenet_multiproc(env): if env.useSlaves: con2 = env.getSlaveConnection() - slave_dtype, slave_shape, slave_data = con2.execute_command('AI.TENSORGET', 'output', 'BLOB') + slave_dtype, slave_shape, slave_data = con2.execute_command( + 'AI.TENSORGET', 'output', 'BLOB') env.assertEqual(dtype, slave_dtype) env.assertEqual(shape, slave_shape) env.assertEqual(data, slave_data) +@skip_if_no_TF def test_del_tf_model(env): - if not TEST_TF: - env.debugPrint("skipping {} since TEST_TF=0".format(sys._getframe().f_code.co_name), force=True) - return - con = env.getConnection() test_data_path = os.path.join(os.path.dirname(__file__), 'test_data') @@ -154,14 +166,12 @@ def test_del_tf_model(env): except Exception as e: exception = e env.assertEqual(type(exception), redis.exceptions.ResponseError) - env.assertEqual("WRONGTYPE Operation against a key holding the wrong kind of value", exception.__str__()) + env.assertEqual( + "WRONGTYPE Operation against a key holding the wrong kind of value", exception.__str__()) +@skip_if_no_TF def test_run_tf_model(env): - if not TEST_TF: - env.debugPrint("skipping {} since TEST_TF=0".format(sys._getframe().f_code.co_name), force=True) - return - con = env.getConnection() test_data_path = os.path.join(os.path.dirname(__file__), 'test_data') @@ -182,8 +192,10 @@ def test_run_tf_model(env): # env.assertEqual(ret[0], b'TF') # env.assertEqual(ret[1], b'CPU') - con.execute_command('AI.TENSORSET', 'a', 'FLOAT', 2, 2, 'VALUES', 2, 3, 2, 3) - con.execute_command('AI.TENSORSET', 'b', 'FLOAT', 2, 2, 'VALUES', 2, 3, 2, 3) + con.execute_command('AI.TENSORSET', 'a', 'FLOAT', + 2, 2, 'VALUES', 2, 3, 2, 3) + con.execute_command('AI.TENSORSET', 'b', 'FLOAT', + 2, 2, 'VALUES', 2, 3, 2, 3) ensureSlaveSynced(con, env) @@ -217,11 +229,8 @@ def test_run_tf_model(env): env.assertFalse(con2.execute_command('EXISTS', 'm')) +@skip_if_no_TF def test_run_tf_model_errors(env): - if not TEST_TF: - env.debugPrint("skipping {} since TEST_TF=0".format(sys._getframe().f_code.co_name), force=True) - return - con = env.getConnection() test_data_path = os.path.join(os.path.dirname(__file__), 'test_data') @@ -245,7 +254,8 @@ def test_run_tf_model_errors(env): except Exception as e: exception = e env.assertEqual(type(exception), redis.exceptions.ResponseError) - env.assertEqual("wrong number of arguments for 'AI.MODELGET' command", exception.__str__()) + env.assertEqual( + "wrong number of arguments for 'AI.MODELGET' command", exception.__str__()) # ERR WRONGTYPE con.execute_command('SET', 'NOT_MODEL', 'BAR') @@ -254,7 +264,8 @@ def test_run_tf_model_errors(env): except Exception as e: exception = e env.assertEqual(type(exception), redis.exceptions.ResponseError) - env.assertEqual("WRONGTYPE Operation against a key holding the wrong kind of value", exception.__str__()) + env.assertEqual( + "WRONGTYPE Operation against a key holding the wrong kind of value", exception.__str__()) # cleanup con.execute_command('DEL', 'NOT_MODEL') @@ -365,6 +376,7 @@ def test_run_tf_model_errors(env): env.assertEqual(type(exception), redis.exceptions.ResponseError) +@skip_if_no_TF def test_run_tf_model_autobatch(env): if not TEST_PT: return @@ -382,17 +394,22 @@ def test_run_tf_model_autobatch(env): 'INPUTS', 'a', 'b', 'OUTPUTS', 'mul', model_pb) env.assertEqual(ret, b'OK') - con.execute_command('AI.TENSORSET', 'a', 'FLOAT', 2, 2, 'VALUES', 2, 3, 2, 3) - con.execute_command('AI.TENSORSET', 'b', 'FLOAT', 2, 2, 'VALUES', 2, 3, 2, 3) + con.execute_command('AI.TENSORSET', 'a', 'FLOAT', + 2, 2, 'VALUES', 2, 3, 2, 3) + con.execute_command('AI.TENSORSET', 'b', 'FLOAT', + 2, 2, 'VALUES', 2, 3, 2, 3) - con.execute_command('AI.TENSORSET', 'd', 'FLOAT', 2, 2, 'VALUES', 2, 3, 2, 3) - con.execute_command('AI.TENSORSET', 'e', 'FLOAT', 2, 2, 'VALUES', 2, 3, 2, 3) + con.execute_command('AI.TENSORSET', 'd', 'FLOAT', + 2, 2, 'VALUES', 2, 3, 2, 3) + con.execute_command('AI.TENSORSET', 'e', 'FLOAT', + 2, 2, 'VALUES', 2, 3, 2, 3) ensureSlaveSynced(con, env) def run(): con = env.getConnection() - con.execute_command('AI.MODELRUN', 'm', 'INPUTS', 'd', 'e', 'OUTPUTS', 'f') + con.execute_command('AI.MODELRUN', 'm', 'INPUTS', + 'd', 'e', 'OUTPUTS', 'f') ensureSlaveSynced(con, env) t = threading.Thread(target=run) @@ -411,11 +428,8 @@ def run(): env.assertEqual(values, [b'4', b'9', b'4', b'9']) +@skip_if_no_TF def test_tensorflow_modelinfo(env): - if not TEST_TF: - env.debugPrint("skipping {} since TEST_TF=0".format(sys._getframe().f_code.co_name), force=True) - return - con = env.getConnection() test_data_path = os.path.join(os.path.dirname(__file__), 'test_data') @@ -428,17 +442,20 @@ def test_tensorflow_modelinfo(env): 'INPUTS', 'a', 'b', 'OUTPUTS', 'mul', model_pb) env.assertEqual(ret, b'OK') - ret = con.execute_command('AI.TENSORSET', 'a', 'FLOAT', 2, 2, 'VALUES', 2, 3, 2, 3) + ret = con.execute_command( + 'AI.TENSORSET', 'a', 'FLOAT', 2, 2, 'VALUES', 2, 3, 2, 3) env.assertEqual(ret, b'OK') - ret = con.execute_command('AI.TENSORSET', 'b', 'FLOAT', 2, 2, 'VALUES', 2, 3, 2, 3) + ret = con.execute_command( + 'AI.TENSORSET', 'b', 'FLOAT', 2, 2, 'VALUES', 2, 3, 2, 3) env.assertEqual(ret, b'OK') ensureSlaveSynced(con, env) previous_duration = 0 for call in range(1, 10): - ret = con.execute_command('AI.MODELRUN', 'm', 'INPUTS', 'a', 'b', 'OUTPUTS', 'c') + ret = con.execute_command( + 'AI.MODELRUN', 'm', 'INPUTS', 'a', 'b', 'OUTPUTS', 'c') env.assertEqual(ret, b'OK') ensureSlaveSynced(con, env) @@ -466,11 +483,8 @@ def test_tensorflow_modelinfo(env): env.assertEqual(info_dict_0['ERRORS'], 0) +@skip_if_no_TF def test_tensorflow_modelrun_disconnect(env): - if not TEST_TF: - env.debugPrint("skipping {} since TEST_TF=0".format(sys._getframe().f_code.co_name), force=True) - return - red = env.getConnection() test_data_path = os.path.join(os.path.dirname(__file__), 'test_data') @@ -483,13 +497,77 @@ def test_tensorflow_modelrun_disconnect(env): 'INPUTS', 'a', 'b', 'OUTPUTS', 'mul', model_pb) env.assertEqual(ret, b'OK') - ret = red.execute_command('AI.TENSORSET', 'a', 'FLOAT', 2, 2, 'VALUES', 2, 3, 2, 3) + ret = red.execute_command( + 'AI.TENSORSET', 'a', 'FLOAT', 2, 2, 'VALUES', 2, 3, 2, 3) env.assertEqual(ret, b'OK') - ret = red.execute_command('AI.TENSORSET', 'b', 'FLOAT', 2, 2, 'VALUES', 2, 3, 2, 3) + ret = red.execute_command( + 'AI.TENSORSET', 'b', 'FLOAT', 2, 2, 'VALUES', 2, 3, 2, 3) env.assertEqual(ret, b'OK') ensureSlaveSynced(red, env) - ret = send_and_disconnect(('AI.MODELRUN', 'm', 'INPUTS', 'a', 'b', 'OUTPUTS', 'c'), red) + ret = send_and_disconnect( + ('AI.MODELRUN', 'm', 'INPUTS', 'a', 'b', 'OUTPUTS', 'c'), red) env.assertEqual(ret, None) + + +@skip_if_no_TF +def test_with_batch_and_minbatch(env): + con = env.getConnection() + batch_size = 2 + minbatch_size = 2 + model_name = 'model' + another_model_name = 'another_model' + inputvar = 'input' + outputvar = 'MobilenetV2/Predictions/Reshape_1' + + model_pb, labels, img = load_mobilenet_test_data() + + con.execute_command('AI.MODELSET', model_name, 'TF', DEVICE, + 'BATCHSIZE', batch_size, 'MINBATCHSIZE', minbatch_size, + 'INPUTS', inputvar, + 'OUTPUTS', outputvar, + model_pb) + con.execute_command('AI.TENSORSET', 'input', + 'FLOAT', 1, img.shape[1], img.shape[0], img.shape[2], + 'BLOB', img.tobytes()) + + def run(name=model_name, output_name='output'): + con.execute_command('AI.MODELRUN', name, + 'INPUTS', 'input', 'OUTPUTS', output_name) + + # Running thrice since minbatchsize = 2 + p1 = mp.Process(target=run) + p1.start() + p2 = mp.Process(target=run) + p2.start() + p3 = mp.Process(target=run) + p3.start() + + time.sleep(3) + + con.execute_command('AI.MODELSET', another_model_name, 'TF', DEVICE, + 'BATCHSIZE', batch_size, 'MINBATCHSIZE', minbatch_size, + 'INPUTS', inputvar, + 'OUTPUTS', outputvar, + model_pb) + + p1 = mp.Process(target=run, args=(another_model_name, 'final1')) + p1.start() + p2 = mp.Process(target=run, args=(another_model_name, 'final2')) + p2.start() + + time.sleep(3) + + dtype, shape, data = con.execute_command('AI.TENSORGET', 'final1', 'BLOB') + dtype_map = {b'FLOAT': np.float32} + tensor = np.frombuffer(data, dtype=dtype_map[dtype]).reshape(shape) + label_id = np.argmax(tensor) - 1 + + _, label = labels[str(label_id)] + + env.assertEqual(label, 'giant_panda') + + p3.terminate() +