Skip to content

Conversation

vzantedeschi
Copy link
Contributor

Thank you for the very useful repo!

I made a small change in the code of flops_counter.calculate_flops() so that we can pass non-tensor arguments to model.generate() (e.g., kwargs = {..., max_new_tokens=10}). I couldn't figure out how to pass non-tensor arguments without changing the code.

Other minor changes: small refactor and raise an error when forward_mode is not forward or generate, to see immediately that the runtime error is caused by e.g. a typo.

@MrYxJ
Copy link
Owner

MrYxJ commented Feb 7, 2024 via email

@MrYxJ MrYxJ merged commit 29ca584 into MrYxJ:main Mar 1, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants