Skip to content

Commit 24b6f44

Browse files
committed
fix: fix the schema not found for node error
Signed-off-by: Bo Wang <[email protected]>
1 parent 84ffb67 commit 24b6f44

File tree

1 file changed

+7
-3
lines changed

1 file changed

+7
-3
lines changed

core/partitioning/partitioning.cpp

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,10 +40,13 @@ bool containNonTensorOutputs(torch::jit::Node* n) {
4040
}
4141

4242
bool isModifyingNodes(torch::jit::Node* node, torch::jit::Value* val) {
43-
const auto& schema = node->schema();
43+
const torch::jit::FunctionSchema* schema = node->maybeSchema();
44+
if (!schema) {
45+
return false;
46+
}
4447
for (size_t i = 0; i < node->inputs().size(); ++i) {
4548
if (node->inputs()[i] == val) {
46-
const at::AliasInfo* formal = schema.arguments()[i].alias_info();
49+
const at::AliasInfo* formal = schema->arguments()[i].alias_info();
4750
if (formal && formal->isWrite()) {
4851
return true;
4952
}
@@ -124,7 +127,8 @@ void find_all_fallback_nodes(
124127
if (!isTensor(output)) {
125128
for (auto use : output->uses()) {
126129
auto node = use.user;
127-
if (node->kind() != torch::jit::prim::Constant && global_fallback_nodes.insert({node, FallbackNodeType::kNON_TENSOR}).second) {
130+
if (node->kind() != torch::jit::prim::Constant &&
131+
global_fallback_nodes.insert({node, FallbackNodeType::kNON_TENSOR}).second) {
128132
q.push(node);
129133
}
130134
}

0 commit comments

Comments
 (0)