diff --git a/torchao/dtypes/uintx/marlin_qqq_tensor.py b/torchao/dtypes/uintx/marlin_qqq_tensor.py index b75d959b41..9bd7512acc 100644 --- a/torchao/dtypes/uintx/marlin_qqq_tensor.py +++ b/torchao/dtypes/uintx/marlin_qqq_tensor.py @@ -179,7 +179,7 @@ def __tensor_unflatten__( def get_plain(self): from torchao.quantization.marlin_qqq import ( unpack_from_marlin_qqq, - ) # avoid circular import + ) int_data_expanded, s_group_expanded, s_channel_expanded = ( unpack_from_marlin_qqq( @@ -207,7 +207,7 @@ def from_plain( from torchao.quantization.marlin_qqq import ( const, pack_to_marlin_qqq, - ) # avoid circular import + ) assert isinstance(_layout, MarlinQQQLayout) diff --git a/torchao/dtypes/uintx/marlin_sparse_layout.py b/torchao/dtypes/uintx/marlin_sparse_layout.py index 2a84dd1813..103e23544e 100644 --- a/torchao/dtypes/uintx/marlin_sparse_layout.py +++ b/torchao/dtypes/uintx/marlin_sparse_layout.py @@ -195,7 +195,7 @@ def __tensor_unflatten__( def get_plain(self): from torchao.sparsity.marlin import ( unpack_from_marlin_24, - ) # avoid circular import + ) int_data_expanded, scales_expanded = unpack_from_marlin_24( self.int_data, @@ -220,7 +220,7 @@ def from_plain( from torchao.sparsity.marlin import ( const, pack_to_marlin_24, - ) # avoid circular import + ) assert isinstance(_layout, MarlinSparseLayout) diff --git a/torchao/float8/config.py b/torchao/float8/config.py index c7f32cd3fa..866d7d199c 100644 --- a/torchao/float8/config.py +++ b/torchao/float8/config.py @@ -11,6 +11,8 @@ import torch +from torchao.utils import is_MI300 + logger: logging.Logger = logging.getLogger() @@ -58,7 +60,7 @@ class Float8TypeConfig: """ Configuration for selecting the preferred float8 type pair, either e4m3fn/e5m2 or e4m3fnuz/e5m2fnuz. - Currently, ROCm only supports fnuz variants. + Currently, ROCm supports 1. fnuz variants in MI300. 2. OCP F8 variants in MI350/Navi4. """ # The preferred e4m3 type. @@ -68,12 +70,9 @@ class Float8TypeConfig: e5m2_dtype = torch.float8_e5m2 def __post_init__(self): - if torch.version.hip and torch.cuda.is_available(): - prop = torch.cuda.get_device_properties(0) - MI300_ARCH = ("gfx940", "gfx941", "gfx942") - if prop.gcnArchName.split(":")[0] in MI300_ARCH: - self.e4m3_dtype = torch.float8_e4m3fnuz - self.e5m2_dtype = torch.float8_e5m2fnuz + if torch.version.hip and torch.cuda.is_available() and is_MI300(): + self.e4m3_dtype = torch.float8_e4m3fnuz + self.e5m2_dtype = torch.float8_e5m2fnuz # User defined type for using the individual F8 type based on config diff --git a/torchao/utils.py b/torchao/utils.py index 7a17c1b104..a38155edf5 100644 --- a/torchao/utils.py +++ b/torchao/utils.py @@ -604,16 +604,41 @@ def _torch_version_at_least(min_version): return is_fbcode() or version("torch") >= min_version +# Supported AMD GPU Models and their LLVM gfx Codes: +# +# | AMD GPU Model | LLVM gfx Code | +# |---------------|------------------------| +# | Navi4 | gfx1200, gfx1201 | +# | MI300X | gfx940, gfx941, gfx942 | +# | MI350 | gfx950 | + + def is_MI300(): if torch.cuda.is_available() and torch.version.hip: mxArchName = ["gfx940", "gfx941", "gfx942"] - archName = torch.cuda.get_device_properties().gcnArchName + archName = torch.cuda.get_device_properties(0).gcnArchName for arch in mxArchName: if arch in archName: return True return False +def is_MI350(): + if torch.cuda.is_available() and torch.version.hip: + archName = torch.cuda.get_device_properties(0).gcnArchName + if "gfx950" in archName: + return True + return False + + +def is_Navi4(): + if torch.cuda.is_available() and torch.version.hip: + archName = torch.cuda.get_device_properties(0).gcnArchName + if "gfx1200" or "gfx1201" in archName: + return True + return False + + def is_sm_at_least_89(): return ( torch.cuda.is_available()