Skip to content

Commit 48e0df3

Browse files
committed
Catch exceptions thrown during inference and report as errors
1 parent 23b6900 commit 48e0df3

File tree

4 files changed

+38
-18
lines changed

4 files changed

+38
-18
lines changed

bin/pytorch_inference/CCommandParser.h

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
#include <functional>
2020
#include <iosfwd>
2121
#include <memory>
22+
#include <optional>
2223
#include <string>
2324
#include <vector>
2425

@@ -58,7 +59,7 @@ class CCommandParser {
5859
//! \brief Inference request cache interface.
5960
class CRequestCacheInterface {
6061
public:
61-
using TComputeResponse = std::function<std::string(SRequest)>;
62+
using TComputeResponse = std::function<std::optional<std::string>(SRequest)>;
6263
using TReadResponse = std::function<void(const std::string&, bool)>;
6364

6465
public:
@@ -102,7 +103,10 @@ class CCommandParser {
102103
bool lookup(SRequest request,
103104
const TComputeResponse& computeResponse,
104105
const TReadResponse& readResponse) override {
105-
readResponse(computeResponse(std::move(request)), false);
106+
auto computed = computeResponse(std::move(request));
107+
if (computed) {
108+
readResponse(*computed, false);
109+
}
106110
return false;
107111
}
108112

bin/pytorch_inference/Main.cc

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
#include <core/CStringUtils.h>
1919
#include <core/Concurrency.h>
2020

21+
#include <optional>
2122
#include <seccomp/CSystemCallFilter.h>
2223

2324
#include <ver/CBuildInfo.h>
@@ -92,16 +93,24 @@ bool handleRequest(ml::torch::CCommandParser::CRequestCacheInterface& cache,
9293
// We time the combination of the cache lookup and (if necessary)
9394
// the inference.
9495
ml::core::CStopWatch stopWatch(true);
95-
cache.lookup(std::move(capturedRequest),
96-
[&](auto request_) -> std::string {
97-
torch::Tensor results = infer(module_, request_);
98-
return resultWriter.createInnerResult(results);
99-
},
100-
[&](const auto& innerResponseJson_, bool isCacheHit) {
101-
resultWriter.wrapAndWriteInnerResponse(innerResponseJson_,
102-
requestId, isCacheHit,
103-
stopWatch.stop());
104-
});
96+
cache.lookup(
97+
std::move(capturedRequest),
98+
[&](auto request_) -> std::optional<std::string> {
99+
try {
100+
torch::Tensor results = infer(module_, request_);
101+
return resultWriter.createInnerResult(results);
102+
} catch (const c10::Error& e) {
103+
resultWriter.writeError(request_.s_RequestId, e.what());
104+
return std::nullopt;
105+
} catch (std::runtime_error& e) {
106+
resultWriter.writeError(request_.s_RequestId, e.what());
107+
return std::nullopt;
108+
}
109+
},
110+
[&](const auto& innerResponseJson_, bool isCacheHit) {
111+
resultWriter.wrapAndWriteInnerResponse(
112+
innerResponseJson_, requestId, isCacheHit, stopWatch.stop());
113+
});
105114
});
106115
return true;
107116
}

bin/pytorch_inference/evaluate.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -288,7 +288,7 @@ def test_evaluation(args):
288288
for result in result_docs:
289289

290290
if 'error' in result:
291-
print(f"Inference failed. Request: {result['error']['request_id']}, Msg: {result['error']['error']}")
291+
print(f"Inference failed. Request: {result['request_id']}, Msg: {result['error']['error']}")
292292
results_match = False
293293
continue
294294

include/core/CCompressedLfuCache.h

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
#include <limits>
3131
#include <memory>
3232
#include <mutex>
33+
#include <optional>
3334
#include <set>
3435
#include <shared_mutex>
3536
#include <string>
@@ -65,7 +66,7 @@ class CCompressedLfuCache {
6566
using TDictionary = CCompressedDictionary<COMPRESSED_KEY_BITS / 64>;
6667
using TCompressedKey = typename TDictionary::CWord;
6768
using TCompressKey = std::function<TCompressedKey(const TDictionary&, const KEY&)>;
68-
using TComputeValueCallback = std::function<VALUE(KEY)>;
69+
using TComputeValueCallback = std::function<std::optional<VALUE>(KEY)>;
6970
using TReadValueCallback = std::function<void(const VALUE&, bool)>;
7071

7172
public:
@@ -96,6 +97,9 @@ class CCompressedLfuCache {
9697

9798
//! Lookup an item with \p key in the cache or else fall back to computing.
9899
//!
100+
//! \warning If \p computeValue fails to produce a value (returns std::nullopt)
101+
//! then \p readValue will not be called.
102+
//!
99103
//! \param[in] key The item key.
100104
//! \param[in] computeValue Computes the value in the case of a cache miss.
101105
//! \param[in] readValue Processes the value.
@@ -137,15 +141,18 @@ class CCompressedLfuCache {
137141
}
138142

139143
auto value = computeValue(std::move(key));
144+
if (!value) {
145+
return false;
146+
}
140147

141-
std::size_t itemMemoryUsage{memory::dynamicSize(value)};
148+
std::size_t itemMemoryUsage{memory::dynamicSize(*value)};
142149

143150
if (this->guardWrite(TIME_OUT, [&] {
144151
// It is possible that two values with the same key check the cache
145152
// before either takes the write lock. So check if this is already
146153
// in the cache before going any further.
147154
if (m_ItemCache.find(compressedKey) != m_ItemCache.end()) {
148-
readValue(value, true);
155+
readValue(*value, true);
149156
this->incrementCount(compressedKey);
150157
return;
151158
}
@@ -158,14 +165,14 @@ class CCompressedLfuCache {
158165
// It's possible that the cache is empty yet isn't big
159166
// enough to hold this new item.
160167
if (itemToEvict == m_ItemStats.end()) {
161-
readValue(value, false);
168+
readValue(*value, false);
162169
return;
163170
}
164171
m_RemovedCount += lastEvictedCount;
165172
lastEvictedCount = itemToEvict->count();
166173
this->removeFromCache(itemToEvict);
167174
}
168-
readValue(this->insert(compressedKey, value, itemMemoryUsage,
175+
readValue(this->insert(compressedKey, *value, itemMemoryUsage,
169176
count + lastEvictedCount),
170177
false);
171178
}) == false) {

0 commit comments

Comments
 (0)