Skip to content

Commit d224653

Browse files
authored
Added cpu support for llama generate.py/eval.py (#1307)
1 parent 129316d commit d224653

File tree

2 files changed

+22
-20
lines changed

2 files changed

+22
-20
lines changed

torchao/_models/llama/eval.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -10,27 +10,22 @@
1010
from generate import (
1111
_load_model,
1212
device_sync,
13-
1413
)
15-
from torchao.quantization.quant_api import (
14+
from torchao.quantization import (
1615
quantize_,
1716
int4_weight_only,
1817
int8_weight_only,
1918
int8_dynamic_activation_int8_weight,
2019
fpx_weight_only,
2120
uintx_weight_only,
22-
unwrap_tensor_subclass,
2321
float8_weight_only,
2422
float8_dynamic_activation_float8_weight,
25-
float8_static_activation_float8_weight,
2623
)
27-
from torchao._models._eval import TransformerEvalWrapper, InputRecorder
2824
from torchao._models.llama.model import prepare_inputs_for_model
29-
from torchao.quantization.granularity import PerRow, PerTensor
30-
25+
from torchao.quantization import PerRow, PerTensor
3126
from tokenizer import get_tokenizer
3227
import time
33-
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5
28+
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, unwrap_tensor_subclass
3429

3530
def run_evaluation(
3631
checkpoint_path: Path,

torchao/_models/llama/generate.py

Lines changed: 19 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ def decode_one_token(model: Transformer, x: torch.Tensor, input_pos: torch.Tenso
6767
def decode_n_tokens(model: Transformer, cur_token: torch.Tensor, input_pos: torch.Tensor, num_new_tokens: int, callback=lambda _: _, **sampling_kwargs):
6868
new_tokens, new_probs = [], []
6969
for i in range(num_new_tokens):
70-
with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_mem_efficient=False, enable_math=True): # Actually better for Inductor to codegen attention here
70+
with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.MATH):
7171
next_token, next_prob = decode_one_token(
7272
model, cur_token, input_pos, **sampling_kwargs
7373
)
@@ -345,15 +345,19 @@ def main(
345345
prefill = torch.compile(prefill, fullgraph=True, dynamic=True)
346346

347347
if memory_profile:
348-
torch.cuda.memory._record_memory_history(True,trace_alloc_max_entries=250000, trace_alloc_record_context=True)
348+
if device != "cuda":
349+
print("Memory profiling only works on CUDA")
350+
else:
351+
torch.cuda.memory._record_memory_history(True,trace_alloc_max_entries=250000, trace_alloc_record_context=True)
349352
aggregate_metrics = {
350353
'tokens_per_sec': [],
351354
}
352355
start = -1 if compile else 0
353356

354357
for i in range(start, num_samples):
355358
if i==0:
356-
torch.cuda.reset_peak_memory_stats()
359+
if device == "cuda":
360+
torch.cuda.reset_peak_memory_stats() # MKG
357361
device_sync(device=device) # MKG
358362
if i >= 0 and interactive:
359363
prompt = input("What is your prompt? ")
@@ -421,15 +425,18 @@ def callback(x):
421425
print(f"Bandwidth achieved: {model_size * tokens_sec:.02f} GB/s")
422426

423427
if memory_profile and i==0:
424-
snapshot = torch.cuda.memory._snapshot()
425-
with open(f"{memory_profile}.pickle", 'wb') as f:
426-
from pickle import dump
427-
dump(snapshot, f)
428-
print(
429-
f"\nmemory profile {memory_profile}.pickle saved, to convert that to a usable file, use",
430-
"python pytorch/torch/cuda/_memory_viz.py trace_plot <pickle file> -o <desired output name>.html"
431-
)
432-
break
428+
if device != "cuda":
429+
print("Memory profiling only works on CUDA")
430+
else:
431+
snapshot = torch.cuda.memory._snapshot()
432+
with open(f"{memory_profile}.pickle", 'wb') as f:
433+
from pickle import dump
434+
dump(snapshot, f)
435+
print(
436+
f"\nmemory profile {memory_profile}.pickle saved, to convert that to a usable file, use",
437+
"python pytorch/torch/cuda/_memory_viz.py trace_plot <pickle file> -o <desired output name>.html"
438+
)
439+
break
433440

434441
print("==========")
435442

0 commit comments

Comments
 (0)