From 5b3eff170bee6c32d590f8bc0135201bbd484e9f Mon Sep 17 00:00:00 2001 From: nlpfollower Date: Wed, 15 Jan 2025 19:09:53 -0800 Subject: [PATCH 1/5] Add encoded size to start_pos --- torchchat/generate.py | 1 + 1 file changed, 1 insertion(+) diff --git a/torchchat/generate.py b/torchchat/generate.py index ad933687d..ebbe56d5e 100644 --- a/torchchat/generate.py +++ b/torchchat/generate.py @@ -1192,6 +1192,7 @@ def callback(x, *, done_generating=False): max_seq_length=max_seq_length, attention_backend=self.builder_args.attention_backend, ) + start_pos += encoded.size(0) for token_tensor, metrics in generator_func: if token_tensor is not None: start_pos += token_tensor.size(0) From 722fb785673044f8e7430ae820b666b8ea73bb75 Mon Sep 17 00:00:00 2001 From: nlpfollower Date: Wed, 15 Jan 2025 20:05:30 -0800 Subject: [PATCH 2/5] Only in chat mode --- torchchat/generate.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/torchchat/generate.py b/torchchat/generate.py index ebbe56d5e..7f37386ac 100644 --- a/torchchat/generate.py +++ b/torchchat/generate.py @@ -1192,7 +1192,8 @@ def callback(x, *, done_generating=False): max_seq_length=max_seq_length, attention_backend=self.builder_args.attention_backend, ) - start_pos += encoded.size(0) + if generator_args.chat_mode: + start_pos += encoded.size(0) for token_tensor, metrics in generator_func: if token_tensor is not None: start_pos += token_tensor.size(0) From f8f7ad26ab68300071730ce9912db684910d128a Mon Sep 17 00:00:00 2001 From: nlpfollower Date: Thu, 16 Jan 2025 20:21:54 -0800 Subject: [PATCH 3/5] Debug logs --- torchchat/generate.py | 10 ++++++++++ torchchat/model.py | 2 ++ 2 files changed, 12 insertions(+) diff --git a/torchchat/generate.py b/torchchat/generate.py index 7f37386ac..3f562d20e 100644 --- a/torchchat/generate.py +++ b/torchchat/generate.py @@ -576,6 +576,9 @@ def decode_n_tokens( **sampling_kwargs, ) input_pos += 1 + if os.getenv('DEBUG_CACHE'): + print(f"final token input_pos: {input_pos}") + yield cur_token.clone(), next_prob.clone() break if not encountered_eos: @@ -1174,6 +1177,7 @@ def callback(x, *, done_generating=False): prof = torch.profiler.profile() t0 = time.perf_counter() num_tokens_generated = 0 + local_token_tensor = [] with prof: generator_func = self.generate( self.model, @@ -1196,6 +1200,9 @@ def callback(x, *, done_generating=False): start_pos += encoded.size(0) for token_tensor, metrics in generator_func: if token_tensor is not None: + if os.getenv('DEBUG_CACHE'): + print(f"Token tensor: {token_tensor}") + local_token_tensor.append(token_tensor.tolist()[0]) start_pos += token_tensor.size(0) num_tokens_generated += token_tensor.size(0) if metrics is not None: @@ -1204,6 +1211,9 @@ def callback(x, *, done_generating=False): jit_compile = is_first_sample and ( generator_args.compile or generator_args.compile_prefill ) + if os.getenv('DEBUG_CACHE'): + print(f"local_token_tensor: {local_token_tensor}") + print(self.tokenizer.decode(local_token_tensor)) compilation_time = time.perf_counter() - t0 device_sync(device=self.builder_args.device) t = time.perf_counter() - t0 diff --git a/torchchat/model.py b/torchchat/model.py index c01ff1262..58f5b46a8 100644 --- a/torchchat/model.py +++ b/torchchat/model.py @@ -723,6 +723,8 @@ def distribute(self, device_mesh: DeviceMesh): def forward(self, x: Tensor, input_pos: Optional[Tensor] = None, cache_lane: int = 0) -> Tensor: assert self.freqs_cis is not None, "Caches must be initialized first" + if os.getenv('DEBUG_CACHE'): + print("Transformer forward input pos", input_pos) mask = self.causal_mask[None, None, input_pos] freqs_cis = self.freqs_cis[input_pos] if self.tok_embeddings: From ba5d527a5afcf8302bf25db5595671a796cb35ae Mon Sep 17 00:00:00 2001 From: nlpfollower Date: Thu, 23 Jan 2025 13:15:42 -0800 Subject: [PATCH 4/5] debug --- torchchat/generate.py | 113 ++++++++++++++++++++++-------------------- torchchat/model.py | 9 +++- 2 files changed, 66 insertions(+), 56 deletions(-) diff --git a/torchchat/generate.py b/torchchat/generate.py index 3f562d20e..edc46b257 100644 --- a/torchchat/generate.py +++ b/torchchat/generate.py @@ -20,6 +20,7 @@ from pathlib import Path from typing import Any, Dict, List, Optional, Sequence, Tuple, Union +import pydevd_pycharm import torch import torch._dynamo.config import torch._inductor.config @@ -575,9 +576,9 @@ def decode_n_tokens( batch=batch, **sampling_kwargs, ) - input_pos += 1 if os.getenv('DEBUG_CACHE'): - print(f"final token input_pos: {input_pos}") + print(f"last cur_token: {cur_token}") + print(f"final token {final_token} input_pos: {input_pos}") yield cur_token.clone(), next_prob.clone() break @@ -839,7 +840,8 @@ def _callback(self, x, *, buffer, done_generating): done_generating = True buffer = buffer[:-1] # drop the eot_id from the output buffer if len(buffer) == 4 or done_generating: - print("".join(buffer), end="", flush=True) + if not os.getenv('DEBUG_CACHE'): + print("".join(buffer), end="", flush=True) buffer.clear() def _gen_model_input( @@ -1201,7 +1203,7 @@ def callback(x, *, done_generating=False): for token_tensor, metrics in generator_func: if token_tensor is not None: if os.getenv('DEBUG_CACHE'): - print(f"Token tensor: {token_tensor}") + # print(f"Token tensor: {token_tensor}") local_token_tensor.append(token_tensor.tolist()[0]) start_pos += token_tensor.size(0) num_tokens_generated += token_tensor.size(0) @@ -1216,57 +1218,58 @@ def callback(x, *, done_generating=False): print(self.tokenizer.decode(local_token_tensor)) compilation_time = time.perf_counter() - t0 device_sync(device=self.builder_args.device) - t = time.perf_counter() - t0 - if hasattr(prof, "export_chrome_trace"): - if self.builder_args.device == "cpu": - print(prof.key_averages().table(sort_by="self_cpu_time_total")) - elif self.builder_args.device == "cuda": - print(prof.key_averages().table(sort_by="self_cuda_time_total")) - else: - print(prof.key_averages().table(sort_by="self_xpu_time_total")) - prof.export_chrome_trace(f"{self.profile}.json") - - if start_pos >= max_seq_length: - print( - f"[Max Sequence Length {max_seq_length} Reached. Ending Conversation.]" - ) - print("---------------------------------------------------") - - tokens_sec = (num_tokens_generated + 1) / t - first_token_sec = 1 / aggregate_metrics.get("time_to_first_token", 0) - next_tokens_sec = num_tokens_generated / ( - t - aggregate_metrics.get("time_to_first_token", 0) - ) - - if jit_compile: - print( - f"just-in-time compilation time (incl run time): {compilation_time:.2} seconds" - ) - else: - # aggregate_metrics will not append when is jit_compile, which will affect the average numbers. - aggregate_metrics["tokens_per_sec"].append(tokens_sec) - aggregate_metrics["first_token_per_sec"].append(first_token_sec) - aggregate_metrics["next_tokens_per_sec"].append(next_tokens_sec) - - logging.info( - f"\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\ - \nGenerated {num_tokens_generated} tokens \ - \nTime for inference {i + 1}: {t:.04f} sec total \ - \nTime to first token: {aggregate_metrics.get('time_to_first_token', 0):.04f} sec \ -with {'sequential' if generator_args.sequential_prefill else 'parallel'} prefill.\ - \n\n Total throughput: {tokens_sec:.04f} tokens/sec, {1 / tokens_sec:.04f} s/token \ - \nFirst token throughput: {first_token_sec:.04f} tokens/sec, {1 / first_token_sec:.04f} s/token \ - \n Next token throughput: {next_tokens_sec:.04f} tokens/sec, {1 / next_tokens_sec:.04f} s/token \ - " - ) - logging.info( - f"\nBandwidth achieved: {model_size * tokens_sec / 1e9:.02f} GB/s" - ) - if i == 0: - logging.info( - f"*** This first iteration will include cold start effects for dynamic import, hardware caches{', JIT compilation' if jit_compile else ''}. ***" - ) - print("\n========================================\n") +# t = time.perf_counter() - t0 +# if hasattr(prof, "export_chrome_trace"): +# if self.builder_args.device == "cpu": +# print(prof.key_averages().table(sort_by="self_cpu_time_total")) +# elif self.builder_args.device == "cuda": +# print(prof.key_averages().table(sort_by="self_cuda_time_total")) +# else: +# print(prof.key_averages().table(sort_by="self_xpu_time_total")) +# prof.export_chrome_trace(f"{self.profile}.json") +# +# if start_pos >= max_seq_length: +# print( +# f"[Max Sequence Length {max_seq_length} Reached. Ending Conversation.]" +# ) +# print("---------------------------------------------------") +# +# tokens_sec = (num_tokens_generated + 1) / t +# first_token_sec = 1 / aggregate_metrics.get("time_to_first_token", 0) +# next_tokens_sec = num_tokens_generated / ( +# t - aggregate_metrics.get("time_to_first_token", 0) +# ) +# +# if jit_compile: +# print( +# f"just-in-time compilation time (incl run time): {compilation_time:.2} seconds" +# ) +# else: +# # aggregate_metrics will not append when is jit_compile, which will affect the average numbers. +# aggregate_metrics["tokens_per_sec"].append(tokens_sec) +# aggregate_metrics["first_token_per_sec"].append(first_token_sec) +# aggregate_metrics["next_tokens_per_sec"].append(next_tokens_sec) +# +# logging.info( +# f"\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\ +# \nGenerated {num_tokens_generated} tokens \ +# \nTime for inference {i + 1}: {t:.04f} sec total \ +# \nTime to first token: {aggregate_metrics.get('time_to_first_token', 0):.04f} sec \ +# with {'sequential' if generator_args.sequential_prefill else 'parallel'} prefill.\ +# \n\n Total throughput: {tokens_sec:.04f} tokens/sec, {1 / tokens_sec:.04f} s/token \ +# \nFirst token throughput: {first_token_sec:.04f} tokens/sec, {1 / first_token_sec:.04f} s/token \ +# \n Next token throughput: {next_tokens_sec:.04f} tokens/sec, {1 / next_tokens_sec:.04f} s/token \ +# " +# ) +# logging.info( +# f"\nBandwidth achieved: {model_size * tokens_sec / 1e9:.02f} GB/s" +# ) +# if i == 0: +# logging.info( +# f"*** This first iteration will include cold start effects for dynamic import, hardware caches{', JIT compilation' if jit_compile else ''}. ***" +# ) +# print("\n========================================\n") + print() if start_pos >= max_seq_length: if generator_args.chat_mode: break diff --git a/torchchat/model.py b/torchchat/model.py index 58f5b46a8..c912b59be 100644 --- a/torchchat/model.py +++ b/torchchat/model.py @@ -724,7 +724,7 @@ def distribute(self, device_mesh: DeviceMesh): def forward(self, x: Tensor, input_pos: Optional[Tensor] = None, cache_lane: int = 0) -> Tensor: assert self.freqs_cis is not None, "Caches must be initialized first" if os.getenv('DEBUG_CACHE'): - print("Transformer forward input pos", input_pos) + print(f"Transformer forward input pos: {input_pos}", end=" | ") mask = self.causal_mask[None, None, input_pos] freqs_cis = self.freqs_cis[input_pos] if self.tok_embeddings: @@ -745,6 +745,13 @@ def forward(self, x: Tensor, input_pos: Optional[Tensor] = None, cache_lane: int if self.config.logits_scaling: x = x / self.config.logits_scaling # print(f"output shape: {x.shape}") + if os.getenv('DEBUG_CACHE'): + def get_cache_sample(input_pos): + return self.layers['0'].attention.kv_cache[0].k_cache[0][0][input_pos][0].tolist() + last_input_pos = input_pos[-1].item() + print(f"Transformer after forward cache (input_pos={last_input_pos-1}; cache_sample={get_cache_sample(last_input_pos-1)}) " + f"(input_pos={last_input_pos}; cache_sample={get_cache_sample(last_input_pos)}) (input_pos={last_input_pos+1}; cache_sample={get_cache_sample(last_input_pos+1)})") + return x From 44d9a07b56f12e36c71ee0401c25460dd232ca63 Mon Sep 17 00:00:00 2001 From: nlpfollower Date: Thu, 23 Jan 2025 13:42:43 -0800 Subject: [PATCH 5/5] fix --- torchchat/model.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/torchchat/model.py b/torchchat/model.py index c912b59be..65d1ef39e 100644 --- a/torchchat/model.py +++ b/torchchat/model.py @@ -724,7 +724,10 @@ def distribute(self, device_mesh: DeviceMesh): def forward(self, x: Tensor, input_pos: Optional[Tensor] = None, cache_lane: int = 0) -> Tensor: assert self.freqs_cis is not None, "Caches must be initialized first" if os.getenv('DEBUG_CACHE'): - print(f"Transformer forward input pos: {input_pos}", end=" | ") + def get_cache_sample(input_pos): + return self.layers['0'].attention.kv_cache[0].k_cache[0][0][input_pos][0].tolist() + first_pos = input_pos[0].item() + print(f"Transformer before forward input pos: {input_pos}; (pos={first_pos} cache_sample={get_cache_sample(first_pos)})" , end=" | ") mask = self.causal_mask[None, None, input_pos] freqs_cis = self.freqs_cis[input_pos] if self.tok_embeddings: @@ -748,9 +751,8 @@ def forward(self, x: Tensor, input_pos: Optional[Tensor] = None, cache_lane: int if os.getenv('DEBUG_CACHE'): def get_cache_sample(input_pos): return self.layers['0'].attention.kv_cache[0].k_cache[0][0][input_pos][0].tolist() - last_input_pos = input_pos[-1].item() - print(f"Transformer after forward cache (input_pos={last_input_pos-1}; cache_sample={get_cache_sample(last_input_pos-1)}) " - f"(input_pos={last_input_pos}; cache_sample={get_cache_sample(last_input_pos)}) (input_pos={last_input_pos+1}; cache_sample={get_cache_sample(last_input_pos+1)})") + first_pos = input_pos[0].item() + print(f"Transformer after forward cache (pos={first_pos}; cache_sample={get_cache_sample(first_pos)})") return x