Skip to content

Commit cebc58f

Browse files
committed
fix: Updated paradigm for device casting to depend on user-specified device
- Added field to LowerInfo to hold device information - Updated internal Device struct location to allow streamlined imports - Updated BUILD files - Build strings in lowering phase using user-specified target device
1 parent 115edfb commit cebc58f

File tree

11 files changed

+69
-38
lines changed

11 files changed

+69
-38
lines changed

core/conversion/conversionctx/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ cc_library(
2121
deps = [
2222
"@tensorrt//:nvinfer",
2323
"//core/util:prelude",
24+
"//core/ir",
2425
] + select({
2526
":use_pre_cxx11_abi": ["@libtorch_pre_cxx11_abi//:libtorch"],
2627
"//conditions:default": ["@libtorch//:libtorch"],

core/conversion/conversionctx/ConversionCtx.h

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -9,28 +9,21 @@
99
#include "torch/csrc/jit/ir/ir.h"
1010

1111
#include <cuda_runtime.h>
12+
#include "core/ir/ir.h"
1213
#include "core/util/prelude.h"
1314

1415
namespace torch_tensorrt {
1516
namespace core {
1617
namespace conversion {
1718

18-
struct Device {
19-
nvinfer1::DeviceType device_type;
20-
int64_t gpu_id;
21-
int64_t dla_core;
22-
bool allow_gpu_fallback;
23-
Device() : device_type(nvinfer1::DeviceType::kGPU), gpu_id(0), dla_core(0), allow_gpu_fallback(false) {}
24-
};
25-
2619
struct BuilderSettings {
2720
std::set<nvinfer1::DataType> enabled_precisions = {};
2821
bool sparse_weights = false;
2922
bool disable_tf32 = false;
3023
bool refit = false;
3124
bool debug = false;
3225
bool truncate_long_and_double = false;
33-
Device device;
26+
ir::Device device;
3427
nvinfer1::EngineCapability capability = TRT_ENGINE_CAPABILITY_STANDARD;
3528
nvinfer1::IInt8Calibrator* calibrator = nullptr;
3629
uint64_t num_avg_timing_iters = 1;

core/ir/ir.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,14 @@ namespace torch_tensorrt {
1111
namespace core {
1212
namespace ir {
1313

14+
struct Device {
15+
nvinfer1::DeviceType device_type;
16+
int64_t gpu_id;
17+
int64_t dla_core;
18+
bool allow_gpu_fallback;
19+
Device() : device_type(nvinfer1::DeviceType::kGPU), gpu_id(0), dla_core(0), allow_gpu_fallback(false) {}
20+
};
21+
1422
struct Input : torch::CustomClassHolder {
1523
Input(){};
1624
Input(

core/lowering/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ cc_library(
2424
deps = [
2525
"//core/lowering/passes",
2626
"//core/util:prelude",
27+
"//core/ir",
2728
] + select({
2829
":use_pre_cxx11_abi": ["@libtorch_pre_cxx11_abi//:libtorch"],
2930
"//conditions:default": ["@libtorch//:libtorch"],

core/lowering/lowering.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -70,9 +70,9 @@ void LowerGraph(std::shared_ptr<torch::jit::Graph>& g, LowerInfo lower_info) {
7070
passes::SiluToSigmoidMultipication(g);
7171
passes::RemoveSingleUse0DTensors(g);
7272
passes::RemoveUnnecessaryCasts(g);
73-
passes::UnpackAndCastMaskedFill(g);
74-
passes::UnpackAndCastNumToTensor(g);
75-
passes::UnpackAndCastFull(g);
73+
passes::UnpackAndCastMaskedFill(g, lower_info.getGPUDeviceString());
74+
passes::UnpackAndCastNumToTensor(g, lower_info.getGPUDeviceString());
75+
passes::UnpackAndCastFull(g, lower_info.getGPUDeviceString());
7676
passes::ReplaceScalarImplicit(g);
7777
LOG_GRAPH(*g);
7878
}

core/lowering/lowering.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#pragma once
22
#include <memory>
3+
#include "core/ir/ir.h"
34
#include "torch/csrc/jit/ir/ir.h"
45

56
namespace torch_tensorrt {
@@ -15,8 +16,13 @@ struct LowerInfo {
1516
// Since these QDQ nodes will be identical as they share same input, one of them is eliminated due to CSE lowering
1617
// pass. Disable this in order to not disturb TensorRT's QAT optimizations.
1718
bool disable_cse = false;
19+
ir::Device target_device;
1820
std::vector<std::string> forced_fallback_modules;
1921
friend std::ostream& operator<<(std::ostream& os, const LowerInfo& l);
22+
23+
std::string getGPUDeviceString() {
24+
return "cuda:" + std::to_string(target_device.gpu_id);
25+
};
2026
};
2127

2228
void LowerBlock(torch::jit::Block* b);

core/lowering/passes/device_casting.cpp

Lines changed: 32 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -8,68 +8,86 @@ namespace core {
88
namespace lowering {
99
namespace passes {
1010

11-
void UnpackAndCastMaskedFill(std::shared_ptr<torch::jit::Graph>& graph) {
11+
void UnpackAndCastMaskedFill(std::shared_ptr<torch::jit::Graph>& graph, std::string target_device_name) {
1212
std::string masked_fill_pattern = R"IR(
1313
graph(%self, %mask, %value):
1414
%out: Tensor = aten::masked_fill_(%self, %mask, %value)
1515
return (%out))IR";
1616

1717
// Calls to masked_fill_ often utilize CPU tensors, and as such
18-
// should be casted to CUDA to avoid device mismatch errors
19-
std::string unpacked_pattern = R"IR(
18+
// should be moved to gpu to avoid device mismatch errors
19+
20+
// Separate string into portions to insert device name
21+
std::string clean_pattern_part_1 = R"IR(
2022
graph(%self, %mask, %value):
21-
%device: Device = prim::Constant[value="cuda"]()
23+
%device: Device = prim::Constant[value=")IR";
24+
25+
std::string clean_pattern_part_2 = R"IR("]()
2226
%dtype: NoneType = prim::Constant()
2327
%false: bool = prim::Constant[value=0]()
2428
%mask_cuda: Tensor = aten::to(%mask, %device, %dtype, %false, %false)
2529
%self_cuda: Tensor = aten::to(%self, %device, %dtype, %false, %false)
26-
%out: Tensor = aten::masked_fill_(%self_cuda, %mask_cuda, %value)
30+
%out: Tensor = aten::masked_fill(%self_cuda, %mask_cuda, %value)
2731
return (%out))IR";
2832

33+
auto unpacked_pattern = clean_pattern_part_1 + target_device_name + clean_pattern_part_2;
34+
2935
torch::jit::SubgraphRewriter masked_fill_rewriter;
3036
masked_fill_rewriter.RegisterRewritePattern(masked_fill_pattern, unpacked_pattern);
3137
masked_fill_rewriter.runOnGraph(graph);
3238
LOG_GRAPH("After unpack and cast masked_fill_: " << *graph);
3339
}
3440

35-
void UnpackAndCastNumToTensor(std::shared_ptr<torch::jit::Graph>& graph) {
41+
void UnpackAndCastNumToTensor(std::shared_ptr<torch::jit::Graph>& graph, std::string target_device_name) {
3642
std::string num_to_tensor_cast_pattern = R"IR(
3743
graph(%1: Scalar):
3844
%2: Tensor = prim::NumToTensor(%1)
3945
return (%2))IR";
4046

41-
// 0D Tensors are initialized on cpu, and need to be casted to CUDA
47+
// 0D Tensors are initialized on cpu, and need to be moved to gpu
4248
// to avoid device mismatch issues
43-
std::string num_to_tensor_clean_pattern = R"IR(
49+
50+
// Separate string into portions to insert device name
51+
std::string clean_pattern_part_1 = R"IR(
4452
graph(%1: Scalar):
4553
%2: Tensor = prim::NumToTensor(%1)
46-
%device: Device = prim::Constant[value="cuda"]()
54+
%device: Device = prim::Constant[value=")IR";
55+
56+
std::string clean_pattern_part_2 = R"IR("]()
4757
%dtype: NoneType = prim::Constant()
4858
%false: bool = prim::Constant[value=0]()
4959
%3: Tensor = aten::to(%2, %device, %dtype, %false, %false)
5060
return (%3))IR";
5161

62+
auto num_to_tensor_clean_pattern = clean_pattern_part_1 + target_device_name + clean_pattern_part_2;
63+
5264
torch::jit::SubgraphRewriter num_to_tensor_cast_rewriter;
5365
num_to_tensor_cast_rewriter.RegisterRewritePattern(num_to_tensor_cast_pattern, num_to_tensor_clean_pattern);
5466
num_to_tensor_cast_rewriter.runOnGraph(graph);
5567

5668
LOG_GRAPH("After unpack and cast NumToTensor: " << *graph);
5769
}
5870

59-
void UnpackAndCastFull(std::shared_ptr<torch::jit::Graph>& graph) {
71+
void UnpackAndCastFull(std::shared_ptr<torch::jit::Graph>& graph, std::string target_device_name) {
6072
std::string full_cast_pattern = R"IR(
6173
graph(%1, %2, %3, %4, %5, %6):
6274
%out: Tensor = aten::full(%1, %2, %3, %4, %5, %6)
6375
return (%out))IR";
6476

65-
// Tensors created via aten::full are initialized on cpu, and need to be casted to CUDA
77+
// Tensors created via aten::full are initialized on cpu, and need to be casted to gpu
6678
// to avoid device mismatch issues
67-
std::string full_clean_pattern = R"IR(
79+
80+
// Separate string into portions to insert device name
81+
std::string clean_pattern_part_1 = R"IR(
6882
graph(%1, %2, %3, %4, %5, %6):
69-
%cuda: Device = prim::Constant[value="cuda"]()
70-
%out: Tensor = aten::full(%1, %2, %3, %4, %cuda, %6)
83+
%device: Device = prim::Constant[value=")IR";
84+
85+
std::string clean_pattern_part_2 = R"IR("]()
86+
%out: Tensor = aten::full(%1, %2, %3, %4, %device, %6)
7187
return (%out))IR";
7288

89+
auto full_clean_pattern = clean_pattern_part_1 + target_device_name + clean_pattern_part_2;
90+
7391
torch::jit::SubgraphRewriter full_cast_rewriter;
7492
full_cast_rewriter.RegisterRewritePattern(full_cast_pattern, full_clean_pattern);
7593
full_cast_rewriter.runOnGraph(graph);

core/lowering/passes/passes.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,9 +40,9 @@ void AliasOperators(std::shared_ptr<torch::jit::Graph>& graph);
4040
void SiluToSigmoidMultipication(std::shared_ptr<torch::jit::Graph>& graph);
4141
void UnpackHardSwish(std::shared_ptr<torch::jit::Graph>& graph);
4242
void UnpackHardSigmoid(std::shared_ptr<torch::jit::Graph>& graph);
43-
void UnpackAndCastMaskedFill(std::shared_ptr<torch::jit::Graph>& graph);
44-
void UnpackAndCastNumToTensor(std::shared_ptr<torch::jit::Graph>& graph);
45-
void UnpackAndCastFull(std::shared_ptr<torch::jit::Graph>& graph);
43+
void UnpackAndCastMaskedFill(std::shared_ptr<torch::jit::Graph>& graph, std::string target_device_name);
44+
void UnpackAndCastNumToTensor(std::shared_ptr<torch::jit::Graph>& graph, std::string target_device_name);
45+
void UnpackAndCastFull(std::shared_ptr<torch::jit::Graph>& graph, std::string target_device_name);
4646
void ReplaceScalarImplicit(std::shared_ptr<torch::jit::Graph>& graph);
4747

4848
} // namespace passes

core/runtime/execute_engine.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -90,8 +90,7 @@ std::vector<at::Tensor> execute_engine(std::vector<at::Tensor> inputs, c10::intr
9090
if (current_tensor_device != target_device) {
9191
LOG_WARNING(
9292
"Input " << i << " of engine " << compiled_engine->name << " was found to be on " << current_tensor_device
93-
<< " but should be on " << target_device
94-
<< ". This tensor is being moved manually by the runtime but "
93+
<< " but should be on " << target_device << ". This tensor is being moved by the runtime but "
9594
<< "for performance considerations, ensure your inputs are all on GPU "
9695
<< "and open an issue here (https://github.com/pytorch/TensorRT/issues) if this "
9796
<< "warning persists.");

cpp/src/compile_spec.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,7 @@ torchtrt::core::CompileSpec to_internal_compile_spec(CompileSpec external) {
110110
internal.convert_info.engine_settings.debug = external.debug;
111111
internal.convert_info.engine_settings.truncate_long_and_double = external.truncate_long_and_double;
112112
internal.convert_info.engine_settings.device.allow_gpu_fallback = external.device.allow_gpu_fallback;
113+
internal.lower_info.target_device.allow_gpu_fallback = external.device.allow_gpu_fallback;
113114

114115
TORCHTRT_CHECK(
115116
!(external.require_full_compilation && (external.torch_executed_ops.size() > 0)),
@@ -130,10 +131,12 @@ torchtrt::core::CompileSpec to_internal_compile_spec(CompileSpec external) {
130131
switch (external.device.device_type) {
131132
case Device::DeviceType::kDLA:
132133
internal.convert_info.engine_settings.device.device_type = nvinfer1::DeviceType::kDLA;
134+
internal.lower_info.target_device.device_type = nvinfer1::DeviceType::kDLA;
133135
break;
134136
case Device::DeviceType::kGPU:
135137
default:
136138
internal.convert_info.engine_settings.device.device_type = nvinfer1::DeviceType::kGPU;
139+
internal.lower_info.target_device.device_type = nvinfer1::DeviceType::kGPU;
137140
}
138141

139142
switch (external.capability) {
@@ -150,6 +153,8 @@ torchtrt::core::CompileSpec to_internal_compile_spec(CompileSpec external) {
150153

151154
internal.convert_info.engine_settings.device.gpu_id = external.device.gpu_id;
152155
internal.convert_info.engine_settings.device.dla_core = external.device.dla_core;
156+
internal.lower_info.target_device.gpu_id = external.device.gpu_id;
157+
internal.lower_info.target_device.dla_core = external.device.dla_core;
153158
internal.convert_info.engine_settings.num_avg_timing_iters = external.num_avg_timing_iters;
154159
internal.convert_info.engine_settings.workspace_size = external.workspace_size;
155160
internal.convert_info.engine_settings.dla_sram_size = external.dla_sram_size;

0 commit comments

Comments
 (0)