Skip to content

Commit 91e9d91

Browse files
committed
Updates
1 parent 30567b4 commit 91e9d91

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

65 files changed

+319
-302
lines changed

benchmarks/_models/llama/eval.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from tokenizer import get_tokenizer
1616

1717
import torchao
18-
from torchao._models.model import prepare_inputs_for_model
18+
from torchao._models.llm.model import prepare_inputs_for_model
1919
from torchao.quantization import (
2020
PerRow,
2121
PerTensor,
@@ -172,7 +172,7 @@ def run_evaluation(
172172
if "autoround" in quantization:
173173
from transformers import AutoTokenizer
174174

175-
from torchao._models.model import TransformerBlock
175+
from torchao._models.llm.model import TransformerBlock
176176
from torchao.prototype.autoround.autoround_llm import (
177177
quantize_model_with_autoround_,
178178
)

benchmarks/_models/llama/generate.py

Lines changed: 17 additions & 111 deletions
Original file line numberDiff line numberDiff line change
@@ -7,20 +7,30 @@
77
import time
88
from datetime import datetime
99
from pathlib import Path
10-
from typing import Optional, Tuple
10+
from typing import Optional
1111

1212
import torch
1313
import torch._dynamo.config
1414
import torch._inductor.config
1515

1616
import torchao
17-
from benchmarks._models.utils import (
17+
from torchao._models.utils import (
18+
_load_model,
19+
decode_n_tokens,
20+
decode_one_token,
21+
encode_tokens,
1822
get_arch_name,
23+
prefill,
1924
write_json_result_local,
2025
write_json_result_ossci,
2126
)
2227
from torchao.quantization.quant_primitives import MappingType
23-
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, get_model_size_in_bytes
28+
from torchao.utils import (
29+
TORCH_VERSION_AT_LEAST_2_5,
30+
default_device,
31+
device_sync,
32+
get_model_size_in_bytes,
33+
)
2434

2535
torch.sparse.SparseSemiStructuredTensor._FORCE_CUTLASS = False
2636
torch.backends.cuda.enable_cudnn_sdp(True)
@@ -49,97 +59,12 @@ def device_timer(device):
4959
print(f"device={device} is not yet suppported")
5060

5161

52-
def device_sync(device):
53-
if "cuda" in device:
54-
torch.cuda.synchronize(device)
55-
elif "xpu" in device:
56-
torch.xpu.synchronize(device)
57-
elif ("cpu" in device) or ("mps" in device):
58-
pass
59-
else:
60-
print(f"device={device} is not yet suppported")
61-
62-
63-
default_device = (
64-
"cuda"
65-
if torch.cuda.is_available()
66-
else "xpu"
67-
if torch.xpu.is_available()
68-
else "cpu"
69-
)
70-
7162
# support running without installing as a package
7263
wd = Path(__file__).parent.parent.resolve()
7364
sys.path.append(str(wd))
7465

75-
from torchao._models.model import Transformer, prepare_inputs_for_model
76-
from torchao._models.tokenizer import get_tokenizer
77-
78-
79-
def multinomial_sample_one_no_sync(
80-
probs_sort,
81-
): # Does multinomial sampling without a cuda synchronization
82-
q = torch.empty_like(probs_sort).exponential_(1)
83-
return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int)
84-
85-
86-
def logits_to_probs(logits, temperature: float = 1.0, top_k: Optional[int] = None):
87-
logits = logits / max(temperature, 1e-5)
88-
89-
if top_k is not None:
90-
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
91-
pivot = v.select(-1, -1).unsqueeze(-1)
92-
logits = torch.where(logits < pivot, -float("Inf"), logits)
93-
probs = torch.nn.functional.softmax(logits, dim=-1)
94-
return probs
95-
96-
97-
def sample(logits, temperature: float = 1.0, top_k: Optional[int] = None):
98-
probs = logits_to_probs(logits[:, -1], temperature, top_k)
99-
idx_next = multinomial_sample_one_no_sync(probs)
100-
return idx_next, probs
101-
102-
103-
def prefill(
104-
model: Transformer, x: torch.Tensor, input_pos: torch.Tensor, **sampling_kwargs
105-
) -> torch.Tensor:
106-
# input_pos: [B, S]
107-
logits = model(x, input_pos)
108-
return sample(logits, **sampling_kwargs)[0]
109-
110-
111-
def decode_one_token(
112-
model: Transformer, x: torch.Tensor, input_pos: torch.Tensor, **sampling_kwargs
113-
) -> Tuple[torch.Tensor, torch.Tensor]:
114-
# input_pos: [B, 1]
115-
assert input_pos.shape[-1] == 1
116-
logits = model(x, input_pos)
117-
return sample(logits, **sampling_kwargs)
118-
119-
120-
def decode_n_tokens(
121-
model: Transformer,
122-
cur_token: torch.Tensor,
123-
input_pos: torch.Tensor,
124-
num_new_tokens: int,
125-
callback=lambda _: _,
126-
**sampling_kwargs,
127-
):
128-
new_tokens, new_probs = [], []
129-
for i in range(num_new_tokens):
130-
with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.MATH):
131-
next_token, next_prob = decode_one_token(
132-
model, cur_token, input_pos, **sampling_kwargs
133-
)
134-
next_token, next_prob = next_token.clone(), next_prob.clone()
135-
input_pos += 1
136-
# in some instances not having this causes weird issues with the stored tokens when you run the next decode_one_token step
137-
new_tokens.append(next_token.clone())
138-
callback(new_tokens[-1])
139-
new_probs.append(next_prob)
140-
cur_token = next_token
141-
142-
return new_tokens, new_probs
66+
from torchao._models.llm.model import Transformer, prepare_inputs_for_model
67+
from torchao._models.llm.tokenizer import get_tokenizer
14368

14469

14570
def model_forward(model, x, input_pos):
@@ -230,25 +155,6 @@ def generate(
230155
return seq
231156

232157

233-
def encode_tokens(tokenizer, string, bos=True, device=default_device):
234-
tokens = tokenizer.encode(string)
235-
if bos:
236-
tokens = [tokenizer.bos_id()] + tokens
237-
return torch.tensor(tokens, dtype=torch.int, device=device)
238-
239-
240-
def _load_model(checkpoint_path, device, precision):
241-
checkpoint = torch.load(str(checkpoint_path), mmap=True, weights_only=True)
242-
if "model" in checkpoint and "stories" in str(checkpoint_path):
243-
checkpoint = checkpoint["model"]
244-
with torch.device("meta"):
245-
model = Transformer.from_name(checkpoint_path.parent.name)
246-
model.load_state_dict(checkpoint, assign=True)
247-
model = model.to(device=device, dtype=precision)
248-
249-
return model.eval()
250-
251-
252158
B_INST, E_INST = "[INST]", "[/INST]"
253159

254160

@@ -575,8 +481,8 @@ def ffn_or_attn_only(mod, fqn):
575481
model, float8_dynamic_activation_float8_weight(granularity=granularity)
576482
)
577483
elif "autoquant_v2" in quantization:
578-
from torchao._models.model import prepare_inputs_for_model
579484
from torchao._models._eval import InputRecorder
485+
from torchao._models.llm.model import prepare_inputs_for_model
580486
from torchao.prototype.quantization.autoquant_v2 import autoquant_v2
581487

582488
calibration_seq_length = 256
@@ -665,8 +571,8 @@ def ffn_or_attn_only(mod, fqn):
665571
# do autoquantization
666572
model.finalize_autoquant()
667573
elif "autoquant" in quantization:
668-
from torchao._models.model import prepare_inputs_for_model
669574
from torchao._models._eval import InputRecorder
575+
from torchao._models.llm.model import prepare_inputs_for_model
670576

671577
calibration_seq_length = 256
672578
inputs = (

benchmarks/_models/llama/perf_profile.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -116,8 +116,8 @@
116116
import torch
117117
from torch.nn.attention import SDPBackend
118118

119-
from torchao._models.model import Transformer
120-
from torchao._models.tokenizer import get_tokenizer
119+
from torchao._models.llm.model import Transformer
120+
from torchao._models.llm.tokenizer import get_tokenizer
121121
from torchao.prototype.profiler import (
122122
CUDADeviceSpec,
123123
TransformerPerformanceCounter,

benchmarks/_models/sam/eval_combo.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from metrics import calculate_miou, create_result_entry
1010

1111
import torchao
12-
from benchmarks._models.utils import (
12+
from torchao._models.utils import (
1313
get_arch_name,
1414
write_json_result_local,
1515
write_json_result_ossci,

benchmarks/quantized_training/pretrain_llama2.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,13 +22,13 @@
2222
from torch.utils.checkpoint import checkpoint
2323
from tqdm import tqdm
2424

25-
from torchao._models.model import (
25+
from torchao import quantize_
26+
from torchao._models.llm.model import (
2627
ModelArgs,
2728
RMSNorm,
2829
Transformer,
2930
transformer_configs,
3031
)
31-
from torchao import quantize_
3232
from torchao.prototype import low_bit_optim
3333
from torchao.prototype.quantized_training import (
3434
bitnet_training,

examples/sam2_amg_server/annotate_with_rle.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
)
1515
from tqdm import tqdm
1616

17-
from benchmarks._models.sam2.utils.amg import area_from_rle, rle_to_mask
17+
from torchao._models.sam2.utils.amg import area_from_rle, rle_to_mask
1818

1919

2020
def timestamped_print(*args, **kwargs):

examples/sam2_amg_server/cli.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,9 @@
1212
show_anns,
1313
)
1414

15-
from benchmarks._models.sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator
16-
from benchmarks._models.sam2.build_sam import build_sam2
17-
from benchmarks._models.sam2.utils.amg import rle_to_mask
15+
from torchao._models.sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator
16+
from torchao._models.sam2.build_sam import build_sam2
17+
from torchao._models.sam2.utils.amg import rle_to_mask
1818

1919

2020
def main_docstring():

examples/sam2_amg_server/cli_on_modal.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -84,10 +84,10 @@ def build(self):
8484
from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator
8585
from sam2.build_sam import build_sam2
8686
else:
87-
from benchmarks._models.sam2.automatic_mask_generator import (
87+
from torchao._models.sam2.automatic_mask_generator import (
8888
SAM2AutomaticMaskGenerator,
8989
)
90-
from benchmarks._models.sam2.build_sam import build_sam2
90+
from torchao._models.sam2.build_sam import build_sam2
9191

9292
os.chdir(f"{TARGET}ao_src_0/examples/sam2_amg_server")
9393
import sys
@@ -139,11 +139,11 @@ def build(self):
139139
from sam2.utils.amg import mask_to_rle_pytorch as mask_to_rle_pytorch_2
140140
from sam2.utils.amg import rle_to_mask
141141
else:
142-
from benchmarks._models.sam2.utils.amg import (
142+
from torchao._models.sam2.utils.amg import (
143143
mask_to_rle_pytorch_2,
144144
rle_to_mask,
145145
)
146-
from benchmarks._models.sam2.utils.amg import area_from_rle
146+
from torchao._models.sam2.utils.amg import area_from_rle
147147

148148
self.np = np
149149
self.tio = tio

examples/sam2_amg_server/compare_rle_lists.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import torch
88

99

10-
# from benchmarks._models.sam2.utils.amg import rle_to_mask
10+
# from torchao._models.sam2.utils.amg import rle_to_mask
1111
def rle_to_mask(rle: Dict[str, Any]) -> np.ndarray:
1212
"""Compute a binary mask from an uncompressed RLE."""
1313
h, w = rle["size"]

examples/sam2_amg_server/compile_export_utils.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
import torch
66

7-
from benchmarks._models.sam2.sam2_image_predictor import SAM2ImagePredictor
7+
from torchao._models.sam2.sam2_image_predictor import SAM2ImagePredictor
88

99
# Tools used to avoid compilation cold start and dynamo cache lookups
1010
# We take the compiled model and export it using the largest
@@ -519,12 +519,12 @@ def set_fast(
519519
# A bunch of extra compiles at module level
520520
# Note that this can cause recompilations!
521521
# We might want to guard on that
522-
benchmarks._models.sam2.utils.amg._mask_to_rle_pytorch_2_0_0 = torch.compile(
522+
torchao._models.sam2.utils.amg._mask_to_rle_pytorch_2_0_0 = torch.compile(
523523
fullgraph=True, dynamic=True
524-
)(benchmarks._models.sam2.utils.amg._mask_to_rle_pytorch_2_0_0)
525-
benchmarks._models.sam2.utils.amg._mask_to_rle_pytorch_2_0_1 = torch.compile(
524+
)(torchao._models.sam2.utils.amg._mask_to_rle_pytorch_2_0_0)
525+
torchao._models.sam2.utils.amg._mask_to_rle_pytorch_2_0_1 = torch.compile(
526526
fullgraph=True, dynamic=True
527-
)(benchmarks._models.sam2.utils.amg._mask_to_rle_pytorch_2_0_1)
527+
)(torchao._models.sam2.utils.amg._mask_to_rle_pytorch_2_0_1)
528528
mask_generator.calculate_stability_score = torch.compile(
529529
fullgraph=True, dynamic=True
530530
)(mask_generator.calculate_stability_score)

0 commit comments

Comments
 (0)