Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit 12e3c65

Browse files
Merge branch 'main' into 1520-enable-torchao-experimental-embedding-quant
2 parents 49039c7 + 5f8f35d commit 12e3c65

File tree

9 files changed

+151
-87
lines changed

9 files changed

+151
-87
lines changed

.github/workflows/pull.yml

Lines changed: 2 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -950,27 +950,11 @@ jobs:
950950
run: |
951951
export TORCHCHAT_ROOT=${PWD}
952952
echo "et-git-hash=$(cat ${TORCHCHAT_ROOT}/install/.pins/et-pin.txt)" >> "$GITHUB_ENV"
953-
- name: Load or install ET
954-
id: install-et
955-
uses: actions/cache@v4
956-
with:
957-
path: |
958-
./et-build
959-
./torchchat/utils/scripts
960-
key: et-build-${{runner.os}}-${{runner.arch}}-${{env.et-git-hash}}-${{ hashFiles('**/install_et.sh', '**/build_native.sh') }}
961-
- if: ${{ steps.install-et.outputs.cache-hit != 'true' }}
962-
continue-on-error: true
953+
- name: Install ExecuTorch
963954
run: |
964955
echo "Installing ExecuTorch"
956+
export TORCHCHAT_ROOT=${PWD}
965957
bash torchchat/utils/scripts/install_et.sh
966-
- name: Install ExecuTorch python
967-
run: |
968-
echo "Install ExecuTorch python"
969-
export TORCHCHAT_ROOT=$PWD
970-
export ET_BUILD_DIR="et-build"
971-
ENABLE_ET_PYBIND="${1:-true}"
972-
source "torchchat/utils/scripts/install_utils.sh"
973-
install_executorch_python_libs $ENABLE_ET_PYBIND
974958
- name: Install runner
975959
run: |
976960
echo "Installing runner"

install/.pins/torchao-pin.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
a96eeb1c7d7ba24cf0ccfc105141729acfed22bf
1+
7513042f39515af4c643bc1f9399952ad7f4f904

install/install_torch.sh

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,13 @@ then
6666
torchvision=="0.22.0.${VISION_NIGHTLY_VERSION}"
6767
#torchtune=="0.7.0" # no 0.6.0 on xpu nightly
6868
)
69+
elif [[ -x "$(command -v npu-smi)" ]];
70+
then
71+
REQUIREMENTS_TO_INSTALL=(
72+
torch=="2.7.0.dev20250310+cpu"
73+
torchvision=="0.22.0.dev20250310"
74+
torchtune=="0.6.0"
75+
)
6976
else
7077
REQUIREMENTS_TO_INSTALL=(
7178
torch=="2.8.0.${PYTORCH_NIGHTLY_VERSION}"

torchchat/cli/builder.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
from torchchat.utils.build_utils import (
3030
device_sync,
3131
is_cpu_device,
32-
is_cuda_or_cpu_or_xpu_device,
32+
is_supported_device,
3333
name_to_dtype,
3434
)
3535
from torchchat.utils.measure_time import measure_time
@@ -74,10 +74,8 @@ class BuilderArgs:
7474

7575
def __post_init__(self):
7676
if self.device is None:
77-
if torch.cuda.is_available():
78-
self.device = "cuda"
79-
elif torch.xpu.is_available():
80-
self.device = "xpu"
77+
if torch.accelerator.is_available():
78+
self.device = torch.accelerator.current_accelerator().type
8179
else:
8280
self.device = "cpu"
8381

@@ -539,7 +537,7 @@ def _initialize_model(
539537
_set_gguf_kwargs(builder_args, is_et=is_pte, context="generate")
540538

541539
if builder_args.dso_path:
542-
if not is_cuda_or_cpu_or_xpu_device(builder_args.device):
540+
if not is_supported_device(builder_args.device):
543541
print(
544542
f"Cannot load specified DSO to {builder_args.device}. Attempting to load model to CPU instead"
545543
)
@@ -573,7 +571,7 @@ def do_nothing(max_batch_size, max_seq_length):
573571
raise RuntimeError(f"Failed to load AOTI compiled {builder_args.dso_path}")
574572

575573
elif builder_args.aoti_package_path:
576-
if not is_cuda_or_cpu_or_xpu_device(builder_args.device):
574+
if not is_supported_device(builder_args.device):
577575
print(
578576
f"Cannot load specified PT2 to {builder_args.device}. Attempting to load model to CPU instead"
579577
)

torchchat/cli/cli.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -176,8 +176,8 @@ def _add_model_config_args(parser, verb: str) -> None:
176176
"--device",
177177
type=str,
178178
default=None,
179-
choices=["fast", "cpu", "cuda", "mps", "xpu"],
180-
help="Hardware device to use. Options: fast, cpu, cuda, mps, xpu",
179+
choices=["fast", "cpu", "cuda", "mps", "xpu", "npu"],
180+
help="Hardware device to use. Options: fast, cpu, cuda, mps, xpu, npu",
181181
)
182182
model_config_parser.add_argument(
183183
"--attention-backend",
@@ -359,6 +359,12 @@ def _add_generation_args(parser, verb: str) -> None:
359359
default=1,
360360
help="Number of samples",
361361
)
362+
generator_parser.add_argument(
363+
"--accumulate-tokens",
364+
type=int,
365+
default=8,
366+
help="Number of generated tokens to accumulate before calling the callback on each one of them.",
367+
)
362368

363369
generator_parser.add_argument(
364370
"--image-prompts",

torchchat/generate.py

Lines changed: 61 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -230,6 +230,7 @@ class GeneratorArgs:
230230
max_autotune: bool = False
231231
# (Misnomer) See Issue: https://github.com/pytorch/torchchat/issues/1273
232232
is_torchtune_model: bool = False
233+
accumulate_tokens: int = 8
233234

234235
def __post_init__(self):
235236
if self.compile_prefill and self.sequential_prefill:
@@ -294,6 +295,7 @@ def from_args(cls, args):
294295
sequential_prefill=sequential_prefill,
295296
max_autotune=args.max_autotune,
296297
is_torchtune_model=args.model and args.model.endswith("tune"),
298+
accumulate_tokens=getattr(args, "accumulate_tokens", 8),
297299
)
298300

299301

@@ -530,12 +532,13 @@ def decode_n_tokens(
530532
need_probs: bool,
531533
batch=Optional[Dict[str, Any]], # Inputs for multimodal models
532534
callback=lambda _: _,
535+
accumulate_tokens: int = 8,
533536
eos_token_id: int = 2,
534537
eot_id: Optional[int] = None,
535538
attention_backend: SDPBackend = torch.nn.attention.SDPBackend.MATH,
536539
**sampling_kwargs,
537540
):
538-
new_tokens, new_probs = [], []
541+
new_tokens = []
539542
encountered_eos = False
540543
for _i in range(
541544
num_new_tokens - 1
@@ -554,38 +557,58 @@ def decode_n_tokens(
554557
)
555558
input_pos += 1
556559
new_tokens.append(next_token.clone())
557-
callback(new_tokens[-1], done_generating=_i == num_new_tokens - 2)
558-
if need_probs or next_prob is None:
560+
561+
done_generating = _i == num_new_tokens - 2
562+
if need_probs:
563+
callback(new_tokens[-1], done_generating=done_generating)
564+
if not need_probs or next_prob is None:
559565
yield out_token, None
560566
else:
561-
new_probs.append(next_prob.clone())
562567
yield out_token, next_prob.clone()
563568
cur_token = next_token
564569

565-
# encountered eos
566-
if next_token.item() == eos_token_id or (
567-
eot_id is not None and next_token.item() == eot_id
568-
):
569-
encountered_eos = True
570-
final_token, next_prob = self.decode_one_token(
571-
model,
572-
cur_token,
573-
input_pos,
574-
need_probs,
575-
batch=batch,
576-
**sampling_kwargs,
577-
)
578-
input_pos += 1
579-
yield cur_token.clone(), next_prob.clone()
580-
break
570+
if need_probs:
571+
# encountered eos
572+
if next_token.item() == eos_token_id or (
573+
eot_id is not None and next_token.item() == eot_id
574+
):
575+
encountered_eos = True
576+
final_token, next_prob = self.decode_one_token(
577+
model,
578+
cur_token,
579+
input_pos,
580+
need_probs,
581+
batch=batch,
582+
**sampling_kwargs,
583+
)
584+
input_pos += 1
585+
yield cur_token.clone(), next_prob.clone()
586+
break
587+
else:
588+
callback_pos = _i % accumulate_tokens + 1
589+
if done_generating or callback_pos == accumulate_tokens:
590+
callback_num = min(accumulate_tokens, callback_pos)
591+
for i in range(callback_num, 0, -1):
592+
callback(new_tokens[-i], done_generating=done_generating)
593+
594+
token_item = new_tokens[-i].item()
595+
# encountered eos
596+
if token_item == eos_token_id or (
597+
eot_id is not None and token_item == eot_id
598+
):
599+
encountered_eos = True
600+
input_pos += 1
601+
yield new_tokens[-i].clone(), None
602+
break
603+
if encountered_eos:
604+
break
581605

582606
if not encountered_eos:
583607
eos_token = torch.tensor(
584608
[eos_token_id if eot_id is None else eot_id],
585609
dtype=cur_token.dtype,
586610
device=cur_token.device,
587611
)
588-
new_tokens.append(eos_token.clone())
589612
eos_token, next_prob = self.decode_one_token(
590613
model,
591614
eos_token.view(1, -1),
@@ -685,6 +708,7 @@ def generate(
685708
speculate_k: Optional[int] = 8,
686709
sequential_prefill=True,
687710
callback=lambda x: x,
711+
accumulate_tokens: int,
688712
max_seq_length: int,
689713
attention_backend: SDPBackend = torch.nn.attention.SDPBackend.MATH,
690714
seed: Optional[int] = None,
@@ -788,14 +812,14 @@ def generate(
788812
input_pos = input_pos + num_added
789813
next_token = next_tokens[-1]
790814
else:
791-
generated_tokens = []
792815
for generated_token, _ in self.decode_n_tokens(
793816
model,
794817
next_token,
795818
input_pos,
796819
max_new_tokens - 1,
797820
batch=batch,
798821
callback=callback,
822+
accumulate_tokens=accumulate_tokens,
799823
need_probs=False,
800824
eos_token_id=self.tokenizer.eos_id() if self.tokenizer else 2,
801825
eot_id=(
@@ -806,7 +830,6 @@ def generate(
806830
attention_backend=attention_backend,
807831
**sampling_kwargs,
808832
):
809-
generated_tokens.append(generated_token.view(-1))
810833
yield generated_token, None
811834

812835
generate_stats = {
@@ -1185,6 +1208,7 @@ def callback(x, *, done_generating=False):
11851208
chat_mode=generator_args.chat_mode,
11861209
batch=batch,
11871210
callback=callback,
1211+
accumulate_tokens=generator_args.accumulate_tokens,
11881212
temperature=generator_args.temperature,
11891213
top_k=generator_args.top_k,
11901214
sequential_prefill=generator_args.sequential_prefill,
@@ -1213,8 +1237,10 @@ def callback(x, *, done_generating=False):
12131237
print(prof.key_averages().table(sort_by="self_cpu_time_total"))
12141238
elif self.builder_args.device == "cuda":
12151239
print(prof.key_averages().table(sort_by="self_cuda_time_total"))
1216-
else:
1240+
elif self.builder_args.device == "xpu":
12171241
print(prof.key_averages().table(sort_by="self_xpu_time_total"))
1242+
elif self.builder_args.device == "npu":
1243+
print(prof.key_averages().table(sort_by="self_npu_time_total"))
12181244
prof.export_chrome_trace(f"{self.profile}.json")
12191245

12201246
if start_pos >= max_seq_length:
@@ -1229,11 +1255,7 @@ def callback(x, *, done_generating=False):
12291255
t - aggregate_metrics.get("time_to_first_token", 0)
12301256
)
12311257

1232-
if jit_compile:
1233-
print(
1234-
f"just-in-time compilation time (incl run time): {compilation_time:.2} seconds"
1235-
)
1236-
else:
1258+
if not jit_compile:
12371259
# aggregate_metrics will not append when is jit_compile, which will affect the average numbers.
12381260
aggregate_metrics["tokens_per_sec"].append(tokens_sec)
12391261
aggregate_metrics["first_token_per_sec"].append(first_token_sec)
@@ -1257,6 +1279,10 @@ def callback(x, *, done_generating=False):
12571279
logging.info(
12581280
f"*** This first iteration will include cold start effects for dynamic import, hardware caches{', JIT compilation' if jit_compile else ''}. ***"
12591281
)
1282+
if jit_compile:
1283+
logging.info(
1284+
f"just-in-time compilation time (incl run time): {compilation_time:.2} seconds"
1285+
)
12601286
print("\n========================================\n")
12611287
if start_pos >= max_seq_length:
12621288
if generator_args.chat_mode:
@@ -1299,8 +1325,10 @@ def callback(x, *, done_generating=False):
12991325
)
13001326
if torch.cuda.is_available():
13011327
print(f"Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB")
1302-
if torch.xpu.is_available():
1328+
elif torch.xpu.is_available():
13031329
print(f"Memory used: {torch.xpu.max_memory_reserved() / 1e9:.02f} GB")
1330+
elif hasattr(torch, "npu") and torch.npu.is_available():
1331+
print(f"Memory used: {torch.npu.max_memory_reserved() / 1e9:.02f} GB")
13041332

13051333

13061334

@@ -1595,7 +1623,6 @@ def sample(
15951623

15961624
return idx_next, probs
15971625

1598-
15991626
def run_generator(
16001627
args,
16011628
rank: Optional[int] =None
@@ -1628,8 +1655,10 @@ def run_generator(
16281655
)
16291656
if torch.cuda.is_available():
16301657
torch.cuda.reset_peak_memory_stats()
1631-
if torch.xpu.is_available():
1658+
elif torch.xpu.is_available():
16321659
torch.xpu.reset_peak_memory_stats()
1660+
elif hasattr(torch, "npu") and torch.npu.is_available():
1661+
torch.npu.reset_peak_memory_stats()
16331662

16341663
for _ in gen.chat(generator_args):
16351664
pass

torchchat/utils/build_utils.py

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -233,6 +233,8 @@ def device_sync(device="cpu"):
233233
torch.cuda.synchronize(device)
234234
elif "xpu" in device:
235235
torch.xpu.synchronize(device)
236+
elif "npu" in device:
237+
torch.npu.synchronize(device)
236238
elif ("cpu" in device) or ("mps" in device):
237239
pass
238240
else:
@@ -275,33 +277,32 @@ def is_mps_available() -> bool:
275277
# MPS, is that you?
276278
return True
277279

280+
def select_device() -> str:
281+
if torch.accelerator.is_available():
282+
device = torch.accelerator.current_accelerator().type
283+
if device == "mps" and not is_mps_available():
284+
return "cpu"
285+
return device
286+
else:
287+
return "cpu"
278288

279289
def get_device_str(device) -> str:
280290
if isinstance(device, str) and device == "fast":
281-
device = (
282-
"cuda"
283-
if torch.cuda.is_available()
284-
else "mps" if is_mps_available()
285-
else "xpu" if torch.xpu.is_available() else "cpu"
286-
)
291+
device = select_device()
287292
return device
288293
else:
289294
return str(device)
290295

291296

292297
def get_device(device) -> str:
293298
if isinstance(device, str) and device == "fast":
294-
device = (
295-
"cuda"
296-
if torch.cuda.is_available()
297-
else "mps" if is_mps_available()
298-
else "xpu" if torch.xpu.is_available() else "cpu"
299-
)
299+
device = select_device()
300300
return torch.device(device)
301301

302302

303303
def is_cpu_device(device) -> bool:
304304
return device == "" or str(device) == "cpu"
305305

306-
def is_cuda_or_cpu_or_xpu_device(device) -> bool:
307-
return is_cpu_device(device) or ("cuda" in str(device)) or ("xpu" in str(device))
306+
def is_supported_device(device) -> bool:
307+
device_str = str(device)
308+
return is_cpu_device(device) or any(dev in device_str for dev in ('cuda', 'xpu', 'npu'))

0 commit comments

Comments
 (0)