Skip to content
This repository was archived by the owner on Aug 7, 2024. It is now read-only.

Commit a9bae50

Browse files
committed
Update
[ghstack-poisoned]
1 parent fcf8a20 commit a9bae50

File tree

2 files changed

+2
-2
lines changed

2 files changed

+2
-2
lines changed

float8_experimental/float8_linear.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -296,7 +296,7 @@ def cast_x_to_float8(
296296
if torch.is_autocast_enabled():
297297
# For now, hardcode to GPU's autocast dtype
298298
# if we need CPU support in the future, we can add it
299-
autocast_dtype = torch.get_autocast_gpu_dtype()
299+
autocast_dtype = torch.get_autocast_dtype("cuda")
300300
x = x.to(autocast_dtype)
301301

302302
if self.scaling_type_x is TensorScalingType.DELAYED:

float8_experimental/float8_linear_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -274,7 +274,7 @@ def sync_float8_amax_and_scale_history(model: torch.nn.Module, fp8_layers=None)
274274
fp8_layers = get_float8_layers(model)
275275

276276
if len(fp8_layers) == 0:
277-
log.warn(
277+
log.warning(
278278
"Calling sync_float8_amax_and_scale_history on a module with no Float8Linear layers"
279279
)
280280
return

0 commit comments

Comments
 (0)