From 6a6be219bf15112be195bbe0bd17251cd1d7ade4 Mon Sep 17 00:00:00 2001 From: nlpfollower Date: Thu, 16 Jan 2025 20:21:54 -0800 Subject: [PATCH] Debug logs --- torchchat/generate.py | 10 ++++++++++ torchchat/model.py | 2 ++ 2 files changed, 12 insertions(+) diff --git a/torchchat/generate.py b/torchchat/generate.py index a14ece1ad..811f70ab6 100644 --- a/torchchat/generate.py +++ b/torchchat/generate.py @@ -574,6 +574,9 @@ def decode_n_tokens( **sampling_kwargs, ) input_pos += 1 + if os.getenv('DEBUG_CACHE'): + print(f"final token input_pos: {input_pos}") + yield cur_token.clone(), next_prob.clone() break if not encountered_eos: @@ -1170,6 +1173,7 @@ def callback(x, *, done_generating=False): prof = torch.profiler.profile() t0 = time.perf_counter() num_tokens_generated = 0 + local_token_tensor = [] with prof: generator_func = self.generate( self.model, @@ -1191,6 +1195,9 @@ def callback(x, *, done_generating=False): start_pos += encoded.size(0) for token_tensor, metrics in generator_func: if token_tensor is not None: + if os.getenv('DEBUG_CACHE'): + print(f"Token tensor: {token_tensor}") + local_token_tensor.append(token_tensor.tolist()[0]) start_pos += token_tensor.size(0) num_tokens_generated += token_tensor.size(0) if metrics is not None: @@ -1199,6 +1206,9 @@ def callback(x, *, done_generating=False): jit_compile = is_first_sample and ( generator_args.compile or generator_args.compile_prefill ) + if os.getenv('DEBUG_CACHE'): + print(f"local_token_tensor: {local_token_tensor}") + print(self.tokenizer.decode(local_token_tensor)) compilation_time = time.perf_counter() - t0 device_sync(device=self.builder_args.device) t = time.perf_counter() - t0 diff --git a/torchchat/model.py b/torchchat/model.py index f50d2a8be..ceb297b7f 100644 --- a/torchchat/model.py +++ b/torchchat/model.py @@ -723,6 +723,8 @@ def distribute(self, device_mesh: DeviceMesh): def forward(self, x: Tensor, input_pos: Optional[Tensor] = None, cache_lane: int = 0) -> Tensor: assert self.freqs_cis is not None, "Caches must be initialized first" + if os.getenv('DEBUG_CACHE'): + print("Transformer forward input pos", input_pos) mask = self.causal_mask[None, None, input_pos] freqs_cis = self.freqs_cis[input_pos] if self.tok_embeddings: