File tree Expand file tree Collapse file tree 2 files changed +31
-7
lines changed Expand file tree Collapse file tree 2 files changed +31
-7
lines changed Original file line number Diff line number Diff line change 1111
1212import torch
1313
14+ from torchao .utils import is_MI300
15+
1416logger : logging .Logger = logging .getLogger ()
1517
1618
@@ -52,7 +54,7 @@ class Float8TypeConfig:
5254 """
5355 Configuration for selecting the preferred float8 type pair, either e4m3fn/e5m2 or e4m3fnuz/e5m2fnuz.
5456
55- Currently, ROCm only supports fnuz variants.
57+ Currently, ROCm supports 1. fnuz variants in MI300. 2. OCP F8 variants in MI350/Navi4 .
5658 """
5759
5860 # The preferred e4m3 type.
@@ -62,12 +64,9 @@ class Float8TypeConfig:
6264 e5m2_dtype = torch .float8_e5m2
6365
6466 def __post_init__ (self ):
65- if torch .version .hip and torch .cuda .is_available ():
66- prop = torch .cuda .get_device_properties (0 )
67- MI300_ARCH = ("gfx940" , "gfx941" , "gfx942" )
68- if prop .gcnArchName .split (":" )[0 ] in MI300_ARCH :
69- self .e4m3_dtype = torch .float8_e4m3fnuz
70- self .e5m2_dtype = torch .float8_e5m2fnuz
67+ if torch .version .hip and torch .cuda .is_available () and is_MI300 ():
68+ self .e4m3_dtype = torch .float8_e4m3fnuz
69+ self .e5m2_dtype = torch .float8_e5m2fnuz
7170
7271
7372# User defined type for using the individual F8 type based on config
Original file line number Diff line number Diff line change @@ -606,6 +606,15 @@ def _torch_version_at_least(min_version):
606606 return is_fbcode () or version ("torch" ) >= min_version
607607
608608
609+ # Supported AMD GPU Models and their LLVM gfx Codes:
610+ #
611+ # | AMD GPU Model | LLVM gfx Code |
612+ # |---------------|------------------------|
613+ # | Navi4 | gfx1200, gfx1201 |
614+ # | MI300X | gfx940, gfx941, gfx942 |
615+ # | MI350 | gfx950 |
616+
617+
609618def is_MI300 ():
610619 if torch .cuda .is_available () and torch .version .hip :
611620 mxArchName = ["gfx940" , "gfx941" , "gfx942" ]
@@ -616,6 +625,22 @@ def is_MI300():
616625 return False
617626
618627
628+ def is_MI350 ():
629+ if torch .cuda .is_available () and torch .version .hip :
630+ archName = torch .cuda .get_device_properties (0 ).gcnArchName
631+ if "gfx950" in archName :
632+ return True
633+ return False
634+
635+
636+ def is_Navi4 ():
637+ if torch .cuda .is_available () and torch .version .hip :
638+ archName = torch .cuda .get_device_properties (0 ).gcnArchName
639+ if "gfx1200" or "gfx1201" in archName :
640+ return True
641+ return False
642+
643+
619644def is_sm_at_least_89 ():
620645 return (
621646 torch .cuda .is_available ()
You can’t perform that action at this time.
0 commit comments