Skip to content

Commit 4b769f8

Browse files
authored
[NLP] Catch exceptions thrown during inference and report as errors (#2542)
1 parent 23b6900 commit 4b769f8

File tree

6 files changed

+55
-18
lines changed

6 files changed

+55
-18
lines changed

bin/pytorch_inference/CCommandParser.h

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
#include <functional>
2020
#include <iosfwd>
2121
#include <memory>
22+
#include <optional>
2223
#include <string>
2324
#include <vector>
2425

@@ -58,7 +59,7 @@ class CCommandParser {
5859
//! \brief Inference request cache interface.
5960
class CRequestCacheInterface {
6061
public:
61-
using TComputeResponse = std::function<std::string(SRequest)>;
62+
using TComputeResponse = std::function<std::optional<std::string>(SRequest)>;
6263
using TReadResponse = std::function<void(const std::string&, bool)>;
6364

6465
public:
@@ -102,7 +103,10 @@ class CCommandParser {
102103
bool lookup(SRequest request,
103104
const TComputeResponse& computeResponse,
104105
const TReadResponse& readResponse) override {
105-
readResponse(computeResponse(std::move(request)), false);
106+
auto computed = computeResponse(std::move(request));
107+
if (computed) {
108+
readResponse(*computed, false);
109+
}
106110
return false;
107111
}
108112

bin/pytorch_inference/Main.cc

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737

3838
#include <cstdint>
3939
#include <memory>
40+
#include <optional>
4041
#include <string>
4142

4243
torch::Tensor infer(torch::jit::script::Module& module_,
@@ -92,16 +93,24 @@ bool handleRequest(ml::torch::CCommandParser::CRequestCacheInterface& cache,
9293
// We time the combination of the cache lookup and (if necessary)
9394
// the inference.
9495
ml::core::CStopWatch stopWatch(true);
95-
cache.lookup(std::move(capturedRequest),
96-
[&](auto request_) -> std::string {
97-
torch::Tensor results = infer(module_, request_);
98-
return resultWriter.createInnerResult(results);
99-
},
100-
[&](const auto& innerResponseJson_, bool isCacheHit) {
101-
resultWriter.wrapAndWriteInnerResponse(innerResponseJson_,
102-
requestId, isCacheHit,
103-
stopWatch.stop());
104-
});
96+
cache.lookup(
97+
std::move(capturedRequest),
98+
[&](auto request_) -> std::optional<std::string> {
99+
try {
100+
torch::Tensor results = infer(module_, request_);
101+
return resultWriter.createInnerResult(results);
102+
} catch (const c10::Error& e) {
103+
resultWriter.writeError(request_.s_RequestId, e.what());
104+
return std::nullopt;
105+
} catch (std::runtime_error& e) {
106+
resultWriter.writeError(request_.s_RequestId, e.what());
107+
return std::nullopt;
108+
}
109+
},
110+
[&](const auto& innerResponseJson_, bool isCacheHit) {
111+
resultWriter.wrapAndWriteInnerResponse(
112+
innerResponseJson_, requestId, isCacheHit, stopWatch.stop());
113+
});
105114
});
106115
return true;
107116
}

bin/pytorch_inference/evaluate.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -288,7 +288,7 @@ def test_evaluation(args):
288288
for result in result_docs:
289289

290290
if 'error' in result:
291-
print(f"Inference failed. Request: {result['error']['request_id']}, Msg: {result['error']['error']}")
291+
print(f"Inference failed. Request: {result['request_id']}, Msg: {result['error']['error']}")
292292
results_match = False
293293
continue
294294

docs/CHANGELOG.asciidoc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838

3939
=== Bug Fixes
4040
* Prevent high memory usage by evaluating batch inference singularly. (See {ml-pull}2538[#2538].)
41+
* Catch exceptions thrown during inference and report as errors. (See {ml-pull}2542[#2542].)
4142

4243
== {es} version 8.8.0
4344

include/core/CCompressedLfuCache.h

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
#include <limits>
3131
#include <memory>
3232
#include <mutex>
33+
#include <optional>
3334
#include <set>
3435
#include <shared_mutex>
3536
#include <string>
@@ -65,7 +66,7 @@ class CCompressedLfuCache {
6566
using TDictionary = CCompressedDictionary<COMPRESSED_KEY_BITS / 64>;
6667
using TCompressedKey = typename TDictionary::CWord;
6768
using TCompressKey = std::function<TCompressedKey(const TDictionary&, const KEY&)>;
68-
using TComputeValueCallback = std::function<VALUE(KEY)>;
69+
using TComputeValueCallback = std::function<std::optional<VALUE>(KEY)>;
6970
using TReadValueCallback = std::function<void(const VALUE&, bool)>;
7071

7172
public:
@@ -96,6 +97,9 @@ class CCompressedLfuCache {
9697

9798
//! Lookup an item with \p key in the cache or else fall back to computing.
9899
//!
100+
//! \warning If \p computeValue fails to produce a value (returns std::nullopt)
101+
//! then \p readValue will not be called.
102+
//!
99103
//! \param[in] key The item key.
100104
//! \param[in] computeValue Computes the value in the case of a cache miss.
101105
//! \param[in] readValue Processes the value.
@@ -137,15 +141,18 @@ class CCompressedLfuCache {
137141
}
138142

139143
auto value = computeValue(std::move(key));
144+
if (!value) {
145+
return false;
146+
}
140147

141-
std::size_t itemMemoryUsage{memory::dynamicSize(value)};
148+
std::size_t itemMemoryUsage{memory::dynamicSize(*value)};
142149

143150
if (this->guardWrite(TIME_OUT, [&] {
144151
// It is possible that two values with the same key check the cache
145152
// before either takes the write lock. So check if this is already
146153
// in the cache before going any further.
147154
if (m_ItemCache.find(compressedKey) != m_ItemCache.end()) {
148-
readValue(value, true);
155+
readValue(*value, true);
149156
this->incrementCount(compressedKey);
150157
return;
151158
}
@@ -158,14 +165,14 @@ class CCompressedLfuCache {
158165
// It's possible that the cache is empty yet isn't big
159166
// enough to hold this new item.
160167
if (itemToEvict == m_ItemStats.end()) {
161-
readValue(value, false);
168+
readValue(*value, false);
162169
return;
163170
}
164171
m_RemovedCount += lastEvictedCount;
165172
lastEvictedCount = itemToEvict->count();
166173
this->removeFromCache(itemToEvict);
167174
}
168-
readValue(this->insert(compressedKey, value, itemMemoryUsage,
175+
readValue(this->insert(compressedKey, *value, itemMemoryUsage,
169176
count + lastEvictedCount),
170177
false);
171178
}) == false) {

lib/core/unittest/CCompressedLfuCacheTest.cc

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
#include <boost/test/unit_test.hpp>
2525

2626
#include <chrono>
27+
#include <optional>
2728
#include <sstream>
2829
#include <string>
2930
#include <thread>
@@ -612,4 +613,19 @@ BOOST_AUTO_TEST_CASE(testClear) {
612613
BOOST_TEST_REQUIRE(cache.checkInvariants());
613614
}
614615

616+
BOOST_AUTO_TEST_CASE(testComputeValueReturnsNullOpt) {
617+
TStrStrCache cache{32 * core::constants::BYTES_IN_KILOBYTES,
618+
[](const TStrStrCache::TDictionary& dictionary, const std::string& key) {
619+
return dictionary.word(key);
620+
}};
621+
622+
bool valueRead{false};
623+
624+
BOOST_REQUIRE_EQUAL(
625+
false,
626+
cache.lookup("key_1", [](std::string) { return std::nullopt; },
627+
[&valueRead](const std::string&, bool) { valueRead = true; }));
628+
BOOST_REQUIRE_EQUAL(false, valueRead);
629+
}
630+
615631
BOOST_AUTO_TEST_SUITE_END()

0 commit comments

Comments
 (0)