Skip to content

Conversation

ryan-mangeno
Copy link
Contributor

@ryan-mangeno ryan-mangeno commented Aug 28, 2025

adding support to run granite embedding small, and it primarily pulls the modern bert architecture - https://huggingface.co/ibm-granite/granite-embedding-small-english-r2, currently working on it still, havent figured out the pre-tokenizer type or if I need to impliment it, also for the ubatch size the assert fails in llama-graph.cpp, hacked it to accept ubatch size of 1 for testing, but it seems to keep failing there and not sure why,

if I comment out of the line in llama-graph.cpp

assert(!ubatch.equal_seqs());

then it works

@ryan-mangeno ryan-mangeno marked this pull request as draft August 28, 2025 17:05
@ryan-mangeno
Copy link
Contributor Author

ryan-mangeno commented Aug 28, 2025

@gabe-l-hart thanks in advance :)

@ryan-mangeno
Copy link
Contributor Author

@gabe-l-hart thanks in advance :)

also realizing this a little late haha, but should I be changing all of the modern bert stuff to a granite embedding macro like LLM_ARCH_GRANITE_EMBD or keep it as is

@CISC
Copy link
Collaborator

CISC commented Aug 28, 2025

You may want to check out an earlier attempt at ModernBert in #14014

@gabe-l-hart
Copy link
Collaborator

Thanks for getting this together @ryan-mangeno and thanks for pointing out the previous work @CISC. Ryan, let me know if/when you've looked over that PR and found anything to fix and I'll take a pass at review.

@gabe-l-hart
Copy link
Collaborator

also realizing this a little late haha, but should I be changing all of the modern bert stuff to a granite embedding macro like LLM_ARCH_GRANITE_EMBD or keep it as is

In general, we want to keep things as generic as possible, so since this uses the ModernBertModel architecture from transformers, it's best to keep the implementation here similarly robust unless there's a concrete reason to subset the transformers architecture to just work for granite (eg there's some non-trivial code path in the transformers version that would make sense as a separate architecture).

@github-actions github-actions bot added the python python script changes label Aug 28, 2025
@ryan-mangeno
Copy link
Contributor Author

Thanks for getting this together @ryan-mangeno and thanks for pointing out the previous work @CISC. Ryan, let me know if/when you've looked over that PR and found anything to fix and I'll take a pass at review.

will do

@ryan-mangeno
Copy link
Contributor Author

ryan-mangeno commented Sep 3, 2025

@gabe-l-hart im looking into modern berts research paper, I cant find a mention of symmetric sliding window attention but rather local sliding window attention so I am going to opt to use LLAMA_SWA_TYPE_LOCAL versus LLAMA_SWA_TYPE_SYMMETRIC used in the previous attempt. It also uses global attention every third layer so I am going to implement this stuff and then it should be ready for a review :)

@gabe-l-hart
Copy link
Collaborator

@ryan-mangeno That sounds good! I haven't unpacked any of those mechanics myself, but can try to get into it if you get stuck.

… per previous attempt, added local sliding window attention that alternates every third layer
@ryan-mangeno
Copy link
Contributor Author

@ryan-mangeno That sounds good! I haven't unpacked any of those mechanics myself, but can try to get into it if you get stuck.

ok 👍 , made some changes but not sure if its fully ready yet, I will ping you when I think its ready if thats ok

@ryan-mangeno
Copy link
Contributor Author

ryan-mangeno commented Sep 4, 2025

status update - I found out that modern bert uses an alternating rope method , per https://arxiv.org/pdf/2412.13663

In ModernBERT, every third layer employs global
attention with a RoPE theta of 160,000 and the
remaining layers use a 128 token, local sliding window attention with a RoPE theta of 10,000.

I am currently figuring out how to implement this

@ehoogeveen-medweb
Copy link

status update - I found out that modern bert uses an alternating rope method , per arxiv.org/pdf/2412.13663

In ModernBERT, every third layer employs global
attention with a RoPE theta of 160,000 and the
remaining layers use a 128 token, local sliding window attention with a RoPE theta of 10,000.

I am currently figuring out how to implement this

IIUC this matches how sliding window attention is handled for Gemma3:

hparams.rope_freq_base_train_swa = 10000.0f;

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
python python script changes
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants