1- from typing import Any , NamedTuple , Optional , Tuple
1+ from typing import Any , NamedTuple , Optional , Tuple , Union
22
33import torch
44import torch .utils ._pytree as pytree
5- from torch import Tensor
5+ from torch import Tensor , nn
66from torch .utils ._triton import has_triton
77
88from torchao .quantization .quant_api import _get_linear_subclass_inserter
@@ -75,7 +75,7 @@ def to_original(self):
7575 def __torch_dispatch__ (cls , func , types , args , kwargs ):
7676 config = None
7777
78- def unwrap (x : cls ):
78+ def unwrap (x ):
7979 nonlocal config
8080 if config is None :
8181 config = x .config
@@ -151,7 +151,16 @@ def _(func, types, args, kwargs):
151151 if torch .is_autocast_enabled ("cuda" ):
152152 dtype = torch .get_autocast_gpu_dtype ()
153153 args = tuple (x .to (dtype ) if x is not None else x for x in args )
154- return _Int8MixedPrecisionTrainingLinear .apply (* args , ** kwargs )
154+ return _Int8MixedPrecisionTrainingLinearFunction .apply (* args , ** kwargs )
155+
156+
157+ class Int8MixedPrecisionTrainingLinear (nn .Linear ):
158+ def __init__ (self , * args , config : Int8MixedPrecisionTrainingConfig , ** kwargs ) -> None :
159+ super ().__init__ (* args , ** kwargs )
160+ self .config = config
161+
162+ def forward (self , input : Tensor ) -> Tensor :
163+ return _Int8MixedPrecisionTrainingLinearFunction .apply (input , self .weight , self .bias , self .config )
155164
156165
157166def _dynamic_int8_mm (A : Tensor , B : Tensor ) -> Tensor :
@@ -184,26 +193,46 @@ def _dynamic_int8_mm(A: Tensor, B: Tensor) -> Tensor:
184193 return out .view (* A .shape [:- 1 ], out .shape [- 1 ])
185194
186195
187- class _Int8MixedPrecisionTrainingLinear (torch .autograd .Function ):
196+ @torch .compiler .allow_in_graph # this is required for module-swap, but not for tensor subclass
197+ class _Int8MixedPrecisionTrainingLinearFunction (torch .autograd .Function ):
188198 @staticmethod
189- def forward (input : Tensor , weight : Int8MixedPrecisionTrainingLinearWeight , bias : Optional [Tensor ]):
190- if weight .config .output :
191- out = _dynamic_int8_mm (input , weight ._data .T )
199+ def forward (
200+ ctx ,
201+ input : Tensor ,
202+ weight : Union [Int8MixedPrecisionTrainingLinearWeight , Tensor ],
203+ bias : Optional [Tensor ],
204+ config : Optional [Int8MixedPrecisionTrainingConfig ] = None ,
205+ ):
206+ # unpack tensor subclass and dequant if necessary.
207+ # NOTE: we have to do this inside autograd.Function so that autograd works correctly.
208+ if isinstance (weight , Int8MixedPrecisionTrainingLinearWeight ):
209+ config = weight .config # override `config` input argument
210+ weight = weight ._data
211+
212+ ctx .config = config
213+ ctx .save_for_backward (input , weight )
214+ ctx .bias = bias is not None
215+
216+ # for NF4Tensor, this will dequantize the tensor.
217+ # NOTE: not all quantized tensor subclasses implement .to() this way.
218+ # e.g. AffineQuantizedTensor.to(dtype=dtype) returns the same AQT tensor.
219+ # casting weight dtype may also introduce unintended behavior.
220+ # e.g. FP32 activations and BF16 weight (both plain tensors), which should raise an error,
221+ # but now we cast BF16 weight to FP32 instead (and return results in FP32).
222+ weight = weight .to (input .dtype )
223+
224+ if config .output :
225+ out = _dynamic_int8_mm (input , weight .T )
192226 else :
193- out = input @ weight ._data . T
227+ out = input @ weight .T
194228 out = out + bias if bias is not None else out
195229 return out
196230
197- @staticmethod
198- def setup_context (ctx , inputs , output ):
199- input , weight , bias = inputs
200- ctx .config = weight .config
201- ctx .save_for_backward (input , weight ._data )
202- ctx .bias = bias is not None
203-
204231 @staticmethod
205232 def backward (ctx , grad_output ):
206233 input , weight = ctx .saved_tensors
234+ weight = weight .to (input .dtype ) # dequant NF4
235+
207236 grad_input = grad_weight = grad_bias = None
208237
209238 if ctx .needs_input_grad [0 ]:
@@ -224,12 +253,28 @@ def backward(ctx, grad_output):
224253 if ctx .needs_input_grad [2 ] and ctx .bias :
225254 grad_bias = grad_output .sum (0 )
226255
227- return grad_input , grad_weight , grad_bias
228-
229-
230- def int8_mixed_precision_training (config : Int8MixedPrecisionTrainingConfig = _DEFAULT_CONFIG ):
231- return _get_linear_subclass_inserter (
232- Int8MixedPrecisionTrainingLinearWeight ,
233- config = config ,
234- allow_requires_grad = True ,
235- )
256+ return grad_input , grad_weight , grad_bias , None
257+
258+
259+ def int8_mixed_precision_training (
260+ config : Int8MixedPrecisionTrainingConfig = _DEFAULT_CONFIG ,
261+ * ,
262+ module_swap : bool = False ,
263+ ):
264+ # TODO: skip small layers that don't have perf gain.
265+ if module_swap :
266+ # module swap implementation
267+ def convert_linear (linear : nn .Linear ):
268+ linear .__class__ = Int8MixedPrecisionTrainingLinear
269+ linear .config = config
270+ return linear
271+
272+ return convert_linear
273+
274+ else :
275+ # tensor subclass implementation
276+ return _get_linear_subclass_inserter (
277+ Int8MixedPrecisionTrainingLinearWeight ,
278+ config = config ,
279+ allow_requires_grad = True ,
280+ )
0 commit comments