diff --git a/test/test_utils.py b/test/test_utils.py index c5bbf45a96..3bc16c20c0 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -16,14 +16,14 @@ class TestTorchVersion(unittest.TestCase): def test_torch_version_at_least(self): test_cases = [ - ("2.5.0a0+git9f17037", "2.5.0", True), - ("2.5.0a0+git9f17037", "2.4.0", True), - ("2.5.0.dev20240708+cu121", "2.5.0", True), - ("2.5.0.dev20240708+cu121", "2.4.0", True), - ("2.5.0", "2.4.0", True), - ("2.5.0", "2.5.0", True), - ("2.4.0", "2.4.0", True), - ("2.4.0", "2.5.0", False), + ("2.5.0a0+git9f17037", "2.5.0", False), # [2, 5, -1] < [2, 5, 0] + ("2.5.0a0+git9f17037", "2.4.0", True), # [2, 5, -1] > [2, 4, 0] + ("2.5.0.dev20240708+cu121", "2.5.0", False), # [2, 5, -1] < [2, 5, 0] + ("2.5.0.dev20240708+cu121", "2.4.0", True), # [2, 5, -1] > [2, 4, 0] + ("2.5.0", "2.4.0", True), # [2, 5, 0] > [2, 4, 0] + ("2.5.0", "2.5.0", True), # [2, 5, 0] >= [2, 5, 0] + ("2.4.0", "2.4.0", True), # [2, 4, 0] >= [2, 4, 0] + ("2.4.0", "2.5.0", False), # [2, 4, 0] < [2, 5, 0] ] for torch_version, compare_version, expected_result in test_cases: diff --git a/torchao/utils.py b/torchao/utils.py index a32166d556..298a0d176a 100644 --- a/torchao/utils.py +++ b/torchao/utils.py @@ -348,27 +348,33 @@ def _is_float8_type(dtype: torch.dtype) -> bool: def parse_version(version_string): - # Extract just the X.Y.Z part from the version string - match = re.match(r"(\d+\.\d+\.\d+)", version_string) + """ + Parse version string representing pre-release with -1 + + Examples: "2.5.0.dev20240708+cu121" -> [2, 5, -1], "2.5.0" -> [2, 5, 0] + """ + # Check for pre-release indicators + is_prerelease = bool(re.search(r"(git|dev)", version_string)) + match = re.match(r"(\d+)\.(\d+)\.(\d+)", version_string) if match: - version = match.group(1) - return [int(x) for x in version.split(".")] + major, minor, patch = map(int, match.groups()) + if is_prerelease: + patch = -1 + return [major, minor, patch] else: raise ValueError(f"Invalid version string format: {version_string}") -def compare_versions(v1, v2): - v1_parts = parse_version(v1) - v2_parts = parse_version(v2) - return (v1_parts > v2_parts) - (v1_parts < v2_parts) - - def is_fbcode(): return not hasattr(torch.version, "git_version") def torch_version_at_least(min_version): - return is_fbcode() or compare_versions(torch.__version__, min_version) >= 0 + if is_fbcode(): + return True + + # Parser for local identifiers + return parse_version(torch.__version__) >= parse_version(min_version) def _deprecated_torch_version_at_least(version_str: str) -> str: @@ -983,13 +989,13 @@ def is_sm_at_least_100(): def check_cpu_version(device, version="2.6.0"): if isinstance(device, torch.device): device = device.type - return device == "cpu" and compare_versions(torch.__version__, version) >= 0 + return device == "cpu" and torch_version_at_least(version) def check_xpu_version(device, version="2.8.0"): if isinstance(device, torch.device): device = device.type - return device == "xpu" and compare_versions(torch.__version__, version) >= 0 + return device == "xpu" and torch_version_at_least(version) def ceil_div(a, b):