Skip to content

Commit 9851bfe

Browse files
committed
add unset_fake_temporarily with minor changes
1 parent 6b292ea commit 9851bfe

File tree

4 files changed

+18
-8
lines changed

4 files changed

+18
-8
lines changed

py/torch_tensorrt/dynamo/_compiler.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -513,7 +513,7 @@ def compile(
513513

514514
if kwargs.get("debug", False):
515515
warnings.warn(
516-
"`debug` is deprecated. Please use `with torch_tensorrt.dynamo.Debugger(...)` to wrap your compilation call to enable debugging functionality",
516+
"`debug` is deprecated. Please use `with torch_tensorrt.dynamo.Debugger(...)` to wrap your compilation call to enable debugging functionality.",
517517
DeprecationWarning,
518518
stacklevel=2,
519519
)
@@ -1122,6 +1122,7 @@ def convert_exported_program_to_serialized_trt_engine(
11221122
Returns:
11231123
bytes: Serialized TensorRT engine, can either be saved to a file or deserialized via TensorRT APIs
11241124
"""
1125+
11251126
if kwargs.get("debug", False):
11261127
warnings.warn(
11271128
"`debug` is deprecated. Please use `with torch_tensorrt.dynamo.Debugger(...)` to wrap your compilation call to enable debugging functionality.",

py/torch_tensorrt/dynamo/conversion/impl/normalization/ops.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -325,8 +325,9 @@ def native_group_norm(
325325

326326
shape = [1, group] + [1] * (rank - 2)
327327

328-
weight_torch = torch.ones(shape)
329-
bias_torch = torch.zeros(shape)
328+
with unset_fake_temporarily():
329+
weight_torch = torch.ones(shape)
330+
bias_torch = torch.zeros(shape)
330331

331332
weight_one = get_trt_tensor(ctx, weight_torch, f"{name}_weight_one", input.dtype)
332333
bias_zero = get_trt_tensor(ctx, bias_torch, f"{name}_bias_zero", input.dtype)

tools/perf/README.md

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ Benchmark scripts depends on following Python packages in addition to requiremen
4444
Here are the list of `CompileSpec` options that can be provided directly to compile the pytorch module
4545

4646
* `--backends` : Comma separated string of backends. Eg: torch, ts_trt, dynamo, torch_compile, inductor, onnx_trt
47-
* `--model` : Name of the model file (Can be a torchscript module or a tensorrt engine (ending in `.plan` extension)). If the backend is `dynamo` or `torch_compile`, the input should be a Pytorch module (instead of a torchscript module).
47+
* `--model` : Name of the model file (Can be a torchscript module or a tensorrt engine (pairing with `--is_trt_engine`)). If the backend is `dynamo` or `torch_compile`, the input should be a Pytorch module (instead of a torchscript module).
4848
* `--model_torch` : Name of the PyTorch model file (optional, only necessary if `dynamo` or `torch_compile` is a chosen backend)
4949
* `--onnx` : ONNX model file which helps bypass the step of exporting ONNX from `model_torch`. If this argument is provided, the ONNX will be directly converted to TRT engine
5050
* `--inputs` : List of input shapes & dtypes. Eg: (1, 3, 224, 224)@fp32 for Resnet or (1, 128)@int32;(1, 128)@int32 for BERT
@@ -61,16 +61,16 @@ Eg:
6161
```
6262
python perf_run.py --model ${MODELS_DIR}/vgg16_scripted.jit.pt \
6363
--model_torch ${MODELS_DIR}/vgg16_torch.pt \
64-
--precision fp32,fp16 --inputs="(1, 3, 224, 224)@fp32" \
64+
--precision fp32,fp16 \
65+
--inputs "(1, 3, 224, 224)@fp32" \
6566
--batch_size 1 \
66-
--backends torch,ts_trt,dynamo,torch_compile,tensorrt \
67+
--backends torch,ts_trt,dynamo,torch_compile,inductor,onnx_trt \
6768
--report "vgg_perf_bs1.txt"
6869
```
6970

7071
Note:
7172

7273
1. Please note that measuring INT8 performance is only supported via a `calibration cache` file or QAT mode for `torch_tensorrt` backend.
73-
2. TensorRT engine filename should end with `.plan` otherwise it will be treated as Torchscript module.
7474

7575
### Example models
7676

tools/perf/perf_run.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -480,7 +480,15 @@ def run_onnx_trt(
480480
onnx_path = params["onnx"]
481481
else:
482482
onnx_path = f"{params['model_torch']}-onnx-trt.onnx"
483-
torch.onnx.export(model, tuple(input_tensors), onnx_path, dynamo=True)
483+
len_output = len(model(*input_tensors))
484+
# to match the output names with Torch-TRT engine's
485+
torch.onnx.export(
486+
model,
487+
tuple(input_tensors),
488+
onnx_path,
489+
dynamo=True,
490+
output_names=[f"output{i}" for i in range(len_output)],
491+
)
484492
start_compile = timeit.default_timer()
485493
builder = trt.Builder(logger)
486494
network = builder.create_network(

0 commit comments

Comments
 (0)