-
Notifications
You must be signed in to change notification settings - Fork 724
[MTK] Add support for Llama 3.2 and code updates to align with current ET API for dynamic dim #6726
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -675,8 +675,8 @@ def load_weights(self, state_dict, state_dict_start_idx): | |
| ) | ||
| else: | ||
| if self.config.tie_word_embeddings: | ||
| lm_head_weight_key = "embed_tokens.weight" | ||
| lm_head_bias_key = "embed_tokens.bias" | ||
| lm_head_weight_key = f"{prefix}embed_tokens.weight" | ||
| lm_head_bias_key = f"{prefix}embed_tokens.bias" | ||
| else: | ||
| lm_head_weight_key = "lm_head.weight" | ||
| lm_head_bias_key = "lm_head.bias" | ||
|
|
@@ -751,15 +751,16 @@ def get_example_inputs( | |
| for _ in range(2 * self.num_blocks) | ||
| ], | ||
| ) | ||
| # Specify dims that would be dynamic during calibration | ||
| # Note: Assume cache size fixed shape as torch dynamic shape cannot handle dim 3 being | ||
| # combination of 2 dynamic dims | ||
| if get_dym_shape: | ||
| nt = Dim("num_token", max=num_token) | ||
| cache_dims = tuple(({} for _ in range(2 * self.num_blocks))) | ||
| dynamic_shapes = ( | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think I may ask in the previous PR - what does
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hi @cccclai, dynamic_shape is passed in to torch.export.export_for_training to indicate that the input shapes during calibration for some of the inputs may be different. After calibration, the dynamic_shape argument is not needed |
||
| {0: None, 1: nt, 2: None}, | ||
| {0: None, 1: None, 2: nt, 3: nt + cache_size}, | ||
| {0: None, 1: None, 2: nt, 3: None}, | ||
| {0: Dim.STATIC, 1: nt, 2: Dim.STATIC}, | ||
| {0: Dim.STATIC, 1: Dim.STATIC, 2: nt, 3: nt + cache_size}, | ||
| {0: Dim.STATIC, 1: Dim.STATIC, 2: nt, 3: Dim.STATIC}, | ||
| cache_dims, | ||
| ) | ||
| return example_inputs, dynamic_shapes | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,23 @@ | ||
| { | ||
| "architectures": [ | ||
| "LlamaForCausalLM" | ||
| ], | ||
| "bos_token_id": 128000, | ||
| "eos_token_id": 128001, | ||
| "head_dim": 64, | ||
| "hidden_size": 2048, | ||
| "initializer_range": 0.02, | ||
| "intermediate_size": 8192, | ||
| "max_position_embeddings": 131072, | ||
| "model_type": "llama", | ||
| "num_attention_heads": 32, | ||
| "num_hidden_layers": 16, | ||
| "num_key_value_heads": 8, | ||
| "rms_norm_eps": 1e-05, | ||
| "rope_theta": 500000.0, | ||
| "tie_word_embeddings": true, | ||
| "torch_dtype": "bfloat16", | ||
| "transformers_version": "4.45.0.dev0", | ||
| "vocab_size": 128256, | ||
| "tokenizer": "pretrained_fast" | ||
| } |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,16 @@ | ||
| { | ||
| "bos_token": { | ||
| "content": "<|begin_of_text|>", | ||
| "lstrip": false, | ||
| "normalized": false, | ||
| "rstrip": false, | ||
| "single_word": false | ||
| }, | ||
| "eos_token": { | ||
| "content": "<|eot_id|>", | ||
| "lstrip": false, | ||
| "normalized": false, | ||
| "rstrip": false, | ||
| "single_word": false | ||
| } | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Any specific reason to remove this line?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh I didn't remove it. I shifted it above to line 422. The reason being that during chunk.load_weights (line 437), the weights are popped from the original state_dict and for the tie word embedding True case, this would mean that the embedding weights are not present in the state_dict anymore by the time we get the embedding layer, hence I shifted it up