Skip to content

Commit 23b6900

Browse files
authored
[NLP] Evaluate batched inference calls singularly (#2538)
One at a time inference uses much less memory and throughput is significantly diminished.
1 parent 3c8abbd commit 23b6900

File tree

4 files changed

+32
-22
lines changed

4 files changed

+32
-22
lines changed

bin/pytorch_inference/CCommandParser.cc

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -245,12 +245,6 @@ bool CCommandParser::checkArrayContainsUInts(const rapidjson::Value::ConstArray&
245245
}) == arr.End();
246246
}
247247

248-
bool CCommandParser::checkArrayContainsDoubles(const rapidjson::Value::ConstArray& arr) {
249-
return std::find_if(arr.Begin(), arr.End(), [](const auto& i) {
250-
return i.IsDouble() == false;
251-
}) == arr.End();
252-
}
253-
254248
CCommandParser::SRequest
255249
CCommandParser::jsonToInferenceRequest(const rapidjson::Document& doc) {
256250
SRequest request;

bin/pytorch_inference/CCommandParser.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,6 @@ class CCommandParser {
183183
static EMessageType validateControlMessageJson(const rapidjson::Document& doc,
184184
const TErrorHandlerFunc& errorHandler);
185185
static bool checkArrayContainsUInts(const rapidjson::Value::ConstArray& arr);
186-
static bool checkArrayContainsDoubles(const rapidjson::Value::ConstArray& arr);
187186
static SRequest jsonToInferenceRequest(const rapidjson::Document& doc);
188187
static SControlMessage jsonToControlMessage(const rapidjson::Document& doc);
189188

bin/pytorch_inference/Main.cc

Lines changed: 29 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
#include "CThreadSettings.h"
3232

3333
#include <ATen/Parallel.h>
34+
#include <ATen/ops/cat.h>
3435
#include <torch/csrc/api/include/torch/types.h>
3536
#include <torch/script.h>
3637

@@ -44,26 +45,39 @@ torch::Tensor infer(torch::jit::script::Module& module_,
4445
std::vector<torch::jit::IValue> inputs;
4546
inputs.reserve(1 + request.s_SecondaryArguments.size());
4647

47-
std::array<std::int64_t, 2> dimensions = {request.s_NumberInferences,
48-
request.s_NumberInputTokens};
48+
std::array<std::int64_t, 2> dimensions = {1, request.s_NumberInputTokens};
4949
at::IntArrayRef inputSize{dimensions};
5050

51-
// Sequence tokens.
52-
inputs.emplace_back(torch::from_blob(static_cast<void*>(request.s_Tokens.data()),
53-
inputSize, at::dtype(torch::kInt64)));
54-
// Attention mask.
55-
for (auto& args : request.s_SecondaryArguments) {
56-
inputs.emplace_back(torch::from_blob(static_cast<void*>(args.data()),
57-
inputSize, at::dtype(torch::kInt64)));
58-
}
51+
std::vector<at::Tensor> all;
5952

6053
torch::InferenceMode inferenceModeGuard;
61-
auto result = module_.forward(inputs);
62-
if (result.isTuple()) {
63-
// For transformers the result tensor is the first element in a tuple.
64-
return result.toTuple()->elements()[0].toTensor();
54+
55+
for (int i = 0; i < request.s_NumberInferences; i++) {
56+
57+
std::size_t offset = i * request.s_NumberInputTokens;
58+
59+
// Sequence tokens.
60+
inputs.emplace_back(
61+
torch::from_blob(static_cast<void*>(request.s_Tokens.data() + offset),
62+
inputSize, at::dtype(torch::kInt64)));
63+
// Attention mask etc
64+
for (auto& args : request.s_SecondaryArguments) {
65+
inputs.emplace_back(torch::from_blob(static_cast<void*>(args.data() + offset),
66+
inputSize, at::dtype(torch::kInt64)));
67+
}
68+
69+
auto output = module_.forward(inputs);
70+
if (output.isTuple()) {
71+
// For transformers the result tensor is the first element in a tuple.
72+
all.push_back(output.toTuple()->elements()[0].toTensor());
73+
} else {
74+
all.push_back(output.toTensor());
75+
}
76+
77+
inputs.clear();
6578
}
66-
return result.toTensor();
79+
80+
return at::cat(all, 0);
6781
}
6882

6983
bool handleRequest(ml::torch::CCommandParser::CRequestCacheInterface& cache,

docs/CHANGELOG.asciidoc

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,9 @@
3636
* Improve detection of time shifts, for example for day light saving. (See {ml-pull}2479[#2479].)
3737
* Improve detection of calendar cyclic components with long bucket lengths. (See {ml-pull}2493[#2493].)
3838

39+
=== Bug Fixes
40+
* Prevent high memory usage by evaluating batch inference singularly. (See {ml-pull}2538[#2538].)
41+
3942
== {es} version 8.8.0
4043

4144
=== Enhancements

0 commit comments

Comments
 (0)