@@ -337,15 +337,16 @@ def _check_per_channel_group_params(
337337 # For now group quantization is only supported for 4b weights
338338 assert quant_params .is_qc4w , "Only 4b group quantization is supported"
339339
340- def define_tensor (
340+ def define_tensor ( # noqa: C901
341341 self ,
342342 tensor : torch .fx .Node ,
343343 xnn_graph : XNNGraph ,
344344 vals_to_ids : Dict [torch .fx .Node , int ],
345345 convert_to_nhwc : bool = False ,
346- swap_nc_for_depthwise_weights : bool = False ,
346+ swap_in_out_for_weights : bool = False ,
347347 quant_params : Optional [QuantParams ] = None ,
348348 fp32_static_weights : bool = False ,
349+ groups : int = 1 ,
349350 ) -> None :
350351 """
351352 Defines an tensor value into the XNNGraph
@@ -357,16 +358,21 @@ def define_tensor(
357358 their corresponding ids in XNNGraph
358359 convert_to_nhwc: bool to indicate whether tensor shape should be permuted to
359360 reflect the nhwc memory format.
360- swap_nc_for_depthwise_weights : bool to indicate whether tensor shape
361- should be permuted such that the N and C dimensions are
362- swapped , which should be used for depthwise convolution
361+ swap_in_out_for_weights : bool to indicate whether tensor shape should be
362+ permuted and reshape from (inc, oc/groups, height, width) to (oc, inc/groups, height, width)
363+ , which should be used for depthwise/transpose convolution
363364 weights. This is only valid for tensors which hold
364365 constant data. If used along with convert_to_nhwc, this
365366 swap will happen before converting to nhwc.
366367 quant_params: Quantization meta data for this tensor, None if it is not quantized
367368 fp32_static_weights: XNN_FLAG_FP32_STATIC_WEIGHTS for fp16 conv
369+ groups: number of groups for swap_in_out_for_weights
368370 """
369371
372+ assert (
373+ swap_in_out_for_weights or groups == 1
374+ ), "groups is option for swap_in_out_for_weights"
375+
370376 if tensor in vals_to_ids :
371377 return
372378
@@ -394,15 +400,16 @@ def define_tensor(
394400 xnn_graph ,
395401 vals_to_ids ,
396402 convert_to_nhwc ,
397- swap_nc_for_depthwise_weights ,
403+ swap_in_out_for_weights ,
398404 quant_params ,
399405 fp32_static_weights ,
406+ groups ,
400407 )
401408
402409 # convert tensor shape must reflect memory format, default is contiguous, so
403410 # only permute shape if we are converting the tensor to nhwc format
404- if swap_nc_for_depthwise_weights :
405- dims = [dims [1 ], dims [0 ]] + dims [2 :]
411+ if swap_in_out_for_weights :
412+ dims = [dims [1 ] * groups , dims [0 ] // groups ] + dims [2 :]
406413 if convert_to_nhwc :
407414 check_or_raise (len (dims ) == 4 , "Converting to nhwc requires 4d tensor" )
408415 dims = [dims [i ] for i in PERM_NCHW_TO_NHWC ]
@@ -422,16 +429,16 @@ def define_tensor(
422429 )
423430
424431 # Override the quant params axis since we have
425- # updated the weights for depthwise, with that the out_channels dim
432+ # updated the weights for depthwise/ transposed_conv2d , with that the out_channels dim
426433 # will be dims[3] instead of dims[0]. Let's update the per_channel
427434 # quant axis to match the new weight tensor before serializing
428- if swap_nc_for_depthwise_weights and (
429- quant_params and quant_params .per_channel
430- ):
435+ if swap_in_out_for_weights and (quant_params and quant_params .per_channel ):
431436 if quant_params .axis == 0 :
432437 quant_params .axis = len (dims ) - 1
438+ elif quant_params .axis == 1 :
439+ quant_params .axis = 0
433440 else :
434- assert f"Unsupported weight per channel quantization axis for depthwise conv2d: { quant_params .axis } , expecting 0."
441+ assert f"Unsupported weight per channel quantization axis for depthwise conv2d / conv_transpose2d : { quant_params .axis } , expecting 0 / 1 ."
435442
436443 # Serialize tensor value
437444 ser_val = (
@@ -492,9 +499,10 @@ def get_serialized_buffer_index(
492499 xnn_graph : XNNGraph ,
493500 vals_to_ids : Dict [torch .fx .Node , int ],
494501 convert_to_nhwc : bool ,
495- swap_nc_for_depthwise_weights : bool ,
502+ swap_in_out_for_weights : bool ,
496503 quant_params : Optional [QuantParams ],
497504 fp32_static_weights : bool = False ,
505+ groups : int = 1 ,
498506 ) -> int :
499507 """
500508 If tensor holds some constant data, serialize it and return the
@@ -507,24 +515,30 @@ def get_serialized_buffer_index(
507515 their corresponding ids in XNNGraph
508516 convert_to_nhwc: bool to indicate whether tensor shape should be permuted to
509517 reflect the nhwc memory format.
510- swap_nc_for_depthwise_weights : bool to indicate whether tensor shape
511- should be permuted such that the N and C dimensions are
512- swapped , which should be used for depthwise convolution
518+ swap_in_out_for_weights : bool to indicate whether tensor shape should be
519+ permuted and reshape from (inc, oc/groups, height, width) to (oc, inc/groups, height, width)
520+ , which should be used for depthwise/transpose convolution
513521 weights. This is only valid for tensors which hold
514522 constant data. If used along with convert_to_nhwc, this
515523 swap will happen before converting to nhwc.
516524 quant_params: Quantization meta data for this tensor, None if it is not quantize
517525 fp32_static_weights: bool to indicate whether tensor is fp32 static weights
526+ groups: groups for swap_in_out_for_weights
518527
519528 Returns:
520529 buffer_idx: idx of the serialized data. 0 If not associated constant
521530 data
522531 """
532+
533+ assert (
534+ swap_in_out_for_weights or groups == 1
535+ ), "groups is option for swap_in_out_for_weights"
536+
523537 # The get_attr node is the input to quant_params.
524538 get_attr_node = tensor if quant_params is None else quant_params .q_input
525539 if not is_param_node (self .exported_program , get_attr_node ):
526540 check_or_raise (
527- not swap_nc_for_depthwise_weights ,
541+ not swap_in_out_for_weights ,
528542 "Swapping N and C dimensions is only valid for constant data tensors" ,
529543 )
530544 return 0
@@ -541,9 +555,16 @@ def get_serialized_buffer_index(
541555 # ensure that the const is fp32
542556 const_val = const_val .to (dtype = torch .float32 ).contiguous ()
543557
544- if swap_nc_for_depthwise_weights :
545- const_val = const_val .permute (
546- dims = ((1 , 0 ) + tuple (range (2 , const_val .dim ())))
558+ if swap_in_out_for_weights :
559+ # Permute and reshape the tensor from (inc, oc/groups, height, width) to (oc, inc/groups, height, width)
560+ # which should be used for depthwise/transpose convolution weights for XNNPACK
561+ shape = const_val .shape
562+ const_val = const_val .reshape (
563+ (groups , const_val .shape [0 ] // groups ) + const_val .shape [1 :]
564+ )
565+ const_val = const_val .permute ((0 , 2 , 1 ) + tuple (range (3 , const_val .dim ())))
566+ const_val = const_val .reshape (
567+ (shape [1 ] * groups , shape [0 ] // groups ) + shape [2 :]
547568 ).contiguous ()
548569
549570 if convert_to_nhwc :
0 commit comments