1515import torch
1616
1717from torch .utils ._python_dispatch import return_and_correct_aliasing
18- from torchao .quantization .quant_primitives import choose_qparams_affine , MappingType
18+ from torchao .quantization .quant_primitives import (
19+ choose_qparams_affine ,
20+ MappingType ,
21+ quantize_affine ,
22+ dequantize_affine ,
23+ )
1924from torchao .dtypes .utils import (
2025 LayoutType ,
2126 PlainLayoutType ,
2429
2530aten = torch .ops .aten
2631
32+ # TODO: move to torchao/utils.py
33+ def fill_defaults (args , n , defaults_tail ):
34+ """
35+ __torch_dispatch__ doesn't guarantee the number of arguments you are
36+ passed (e.g., defaulted arguments are not passed); but usually it is
37+ convenient to pad out the arguments list with defaults. This function
38+ helps you do that.
39+ Args:
40+ args: the list of positional arguments passed to __torch_dispatch__
41+ n: the number of arguments you are expecting to get
42+ defaults_tail: default values for the arguments, starting from the
43+ end of the list
44+ Example:
45+ >>> fill_defaults([1, 2, 3], 5, [3, 4, 5])
46+ [1, 2, 3, 4, 5]
47+ >>> fill_defaults([1, 2, 3], 5, [None, None, None])
48+ [1, 2, 3, None, None]]
49+ """
50+ if n - len (defaults_tail ) > len (args ):
51+ raise RuntimeError ("not enough defaults to fill arguments" )
52+ r = list (args )
53+ for i in range (len (args ), n ):
54+ r .append (defaults_tail [i - n + len (defaults_tail )])
55+ return r
56+
57+
2758###############################
2859# Base Layout Tensor Subclass #
2960###############################
@@ -140,10 +171,10 @@ def from_float(
140171 layout_type : LayoutType = PlainLayoutType (),
141172 ):
142173 mapping_type = MappingType .SYMMETRIC
143- block_size = input_float .shape
174+ block_size = ( 1 , input_float .shape [ - 1 ])
144175 dtype = torch .int16
145- scale , _ = choose_qparams_affine (input_float , mapping_type , block_size , dtype )
146- int_data = (input_float / scale ). to ( torch . int8 )
176+ scale , zero_point = choose_qparams_affine (input_float , mapping_type , block_size , dtype )
177+ int_data = quantize_affine (input_float , block_size , scale , zero_point , dtype )
147178 layout_tensor_ctr = get_layout_tensor_constructor (type (layout_type ))
148179 layout_tensor = layout_tensor_ctr (int_data , scale , layout_type )
149180 return cls (layout_tensor , input_float .shape )
@@ -160,7 +191,14 @@ def dequantize(self, output_dtype=None):
160191 if output_dtype is None :
161192 output_dtype = torch .get_default_dtype ()
162193 int_data , scale = self .layout_tensor .get_plain ()
163- return int_data .to (output_dtype ) * scale
194+ transposed = False
195+ block_size = (1 , int_data .shape [- 1 ])
196+ if hasattr (self .layout_tensor , "transposed" ) and self .layout_tensor .transposed :
197+ transposed = True
198+ res = dequantize_affine (int_data , block_size , scale , None , int_data .dtype , output_dtype = output_dtype )
199+ if transposed :
200+ res = res .t ()
201+ return res
164202
165203 def __repr__ (self ):
166204 return (
@@ -203,6 +241,7 @@ def __new__(
203241 cls ,
204242 int_data : torch .Tensor ,
205243 scale : torch .Tensor ,
244+ transposed : bool ,
206245 layout_type : LayoutType ,
207246 ):
208247 kwargs = {}
@@ -219,22 +258,24 @@ def __init__(
219258 self ,
220259 int_data : torch .Tensor ,
221260 scale : torch .Tensor ,
261+ transposed : bool ,
222262 layout_type : LayoutType ,
223263 ):
224264 self .int_data = int_data
225265 self .scale = scale
266+ self .transposed = transposed
226267 self .layout_type = layout_type
227268
228269 def __tensor_flatten__ (self ):
229- return ["int_data" , "scale" ], [self .layout_type ]
270+ return ["int_data" , "scale" ], [self .transposed , self . layout_type ]
230271
231272 @classmethod
232273 def __tensor_unflatten__ (
233274 cls , tensor_data_dict , tensor_attributes , outer_size , outer_stride
234275 ):
235276 int_data , scale = tensor_data_dict ["int_data" ], tensor_data_dict ["scale" ]
236- layout_type , = tensor_attributes
237- return cls (int_data , scale , layout_type )
277+ transposed , layout_type , = tensor_attributes
278+ return cls (int_data , scale , transposed , layout_type )
238279
239280 @classmethod
240281 def from_plain (
@@ -247,12 +288,13 @@ def from_plain(
247288 extra metadata for packing etc.
248289 """
249290 assert isinstance (layout_type , PlainLayoutType )
250- return cls (int_data , scale , layout_type )
291+ return cls (int_data , scale , False , layout_type )
251292
252293 def _apply_fn_to_data (self , fn ):
253294 return self .__class__ (
254295 fn (self .int_data ),
255296 fn (self .scale ),
297+ self .transposed ,
256298 self .layout_type ,
257299 )
258300
@@ -265,8 +307,36 @@ def __torch_dispatch__(cls, func, types, args, kwargs):
265307 func , args , kwargs , args [0 ]._apply_fn_to_data (torch .detach )
266308 )
267309
310+ # Tensor parallel support START
311+ elif func in [aten ._to_copy .default , aten .clone .default ]:
312+ return return_and_correct_aliasing (
313+ func , args , kwargs , args [0 ]._apply_fn_to_data (torch .clone )
314+ )
315+ elif func is aten .split .Tensor :
316+ int_data_list = func (args [0 ].int_data , * args [1 :], ** kwargs )
317+ scale_list = func (args [0 ].scale , * args [1 :], ** kwargs )
318+ out = [PlainMyDTypeLayout (int_data , scale , args [0 ].transposed , args [0 ].layout_type ) for int_data , scale in zip (int_data_list , scale_list )]
319+ return out
320+ elif func is aten .empty_like .default :
321+ int_data_empty_like = func (args [0 ].int_data , * args [1 :], ** kwargs )
322+ return PlainMyDTypeLayout (int_data_empty_like , args [0 ].scale , args [0 ].transposed , args [0 ].layout_type )
323+ elif func is aten .slice .Tensor :
324+ self , dim , start , end , step = fill_defaults (args , 5 , [0 , None , None , 1 ])
325+ if dim == 0 :
326+ return return_and_correct_aliasing (
327+ func , args , kwargs , args [0 ]._apply_fn_to_data (lambda x : aten .slice .Tensor (x , dim , start , end , step ))
328+ )
329+ elif dim == 1 :
330+ return PlainMyDTypeLayout (aten .slice .Tensor (self .int_data , dim , start , end , step ), self .scale .view (- 1 , 1 ), self .transposed , self .layout_type )
331+ else :
332+ raise NotImplementedError (f"PlainMyDTypeLayout dispatch: attempting to run { func } , with dim={ dim } , that is not supported" )
333+ elif func is aten .t .default :
334+ return return_and_correct_aliasing (func , args , kwargs , PlainMyDTypeLayout (args [0 ].int_data , args [0 ].scale , not args [0 ].transposed , args [0 ].layout_type ))
335+
336+ # Tensor parallel support END
337+
268338 raise NotImplementedError (
269- f"MyDTypeLayout dispatch: attempting to run { func } , this is not supported"
339+ f"PlainMyDTypeLayout dispatch: attempting to run { func } , this is not supported"
270340 )
271341
272342#####################################################
@@ -315,15 +385,6 @@ def _(func, types, args, kwargs):
315385 func , args , kwargs , args [0 ]._apply_fn_to_data (torch .detach )
316386 )
317387
318-
319- class M (torch .nn .Module ):
320- def __init__ (self , * args , ** kwargs ) -> None :
321- super ().__init__ (* args , ** kwargs )
322- self .linear = torch .nn .Linear (1024 , 1024 )
323-
324- def forward (self , x : torch .Tensor ) -> torch .Tensor :
325- return self .linear (x )
326-
327388#####################
328389# Factory functions #
329390#####################
@@ -333,42 +394,49 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
333394########
334395# Test #
335396########
336-
337- def test ():
397+ def main ():
338398 from torchao .utils import benchmark_model
339-
399+
400+ class M (torch .nn .Module ):
401+ def __init__ (self ) -> None :
402+ super ().__init__ ()
403+ self .linear = torch .nn .Linear (1024 , 128 )
404+
405+ def forward (self , x : torch .Tensor ) -> torch .Tensor :
406+ return self .linear (x )
407+
340408 m = M ()
341- example_inputs = (100 * torch .randn (1024 , 1024 ),)
409+ example_inputs = (100 * torch .randn (512 , 1024 ),)
342410 NUM_WARMUPS = 10
343411 NUM_RUNS = 100
344-
412+
345413 for _ in range (NUM_WARMUPS ):
346414 m (* example_inputs )
347415 print ("before quantization:" , benchmark_model (m , NUM_RUNS , example_inputs ))
348-
416+
349417 compiled = torch .compile (m , mode = "max-autotune" )
350418 for _ in range (NUM_WARMUPS ):
351419 compiled (* example_inputs )
352420 print ("after compile:" , benchmark_model (compiled , NUM_RUNS , example_inputs ))
353-
421+
354422 # convert weights to quantized weights
355423 m .linear .weight = torch .nn .Parameter (
356424 to_my_dtype (m .linear .weight ), requires_grad = False
357425 )
358-
426+
359427 for _ in range (NUM_WARMUPS ):
360428 m (* example_inputs )
361-
429+
362430 print ("after quantization:" , benchmark_model (m , NUM_RUNS , example_inputs ))
363-
431+
364432 m = torch .compile (m , mode = "max-autotune" )
365-
433+
366434 for _ in range (NUM_WARMUPS ):
367435 m (* example_inputs )
368-
436+
369437 # NOTE: currently there is no speedup because we just dequantize the weight in the _quantized_linear op
370438 # we plan to add custom op example in the future and that will help us to get speedup
371439 print ("after quantization and compile:" , benchmark_model (m , NUM_RUNS , example_inputs ))
372440
373441if __name__ == "__main__" :
374- test ()
442+ main ()
0 commit comments