44# This source code is licensed under the license found in the
55# LICENSE file in the root directory of this source tree.
66
7- from typing import Any , Optional , Tuple
7+ from typing import Any , List , Optional , Tuple
88
99import torch
1010import torch .nn .functional as F
2525 ZeroPointDomain ,
2626)
2727from torchao .quantization .unified import TwoStepQuantizer
28- from torchao .quantization .utils import get_group_qparams_symmetric
28+ from torchao .quantization .utils import (
29+ _get_per_token_block_size ,
30+ get_group_qparams_symmetric ,
31+ )
2932
3033
3134# =================
@@ -346,8 +349,14 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
346349 scales , zero_points = get_groupwise_affine_qparams (
347350 self .weight , n_bit , self .groupsize , self .scales_precision ,
348351 )
349- w_fq = _Int4WeightOnlyFakeQuantize .apply (
350- self .weight , scales , zero_points , qmin , qmax , self .groupsize ,
352+ w_fq = fake_quantize_per_channel_group (
353+ self .weight ,
354+ scales ,
355+ zero_points ,
356+ qmin ,
357+ qmax ,
358+ self .groupsize ,
359+ ZeroPointDomain .FLOAT ,
351360 )
352361 return F .linear (x , w_fq )
353362
@@ -370,39 +379,6 @@ def disable_4w_fake_quant(mod: torch.nn.Module):
370379# | QUANT PRIMITIVES |
371380# ========================
372381
373- class _Int4WeightOnlyFakeQuantize (torch .autograd .Function ):
374- """
375- Implementation of int4 grouped per channel weight-only fake quantize
376- intended to match the numerics of the efficient int4 tinygemm kernel.
377- """
378-
379- @staticmethod
380- def forward (ctx , input , scales , zero_points , quant_min , quant_max , groupsize ):
381- assert groupsize > 1
382- assert input .shape [- 1 ] % groupsize == 0
383- assert input .dim () == 2
384- n_bit = 4
385- block_size = (1 , groupsize )
386- quant_min = 0
387- quant_max = 2 ** n_bit - 1
388- (fq , mask ) = fake_quantize_affine_cachemask (
389- input ,
390- block_size ,
391- scales ,
392- zero_points ,
393- torch .int32 ,
394- quant_min ,
395- quant_max ,
396- zero_point_domain = ZeroPointDomain .FLOAT ,
397- )
398- ctx .save_for_backward (mask )
399- return fq
400-
401- @staticmethod
402- def backward (ctx , gy ):
403- (mask ,) = ctx .saved_tensors
404- return gy * mask , None , None , None , None , None
405-
406382class _GenericFakeQuantize (torch .autograd .Function ):
407383 """
408384 Implementation of generic fake quantize with backward STE.
@@ -412,71 +388,73 @@ class _GenericFakeQuantize(torch.autograd.Function):
412388 """
413389
414390 @staticmethod
415- def forward (ctx , input , scales , zero_points , quant_min , quant_max ):
391+ def forward (
392+ ctx : torch .autograd .function .FunctionCtx ,
393+ input : torch .Tensor ,
394+ scales : torch .Tensor ,
395+ zero_points : torch .Tensor ,
396+ quant_min : int ,
397+ quant_max : int ,
398+ block_size : List [int ],
399+ zero_point_domain : ZeroPointDomain = ZeroPointDomain .INT ,
400+ ) -> torch .Tensor :
416401 # Note: for bf16 inputs, casting them to fp32 has the unexpected
417402 # side effect of reducing memory footprint significantly, presumably
418403 # because bf16 * fp32 kernels are not as memory efficient
419404 assert input .dtype == torch .float32
420405 assert scales .dtype == torch .float32
421406 assert zero_points .dtype == torch .int32
422- q = input .mul (1.0 / scales ).round ().add (zero_points )
423- dq = q .clamp (quant_min , quant_max ).sub (zero_points ).mul (scales )
424- mask = torch .logical_and ((q >= quant_min ), (q <= quant_max ))
407+
408+ (fq , mask ) = fake_quantize_affine_cachemask (
409+ input ,
410+ block_size ,
411+ scales ,
412+ zero_points ,
413+ torch .int32 ,
414+ quant_min ,
415+ quant_max ,
416+ zero_point_domain ,
417+ )
418+
425419 ctx .save_for_backward (mask )
426- return dq
420+ return fq
427421
428422 @staticmethod
429423 def backward (ctx , gy ):
430424 (mask ,) = ctx .saved_tensors
431- return gy * mask , None , None , None , None , None
432-
433- # TODO: move this to core
434- quantized_decomposed_lib .define (
435- "fake_quantize_per_channel_group(Tensor input, Tensor scales, Tensor zero_points, "
436- "int quant_min, int quant_max, int group_size) -> Tensor"
437- )
425+ return gy * mask , None , None , None , None , None , None
438426
439- @impl (quantized_decomposed_lib , "fake_quantize_per_channel_group" , "CompositeImplicitAutograd" )
440427def fake_quantize_per_channel_group (
441428 input : torch .Tensor ,
442429 scales : torch .Tensor ,
443430 zero_points : torch .Tensor ,
444431 quant_min : int ,
445432 quant_max : int ,
446433 group_size : int ,
434+ zero_point_domain : ZeroPointDomain = ZeroPointDomain .INT ,
447435) -> torch .Tensor :
448436 assert group_size > 1
449437 assert input .shape [- 1 ] % group_size == 0
450438 assert input .dim () == 2
451- grouped_input = input .reshape (- 1 , group_size ).to (torch .float32 )
452- scales = scales .reshape (- 1 , 1 )
453- zero_points = zero_points .reshape (- 1 , 1 )
454- fq = _GenericFakeQuantize .apply (
455- grouped_input , scales , zero_points , quant_min , quant_max ,
439+ block_size = (1 , group_size )
440+ return _GenericFakeQuantize .apply (
441+ input , scales , zero_points , quant_min , quant_max , block_size , zero_point_domain ,
456442 )
457- return fq .reshape_as (input ).to (input .dtype )
458-
459- # TODO: move this to core
460- quantized_decomposed_lib .define (
461- "fake_quantize_per_token(Tensor input, Tensor scales, Tensor zero_points, "
462- "int quant_min, int quant_max) -> Tensor"
463- )
464443
465- @impl (quantized_decomposed_lib , "fake_quantize_per_token" , "CompositeImplicitAutograd" )
466444def fake_quantize_per_token (
467445 input : torch .Tensor ,
468446 scales : torch .Tensor ,
469447 zero_points : torch .Tensor ,
470448 quant_min : int ,
471449 quant_max : int ,
472450) -> torch .Tensor :
473- # TODO: we won't need this import anymore once we move this to core
474451 from torch .ao .quantization .fx ._decomposed import _per_token_quant_qparam_dim_check
475452
476453 _per_token_quant_qparam_dim_check (input , scales , zero_points )
454+ block_size = _get_per_token_block_size (input )
477455 fq_input = input .to (torch .float32 )
478456 fq = _GenericFakeQuantize .apply (
479- fq_input , scales , zero_points , quant_min , quant_max ,
457+ fq_input , scales , zero_points , quant_min , quant_max , block_size ,
480458 )
481459 return fq .reshape_as (input ).to (input .dtype )
482460
0 commit comments