|
25 | 25 | get_float8_layers, |
26 | 26 | sync_float8_amax_and_scale_history, |
27 | 27 | ) |
28 | | -from torchao.float8.float8_scaling_utils import hp_tensor_to_float8_delayed |
29 | | -from torchao.float8.float8_tensor import LinearMMConfig |
| 28 | +from torchao.float8.float8_scaling_utils import ( |
| 29 | + hp_tensor_to_float8_delayed, |
| 30 | + hp_tensor_to_float8_dynamic, |
| 31 | +) |
| 32 | +from torchao.float8.float8_tensor import ( |
| 33 | + LinearMMConfig, |
| 34 | + GemmInputRole, |
| 35 | + ScaledMMConfig, |
| 36 | +) |
30 | 37 | from torchao.float8.float8_utils import e4m3_dtype |
31 | 38 |
|
32 | 39 | from torch._dynamo.test_case import TestCase as DynamoTestCase |
@@ -353,5 +360,65 @@ def test_sync_amax_func_cuda_graph_success(): |
353 | 360 | assert "skipping cudagraphs due to mutaton on input" not in stderr[0] |
354 | 361 |
|
355 | 362 |
|
| 363 | +@unittest.skipIf( |
| 364 | + not is_cuda_8_9, |
| 365 | + "CUDA not available", |
| 366 | + ) |
| 367 | +@pytest.mark.parametrize( |
| 368 | + "dtype", |
| 369 | + [ |
| 370 | + torch.float32, |
| 371 | + torch.bfloat16, |
| 372 | + torch.float16, |
| 373 | + ], |
| 374 | +) |
| 375 | +def test_dynamic_scale_numeric_parity(dtype: torch.dtype): |
| 376 | + scaling_type_weight = ScalingType.DYNAMIC |
| 377 | + torch.manual_seed(42) |
| 378 | + hp_tensor1 = torch.randn(16, 16, device="cuda", dtype=dtype) |
| 379 | + hp_tensor2 = hp_tensor1.detach().clone() |
| 380 | + float8_config = Float8LinearConfig( |
| 381 | + cast_config_weight=CastConfig(scaling_type=scaling_type_weight), |
| 382 | + ) |
| 383 | + linear_mm_config = LinearMMConfig( |
| 384 | + # output |
| 385 | + ScaledMMConfig( |
| 386 | + False, |
| 387 | + float8_config.gemm_config_output.use_fast_accum, |
| 388 | + False, |
| 389 | + float8_config.pad_inner_dim, |
| 390 | + ), |
| 391 | + # grad_input |
| 392 | + ScaledMMConfig( |
| 393 | + False, |
| 394 | + float8_config.gemm_config_grad_input.use_fast_accum, |
| 395 | + False, |
| 396 | + float8_config.pad_inner_dim, |
| 397 | + ), |
| 398 | + # grad_weight |
| 399 | + ScaledMMConfig( |
| 400 | + False, |
| 401 | + float8_config.gemm_config_grad_weight.use_fast_accum, |
| 402 | + False, |
| 403 | + float8_config.pad_inner_dim, |
| 404 | + ), |
| 405 | + ) |
| 406 | + float8_eager = hp_tensor_to_float8_dynamic( |
| 407 | + hp_tensor1, |
| 408 | + torch.float8_e4m3fn, |
| 409 | + linear_mm_config, |
| 410 | + gemm_input_role=GemmInputRole.WEIGHT, |
| 411 | + ) |
| 412 | + torch._dynamo.reset() |
| 413 | + float8_compile = torch.compile(hp_tensor_to_float8_dynamic)( |
| 414 | + hp_tensor2, |
| 415 | + torch.float8_e4m3fn, |
| 416 | + linear_mm_config, |
| 417 | + gemm_input_role=GemmInputRole.WEIGHT, |
| 418 | + ) |
| 419 | + assert torch.equal(float8_eager._scale, float8_compile._scale) |
| 420 | + assert torch.equal(float8_eager._data, float8_compile._data) |
| 421 | + |
| 422 | + |
356 | 423 | if __name__ == "__main__": |
357 | 424 | pytest.main([__file__]) |
0 commit comments