Skip to content

Commit ca8c85a

Browse files
committed
update backward pass
1 parent b33fd09 commit ca8c85a

File tree

1 file changed

+1
-4
lines changed

1 file changed

+1
-4
lines changed

torchao/prototype/quantized_training/int8_mixed_precision.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -231,10 +231,7 @@ def forward(
231231
@staticmethod
232232
def backward(ctx, grad_output):
233233
input, weight = ctx.saved_tensors
234-
if isinstance(weight, Int8MixedPrecisionTrainingLinearWeight):
235-
weight = weight._data
236-
elif hasattr(weight, "get_original_weight"):
237-
weight = weight.get_original_weight() # dequant NF4
234+
weight = weight.to(input.dtype) # dequant NF4
238235

239236
grad_input = grad_weight = grad_bias = None
240237

0 commit comments

Comments
 (0)