diff --git a/docsrc/tutorials/ptq.rst b/docsrc/tutorials/ptq.rst index 615864ef4a..3acb89bf41 100644 --- a/docsrc/tutorials/ptq.rst +++ b/docsrc/tutorials/ptq.rst @@ -167,19 +167,16 @@ a TensorRT calibrator by providing desired configuration. The following code dem algo_type=torch_tensorrt.ptq.CalibrationAlgo.ENTROPY_CALIBRATION_2, device=torch.device('cuda:0')) - compile_spec = { - "inputs": [torch_tensorrt.Input((1, 3, 32, 32))], - "enabled_precisions": {torch.float, torch.half, torch.int8}, - "calibrator": calibrator, - "device": { - "device_type": torch_tensorrt.DeviceType.GPU, - "gpu_id": 0, - "dla_core": 0, - "allow_gpu_fallback": False, - "disable_tf32": False - } - } - trt_mod = torch_tensorrt.compile(model, compile_spec) + trt_mod = torch_tensorrt.compile(model, inputs=[torch_tensorrt.Input((1, 3, 32, 32))], + enabled_precisions={torch.float, torch.half, torch.int8}, + calibrator=calibrator, + device={ + "device_type": torch_tensorrt.DeviceType.GPU, + "gpu_id": 0, + "dla_core": 0, + "allow_gpu_fallback": False, + "disable_tf32": False + }) In the cases where there is a pre-existing calibration cache file that users want to use, ``CacheCalibrator`` can be used without any dataloaders. The following example demonstrates how to use ``CacheCalibrator`` to use in INT8 mode. @@ -188,13 +185,9 @@ to use ``CacheCalibrator`` to use in INT8 mode. calibrator = torch_tensorrt.ptq.CacheCalibrator("./calibration.cache") - compile_settings = { - "inputs": [torch_tensorrt.Input([1, 3, 32, 32])], - "enabled_precisions": {torch.float, torch.half, torch.int8}, - "calibrator": calibrator, - } - - trt_mod = torch_tensorrt.compile(model, compile_settings) + trt_mod = torch_tensorrt.compile(model, inputs=[torch_tensorrt.Input([1, 3, 32, 32])], + enabled_precisions={torch.float, torch.half, torch.int8}, + calibrator=calibrator) If you already have an existing calibrator class (implemented directly using TensorRT API), you can directly set the calibrator field to your class which can be very convenient. For a demo on how PTQ can be performed on a VGG network using Torch-TensorRT API, you can refer to https://github.com/pytorch/TensorRT/blob/master/tests/py/test_ptq_dataloader_calibrator.py