2020 _ATEN_OP_OR_TORCH_FN_TABLE ,
2121 _register_layout_cls ,
2222 _get_layout_tensor_constructor ,
23+ LayoutType ,
2324)
25+ from typing import ClassVar
26+ from dataclasses import dataclass
2427
2528aten = torch .ops .aten
2629
30+ @dataclass (frozen = True )
31+ class PlainLayoutType (LayoutType ):
32+ pass
33+
34+ @dataclass (frozen = True )
35+ class TensorCoreTiledLayoutType (LayoutType ):
36+ inner_k_tiles : int = 8
37+
38+ def pre_process (self , input : torch .Tensor ) -> torch .Tensor :
39+ orig_out_features , orig_in_features = input .shape
40+ in_features = find_multiple (orig_in_features , 1024 )
41+ out_features = find_multiple (orig_out_features , 8 )
42+ input = torch .nn .functional .pad (
43+ input ,
44+ (0 , in_features - orig_in_features , 0 , out_features - orig_out_features ),
45+ )
46+ return input
47+
48+ def extra_repr (self ):
49+ return f"inner_k_tiles={ self .inner_k_tiles } "
50+
51+
2752def _aqt_is_int8 (aqt ):
2853 """Check if an AffineQuantizedTensor is int8 quantized Tensor"""
2954 return (
@@ -52,10 +77,10 @@ class AQTLayout(torch.Tensor):
5277 """
5378 Base class for the layout tensor for `AffineQuantizedTensor`
5479 """
55- # this should be set for each layout class during registration
56- extended_layout : Optional [ str ] = None
80+ def get_plain ( self ) -> Tuple [ torch . Tensor , torch . Tensor , torch . Tensor ]:
81+ pass
5782
58- def get_plain ( ) -> Tuple [ torch . Tensor , torch . Tensor , torch . Tensor ] :
83+ def get_layout_type ( self ) -> LayoutType :
5984 pass
6085
6186 @classmethod
@@ -64,9 +89,15 @@ def from_plain(
6489 int_data : torch .Tensor ,
6590 scale : torch .Tensor ,
6691 zero_point : torch .Tensor ,
92+ layout_type : LayoutType ,
6793 ):
6894 pass
6995
96+ def __repr__ (self ):
97+ int_data , scale , zero_point = self .get_plain ()
98+ layout_type = self .get_layout_type ()
99+ return f"{ self .__class__ .__name__ } (int_data={ int_data } , scale={ scale } , zero_point={ zero_point } , layout_type={ layout_type } )"
100+
70101 def _get_to_kwargs (self , * args , ** kwargs ):
71102 device , dtype , _ , memory_format = torch ._C ._nn ._parse_to (* args , ** kwargs )
72103 device = self .device if device is None else device
@@ -194,30 +225,17 @@ def from_float(
194225 zero_point_dtype : Optional [torch .dtype ] = None ,
195226 preserve_zero : bool = True ,
196227 zero_point_domain : ZeroPointDomain = ZeroPointDomain .INT ,
197- extended_layout : str = "plain" ,
198- # TODO: this is only for "tensor_core_tiled", need to figure out
199- # the proper API for this arg
200- inner_k_tiles : Optional [int ] = None ,
228+ layout_type : LayoutType = PlainLayoutType (),
201229 ):
202230 original_shape = input_float .shape
203- if extended_layout == "tensor_core_tiled" :
204- orig_out_features , orig_in_features = input_float .shape
205- in_features = find_multiple (orig_in_features , 1024 )
206- out_features = find_multiple (orig_out_features , 8 )
207- input_float = torch .nn .functional .pad (
208- input_float ,
209- (0 , in_features - orig_in_features , 0 , out_features - orig_out_features ),
210- )
231+ input_float = layout_type .pre_process (input_float )
211232
212233 scale , zero_point = choose_qparams_affine (input_float , mapping_type , block_size , target_dtype , quant_min , quant_max , eps , scale_dtype , zero_point_dtype , preserve_zero , zero_point_domain )
213234 int_data = quantize_affine (input_float , block_size , scale , zero_point , target_dtype , quant_min , quant_max , zero_point_domain )
235+ int_data = layout_type .post_process (int_data )
214236
215- layout_cls_ctr = get_layout_tensor_constructor (extended_layout )
216- # TODO: this is temporary, need to come up with the proper UX
217- if extended_layout == "tensor_core_tiled" :
218- layout_tensor = layout_cls_ctr (int_data , scale , zero_point , inner_k_tiles )
219- else :
220- layout_tensor = layout_cls_ctr (int_data , scale , zero_point )
237+ layout_tensor_ctr = get_layout_tensor_constructor (type (layout_type ))
238+ layout_tensor = layout_tensor_ctr (int_data , scale , zero_point , layout_type )
221239 return cls (
222240 layout_tensor ,
223241 block_size ,
@@ -229,8 +247,8 @@ def from_float(
229247 )
230248
231249 @property
232- def extended_layout (self ) -> str :
233- return self .layout_tensor .extended_layout
250+ def layout_type (self ) -> str :
251+ return self .layout_tensor .layout_type
234252
235253 @classmethod
236254 def __torch_function__ (cls , func , types , args = (), kwargs = None ):
@@ -308,13 +326,13 @@ def __torch_dispatch__(cls, func, types, args, kwargs):
308326def implements (aten_ops_or_torch_fn ):
309327 return _implements (AffineQuantizedTensor , aten_ops_or_torch_fn )
310328
311- def register_layout_cls (extended_layout : str ):
312- return _register_layout_cls (AffineQuantizedTensor , extended_layout )
329+ def register_layout_cls (layout_type_class : type ( LayoutType ) ):
330+ return _register_layout_cls (AffineQuantizedTensor , layout_type_class )
313331
314- def get_layout_tensor_constructor (extended_layout : str ):
315- return _get_layout_tensor_constructor (AffineQuantizedTensor , extended_layout )
332+ def get_layout_tensor_constructor (layout_type_class : type ( LayoutType ) ):
333+ return _get_layout_tensor_constructor (AffineQuantizedTensor , layout_type_class )
316334
317- @register_layout_cls ("plain" )
335+ @register_layout_cls (PlainLayoutType )
318336class PlainAQTLayout (AQTLayout ):
319337 """
320338 Layout storage class for plain layout for affine quantized tensor, it stores int_data, scale, zero_point
@@ -330,6 +348,7 @@ def __new__(
330348 int_data : torch .Tensor ,
331349 scale : torch .Tensor ,
332350 zero_point : torch .Tensor ,
351+ layout_type : LayoutType ,
333352 ):
334353 kwargs = {}
335354 kwargs ["device" ] = int_data .device
@@ -346,34 +365,39 @@ def __init__(
346365 int_data : torch .Tensor ,
347366 scale : torch .Tensor ,
348367 zero_point : torch .Tensor ,
368+ layout_type : LayoutType ,
349369 ):
350370 self .int_data = int_data
351371 self .scale = scale
352372 self .zero_point = zero_point
373+ self .layout_type = layout_type
353374
354375 def __tensor_flatten__ (self ):
355- return ["int_data" , "scale" , "zero_point" ], []
376+ return ["int_data" , "scale" , "zero_point" ], [self . layout_type ]
356377
357378 @classmethod
358379 def __tensor_unflatten__ (
359380 cls , tensor_data_dict , tensor_attributes , outer_size , outer_stride
360381 ):
361382 int_data , scale , zero_point = tensor_data_dict ["int_data" ], tensor_data_dict ["scale" ], tensor_data_dict ["zero_point" ]
362- return cls (int_data , scale , zero_point )
383+ layout_type , = tensor_attributes
384+ return cls (int_data , scale , zero_point , layout_type )
363385
364386 def to (self , * args , ** kwargs ):
365387 kwargs = self ._get_to_kwargs (* args , ** kwargs )
366388 return self .__class__ (
367389 self .int_data .to (kwargs ["device" ]),
368390 self .scale .to (kwargs ["device" ]),
369391 self .zero_point .to (kwargs ["device" ]),
392+ self .layout_type ,
370393 )
371394
372395 def _apply_fn_to_data (self , fn ):
373396 return self .__class__ (
374397 fn (self .int_data ),
375398 fn (self .scale ),
376399 fn (self .zero_point ),
400+ self .layout_type ,
377401 )
378402
379403 @classmethod
@@ -398,19 +422,24 @@ def __torch_dispatch__(cls, func, types, args, kwargs):
398422
399423 __torch_function__ = torch ._C ._disabled_torch_function_impl
400424
401- def get_plain (self ):
425+ def get_plain (self ) -> Tuple [ torch . Tensor , torch . Tensor , torch . Tensor ] :
402426 return self .int_data , self .scale , self .zero_point
403427
428+ def get_layout_type (self ) -> LayoutType :
429+ return self .layout_type
430+
404431 @classmethod
405432 def from_plain (
406433 cls ,
407434 int_data : torch .Tensor ,
408435 scale : torch .Tensor ,
409436 zero_point : torch .Tensor ,
437+ layout_type : LayoutType ,
410438 ):
411- return cls (int_data , scale , zero_point )
439+ assert isinstance (layout_type , PlainLayoutType )
440+ return cls (int_data , scale , zero_point , layout_type )
412441
413- @register_layout_cls ("tensor_core_tiled" )
442+ @register_layout_cls (TensorCoreTiledLayoutType )
414443class TensorCoreTiledAQTLayout (AQTLayout ):
415444 """
416445 Layout storage class for tensor_core_tiled layout for affine quantized tensor, this is for int4 only,
@@ -427,6 +456,7 @@ def __new__(
427456 packed_weight : torch .Tensor ,
428457 scale_and_zero : torch .Tensor ,
429458 transposed : bool ,
459+ layout_type : LayoutType ,
430460 ):
431461 kwargs = {}
432462 kwargs ["device" ] = packed_weight .device
@@ -443,31 +473,40 @@ def __init__(
443473 packed_weight : torch .Tensor ,
444474 scale_and_zero : torch .Tensor ,
445475 transposed : bool ,
476+ layout_type : LayoutType ,
446477 ):
447478 self .packed_weight = packed_weight
448479 self .scale_and_zero = scale_and_zero
449480 self .transposed = False
481+ self .layout_type = layout_type
450482
451483 def __tensor_flatten__ (self ):
452- return ["packed_weight" , "scale_and_zero" ], [self .transposed ]
484+ return ["packed_weight" , "scale_and_zero" ], [self .transposed , self . layout_type ]
453485
454486 @classmethod
455487 def __tensor_unflatten__ (
456488 cls , tensor_data_dict , tensor_attributes , outer_size , outer_stride
457489 ):
458490 packed_weight , scale_and_zero = tensor_data_dict ["packed_weight" ], tensor_data_dict ["scale_and_zero" ]
459- transposed , = tensor_attributes
460- return cls (packed_weight , scale_and_zero , transposed )
491+ transposed , layout_type , = tensor_attributes
492+ return cls (packed_weight , scale_and_zero , transposed , layout_type )
461493
462494 @classmethod
463- def from_plain (cls , int_data , scale , zero_point , inner_k_tiles = 8 ):
495+ def from_plain (
496+ cls ,
497+ int_data : torch .Tensor ,
498+ scale : torch .Tensor ,
499+ zero_point : torch .Tensor ,
500+ layout_type : LayoutType
501+ ):
502+ assert isinstance (layout_type , TensorCoreTiledLayoutType )
464503 # assert int_data.dtype == torch.uint8, "torch.ops.aten._convert_weight_to_int4pack expects `uint8` dtype"
465504 # packed_weight = torch.ops.aten._convert_weight_to_int4pack(int_data, inner_k_tiles)
466- packed_weight = torch .ops .aten ._convert_weight_to_int4pack (int_data .to (torch .int32 ), inner_k_tiles )
505+ packed_weight = torch .ops .aten ._convert_weight_to_int4pack (int_data .to (torch .int32 ), layout_type . inner_k_tiles )
467506 scale = scale .reshape (int_data .shape [0 ], - 1 )
468507 zero_point = zero_point .reshape (int_data .shape [0 ], - 1 )
469508 scale_and_zero = pack_tinygemm_scales_and_zeros (scale , zero_point )
470- return cls (packed_weight , scale_and_zero , False )
509+ return cls (packed_weight , scale_and_zero , False , layout_type )
471510
472511 def to (self , * args , ** kwargs ):
473512 kwargs = self ._get_to_kwargs (* args , ** kwargs )
@@ -477,18 +516,15 @@ def to(self, *args, **kwargs):
477516 return self .__class__ (
478517 self .packed_weight .to (device ),
479518 self .scale_and_zero .to (device ),
480- self .transposed
519+ self .transposed ,
520+ self .layout_type ,
481521 )
482522
483523 def _apply_fn_to_data (self , fn ):
484524 self .packed_weight = fn (self .packed_weight )
485525 self .scale_and_zero = fn (self .scale_and_zero )
486526 return self
487527
488- def __repr__ (self ):
489- int_data , scale , zero_point = self .get_plain ()
490- return f"TensorCoreTiledAQTLayout(int_data={ int_data } , scale={ scale } , zero_point={ zero_point } )"
491-
492528 @classmethod
493529 def __torch_dispatch__ (cls , func , types , args , kwargs ):
494530 kwargs = {} if kwargs is None else kwargs
@@ -511,7 +547,7 @@ def __torch_dispatch__(cls, func, types, args, kwargs):
511547
512548 __torch_function__ = torch ._C ._disabled_torch_function_impl
513549
514- def get_plain (self ):
550+ def get_plain (self ) -> Tuple [ torch . Tensor , torch . Tensor , torch . Tensor ] :
515551 from torchao .quantization .quant_primitives import (
516552 ZeroPointDomain ,
517553 quantize_affine ,
@@ -542,6 +578,9 @@ def get_plain(self):
542578 int_data = quantize_affine (dequantized , block_size , scale , zero , target_dtype , quant_min , quant_max , zero_point_domain )
543579 return int_data , scale , zero
544580
581+ def get_layout_type (self ) -> LayoutType :
582+ return self .layout_type
583+
545584def _quantized_linear_op (input_tensor , weight_qtensor , bias ):
546585 """
547586 Quantized version of F.linear operator
@@ -565,8 +604,8 @@ def _quantized_linear_op(input_tensor, weight_qtensor, bias):
565604 is_cuda and
566605 input_is_int8 and
567606 input_tensor .dtype == weight_qtensor .dtype and
568- input_tensor .extended_layout == "plain" and
569- weight_qtensor .extended_layout == "plain"
607+ isinstance ( input_tensor .layout_type , PlainLayoutType ) and
608+ isinstance ( weight_qtensor .layout_type , PlainLayoutType )
570609 ):
571610 #
572611 # 1. do the matrix form of dot(X_i, W_j)
@@ -608,7 +647,7 @@ def _quantized_linear_op(input_tensor, weight_qtensor, bias):
608647 weight_qtensor .dtype == torch .bfloat16 and
609648 len (weight_qtensor .shape ) == 2 and
610649 weight_qtensor .zero_point_domain == ZeroPointDomain .FLOAT and
611- weight_qtensor .extended_layout == "tensor_core_tiled"
650+ isinstance ( weight_qtensor .layout_type , TensorCoreTiledLayoutType )
612651 ):
613652 assert weight_qtensor .block_size [0 ] == 1 , f"Requires groupwise quantization, got block_size: { block_size } "
614653 assert input_tensor .shape [- 1 ] == weight_qtensor .shape [1 ], (
@@ -651,7 +690,7 @@ def _quantized_linear_op(input_tensor, weight_qtensor, bias):
651690 weight_qtensor .block_size [0 ] == 1 and
652691 weight_qtensor .block_size [1 ] == weight_qtensor .shape [1 ] and
653692 weight_qtensor .zero_point_domain == ZeroPointDomain .INT and
654- weight_qtensor .extended_layout == "plain"
693+ isinstance ( weight_qtensor .layout_type , PlainLayoutType )
655694 ):
656695 # TODO: enable cpu and mps efficient path
657696 # per channel int8 weight only quantizated mm
0 commit comments