Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 13 additions & 20 deletions docsrc/tutorials/ptq.rst
Original file line number Diff line number Diff line change
Expand Up @@ -167,19 +167,16 @@ a TensorRT calibrator by providing desired configuration. The following code dem
algo_type=torch_tensorrt.ptq.CalibrationAlgo.ENTROPY_CALIBRATION_2,
device=torch.device('cuda:0'))

compile_spec = {
"inputs": [torch_tensorrt.Input((1, 3, 32, 32))],
"enabled_precisions": {torch.float, torch.half, torch.int8},
"calibrator": calibrator,
"device": {
"device_type": torch_tensorrt.DeviceType.GPU,
"gpu_id": 0,
"dla_core": 0,
"allow_gpu_fallback": False,
"disable_tf32": False
}
}
trt_mod = torch_tensorrt.compile(model, compile_spec)
trt_mod = torch_tensorrt.compile(model, inputs=[torch_tensorrt.Input((1, 3, 32, 32))],
enabled_precisions={torch.float, torch.half, torch.int8},
calibrator=calibrator,
device={
"device_type": torch_tensorrt.DeviceType.GPU,
"gpu_id": 0,
"dla_core": 0,
"allow_gpu_fallback": False,
"disable_tf32": False
})

In the cases where there is a pre-existing calibration cache file that users want to use, ``CacheCalibrator`` can be used without any dataloaders. The following example demonstrates how
to use ``CacheCalibrator`` to use in INT8 mode.
Expand All @@ -188,13 +185,9 @@ to use ``CacheCalibrator`` to use in INT8 mode.

calibrator = torch_tensorrt.ptq.CacheCalibrator("./calibration.cache")

compile_settings = {
"inputs": [torch_tensorrt.Input([1, 3, 32, 32])],
"enabled_precisions": {torch.float, torch.half, torch.int8},
"calibrator": calibrator,
}

trt_mod = torch_tensorrt.compile(model, compile_settings)
trt_mod = torch_tensorrt.compile(model, inputs=[torch_tensorrt.Input([1, 3, 32, 32])],
enabled_precisions={torch.float, torch.half, torch.int8},
calibrator=calibrator)

If you already have an existing calibrator class (implemented directly using TensorRT API), you can directly set the calibrator field to your class which can be very convenient.
For a demo on how PTQ can be performed on a VGG network using Torch-TensorRT API, you can refer to https://github.com/pytorch/TensorRT/blob/master/tests/py/test_ptq_dataloader_calibrator.py
Expand Down