Skip to content

Commit 1e56b61

Browse files
zewenli98laikhtewari
authored andcommitted
feat: support adaptive avg pool 2d and 3d dynamo converters (#2632)
1 parent 182344a commit 1e56b61

File tree

3 files changed

+477
-3
lines changed

3 files changed

+477
-3
lines changed

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2226,7 +2226,12 @@ def aten_ops_avg_pool(
22262226

22272227

22282228
@dynamo_tensorrt_converter(torch.ops.aten.adaptive_avg_pool1d.default)
2229-
def aten_ops_adaptive_avg_pool(
2229+
@enforce_tensor_types(
2230+
{
2231+
0: (TRTTensor,),
2232+
}
2233+
)
2234+
def aten_ops_adaptive_avg_pool1d(
22302235
ctx: ConversionContext,
22312236
target: Target,
22322237
args: Tuple[Argument, ...],
@@ -2243,6 +2248,32 @@ def aten_ops_adaptive_avg_pool(
22432248
)
22442249

22452250

2251+
@dynamo_tensorrt_converter(torch.ops.aten.adaptive_avg_pool2d.default)
2252+
@dynamo_tensorrt_converter(torch.ops.aten._adaptive_avg_pool2d.default)
2253+
@dynamo_tensorrt_converter(torch.ops.aten.adaptive_avg_pool3d.default)
2254+
@dynamo_tensorrt_converter(torch.ops.aten._adaptive_avg_pool3d.default)
2255+
@enforce_tensor_types(
2256+
{
2257+
0: (TRTTensor,),
2258+
}
2259+
)
2260+
def aten_ops_adaptive_avg_poolNd(
2261+
ctx: ConversionContext,
2262+
target: Target,
2263+
args: Tuple[Argument, ...],
2264+
kwargs: Dict[str, Argument],
2265+
name: str,
2266+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
2267+
return impl.pool.adaptive_avg_poolNd(
2268+
ctx,
2269+
target,
2270+
source_ir=SourceIR.ATEN,
2271+
name=name,
2272+
input=args[0],
2273+
output_size=args[1],
2274+
)
2275+
2276+
22462277
def max_pool_param_validator(pool_node: Node) -> bool:
22472278
dilation = args_bounds_check(pool_node.args, 4, 1)
22482279
ceil_mode = args_bounds_check(pool_node.args, 5, False)

py/torch_tensorrt/dynamo/conversion/impl/pool.py

Lines changed: 229 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,10 @@
66
from torch.fx.node import Target
77
from torch_tensorrt.dynamo._SourceIR import SourceIR
88
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
9-
from torch_tensorrt.dynamo.conversion.converter_utils import extend_attr_to_tuple
9+
from torch_tensorrt.dynamo.conversion.converter_utils import (
10+
extend_attr_to_tuple,
11+
get_positive_dim,
12+
)
1013
from torch_tensorrt.fx.converters.converter_utils import (
1114
has_dynamic_shape,
1215
set_layer_name,
@@ -169,3 +172,228 @@ def end_index(idx: int, out_dim: int, in_dim: int) -> int:
169172

170173
output = impl.cat.cat(ctx, target, source_ir, f"{name}_cat", output_list, dim=-1)
171174
return output
175+
176+
177+
def adaptive_avg_poolNd(
178+
ctx: ConversionContext,
179+
target: Union[Target, str],
180+
source_ir: Optional[SourceIR],
181+
name: str,
182+
input: TRTTensor,
183+
output_size: Sequence[int],
184+
) -> TRTTensor:
185+
input_shape = input.shape
186+
input_rank = len(input_shape)
187+
output_rank = len(output_size)
188+
need_reshape_back = False
189+
190+
if input_rank == output_rank + 1: # reshape to 4D/5D for TRT pooling
191+
input = impl.shuffle.reshape(
192+
ctx, target, source_ir, f"{name}_reshape", input, (1, *input.shape)
193+
)
194+
need_reshape_back = True
195+
input_shape = input.shape
196+
input_rank = len(input_shape)
197+
198+
extend_len = len(output_size)
199+
output_size = list(output_size)
200+
original_input = input
201+
202+
# repeat_interleave the input if the dim of output is larger than input
203+
insert_axises = []
204+
for axis in range(1, extend_len + 1):
205+
axis = -axis
206+
positive_axis = get_positive_dim(
207+
axis, input_rank
208+
) # convert to positive axis, which is for calculating new shapes below
209+
input_dim = input_shape[axis]
210+
output_dim = output_size[axis]
211+
diff = output_dim - input_dim
212+
if diff > 0: # the dim of output is larger than input
213+
times = output_dim // input_dim
214+
remainder = output_dim % input_dim
215+
if (
216+
diff == 2 and remainder == 2
217+
): # case 1: output_dim - input_dim == 2 and is not an integral multiple
218+
insert_axises.append(axis)
219+
remainder -= 1
220+
output_size[axis] -= 1
221+
222+
if (
223+
remainder + 1 == input_dim
224+
): # case 2: remainder + 1 == input_dim, we will repeat_interleave the whole input
225+
remainder = 0
226+
times += 1
227+
228+
flags = [] # record the axis that needs to be repeated
229+
concat_list = []
230+
for j in range(
231+
input_dim
232+
): # iterate the input dim to see which dim needs to be repeated or not
233+
single_elem = impl.select.select(
234+
ctx, target, source_ir, f"{name}_select_{axis}_{j}", input, axis, j
235+
)
236+
new_shape = list(single_elem.shape)
237+
new_shape.insert(positive_axis, 1)
238+
single_elem = impl.shuffle.reshape(
239+
ctx,
240+
target,
241+
source_ir,
242+
f"{name}_reshape_{axis}_{j}",
243+
single_elem,
244+
new_shape,
245+
)
246+
if remainder > 0 or j in flags:
247+
concat_list.extend([single_elem] * (times + 1))
248+
remainder -= 2
249+
flags.append(input_dim - j - 1)
250+
else:
251+
concat_list.extend([single_elem] * times)
252+
out = impl.cat.cat(
253+
ctx, target, source_ir, f"{name}_cat_{axis}_{j}", concat_list, axis
254+
)
255+
input = out
256+
257+
stride = tuple(
258+
input.shape[-extend_len + i] // output_size[i] for i in range(extend_len)
259+
)
260+
kernel_size = tuple(
261+
input.shape[-extend_len + i] - (output_size[i] - 1) * stride[i]
262+
for i in range(extend_len)
263+
)
264+
265+
# Don't have to pool, directly return
266+
if all(s == 1 for s in stride) and all(k == 1 for k in kernel_size):
267+
if need_reshape_back: # reshape back
268+
input = impl.shuffle.reshape(
269+
ctx,
270+
target,
271+
source_ir,
272+
f"{name}_reshape_back",
273+
input,
274+
(*input.shape[1:],),
275+
)
276+
return input
277+
278+
layer = ctx.net.add_pooling_nd(
279+
input=input, type=trt.PoolingType.AVERAGE, window_size=kernel_size
280+
)
281+
layer.stride_nd = stride
282+
set_layer_name(layer, target, f"{name}_pooling_{extend_len}d", source_ir)
283+
284+
output = layer.get_output(0)
285+
286+
# For case 1, we need to split the output and insert the mid of input
287+
for axis in insert_axises:
288+
positive_axis = get_positive_dim(axis, input_rank)
289+
input_dim = input_shape[axis]
290+
output_dim = output_size[axis]
291+
if input_dim % 2 == 1:
292+
prev_one = impl.select.select(
293+
ctx,
294+
target,
295+
source_ir,
296+
f"{name}_select_prev_one_{axis}",
297+
output,
298+
axis,
299+
output_dim // 2 - 1,
300+
)
301+
extend_shape = list(prev_one.shape)
302+
extend_shape.insert(positive_axis, 1)
303+
prev_one = impl.shuffle.reshape(
304+
ctx,
305+
target,
306+
source_ir,
307+
f"{name}_reshape_extend_shape_{axis}",
308+
prev_one,
309+
extend_shape,
310+
)
311+
prev_two = impl.select.select(
312+
ctx,
313+
target,
314+
source_ir,
315+
f"{name}_select_prev_two_{axis}",
316+
output,
317+
axis,
318+
output_dim // 2 - 2,
319+
)
320+
prev_two = impl.shuffle.reshape(
321+
ctx,
322+
target,
323+
source_ir,
324+
f"{name}_two_shape_reshape_{axis}",
325+
prev_two,
326+
extend_shape,
327+
)
328+
prev_one_two_diff = impl.elementwise.sub(
329+
ctx,
330+
target,
331+
source_ir,
332+
f"{name}_prev_one_two_diff_{axis}",
333+
prev_one,
334+
prev_two,
335+
)
336+
337+
mid = impl.elementwise.add(
338+
ctx,
339+
target,
340+
source_ir,
341+
f"{name}_mid_{axis}",
342+
prev_one,
343+
prev_one_two_diff,
344+
)
345+
split_output = impl.split.split(
346+
ctx, target, source_ir, f"{name}_split_{axis}", output, 2, axis
347+
)
348+
split_output.insert(1, mid)
349+
output = impl.cat.cat(
350+
ctx, target, source_ir, f"{name}_cat_{axis}", split_output, axis
351+
)
352+
else:
353+
mid1 = impl.select.select(
354+
ctx,
355+
target,
356+
source_ir,
357+
f"{name}_select_{axis}",
358+
original_input,
359+
axis,
360+
input_dim // 2 - 1,
361+
)
362+
new_shape = list(mid1.shape)
363+
new_shape.insert(positive_axis, 1)
364+
mid1 = impl.shuffle.reshape(
365+
ctx, target, source_ir, f"{name}_reshape_{axis}", mid1, new_shape
366+
)
367+
mid2 = impl.select.select(
368+
ctx,
369+
target,
370+
source_ir,
371+
f"{name}_select_{axis}",
372+
original_input,
373+
axis,
374+
input_dim // 2,
375+
)
376+
mid2 = impl.shuffle.reshape(
377+
ctx, target, source_ir, f"{name}_reshape_{axis}", mid2, new_shape
378+
)
379+
split_output = impl.split.split(
380+
ctx,
381+
target,
382+
source_ir,
383+
f"{name}_split_{axis}",
384+
output,
385+
[output_dim // 2, 1, output_dim // 2],
386+
axis,
387+
)
388+
split_output[1] = mid1
389+
split_output.insert(2, mid2)
390+
output = impl.cat.cat(
391+
ctx, target, source_ir, f"{name}_cat_{axis}", split_output, axis
392+
)
393+
394+
if need_reshape_back: # reshape back
395+
output = impl.shuffle.reshape(
396+
ctx, target, source_ir, f"{name}_reshape_back", output, (*output.shape[1:],)
397+
)
398+
399+
return output

0 commit comments

Comments
 (0)