|
7 | 7 | import time |
8 | 8 | from datetime import datetime |
9 | 9 | from pathlib import Path |
10 | | -from typing import Optional, Tuple |
| 10 | +from typing import Optional |
11 | 11 |
|
12 | 12 | import torch |
13 | 13 | import torch._dynamo.config |
14 | 14 | import torch._inductor.config |
15 | 15 |
|
16 | 16 | import torchao |
17 | | -from benchmarks._models.utils import ( |
| 17 | +from torchao._models.utils import ( |
| 18 | + _load_model, |
| 19 | + decode_n_tokens, |
| 20 | + decode_one_token, |
| 21 | + encode_tokens, |
18 | 22 | get_arch_name, |
| 23 | + prefill, |
19 | 24 | write_json_result_local, |
20 | 25 | write_json_result_ossci, |
21 | 26 | ) |
22 | 27 | from torchao.quantization.quant_primitives import MappingType |
23 | | -from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, get_model_size_in_bytes |
| 28 | +from torchao.utils import ( |
| 29 | + TORCH_VERSION_AT_LEAST_2_5, |
| 30 | + default_device, |
| 31 | + device_sync, |
| 32 | + get_model_size_in_bytes, |
| 33 | +) |
24 | 34 |
|
25 | 35 | torch.sparse.SparseSemiStructuredTensor._FORCE_CUTLASS = False |
26 | 36 | torch.backends.cuda.enable_cudnn_sdp(True) |
@@ -49,97 +59,12 @@ def device_timer(device): |
49 | 59 | print(f"device={device} is not yet suppported") |
50 | 60 |
|
51 | 61 |
|
52 | | -def device_sync(device): |
53 | | - if "cuda" in device: |
54 | | - torch.cuda.synchronize(device) |
55 | | - elif "xpu" in device: |
56 | | - torch.xpu.synchronize(device) |
57 | | - elif ("cpu" in device) or ("mps" in device): |
58 | | - pass |
59 | | - else: |
60 | | - print(f"device={device} is not yet suppported") |
61 | | - |
62 | | - |
63 | | -default_device = ( |
64 | | - "cuda" |
65 | | - if torch.cuda.is_available() |
66 | | - else "xpu" |
67 | | - if torch.xpu.is_available() |
68 | | - else "cpu" |
69 | | -) |
70 | | - |
71 | 62 | # support running without installing as a package |
72 | 63 | wd = Path(__file__).parent.parent.resolve() |
73 | 64 | sys.path.append(str(wd)) |
74 | 65 |
|
75 | | -from torchao._models.model import Transformer, prepare_inputs_for_model |
76 | | -from torchao._models.tokenizer import get_tokenizer |
77 | | - |
78 | | - |
79 | | -def multinomial_sample_one_no_sync( |
80 | | - probs_sort, |
81 | | -): # Does multinomial sampling without a cuda synchronization |
82 | | - q = torch.empty_like(probs_sort).exponential_(1) |
83 | | - return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int) |
84 | | - |
85 | | - |
86 | | -def logits_to_probs(logits, temperature: float = 1.0, top_k: Optional[int] = None): |
87 | | - logits = logits / max(temperature, 1e-5) |
88 | | - |
89 | | - if top_k is not None: |
90 | | - v, _ = torch.topk(logits, min(top_k, logits.size(-1))) |
91 | | - pivot = v.select(-1, -1).unsqueeze(-1) |
92 | | - logits = torch.where(logits < pivot, -float("Inf"), logits) |
93 | | - probs = torch.nn.functional.softmax(logits, dim=-1) |
94 | | - return probs |
95 | | - |
96 | | - |
97 | | -def sample(logits, temperature: float = 1.0, top_k: Optional[int] = None): |
98 | | - probs = logits_to_probs(logits[:, -1], temperature, top_k) |
99 | | - idx_next = multinomial_sample_one_no_sync(probs) |
100 | | - return idx_next, probs |
101 | | - |
102 | | - |
103 | | -def prefill( |
104 | | - model: Transformer, x: torch.Tensor, input_pos: torch.Tensor, **sampling_kwargs |
105 | | -) -> torch.Tensor: |
106 | | - # input_pos: [B, S] |
107 | | - logits = model(x, input_pos) |
108 | | - return sample(logits, **sampling_kwargs)[0] |
109 | | - |
110 | | - |
111 | | -def decode_one_token( |
112 | | - model: Transformer, x: torch.Tensor, input_pos: torch.Tensor, **sampling_kwargs |
113 | | -) -> Tuple[torch.Tensor, torch.Tensor]: |
114 | | - # input_pos: [B, 1] |
115 | | - assert input_pos.shape[-1] == 1 |
116 | | - logits = model(x, input_pos) |
117 | | - return sample(logits, **sampling_kwargs) |
118 | | - |
119 | | - |
120 | | -def decode_n_tokens( |
121 | | - model: Transformer, |
122 | | - cur_token: torch.Tensor, |
123 | | - input_pos: torch.Tensor, |
124 | | - num_new_tokens: int, |
125 | | - callback=lambda _: _, |
126 | | - **sampling_kwargs, |
127 | | -): |
128 | | - new_tokens, new_probs = [], [] |
129 | | - for i in range(num_new_tokens): |
130 | | - with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.MATH): |
131 | | - next_token, next_prob = decode_one_token( |
132 | | - model, cur_token, input_pos, **sampling_kwargs |
133 | | - ) |
134 | | - next_token, next_prob = next_token.clone(), next_prob.clone() |
135 | | - input_pos += 1 |
136 | | - # in some instances not having this causes weird issues with the stored tokens when you run the next decode_one_token step |
137 | | - new_tokens.append(next_token.clone()) |
138 | | - callback(new_tokens[-1]) |
139 | | - new_probs.append(next_prob) |
140 | | - cur_token = next_token |
141 | | - |
142 | | - return new_tokens, new_probs |
| 66 | +from torchao._models.llm.model import Transformer, prepare_inputs_for_model |
| 67 | +from torchao._models.llm.tokenizer import get_tokenizer |
143 | 68 |
|
144 | 69 |
|
145 | 70 | def model_forward(model, x, input_pos): |
@@ -230,25 +155,6 @@ def generate( |
230 | 155 | return seq |
231 | 156 |
|
232 | 157 |
|
233 | | -def encode_tokens(tokenizer, string, bos=True, device=default_device): |
234 | | - tokens = tokenizer.encode(string) |
235 | | - if bos: |
236 | | - tokens = [tokenizer.bos_id()] + tokens |
237 | | - return torch.tensor(tokens, dtype=torch.int, device=device) |
238 | | - |
239 | | - |
240 | | -def _load_model(checkpoint_path, device, precision): |
241 | | - checkpoint = torch.load(str(checkpoint_path), mmap=True, weights_only=True) |
242 | | - if "model" in checkpoint and "stories" in str(checkpoint_path): |
243 | | - checkpoint = checkpoint["model"] |
244 | | - with torch.device("meta"): |
245 | | - model = Transformer.from_name(checkpoint_path.parent.name) |
246 | | - model.load_state_dict(checkpoint, assign=True) |
247 | | - model = model.to(device=device, dtype=precision) |
248 | | - |
249 | | - return model.eval() |
250 | | - |
251 | | - |
252 | 158 | B_INST, E_INST = "[INST]", "[/INST]" |
253 | 159 |
|
254 | 160 |
|
@@ -575,8 +481,8 @@ def ffn_or_attn_only(mod, fqn): |
575 | 481 | model, float8_dynamic_activation_float8_weight(granularity=granularity) |
576 | 482 | ) |
577 | 483 | elif "autoquant_v2" in quantization: |
578 | | - from torchao._models.model import prepare_inputs_for_model |
579 | 484 | from torchao._models._eval import InputRecorder |
| 485 | + from torchao._models.llm.model import prepare_inputs_for_model |
580 | 486 | from torchao.prototype.quantization.autoquant_v2 import autoquant_v2 |
581 | 487 |
|
582 | 488 | calibration_seq_length = 256 |
@@ -665,8 +571,8 @@ def ffn_or_attn_only(mod, fqn): |
665 | 571 | # do autoquantization |
666 | 572 | model.finalize_autoquant() |
667 | 573 | elif "autoquant" in quantization: |
668 | | - from torchao._models.model import prepare_inputs_for_model |
669 | 574 | from torchao._models._eval import InputRecorder |
| 575 | + from torchao._models.llm.model import prepare_inputs_for_model |
670 | 576 |
|
671 | 577 | calibration_seq_length = 256 |
672 | 578 | inputs = ( |
|
0 commit comments