From 92511aaf62f75d31cc376fceae24edf6c3eca108 Mon Sep 17 00:00:00 2001 From: Wei Wei Date: Mon, 22 Aug 2022 14:21:03 -0700 Subject: [PATCH] Changes done internally at Facebook bd46e8f292bf68fe6b87d2d5d206c89fda79a746 Shirong Wu Disable group ln fuse pass 6ce1d3bc19d75b266e99355c96daeff7054dcbf8 Wei Wei [fx2trt] set logging level to INFO at fx root 9d552dc3f69db9e4a249f80ef00803a9413e5d38 Wei Wei [fx2trt] change OSS method lower_to_trt() to compile() 6c4bdb8ac5823d161e4afc7c9d295f961aeeb0bf Mor Tzur fix engine holder test binary to fix contbuild_pytorch_fx2trt_build 636d0ab2a3d0f09267e25b8b8e7eedd4d91d791d Yinghai Lu [easy] remove random prints 5a97668307c26e69a89a4e02a535e319eaf3ce3d Wei Wei [ads] sequential linear fuse 508338ab343e407ee49605919508210b62ad9a52 Wei Wei [fx2trt] minor literal fix --- py/torch_tensorrt/fx/__init__.py | 1 - py/torch_tensorrt/fx/lower.py | 2 +- py/torch_tensorrt/fx/passes/pass_utils.py | 47 +++++++++++++++++++ .../fx/tracer/acc_tracer/acc_ops.py | 4 +- 4 files changed, 51 insertions(+), 3 deletions(-) diff --git a/py/torch_tensorrt/fx/__init__.py b/py/torch_tensorrt/fx/__init__.py index 493d749d09..c1c42c446f 100644 --- a/py/torch_tensorrt/fx/__init__.py +++ b/py/torch_tensorrt/fx/__init__.py @@ -11,6 +11,5 @@ from .input_tensor_spec import generate_input_specs, InputTensorSpec # noqa from .lower_setting import LowerSetting # noqa from .trt_module import TRTModule # noqa -from .lower import compile logging.basicConfig(level=logging.INFO) diff --git a/py/torch_tensorrt/fx/lower.py b/py/torch_tensorrt/fx/lower.py index deeee14178..59b59d580f 100644 --- a/py/torch_tensorrt/fx/lower.py +++ b/py/torch_tensorrt/fx/lower.py @@ -53,7 +53,7 @@ def compile( timing_cache_prefix: Timing cache file name for timing cache used by fx2trt. save_timing_cache: Update timing cache with current timing cache data if set to True. cuda_graph_batch_size: Cuda graph batch size, default to be -1. - + dynamic_batch: batch dimension (dim=0) is dynamic. Returns: A torch.nn.Module lowered by TensorRT. """ diff --git a/py/torch_tensorrt/fx/passes/pass_utils.py b/py/torch_tensorrt/fx/passes/pass_utils.py index d430a67408..3fb88e04a9 100644 --- a/py/torch_tensorrt/fx/passes/pass_utils.py +++ b/py/torch_tensorrt/fx/passes/pass_utils.py @@ -102,6 +102,53 @@ def bounded_method(*args, **kwargs): return dec_for_method +def log_perf_before_after(pass_: PassFunc) -> PassFunc: + """ + Wraps a pass function to log perf of the module before and after the pass + """ + + @wraps(pass_) + def check_perf_with_before_after_log( + module: fx.GraphModule, input: Input + ) -> fx.GraphModule: + def benchmark_torch_function(iters: int, f, *args) -> float: + """Estimates the average time duration for a single inference call in second + + If the input is batched, then the estimation is for the batches inference call. + + Args: + iters: number of inference iterations to run + f: a function to perform a single inference call + + Returns: + estimated average time duration in second for a single inference call + """ + with torch.inference_mode(): + f(*args) + torch.cuda.synchronize() + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + # print("== Start benchmark iterations") + with torch.inference_mode(): + start_event.record() + for _ in range(iters): + f(*args) + end_event.record() + torch.cuda.synchronize() + # print("== End benchmark iterations") + return (start_event.elapsed_time(end_event) * 1.0e-3) / iters + + time_before = benchmark_torch_function(100, lambda: module(*input)) + _LOGGER.info(f"[{pass_}] Perf Before(eager mode): {time_before}") + + module = pass_(module, input) + time_after = benchmark_torch_function(100, lambda: module(*input)) + _LOGGER.info(f"[{pass_}] Perf After(eager mode): {time_after}") + return module + + return check_perf_with_before_after_log + + def log_before_after(pass_: PassFunc) -> PassFunc: """ Wraps a pass function to log the module graph before and after the pass diff --git a/py/torch_tensorrt/fx/tracer/acc_tracer/acc_ops.py b/py/torch_tensorrt/fx/tracer/acc_tracer/acc_ops.py index d1a5322316..ccd572b9aa 100644 --- a/py/torch_tensorrt/fx/tracer/acc_tracer/acc_ops.py +++ b/py/torch_tensorrt/fx/tracer/acc_tracer/acc_ops.py @@ -778,7 +778,9 @@ def dropout_mapper(node: torch.fx.Node, mod: nn.Module): assert callable(stochastic_depth) except Exception as e: - warnings.warn(f"Unable to import torchvision related libraries.: {e}") + warnings.warn( + f"Unable to import torchvision related libraries.: {e}. Please install torchvision lib in order to lower stochastic_depth" + ) else: @register_custom_acc_mapper_fn(