Skip to content

Commit c8aff33

Browse files
authored
Revert "pin nightly to 2.5.0.dev20240709+cu121 (#505)"
This reverts commit cc871c5.
1 parent 6e7cf71 commit c8aff33

File tree

2 files changed

+9
-3
lines changed

2 files changed

+9
-3
lines changed

.github/workflows/regression_test.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ jobs:
3333
gpu-arch-version: "12.1"
3434
- name: CUDA Nightly
3535
runs-on: linux.g5.12xlarge.nvidia.gpu
36-
torch-spec: '--pre torch==2.5.0.dev20240709+cu121 --index-url https://download.pytorch.org/whl/nightly/cu121'
36+
torch-spec: '--pre torch --index-url https://download.pytorch.org/whl/nightly/cu121'
3737
gpu-arch-type: "cuda"
3838
gpu-arch-version: "12.1"
3939
- name: CPU 2.2.2
@@ -48,7 +48,7 @@ jobs:
4848
gpu-arch-version: ""
4949
- name: CPU Nightly
5050
runs-on: linux.4xlarge
51-
torch-spec: '--pre torch==2.5.0.dev20240709+cpu --index-url https://download.pytorch.org/whl/nightly/cpu'
51+
torch-spec: '--pre torch --index-url https://download.pytorch.org/whl/nightly/cpu'
5252
gpu-arch-type: "cpu"
5353
gpu-arch-version: ""
5454

test/dtypes/test_uint4.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,19 @@
44
PerChannelSymmetricWeightUInt4Tensor,
55
)
66
import unittest
7+
from unittest import TestCase, main
78
from torch.ao.quantization.quantize_pt2e import prepare_pt2e, convert_pt2e
89
from torch.ao.quantization.quantizer import QuantizationSpec, Quantizer
910

1011
from torch._export import capture_pre_autograd_graph
12+
from torch._export import dynamic_dim
1113
from torch.testing._internal.common_quantization import (
1214
NodeSpec as ns,
1315
QuantizationTestCase,
1416
)
17+
from torchao.quantization.utils import (
18+
compute_error,
19+
)
1520
from torchao.quantization.quant_api import (
1621
_replace_with_custom_fn_if_matches_filter,
1722
)
@@ -25,6 +30,7 @@
2530
QuantizationAnnotation,
2631
)
2732
import copy
33+
from packaging import version
2834

2935

3036
def _apply_weight_only_uint4_quant(model):
@@ -223,4 +229,4 @@ def forward(self, x):
223229
)
224230

225231
if __name__ == "__main__":
226-
unittest.main()
232+
main()

0 commit comments

Comments
 (0)