-
Notifications
You must be signed in to change notification settings - Fork 724
[MTK] Add support for Llama 3.2 and code updates to align with current ET API for dynamic dim #6726
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[MTK] Add support for Llama 3.2 and code updates to align with current ET API for dynamic dim #6726
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/6726
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit fc75330 with merge base 97a4600 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
cccclai
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Most nit comments, thank you for adding support for 3.2
| "eos_token_id_tensor": torch.tensor(tokenizer.eos_token_id), | ||
| "response_cap": args.response_cap, | ||
| }, | ||
| keep_in_memory=True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Mind sharing what is keep in memory?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @cccclai, keep in memory stores the dataset in RAM instead of caching it. It was a temporary workaround for an OSError I encountered on my end. I have since resolved the issue and will remove this argument in the next commit.
| if get_dym_shape: | ||
| nt = Dim("num_token", max=num_token) | ||
| cache_dims = tuple(({} for _ in range(2 * self.num_blocks))) | ||
| dynamic_shapes = ( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think I may ask in the previous PR - what does dynamic_shape do? Probably good to add a comment to explain
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @cccclai, dynamic_shape is passed in to torch.export.export_for_training to indicate that the input shapes during calibration for some of the inputs may be different. After calibration, the dynamic_shape argument is not needed
| cal_dataset = None | ||
| if args.dataset is not None: | ||
| cal_dataset = load_dataset("text", data_files=args.dataset, split="train") | ||
| embedding_layer = get_embedding_layer(config, weight_dir, state_dict) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Any specific reason to remove this line?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh I didn't remove it. I shifted it above to line 422. The reason being that during chunk.load_weights (line 437), the weights are popped from the original state_dict and for the tie word embedding True case, this would mean that the embedding weights are not present in the state_dict anymore by the time we get the embedding layer, hence I shifted it up
|
Thanks @neuropilot-captain! do we also need to make to the |
|
There are some lint errors. Could you send a fix? |
Hi @cmodi-meta, yes the script would need to be modified to include Llama 3.2 1B and 3B in the if else. I can modify the script and update it in the next commit |
|
@cccclai has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
This PR contains the following modifications/udpates: