Skip to content

Commit d0e0466

Browse files
authored
Expose mehod name as part of backend init context
Differential Revision: D65386597 Pull Request resolved: #6622
1 parent f01b20b commit d0e0466

File tree

4 files changed

+78
-8
lines changed

4 files changed

+78
-8
lines changed

runtime/backend/backend_execution_context.h

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,11 @@ class BackendExecutionContext final {
2121
public:
2222
BackendExecutionContext(
2323
EventTracer* event_tracer = nullptr,
24-
MemoryAllocator* temp_allocator = nullptr)
25-
: event_tracer_(event_tracer), temp_allocator_(temp_allocator) {}
24+
MemoryAllocator* temp_allocator = nullptr,
25+
const char* method_name = nullptr)
26+
: event_tracer_(event_tracer),
27+
temp_allocator_(temp_allocator),
28+
method_name_(method_name) {}
2629

2730
/**
2831
* Returns a pointer to an instance of EventTracer to do profiling/debugging
@@ -52,9 +55,17 @@ class BackendExecutionContext final {
5255
return temp_allocator_;
5356
}
5457

58+
/**
59+
* Get the name of the executing method from the ExecuTorch runtime.
60+
*/
61+
const char* get_method_name() const {
62+
return method_name_;
63+
}
64+
5565
private:
5666
EventTracer* event_tracer_ = nullptr;
5767
MemoryAllocator* temp_allocator_ = nullptr;
68+
const char* method_name_ = nullptr;
5869
};
5970

6071
} // namespace runtime

runtime/backend/backend_init_context.h

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,10 @@ namespace runtime {
1818
*/
1919
class BackendInitContext final {
2020
public:
21-
explicit BackendInitContext(MemoryAllocator* runtime_allocator)
22-
: runtime_allocator_(runtime_allocator) {}
21+
explicit BackendInitContext(
22+
MemoryAllocator* runtime_allocator,
23+
const char* method_name = nullptr)
24+
: runtime_allocator_(runtime_allocator), method_name_(method_name) {}
2325

2426
/** Get the runtime allocator passed from Method. It's the same runtime
2527
* executor used by the standard executor runtime and the life span is the
@@ -29,8 +31,20 @@ class BackendInitContext final {
2931
return runtime_allocator_;
3032
}
3133

34+
/** Get the loaded method name from ExecuTorch runtime. Usually it's
35+
* "forward", however, if there are multiple methods in the .pte file, it can
36+
* be different. One example is that we may have prefill and decode methods in
37+
* the same .pte file. In this case, when client loads "prefill" method, the
38+
* `get_method_name` function will return "prefill", when client loads
39+
* "decode" method, the `get_method_name` function will return "decode".
40+
*/
41+
const char* get_method_name() const {
42+
return method_name_;
43+
}
44+
3245
private:
3346
MemoryAllocator* runtime_allocator_ = nullptr;
47+
const char* method_name_ = nullptr;
3448
};
3549

3650
} // namespace runtime

runtime/executor/method.cpp

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -626,7 +626,9 @@ Error Method::init(executorch_flatbuffer::ExecutionPlan* s_plan) {
626626

627627
for (size_t i = 0; i < n_delegate; ++i) {
628628
const auto& delegate = *delegates->Get(i);
629-
BackendInitContext backend_init_context(method_allocator);
629+
BackendInitContext backend_init_context(
630+
method_allocator,
631+
/*method_name=*/serialization_plan_->name()->c_str());
630632
Error err = BackendDelegate::Init(
631633
delegate, program_, backend_init_context, &delegates_[i]);
632634
if (err != Error::Ok) {
@@ -1097,8 +1099,9 @@ Error Method::execute_instruction() {
10971099
n_delegate_,
10981100
step_state_.instr_idx);
10991101
BackendExecutionContext backend_execution_context(
1100-
/*event_tracer*/ event_tracer_,
1101-
/*temp_allocator*/ temp_allocator_);
1102+
/*event_tracer=*/event_tracer_,
1103+
/*temp_allocator=*/temp_allocator_,
1104+
/*method_name=*/serialization_plan_->name()->c_str());
11021105
err = delegates_[delegate_idx].Execute(
11031106
backend_execution_context,
11041107
chain.argument_lists_[step_state_.instr_idx].data());

runtime/executor/test/backend_integration_test.cpp

Lines changed: 43 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ class StubBackend final : public BackendInterface {
9595
}
9696

9797
Error execute(
98-
ET_UNUSED BackendExecutionContext& context,
98+
BackendExecutionContext& context,
9999
DelegateHandle* handle,
100100
EValue** args) const override {
101101
if (execute_fn_) {
@@ -530,6 +530,48 @@ TEST_P(BackendIntegrationTest, SegmentInfoIsPassedIntoDataLoader) {
530530
EXPECT_EQ(backend_load_was_called, using_segments());
531531
}
532532

533+
TEST_P(BackendIntegrationTest, GetMethodNameDuringInitSuccess) {
534+
Result<FileDataLoader> loader = FileDataLoader::from(program_path());
535+
ASSERT_EQ(loader.error(), Error::Ok);
536+
const void* processed_data = nullptr;
537+
StubBackend::singleton().install_init(
538+
[&](FreeableBuffer* processed,
539+
ET_UNUSED ArrayRef<CompileSpec> compile_specs,
540+
ET_UNUSED BackendInitContext& backend_init_context)
541+
-> Result<DelegateHandle*> {
542+
auto method_name = backend_init_context.get_method_name();
543+
// Ensure that we can get the method name during init via context
544+
EXPECT_STREQ(method_name, "forward");
545+
processed_data = processed->data();
546+
return nullptr;
547+
});
548+
Result<Program> program = Program::load(&loader.get());
549+
ManagedMemoryManager mmm(kDefaultNonConstMemBytes, kDefaultRuntimeMemBytes);
550+
Result<Method> method = program->load_method("forward", &mmm.get());
551+
EXPECT_TRUE(method.ok());
552+
ASSERT_EQ(program.error(), Error::Ok);
553+
}
554+
555+
TEST_P(BackendIntegrationTest, GetMethodNameDuringExecuteSuccess) {
556+
Result<FileDataLoader> loader = FileDataLoader::from(program_path());
557+
ASSERT_EQ(loader.error(), Error::Ok);
558+
StubBackend::singleton().install_execute(
559+
[&](BackendExecutionContext& backend_execution_context,
560+
ET_UNUSED DelegateHandle* handle,
561+
ET_UNUSED EValue** args) -> Error {
562+
// Ensure that we can get the method name during execution via context
563+
auto method_name = backend_execution_context.get_method_name();
564+
EXPECT_STREQ(method_name, "forward");
565+
return Error::Ok;
566+
});
567+
Result<Program> program = Program::load(&loader.get());
568+
ManagedMemoryManager mmm(kDefaultNonConstMemBytes, kDefaultRuntimeMemBytes);
569+
Result<Method> method = program->load_method("forward", &mmm.get());
570+
EXPECT_TRUE(method.ok());
571+
Error err = method->execute();
572+
ASSERT_EQ(err, Error::Ok);
573+
}
574+
533575
// TODO: Add more tests for the runtime-to-backend interface. E.g.:
534576
// - Errors during init() or execute() result in runtime init/execution failures
535577
// - Correct values are passed to init()/execute()

0 commit comments

Comments
 (0)