Skip to content

Commit ca7407e

Browse files
[JAX] Update tolerance of distributed layernorm MLP for FP8 (#1971)
Update tolerance of distributed layernorm MLP for FP8 Signed-off-by: Jeremy Berchtold <[email protected]>
1 parent 86c5097 commit ca7407e

File tree

1 file changed

+18
-0
lines changed

1 file changed

+18
-0
lines changed

tests/jax/test_distributed_layernorm_mlp.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -389,6 +389,24 @@ def _test_layernorm_mlp(
389389
atol = 0.04
390390
rtol = 11
391391

392+
# JAX's FP8 GEMM, jax.lax.dot_general, now uses the
393+
# Triton backend by default. The error of
394+
# the Triton FP8 gemm has been verified to be less than or equal
395+
# to the error of the cuDNN FP8 gemm w.r.t a float32 ground truth.
396+
# However, Triton can auto-tune a different kernel for the single GPU
397+
# and multi-GPU run in this test, meaning the diff between single GPU
398+
# and multi-GPU can be larger in some cases, even though both are
399+
# within tolerance to the float32 ground truth.
400+
jax_triton_gemm_precision_tolerance_update = (
401+
with_jax_gemm
402+
and isinstance(fp8_recipe, recipe.Float8CurrentScaling)
403+
and dtype == jnp.bfloat16
404+
and activation_type == ("gelu", "linear")
405+
)
406+
if jax_triton_gemm_precision_tolerance_update:
407+
atol = 0.08
408+
rtol = 15
409+
392410
assert_allclose(mlp_out_sharded, mlp_out_single, dtype=dtype, atol=atol, rtol=rtol)
393411

394412
@pytest_parametrize_wrapper("input_shape", INPUT_SHAPE)

0 commit comments

Comments
 (0)