Skip to content

Commit e7b33bc

Browse files
authored
Unpin CUDA Nightly (#1064)
1 parent e6ceb95 commit e7b33bc

File tree

4 files changed

+26
-11
lines changed

4 files changed

+26
-11
lines changed

.github/workflows/regression_test.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,9 +38,9 @@ jobs:
3838
torch-spec: 'torch==2.4.0'
3939
gpu-arch-type: "cuda"
4040
gpu-arch-version: "12.1"
41-
- name: CUDA Nightly (Oct 1)
41+
- name: CUDA Nightly
4242
runs-on: linux.g5.12xlarge.nvidia.gpu
43-
torch-spec: '--pre torch==2.6.0.dev20241001+cu121 --index-url https://download.pytorch.org/whl/nightly/cu121'
43+
torch-spec: '--pre torch --index-url https://download.pytorch.org/whl/nightly/cu121'
4444
gpu-arch-type: "cuda"
4545
gpu-arch-version: "12.1"
4646

torchao/prototype/quantized_training/bitnet.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -92,9 +92,8 @@ def __torch_dispatch__(cls, func, types, args, kwargs):
9292
# return new unwrapped object
9393
return out
9494

95-
# new signature https://github.com/pytorch/pytorch/pull/136129
96-
# we need default None for module and mp_policy so this method still works with PyTorch 2.4 and 2.5
97-
def fsdp_pre_all_gather(self, mesh, module=None, mp_policy=None):
95+
# FSDP all-gather extension v1
96+
def fsdp_pre_all_gather(self, mesh):
9897
# quantize and pack into 2-bit to save comm bandwidth
9998
if self._precomputed_scale is not None:
10099
scale = self._precomputed_scale

torchao/prototype/quantized_training/int8.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -99,9 +99,17 @@ def __repr__(self):
9999
f"requires_grad={self.requires_grad})"
100100
)
101101

102-
# require https://github.com/pytorch/pytorch/pull/136129 for mixed-precision param_dtype
103-
# we need default None for module and mp_policy so this method still works with PyTorch 2.4 and 2.5
104-
def fsdp_pre_all_gather(self, mesh, module=None, mp_policy=None):
102+
# FSDP all-gather extension v2
103+
# https://github.com/pytorch/pytorch/pull/137005
104+
# we need default values so this method still works with PyTorch 2.4 and 2.5
105+
def fsdp_pre_all_gather(
106+
self,
107+
mesh,
108+
outer_size=None,
109+
outer_stride=None,
110+
module=None,
111+
mp_policy=None,
112+
):
105113
scale = self.scale
106114
if mp_policy is not None:
107115
scale = scale.to(mp_policy.param_dtype)

torchao/prototype/quantized_training/int8_mixed_precision.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -110,9 +110,17 @@ def unwrap(x: cls):
110110
# return new unwrapped object
111111
return out
112112

113-
# require https://github.com/pytorch/pytorch/pull/136129 for mixed-precision param_dtype
114-
# we need default None for module and mp_policy so this method still works with PyTorch 2.4 and 2.5
115-
def fsdp_pre_all_gather(self, mesh, module=None, mp_policy=None):
113+
# FSDP all-gather extension v2
114+
# https://github.com/pytorch/pytorch/pull/137005
115+
# we need default values so this method still works with PyTorch 2.4 and 2.5
116+
def fsdp_pre_all_gather(
117+
self,
118+
mesh,
119+
outer_size=None,
120+
outer_stride=None,
121+
module=None,
122+
mp_policy=None,
123+
):
116124
# TODO: pre-quantize weight here -> reduce comm bandwidth.
117125
# we will need another tensor subclass to hold the quantized weight.
118126
data = self._data

0 commit comments

Comments
 (0)