We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent b33fd09 commit ca8c85aCopy full SHA for ca8c85a
torchao/prototype/quantized_training/int8_mixed_precision.py
@@ -231,10 +231,7 @@ def forward(
231
@staticmethod
232
def backward(ctx, grad_output):
233
input, weight = ctx.saved_tensors
234
- if isinstance(weight, Int8MixedPrecisionTrainingLinearWeight):
235
- weight = weight._data
236
- elif hasattr(weight, "get_original_weight"):
237
- weight = weight.get_original_weight() # dequant NF4
+ weight = weight.to(input.dtype) # dequant NF4
238
239
grad_input = grad_weight = grad_bias = None
240
0 commit comments