|
14 | 14 | _LOGGER: logging.Logger = logging.getLogger(__name__) |
15 | 15 |
|
16 | 16 |
|
17 | | -def or_none(args, i): |
18 | | - return args[i] if len(args) > i else None |
| 17 | +def args_bounds_check(args, i, replacement=None): |
| 18 | + return args[i] if len(args) > i else replacement |
19 | 19 |
|
20 | 20 |
|
21 | 21 | @dynamo_tensorrt_converter(torch.ops.aten.batch_norm) |
@@ -59,17 +59,24 @@ def aten_ops_div( |
59 | 59 | # If both are TRTTensor, both are cast to float32 |
60 | 60 | if isinstance(args[0], TRTTensor) and isinstance(args[1], TRTTensor): |
61 | 61 | kwargs_new["input"], kwargs_new["other"] = cast_int_int_div_trt_tensor( |
62 | | - network, kwargs_new["input"], kwargs_new["other"] |
| 62 | + network, |
| 63 | + kwargs_new["input"], |
| 64 | + kwargs_new["other"], |
| 65 | + name, |
63 | 66 | ) |
64 | 67 | # If one is TRTTensor, it is cast to float32 |
65 | 68 | elif isinstance(args[0], TRTTensor) and ( |
66 | 69 | kwargs_new["input"].dtype == trt.int8 or kwargs_new["input"].dtype == trt.int32 |
67 | 70 | ): |
68 | | - kwargs_new["input"] = cast_trt_tensor(network, kwargs_new["input"], trt.float32) |
| 71 | + kwargs_new["input"] = cast_trt_tensor( |
| 72 | + network, kwargs_new["input"], trt.float32, name |
| 73 | + ) |
69 | 74 | elif isinstance(args[1], TRTTensor) and ( |
70 | 75 | kwargs_new["other"].dtype == trt.int8 or kwargs_new["other"].dtype == trt.int32 |
71 | 76 | ): |
72 | | - kwargs_new["other"] = cast_trt_tensor(network, kwargs_new["other"], trt.float32) |
| 77 | + kwargs_new["other"] = cast_trt_tensor( |
| 78 | + network, kwargs_new["other"], trt.float32, name |
| 79 | + ) |
73 | 80 | rounding_mode = kwargs.get("rounding_mode") |
74 | 81 | if rounding_mode is None: |
75 | 82 | return acc_ops_converters.acc_ops_div(network, target, None, kwargs_new, name) |
@@ -136,10 +143,10 @@ def aten_ops_embedding( |
136 | 143 | name, |
137 | 144 | input=args[1], |
138 | 145 | weight=args[0], |
139 | | - max_norm=or_none(args, 2), |
140 | | - norm_type=or_none(args, 3), |
141 | | - scale_grad_by_freq=or_none(args, 4), |
142 | | - sparse=or_none(args, 5), |
| 146 | + max_norm=args_bounds_check(args, 2), |
| 147 | + norm_type=args_bounds_check(args, 3), |
| 148 | + scale_grad_by_freq=args_bounds_check(args, 4), |
| 149 | + sparse=args_bounds_check(args, 5), |
143 | 150 | ) |
144 | 151 |
|
145 | 152 |
|
@@ -311,11 +318,11 @@ def aten_ops_clamp( |
311 | 318 | return impl.elementwise.clamp( |
312 | 319 | network, |
313 | 320 | target, |
314 | | - SourceIR.ACC, |
| 321 | + SourceIR.ATEN, |
315 | 322 | name, |
316 | 323 | input_val=args[0], |
317 | | - min_val=or_none(args, 1), |
318 | | - max_val=or_none(args, 2), |
| 324 | + min_val=args_bounds_check(args, 1), |
| 325 | + max_val=args_bounds_check(args, 2), |
319 | 326 | ) |
320 | 327 |
|
321 | 328 |
|
@@ -349,5 +356,5 @@ def aten_ops_slice( |
349 | 356 | args[1], |
350 | 357 | args[2], |
351 | 358 | args[3], |
352 | | - args[4], |
| 359 | + args_bounds_check(args, 4, replacement=1), |
353 | 360 | ) |
0 commit comments