Skip to content

Commit c5ed524

Browse files
authored
Merge pull request #1071 from inocsin/collection_python_api
Collection: Python api support
2 parents c246159 + da4ec96 commit c5ed524

File tree

7 files changed

+50
-26
lines changed

7 files changed

+50
-26
lines changed

core/ir/GraphInputs.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ GraphInputs::GraphInputs(torch::jit::IValue& input_signature_) {
6868
inputs = flattened_inputs;
6969
input_signature = input_signature_;
7070
collection_inputs = collection_inputs_;
71+
LOG_DEBUG("Collection Input Size: " << collection_inputs_.size());
7172
}
7273

7374
} // namespace ir

core/partitioning/partitioning.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -500,7 +500,7 @@ PartitionedGraph segment_graph(torch::jit::Block* block, const PartitionInfo& pa
500500
LOG_DEBUG(
501501
"In progress TRT block does not meet minimum block size requirements, therefore folding into in progress PyTorch block");
502502
in_prog_pyt_blk_nodes.insert(
503-
in_prog_pyt_blk_nodes.end(), in_prog_trt_blk_nodes.begin(), in_prog_trt_blk_nodes.end());
503+
in_prog_pyt_blk_nodes.begin(), in_prog_trt_blk_nodes.begin(), in_prog_trt_blk_nodes.end());
504504
}
505505
in_prog_trt_blk_nodes.clear();
506506
// if there is a prim::If then this if node will be encapsulated in a SegmentedBlock

py/torch_tensorrt/csrc/tensorrt_classes.cpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -214,18 +214,30 @@ void to_internal_input_signature(torch::jit::IValue input_ivalue, torch::jit::IV
214214
} else if(input_ivalue.isCustomClass()) {
215215
core::ir::Input cur_input = (*(input_ivalue.toCustomClass<Input>())).toInternalInput();
216216
converted_ivalue = torch::jit::IValue(std::move(c10::make_intrusive<core::ir::Input>(cur_input)));
217+
} else if(input_ivalue.isPyObject()) {
218+
auto py_object_holder = input_ivalue.toPyObjectHolder();
219+
auto infer_type = py_object_holder->tryToInferType();
220+
auto type = infer_type.type();
221+
torch::jit::IValue ival = py_object_holder->toIValue(type);
222+
torch::jit::IValue converted_item;
223+
to_internal_input_signature(ival, converted_item);
224+
converted_ivalue = torch::jit::IValue(converted_item);
225+
} else {
226+
LOG_ERROR("Unknown input spec type");
217227
}
218228
}
219229

220230
core::CompileSpec init_compile_spec(CompileSpec external) {
221231
if (external.inputs.size() > 0) {
232+
LOG_DEBUG("init_compile_spec with input vector");
222233
std::vector<core::ir::Input> internal_inputs;
223234
for (auto i : external.inputs) {
224235
internal_inputs.push_back(i.toInternalInput());
225236
}
226237
core::CompileSpec internal(internal_inputs);
227238
return internal;
228239
} else {
240+
LOG_DEBUG("init_compile_spec with input signature");
229241
torch::jit::IValue converted_input_signature;
230242
to_internal_input_signature(external.input_signature.signature_ivalue, converted_input_signature);
231243
core::CompileSpec internal(converted_input_signature);

py/torch_tensorrt/csrc/torch_tensorrt_py.cpp

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
#include "pybind11/stl.h"
33

44
#include "Python.h"
5+
#include "ATen/core/jit_type.h"
56
#include "core/compiler.h"
67
#include "core/conversion/conversion.h"
78
#include "tensorrt_classes.h"
@@ -179,7 +180,11 @@ PYBIND11_MODULE(_C, m) {
179180
.def_readwrite("format", &Input::format);
180181

181182
py::class_<InputSignature>(m, "InputSignature")
182-
.def(py::init<>())
183+
.def(pybind11::init([](py::object py_obj) {
184+
InputSignature input_signature;
185+
input_signature.signature_ivalue = torch::jit::toIValue(std::move(py_obj), c10::PyObjectType::get(), c10::nullopt);
186+
return input_signature;
187+
}))
183188
.def("__str__", &InputSignature::to_str)
184189
.def_readwrite("_signature_ivalue", &InputSignature::signature_ivalue);
185190

py/torch_tensorrt/ts/_compile_spec.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -168,8 +168,7 @@ def _parse_torch_fallback(fallback_info: Dict[str, Any]) -> _ts_C.TorchFallback:
168168

169169
return info
170170

171-
def _parse_input_signature(input_signature: Any) -> _C.InputSignature:
172-
print(input_signature)
171+
def _parse_input_signature(input_signature: Any):
173172
if isinstance(input_signature, tuple):
174173
input_list = []
175174
for item in input_signature:
@@ -180,7 +179,7 @@ def _parse_input_signature(input_signature: Any) -> _C.InputSignature:
180179
input_list = []
181180
for item in input_signature:
182181
input = _parse_input_signature(item)
183-
input_list.append(input)
182+
input_list.append(input)
184183
return input_list
185184
elif isinstance(input_signature, Input) or isinstance(input_signature, torch.Tensor):
186185
i = Input._from_tensor(input_signature) if isinstance(input_signature, torch.Tensor) else input_signature
@@ -202,17 +201,14 @@ def _parse_compile_spec(compile_spec: Dict[str, Any]) -> _ts_C.CompileSpec:
202201

203202
elif compile_spec["input_signature"] is not None:
204203
log(Level.Warning, "Input signature parsing is an experimental feature, behavior and APIs may change")
205-
signature =_parse_input_signature(compile_spec["input_signature"])
206-
print(signature)
207-
info.input_signature = signature
204+
signature = _parse_input_signature(compile_spec["input_signature"])
205+
info.input_signature = _C.InputSignature(signature) # py_object
208206

209207
else:
210208
raise KeyError(
211209
"Module input definitions are requried to compile module. Provide a list of torch_tensorrt.Input keyed to \"inputs\" in the compile spec"
212210
)
213211

214-
#assert(len(info.inputs) > 0 or compile_spec["input_signature"] is not None, "Require at least one input definition to compile model")
215-
216212
if "enabled_precisions" in compile_spec:
217213
info.enabled_precisions = _parse_enabled_precisions(compile_spec["enabled_precisions"])
218214

tests/cpp/test_collection.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ TEST(CppAPITests, TestCollectionStandardTensorInput) {
3939
input_range.push_back({in0.sizes(), torch::kF16});
4040
torch_tensorrt::ts::CompileSpec compile_settings(input_range);
4141
compile_settings.require_full_compilation = true;
42-
compile_settings.min_block_size = 1;
42+
compile_settings.min_block_size = 3;
4343

4444
// // FP16 execution
4545
compile_settings.enabled_precisions = {torch::kHalf};
@@ -88,7 +88,7 @@ TEST(CppAPITests, TestCollectionTupleInput) {
8888

8989
auto compile_settings = torch_tensorrt::ts::CompileSpec(complex_input_shape2);
9090
compile_settings.require_full_compilation = false;
91-
compile_settings.min_block_size = 1;
91+
compile_settings.min_block_size = 3;
9292

9393
// // FP16 execution
9494
compile_settings.enabled_precisions = {torch::kHalf};
@@ -153,7 +153,7 @@ TEST(CppAPITests, TestCollectionListInput) {
153153

154154
auto compile_settings = torch_tensorrt::ts::CompileSpec(complex_input_shape2);
155155
compile_settings.require_full_compilation = false;
156-
compile_settings.min_block_size = 1;
156+
compile_settings.min_block_size = 3;
157157
compile_settings.torch_executed_ops.push_back("aten::__getitem__");
158158

159159
// // FP16 execution
@@ -206,7 +206,7 @@ TEST(CppAPITests, TestCollectionTupleInputOutput) {
206206

207207
auto compile_settings = torch_tensorrt::ts::CompileSpec(complex_input_shape2);
208208
compile_settings.require_full_compilation = false;
209-
compile_settings.min_block_size = 1;
209+
compile_settings.min_block_size = 3;
210210

211211
// compile_settings.torch_executed_ops.push_back("prim::TupleConstruct");
212212

@@ -276,7 +276,7 @@ TEST(CppAPITests, TestCollectionListInputOutput) {
276276

277277
auto compile_settings = torch_tensorrt::ts::CompileSpec(complex_input_shape2);
278278
compile_settings.require_full_compilation = false;
279-
compile_settings.min_block_size = 1;
279+
compile_settings.min_block_size = 3;
280280

281281
// Need to skip the conversion of __getitem__ and ListConstruct
282282
compile_settings.torch_executed_ops.push_back("aten::__getitem__");
@@ -346,7 +346,7 @@ TEST(CppAPITests, TestCollectionComplexModel) {
346346

347347
auto compile_settings = torch_tensorrt::ts::CompileSpec(complex_input_shape2);
348348
compile_settings.require_full_compilation = false;
349-
compile_settings.min_block_size = 1;
349+
compile_settings.min_block_size = 3;
350350

351351
// Need to skip the conversion of __getitem__ and ListConstruct
352352
compile_settings.torch_executed_ops.push_back("aten::__getitem__");

tests/py/test_collections.py

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,11 @@ def setUp(self):
2929

3030
def test_compile(self):
3131
compile_spec = {
32-
"input_signature": ((torchtrt.Input(self.input.shape), torchtrt.Input(self.input.shape))),
32+
"input_signature": ((torchtrt.Input(self.input.shape), torchtrt.Input(self.input.shape)),),
3333
"device": torchtrt.Device("gpu:0"),
34-
"enabled_precisions": {torch.float}
34+
"enabled_precisions": {torch.float},
35+
"require_full_compilation": False,
36+
"min_block_size": 3
3537
}
3638

3739
trt_mod = torchtrt.ts.compile(self.model, **compile_spec)
@@ -45,9 +47,11 @@ def setUp(self):
4547

4648
def test_compile(self):
4749
compile_spec = {
48-
"input_signature": ([torchtrt.Input(self.input.shape), torchtrt.Input(self.input.shape)]),
50+
"input_signature": ([torchtrt.Input(self.input.shape), torchtrt.Input(self.input.shape)],),
4951
"device": torchtrt.Device("gpu:0"),
50-
"enabled_precisions": {torch.float}
52+
"enabled_precisions": {torch.float},
53+
"require_full_compilation": False,
54+
"min_block_size": 3
5155
}
5256

5357
trt_mod = torchtrt.ts.compile(self.model, **compile_spec)
@@ -61,9 +65,11 @@ def setUp(self):
6165

6266
def test_compile(self):
6367
compile_spec = {
64-
"input_signature": ((torchtrt.Input(self.input.shape), torchtrt.Input(self.input.shape))),
68+
"input_signature": ((torchtrt.Input(self.input.shape), torchtrt.Input(self.input.shape)),),
6569
"device": torchtrt.Device("gpu:0"),
66-
"enabled_precisions": {torch.float}
70+
"enabled_precisions": {torch.float},
71+
"require_full_compilation": False,
72+
"min_block_size": 3
6773
}
6874

6975
trt_mod = torchtrt.ts.compile(self.model, **compile_spec)
@@ -79,9 +85,11 @@ def setUp(self):
7985

8086
def test_compile(self):
8187
compile_spec = {
82-
"input_signature": ([torchtrt.Input(self.input.shape), torchtrt.Input(self.input.shape)]),
88+
"input_signature": ([torchtrt.Input(self.input.shape), torchtrt.Input(self.input.shape)],),
8389
"device": torchtrt.Device("gpu:0"),
84-
"enabled_precisions": {torch.float}
90+
"enabled_precisions": {torch.float},
91+
"require_full_compilation": False,
92+
"min_block_size": 3
8593
}
8694

8795
trt_mod = torchtrt.ts.compile(self.model, **compile_spec)
@@ -98,9 +106,11 @@ def setUp(self):
98106

99107
def test_compile(self):
100108
compile_spec = {
101-
"input_signature": ([torchtrt.Input(self.input.shape), torchtrt.Input(self.input.shape)]),
109+
"input_signature": ([torchtrt.Input(self.input.shape), torchtrt.Input(self.input.shape)],),
102110
"device": torchtrt.Device("gpu:0"),
103-
"enabled_precisions": {torch.float}
111+
"enabled_precisions": {torch.float},
112+
"require_full_compilation": False,
113+
"min_block_size": 3
104114
}
105115

106116
trt_mod = torchtrt.ts.compile(self.model, **compile_spec)

0 commit comments

Comments
 (0)