Skip to content

Commit b360284

Browse files
lantigafilipecosta90
authored andcommitted
Avoid splitting outputs in batches when nbatches == 1 (#406)
* Avoid splitting outputs in batches when nbatches == 1 * Add batch size checks * Fix batch checks * Update readies * Add bad batching test
1 parent b66585b commit b360284

File tree

6 files changed

+140
-10
lines changed

6 files changed

+140
-10
lines changed

src/backends/onnxruntime.c

Lines changed: 46 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,7 @@ OrtValue* RAI_OrtValueFromTensors(RAI_Tensor** ts, size_t count, RAI_Error *erro
163163
return NULL;
164164
}
165165

166-
RAI_Tensor* RAI_TensorCreateFromOrtValue(OrtValue* v, size_t batch_offset, size_t batch_size, RAI_Error *error) {
166+
RAI_Tensor* RAI_TensorCreateFromOrtValue(OrtValue* v, size_t batch_offset, long long batch_size, RAI_Error *error) {
167167
OrtStatus* status = NULL;
168168
const OrtApi* ort = OrtGetApiBase()->GetApi(1);
169169

@@ -215,7 +215,12 @@ RAI_Tensor* RAI_TensorCreateFromOrtValue(OrtValue* v, size_t batch_offset, size_
215215
shape[i] = dims[i];
216216
strides[i] = 1;
217217
}
218-
shape[0] = batch_size;
218+
if (batch_size != -1) {
219+
shape[0] = batch_size;
220+
}
221+
else {
222+
batch_size = total_batch_size;
223+
}
219224
for (int64_t i = ndims - 2; i >= 0; --i)
220225
{
221226
strides[i] *= strides[i + 1] * shape[i + 1];
@@ -412,9 +417,11 @@ int RAI_ModelRunORT(RAI_ModelRunCtx **mctxs, RAI_Error *error)
412417

413418
size_t batch_sizes[nbatches];
414419
size_t batch_offsets[nbatches];
420+
size_t total_batch_size = 0;
415421
if (array_len(mctxs[0]->inputs) > 0) {
416422
for (size_t b=0; b<nbatches; ++b) {
417423
batch_sizes[b] = RAI_TensorDim(mctxs[b]->inputs[0].tensor, 0);
424+
total_batch_size += batch_sizes[b];
418425
}
419426
batch_offsets[0] = 0;
420427
for (size_t b=1; b<nbatches; ++b) {
@@ -530,14 +537,48 @@ int RAI_ModelRunORT(RAI_ModelRunCtx **mctxs, RAI_Error *error)
530537
}
531538

532539
for (size_t i = 0; i < n_output_nodes; i++) {
533-
for (size_t b=0; b<nbatches; b++) {
534-
RAI_Tensor* output_tensor = RAI_TensorCreateFromOrtValue(outputs[i], batch_offsets[b], batch_sizes[b], error);
540+
if (nbatches > 1) {
541+
OrtTensorTypeAndShapeInfo* info;
542+
status = ort->GetTensorTypeAndShape(outputs[i], &info);
543+
if (status != NULL) goto error;
544+
545+
size_t ndims;
546+
status = ort->GetDimensionsCount(info, &ndims);
547+
if (status != NULL) goto error;
548+
549+
int64_t dims[ndims];
550+
status = ort->GetDimensions(info, dims, ndims);
551+
if (status != NULL) goto error;
552+
553+
if (dims[0] != total_batch_size) {
554+
RAI_SetError(error, RAI_EMODELRUN, "ERR Model did not generate the expected batch size");
555+
ort->ReleaseStatus(status);
556+
return 1;
557+
}
558+
559+
for (size_t b=0; b<nbatches; b++) {
560+
RAI_Tensor* output_tensor = RAI_TensorCreateFromOrtValue(outputs[i], batch_offsets[b], batch_sizes[b], error);
561+
if (error->code != RAI_OK) {
562+
ort->ReleaseStatus(status);
563+
return 1;
564+
}
565+
if (output_tensor) {
566+
mctxs[b]->outputs[i].tensor = RAI_TensorGetShallowCopy(output_tensor);
567+
RAI_TensorFree(output_tensor);
568+
}
569+
else {
570+
printf("ERR: non-tensor output from ONNX models, ignoring (currently unsupported)");
571+
}
572+
}
573+
}
574+
else {
575+
RAI_Tensor* output_tensor = RAI_TensorCreateFromOrtValue(outputs[i], 0, -1, error);
535576
if (error->code != RAI_OK) {
536577
ort->ReleaseStatus(status);
537578
return 1;
538579
}
539580
if (output_tensor) {
540-
mctxs[b]->outputs[i].tensor = RAI_TensorGetShallowCopy(output_tensor);
581+
mctxs[0]->outputs[i].tensor = RAI_TensorGetShallowCopy(output_tensor);
541582
RAI_TensorFree(output_tensor);
542583
}
543584
else {

src/backends/tensorflow.c

Lines changed: 26 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ DLDataType RAI_GetDLDataTypeFromTF(TF_DataType dtype) {
7979
return (DLDataType){ .bits = 0 };
8080
}
8181

82-
RAI_Tensor* RAI_TensorCreateFromTFTensor(TF_Tensor *tensor, size_t batch_offset, size_t batch_size) {
82+
RAI_Tensor* RAI_TensorCreateFromTFTensor(TF_Tensor *tensor, size_t batch_offset, long long batch_size) {
8383
RAI_Tensor* ret = RedisModule_Calloc(1, sizeof(*ret));
8484

8585
DLContext ctx = (DLContext){
@@ -98,7 +98,12 @@ RAI_Tensor* RAI_TensorCreateFromTFTensor(TF_Tensor *tensor, size_t batch_offset,
9898
shape[i] = TF_Dim(tensor, i);
9999
strides[i] = 1;
100100
}
101-
shape[0] = batch_size;
101+
if (batch_size != -1) {
102+
shape[0] = batch_size;
103+
}
104+
else {
105+
batch_size = total_batch_size;
106+
}
102107
for (int64_t i = ndims-2 ; i >= 0 ; --i) {
103108
strides[i] *= strides[i+1] * shape[i+1];
104109
}
@@ -476,9 +481,11 @@ int RAI_ModelRunTF(RAI_ModelRunCtx** mctxs, RAI_Error *error) {
476481

477482
size_t batch_sizes[nbatches];
478483
size_t batch_offsets[nbatches];
484+
size_t total_batch_size = 0;
479485
if (ninputs > 0) {
480486
for (size_t b=0; b<nbatches; ++b) {
481487
batch_sizes[b] = RAI_TensorDim(mctxs[b]->inputs[0].tensor, 0);
488+
total_batch_size += batch_sizes[b];
482489
}
483490
batch_offsets[0] = 0;
484491
for (size_t b=1; b<nbatches; ++b) {
@@ -532,8 +539,23 @@ int RAI_ModelRunTF(RAI_ModelRunCtx** mctxs, RAI_Error *error) {
532539
}
533540

534541
for(size_t i=0; i<noutputs; ++i) {
535-
for (size_t b=0; b<nbatches; b++) {
536-
mctxs[b]->outputs[i].tensor = RAI_TensorCreateFromTFTensor(outputTensorsValues[i], batch_offsets[b], batch_sizes[b]);
542+
if (nbatches > 1) {
543+
if (TF_NumDims(outputTensorsValues[i]) == 0) {
544+
continue;
545+
}
546+
if (TF_Dim(outputTensorsValues[i], 0) != total_batch_size) {
547+
TF_DeleteTensor(outputTensorsValues[i]);
548+
TF_DeleteStatus(status);
549+
RAI_SetError(error, RAI_EMODELRUN, "ERR Model did not generate the expected batch size");
550+
return 1;
551+
}
552+
553+
for (size_t b=0; b<nbatches; b++) {
554+
mctxs[b]->outputs[i].tensor = RAI_TensorCreateFromTFTensor(outputTensorsValues[i], batch_offsets[b], batch_sizes[b]);
555+
}
556+
}
557+
else {
558+
mctxs[0]->outputs[i].tensor = RAI_TensorCreateFromTFTensor(outputTensorsValues[i], 0, -1);
537559
}
538560
TF_DeleteTensor(outputTensorsValues[i]);
539561
}

src/backends/torch.c

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,9 +93,9 @@ int RAI_ModelRunTorch(RAI_ModelRunCtx** mctxs, RAI_Error *error) {
9393

9494
size_t batch_sizes[nbatches];
9595
size_t batch_offsets[nbatches];
96+
size_t total_batch_size = 0;
9697

9798
if (nbatches > 1) {
98-
size_t total_batch_size = 0;
9999
if (array_len(mctxs[0]->inputs) > 0) {
100100
for (size_t b=0; b<nbatches; ++b) {
101101
batch_sizes[b] = RAI_TensorDim(mctxs[b]->inputs[0].tensor, 0);
@@ -147,6 +147,10 @@ int RAI_ModelRunTorch(RAI_ModelRunCtx** mctxs, RAI_Error *error) {
147147
}
148148
RAI_Tensor* output_tensor = RAI_TensorCreateFromDLTensor(outputs_dl[i]);
149149
if (nbatches > 1) {
150+
if (outputs_dl[i]->dl_tensor.shape[0] != total_batch_size) {
151+
RAI_SetError(error, RAI_EMODELRUN, "ERR Model did not generate the expected batch size");
152+
return 1;
153+
}
150154
for (size_t b=0; b<nbatches; b++) {
151155
mctxs[b]->outputs[i].tensor = RAI_TensorCreateBySlicingTensor(output_tensor, batch_offsets[b], batch_sizes[b]);
152156
}

test/test_data/pt-minimal-bb.pt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
version https://git-lfs.github.com/spec/v1
2+
oid sha256:dd657a26454418d7bfd2c02fe76fc15166f6845ec14efa9653ffdc019b021008
3+
size 1514

test/test_data/pt_minimal_bb.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
import torch
2+
3+
4+
class MyModule(torch.jit.ScriptModule):
5+
def __init__(self):
6+
super(MyModule, self).__init__()
7+
8+
@torch.jit.script_method
9+
def forward(self, a, b):
10+
return a + b, torch.ones(1)
11+
12+
13+
my_script_module = MyModule()
14+
print(my_script_module(torch.rand(2), torch.rand(2)))
15+
my_script_module.save("pt-minimal-bb.pt")

test/tests_pytorch.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -230,6 +230,51 @@ def run():
230230
env.assertEqual(values, [b'4', b'6', b'4', b'6'])
231231

232232

233+
def test_pytorch_modelrun_autobatch_badbatch(env):
234+
if not TEST_PT:
235+
return
236+
237+
con = env.getConnection()
238+
239+
test_data_path = os.path.join(os.path.dirname(__file__), 'test_data')
240+
model_filename = os.path.join(test_data_path, 'pt-minimal-bb.pt')
241+
242+
with open(model_filename, 'rb') as f:
243+
model_pb = f.read()
244+
245+
ret = con.execute_command('AI.MODELSET', 'm', 'TORCH', 'CPU',
246+
'BATCHSIZE', 4, 'MINBATCHSIZE', 3, 'BLOB', model_pb)
247+
env.assertEqual(ret, b'OK')
248+
249+
con.execute_command('AI.TENSORSET', 'a', 'FLOAT', 2, 2, 'VALUES', 2, 3, 2, 3)
250+
con.execute_command('AI.TENSORSET', 'b', 'FLOAT', 2, 2, 'VALUES', 2, 3, 2, 3)
251+
252+
con.execute_command('AI.TENSORSET', 'd', 'FLOAT', 2, 2, 'VALUES', 2, 3, 2, 3)
253+
con.execute_command('AI.TENSORSET', 'e', 'FLOAT', 2, 2, 'VALUES', 2, 3, 2, 3)
254+
255+
ensureSlaveSynced(con, env)
256+
257+
def run():
258+
con = env.getConnection()
259+
try:
260+
con.execute_command('AI.MODELRUN', 'm', 'INPUTS', 'd', 'e', 'OUTPUTS', 'f1', 'f2')
261+
except Exception as e:
262+
exception = e
263+
env.assertEqual(type(exception), redis.exceptions.ResponseError)
264+
env.assertEqual("Model did not generate the expected batch size", exception.__str__())
265+
266+
t = threading.Thread(target=run)
267+
t.start()
268+
269+
try:
270+
con.execute_command('AI.MODELRUN', 'm', 'INPUTS', 'a', 'b', 'OUTPUTS', 'c1', 'c2')
271+
except Exception as e:
272+
exception = e
273+
env.assertEqual(type(exception), redis.exceptions.ResponseError)
274+
env.assertEqual("Model did not generate the expected batch size", exception.__str__())
275+
276+
277+
233278
def test_pytorch_modelinfo(env):
234279
if not TEST_PT:
235280
env.debugPrint("skipping {} since TEST_PT=0".format(sys._getframe().f_code.co_name), force=True)

0 commit comments

Comments
 (0)