Skip to content

Commit 15974c7

Browse files
authored
Merge branch 'main' into rocm_sparse_marlin
2 parents d2c7ce4 + 457c5b1 commit 15974c7

File tree

208 files changed

+110712
-4801
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

208 files changed

+110712
-4801
lines changed

.github/scripts/github_utils.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,12 @@
33
import json
44
import os
55
import warnings
6-
76
from dataclasses import dataclass
8-
from typing import Any, Callable, cast, Dict, List, Optional, Tuple, Union
7+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union, cast
98
from urllib.error import HTTPError
109
from urllib.parse import quote
1110
from urllib.request import Request, urlopen
1211

13-
1412
GITHUB_API_URL = "https://api.github.com"
1513

1614

.github/scripts/gitutils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,14 +9,14 @@
99
from typing import (
1010
Any,
1111
Callable,
12-
cast,
1312
Dict,
1413
Iterator,
1514
List,
1615
Optional,
1716
Tuple,
1817
TypeVar,
1918
Union,
19+
cast,
2020
)
2121

2222
T = TypeVar("T")
@@ -45,7 +45,7 @@ def fuzzy_list_to_dict(items: List[Tuple[str, str]]) -> Dict[str, List[str]]:
4545

4646

4747
def _check_output(items: List[str], encoding: str = "utf-8") -> str:
48-
from subprocess import CalledProcessError, check_output, STDOUT
48+
from subprocess import STDOUT, CalledProcessError, check_output
4949

5050
try:
5151
return check_output(items, stderr=STDOUT).decode(encoding)

.github/scripts/label_utils.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,10 @@
11
"""GitHub Label Utilities."""
22

33
import json
4-
54
from functools import lru_cache
6-
from typing import Any, List, Tuple, TYPE_CHECKING, Union
5+
from typing import TYPE_CHECKING, Any, List, Tuple, Union
76

8-
from github_utils import gh_fetch_url_and_headers, GitHubComment
7+
from github_utils import GitHubComment, gh_fetch_url_and_headers
98

109
# TODO: this is a temp workaround to avoid circular dependencies,
1110
# and should be removed once GitHubPR is refactored out of trymerge script.

.github/scripts/trymerge.py

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -23,44 +23,41 @@
2323
from typing import (
2424
Any,
2525
Callable,
26-
cast,
2726
Dict,
2827
Iterable,
2928
List,
3029
NamedTuple,
3130
Optional,
3231
Pattern,
3332
Tuple,
33+
cast,
3434
)
3535
from warnings import warn
3636

3737
import yaml
3838
from github_utils import (
39+
GitHubComment,
3940
gh_fetch_json_list,
4041
gh_fetch_merge_base,
4142
gh_fetch_url,
4243
gh_graphql,
4344
gh_post_commit_comment,
4445
gh_post_pr_comment,
4546
gh_update_pr_state,
46-
GitHubComment,
4747
)
48-
4948
from gitutils import (
49+
GitRepo,
5050
are_ghstack_branches_in_sync,
5151
get_git_remote_name,
5252
get_git_repo_dir,
53-
GitRepo,
5453
patterns_to_regex,
5554
retries_decorator,
5655
)
5756
from label_utils import (
5857
gh_add_labels,
5958
gh_remove_label,
60-
has_required_labels,
61-
LABEL_ERR_MSG,
6259
)
63-
from trymerge_explainer import get_revert_message, TryMergeExplainer
60+
from trymerge_explainer import TryMergeExplainer, get_revert_message
6461

6562
# labels
6663
MERGE_IN_PROGRESS_LABEL = "merging"
@@ -1477,7 +1474,7 @@ def checks_to_str(checks: List[Tuple[str, Optional[str]]]) -> str:
14771474

14781475

14791476
def checks_to_markdown_bullets(
1480-
checks: List[Tuple[str, Optional[str], Optional[int]]]
1477+
checks: List[Tuple[str, Optional[str], Optional[int]]],
14811478
) -> List[str]:
14821479
return [
14831480
f"- [{c[0]}]({c[1]})" if c[1] is not None else f"- {c[0]}" for c in checks[:5]
@@ -1716,7 +1713,7 @@ def get_readable_drci_results(drci_classifications: Any) -> str:
17161713
try:
17171714
print(f"From Dr.CI checkrun summary: {drci_summary}")
17181715
drci_classifications = json.loads(str(drci_summary))
1719-
except json.JSONDecodeError as error:
1716+
except json.JSONDecodeError:
17201717
warn("Invalid Dr.CI checkrun summary")
17211718
drci_classifications = {}
17221719

@@ -1887,7 +1884,6 @@ def do_revert_prs(
18871884
dry_run: bool = False,
18881885
) -> None:
18891886
# Prepare and push revert commits
1890-
commit_shas: List[str] = []
18911887
for commit_sha, pr in shas_and_prs:
18921888
revert_msg = f"\nReverted {pr.get_pr_url()} on behalf of {prefix_with_github_url(author_login)}"
18931889
revert_msg += extra_msg

.github/scripts/trymerge_explainer.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
import re
33
from typing import List, Optional, Pattern, Tuple
44

5-
65
BOT_COMMANDS_WIKI = "https://github.com/pytorch/pytorch/wiki/Bot-commands"
76

87
CIFLOW_LABEL = re.compile(r"^ciflow/.+")

benchmarks/bench_galore_fused_kernels.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,7 @@
88
def run(args):
99
dtype = getattr(torch, args.dtype)
1010
allow_tf32 = args.allow_tf32
11-
fp8_fast_accum = False
1211
torch.backends.cuda.matmul.allow_tf32 = allow_tf32
13-
kernel = args.kernel
1412
M, N = args.M, args.N
1513
rank = args.rank
1614

benchmarks/benchmark_aq.py

Lines changed: 83 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,26 @@
1-
"""Benchmarks for affine quantized tensor, this includes int8 dynamic quant, int8 weight only quant and int4 weight only quant APIs
2-
"""
1+
"""Benchmarks for affine quantized tensor, this includes int8 dynamic quant, int8 weight only quant and int4 weight only quant APIs"""
2+
3+
import copy
4+
35
import torch
6+
7+
from torchao.quantization.quant_api import (
8+
_replace_with_custom_fn_if_matches_filter,
9+
int4_weight_only,
10+
int8_dynamic_activation_int8_weight,
11+
int8_weight_only,
12+
quantize_,
13+
)
414
from torchao.quantization.subclass import (
5-
Int8WeightOnlyQuantizedLinearWeight,
615
Int4WeightOnlyQuantizedLinearWeight,
16+
Int8WeightOnlyQuantizedLinearWeight,
717
)
818
from torchao.utils import (
919
TORCH_VERSION_AT_LEAST_2_4,
1020
TORCH_VERSION_AT_LEAST_2_5,
21+
unwrap_tensor_subclass,
1122
)
12-
from torchao.quantization.quant_api import (
13-
int4_weight_only,
14-
int8_weight_only,
15-
int8_dynamic_activation_int8_weight,
16-
quantize_,
17-
_replace_with_custom_fn_if_matches_filter,
18-
)
19-
import copy
20-
from torchao.utils import unwrap_tensor_subclass
23+
2124

2225
def _int8wo_api(mod, **kwargs):
2326
if TORCH_VERSION_AT_LEAST_2_4:
@@ -27,14 +30,20 @@ def _int8wo_api(mod, **kwargs):
2730
else:
2831
change_linear_weights_to_int8_woqtensors(mod, **kwargs)
2932

33+
3034
def _int8da_int8w_api(mod, **kwargs):
3135
if TORCH_VERSION_AT_LEAST_2_4:
32-
quantize_(mod, int8_dynamic_activation_int8_weight(**kwargs), set_inductor_config=False)
36+
quantize_(
37+
mod,
38+
int8_dynamic_activation_int8_weight(**kwargs),
39+
set_inductor_config=False,
40+
)
3341
if not TORCH_VERSION_AT_LEAST_2_5:
3442
unwrap_tensor_subclass(mod)
3543
else:
3644
change_linear_weights_to_int8_dqtensors(mod, **kwargs)
3745

46+
3847
def _int4wo_api(mod, **kwargs):
3948
if TORCH_VERSION_AT_LEAST_2_4:
4049
kwargs_copy = kwargs.copy()
@@ -47,31 +56,43 @@ def _int4wo_api(mod, **kwargs):
4756
else:
4857
change_linear_weights_to_int4_woqtensors(mod, **kwargs)
4958

59+
5060
class ToyLinearModel(torch.nn.Module):
51-
"""Single linear for m * k * n problem size
52-
"""
53-
def __init__(self, m=64, n=32, k=64, has_bias=False, dtype=torch.float, device="cuda"):
61+
"""Single linear for m * k * n problem size"""
62+
63+
def __init__(
64+
self, m=64, n=32, k=64, has_bias=False, dtype=torch.float, device="cuda"
65+
):
5466
super().__init__()
5567
self.m = m
5668
self.dtype = dtype
5769
self.device = device
58-
self.linear = torch.nn.Linear(k, n, bias=has_bias).to(dtype=self.dtype, device=self.device)
70+
self.linear = torch.nn.Linear(k, n, bias=has_bias).to(
71+
dtype=self.dtype, device=self.device
72+
)
5973

6074
def example_inputs(self):
61-
return (torch.randn(self.m, self.linear.in_features, dtype=self.dtype, device=self.device),)
75+
return (
76+
torch.randn(
77+
self.m, self.linear.in_features, dtype=self.dtype, device=self.device
78+
),
79+
)
6280

6381
def forward(self, x):
6482
x = self.linear(x)
6583
return x
6684

85+
6786
def _ref_change_linear_weights_to_int8_dqtensors(model, filter_fn=None, **kwargs):
6887
"""
6988
The deprecated implementation for int8 dynamic quant API, used as a reference for
7089
numerics and performance
7190
"""
72-
from torchao.quantization.quant_api import _in_features_greater_than_16
73-
from torchao.quantization.quant_api import _is_linear
74-
from torchao.quantization.quant_api import _get_subclass_inserter
91+
from torchao.quantization.quant_api import (
92+
_get_subclass_inserter,
93+
_in_features_greater_than_16,
94+
_is_linear,
95+
)
7596
from torchao.quantization.subclass import Int8DynamicallyQuantizedLinearWeight
7697

7798
if filter_fn is None:
@@ -80,40 +101,54 @@ def _ref_change_linear_weights_to_int8_dqtensors(model, filter_fn=None, **kwargs
80101
)
81102

82103
_replace_with_custom_fn_if_matches_filter(
83-
model, _get_subclass_inserter(Int8DynamicallyQuantizedLinearWeight, enable_parametrization=False, **kwargs), filter_fn
104+
model,
105+
_get_subclass_inserter(
106+
Int8DynamicallyQuantizedLinearWeight, enable_parametrization=False, **kwargs
107+
),
108+
filter_fn,
84109
)
85110

111+
86112
def _get_ref_change_linear_weights_to_woqtensors(deprecated_tenosr_subclass):
87113
def _ref_change_linear_weights_to_woqtensors(model, filter_fn=None, **kwargs):
88114
"""
89115
The deprecated implementation for weight only quant API, used as a reference for
90116
numerics and performance
91117
"""
92-
from torchao.quantization.quant_api import _is_linear
93-
from torchao.quantization.quant_api import _get_subclass_inserter
118+
from torchao.quantization.quant_api import _get_subclass_inserter, _is_linear
94119

95120
filter_fn = kwargs.pop("filter_fn", _is_linear)
96121

97122
_replace_with_custom_fn_if_matches_filter(
98123
model,
99-
_get_subclass_inserter(deprecated_tenosr_subclass, enable_parametrization=True, **kwargs),
124+
_get_subclass_inserter(
125+
deprecated_tenosr_subclass, enable_parametrization=True, **kwargs
126+
),
100127
filter_fn,
101128
)
102129

103130
return _ref_change_linear_weights_to_woqtensors
104131

105-
_ref_change_linear_weights_to_int8_woqtensors = _get_ref_change_linear_weights_to_woqtensors(Int8WeightOnlyQuantizedLinearWeight)
106-
_ref_change_linear_weights_to_int4_woqtensors = _get_ref_change_linear_weights_to_woqtensors(Int4WeightOnlyQuantizedLinearWeight)
132+
133+
_ref_change_linear_weights_to_int8_woqtensors = (
134+
_get_ref_change_linear_weights_to_woqtensors(Int8WeightOnlyQuantizedLinearWeight)
135+
)
136+
_ref_change_linear_weights_to_int4_woqtensors = (
137+
_get_ref_change_linear_weights_to_woqtensors(Int4WeightOnlyQuantizedLinearWeight)
138+
)
107139

108140

109141
torch._dynamo.config.cache_size_limit = 50000
110142

143+
111144
@torch.no_grad
112145
def _bench_quantized_tensor_subclass_perf(api, ref_api, M, N, K, kwargs=None):
113146
if kwargs is None:
114147
kwargs = {}
115148

116-
m = ToyLinearModel(M, N, K, has_bias=True, dtype=torch.bfloat16, device="cuda").eval()
149+
m = ToyLinearModel(
150+
M, N, K, has_bias=True, dtype=torch.bfloat16, device="cuda"
151+
).eval()
117152
m_bf16 = copy.deepcopy(m)
118153
m_ref = copy.deepcopy(m)
119154
example_inputs = m.example_inputs()
@@ -130,26 +165,30 @@ def _bench_quantized_tensor_subclass_perf(api, ref_api, M, N, K, kwargs=None):
130165

131166
# perf comparison
132167
from torchao.utils import benchmark_model
168+
133169
# warmup
134170
WARMUP = 20
135171
RUNS = 100
136172

137173
torch._dynamo.reset()
138-
m_ref = torch.compile(m_ref, mode='max-autotune', fullgraph=True)
174+
m_ref = torch.compile(m_ref, mode="max-autotune", fullgraph=True)
139175
benchmark_model(m_ref, WARMUP, example_inputs)
140176
ref_elapsed_time = benchmark_model(m_ref, RUNS, example_inputs)
141177

142178
torch._dynamo.reset()
143-
m = torch.compile(m, mode='max-autotune', fullgraph=True)
179+
m = torch.compile(m, mode="max-autotune", fullgraph=True)
144180
benchmark_model(m, WARMUP, example_inputs)
145181
elapsed_time = benchmark_model(m, RUNS, example_inputs)
146182

147183
torch._dynamo.reset()
148-
m_bf16 = torch.compile(m_bf16, mode='max-autotune', fullgraph=True)
184+
m_bf16 = torch.compile(m_bf16, mode="max-autotune", fullgraph=True)
149185
benchmark_model(m_bf16, WARMUP, example_inputs)
150186
bf16_elapsed_time = benchmark_model(m_bf16, RUNS, example_inputs)
151187

152-
print(f"{(M, N, K)}: elapsed time: {elapsed_time}, ref elapsed time: {ref_elapsed_time}, bf16 elapsed time: {bf16_elapsed_time}")
188+
print(
189+
f"{(M, N, K)}: elapsed time: {elapsed_time}, ref elapsed time: {ref_elapsed_time}, bf16 elapsed time: {bf16_elapsed_time}"
190+
)
191+
153192

154193
if __name__ == "__main__" and TORCH_VERSION_AT_LEAST_2_4 and torch.cuda.is_available():
155194
all_shapes = [
@@ -158,16 +197,25 @@ def _bench_quantized_tensor_subclass_perf(api, ref_api, M, N, K, kwargs=None):
158197

159198
print("_int8da_int8w_api")
160199
from torchao.quantization.quant_api import change_linear_weights_to_int8_dqtensors
200+
161201
for M, N, K in all_shapes:
162-
_bench_quantized_tensor_subclass_perf(_int8da_int8w_api, _ref_change_linear_weights_to_int8_dqtensors, M, N, K)
202+
_bench_quantized_tensor_subclass_perf(
203+
_int8da_int8w_api, _ref_change_linear_weights_to_int8_dqtensors, M, N, K
204+
)
163205

164206
print("_int8wo_api")
165207
from torchao.quantization.quant_api import change_linear_weights_to_int8_woqtensors
208+
166209
for M, N, K in all_shapes:
167-
_bench_quantized_tensor_subclass_perf(_int8wo_api, _ref_change_linear_weights_to_int8_woqtensors, M, N, K)
210+
_bench_quantized_tensor_subclass_perf(
211+
_int8wo_api, _ref_change_linear_weights_to_int8_woqtensors, M, N, K
212+
)
168213

169214
print("_int4wo_api")
170215
kwargs = {"groupsize": 32}
171216
from torchao.quantization.quant_api import change_linear_weights_to_int4_woqtensors
217+
172218
for M, N, K in all_shapes:
173-
_bench_quantized_tensor_subclass_perf(_int4wo_api, _ref_change_linear_weights_to_int4_woqtensors, M, N, K, kwargs)
219+
_bench_quantized_tensor_subclass_perf(
220+
_int4wo_api, _ref_change_linear_weights_to_int4_woqtensors, M, N, K, kwargs
221+
)

0 commit comments

Comments
 (0)