From d62163ab1f6ff896c553f49fd8369bf69cc10cf0 Mon Sep 17 00:00:00 2001 From: McClain Thiel Date: Fri, 14 Nov 2025 12:30:37 +0000 Subject: [PATCH 1/6] started adding support for evo2 --- docs/source/en/_toctree.yml | 2 + docs/source/en/model_doc/evo2.md | 71 +++ src/transformers/models/__init__.py | 1 + .../models/auto/configuration_auto.py | 2 + src/transformers/models/auto/modeling_auto.py | 5 + src/transformers/models/evo2/__init__.py | 29 + .../models/evo2/configuration_evo2.py | 193 +++++++ src/transformers/models/evo2/modeling_evo2.py | 524 ++++++++++++++++++ src/transformers/models/evo2/modular_evo2.py | 88 +++ tests/models/evo2/__init__.py | 0 tests/models/evo2/test_modeling_evo2.py | 492 ++++++++++++++++ 11 files changed, 1407 insertions(+) create mode 100644 docs/source/en/model_doc/evo2.md create mode 100644 src/transformers/models/evo2/__init__.py create mode 100644 src/transformers/models/evo2/configuration_evo2.py create mode 100644 src/transformers/models/evo2/modeling_evo2.py create mode 100644 src/transformers/models/evo2/modular_evo2.py create mode 100644 tests/models/evo2/__init__.py create mode 100644 tests/models/evo2/test_modeling_evo2.py diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index c92fed507a6d..0e639d5bffb4 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -480,6 +480,8 @@ title: ErnieM - local: model_doc/esm title: ESM + - local: model_doc/evo2 + title: Evo2 - local: model_doc/exaone4 title: EXAONE-4.0 - local: model_doc/falcon diff --git a/docs/source/en/model_doc/evo2.md b/docs/source/en/model_doc/evo2.md new file mode 100644 index 000000000000..d4a89cc7c575 --- /dev/null +++ b/docs/source/en/model_doc/evo2.md @@ -0,0 +1,71 @@ + + + +# Evo2 + +## Overview + +The Evo2 model was proposed in []() by . + + +The abstract from the paper is the following: + + + +Tips: + + + +This model was contributed by [INSERT YOUR HF USERNAME HERE](https://huggingface.co/). +The original code can be found [here](). + +## Usage examples + + + +## Evo2Config + +[[autodoc]] Evo2Config + +## Evo2ForCausalLM + +[[autodoc]] Evo2ForCausalLM + +## Evo2ForQuestionAnswering + +[[autodoc]] Evo2ForQuestionAnswering + +## Evo2Model + +[[autodoc]] Evo2Model + - forward + +## Evo2PreTrainedModel + +[[autodoc]] Evo2PreTrainedModel + - forward + +## Evo2ForSequenceClassification + +[[autodoc]] Evo2ForSequenceClassification + +## Evo2ForTokenClassification + +[[autodoc]] Evo2ForTokenClassification \ No newline at end of file diff --git a/src/transformers/models/__init__.py b/src/transformers/models/__init__.py index 3534ce6719d0..5f3472868cba 100644 --- a/src/transformers/models/__init__.py +++ b/src/transformers/models/__init__.py @@ -120,6 +120,7 @@ from .encoder_decoder import * from .ernie import * from .esm import * + from .evo2 import * from .evolla import * from .exaone4 import * from .falcon import * diff --git a/src/transformers/models/auto/configuration_auto.py b/src/transformers/models/auto/configuration_auto.py index 9a3b2ec5ecc2..1f5a861e9b83 100644 --- a/src/transformers/models/auto/configuration_auto.py +++ b/src/transformers/models/auto/configuration_auto.py @@ -145,6 +145,7 @@ ("ernie4_5_moe", "Ernie4_5_MoeConfig"), ("ernie_m", "ErnieMConfig"), ("esm", "EsmConfig"), + ("evo2", "Evo2Config"), ("evolla", "EvollaConfig"), ("exaone4", "Exaone4Config"), ("falcon", "FalconConfig"), @@ -590,6 +591,7 @@ ("ernie4_5_moe", "Ernie4_5_MoE"), ("ernie_m", "ErnieM"), ("esm", "ESM"), + ("evo2", "Evo2"), ("evolla", "Evolla"), ("exaone4", "EXAONE-4.0"), ("falcon", "Falcon"), diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index 257fb95fdea7..e33b6743319e 100644 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -148,6 +148,7 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin): ("ernie4_5_moe", "Ernie4_5_MoeModel"), ("ernie_m", "ErnieMModel"), ("esm", "EsmModel"), + ("evo2", "Evo2Model"), ("evolla", "EvollaModel"), ("exaone4", "Exaone4Model"), ("falcon", "FalconModel"), @@ -666,6 +667,7 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin): ("ernie", "ErnieForCausalLM"), ("ernie4_5", "Ernie4_5ForCausalLM"), ("ernie4_5_moe", "Ernie4_5_MoeForCausalLM"), + ("evo2", "Evo2ForCausalLM"), ("exaone4", "Exaone4ForCausalLM"), ("falcon", "FalconForCausalLM"), ("falcon_h1", "FalconH1ForCausalLM"), @@ -1239,6 +1241,7 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin): ("ernie", "ErnieForSequenceClassification"), ("ernie_m", "ErnieMForSequenceClassification"), ("esm", "EsmForSequenceClassification"), + ("evo2", "Evo2ForSequenceClassification"), ("exaone4", "Exaone4ForSequenceClassification"), ("falcon", "FalconForSequenceClassification"), ("flaubert", "FlaubertForSequenceClassification"), @@ -1353,6 +1356,7 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin): ("electra", "ElectraForQuestionAnswering"), ("ernie", "ErnieForQuestionAnswering"), ("ernie_m", "ErnieMForQuestionAnswering"), + ("evo2", "Evo2ForQuestionAnswering"), ("exaone4", "Exaone4ForQuestionAnswering"), ("falcon", "FalconForQuestionAnswering"), ("flaubert", "FlaubertForQuestionAnsweringSimple"), @@ -1464,6 +1468,7 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin): ("ernie", "ErnieForTokenClassification"), ("ernie_m", "ErnieMForTokenClassification"), ("esm", "EsmForTokenClassification"), + ("evo2", "Evo2ForTokenClassification"), ("exaone4", "Exaone4ForTokenClassification"), ("falcon", "FalconForTokenClassification"), ("flaubert", "FlaubertForTokenClassification"), diff --git a/src/transformers/models/evo2/__init__.py b/src/transformers/models/evo2/__init__.py new file mode 100644 index 000000000000..3b196962679e --- /dev/null +++ b/src/transformers/models/evo2/__init__.py @@ -0,0 +1,29 @@ +# coding=utf-8 +# Copyright 2025 the HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_evo2 import * + from .modeling_evo2 import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/src/transformers/models/evo2/configuration_evo2.py b/src/transformers/models/evo2/configuration_evo2.py new file mode 100644 index 000000000000..bd66fcac84fd --- /dev/null +++ b/src/transformers/models/evo2/configuration_evo2.py @@ -0,0 +1,193 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/evo2/modular_evo2.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_evo2.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# coding=utf-8 +# Copyright 2025 the HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from typing import Optional + +from ...configuration_utils import PreTrainedConfig +from ...modeling_rope_utils import RopeParameters, rope_config_validation, standardize_rope_params +from ...utils import logging + + +logger = logging.get_logger(__name__) + + +class Evo2Config(PreTrainedConfig): + r""" + This is the configuration class to store the configuration of a [`Evo2Model`]. It is used to instantiate an + Evo2 model according to the specified arguments, defining the model architecture. Instantiating a configuration + with the defaults will yield a similar configuration to that of the Evo2-7B-v0.1 or Evo2-7B-Instruct-v0.1. + + [evo2ai/Evo2-7B-v0.1](https://huggingface.co/evo2ai/Evo2-7B-v0.1) + [evo2ai/Evo2-7B-Instruct-v0.1](https://huggingface.co/evo2ai/Evo2-7B-Instruct-v0.1) + + Configuration objects inherit from [`PreTrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PreTrainedConfig`] for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 32000): + Vocabulary size of the Evo2 model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`Evo2Model`] + hidden_size (`int`, *optional*, defaults to 4096): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 14336): + Dimension of the MLP representations. + num_hidden_layers (`int`, *optional*, defaults to 32): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 32): + Number of attention heads for each attention layer in the Transformer encoder. + num_key_value_heads (`int`, *optional*, defaults to 8): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details, check out [this + paper](https://huggingface.co/papers/2305.13245). If it is not specified, will default to `8`. + head_dim (`int`, *optional*, defaults to `hidden_size // num_attention_heads`): + The attention head dimension. + hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) in the decoder. + max_position_embeddings (`int`, *optional*, defaults to `4096*32`): + The maximum sequence length that this model might ever be used with. Evo2's sliding window attention + allows sequence of up to 4096*32 tokens. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + rms_norm_eps (`float`, *optional*, defaults to 1e-06): + The epsilon used by the rms normalization layers. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + pad_token_id (`int`, *optional*): + The id of the padding token. + bos_token_id (`int`, *optional*, defaults to 1): + The id of the "beginning-of-sequence" token. + eos_token_id (`int`, *optional*, defaults to 2): + The id of the "end-of-sequence" token. + tie_word_embeddings (`bool`, *optional*, defaults to `False`): + Whether the model's input and output word embeddings should be tied. + rope_parameters (`RopeParameters`, *optional*): + Dictionary containing the configuration parameters for the RoPE embeddings. The dictionaty should contain + a value for `rope_theta` and optionally parameters used for scaling in case you want to use RoPE + with longer `max_position_embeddings`. + sliding_window (`int`, *optional*, defaults to 4096): + Sliding window attention window size. If not specified, will default to `4096`. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + + ```python + >>> from transformers import Evo2Model, Evo2Config + + >>> # Initializing a Evo2 7B style configuration + >>> configuration = Evo2Config() + + >>> # Initializing a model from the Evo2 7B style configuration + >>> model = Evo2Model(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "evo2" + keys_to_ignore_at_inference = ["past_key_values"] + # Default tensor parallel plan for base model `Evo2Model` + base_model_tp_plan = { + "layers.*.self_attn.q_proj": "colwise", + "layers.*.self_attn.k_proj": "colwise", + "layers.*.self_attn.v_proj": "colwise", + "layers.*.self_attn.o_proj": "rowwise", + "layers.*.mlp.gate_proj": "colwise", + "layers.*.mlp.up_proj": "colwise", + "layers.*.mlp.down_proj": "rowwise", + } + base_model_pp_plan = { + "embed_tokens": (["input_ids"], ["inputs_embeds"]), + "layers": (["hidden_states", "attention_mask"], ["hidden_states"]), + "norm": (["hidden_states"], ["hidden_states"]), + } + + def __init__( + self, + vocab_size: Optional[int] = 32000, + hidden_size: Optional[int] = 4096, + intermediate_size: Optional[int] = 14336, + num_hidden_layers: Optional[int] = 32, + num_attention_heads: Optional[int] = 32, + num_key_value_heads: Optional[int] = 8, + head_dim: Optional[int] = None, + hidden_act: Optional[str] = "silu", + max_position_embeddings: Optional[int] = 4096 * 32, + initializer_range: Optional[float] = 0.02, + rms_norm_eps: Optional[int] = 1e-6, + use_cache: Optional[bool] = True, + pad_token_id: Optional[int] = None, + bos_token_id: Optional[int] = 1, + eos_token_id: Optional[int] = 2, + tie_word_embeddings: Optional[bool] = False, + rope_parameters: Optional[RopeParameters | dict[str, RopeParameters]] = None, + sliding_window: Optional[int] = 4096, + attention_dropout: Optional[float] = 0.0, + **kwargs, + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.sliding_window = sliding_window + self.head_dim = head_dim if head_dim is not None else hidden_size // num_attention_heads + + # for backward compatibility + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.use_cache = use_cache + self.attention_dropout = attention_dropout + + if "layer_types" in kwargs: + logger.warning_once( + "Detected Evo2 model with layer_types. Consider using AutoModel or Ministral classes instead to enable alternating attention compatibility." + ) + + # Try to set `rope_scaling` if available, otherwise use `rope_parameters` + rope_scaling = kwargs.pop("rope_scaling", None) + self.rope_parameters = rope_scaling or rope_parameters + + # Validate the correctness of rotary position embeddings parameters + rope_theta = kwargs.get("rope_theta", 10000.0) + standardize_rope_params(self, rope_theta=rope_theta) + rope_config_validation(self) + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + + +__all__ = ["Evo2Config"] diff --git a/src/transformers/models/evo2/modeling_evo2.py b/src/transformers/models/evo2/modeling_evo2.py new file mode 100644 index 000000000000..5d0e11c88104 --- /dev/null +++ b/src/transformers/models/evo2/modeling_evo2.py @@ -0,0 +1,524 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/evo2/modular_evo2.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_evo2.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# coding=utf-8 +# Copyright 2025 the HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from collections.abc import Callable +from typing import Optional, Union + +import torch +from torch import nn + +from transformers.utils.generic import check_model_inputs + +from ...activations import ACT2FN +from ...cache_utils import Cache, DynamicCache +from ...generation import GenerationMixin +from ...integrations import use_kernel_forward_from_hub +from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask +from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_layers import ( + GenericForQuestionAnswering, + GenericForSequenceClassification, + GenericForTokenClassification, + GradientCheckpointingLayer, +) +from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast +from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...processing_utils import Unpack +from ...utils import TransformersKwargs, auto_docstring, can_return_tuple +from .configuration_evo2 import Evo2Config + + +class Evo2MLP(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, x): + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + return down_proj + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + **kwargs: Unpack[TransformersKwargs], +): + key_states = repeat_kv(key, module.num_key_value_groups) + value_states = repeat_kv(value, module.num_key_value_groups) + + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling + if attention_mask is not None: + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + +class Evo2Attention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: Evo2Config, layer_idx: int): + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.head_dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads + self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads + self.scaling = self.head_dim**-0.5 + self.attention_dropout = config.attention_dropout + self.is_causal = True + self.q_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=False) + self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False) + self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False) + self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False) + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor], + past_key_values: Optional[Cache] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + + query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) + key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) + + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_values is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + sliding_window=getattr(self.config, "sliding_window", None), # main diff with Llama + **kwargs, + ) + + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights + + +@use_kernel_forward_from_hub("RMSNorm") +class Evo2RMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + Evo2RMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" + + +class Evo2DecoderLayer(GradientCheckpointingLayer): + def __init__(self, config: Evo2Config, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + self.self_attn = Evo2Attention(config=config, layer_idx=layer_idx) + self.mlp = Evo2MLP(config) + self.input_layernorm = Evo2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = Evo2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> torch.Tensor: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + # Self Attention + hidden_states, _ = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + return hidden_states + + +@auto_docstring +class Evo2PreTrainedModel(PreTrainedModel): + config: Evo2Config + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["Evo2DecoderLayer"] + _skip_keys_device_placement = ["past_key_values"] + _supports_flash_attn = True + _supports_sdpa = True + _supports_flex_attn = True + + _can_compile_fullgraph = True + _supports_attention_backend = True + _can_record_outputs = { + "hidden_states": Evo2DecoderLayer, + "attentions": Evo2Attention, + } + + +class Evo2RotaryEmbedding(nn.Module): + inv_freq: torch.Tensor # fix linting for `register_buffer` + + def __init__(self, config: Evo2Config, device=None): + super().__init__() + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + + self.config = config + + self.rope_type = self.config.rope_parameters["rope_type"] + rope_init_fn: Callable = self.compute_default_rope_parameters + if self.rope_type != "default": + rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + inv_freq, self.attention_scaling = rope_init_fn(self.config, device) + + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.original_inv_freq = inv_freq + + @staticmethod + def compute_default_rope_parameters( + config: Optional[Evo2Config] = None, + device: Optional["torch.device"] = None, + seq_len: Optional[int] = None, + ) -> tuple["torch.Tensor", float]: + """ + Computes the inverse frequencies according to the original RoPE implementation + Args: + config ([`~transformers.PreTrainedConfig`]): + The model configuration. + device (`torch.device`): + The device to use for initialization of the inverse frequencies. + seq_len (`int`, *optional*): + The current sequence length. Unused for this type of RoPE. + Returns: + Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the + post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE). + """ + base = config.rope_parameters["rope_theta"] + dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads + + attention_factor = 1.0 # Unused in this type of RoPE + + # Compute the inverse frequencies + inv_freq = 1.0 / ( + base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim) + ) + return inv_freq, attention_factor + + @torch.no_grad() + @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope) + def forward(self, x, position_ids): + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device) + position_ids_expanded = position_ids[:, None, :].float() + + device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): # Force float32 + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() * self.attention_scaling + sin = emb.sin() * self.attention_scaling + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +@auto_docstring +class Evo2Model(Evo2PreTrainedModel): + def __init__(self, config: Evo2Config): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList( + [Evo2DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self.norm = Evo2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.rotary_emb = Evo2RotaryEmbedding(config=config) + self.gradient_checkpointing = False + + # Initialize weights and apply final processing + self.post_init() + + @check_model_inputs() + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> BaseModelOutputWithPast: + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if use_cache and past_key_values is None: + past_key_values = DynamicCache(config=self.config) + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + mask_function = create_causal_mask if self.config.sliding_window is None else create_sliding_window_causal_mask + causal_mask = mask_function( + config=self.config, + input_embeds=inputs_embeds, + attention_mask=attention_mask, + cache_position=cache_position, + past_key_values=past_key_values, + position_ids=position_ids, + ) + + hidden_states = inputs_embeds + position_embeddings = self.rotary_emb(hidden_states, position_ids=position_ids) + + for decoder_layer in self.layers[: self.config.num_hidden_layers]: + hidden_states = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_values=past_key_values, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + hidden_states = self.norm(hidden_states) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values if use_cache else None, + ) + + +@auto_docstring +class Evo2ForCausalLM(Evo2PreTrainedModel, GenerationMixin): + _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"} + _tp_plan = {"lm_head": "colwise_rep"} + _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} + + def __init__(self, config): + super().__init__(config) + self.model = Evo2Model(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + **kwargs: Unpack[TransformersKwargs], + ) -> CausalLMOutputWithPast: + r""" + Example: + + ```python + >>> from transformers import AutoTokenizer, Evo2ForCausalLM + + >>> model = Evo2ForCausalLM.from_pretrained("meta-evo2/Evo2-2-7b-hf") + >>> tokenizer = AutoTokenizer.from_pretrained("meta-evo2/Evo2-2-7b-hf") + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + outputs: BaseModelOutputWithPast = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + cache_position=cache_position, + **kwargs, + ) + + hidden_states = outputs.last_hidden_state + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) + + loss = None + if labels is not None: + loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs) + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +class Evo2ForTokenClassification(GenericForTokenClassification, Evo2PreTrainedModel): + pass + + +class Evo2ForSequenceClassification(GenericForSequenceClassification, Evo2PreTrainedModel): + pass + + +class Evo2ForQuestionAnswering(GenericForQuestionAnswering, Evo2PreTrainedModel): + pass + + +__all__ = [ + "Evo2ForCausalLM", + "Evo2ForQuestionAnswering", + "Evo2Model", + "Evo2PreTrainedModel", + "Evo2ForSequenceClassification", + "Evo2ForTokenClassification", +] diff --git a/src/transformers/models/evo2/modular_evo2.py b/src/transformers/models/evo2/modular_evo2.py new file mode 100644 index 000000000000..e8f3caead0dc --- /dev/null +++ b/src/transformers/models/evo2/modular_evo2.py @@ -0,0 +1,88 @@ +# coding=utf-8 +# Copyright 2025 the HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ..mistral.configuration_mistral import MistralConfig +from ..mistral.modeling_mistral import ( + MistralAttention, + MistralDecoderLayer, + MistralForCausalLM, + MistralForQuestionAnswering, + MistralForSequenceClassification, + MistralForTokenClassification, + MistralMLP, + MistralModel, + MistralPreTrainedModel, + MistralRMSNorm, + MistralRotaryEmbedding, +) + + +class Evo2Config(MistralConfig): + pass + + +class Evo2MLP(MistralMLP): + pass + + +class Evo2Attention(MistralAttention): + pass + + +class Evo2RMSNorm(MistralRMSNorm): + pass + + +class Evo2DecoderLayer(MistralDecoderLayer): + pass + + +class Evo2PreTrainedModel(MistralPreTrainedModel): + pass + + +class Evo2RotaryEmbedding(MistralRotaryEmbedding): + pass + + +class Evo2Model(MistralModel): + pass + + +class Evo2ForCausalLM(MistralForCausalLM): + pass + + +class Evo2ForTokenClassification(MistralForTokenClassification): + pass + + +class Evo2ForSequenceClassification(MistralForSequenceClassification): + pass + + +class Evo2ForQuestionAnswering(MistralForQuestionAnswering): + pass + + +__all__ = [ + "Evo2Config", + "Evo2ForCausalLM", + "Evo2ForQuestionAnswering", + "Evo2Model", + "Evo2PreTrainedModel", + "Evo2ForSequenceClassification", + "Evo2ForTokenClassification", +] diff --git a/tests/models/evo2/__init__.py b/tests/models/evo2/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/models/evo2/test_modeling_evo2.py b/tests/models/evo2/test_modeling_evo2.py new file mode 100644 index 000000000000..9d7b77f0d0f8 --- /dev/null +++ b/tests/models/evo2/test_modeling_evo2.py @@ -0,0 +1,492 @@ +# coding=utf-8 +# Copyright 2025 the HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Testing suite for the PyTorch Evo2 model.""" + +import gc +import unittest + +import pytest +from packaging import version +from parameterized import parameterized + +from transformers import AutoTokenizer, BitsAndBytesConfig, DynamicCache, is_torch_available, set_seed +from transformers.cache_utils import DynamicSlidingWindowLayer +from transformers.testing_utils import ( + DeviceProperties, + Expectations, + backend_empty_cache, + cleanup, + get_device_properties, + require_bitsandbytes, + require_flash_attn, + require_read_token, + require_torch, + require_torch_accelerator, + slow, + torch_device, +) + + +if is_torch_available(): + import torch + + from transformers import ( + Evo2ForCausalLM, + Evo2Model, + ) +from ...causal_lm_tester import CausalLMModelTest, CausalLMModelTester + + +class Evo2ModelTester(CausalLMModelTester): + if is_torch_available(): + base_model_class = Evo2Model + + +@require_torch +class Evo2ModelTest(CausalLMModelTest, unittest.TestCase): + model_tester_class = Evo2ModelTester + + # TODO (ydshieh): Check this. See https://app.circleci.com/pipelines/github/huggingface/transformers/79245/workflows/9490ef58-79c2-410d-8f51-e3495156cf9c/jobs/1012146 + def is_pipeline_test_to_skip( + self, + pipeline_test_case_name, + config_class, + model_architecture, + tokenizer_name, + image_processor_name, + feature_extractor_name, + processor_name, + ): + return True + + +@require_torch_accelerator +@require_read_token +class Evo2IntegrationTest(unittest.TestCase): + # This variable is used to determine which accelerator are we using for our runners (e.g. A10 or T4) + # Depending on the hardware we get different logits / generations + device_properties: DeviceProperties = (None, None, None) + + @classmethod + def setUpClass(cls): + cls.device_properties = get_device_properties() + + def setUp(self): + cleanup(torch_device, gc_collect=True) + + def tearDown(self): + cleanup(torch_device, gc_collect=True) + + @slow + def test_model_7b_logits(self): + input_ids = [1, 306, 4658, 278, 6593, 310, 2834, 338] + model = Evo2ForCausalLM.from_pretrained("mistralai/Evo2-7B-v0.1", device_map="auto", dtype=torch.float16) + input_ids = torch.tensor([input_ids]).to(model.model.embed_tokens.weight.device) + with torch.no_grad(): + out = model(input_ids).logits.float().cpu() + # Expected mean on dim = -1 + EXPECTED_MEAN = torch.tensor([[-2.5548, -2.5737, -3.0600, -2.5906, -2.8478, -2.8118, -2.9325, -2.7694]]) + torch.testing.assert_close(out.mean(-1), EXPECTED_MEAN, rtol=1e-2, atol=1e-2) + + # ("cuda", 8) for A100/A10, and ("cuda", 7) 7 for T4. + # considering differences in hardware processing and potential deviations in output. + # fmt: off + EXPECTED_SLICES = Expectations( + { + ("cuda", 7): torch.tensor([-5.8828, -5.8633, -0.1042, -4.7266, -5.8828, -5.8789, -5.8789, -5.8828, -5.8828, -5.8828, -5.8828, -5.8828, -1.0801, 1.7598, -5.8828, -5.8828, -5.8828, -5.8828, -5.8828, -5.8828, -5.8828, -5.8828, -5.8828, -5.8828, -5.8828, -5.8828, -5.8828, -5.8828, -5.8828, -5.8828]), + ("cuda", 8): torch.tensor([-5.8711, -5.8555, -0.1050, -4.7148, -5.8711, -5.8711, -5.8711, -5.8711, -5.8711, -5.8711, -5.8711, -5.8711, -1.0781, 1.7568, -5.8711, -5.8711, -5.8711, -5.8711, -5.8711, -5.8711, -5.8711, -5.8711, -5.8711, -5.8711, -5.8711, -5.8711, -5.8711, -5.8711, -5.8711, -5.8711]), + ("rocm", 9): torch.tensor([-5.8750, -5.8594, -0.1047, -4.7188, -5.8750, -5.8750, -5.8750, -5.8750, -5.8750, -5.8750, -5.8750, -5.8750, -1.0781, 1.7578, -5.8750, -5.8750, -5.8750, -5.8750, -5.8750, -5.8750, -5.8750, -5.8750, -5.8750, -5.8750, -5.8750, -5.8750, -5.8750, -5.8750, -5.8750, -5.8750]), + } + ) + # fmt: on + expected_slice = EXPECTED_SLICES.get_expectation() + + torch.testing.assert_close(out[0, 0, :30], expected_slice, atol=1e-4, rtol=1e-4) + + @slow + @require_bitsandbytes + def test_model_7b_generation(self): + EXPECTED_TEXT_COMPLETION = "My favourite condiment is 100% ketchup. I’m not a fan of mustard, mayo," + + prompt = "My favourite condiment is " + tokenizer = AutoTokenizer.from_pretrained("mistralai/Evo2-7B-v0.1", use_fast=False) + model = Evo2ForCausalLM.from_pretrained( + "mistralai/Evo2-7B-v0.1", + device_map={"": torch_device}, + quantization_config=BitsAndBytesConfig(load_in_4bit=True), + ) + input_ids = tokenizer.encode(prompt, return_tensors="pt").to(model.model.embed_tokens.weight.device) + + # greedy generation outputs + generated_ids = model.generate(input_ids, max_new_tokens=20, temperature=0) + text = tokenizer.decode(generated_ids[0], skip_special_tokens=True) + self.assertEqual(EXPECTED_TEXT_COMPLETION, text) + + # TODO joao, manuel: remove this in v4.62.0 + @slow + def test_model_7b_dola_generation(self): + # ground truth text generated with dola_layers="low", repetition_penalty=1.2 + EXPECTED_TEXT_COMPLETION = ( + """My favourite condiment is 100% ketchup. I love it on everything, and I’m not ash""" + ) + prompt = "My favourite condiment is " + tokenizer = AutoTokenizer.from_pretrained("mistralai/Evo2-7B-v0.1", use_fast=False) + model = Evo2ForCausalLM.from_pretrained("mistralai/Evo2-7B-v0.1", device_map="auto", dtype=torch.float16) + input_ids = tokenizer.encode(prompt, return_tensors="pt").to(model.model.embed_tokens.weight.device) + + # greedy generation outputs + generated_ids = model.generate( + input_ids, + max_new_tokens=20, + temperature=0, + dola_layers="low", + repetition_penalty=1.2, + trust_remote_code=True, + custom_generate="transformers-community/dola", + ) + text = tokenizer.decode(generated_ids[0], skip_special_tokens=True) + self.assertEqual(EXPECTED_TEXT_COMPLETION, text) + + del model + backend_empty_cache(torch_device) + gc.collect() + + @require_flash_attn + @require_bitsandbytes + @slow + @pytest.mark.flash_attn_test + def test_model_7b_long_prompt(self): + EXPECTED_OUTPUT_TOKEN_IDS = [306, 338] + # An input with 4097 tokens that is above the size of the sliding window + input_ids = [1] + [306, 338] * 2048 + model = Evo2ForCausalLM.from_pretrained( + "mistralai/Evo2-7B-v0.1", + device_map={"": torch_device}, + quantization_config=BitsAndBytesConfig(load_in_4bit=True), + attn_implementation="flash_attention_2", + ) + input_ids = torch.tensor([input_ids]).to(model.model.embed_tokens.weight.device) + generated_ids = model.generate(input_ids, max_new_tokens=4, temperature=0) + self.assertEqual(EXPECTED_OUTPUT_TOKEN_IDS, generated_ids[0][-2:].tolist()) + + # Assisted generation + assistant_model = model + assistant_model.generation_config.num_assistant_tokens = 2 + assistant_model.generation_config.num_assistant_tokens_schedule = "constant" + generated_ids = model.generate(input_ids, max_new_tokens=4, temperature=0) + self.assertEqual(EXPECTED_OUTPUT_TOKEN_IDS, generated_ids[0][-2:].tolist()) + + @slow + def test_model_7b_long_prompt_sdpa(self): + EXPECTED_OUTPUT_TOKEN_IDS = [306, 338] + # An input with 4097 tokens that is above the size of the sliding window + input_ids = [1] + [306, 338] * 2048 + model = Evo2ForCausalLM.from_pretrained( + "mistralai/Evo2-7B-v0.1", device_map="auto", attn_implementation="sdpa", dtype=torch.float16 + ) + input_ids = torch.tensor([input_ids]).to(model.model.embed_tokens.weight.device) + generated_ids = model.generate(input_ids, max_new_tokens=4, temperature=0) + self.assertEqual(EXPECTED_OUTPUT_TOKEN_IDS, generated_ids[0][-2:].tolist()) + + # Assisted generation + assistant_model = model + assistant_model.generation_config.num_assistant_tokens = 2 + assistant_model.generation_config.num_assistant_tokens_schedule = "constant" + generated_ids = model.generate(input_ids, max_new_tokens=4, temperature=0) + self.assertEqual(EXPECTED_OUTPUT_TOKEN_IDS, generated_ids[0][-2:].tolist()) + + del assistant_model + + backend_empty_cache(torch_device) + gc.collect() + + EXPECTED_TEXT_COMPLETION = """My favourite condiment is 100% ketchup. I love it on everything. I’m not a big""" + prompt = "My favourite condiment is " + tokenizer = AutoTokenizer.from_pretrained("mistralai/Evo2-7B-v0.1", use_fast=False) + + input_ids = tokenizer.encode(prompt, return_tensors="pt").to(model.model.embed_tokens.weight.device) + + # greedy generation outputs + generated_ids = model.generate(input_ids, max_new_tokens=20, temperature=0) + text = tokenizer.decode(generated_ids[0], skip_special_tokens=True) + self.assertEqual(EXPECTED_TEXT_COMPLETION, text) + + @slow + def test_speculative_generation(self): + EXPECTED_TEXT_COMPLETION = "My favourite condiment is 100% ketchup. I’m not a fan of mustard, relish" + prompt = "My favourite condiment is " + tokenizer = AutoTokenizer.from_pretrained("mistralai/Evo2-7B-v0.1", use_fast=False) + model = Evo2ForCausalLM.from_pretrained("mistralai/Evo2-7B-v0.1", device_map="auto", dtype=torch.float16) + input_ids = tokenizer.encode(prompt, return_tensors="pt").to(model.model.embed_tokens.weight.device) + + # greedy generation outputs + set_seed(0) + generated_ids = model.generate( + input_ids, max_new_tokens=20, do_sample=True, temperature=0.3, assistant_model=model + ) + text = tokenizer.decode(generated_ids[0], skip_special_tokens=True) + self.assertEqual(EXPECTED_TEXT_COMPLETION, text) + + @pytest.mark.torch_compile_test + @slow + def test_compile_static_cache(self): + # `torch==2.2` will throw an error on this test (as in other compilation tests), but torch==2.1.2 and torch>2.2 + # work as intended. See https://github.com/pytorch/pytorch/issues/121943 + if version.parse(torch.__version__) < version.parse("2.3.0"): + self.skipTest(reason="This test requires torch >= 2.3 to run.") + + if self.device_properties[0] == "cuda" and self.device_properties[1] == 7: + self.skipTest(reason="This test is failing (`torch.compile` fails) on Nvidia T4 GPU.") + + NUM_TOKENS_TO_GENERATE = 40 + EXPECTED_TEXT_COMPLETION = [ + "My favourite condiment is 100% ketchup. I love it on everything. " + "I’m not a big fan of mustard, mayo, or relish. I’m not a fan of pickles" + ] + + prompts = ["My favourite condiment is "] + tokenizer = AutoTokenizer.from_pretrained("mistralai/Evo2-7B-v0.1", use_fast=False) + tokenizer.pad_token = tokenizer.eos_token + model = Evo2ForCausalLM.from_pretrained("mistralai/Evo2-7B-v0.1", device_map=torch_device, dtype=torch.float16) + inputs = tokenizer(prompts, return_tensors="pt", padding=True).to(model.device) + + # Dynamic Cache + generated_ids = model.generate(**inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False) + dynamic_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) + self.assertEqual(EXPECTED_TEXT_COMPLETION, dynamic_text) + + # Static Cache + generated_ids = model.generate( + **inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False, cache_implementation="static" + ) + static_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) + self.assertEqual(EXPECTED_TEXT_COMPLETION, static_text) + + # Sliding Window Cache + generated_ids = model.generate( + **inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False, cache_implementation="sliding_window" + ) + static_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) + self.assertEqual(EXPECTED_TEXT_COMPLETION, static_text) + + # Static Cache + compile + forward_function = model.__call__ + model.__call__ = torch.compile(forward_function, mode="reduce-overhead", fullgraph=True) + generated_ids = model.generate( + **inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False, cache_implementation="static" + ) + static_compiled_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) + self.assertEqual(EXPECTED_TEXT_COMPLETION, static_compiled_text) + + # Sliding Window Cache + compile + torch._dynamo.reset() + model.__call__ = torch.compile(forward_function, mode="reduce-overhead", fullgraph=True) + generated_ids = model.generate( + **inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False, cache_implementation="sliding_window" + ) + static_compiled_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) + self.assertEqual(EXPECTED_TEXT_COMPLETION, static_compiled_text) + + @pytest.mark.flash_attn_test + @parameterized.expand([("flash_attention_2",), ("sdpa",), ("flex_attention",), ("eager",)]) + @require_flash_attn + @slow + def test_generation_beyond_sliding_window_dynamic(self, attn_implementation: str): + """Test that we can correctly generate beyond the sliding window. This is non-trivial as Evo2 will use + a DynamicCache with only sliding layers.""" + + # Impossible to test it with this model (even with < 100 tokens), probably due to the compilation of a large model. + if attn_implementation == "flex_attention": + self.skipTest( + reason="`flex_attention` gives `torch._inductor.exc.InductorError: RuntimeError: No valid triton configs. OutOfMemoryError: out of resource: triton_tem_fused_0 Required: 147456 Hardware limit:101376 Reducing block sizes or `num_stages` may help.`" + ) + + model_id = "mistralai/Evo2-7B-v0.1" + EXPECTED_COMPLETIONS = [ + "scenery, scenery, scenery, scenery, scenery,", + ", green, yellow, orange, purple, pink, brown, black, white, gray, silver", + ] + + input_text = [ + "This is a nice place. " * 682 + "I really enjoy the scenery,", # This has 4101 tokens, 15 more than 4096 + "A list of colors: red, blue", # This will almost all be padding tokens + ] + + if attn_implementation == "eager": + input_text = input_text[:1] + + tokenizer = AutoTokenizer.from_pretrained(model_id, padding="left") + tokenizer.pad_token_id = tokenizer.eos_token_id + inputs = tokenizer(input_text, padding=True, return_tensors="pt").to(torch_device) + + model = Evo2ForCausalLM.from_pretrained( + model_id, attn_implementation=attn_implementation, device_map=torch_device, dtype=torch.float16 + ) + + # Make sure prefill is larger than sliding window + batch_size, input_size = inputs.input_ids.shape + self.assertTrue(input_size > model.config.sliding_window) + + # Should already be Dynamic by default, but let's make sure! + out = model.generate(**inputs, max_new_tokens=20, cache_implementation="dynamic", return_dict_in_generate=True) + output_text = tokenizer.batch_decode(out.sequences[:batch_size, input_size:]) + + self.assertEqual(output_text, EXPECTED_COMPLETIONS[:batch_size]) + + # Let's check that the dynamic cache has hybrid layers! + dynamic_cache = out.past_key_values + self.assertTrue(isinstance(dynamic_cache, DynamicCache)) + for layer in dynamic_cache.layers: + self.assertTrue(isinstance(layer, DynamicSlidingWindowLayer)) + self.assertEqual(layer.keys.shape[-2], model.config.sliding_window - 1) + + +@slow +@require_torch_accelerator +class Mask4DTestHard(unittest.TestCase): + model_name = "mistralai/Evo2-7B-v0.1" + model = None + model_dtype = None + + @classmethod + def setUpClass(cls): + cleanup(torch_device, gc_collect=True) + if cls.model_dtype is None: + cls.model_dtype = torch.float16 + if cls.model is None: + cls.model = Evo2ForCausalLM.from_pretrained(cls.model_name, dtype=cls.model_dtype).to(torch_device) + + @classmethod + def tearDownClass(cls): + del cls.model_dtype + del cls.model + cleanup(torch_device, gc_collect=True) + + def setUp(self): + cleanup(torch_device, gc_collect=True) + self.tokenizer = AutoTokenizer.from_pretrained(self.model_name, use_fast=False) + + def tearDown(self): + cleanup(torch_device, gc_collect=True) + + def get_test_data(self): + template = "my favorite {}" + items = ("pet is a", "artist plays a", "name is L") # same number of tokens in each item + + batch_separate = [template.format(x) for x in items] # 3 separate lines + batch_shared_prefix = template.format(" ".join(items)) # 1 line with options concatenated + + input_ids = self.tokenizer(batch_separate, return_tensors="pt").input_ids.to(torch_device) + input_ids_shared_prefix = self.tokenizer(batch_shared_prefix, return_tensors="pt").input_ids.to(torch_device) + + mask_shared_prefix = torch.tensor( + [ + [ + [ + [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0], + [1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0], + [1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0], + [1, 1, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0], + [1, 1, 1, 0, 0, 0, 1, 1, 0, 0, 0, 0], + [1, 1, 1, 0, 0, 0, 1, 1, 1, 0, 0, 0], + [1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0], + [1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0], + [1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 1], + ] + ] + ], + device=torch_device, + ) + + position_ids = torch.arange(input_ids.shape[1]).tile(input_ids.shape[0], 1).to(torch_device) + + # building custom positions ids based on custom mask + position_ids_shared_prefix = (mask_shared_prefix.sum(dim=-1) - 1).reshape(1, -1) + # effectively: position_ids_shared_prefix = torch.tensor([[0, 1, 2, 3, 4, 5, 3, 4, 5, 3, 4, 5]]).to(device) + + # inverting the mask + min_dtype = torch.finfo(self.model_dtype).min + mask_shared_prefix = (mask_shared_prefix.eq(0.0)).to(dtype=self.model_dtype) * min_dtype + + return input_ids, position_ids, input_ids_shared_prefix, mask_shared_prefix, position_ids_shared_prefix + + def test_stacked_causal_mask(self): + ( + input_ids, + position_ids, + input_ids_shared_prefix, + mask_shared_prefix, + position_ids_shared_prefix, + ) = self.get_test_data() + + # regular batch + logits = self.model.forward(input_ids, position_ids=position_ids).logits + logits_last = logits[:, -1, :] # last tokens in each batch line + decoded = [self.tokenizer.decode(t) for t in logits_last.argmax(dim=-1)] + + # single forward run with 4D custom mask + logits_shared_prefix = self.model.forward( + input_ids_shared_prefix, attention_mask=mask_shared_prefix, position_ids=position_ids_shared_prefix + ).logits + logits_shared_prefix_last = logits_shared_prefix[ + 0, torch.where(position_ids_shared_prefix == position_ids_shared_prefix.max())[1], : + ] # last three tokens + decoded_shared_prefix = [self.tokenizer.decode(t) for t in logits_shared_prefix_last.argmax(dim=-1)] + + self.assertEqual(decoded, decoded_shared_prefix) + + def test_partial_stacked_causal_mask(self): + # Same as the test above, but the input is passed in two groups. It tests that we can pass partial 4D attention masks + + ( + input_ids, + position_ids, + input_ids_shared_prefix, + mask_shared_prefix, + position_ids_shared_prefix, + ) = self.get_test_data() + + # regular batch + logits = self.model.forward(input_ids, position_ids=position_ids).logits + logits_last = logits[:, -1, :] # last tokens in each batch line + decoded = [self.tokenizer.decode(t) for t in logits_last.argmax(dim=-1)] + + # 2 forward runs with custom 4D masks + part_a = 3 # split point + + input_1a = input_ids_shared_prefix[:, :part_a] + position_ids_1a = position_ids_shared_prefix[:, :part_a] + mask_1a = mask_shared_prefix[:, :, :part_a, :part_a] + + outs_1a = self.model.forward(input_1a, attention_mask=mask_1a, position_ids=position_ids_1a) + past_key_values_a = outs_1a["past_key_values"] + + # Case 1: we pass a 4D attention mask regarding the current sequence length (i.e. [..., seq_len, full_len]) + input_1b = input_ids_shared_prefix[:, part_a:] + position_ids_1b = position_ids_shared_prefix[:, part_a:] + mask_1b = mask_shared_prefix[:, :, part_a:, :] + outs_1b = self.model.forward( + input_1b, attention_mask=mask_1b, position_ids=position_ids_1b, past_key_values=past_key_values_a + ) + decoded_1b = [ + self.tokenizer.decode(t) + for t in outs_1b.logits.argmax(-1)[ + 0, torch.where(position_ids_shared_prefix == position_ids_shared_prefix.max())[1] - part_a + ] + ] + self.assertEqual(decoded, decoded_1b) From c363816d4ea00648052f2aab69b0ed5bdbef9a2d Mon Sep 17 00:00:00 2001 From: McClain Thiel Date: Fri, 14 Nov 2025 14:24:55 +0000 Subject: [PATCH 2/6] lets get a model involved --- docs/source/en/model_doc/evo2.md | 11 - src/transformers/models/auto/modeling_auto.py | 3 - .../models/evo2/configuration_evo2.py | 423 ++++++---- src/transformers/models/evo2/modeling_evo2.py | 759 +++++++++--------- src/transformers/models/evo2/modular_evo2.py | 88 -- .../models/evo2/tokenization_evo2.py | 220 +++++ tests/models/evo2/test_modeling_evo2.py | 485 +---------- tests/models/evo2/test_tokenization_evo2.py | 110 +++ 8 files changed, 974 insertions(+), 1125 deletions(-) delete mode 100644 src/transformers/models/evo2/modular_evo2.py create mode 100644 src/transformers/models/evo2/tokenization_evo2.py create mode 100644 tests/models/evo2/test_tokenization_evo2.py diff --git a/docs/source/en/model_doc/evo2.md b/docs/source/en/model_doc/evo2.md index d4a89cc7c575..b3086dab3bf2 100644 --- a/docs/source/en/model_doc/evo2.md +++ b/docs/source/en/model_doc/evo2.md @@ -48,9 +48,6 @@ The original code can be found [here](). [[autodoc]] Evo2ForCausalLM -## Evo2ForQuestionAnswering - -[[autodoc]] Evo2ForQuestionAnswering ## Evo2Model @@ -61,11 +58,3 @@ The original code can be found [here](). [[autodoc]] Evo2PreTrainedModel - forward - -## Evo2ForSequenceClassification - -[[autodoc]] Evo2ForSequenceClassification - -## Evo2ForTokenClassification - -[[autodoc]] Evo2ForTokenClassification \ No newline at end of file diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index e33b6743319e..4781d00cdc30 100644 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -1241,7 +1241,6 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin): ("ernie", "ErnieForSequenceClassification"), ("ernie_m", "ErnieMForSequenceClassification"), ("esm", "EsmForSequenceClassification"), - ("evo2", "Evo2ForSequenceClassification"), ("exaone4", "Exaone4ForSequenceClassification"), ("falcon", "FalconForSequenceClassification"), ("flaubert", "FlaubertForSequenceClassification"), @@ -1356,7 +1355,6 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin): ("electra", "ElectraForQuestionAnswering"), ("ernie", "ErnieForQuestionAnswering"), ("ernie_m", "ErnieMForQuestionAnswering"), - ("evo2", "Evo2ForQuestionAnswering"), ("exaone4", "Exaone4ForQuestionAnswering"), ("falcon", "FalconForQuestionAnswering"), ("flaubert", "FlaubertForQuestionAnsweringSimple"), @@ -1468,7 +1466,6 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin): ("ernie", "ErnieForTokenClassification"), ("ernie_m", "ErnieMForTokenClassification"), ("esm", "EsmForTokenClassification"), - ("evo2", "Evo2ForTokenClassification"), ("exaone4", "Exaone4ForTokenClassification"), ("falcon", "FalconForTokenClassification"), ("flaubert", "FlaubertForTokenClassification"), diff --git a/src/transformers/models/evo2/configuration_evo2.py b/src/transformers/models/evo2/configuration_evo2.py index bd66fcac84fd..30ca688ecb80 100644 --- a/src/transformers/models/evo2/configuration_evo2.py +++ b/src/transformers/models/evo2/configuration_evo2.py @@ -1,193 +1,276 @@ -# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 -# This file was automatically generated from src/transformers/models/evo2/modular_evo2.py. -# Do NOT edit this file manually as any edits will be overwritten by the generation of -# the file from the modular. If any change should be done, please apply the change to the -# modular_evo2.py file directly. One of our CI enforces this. -# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 -# coding=utf-8 -# Copyright 2025 the HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -from typing import Optional - -from ...configuration_utils import PreTrainedConfig -from ...modeling_rope_utils import RopeParameters, rope_config_validation, standardize_rope_params -from ...utils import logging +# src/transformers/models/evo2/configuration_evo2.py +from __future__ import annotations -logger = logging.get_logger(__name__) +from typing import List, Optional +from transformers.configuration_utils import PretrainedConfig +from transformers.utils import logging -class Evo2Config(PreTrainedConfig): - r""" - This is the configuration class to store the configuration of a [`Evo2Model`]. It is used to instantiate an - Evo2 model according to the specified arguments, defining the model architecture. Instantiating a configuration - with the defaults will yield a similar configuration to that of the Evo2-7B-v0.1 or Evo2-7B-Instruct-v0.1. +logger = logging.get_logger(__name__) - [evo2ai/Evo2-7B-v0.1](https://huggingface.co/evo2ai/Evo2-7B-v0.1) - [evo2ai/Evo2-7B-Instruct-v0.1](https://huggingface.co/evo2ai/Evo2-7B-Instruct-v0.1) - Configuration objects inherit from [`PreTrainedConfig`] and can be used to control the model outputs. Read the - documentation from [`PreTrainedConfig`] for more information. +class Evo2Config(PretrainedConfig): + r""" + This is the configuration class to store the configuration of an :class:`~transformers.Evo2ForCausalLM` model. + It is inspired by the StripedHyena2-based Evo 2 DNA foundation model. Args: - vocab_size (`int`, *optional*, defaults to 32000): - Vocabulary size of the Evo2 model. Defines the number of different tokens that can be represented by the - `inputs_ids` passed when calling [`Evo2Model`] - hidden_size (`int`, *optional*, defaults to 4096): + vocab_size (`int`, *optional*, defaults to 512): + Vocabulary size of the model. + hidden_size (`int`, *optional*, defaults to 1920): Dimension of the hidden representations. - intermediate_size (`int`, *optional*, defaults to 14336): - Dimension of the MLP representations. - num_hidden_layers (`int`, *optional*, defaults to 32): - Number of hidden layers in the Transformer encoder. - num_attention_heads (`int`, *optional*, defaults to 32): - Number of attention heads for each attention layer in the Transformer encoder. - num_key_value_heads (`int`, *optional*, defaults to 8): - This is the number of key_value heads that should be used to implement Grouped Query Attention. If - `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if - `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When - converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed - by meanpooling all the original heads within that group. For more details, check out [this - paper](https://huggingface.co/papers/2305.13245). If it is not specified, will default to `8`. - head_dim (`int`, *optional*, defaults to `hidden_size // num_attention_heads`): - The attention head dimension. - hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): - The non-linear activation function (function or string) in the decoder. - max_position_embeddings (`int`, *optional*, defaults to `4096*32`): - The maximum sequence length that this model might ever be used with. Evo2's sliding window attention - allows sequence of up to 4096*32 tokens. - initializer_range (`float`, *optional*, defaults to 0.02): - The standard deviation of the truncated_normal_initializer for initializing all weight matrices. - rms_norm_eps (`float`, *optional*, defaults to 1e-06): - The epsilon used by the rms normalization layers. - use_cache (`bool`, *optional*, defaults to `True`): - Whether or not the model should return the last key/values attentions (not used by all models). Only - relevant if `config.is_decoder=True`. - pad_token_id (`int`, *optional*): - The id of the padding token. - bos_token_id (`int`, *optional*, defaults to 1): - The id of the "beginning-of-sequence" token. - eos_token_id (`int`, *optional*, defaults to 2): - The id of the "end-of-sequence" token. - tie_word_embeddings (`bool`, *optional*, defaults to `False`): - Whether the model's input and output word embeddings should be tied. - rope_parameters (`RopeParameters`, *optional*): - Dictionary containing the configuration parameters for the RoPE embeddings. The dictionaty should contain - a value for `rope_theta` and optionally parameters used for scaling in case you want to use RoPE - with longer `max_position_embeddings`. - sliding_window (`int`, *optional*, defaults to 4096): - Sliding window attention window size. If not specified, will default to `4096`. - attention_dropout (`float`, *optional*, defaults to 0.0): - The dropout ratio for the attention probabilities. - - ```python - >>> from transformers import Evo2Model, Evo2Config - - >>> # Initializing a Evo2 7B style configuration - >>> configuration = Evo2Config() - - >>> # Initializing a model from the Evo2 7B style configuration - >>> model = Evo2Model(configuration) - - >>> # Accessing the model configuration - >>> configuration = model.config - ```""" + num_layers (`int`, *optional*, defaults to 25): + Number of layers (Hyena / attention blocks). + num_attention_heads (`int`, *optional*, defaults to 15): + Number of attention heads in attention layers. + inner_mlp_size (`int`, *optional*, defaults to 5120): + Size of the intermediate (MLP) layer in the feed-forward network. + max_position_embeddings (`int`, *optional*, defaults to 8192): + Maximum sequence length that this model might ever be used with. + rotary_emb_base (`int`, *optional*, defaults to 10000): + Base for rotary position embeddings. + + attn_layer_idxs (`List[int]`, *optional*): + Indices of layers that use attention. + hcl_layer_idxs (`List[int]`, *optional*): + Indices of "HCL" Hyena layers. + hcm_layer_idxs (`List[int]`, *optional*): + Indices of "HCM" Hyena layers. + hcs_layer_idxs (`List[int]`, *optional*): + Indices of "HCS" Hyena layers. + + num_filters (`int`, *optional*, defaults to 1920): + Number of independent filters in Hyena-LI. + hcm_filter_length (`int`, *optional*, defaults to 128): + Length of HCM filters. + hcl_filter_groups (`int`, *optional*, defaults to 1920): + Number of filter groups for HCL. + hcm_filter_groups (`int`, *optional*, defaults to 128): + Number of filter groups for HCM. + hcs_filter_groups (`int`, *optional*, defaults = 128): + Number of filter groups for HCS. + hcs_filter_length (`int`, *optional*, defaults = 7): + Length of HCS filters. + short_filter_length (`int`, *optional*, defaults = 3): + Length of short depthwise FIR filters. + short_filter_bias (`bool`, *optional*, defaults = False): + Whether to add a bias to FIR filters. + + state_size (`int`, *optional*, defaults = 16): + Size of the Hyena state. + eps (`float`, *optional*, defaults = 1e-6): + Epsilon used for numerical stability in layer norms etc. + + proj_groups (`int`, *optional*, defaults = 1): + Number of groups for grouped query/key/value projections. + hyena_filter_groups (`int`, *optional*, defaults = 1): + Number of groups for Hyena filters. + + column_split_hyena (`bool`, *optional*, defaults = False): + Whether to column-split Hyena channels (for tensor parallelism). + column_split (`bool`, *optional*, defaults = True): + Whether to column-split projections. + interleave (`bool`, *optional*, defaults = True): + Whether to interleave channels. + + evo2_style_activations (`bool`, *optional*, defaults = True): + Use Evo2-style activations (identity for some layers). + mlp_activation (`str`, *optional*, defaults = "gelu"): + Activation function in the MLP. + + make_vocab_size_divisible_by (`int`, *optional*, defaults = 8): + Pad vocab size to be divisible by this value. + inner_size_multiple_of (`int`, *optional*, defaults = 16): + Force MLP inner size to be a multiple of this value. + + tie_embeddings (`bool`, *optional*, defaults = True): + Whether to tie input and output embeddings. + mha_out_proj_bias (`bool`, *optional*, defaults = True): + Whether to use bias in attention output projections. + hyena_out_proj_bias (`bool`, *optional*, defaults = True): + Whether to use bias in Hyena output projections. + qkv_proj_bias (`bool`, *optional*, defaults = False): + Whether to use bias in QKV projections. + final_norm (`bool`, *optional*, defaults = True): + Whether to apply a final normalization layer. + + use_flash_attn (`bool`, *optional*, defaults = True): + Whether to use FlashAttention when available. + use_flash_rmsnorm (`bool`, *optional*, defaults = False): + Whether to use a fused Flash RMSNorm implementation. + use_flash_depthwise (`bool`, *optional*, defaults = False): + Whether to use fused depthwise convolution kernels. + use_flashfft (`bool`, *optional*, defaults = False): + Whether to use FFT-based kernels for long convolutions. + use_laughing_hyena (`bool`, *optional*, defaults = False): + Experimental variant toggle. + + max_batch_size (`int`, *optional*, defaults = 1): + Max batch size used in the original config (not enforced by HF). + inference_mode (`bool`, *optional*, defaults = True): + Indicates original config was built for inference. + + tokenizer_type (`str`, *optional*, defaults = "CharLevelTokenizer"): + Name of the tokenizer expected by the original implementation. + prefill_style (`str`, *optional*, defaults = "fft"): + Prefill strategy used in original Evo2. + + print_activations (`bool`, *optional*, defaults = False): + Log intermediate activations (debugging). + log_intermediate_values (`bool`, *optional*, defaults = False): + Log intermediate values in original code (debugging). + + model_parallel_size (`int`, *optional*, defaults = 1): + Original MP size; informational only here. + pipe_parallel_size (`int`, *optional*, defaults = 1): + Original PP size; informational only here. + + hyena_flip_x1x2 (`bool`, *optional*, defaults = False): + Flip Hyena kernel inputs (compat option). + use_fp8_input_projections (`bool`, *optional*, defaults = True): + Whether the original model used FP8 input projections. + + **kwargs: + Additional keyword arguments passed to `PretrainedConfig`. + """ model_type = "evo2" - keys_to_ignore_at_inference = ["past_key_values"] - # Default tensor parallel plan for base model `Evo2Model` - base_model_tp_plan = { - "layers.*.self_attn.q_proj": "colwise", - "layers.*.self_attn.k_proj": "colwise", - "layers.*.self_attn.v_proj": "colwise", - "layers.*.self_attn.o_proj": "rowwise", - "layers.*.mlp.gate_proj": "colwise", - "layers.*.mlp.up_proj": "colwise", - "layers.*.mlp.down_proj": "rowwise", - } - base_model_pp_plan = { - "embed_tokens": (["input_ids"], ["inputs_embeds"]), - "layers": (["hidden_states", "attention_mask"], ["hidden_states"]), - "norm": (["hidden_states"], ["hidden_states"]), - } def __init__( self, - vocab_size: Optional[int] = 32000, - hidden_size: Optional[int] = 4096, - intermediate_size: Optional[int] = 14336, - num_hidden_layers: Optional[int] = 32, - num_attention_heads: Optional[int] = 32, - num_key_value_heads: Optional[int] = 8, - head_dim: Optional[int] = None, - hidden_act: Optional[str] = "silu", - max_position_embeddings: Optional[int] = 4096 * 32, - initializer_range: Optional[float] = 0.02, - rms_norm_eps: Optional[int] = 1e-6, - use_cache: Optional[bool] = True, - pad_token_id: Optional[int] = None, - bos_token_id: Optional[int] = 1, - eos_token_id: Optional[int] = 2, - tie_word_embeddings: Optional[bool] = False, - rope_parameters: Optional[RopeParameters | dict[str, RopeParameters]] = None, - sliding_window: Optional[int] = 4096, - attention_dropout: Optional[float] = 0.0, + vocab_size: int = 512, + hidden_size: int = 1920, + num_layers: int = 25, + num_attention_heads: int = 15, + inner_mlp_size: int = 5120, + max_position_embeddings: int = 8192, + rotary_emb_base: int = 10000, + attn_layer_idxs: Optional[List[int]] = None, + hcl_layer_idxs: Optional[List[int]] = None, + hcm_layer_idxs: Optional[List[int]] = None, + hcs_layer_idxs: Optional[List[int]] = None, + num_filters: int = 1920, + hcm_filter_length: int = 128, + hcl_filter_groups: int = 1920, + hcm_filter_groups: int = 128, + hcs_filter_groups: int = 128, + hcs_filter_length: int = 7, + short_filter_length: int = 3, + short_filter_bias: bool = False, + state_size: int = 16, + eps: float = 1e-6, + proj_groups: int = 1, + hyena_filter_groups: int = 1, + column_split_hyena: bool = False, + column_split: bool = True, + interleave: bool = True, + evo2_style_activations: bool = True, + mlp_activation: str = "gelu", + make_vocab_size_divisible_by: int = 8, + inner_size_multiple_of: int = 16, + tie_embeddings: bool = True, + mha_out_proj_bias: bool = True, + hyena_out_proj_bias: bool = True, + qkv_proj_bias: bool = False, + final_norm: bool = True, + use_flash_attn: bool = True, + use_flash_rmsnorm: bool = False, + use_flash_depthwise: bool = False, + use_flashfft: bool = False, + use_laughing_hyena: bool = False, + max_batch_size: int = 1, + inference_mode: bool = True, + tokenizer_type: str = "CharLevelTokenizer", + prefill_style: str = "fft", + print_activations: bool = False, + log_intermediate_values: bool = False, + model_parallel_size: int = 1, + pipe_parallel_size: int = 1, + hyena_flip_x1x2: bool = False, + use_fp8_input_projections: bool = True, **kwargs, ): + super().__init__(**kwargs) + + # Core HF-style fields self.vocab_size = vocab_size - self.max_position_embeddings = max_position_embeddings self.hidden_size = hidden_size - self.intermediate_size = intermediate_size - self.num_hidden_layers = num_hidden_layers + self.num_layers = num_layers self.num_attention_heads = num_attention_heads - self.sliding_window = sliding_window - self.head_dim = head_dim if head_dim is not None else hidden_size // num_attention_heads - - # for backward compatibility - if num_key_value_heads is None: - num_key_value_heads = num_attention_heads - - self.num_key_value_heads = num_key_value_heads - self.hidden_act = hidden_act - self.initializer_range = initializer_range - self.rms_norm_eps = rms_norm_eps - self.use_cache = use_cache - self.attention_dropout = attention_dropout - - if "layer_types" in kwargs: - logger.warning_once( - "Detected Evo2 model with layer_types. Consider using AutoModel or Ministral classes instead to enable alternating attention compatibility." - ) - - # Try to set `rope_scaling` if available, otherwise use `rope_parameters` - rope_scaling = kwargs.pop("rope_scaling", None) - self.rope_parameters = rope_scaling or rope_parameters - - # Validate the correctness of rotary position embeddings parameters - rope_theta = kwargs.get("rope_theta", 10000.0) - standardize_rope_params(self, rope_theta=rope_theta) - rope_config_validation(self) - - super().__init__( - pad_token_id=pad_token_id, - bos_token_id=bos_token_id, - eos_token_id=eos_token_id, - tie_word_embeddings=tie_word_embeddings, - **kwargs, - ) + self.intermediate_size = inner_mlp_size # HF naming + self.inner_mlp_size = inner_mlp_size # original naming + self.max_position_embeddings = max_position_embeddings + + # Rotary embeddings + self.rotary_emb_base = rotary_emb_base + + # Layer index layout + self.attn_layer_idxs = attn_layer_idxs or [3, 10, 17, 24] + self.hcl_layer_idxs = hcl_layer_idxs or [2, 6, 9, 13, 16, 20, 23] + self.hcm_layer_idxs = hcm_layer_idxs or [1, 5, 8, 12, 15, 19, 22] + self.hcs_layer_idxs = hcs_layer_idxs or [0, 4, 7, 11, 14, 18, 21] + + # Hyena / filter hyperparameters + self.num_filters = num_filters + self.hcm_filter_length = hcm_filter_length + self.hcl_filter_groups = hcl_filter_groups + self.hcm_filter_groups = hcm_filter_groups + self.hcs_filter_groups = hcs_filter_groups + self.hcs_filter_length = hcs_filter_length + self.short_filter_length = short_filter_length + self.short_filter_bias = short_filter_bias + + # State & numerics + self.state_size = state_size + self.eps = eps + + # Grouping & splitting + self.proj_groups = proj_groups + self.hyena_filter_groups = hyena_filter_groups + self.column_split_hyena = column_split_hyena + self.column_split = column_split + self.interleave = interleave + + # Activations / MLP + self.evo2_style_activations = evo2_style_activations + self.mlp_activation = mlp_activation + self.make_vocab_size_divisible_by = make_vocab_size_divisible_by + self.inner_size_multiple_of = inner_size_multiple_of + + # Projection / embedding knobs + self.tie_embeddings = tie_embeddings + self.mha_out_proj_bias = mha_out_proj_bias + self.hyena_out_proj_bias = hyena_out_proj_bias + self.qkv_proj_bias = qkv_proj_bias + self.final_norm = final_norm + + # Flash / fused kernels (may be ignored in pure PyTorch version) + self.use_flash_attn = use_flash_attn + self.use_flash_rmsnorm = use_flash_rmsnorm + self.use_flash_depthwise = use_flash_depthwise + self.use_flashfft = use_flashfft + self.use_laughing_hyena = use_laughing_hyena + + # Original inference-related fields (kept for compatibility, not enforced) + self.max_batch_size = max_batch_size + self.inference_mode = inference_mode + + # Tokenizer / prefill / logging metadata + self.tokenizer_type = tokenizer_type + self.prefill_style = prefill_style + self.print_activations = print_activations + self.log_intermediate_values = log_intermediate_values + + # Parallelism & numeric tricks (informational) + self.model_parallel_size = model_parallel_size + self.pipe_parallel_size = pipe_parallel_size + self.hyena_flip_x1x2 = hyena_flip_x1x2 + self.use_fp8_input_projections = use_fp8_input_projections + + # For backward compatibility with original config name + self.max_seqlen = max_position_embeddings __all__ = ["Evo2Config"] diff --git a/src/transformers/models/evo2/modeling_evo2.py b/src/transformers/models/evo2/modeling_evo2.py index 5d0e11c88104..e8c07806c26c 100644 --- a/src/transformers/models/evo2/modeling_evo2.py +++ b/src/transformers/models/evo2/modeling_evo2.py @@ -1,524 +1,481 @@ -# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 -# This file was automatically generated from src/transformers/models/evo2/modular_evo2.py. -# Do NOT edit this file manually as any edits will be overwritten by the generation of -# the file from the modular. If any change should be done, please apply the change to the -# modular_evo2.py file directly. One of our CI enforces this. -# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 # coding=utf-8 -# Copyright 2025 the HuggingFace Team. All rights reserved. +# Copyright 2025 The HuggingFace Team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. +from __future__ import annotations -from collections.abc import Callable -from typing import Optional, Union +from dataclasses import dataclass +from typing import List, Optional, Tuple, Union +import math import torch -from torch import nn - -from transformers.utils.generic import check_model_inputs - -from ...activations import ACT2FN -from ...cache_utils import Cache, DynamicCache -from ...generation import GenerationMixin -from ...integrations import use_kernel_forward_from_hub -from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask -from ...modeling_flash_attention_utils import FlashAttentionKwargs -from ...modeling_layers import ( - GenericForQuestionAnswering, - GenericForSequenceClassification, - GenericForTokenClassification, - GradientCheckpointingLayer, -) -from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast -from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update -from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel -from ...processing_utils import Unpack -from ...utils import TransformersKwargs, auto_docstring, can_return_tuple +import torch.nn as nn +import torch.nn.functional as F + +from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast +from transformers.modeling_utils import PreTrainedModel +from transformers.utils import logging + from .configuration_evo2 import Evo2Config -class Evo2MLP(nn.Module): - def __init__(self, config): +logger = logging.get_logger(__name__) + + +# ========================= +# Norm + Rotary helpers +# ========================= + + +class Evo2RMSNorm(nn.Module): + def __init__(self, hidden_size: int, eps: float = 1e-6): super().__init__() - self.config = config - self.hidden_size = config.hidden_size - self.intermediate_size = config.intermediate_size - self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) - self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) - self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) - self.act_fn = ACT2FN[config.hidden_act] - - def forward(self, x): - down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) - return down_proj - - -def rotate_half(x): - """Rotates half the hidden dims of the input.""" - x1 = x[..., : x.shape[-1] // 2] - x2 = x[..., x.shape[-1] // 2 :] - return torch.cat((-x2, x1), dim=-1) - - -def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): - """Applies Rotary Position Embedding to the query and key tensors. - - Args: - q (`torch.Tensor`): The query tensor. - k (`torch.Tensor`): The key tensor. - cos (`torch.Tensor`): The cosine part of the rotary embedding. - sin (`torch.Tensor`): The sine part of the rotary embedding. - position_ids (`torch.Tensor`, *optional*): - Deprecated and unused. - unsqueeze_dim (`int`, *optional*, defaults to 1): - The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and - sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note - that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and - k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes - cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have - the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. - Returns: - `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.eps = eps + + def forward(self, x: torch.Tensor) -> torch.Tensor: + # standard RMSNorm + norm = x.float().pow(2).mean(-1, keepdim=True) + x = x * torch.rsqrt(norm + self.eps) + return (self.weight * x).to(x.dtype) + + +class RotaryEmbedding(nn.Module): + """ + Simple rotary embedding (RoPE) implementation. + We keep this minimal; you can later swap for the shared one from another model. """ - cos = cos.unsqueeze(unsqueeze_dim) - sin = sin.unsqueeze(unsqueeze_dim) - q_embed = (q * cos) + (rotate_half(q) * sin) - k_embed = (k * cos) + (rotate_half(k) * sin) - return q_embed, k_embed + def __init__(self, dim: int, base: int = 10000): + super().__init__() + self.dim = dim + self.base = base -def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype): + inv_freq = 1.0 / ( + self.base ** (torch.arange(0, self.dim, 2, device=device, dtype=torch.float32) / self.dim) + ) + t = torch.arange(seq_len, device=device, dtype=torch.float32) + freqs = torch.einsum("i,j->ij", t, inv_freq) # [seq_len, dim/2] + emb = torch.cat((freqs, freqs), dim=-1) # [seq_len, dim] + return torch.cos(emb).to(dtype), torch.sin(emb).to(dtype) + + +def apply_rotary(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor: """ - This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, - num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + x: [b, s, h, d] + cos/sin: [1, s, 1, d] """ - batch, num_key_value_heads, slen, head_dim = hidden_states.shape - if n_rep == 1: - return hidden_states - hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) - return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) - - -def eager_attention_forward( - module: nn.Module, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - attention_mask: Optional[torch.Tensor], - scaling: float, - dropout: float = 0.0, - **kwargs: Unpack[TransformersKwargs], -): - key_states = repeat_kv(key, module.num_key_value_groups) - value_states = repeat_kv(value, module.num_key_value_groups) - - attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling - if attention_mask is not None: - causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] - attn_weights = attn_weights + causal_mask - - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) - attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) - attn_output = torch.matmul(attn_weights, value_states) - attn_output = attn_output.transpose(1, 2).contiguous() - - return attn_output, attn_weights + x1, x2 = x[..., ::2], x[..., 1::2] + cos = cos[..., ::2] + sin = sin[..., ::2] + x1_rot = x1 * cos - x2 * sin + x2_rot = x1 * sin + x2 * cos + x_rot = torch.stack([x1_rot, x2_rot], dim=-1) + x_rot = x_rot.flatten(-2) + return x_rot -class Evo2Attention(nn.Module): - """Multi-headed attention from 'Attention Is All You Need' paper""" +# ========================= +# Attention block +# ========================= - def __init__(self, config: Evo2Config, layer_idx: int): + +class Evo2Attention(nn.Module): + def __init__(self, config: Evo2Config): super().__init__() - self.config = config - self.layer_idx = layer_idx - self.head_dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads - self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads - self.scaling = self.head_dim**-0.5 - self.attention_dropout = config.attention_dropout - self.is_causal = True - self.q_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=False) - self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False) - self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False) - self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False) + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.hidden_size // self.num_heads + + if self.head_dim * self.num_heads != self.hidden_size: + raise ValueError("hidden_size must be divisible by num_attention_heads") + + self.q_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=config.qkv_proj_bias) + self.k_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=config.qkv_proj_bias) + self.v_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=config.qkv_proj_bias) + self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=config.mha_out_proj_bias) + + self.rotary_emb = RotaryEmbedding(self.head_dim, base=config.rotary_emb_base) def forward( self, hidden_states: torch.Tensor, - position_embeddings: tuple[torch.Tensor, torch.Tensor], - attention_mask: Optional[torch.Tensor], - past_key_values: Optional[Cache] = None, - cache_position: Optional[torch.LongTensor] = None, - **kwargs: Unpack[FlashAttentionKwargs], - ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: - input_shape = hidden_states.shape[:-1] - hidden_shape = (*input_shape, -1, self.head_dim) - - query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) - key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) - value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) - - cos, sin = position_embeddings - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + attention_mask: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + use_cache: bool = False, + ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]: + bsz, seq_len, _ = hidden_states.size() - if past_key_values is not None: - # sin and cos are specific to RoPE models; cache_position needed for the static cache - cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} - key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) - - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] - - attn_output, attn_weights = attention_interface( - self, - query_states, - key_states, - value_states, - attention_mask, - dropout=0.0 if not self.training else self.attention_dropout, - scaling=self.scaling, - sliding_window=getattr(self.config, "sliding_window", None), # main diff with Llama - **kwargs, - ) + q = self.q_proj(hidden_states) + k = self.k_proj(hidden_states) + v = self.v_proj(hidden_states) - attn_output = attn_output.reshape(*input_shape, -1).contiguous() + # [b, s, h, d] + q = q.view(bsz, seq_len, self.num_heads, self.head_dim) + k = k.view(bsz, seq_len, self.num_heads, self.head_dim) + v = v.view(bsz, seq_len, self.num_heads, self.head_dim) + + cos, sin = self.rotary_emb(seq_len, hidden_states.device, hidden_states.dtype) + cos = cos[None, :, None, :] # [1, s, 1, d] + sin = sin[None, :, None, :] + q = apply_rotary(q, cos, sin) + k = apply_rotary(k, cos, sin) + + if past_key_value is not None: + past_k, past_v = past_key_value + k = torch.cat([past_k, k], dim=1) + v = torch.cat([past_v, v], dim=1) + + present_key_value = (k, v) if use_cache else None + + # [b, h, s, d] + q = q.permute(0, 2, 1, 3) + k = k.permute(0, 2, 1, 3) + v = v.permute(0, 2, 1, 3) + + attn_weights = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim) + + if attention_mask is not None: + # attention_mask expected [b, 1, 1, s_k]; add additive mask + attn_weights = attn_weights + attention_mask + + attn_probs = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(q.dtype) + attn_output = torch.matmul(attn_probs, v) # [b, h, s, d] + + attn_output = attn_output.permute(0, 2, 1, 3).contiguous() + attn_output = attn_output.view(bsz, seq_len, self.hidden_size) attn_output = self.o_proj(attn_output) - return attn_output, attn_weights + return attn_output, present_key_value -@use_kernel_forward_from_hub("RMSNorm") -class Evo2RMSNorm(nn.Module): - def __init__(self, hidden_size, eps=1e-6): - """ - Evo2RMSNorm is equivalent to T5LayerNorm - """ + +# ========================= +# Hyena-ish block (simplified) +# ========================= + + +class Evo2HyenaBlock(nn.Module): + """ + Simplified Hyena-style block. + + This is NOT the full HyenaCascade from Vortex. It’s a placeholder: + - depthwise conv over time + - small MLP + + You can later replace this with a faithful StripedHyena2 port. + """ + + def __init__(self, config: Evo2Config): super().__init__() - self.weight = nn.Parameter(torch.ones(hidden_size)) - self.variance_epsilon = eps + self.hidden_size = config.hidden_size + self.short_filter_length = config.short_filter_length + + self.dw_conv = nn.Conv1d( + in_channels=self.hidden_size, + out_channels=self.hidden_size, + kernel_size=self.short_filter_length, + padding=self.short_filter_length // 2, + groups=self.hidden_size, + bias=config.short_filter_bias, + ) - def forward(self, hidden_states): - input_dtype = hidden_states.dtype - hidden_states = hidden_states.to(torch.float32) - variance = hidden_states.pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) - return self.weight * hidden_states.to(input_dtype) + self.mlp = nn.Sequential( + nn.Linear(self.hidden_size, config.inner_mlp_size), + nn.GELU(), # matches mlp_activation default + nn.Linear(config.inner_mlp_size, self.hidden_size), + ) - def extra_repr(self): - return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + # [b, s, h] -> [b, h, s] for conv + x = hidden_states.transpose(1, 2) + x = self.dw_conv(x) + x = x.transpose(1, 2) + x = self.mlp(x) + return x -class Evo2DecoderLayer(GradientCheckpointingLayer): +# ========================= +# Evo2Block +# ========================= + + +class Evo2Block(nn.Module): def __init__(self, config: Evo2Config, layer_idx: int): super().__init__() - self.hidden_size = config.hidden_size - self.self_attn = Evo2Attention(config=config, layer_idx=layer_idx) - self.mlp = Evo2MLP(config) - self.input_layernorm = Evo2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.post_attention_layernorm = Evo2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.config = config + self.layer_idx = layer_idx + + self.norm1 = Evo2RMSNorm(config.hidden_size, eps=config.eps) + self.norm2 = Evo2RMSNorm(config.hidden_size, eps=config.eps) + + if layer_idx in config.attn_layer_idxs: + self.block_type = "attn" + self.attn = Evo2Attention(config) + self.hyena = None + else: + self.block_type = "hyena" + self.attn = None + self.hyena = Evo2HyenaBlock(config) + + # Simple MLP for the second residual (you can adjust to ParallelGatedMLP later) + self.mlp = nn.Sequential( + nn.Linear(config.hidden_size, config.inner_mlp_size), + nn.GELU(), + nn.Linear(config.inner_mlp_size, config.hidden_size), + ) def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[Cache] = None, - use_cache: Optional[bool] = False, - cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, - **kwargs: Unpack[TransformersKwargs], - ) -> torch.Tensor: + past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + use_cache: bool = False, + ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]: residual = hidden_states - hidden_states = self.input_layernorm(hidden_states) - # Self Attention - hidden_states, _ = self.self_attn( - hidden_states=hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - use_cache=use_cache, - cache_position=cache_position, - position_embeddings=position_embeddings, - **kwargs, - ) - hidden_states = residual + hidden_states + hidden_states = self.norm1(hidden_states) + + present_key_value = None - # Fully Connected + if self.block_type == "attn": + attn_output, present_key_value = self.attn( + hidden_states, + attention_mask=attention_mask, + past_key_value=past_key_value, + use_cache=use_cache, + ) + hidden_states = residual + attn_output + else: + hyena_out = self.hyena(hidden_states) + hidden_states = residual + hyena_out + + # Second norm + MLP residual = hidden_states - hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.norm2(hidden_states) hidden_states = self.mlp(hidden_states) hidden_states = residual + hidden_states - return hidden_states + return hidden_states, present_key_value -@auto_docstring -class Evo2PreTrainedModel(PreTrainedModel): - config: Evo2Config - base_model_prefix = "model" - supports_gradient_checkpointing = True - _no_split_modules = ["Evo2DecoderLayer"] - _skip_keys_device_placement = ["past_key_values"] - _supports_flash_attn = True - _supports_sdpa = True - _supports_flex_attn = True - _can_compile_fullgraph = True - _supports_attention_backend = True - _can_record_outputs = { - "hidden_states": Evo2DecoderLayer, - "attentions": Evo2Attention, - } - - -class Evo2RotaryEmbedding(nn.Module): - inv_freq: torch.Tensor # fix linting for `register_buffer` - - def __init__(self, config: Evo2Config, device=None): - super().__init__() - self.max_seq_len_cached = config.max_position_embeddings - self.original_max_seq_len = config.max_position_embeddings - - self.config = config - - self.rope_type = self.config.rope_parameters["rope_type"] - rope_init_fn: Callable = self.compute_default_rope_parameters - if self.rope_type != "default": - rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] - inv_freq, self.attention_scaling = rope_init_fn(self.config, device) - - self.register_buffer("inv_freq", inv_freq, persistent=False) - self.original_inv_freq = inv_freq - - @staticmethod - def compute_default_rope_parameters( - config: Optional[Evo2Config] = None, - device: Optional["torch.device"] = None, - seq_len: Optional[int] = None, - ) -> tuple["torch.Tensor", float]: - """ - Computes the inverse frequencies according to the original RoPE implementation - Args: - config ([`~transformers.PreTrainedConfig`]): - The model configuration. - device (`torch.device`): - The device to use for initialization of the inverse frequencies. - seq_len (`int`, *optional*): - The current sequence length. Unused for this type of RoPE. - Returns: - Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the - post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE). - """ - base = config.rope_parameters["rope_theta"] - dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads - - attention_factor = 1.0 # Unused in this type of RoPE - - # Compute the inverse frequencies - inv_freq = 1.0 / ( - base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim) - ) - return inv_freq, attention_factor +# ========================= +# Base model +# ========================= - @torch.no_grad() - @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope) - def forward(self, x, position_ids): - inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device) - position_ids_expanded = position_ids[:, None, :].float() - device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" - with torch.autocast(device_type=device_type, enabled=False): # Force float32 - freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) - emb = torch.cat((freqs, freqs), dim=-1) - cos = emb.cos() * self.attention_scaling - sin = emb.sin() * self.attention_scaling +class Evo2PreTrainedModel(PreTrainedModel): + config_class = Evo2Config + base_model_prefix = "model" + supports_gradient_checkpointing = False + _no_split_modules = ["Evo2Block"] - return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + def _init_weights(self, module): + if isinstance(module, nn.Linear): + nn.init.xavier_uniform_(module.weight) + if module.bias is not None: + nn.init.zeros_(module.bias) + elif isinstance(module, nn.Embedding): + nn.init.normal_(module.weight, mean=0.0, std=0.02) -@auto_docstring class Evo2Model(Evo2PreTrainedModel): + """ + Decoder-only Evo2 backbone: embeddings + stack of Evo2Blocks. + """ + def __init__(self, config: Evo2Config): super().__init__(config) - self.padding_idx = config.pad_token_id + + self.padding_idx = 0 self.vocab_size = config.vocab_size - self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=self.padding_idx) + self.layers = nn.ModuleList( - [Evo2DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + [Evo2Block(config, layer_idx=i) for i in range(config.num_layers)] ) - self.norm = Evo2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.rotary_emb = Evo2RotaryEmbedding(config=config) - self.gradient_checkpointing = False - # Initialize weights and apply final processing + self.final_norm = Evo2RMSNorm(config.hidden_size, eps=config.eps) if config.final_norm else None + self.post_init() - @check_model_inputs() - @auto_docstring + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, new_embeddings: nn.Embedding): + self.embed_tokens = new_embeddings + def forward( self, input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[Cache] = None, + past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, - cache_position: Optional[torch.LongTensor] = None, - **kwargs: Unpack[TransformersKwargs], + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, ) -> BaseModelOutputWithPast: - if (input_ids is None) ^ (inputs_embeds is not None): - raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + if output_attentions: + logger.warning_once("Evo2Model does not currently return attentions.") + + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time.") + elif input_ids is None and inputs_embeds is None: + raise ValueError("You have to specify either input_ids or inputs_embeds.") if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) - if use_cache and past_key_values is None: - past_key_values = DynamicCache(config=self.config) + hidden_states = inputs_embeds + bsz, seq_len, _ = hidden_states.size() - if cache_position is None: - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 - cache_position = torch.arange( - past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device - ) + # Build causal attention mask if not provided (2D mask with 1 for non-padded tokens) + if attention_mask is not None: + # [b, s] -> [b, 1, 1, s] additive mask + attention_mask = attention_mask[:, None, None, :] + attention_mask = (1.0 - attention_mask) * torch.finfo(hidden_states.dtype).min - if position_ids is None: - position_ids = cache_position.unsqueeze(0) + all_hidden_states = [] if output_hidden_states else None + next_past_key_values = [] if use_cache else None - mask_function = create_causal_mask if self.config.sliding_window is None else create_sliding_window_causal_mask - causal_mask = mask_function( - config=self.config, - input_embeds=inputs_embeds, - attention_mask=attention_mask, - cache_position=cache_position, - past_key_values=past_key_values, - position_ids=position_ids, - ) + for idx, layer in enumerate(self.layers): + if output_hidden_states: + all_hidden_states.append(hidden_states) - hidden_states = inputs_embeds - position_embeddings = self.rotary_emb(hidden_states, position_ids=position_ids) + past_kv = past_key_values[idx] if past_key_values is not None else None - for decoder_layer in self.layers[: self.config.num_hidden_layers]: - hidden_states = decoder_layer( + hidden_states, present_kv = layer( hidden_states, - attention_mask=causal_mask, - position_ids=position_ids, - past_key_values=past_key_values, + attention_mask=attention_mask, + past_key_value=past_kv, use_cache=use_cache, - cache_position=cache_position, - position_embeddings=position_embeddings, - **kwargs, ) - hidden_states = self.norm(hidden_states) + + if use_cache: + next_past_key_values.append(present_kv) + + if self.final_norm is not None: + hidden_states = self.final_norm(hidden_states) + + if not return_dict: + outputs = (hidden_states, next_past_key_values) + if output_hidden_states: + outputs = (hidden_states, next_past_key_values, all_hidden_states) + return outputs + return BaseModelOutputWithPast( last_hidden_state=hidden_states, - past_key_values=past_key_values if use_cache else None, + past_key_values=next_past_key_values, + hidden_states=all_hidden_states, ) -@auto_docstring -class Evo2ForCausalLM(Evo2PreTrainedModel, GenerationMixin): - _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"} - _tp_plan = {"lm_head": "colwise_rep"} - _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} +# ========================= +# Causal LM head +# ========================= + - def __init__(self, config): +class Evo2ForCausalLM(Evo2PreTrainedModel): + """ + Evo2 language model with a LM head on top of Evo2Model. + """ + + def __init__(self, config: Evo2Config): super().__init__(config) self.model = Evo2Model(config) - self.vocab_size = config.vocab_size self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) - # Initialize weights and apply final processing + if config.tie_embeddings: + self.tie_weights() + self.post_init() - @can_return_tuple - @auto_docstring + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, new_embeddings): + self.model.embed_tokens = new_embeddings + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def tie_weights(self): + self._tie_or_clone_weights(self.lm_head, self.get_input_embeddings()) + + def prepare_inputs_for_generation( + self, + input_ids: torch.LongTensor, + past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None, + attention_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs, + ): + # Standard decoder-only prepare_inputs_for_generation: + # if we have past_key_values, only feed the last token. + if past_key_values is not None: + input_ids = input_ids[:, -1:] + + return { + "input_ids": input_ids, + "past_key_values": past_key_values, + "attention_mask": attention_mask, + "use_cache": True, + } + def forward( self, input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[Cache] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, + past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None, + inputs_embeds: Optional[torch.Tensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, - cache_position: Optional[torch.LongTensor] = None, - logits_to_keep: Union[int, torch.Tensor] = 0, - **kwargs: Unpack[TransformersKwargs], - ) -> CausalLMOutputWithPast: - r""" - Example: - - ```python - >>> from transformers import AutoTokenizer, Evo2ForCausalLM - - >>> model = Evo2ForCausalLM.from_pretrained("meta-evo2/Evo2-2-7b-hf") - >>> tokenizer = AutoTokenizer.from_pretrained("meta-evo2/Evo2-2-7b-hf") - - >>> prompt = "Hey, are you conscious? Can you talk to me?" - >>> inputs = tokenizer(prompt, return_tensors="pt") - - >>> # Generate - >>> generate_ids = model.generate(inputs.input_ids, max_length=30) - >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] - "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." - ```""" - outputs: BaseModelOutputWithPast = self.model( + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + ) -> Union[Tuple, CausalLMOutputWithPast]: + outputs = self.model( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, - cache_position=cache_position, - **kwargs, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=True, ) hidden_states = outputs.last_hidden_state - # Only compute necessary logits, and do not upcast them to float if we are not computing the loss - slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep - logits = self.lm_head(hidden_states[:, slice_indices, :]) + logits = self.lm_head(hidden_states) loss = None if labels is not None: - loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs) + # shift for causal LM + shift_logits = logits[:, :-1, :].contiguous() + shift_labels = labels[:, 1:].contiguous() + loss = F.cross_entropy( + shift_logits.view(-1, shift_logits.size(-1)), + shift_labels.view(-1), + ) + + if not return_dict: + output = (logits, outputs.past_key_values) + if output_hidden_states: + output = (logits, outputs.past_key_values, outputs.hidden_states) + return ((loss,) + output) if loss is not None else output return CausalLMOutputWithPast( loss=loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, - attentions=outputs.attentions, ) - - -class Evo2ForTokenClassification(GenericForTokenClassification, Evo2PreTrainedModel): - pass - - -class Evo2ForSequenceClassification(GenericForSequenceClassification, Evo2PreTrainedModel): - pass - - -class Evo2ForQuestionAnswering(GenericForQuestionAnswering, Evo2PreTrainedModel): - pass - - -__all__ = [ - "Evo2ForCausalLM", - "Evo2ForQuestionAnswering", - "Evo2Model", - "Evo2PreTrainedModel", - "Evo2ForSequenceClassification", - "Evo2ForTokenClassification", -] diff --git a/src/transformers/models/evo2/modular_evo2.py b/src/transformers/models/evo2/modular_evo2.py deleted file mode 100644 index e8f3caead0dc..000000000000 --- a/src/transformers/models/evo2/modular_evo2.py +++ /dev/null @@ -1,88 +0,0 @@ -# coding=utf-8 -# Copyright 2025 the HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from ..mistral.configuration_mistral import MistralConfig -from ..mistral.modeling_mistral import ( - MistralAttention, - MistralDecoderLayer, - MistralForCausalLM, - MistralForQuestionAnswering, - MistralForSequenceClassification, - MistralForTokenClassification, - MistralMLP, - MistralModel, - MistralPreTrainedModel, - MistralRMSNorm, - MistralRotaryEmbedding, -) - - -class Evo2Config(MistralConfig): - pass - - -class Evo2MLP(MistralMLP): - pass - - -class Evo2Attention(MistralAttention): - pass - - -class Evo2RMSNorm(MistralRMSNorm): - pass - - -class Evo2DecoderLayer(MistralDecoderLayer): - pass - - -class Evo2PreTrainedModel(MistralPreTrainedModel): - pass - - -class Evo2RotaryEmbedding(MistralRotaryEmbedding): - pass - - -class Evo2Model(MistralModel): - pass - - -class Evo2ForCausalLM(MistralForCausalLM): - pass - - -class Evo2ForTokenClassification(MistralForTokenClassification): - pass - - -class Evo2ForSequenceClassification(MistralForSequenceClassification): - pass - - -class Evo2ForQuestionAnswering(MistralForQuestionAnswering): - pass - - -__all__ = [ - "Evo2Config", - "Evo2ForCausalLM", - "Evo2ForQuestionAnswering", - "Evo2Model", - "Evo2PreTrainedModel", - "Evo2ForSequenceClassification", - "Evo2ForTokenClassification", -] diff --git a/src/transformers/models/evo2/tokenization_evo2.py b/src/transformers/models/evo2/tokenization_evo2.py new file mode 100644 index 000000000000..f6d4833d0d26 --- /dev/null +++ b/src/transformers/models/evo2/tokenization_evo2.py @@ -0,0 +1,220 @@ +# src/transformers/models/evo2/tokenization_evo2.py + +from __future__ import annotations + +import json +import os +from typing import Dict, List, Optional, Tuple + +from transformers.tokenization_utils import PreTrainedTokenizer +from transformers.utils import logging + +logger = logging.get_logger(__name__) + + +VOCAB_FILES_NAMES = { + "vocab_file": "vocab.json", +} + +PRETRAINED_VOCAB_FILES_MAP = { + "vocab_file": { + # You can fill these in once you upload a checkpoint + # "arcinstitute/evo2-1b-8k": "https://huggingface.co/arcinstitute/evo2-1b-8k/resolve/main/vocab.json", + } +} + +PRETRAINED_INIT_CONFIGURATION = { + # "arcinstitute/evo2-1b-8k": {}, +} + + +class Evo2Tokenizer(PreTrainedTokenizer): + """ + Hugging Face wrapper around the Evo2 CharLevelTokenizer. + + - Encoding: + text.encode("utf-8") -> list of uint8 bytes in [0, 255] + - Token IDs: + those bytes directly used as IDs (0..255). + `vocab_size` can be larger (e.g. 512), but extra IDs are unused. + - Decoding: + clamp each id with `clamp(n) = max(32, min(n, vocab_size))` + then `chr(clamp(n))` and join. + + We implement vocab as stringified integers: "0" -> 0, "1" -> 1, etc. + """ + + vocab_files_names = VOCAB_FILES_NAMES + pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP + pretrained_init_configuration = PRETRAINED_INIT_CONFIGURATION + + def __init__( + self, + vocab_file: Optional[str] = None, + vocab_size: int = 512, + # Match original CharLevelTokenizer semantics: + # eod_id = eos_id = 0, pad_id = 1 + eos_token: str = "0", + pad_token: str = "1", + unk_token: str = "0", # there is no real "unknown" in char-level; anything maps to a byte + bos_token: Optional[str] = None, + **kwargs, + ): + self._vocab_size = vocab_size + + if vocab_file is None: + # Default vocab: token "0" -> id 0, "1" -> id 1, ..., up to vocab_size-1 + self.vocab: Dict[str, int] = {str(i): i for i in range(vocab_size)} + else: + with open(vocab_file, "r", encoding="utf-8") as f: + self.vocab = json.load(f) + # Ensure ids are ints + self.vocab = {str(k): int(v) for k, v in self.vocab.items()} + + self.ids_to_tokens = {v: k for k, v in self.vocab.items()} + + # Call parent ctor (this also sets pad/eos/bos/unk attributes) + super().__init__( + eos_token=eos_token, + pad_token=pad_token, + bos_token=bos_token, # None by default; CharLevelTokenizer has no BOS + unk_token=unk_token, + **kwargs, + ) + + # Cache some commonly used ids + self._eos_id = int(eos_token) if bos_token is None else self.vocab[eos_token] + self._pad_id = int(pad_token) + self._unk_id = int(unk_token) + + # ---- Char-level core logic --------------------------------------------- + + @property + def vocab_size(self) -> int: + return self._vocab_size + + def get_vocab(self) -> Dict[str, int]: + return dict(self.vocab) + + def clamp(self, n: int) -> int: + # Same as in CharLevelTokenizer: max(32, min(n, vocab_size)) + return max(32, min(n, self._vocab_size)) + + # HF will call this to get string "tokens" before converting to ids + def _tokenize(self, text: str, **kwargs) -> List[str]: + # CharLevelTokenizer.tokenize: + # list(np.frombuffer(text.encode('utf-8'), dtype=np.uint8)) + # We can replicate with Python directly: + byte_ids = list(text.encode("utf-8")) # each in [0, 255] + # Represent each id as a string token "id" + return [str(b) for b in byte_ids] + + def _convert_token_to_id(self, token: str) -> int: + # Tokens we produce are numeric strings "0", "1", ... + try: + idx = int(token) + except ValueError: + # Shouldn't really happen with our _tokenize, but just in case + return self._unk_id + # CharLevelTokenizer allows any 0..255; we don't clamp on encode. + # (clamp is only used on decode) + if 0 <= idx < self._vocab_size: + return idx + # If out-of-range, fall back to unk + return self._unk_id + + def _convert_id_to_token(self, index: int) -> str: + # Represent ids as numeric strings consistently + if 0 <= index < self._vocab_size: + return str(index) + return str(self._unk_id) + + def convert_tokens_to_string(self, tokens: List[str]) -> str: + # CharLevelTokenizer.detokenize: + # "".join(chr(clamp(token)) for token in token_ids) + chars: List[str] = [] + for tok in tokens: + try: + idx = int(tok) + except ValueError: + idx = self._unk_id + c = chr(self.clamp(idx)) + chars.append(c) + return "".join(chars) + + # ---- Special tokens / sequence helpers --------------------------------- + + def build_inputs_with_special_tokens( + self, + token_ids_0: List[int], + token_ids_1: Optional[List[int]] = None, + ) -> List[int]: + """ + CharLevelTokenizer does *not* add BOS/EOS automatically, so we just + return the sequence as-is. + + We also do not support sentence pairs. + """ + if token_ids_1 is not None: + raise ValueError("Evo2Tokenizer (CharLevel) does not support sentence pairs.") + + return token_ids_0 + + def get_special_tokens_mask( + self, + token_ids_0: List[int], + token_ids_1: Optional[List[int]] = None, + already_has_special_tokens: bool = False, + ) -> List[int]: + """ + Mark eos/eod (id 0) and pad (id 1) as special, everything else as 0. + """ + if token_ids_1 is not None: + raise ValueError("Evo2Tokenizer (CharLevel) does not support sentence pairs.") + + if already_has_special_tokens: + # Just mark known special IDs + return [ + 1 if t in {self._eos_id, self._pad_id} else 0 + for t in token_ids_0 + ] + + # We don't auto-add any extra tokens, so same as above + return [ + 1 if t in {self._eos_id, self._pad_id} else 0 + for t in token_ids_0 + ] + + def create_token_type_ids_from_sequences( + self, + token_ids_0: List[int], + token_ids_1: Optional[List[int]] = None, + ) -> List[int]: + """ + No token type IDs; everything is 0, like most decoder-only models. + """ + if token_ids_1 is not None: + raise ValueError("Evo2Tokenizer (CharLevel) does not support sentence pairs.") + + return [0] * len(token_ids_0) + + # ---- Saving / loading vocab -------------------------------------------- + + def save_vocabulary( + self, + save_directory: str, + filename_prefix: Optional[str] = None, + ) -> Tuple[str]: + if not os.path.isdir(save_directory): + os.makedirs(save_directory, exist_ok=True) + + vocab_file = ( + (filename_prefix + "-" if filename_prefix else "") + + VOCAB_FILES_NAMES["vocab_file"] + ) + vocab_path = os.path.join(save_directory, vocab_file) + + with open(vocab_path, "w", encoding="utf-8") as f: + json.dump(self.vocab, f, ensure_ascii=False, indent=2) + + return (vocab_path,) diff --git a/tests/models/evo2/test_modeling_evo2.py b/tests/models/evo2/test_modeling_evo2.py index 9d7b77f0d0f8..b53ba1ec252f 100644 --- a/tests/models/evo2/test_modeling_evo2.py +++ b/tests/models/evo2/test_modeling_evo2.py @@ -1,42 +1,20 @@ # coding=utf-8 -# Copyright 2025 the HuggingFace Team. All rights reserved. +# Copyright 2025 the HuggingFace Team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. -# You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Testing suite for the PyTorch Evo2 model.""" +# This file contains the *unit* tests for the Evo2 model, based on the +# shared CausalLMModelTester utilities. Integration tests that depend on +# public Hub checkpoints or special hardware can be added later once the +# official Evo2 weights are wired to this architecture. -import gc import unittest -import pytest -from packaging import version -from parameterized import parameterized +from transformers import is_torch_available +from transformers.testing_utils import require_torch -from transformers import AutoTokenizer, BitsAndBytesConfig, DynamicCache, is_torch_available, set_seed -from transformers.cache_utils import DynamicSlidingWindowLayer -from transformers.testing_utils import ( - DeviceProperties, - Expectations, - backend_empty_cache, - cleanup, - get_device_properties, - require_bitsandbytes, - require_flash_attn, - require_read_token, - require_torch, - require_torch_accelerator, - slow, - torch_device, -) +from ...causal_lm_tester import CausalLMModelTest, CausalLMModelTester if is_torch_available(): @@ -46,19 +24,42 @@ Evo2ForCausalLM, Evo2Model, ) -from ...causal_lm_tester import CausalLMModelTest, CausalLMModelTester class Evo2ModelTester(CausalLMModelTester): + """ + Minimal tester for Evo2 that plugs into the shared causal LM test + harness. We just need to specify the base and LM classes; the generic + tester will handle: + - building a small config + - instantiating Evo2Model / Evo2ForCausalLM + - running forward / loss / generate / save-load tests + """ + if is_torch_available(): base_model_class = Evo2Model + lm_model_class = Evo2ForCausalLM + + # If you want to tweak the tiny test config (e.g. reduce sizes), + # you can override `prepare_config_and_inputs` or `get_config` here. @require_torch class Evo2ModelTest(CausalLMModelTest, unittest.TestCase): + """ + Generic causal LM tests for Evo2. + + These tests: + - instantiate tiny Evo2 configs + - run forward passes + - check loss computation + - check generation API + - test save / load / from_pretrained with local weights + """ + model_tester_class = Evo2ModelTester - # TODO (ydshieh): Check this. See https://app.circleci.com/pipelines/github/huggingface/transformers/79245/workflows/9490ef58-79c2-410d-8f51-e3495156cf9c/jobs/1012146 + # Pipelines for this model are not wired yet; skip pipeline tests. def is_pipeline_test_to_skip( self, pipeline_test_case_name, @@ -70,423 +71,3 @@ def is_pipeline_test_to_skip( processor_name, ): return True - - -@require_torch_accelerator -@require_read_token -class Evo2IntegrationTest(unittest.TestCase): - # This variable is used to determine which accelerator are we using for our runners (e.g. A10 or T4) - # Depending on the hardware we get different logits / generations - device_properties: DeviceProperties = (None, None, None) - - @classmethod - def setUpClass(cls): - cls.device_properties = get_device_properties() - - def setUp(self): - cleanup(torch_device, gc_collect=True) - - def tearDown(self): - cleanup(torch_device, gc_collect=True) - - @slow - def test_model_7b_logits(self): - input_ids = [1, 306, 4658, 278, 6593, 310, 2834, 338] - model = Evo2ForCausalLM.from_pretrained("mistralai/Evo2-7B-v0.1", device_map="auto", dtype=torch.float16) - input_ids = torch.tensor([input_ids]).to(model.model.embed_tokens.weight.device) - with torch.no_grad(): - out = model(input_ids).logits.float().cpu() - # Expected mean on dim = -1 - EXPECTED_MEAN = torch.tensor([[-2.5548, -2.5737, -3.0600, -2.5906, -2.8478, -2.8118, -2.9325, -2.7694]]) - torch.testing.assert_close(out.mean(-1), EXPECTED_MEAN, rtol=1e-2, atol=1e-2) - - # ("cuda", 8) for A100/A10, and ("cuda", 7) 7 for T4. - # considering differences in hardware processing and potential deviations in output. - # fmt: off - EXPECTED_SLICES = Expectations( - { - ("cuda", 7): torch.tensor([-5.8828, -5.8633, -0.1042, -4.7266, -5.8828, -5.8789, -5.8789, -5.8828, -5.8828, -5.8828, -5.8828, -5.8828, -1.0801, 1.7598, -5.8828, -5.8828, -5.8828, -5.8828, -5.8828, -5.8828, -5.8828, -5.8828, -5.8828, -5.8828, -5.8828, -5.8828, -5.8828, -5.8828, -5.8828, -5.8828]), - ("cuda", 8): torch.tensor([-5.8711, -5.8555, -0.1050, -4.7148, -5.8711, -5.8711, -5.8711, -5.8711, -5.8711, -5.8711, -5.8711, -5.8711, -1.0781, 1.7568, -5.8711, -5.8711, -5.8711, -5.8711, -5.8711, -5.8711, -5.8711, -5.8711, -5.8711, -5.8711, -5.8711, -5.8711, -5.8711, -5.8711, -5.8711, -5.8711]), - ("rocm", 9): torch.tensor([-5.8750, -5.8594, -0.1047, -4.7188, -5.8750, -5.8750, -5.8750, -5.8750, -5.8750, -5.8750, -5.8750, -5.8750, -1.0781, 1.7578, -5.8750, -5.8750, -5.8750, -5.8750, -5.8750, -5.8750, -5.8750, -5.8750, -5.8750, -5.8750, -5.8750, -5.8750, -5.8750, -5.8750, -5.8750, -5.8750]), - } - ) - # fmt: on - expected_slice = EXPECTED_SLICES.get_expectation() - - torch.testing.assert_close(out[0, 0, :30], expected_slice, atol=1e-4, rtol=1e-4) - - @slow - @require_bitsandbytes - def test_model_7b_generation(self): - EXPECTED_TEXT_COMPLETION = "My favourite condiment is 100% ketchup. I’m not a fan of mustard, mayo," - - prompt = "My favourite condiment is " - tokenizer = AutoTokenizer.from_pretrained("mistralai/Evo2-7B-v0.1", use_fast=False) - model = Evo2ForCausalLM.from_pretrained( - "mistralai/Evo2-7B-v0.1", - device_map={"": torch_device}, - quantization_config=BitsAndBytesConfig(load_in_4bit=True), - ) - input_ids = tokenizer.encode(prompt, return_tensors="pt").to(model.model.embed_tokens.weight.device) - - # greedy generation outputs - generated_ids = model.generate(input_ids, max_new_tokens=20, temperature=0) - text = tokenizer.decode(generated_ids[0], skip_special_tokens=True) - self.assertEqual(EXPECTED_TEXT_COMPLETION, text) - - # TODO joao, manuel: remove this in v4.62.0 - @slow - def test_model_7b_dola_generation(self): - # ground truth text generated with dola_layers="low", repetition_penalty=1.2 - EXPECTED_TEXT_COMPLETION = ( - """My favourite condiment is 100% ketchup. I love it on everything, and I’m not ash""" - ) - prompt = "My favourite condiment is " - tokenizer = AutoTokenizer.from_pretrained("mistralai/Evo2-7B-v0.1", use_fast=False) - model = Evo2ForCausalLM.from_pretrained("mistralai/Evo2-7B-v0.1", device_map="auto", dtype=torch.float16) - input_ids = tokenizer.encode(prompt, return_tensors="pt").to(model.model.embed_tokens.weight.device) - - # greedy generation outputs - generated_ids = model.generate( - input_ids, - max_new_tokens=20, - temperature=0, - dola_layers="low", - repetition_penalty=1.2, - trust_remote_code=True, - custom_generate="transformers-community/dola", - ) - text = tokenizer.decode(generated_ids[0], skip_special_tokens=True) - self.assertEqual(EXPECTED_TEXT_COMPLETION, text) - - del model - backend_empty_cache(torch_device) - gc.collect() - - @require_flash_attn - @require_bitsandbytes - @slow - @pytest.mark.flash_attn_test - def test_model_7b_long_prompt(self): - EXPECTED_OUTPUT_TOKEN_IDS = [306, 338] - # An input with 4097 tokens that is above the size of the sliding window - input_ids = [1] + [306, 338] * 2048 - model = Evo2ForCausalLM.from_pretrained( - "mistralai/Evo2-7B-v0.1", - device_map={"": torch_device}, - quantization_config=BitsAndBytesConfig(load_in_4bit=True), - attn_implementation="flash_attention_2", - ) - input_ids = torch.tensor([input_ids]).to(model.model.embed_tokens.weight.device) - generated_ids = model.generate(input_ids, max_new_tokens=4, temperature=0) - self.assertEqual(EXPECTED_OUTPUT_TOKEN_IDS, generated_ids[0][-2:].tolist()) - - # Assisted generation - assistant_model = model - assistant_model.generation_config.num_assistant_tokens = 2 - assistant_model.generation_config.num_assistant_tokens_schedule = "constant" - generated_ids = model.generate(input_ids, max_new_tokens=4, temperature=0) - self.assertEqual(EXPECTED_OUTPUT_TOKEN_IDS, generated_ids[0][-2:].tolist()) - - @slow - def test_model_7b_long_prompt_sdpa(self): - EXPECTED_OUTPUT_TOKEN_IDS = [306, 338] - # An input with 4097 tokens that is above the size of the sliding window - input_ids = [1] + [306, 338] * 2048 - model = Evo2ForCausalLM.from_pretrained( - "mistralai/Evo2-7B-v0.1", device_map="auto", attn_implementation="sdpa", dtype=torch.float16 - ) - input_ids = torch.tensor([input_ids]).to(model.model.embed_tokens.weight.device) - generated_ids = model.generate(input_ids, max_new_tokens=4, temperature=0) - self.assertEqual(EXPECTED_OUTPUT_TOKEN_IDS, generated_ids[0][-2:].tolist()) - - # Assisted generation - assistant_model = model - assistant_model.generation_config.num_assistant_tokens = 2 - assistant_model.generation_config.num_assistant_tokens_schedule = "constant" - generated_ids = model.generate(input_ids, max_new_tokens=4, temperature=0) - self.assertEqual(EXPECTED_OUTPUT_TOKEN_IDS, generated_ids[0][-2:].tolist()) - - del assistant_model - - backend_empty_cache(torch_device) - gc.collect() - - EXPECTED_TEXT_COMPLETION = """My favourite condiment is 100% ketchup. I love it on everything. I’m not a big""" - prompt = "My favourite condiment is " - tokenizer = AutoTokenizer.from_pretrained("mistralai/Evo2-7B-v0.1", use_fast=False) - - input_ids = tokenizer.encode(prompt, return_tensors="pt").to(model.model.embed_tokens.weight.device) - - # greedy generation outputs - generated_ids = model.generate(input_ids, max_new_tokens=20, temperature=0) - text = tokenizer.decode(generated_ids[0], skip_special_tokens=True) - self.assertEqual(EXPECTED_TEXT_COMPLETION, text) - - @slow - def test_speculative_generation(self): - EXPECTED_TEXT_COMPLETION = "My favourite condiment is 100% ketchup. I’m not a fan of mustard, relish" - prompt = "My favourite condiment is " - tokenizer = AutoTokenizer.from_pretrained("mistralai/Evo2-7B-v0.1", use_fast=False) - model = Evo2ForCausalLM.from_pretrained("mistralai/Evo2-7B-v0.1", device_map="auto", dtype=torch.float16) - input_ids = tokenizer.encode(prompt, return_tensors="pt").to(model.model.embed_tokens.weight.device) - - # greedy generation outputs - set_seed(0) - generated_ids = model.generate( - input_ids, max_new_tokens=20, do_sample=True, temperature=0.3, assistant_model=model - ) - text = tokenizer.decode(generated_ids[0], skip_special_tokens=True) - self.assertEqual(EXPECTED_TEXT_COMPLETION, text) - - @pytest.mark.torch_compile_test - @slow - def test_compile_static_cache(self): - # `torch==2.2` will throw an error on this test (as in other compilation tests), but torch==2.1.2 and torch>2.2 - # work as intended. See https://github.com/pytorch/pytorch/issues/121943 - if version.parse(torch.__version__) < version.parse("2.3.0"): - self.skipTest(reason="This test requires torch >= 2.3 to run.") - - if self.device_properties[0] == "cuda" and self.device_properties[1] == 7: - self.skipTest(reason="This test is failing (`torch.compile` fails) on Nvidia T4 GPU.") - - NUM_TOKENS_TO_GENERATE = 40 - EXPECTED_TEXT_COMPLETION = [ - "My favourite condiment is 100% ketchup. I love it on everything. " - "I’m not a big fan of mustard, mayo, or relish. I’m not a fan of pickles" - ] - - prompts = ["My favourite condiment is "] - tokenizer = AutoTokenizer.from_pretrained("mistralai/Evo2-7B-v0.1", use_fast=False) - tokenizer.pad_token = tokenizer.eos_token - model = Evo2ForCausalLM.from_pretrained("mistralai/Evo2-7B-v0.1", device_map=torch_device, dtype=torch.float16) - inputs = tokenizer(prompts, return_tensors="pt", padding=True).to(model.device) - - # Dynamic Cache - generated_ids = model.generate(**inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False) - dynamic_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) - self.assertEqual(EXPECTED_TEXT_COMPLETION, dynamic_text) - - # Static Cache - generated_ids = model.generate( - **inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False, cache_implementation="static" - ) - static_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) - self.assertEqual(EXPECTED_TEXT_COMPLETION, static_text) - - # Sliding Window Cache - generated_ids = model.generate( - **inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False, cache_implementation="sliding_window" - ) - static_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) - self.assertEqual(EXPECTED_TEXT_COMPLETION, static_text) - - # Static Cache + compile - forward_function = model.__call__ - model.__call__ = torch.compile(forward_function, mode="reduce-overhead", fullgraph=True) - generated_ids = model.generate( - **inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False, cache_implementation="static" - ) - static_compiled_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) - self.assertEqual(EXPECTED_TEXT_COMPLETION, static_compiled_text) - - # Sliding Window Cache + compile - torch._dynamo.reset() - model.__call__ = torch.compile(forward_function, mode="reduce-overhead", fullgraph=True) - generated_ids = model.generate( - **inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False, cache_implementation="sliding_window" - ) - static_compiled_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) - self.assertEqual(EXPECTED_TEXT_COMPLETION, static_compiled_text) - - @pytest.mark.flash_attn_test - @parameterized.expand([("flash_attention_2",), ("sdpa",), ("flex_attention",), ("eager",)]) - @require_flash_attn - @slow - def test_generation_beyond_sliding_window_dynamic(self, attn_implementation: str): - """Test that we can correctly generate beyond the sliding window. This is non-trivial as Evo2 will use - a DynamicCache with only sliding layers.""" - - # Impossible to test it with this model (even with < 100 tokens), probably due to the compilation of a large model. - if attn_implementation == "flex_attention": - self.skipTest( - reason="`flex_attention` gives `torch._inductor.exc.InductorError: RuntimeError: No valid triton configs. OutOfMemoryError: out of resource: triton_tem_fused_0 Required: 147456 Hardware limit:101376 Reducing block sizes or `num_stages` may help.`" - ) - - model_id = "mistralai/Evo2-7B-v0.1" - EXPECTED_COMPLETIONS = [ - "scenery, scenery, scenery, scenery, scenery,", - ", green, yellow, orange, purple, pink, brown, black, white, gray, silver", - ] - - input_text = [ - "This is a nice place. " * 682 + "I really enjoy the scenery,", # This has 4101 tokens, 15 more than 4096 - "A list of colors: red, blue", # This will almost all be padding tokens - ] - - if attn_implementation == "eager": - input_text = input_text[:1] - - tokenizer = AutoTokenizer.from_pretrained(model_id, padding="left") - tokenizer.pad_token_id = tokenizer.eos_token_id - inputs = tokenizer(input_text, padding=True, return_tensors="pt").to(torch_device) - - model = Evo2ForCausalLM.from_pretrained( - model_id, attn_implementation=attn_implementation, device_map=torch_device, dtype=torch.float16 - ) - - # Make sure prefill is larger than sliding window - batch_size, input_size = inputs.input_ids.shape - self.assertTrue(input_size > model.config.sliding_window) - - # Should already be Dynamic by default, but let's make sure! - out = model.generate(**inputs, max_new_tokens=20, cache_implementation="dynamic", return_dict_in_generate=True) - output_text = tokenizer.batch_decode(out.sequences[:batch_size, input_size:]) - - self.assertEqual(output_text, EXPECTED_COMPLETIONS[:batch_size]) - - # Let's check that the dynamic cache has hybrid layers! - dynamic_cache = out.past_key_values - self.assertTrue(isinstance(dynamic_cache, DynamicCache)) - for layer in dynamic_cache.layers: - self.assertTrue(isinstance(layer, DynamicSlidingWindowLayer)) - self.assertEqual(layer.keys.shape[-2], model.config.sliding_window - 1) - - -@slow -@require_torch_accelerator -class Mask4DTestHard(unittest.TestCase): - model_name = "mistralai/Evo2-7B-v0.1" - model = None - model_dtype = None - - @classmethod - def setUpClass(cls): - cleanup(torch_device, gc_collect=True) - if cls.model_dtype is None: - cls.model_dtype = torch.float16 - if cls.model is None: - cls.model = Evo2ForCausalLM.from_pretrained(cls.model_name, dtype=cls.model_dtype).to(torch_device) - - @classmethod - def tearDownClass(cls): - del cls.model_dtype - del cls.model - cleanup(torch_device, gc_collect=True) - - def setUp(self): - cleanup(torch_device, gc_collect=True) - self.tokenizer = AutoTokenizer.from_pretrained(self.model_name, use_fast=False) - - def tearDown(self): - cleanup(torch_device, gc_collect=True) - - def get_test_data(self): - template = "my favorite {}" - items = ("pet is a", "artist plays a", "name is L") # same number of tokens in each item - - batch_separate = [template.format(x) for x in items] # 3 separate lines - batch_shared_prefix = template.format(" ".join(items)) # 1 line with options concatenated - - input_ids = self.tokenizer(batch_separate, return_tensors="pt").input_ids.to(torch_device) - input_ids_shared_prefix = self.tokenizer(batch_shared_prefix, return_tensors="pt").input_ids.to(torch_device) - - mask_shared_prefix = torch.tensor( - [ - [ - [ - [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - [1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - [1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0], - [1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0], - [1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0], - [1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0], - [1, 1, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0], - [1, 1, 1, 0, 0, 0, 1, 1, 0, 0, 0, 0], - [1, 1, 1, 0, 0, 0, 1, 1, 1, 0, 0, 0], - [1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0], - [1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0], - [1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 1], - ] - ] - ], - device=torch_device, - ) - - position_ids = torch.arange(input_ids.shape[1]).tile(input_ids.shape[0], 1).to(torch_device) - - # building custom positions ids based on custom mask - position_ids_shared_prefix = (mask_shared_prefix.sum(dim=-1) - 1).reshape(1, -1) - # effectively: position_ids_shared_prefix = torch.tensor([[0, 1, 2, 3, 4, 5, 3, 4, 5, 3, 4, 5]]).to(device) - - # inverting the mask - min_dtype = torch.finfo(self.model_dtype).min - mask_shared_prefix = (mask_shared_prefix.eq(0.0)).to(dtype=self.model_dtype) * min_dtype - - return input_ids, position_ids, input_ids_shared_prefix, mask_shared_prefix, position_ids_shared_prefix - - def test_stacked_causal_mask(self): - ( - input_ids, - position_ids, - input_ids_shared_prefix, - mask_shared_prefix, - position_ids_shared_prefix, - ) = self.get_test_data() - - # regular batch - logits = self.model.forward(input_ids, position_ids=position_ids).logits - logits_last = logits[:, -1, :] # last tokens in each batch line - decoded = [self.tokenizer.decode(t) for t in logits_last.argmax(dim=-1)] - - # single forward run with 4D custom mask - logits_shared_prefix = self.model.forward( - input_ids_shared_prefix, attention_mask=mask_shared_prefix, position_ids=position_ids_shared_prefix - ).logits - logits_shared_prefix_last = logits_shared_prefix[ - 0, torch.where(position_ids_shared_prefix == position_ids_shared_prefix.max())[1], : - ] # last three tokens - decoded_shared_prefix = [self.tokenizer.decode(t) for t in logits_shared_prefix_last.argmax(dim=-1)] - - self.assertEqual(decoded, decoded_shared_prefix) - - def test_partial_stacked_causal_mask(self): - # Same as the test above, but the input is passed in two groups. It tests that we can pass partial 4D attention masks - - ( - input_ids, - position_ids, - input_ids_shared_prefix, - mask_shared_prefix, - position_ids_shared_prefix, - ) = self.get_test_data() - - # regular batch - logits = self.model.forward(input_ids, position_ids=position_ids).logits - logits_last = logits[:, -1, :] # last tokens in each batch line - decoded = [self.tokenizer.decode(t) for t in logits_last.argmax(dim=-1)] - - # 2 forward runs with custom 4D masks - part_a = 3 # split point - - input_1a = input_ids_shared_prefix[:, :part_a] - position_ids_1a = position_ids_shared_prefix[:, :part_a] - mask_1a = mask_shared_prefix[:, :, :part_a, :part_a] - - outs_1a = self.model.forward(input_1a, attention_mask=mask_1a, position_ids=position_ids_1a) - past_key_values_a = outs_1a["past_key_values"] - - # Case 1: we pass a 4D attention mask regarding the current sequence length (i.e. [..., seq_len, full_len]) - input_1b = input_ids_shared_prefix[:, part_a:] - position_ids_1b = position_ids_shared_prefix[:, part_a:] - mask_1b = mask_shared_prefix[:, :, part_a:, :] - outs_1b = self.model.forward( - input_1b, attention_mask=mask_1b, position_ids=position_ids_1b, past_key_values=past_key_values_a - ) - decoded_1b = [ - self.tokenizer.decode(t) - for t in outs_1b.logits.argmax(-1)[ - 0, torch.where(position_ids_shared_prefix == position_ids_shared_prefix.max())[1] - part_a - ] - ] - self.assertEqual(decoded, decoded_1b) diff --git a/tests/models/evo2/test_tokenization_evo2.py b/tests/models/evo2/test_tokenization_evo2.py new file mode 100644 index 000000000000..1d33cbe43e2c --- /dev/null +++ b/tests/models/evo2/test_tokenization_evo2.py @@ -0,0 +1,110 @@ +# Copyright 2025 +# +# Licensed under the Apache License, Version 2.0 (the "License"); + +import json +import os +import tempfile +import unittest + +from transformers.models.evo2.tokenization_evo2 import VOCAB_FILES_NAMES, Evo2Tokenizer +from transformers.tokenization_utils import PreTrainedTokenizer +from transformers.tokenization_utils_base import PreTrainedTokenizerBase + + +class Evo2TokenizationTest(unittest.TestCase): + tokenizer_class = Evo2Tokenizer + + @classmethod + def setUpClass(cls): + super().setUpClass() + + cls.tmpdirname = tempfile.mkdtemp() + + # Build a simple numeric vocab: "0" -> 0, "1" -> 1, ..., "255" -> 255 + vocab_size = 256 + vocab = {str(i): i for i in range(vocab_size)} + + cls.vocab_file = os.path.join(cls.tmpdirname, VOCAB_FILES_NAMES["vocab_file"]) + with open(cls.vocab_file, "w", encoding="utf-8") as vocab_writer: + json.dump(vocab, vocab_writer) + + def get_tokenizers(cls, **kwargs) -> list[PreTrainedTokenizerBase]: + return [cls.get_tokenizer(**kwargs)] + + @classmethod + def get_tokenizer(cls, pretrained_name=None, **kwargs) -> PreTrainedTokenizer: + pretrained_name = pretrained_name or cls.tmpdirname + return cls.tokenizer_class.from_pretrained(pretrained_name, **kwargs) + + def test_tokenizer_single_example(self): + # Direct constructor + tokenizer = self.tokenizer_class(self.vocab_file, vocab_size=256) + + text = "ABC" + # ASCII codes: A=65, B=66, C=67 + tokens = tokenizer.tokenize(text) + self.assertListEqual(tokens, ["65", "66", "67"]) + self.assertListEqual(tokenizer.convert_tokens_to_ids(tokens), [65, 66, 67]) + + def test_tokenizer_encode_single(self): + tokenizer = self.tokenizer_class(self.vocab_file, vocab_size=256) + + text = "ABC" + # encode() should NOT add BOS/EOS for this char-level tokenizer + self.assertListEqual(tokenizer.encode(text), [65, 66, 67]) + + def test_tokenizer_call_no_pad(self): + tokenizer = self.tokenizer_class(self.vocab_file, vocab_size=256) + + seq_batch = ["AB", "XYZ"] + encoded = tokenizer(seq_batch, padding=False)["input_ids"] + + # "AB" -> 65,66 ; "XYZ" -> 88,89,90 + self.assertListEqual(encoded, [[65, 66], [88, 89, 90]]) + + def test_tokenizer_call_pad(self): + tokenizer = self.tokenizer_class(self.vocab_file, vocab_size=256) + + seq_batch = ["AB", "XYZ"] + encoded = tokenizer(seq_batch, padding=True)["input_ids"] + + # pad_token_id should be 1, so shorter seq gets padded with 1 + # max length = 3 + self.assertEqual(tokenizer.pad_token_id, 1) + self.assertListEqual(encoded, [[65, 66, 1], [88, 89, 90]]) + + def test_detokenize_roundtrip(self): + tokenizer = self.tokenizer_class(self.vocab_file, vocab_size=256) + + text = "Hello!" + ids = tokenizer.encode(text) + decoded = tokenizer.decode(ids, skip_special_tokens=False) + + # Because of clamp, some low values could be bumped, but ASCII letters + # should round-trip cleanly. + self.assertEqual(decoded, text) + + def test_add_tokens(self): + tokenizer = self.tokenizer_class(self.vocab_file, vocab_size=256) + + vocab_size = len(tokenizer) + self.assertEqual(tokenizer.add_tokens(""), 0) + self.assertEqual(tokenizer.add_tokens("testtoken"), 1) + self.assertEqual(tokenizer.add_tokens(["testtoken1", "testtoken2"]), 2) + self.assertEqual(len(tokenizer), vocab_size + 3) + + self.assertEqual(tokenizer.add_special_tokens({}), 0) + self.assertEqual(tokenizer.add_special_tokens({"bos_token": "[BOS]", "eos_token": "[EOS]"}), 2) + + # additional_special_tokens logic + self.assertRaises(AssertionError, tokenizer.add_special_tokens, {"additional_special_tokens": ""}) + self.assertEqual(tokenizer.add_special_tokens({"additional_special_tokens": [""]}), 1) + self.assertEqual( + tokenizer.add_special_tokens({"additional_special_tokens": ["", ""]}), 2 + ) + self.assertIn("", tokenizer.special_tokens_map["additional_special_tokens"]) + self.assertIsInstance(tokenizer.special_tokens_map["additional_special_tokens"], list) + self.assertGreaterEqual(len(tokenizer.special_tokens_map["additional_special_tokens"]), 2) + + self.assertEqual(len(tokenizer), vocab_size + 8) From bafc96ebf8cea837a43764c457c5ebcc198265a1 Mon Sep 17 00:00:00 2001 From: McClain Thiel Date: Sun, 16 Nov 2025 12:13:35 +0000 Subject: [PATCH 3/6] Align Evo2 rotary handling with HF generation --- src/transformers/__init__.py | 8 + .../models/auto/tokenization_auto.py | 1 + src/transformers/models/evo2/__init__.py | 50 +- .../models/evo2/configuration_evo2.py | 357 +++------ .../evo2/convert_evo2_original_to_hf.py | 214 ++++++ src/transformers/models/evo2/modeling_evo2.py | 715 ++++++++++-------- .../models/evo2/tokenization_evo2.py | 274 +++---- tests/models/evo2/test_modeling_evo2.py | 87 +-- tests/models/evo2/test_tokenization_evo2.py | 122 +-- 9 files changed, 912 insertions(+), 916 deletions(-) create mode 100644 src/transformers/models/evo2/convert_evo2_original_to_hf.py diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index ae0e0b67c874..75d8cac2b31e 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -264,6 +264,14 @@ "utils.kernel_config": ["KernelConfig"], } +_import_structure["models.evo2"] = [ + "Evo2Config", + "Evo2ForCausalLM", + "Evo2Model", + "Evo2PreTrainedModel", + "Evo2Tokenizer", +] + # tokenizers-backed objects try: if not is_tokenizers_available(): diff --git a/src/transformers/models/auto/tokenization_auto.py b/src/transformers/models/auto/tokenization_auto.py index c830fee86987..1cb97b2fd90e 100644 --- a/src/transformers/models/auto/tokenization_auto.py +++ b/src/transformers/models/auto/tokenization_auto.py @@ -241,6 +241,7 @@ ("ernie4_5_moe", (None, "LlamaTokenizerFast" if is_tokenizers_available() else None)), ("ernie_m", ("ErnieMTokenizer" if is_sentencepiece_available() else None, None)), ("esm", ("EsmTokenizer", None)), + ("evo2", ("Evo2Tokenizer", None)), ("evolla", (None, "LlamaTokenizerFast" if is_tokenizers_available() else None)), ( "exaone4", diff --git a/src/transformers/models/evo2/__init__.py b/src/transformers/models/evo2/__init__.py index 3b196962679e..e48f532a620f 100644 --- a/src/transformers/models/evo2/__init__.py +++ b/src/transformers/models/evo2/__init__.py @@ -1,29 +1,39 @@ -# coding=utf-8 -# Copyright 2025 the HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. +"""Evo2 model, tokenizer, and configuration.""" from typing import TYPE_CHECKING -from ...utils import _LazyModule -from ...utils.import_utils import define_import_structure +from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available + +_import_structure = { + "configuration_evo2": ["Evo2Config"], + "tokenization_evo2": ["Evo2Tokenizer"], +} + +try: + if not is_torch_available(): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + pass +else: + _import_structure["modeling_evo2"] = [ + "Evo2ForCausalLM", + "Evo2Model", + "Evo2PreTrainedModel", + ] if TYPE_CHECKING: - from .configuration_evo2 import * - from .modeling_evo2 import * + from .configuration_evo2 import Evo2Config + from .tokenization_evo2 import Evo2Tokenizer + + try: + if not is_torch_available(): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + pass + else: + from .modeling_evo2 import Evo2ForCausalLM, Evo2Model, Evo2PreTrainedModel else: import sys - _file = globals()["__file__"] - sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) + sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure) diff --git a/src/transformers/models/evo2/configuration_evo2.py b/src/transformers/models/evo2/configuration_evo2.py index 30ca688ecb80..01aa5256088b 100644 --- a/src/transformers/models/evo2/configuration_evo2.py +++ b/src/transformers/models/evo2/configuration_evo2.py @@ -1,276 +1,121 @@ -# src/transformers/models/evo2/configuration_evo2.py +"""Evo2 model configuration.""" from __future__ import annotations -from typing import List, Optional +from typing import Optional, Sequence -from transformers.configuration_utils import PretrainedConfig -from transformers.utils import logging +from ...configuration_utils import PretrainedConfig +from ...modeling_rope_utils import standardize_rope_params +from ...utils import logging -logger = logging.get_logger(__name__) - - -class Evo2Config(PretrainedConfig): - r""" - This is the configuration class to store the configuration of an :class:`~transformers.Evo2ForCausalLM` model. - - It is inspired by the StripedHyena2-based Evo 2 DNA foundation model. - - Args: - vocab_size (`int`, *optional*, defaults to 512): - Vocabulary size of the model. - hidden_size (`int`, *optional*, defaults to 1920): - Dimension of the hidden representations. - num_layers (`int`, *optional*, defaults to 25): - Number of layers (Hyena / attention blocks). - num_attention_heads (`int`, *optional*, defaults to 15): - Number of attention heads in attention layers. - inner_mlp_size (`int`, *optional*, defaults to 5120): - Size of the intermediate (MLP) layer in the feed-forward network. - max_position_embeddings (`int`, *optional*, defaults to 8192): - Maximum sequence length that this model might ever be used with. - rotary_emb_base (`int`, *optional*, defaults to 10000): - Base for rotary position embeddings. - - attn_layer_idxs (`List[int]`, *optional*): - Indices of layers that use attention. - hcl_layer_idxs (`List[int]`, *optional*): - Indices of "HCL" Hyena layers. - hcm_layer_idxs (`List[int]`, *optional*): - Indices of "HCM" Hyena layers. - hcs_layer_idxs (`List[int]`, *optional*): - Indices of "HCS" Hyena layers. - - num_filters (`int`, *optional*, defaults to 1920): - Number of independent filters in Hyena-LI. - hcm_filter_length (`int`, *optional*, defaults to 128): - Length of HCM filters. - hcl_filter_groups (`int`, *optional*, defaults to 1920): - Number of filter groups for HCL. - hcm_filter_groups (`int`, *optional*, defaults to 128): - Number of filter groups for HCM. - hcs_filter_groups (`int`, *optional*, defaults = 128): - Number of filter groups for HCS. - hcs_filter_length (`int`, *optional*, defaults = 7): - Length of HCS filters. - short_filter_length (`int`, *optional*, defaults = 3): - Length of short depthwise FIR filters. - short_filter_bias (`bool`, *optional*, defaults = False): - Whether to add a bias to FIR filters. - - state_size (`int`, *optional*, defaults = 16): - Size of the Hyena state. - eps (`float`, *optional*, defaults = 1e-6): - Epsilon used for numerical stability in layer norms etc. - proj_groups (`int`, *optional*, defaults = 1): - Number of groups for grouped query/key/value projections. - hyena_filter_groups (`int`, *optional*, defaults = 1): - Number of groups for Hyena filters. - - column_split_hyena (`bool`, *optional*, defaults = False): - Whether to column-split Hyena channels (for tensor parallelism). - column_split (`bool`, *optional*, defaults = True): - Whether to column-split projections. - interleave (`bool`, *optional*, defaults = True): - Whether to interleave channels. - - evo2_style_activations (`bool`, *optional*, defaults = True): - Use Evo2-style activations (identity for some layers). - mlp_activation (`str`, *optional*, defaults = "gelu"): - Activation function in the MLP. - - make_vocab_size_divisible_by (`int`, *optional*, defaults = 8): - Pad vocab size to be divisible by this value. - inner_size_multiple_of (`int`, *optional*, defaults = 16): - Force MLP inner size to be a multiple of this value. - - tie_embeddings (`bool`, *optional*, defaults = True): - Whether to tie input and output embeddings. - mha_out_proj_bias (`bool`, *optional*, defaults = True): - Whether to use bias in attention output projections. - hyena_out_proj_bias (`bool`, *optional*, defaults = True): - Whether to use bias in Hyena output projections. - qkv_proj_bias (`bool`, *optional*, defaults = False): - Whether to use bias in QKV projections. - final_norm (`bool`, *optional*, defaults = True): - Whether to apply a final normalization layer. - - use_flash_attn (`bool`, *optional*, defaults = True): - Whether to use FlashAttention when available. - use_flash_rmsnorm (`bool`, *optional*, defaults = False): - Whether to use a fused Flash RMSNorm implementation. - use_flash_depthwise (`bool`, *optional*, defaults = False): - Whether to use fused depthwise convolution kernels. - use_flashfft (`bool`, *optional*, defaults = False): - Whether to use FFT-based kernels for long convolutions. - use_laughing_hyena (`bool`, *optional*, defaults = False): - Experimental variant toggle. - - max_batch_size (`int`, *optional*, defaults = 1): - Max batch size used in the original config (not enforced by HF). - inference_mode (`bool`, *optional*, defaults = True): - Indicates original config was built for inference. - - tokenizer_type (`str`, *optional*, defaults = "CharLevelTokenizer"): - Name of the tokenizer expected by the original implementation. - prefill_style (`str`, *optional*, defaults = "fft"): - Prefill strategy used in original Evo2. - - print_activations (`bool`, *optional*, defaults = False): - Log intermediate activations (debugging). - log_intermediate_values (`bool`, *optional*, defaults = False): - Log intermediate values in original code (debugging). +logger = logging.get_logger(__name__) - model_parallel_size (`int`, *optional*, defaults = 1): - Original MP size; informational only here. - pipe_parallel_size (`int`, *optional*, defaults = 1): - Original PP size; informational only here. +__all__ = ["Evo2Config"] - hyena_flip_x1x2 (`bool`, *optional*, defaults = False): - Flip Hyena kernel inputs (compat option). - use_fp8_input_projections (`bool`, *optional*, defaults = True): - Whether the original model used FP8 input projections. - **kwargs: - Additional keyword arguments passed to `PretrainedConfig`. - """ +class Evo2Config(PretrainedConfig): + r"""Configuration class for the Evo2 model.""" model_type = "evo2" + keys_to_ignore_at_inference = ["past_key_values"] def __init__( self, - vocab_size: int = 512, - hidden_size: int = 1920, - num_layers: int = 25, - num_attention_heads: int = 15, - inner_mlp_size: int = 5120, - max_position_embeddings: int = 8192, - rotary_emb_base: int = 10000, - attn_layer_idxs: Optional[List[int]] = None, - hcl_layer_idxs: Optional[List[int]] = None, - hcm_layer_idxs: Optional[List[int]] = None, - hcs_layer_idxs: Optional[List[int]] = None, - num_filters: int = 1920, - hcm_filter_length: int = 128, - hcl_filter_groups: int = 1920, - hcm_filter_groups: int = 128, - hcs_filter_groups: int = 128, - hcs_filter_length: int = 7, - short_filter_length: int = 3, - short_filter_bias: bool = False, - state_size: int = 16, - eps: float = 1e-6, - proj_groups: int = 1, - hyena_filter_groups: int = 1, - column_split_hyena: bool = False, - column_split: bool = True, - interleave: bool = True, - evo2_style_activations: bool = True, - mlp_activation: str = "gelu", - make_vocab_size_divisible_by: int = 8, - inner_size_multiple_of: int = 16, - tie_embeddings: bool = True, - mha_out_proj_bias: bool = True, - hyena_out_proj_bias: bool = True, - qkv_proj_bias: bool = False, - final_norm: bool = True, - use_flash_attn: bool = True, - use_flash_rmsnorm: bool = False, - use_flash_depthwise: bool = False, - use_flashfft: bool = False, - use_laughing_hyena: bool = False, - max_batch_size: int = 1, - inference_mode: bool = True, - tokenizer_type: str = "CharLevelTokenizer", - prefill_style: str = "fft", - print_activations: bool = False, - log_intermediate_values: bool = False, - model_parallel_size: int = 1, - pipe_parallel_size: int = 1, - hyena_flip_x1x2: bool = False, - use_fp8_input_projections: bool = True, + vocab_size: int = 256, + hidden_size: int = 2048, + intermediate_size: Optional[int] = None, + num_hidden_layers: int = 24, + num_attention_heads: int = 16, + num_key_value_heads: Optional[int] = None, + max_position_embeddings: int = 2048, + rope_theta: float = 1_000_000.0, + rms_norm_eps: float = 1e-6, + attn_dropout: float = 0.0, + hidden_dropout: float = 0.0, + mlp_dropout: float = 0.0, + layer_types: Optional[Sequence[str]] = None, + hyena_filters: int = 256, + hyena_kernel_size: int = 8, + hyena_hidden_size: Optional[int] = None, + hyena_order: int = 4, + initializer_range: float = 0.02, + use_cache: bool = True, + pad_token_id: int = 1, + bos_token_id: Optional[int] = None, + eos_token_id: int = 0, + tie_word_embeddings: bool = True, **kwargs, - ): - super().__init__(**kwargs) + ) -> None: + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) - # Core HF-style fields self.vocab_size = vocab_size self.hidden_size = hidden_size - self.num_layers = num_layers + self.intermediate_size = intermediate_size if intermediate_size is not None else hidden_size * 4 + self.num_hidden_layers = num_hidden_layers self.num_attention_heads = num_attention_heads - self.intermediate_size = inner_mlp_size # HF naming - self.inner_mlp_size = inner_mlp_size # original naming + self.num_key_value_heads = num_key_value_heads or num_attention_heads self.max_position_embeddings = max_position_embeddings - - # Rotary embeddings - self.rotary_emb_base = rotary_emb_base - - # Layer index layout - self.attn_layer_idxs = attn_layer_idxs or [3, 10, 17, 24] - self.hcl_layer_idxs = hcl_layer_idxs or [2, 6, 9, 13, 16, 20, 23] - self.hcm_layer_idxs = hcm_layer_idxs or [1, 5, 8, 12, 15, 19, 22] - self.hcs_layer_idxs = hcs_layer_idxs or [0, 4, 7, 11, 14, 18, 21] - - # Hyena / filter hyperparameters - self.num_filters = num_filters - self.hcm_filter_length = hcm_filter_length - self.hcl_filter_groups = hcl_filter_groups - self.hcm_filter_groups = hcm_filter_groups - self.hcs_filter_groups = hcs_filter_groups - self.hcs_filter_length = hcs_filter_length - self.short_filter_length = short_filter_length - self.short_filter_bias = short_filter_bias - - # State & numerics - self.state_size = state_size - self.eps = eps - - # Grouping & splitting - self.proj_groups = proj_groups - self.hyena_filter_groups = hyena_filter_groups - self.column_split_hyena = column_split_hyena - self.column_split = column_split - self.interleave = interleave - - # Activations / MLP - self.evo2_style_activations = evo2_style_activations - self.mlp_activation = mlp_activation - self.make_vocab_size_divisible_by = make_vocab_size_divisible_by - self.inner_size_multiple_of = inner_size_multiple_of - - # Projection / embedding knobs - self.tie_embeddings = tie_embeddings - self.mha_out_proj_bias = mha_out_proj_bias - self.hyena_out_proj_bias = hyena_out_proj_bias - self.qkv_proj_bias = qkv_proj_bias - self.final_norm = final_norm - - # Flash / fused kernels (may be ignored in pure PyTorch version) - self.use_flash_attn = use_flash_attn - self.use_flash_rmsnorm = use_flash_rmsnorm - self.use_flash_depthwise = use_flash_depthwise - self.use_flashfft = use_flashfft - self.use_laughing_hyena = use_laughing_hyena - - # Original inference-related fields (kept for compatibility, not enforced) - self.max_batch_size = max_batch_size - self.inference_mode = inference_mode - - # Tokenizer / prefill / logging metadata - self.tokenizer_type = tokenizer_type - self.prefill_style = prefill_style - self.print_activations = print_activations - self.log_intermediate_values = log_intermediate_values - - # Parallelism & numeric tricks (informational) - self.model_parallel_size = model_parallel_size - self.pipe_parallel_size = pipe_parallel_size - self.hyena_flip_x1x2 = hyena_flip_x1x2 - self.use_fp8_input_projections = use_fp8_input_projections - - # For backward compatibility with original config name - self.max_seqlen = max_position_embeddings - - -__all__ = ["Evo2Config"] + self.rope_theta = rope_theta + self.rms_norm_eps = rms_norm_eps + self.attn_dropout = attn_dropout + self.hidden_dropout = hidden_dropout + self.mlp_dropout = mlp_dropout + self.hyena_filters = hyena_filters + self.hyena_kernel_size = hyena_kernel_size + self.hyena_hidden_size = hyena_hidden_size if hyena_hidden_size is not None else hidden_size + self.hyena_order = hyena_order + self.initializer_range = initializer_range + self.use_cache = use_cache + + if layer_types is None: + self.layer_types = ["attention"] * num_hidden_layers + else: + self.layer_types = list(layer_types) + + standardize_rope_params(self, rope_theta=self.rope_theta) + + if len(self.layer_types) != self.num_hidden_layers: + raise ValueError( + "The length of `layer_types` must match `num_hidden_layers` (received" + f" {len(self.layer_types)} and {self.num_hidden_layers})." + ) + + for layer_type in self.layer_types: + if layer_type not in {"attention", "hyena"}: + raise ValueError(f"Unsupported layer type: {layer_type}. Expected 'attention' or 'hyena'.") + + if self.num_attention_heads <= 0 or self.hidden_size % self.num_attention_heads != 0: + raise ValueError("`hidden_size` must be divisible by `num_attention_heads`.") + + if self.num_key_value_heads <= 0 or self.hidden_size % self.num_key_value_heads != 0: + raise ValueError("`hidden_size` must be divisible by `num_key_value_heads`.") + + logger.info("Initialized Evo2Config with %s layers (%s).", self.num_hidden_layers, ", ".join(self.layer_types)) + + @property + def head_dim(self) -> int: + return self.hidden_size // self.num_attention_heads + + @property + def kv_head_dim(self) -> int: + return self.hidden_size // self.num_key_value_heads + + @property + def num_attention_layers(self) -> int: + return sum(layer_type == "attention" for layer_type in self.layer_types) + + @property + def num_hyena_layers(self) -> int: + return sum(layer_type == "hyena" for layer_type in self.layer_types) + + def to_dict(self) -> dict: + output = super().to_dict() + output["layer_types"] = list(self.layer_types) + return output diff --git a/src/transformers/models/evo2/convert_evo2_original_to_hf.py b/src/transformers/models/evo2/convert_evo2_original_to_hf.py new file mode 100644 index 000000000000..d8bf5f3984bd --- /dev/null +++ b/src/transformers/models/evo2/convert_evo2_original_to_hf.py @@ -0,0 +1,214 @@ +"""Conversion script for original Evo2 checkpoints to Hugging Face format.""" + +from __future__ import annotations + +import argparse +import json +import os +from typing import Dict, Iterable, Optional + +import torch + +try: + import yaml +except ImportError: # pragma: no cover + yaml = None + +from transformers import Evo2Config, Evo2ForCausalLM + + +def _load_original_state_dict(path: str) -> Dict[str, torch.Tensor]: + checkpoint = torch.load(path, map_location="cpu") + if isinstance(checkpoint, dict): + if "state_dict" in checkpoint: + return checkpoint["state_dict"] + if "model" in checkpoint and isinstance(checkpoint["model"], dict): + return checkpoint["model"] + if isinstance(checkpoint, (list, tuple)): + raise ValueError("Unexpected checkpoint structure. Expected a dictionary with model weights.") + return checkpoint + + +def _load_config(path: str) -> Evo2Config: + ext = os.path.splitext(path)[1].lower() + if ext in {".yml", ".yaml"}: + if yaml is None: + raise ImportError("PyYAML is required to parse YAML configuration files. Please install pyyaml.") + with open(path, "r", encoding="utf-8") as handle: + raw = yaml.safe_load(handle) + config_dict = raw.get("model", raw) + layer_specs = config_dict.get("layers", []) + layer_types = [] + for layer in layer_specs: + if isinstance(layer, dict): + layer_type = layer.get("type") or layer.get("layer_type") or layer.get("block_type") + if layer_type is None: + layer_type = "attention" + layer_types.append(layer_type.lower()) + else: + layer_types.append(str(layer).lower()) + config_dict["layer_types"] = layer_types or config_dict.get("layer_types") + return Evo2Config(**config_dict) + if ext == ".json": + with open(path, "r", encoding="utf-8") as handle: + config_kwargs = json.load(handle) + return Evo2Config(**config_kwargs) + if os.path.isdir(path): + return Evo2Config.from_pretrained(path) + raise ValueError(f"Unsupported config format for '{path}'. Expected directory, JSON, or YAML file.") + + +def _match_first_available(state_dict: Dict[str, torch.Tensor], candidates: Iterable[str]) -> Optional[str]: + for candidate in candidates: + if candidate in state_dict: + return candidate + return None + + +def _map_attention_key(layer_idx: int, target_suffix: str, original_state: Dict[str, torch.Tensor]) -> Optional[str]: + suffix_map = { + "q_proj.weight": ["attn.wq.weight", "attention.wq.weight", "attention.q_proj.weight"], + "k_proj.weight": ["attn.wk.weight", "attention.wk.weight", "attention.k_proj.weight"], + "v_proj.weight": ["attn.wv.weight", "attention.wv.weight", "attention.v_proj.weight"], + "o_proj.weight": ["attn.wo.weight", "attention.wo.weight", "attention.out_proj.weight"], + "q_proj.bias": ["attn.wq.bias", "attention.wq.bias", "attention.q_proj.bias"], + "k_proj.bias": ["attn.wk.bias", "attention.wk.bias", "attention.k_proj.bias"], + "v_proj.bias": ["attn.wv.bias", "attention.wv.bias", "attention.v_proj.bias"], + "o_proj.bias": ["attn.wo.bias", "attention.wo.bias", "attention.out_proj.bias"], + "input_layernorm.weight": ["attn_norm.weight", "input_layernorm.weight"], + "post_attention_layernorm.weight": ["mlp_norm.weight", "post_attention_layernorm.weight"], + "input_layernorm.bias": ["attn_norm.bias", "input_layernorm.bias"], + "post_attention_layernorm.bias": ["mlp_norm.bias", "post_attention_layernorm.bias"], + } + candidates = [] + for variant in suffix_map.get(target_suffix, []): + candidates.append(f"layers.{layer_idx}.{variant}") + candidates.append(f"model.layers.{layer_idx}.{variant}") + return _match_first_available(original_state, candidates) + + +def _map_mlp_key(layer_idx: int, target_suffix: str, original_state: Dict[str, torch.Tensor]) -> Optional[str]: + suffix_map = { + "gate_proj.weight": ["mlp.gate_proj.weight", "mlp.w1.weight", "ffn.gate_proj.weight"], + "up_proj.weight": ["mlp.up_proj.weight", "mlp.w3.weight", "ffn.up_proj.weight"], + "down_proj.weight": ["mlp.down_proj.weight", "mlp.w2.weight", "ffn.down_proj.weight"], + "gate_proj.bias": ["mlp.gate_proj.bias", "mlp.w1.bias", "ffn.gate_proj.bias"], + "up_proj.bias": ["mlp.up_proj.bias", "mlp.w3.bias", "ffn.up_proj.bias"], + "down_proj.bias": ["mlp.down_proj.bias", "mlp.w2.bias", "ffn.down_proj.bias"], + } + candidates = [] + for variant in suffix_map.get(target_suffix, []): + candidates.append(f"layers.{layer_idx}.{variant}") + candidates.append(f"model.layers.{layer_idx}.{variant}") + return _match_first_available(original_state, candidates) + + +def _map_hyena_key(layer_idx: int, target_suffix: str, original_state: Dict[str, torch.Tensor]) -> Optional[str]: + suffix_map = { + "input_layernorm.weight": ["hyena_norm.weight", "input_layernorm.weight"], + "input_layernorm.bias": ["hyena_norm.bias", "input_layernorm.bias"], + "post_attention_layernorm.weight": ["mlp_norm.weight", "post_layernorm.weight"], + "post_attention_layernorm.bias": ["mlp_norm.bias", "post_layernorm.bias"], + "filter.in_proj.weight": ["hyena.filter.in_proj.weight", "hyena.in_proj.weight"], + "filter.out_proj.weight": ["hyena.filter.out_proj.weight", "hyena.out_proj.weight"], + "filter.conv.weight": ["hyena.filter.conv.weight"], + } + candidates = [] + for variant in suffix_map.get(target_suffix, []): + candidates.append(f"layers.{layer_idx}.{variant}") + candidates.append(f"model.layers.{layer_idx}.{variant}") + return _match_first_available(original_state, candidates) + + +def _map_key(target_key: str, config: Evo2Config, original_state: Dict[str, torch.Tensor]) -> Optional[str]: + if target_key == "model.embed_tokens.weight": + return _match_first_available( + original_state, + [ + "model.embed_tokens.weight", + "embed_tokens.weight", + "tok_embeddings.weight", + "embedding.weight", + "embeddings.word_embeddings.weight", + ], + ) + if target_key == "model.norm.weight": + return _match_first_available( + original_state, + ["model.norm.weight", "norm.weight", "final_layer_norm.weight", "rms_norm.weight"], + ) + if target_key == "lm_head.weight": + return _match_first_available(original_state, ["lm_head.weight", "output.weight", "head.weight"]) + + if target_key.startswith("model.layers."): + parts = target_key.split(".") + layer_idx = int(parts[2]) + layer_type = config.layer_types[layer_idx] + suffix = ".".join(parts[4:]) if parts[3] == "block" else ".".join(parts[3:]) + if layer_type == "attention": + if suffix.startswith("attention."): + attn_suffix = suffix[len("attention.") :] + return _map_attention_key(layer_idx, attn_suffix, original_state) + if suffix.startswith("mlp."): + mlp_suffix = suffix[len("mlp.") :] + return _map_mlp_key(layer_idx, mlp_suffix, original_state) + if suffix.startswith("hidden_dropout"): + return None + if suffix.startswith("input_layernorm") or suffix.startswith("post_attention_layernorm"): + return _map_attention_key(layer_idx, suffix, original_state) + else: + if suffix.startswith("filter."): + filter_suffix = suffix + return _map_hyena_key(layer_idx, filter_suffix, original_state) + if suffix.startswith("mlp."): + mlp_suffix = suffix[len("mlp.") :] + return _map_mlp_key(layer_idx, mlp_suffix, original_state) + if suffix.startswith("input_layernorm") or suffix.startswith("post_attention_layernorm"): + return _map_hyena_key(layer_idx, suffix, original_state) + return None + return None + + +def convert_checkpoint(original_checkpoint: str, config_path: str, output_dir: str) -> None: + original_state = _load_original_state_dict(original_checkpoint) + config = _load_config(config_path) + model = Evo2ForCausalLM(config) + + target_state = model.state_dict() + new_state = {} + missing_keys = [] + + for key in target_state.keys(): + source_key = _map_key(key, config, original_state) + if source_key is None: + missing_keys.append(key) + continue + new_state[key] = original_state[source_key] + + if missing_keys: + raise KeyError( + "The following keys could not be mapped from the original checkpoint: " + ", ".join(missing_keys) + ) + + model.load_state_dict(new_state, strict=True) + + os.makedirs(output_dir, exist_ok=True) + model.save_pretrained(output_dir) + config.save_pretrained(output_dir) + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description="Convert an original Evo2 checkpoint to Hugging Face format.") + parser.add_argument("checkpoint", type=str, help="Path to the original .pt checkpoint file") + parser.add_argument("config", type=str, help="Path to the Evo2 YAML/JSON config or HF directory") + parser.add_argument("output", type=str, help="Output directory for the converted model") + return parser.parse_args() + + +def main() -> None: + args = parse_args() + convert_checkpoint(args.checkpoint, args.config, args.output) + + +if __name__ == "__main__": + main() diff --git a/src/transformers/models/evo2/modeling_evo2.py b/src/transformers/models/evo2/modeling_evo2.py index e8c07806c26c..3928e23ee00a 100644 --- a/src/transformers/models/evo2/modeling_evo2.py +++ b/src/transformers/models/evo2/modeling_evo2.py @@ -1,447 +1,548 @@ -# coding=utf-8 -# Copyright 2025 The HuggingFace Team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. +"""PyTorch Evo2 model.""" from __future__ import annotations -from dataclasses import dataclass -from typing import List, Optional, Tuple, Union - import math +from collections.abc import Callable +from typing import Optional, Tuple + import torch -import torch.nn as nn import torch.nn.functional as F - -from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast -from transformers.modeling_utils import PreTrainedModel -from transformers.utils import logging - +from torch import nn +from torch.utils.checkpoint import checkpoint + +from ...cache_utils import Cache, DynamicCache +from ...modeling_attn_mask_utils import _prepare_4d_causal_attention_mask +from ...generation import GenerationMixin +from ...modeling_layers import GradientCheckpointingLayer +from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast +from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update +from ...modeling_utils import PreTrainedModel +from ...utils import logging from .configuration_evo2 import Evo2Config logger = logging.get_logger(__name__) - -# ========================= -# Norm + Rotary helpers -# ========================= +__all__ = ["Evo2Model", "Evo2ForCausalLM", "Evo2PreTrainedModel"] +# Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->Evo2 class Evo2RMSNorm(nn.Module): - def __init__(self, hidden_size: int, eps: float = 1e-6): + def __init__(self, hidden_size, eps=1e-6): + """Evo2RMSNorm is equivalent to T5LayerNorm.""" super().__init__() self.weight = nn.Parameter(torch.ones(hidden_size)) - self.eps = eps + self.variance_epsilon = eps - def forward(self, x: torch.Tensor) -> torch.Tensor: - # standard RMSNorm - norm = x.float().pow(2).mean(-1, keepdim=True) - x = x * torch.rsqrt(norm + self.eps) - return (self.weight * x).to(x.dtype) + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" -class RotaryEmbedding(nn.Module): - """ - Simple rotary embedding (RoPE) implementation. - We keep this minimal; you can later swap for the shared one from another model. - """ - def __init__(self, dim: int, base: int = 10000): - super().__init__() - self.dim = dim - self.base = base +# Copied from transformers.models.llama.modeling_llama.rotate_half +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) - def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype): - inv_freq = 1.0 / ( - self.base ** (torch.arange(0, self.dim, 2, device=device, dtype=torch.float32) / self.dim) - ) - t = torch.arange(seq_len, device=device, dtype=torch.float32) - freqs = torch.einsum("i,j->ij", t, inv_freq) # [seq_len, dim/2] - emb = torch.cat((freqs, freqs), dim=-1) # [seq_len, dim] - return torch.cos(emb).to(dtype), torch.sin(emb).to(dtype) +# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors.""" + del position_ids + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed -def apply_rotary(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor: - """ - x: [b, s, h, d] - cos/sin: [1, s, 1, d] - """ - x1, x2 = x[..., ::2], x[..., 1::2] - cos = cos[..., ::2] - sin = sin[..., ::2] - x1_rot = x1 * cos - x2 * sin - x2_rot = x1 * sin + x2 * cos - x_rot = torch.stack([x1_rot, x2_rot], dim=-1) - x_rot = x_rot.flatten(-2) - return x_rot +class Evo2RotaryEmbedding(nn.Module): + inv_freq: torch.Tensor -# ========================= -# Attention block -# ========================= + def __init__(self, config: Evo2Config, device=None): + super().__init__() + self.max_seq_len_cached = getattr(config, "max_position_embeddings", 2048) + self.original_max_seq_len = self.max_seq_len_cached + self.config = config -class Evo2Attention(nn.Module): + self.rope_type = self.config.rope_parameters["rope_type"] + rope_init_fn: Callable = self.compute_default_rope_parameters + if self.rope_type != "default": + rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + inv_freq, self.attention_scaling = rope_init_fn(self.config, device) + + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.original_inv_freq = inv_freq + + @staticmethod + def compute_default_rope_parameters( + config: Optional[Evo2Config] = None, + device: Optional[torch.device] = None, + seq_len: Optional[int] = None, + ) -> Tuple[torch.Tensor, float]: + del seq_len + rope_params = getattr(config, "rope_parameters", None) + base = rope_params.get("rope_theta") if rope_params is not None else config.rope_theta + dim = config.head_dim + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float, device=device) / dim)) + return inv_freq, 1.0 + + @torch.no_grad() + @dynamic_rope_update + def forward(self, x: torch.Tensor, position_ids: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device) + position_ids_expanded = position_ids[:, None, :].float() + + device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() * self.attention_scaling + sin = emb.sin() * self.attention_scaling + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +# Copied from transformers.models.llama.modeling_llama.repeat_kv +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +class Evo2ParallelGatedMLP(nn.Module): def __init__(self, config: Evo2Config): super().__init__() self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + self.dropout = nn.Dropout(config.mlp_dropout) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + gated = F.silu(self.gate_proj(hidden_states)) + up = self.up_proj(hidden_states) + hidden_states = self.down_proj(gated * up) + hidden_states = self.dropout(hidden_states) + return hidden_states + + +class Evo2Attention(nn.Module): + def __init__(self, config: Evo2Config, layer_idx: int): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size self.num_heads = config.num_attention_heads + self.num_key_value_heads = config.num_key_value_heads self.head_dim = self.hidden_size // self.num_heads + self.kv_head_dim = self.hidden_size // self.num_key_value_heads + self.layer_idx = layer_idx - if self.head_dim * self.num_heads != self.hidden_size: - raise ValueError("hidden_size must be divisible by num_attention_heads") - - self.q_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=config.qkv_proj_bias) - self.k_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=config.qkv_proj_bias) - self.v_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=config.qkv_proj_bias) - self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=config.mha_out_proj_bias) + self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False) + self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.kv_head_dim, bias=False) + self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.kv_head_dim, bias=False) + self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) + self.dropout = nn.Dropout(config.attn_dropout) - self.rotary_emb = RotaryEmbedding(self.head_dim, base=config.rotary_emb_base) + self.rotary_emb = Evo2RotaryEmbedding(config) def forward( self, hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + attention_mask: Optional[torch.Tensor], + position_ids: torch.Tensor, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, use_cache: bool = False, - ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]: - bsz, seq_len, _ = hidden_states.size() - - q = self.q_proj(hidden_states) - k = self.k_proj(hidden_states) - v = self.v_proj(hidden_states) + cache_position: Optional[torch.LongTensor] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor, torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() - # [b, s, h, d] - q = q.view(bsz, seq_len, self.num_heads, self.head_dim) - k = k.view(bsz, seq_len, self.num_heads, self.head_dim) - v = v.view(bsz, seq_len, self.num_heads, self.head_dim) + query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.kv_head_dim).transpose(1, 2) + value_states = self.v_proj(hidden_states).view( + bsz, q_len, self.num_key_value_heads, self.kv_head_dim + ).transpose(1, 2) - cos, sin = self.rotary_emb(seq_len, hidden_states.device, hidden_states.dtype) - cos = cos[None, :, None, :] # [1, s, 1, d] - sin = sin[None, :, None, :] - q = apply_rotary(q, cos, sin) - k = apply_rotary(k, cos, sin) + cos, sin = self.rotary_emb(query_states, position_ids) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_value is not None: - past_k, past_v = past_key_value - k = torch.cat([past_k, k], dim=1) - v = torch.cat([past_v, v], dim=1) + cache_kwargs = {"cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - present_key_value = (k, v) if use_cache else None + kv_seq_len = key_states.shape[-2] - # [b, h, s, d] - q = q.permute(0, 2, 1, 3) - k = k.permute(0, 2, 1, 3) - v = v.permute(0, 2, 1, 3) + if self.num_key_value_heads != self.num_heads: + key_states = repeat_kv(key_states, self.num_heads // self.num_key_value_heads) + value_states = repeat_kv(value_states, self.num_heads // self.num_key_value_heads) - attn_weights = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim) + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) if attention_mask is not None: - # attention_mask expected [b, 1, 1, s_k]; add additive mask attn_weights = attn_weights + attention_mask - attn_probs = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(q.dtype) - attn_output = torch.matmul(attn_probs, v) # [b, h, s, d] + attn_weights = nn.functional.softmax(attn_weights, dim=-1) + attn_weights = self.dropout(attn_weights) - attn_output = attn_output.permute(0, 2, 1, 3).contiguous() - attn_output = attn_output.view(bsz, seq_len, self.hidden_size) + attn_output = torch.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous().view(bsz, q_len, self.hidden_size) attn_output = self.o_proj(attn_output) - return attn_output, present_key_value - + present = past_key_value if use_cache else None + return attn_output, (attn_weights if output_attentions else None), present -# ========================= -# Hyena-ish block (simplified) -# ========================= - - -class Evo2HyenaBlock(nn.Module): - """ - Simplified Hyena-style block. - - This is NOT the full HyenaCascade from Vortex. It’s a placeholder: - - depthwise conv over time - - small MLP - - You can later replace this with a faithful StripedHyena2 port. - """ +class Evo2HyenaFilter(nn.Module): def __init__(self, config: Evo2Config): super().__init__() self.hidden_size = config.hidden_size - self.short_filter_length = config.short_filter_length - - self.dw_conv = nn.Conv1d( - in_channels=self.hidden_size, - out_channels=self.hidden_size, - kernel_size=self.short_filter_length, - padding=self.short_filter_length // 2, - groups=self.hidden_size, - bias=config.short_filter_bias, - ) - - self.mlp = nn.Sequential( - nn.Linear(self.hidden_size, config.inner_mlp_size), - nn.GELU(), # matches mlp_activation default - nn.Linear(config.inner_mlp_size, self.hidden_size), + self.order = config.hyena_order + self.filter_channels = config.hyena_filters + self.kernel_size = config.hyena_kernel_size + + self.in_proj = nn.Linear(self.hidden_size, self.filter_channels * self.order, bias=False) + self.conv = nn.Conv1d( + in_channels=self.filter_channels, + out_channels=self.filter_channels, + kernel_size=self.kernel_size, + groups=self.filter_channels, + padding=self.kernel_size - 1, ) + self.out_proj = nn.Linear(self.filter_channels * self.order, self.hidden_size, bias=False) + self.activation = nn.SiLU() def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - # [b, s, h] -> [b, h, s] for conv - x = hidden_states.transpose(1, 2) - x = self.dw_conv(x) - x = x.transpose(1, 2) - x = self.mlp(x) - return x - + batch, seq_len, _ = hidden_states.shape + projected = self.in_proj(hidden_states) + projected = projected.view(batch, seq_len, self.order, self.filter_channels).permute(0, 2, 3, 1) + conv_input = projected.reshape(batch * self.order, self.filter_channels, seq_len) + conv_output = self.conv(conv_input) + conv_output = conv_output[:, :, :seq_len] + conv_output = conv_output.view(batch, self.order, self.filter_channels, seq_len).permute(0, 3, 1, 2) + conv_output = conv_output.reshape(batch, seq_len, self.order * self.filter_channels) + conv_output = self.activation(conv_output) + return self.out_proj(conv_output) + + +class Evo2AttentionBlock(nn.Module): + def __init__(self, config: Evo2Config, layer_idx: int): + super().__init__() + self.attention = Evo2Attention(config, layer_idx) + self.input_layernorm = Evo2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = Evo2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.mlp = Evo2ParallelGatedMLP(config) + self.hidden_dropout = nn.Dropout(config.hidden_dropout) -# ========================= -# Evo2Block -# ========================= + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor], + position_ids: torch.Tensor, + past_key_value: Optional[Cache], + output_attentions: bool, + use_cache: bool, + cache_position: Optional[torch.LongTensor], + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + attn_output, attn_weights, present_kv = self.attention( + hidden_states, + attention_mask, + position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + ) + hidden_states = residual + self.hidden_dropout(attn_output) + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + self.hidden_dropout(hidden_states) -class Evo2Block(nn.Module): - def __init__(self, config: Evo2Config, layer_idx: int): - super().__init__() - self.config = config - self.layer_idx = layer_idx + return hidden_states, attn_weights, present_kv - self.norm1 = Evo2RMSNorm(config.hidden_size, eps=config.eps) - self.norm2 = Evo2RMSNorm(config.hidden_size, eps=config.eps) - if layer_idx in config.attn_layer_idxs: - self.block_type = "attn" - self.attn = Evo2Attention(config) - self.hyena = None - else: - self.block_type = "hyena" - self.attn = None - self.hyena = Evo2HyenaBlock(config) - - # Simple MLP for the second residual (you can adjust to ParallelGatedMLP later) - self.mlp = nn.Sequential( - nn.Linear(config.hidden_size, config.inner_mlp_size), - nn.GELU(), - nn.Linear(config.inner_mlp_size, config.hidden_size), - ) +class Evo2HyenaBlock(nn.Module): + def __init__(self, config: Evo2Config): + super().__init__() + self.input_layernorm = Evo2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.filter = Evo2HyenaFilter(config) + self.post_attention_layernorm = Evo2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.mlp = Evo2ParallelGatedMLP(config) + self.hidden_dropout = nn.Dropout(config.hidden_dropout) def forward( self, hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, - use_cache: bool = False, - ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]: + attention_mask: Optional[torch.Tensor], + position_ids: torch.Tensor, + past_key_value: Optional[Cache], + output_attentions: bool, + use_cache: bool, + cache_position: Optional[torch.LongTensor], + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]: + del attention_mask, past_key_value, output_attentions, use_cache, cache_position, position_ids residual = hidden_states - hidden_states = self.norm1(hidden_states) - - present_key_value = None - - if self.block_type == "attn": - attn_output, present_key_value = self.attn( - hidden_states, - attention_mask=attention_mask, - past_key_value=past_key_value, - use_cache=use_cache, - ) - hidden_states = residual + attn_output - else: - hyena_out = self.hyena(hidden_states) - hidden_states = residual + hyena_out + hidden_states = self.input_layernorm(hidden_states) + hidden_states = self.filter(hidden_states) + hidden_states = residual + self.hidden_dropout(hidden_states) - # Second norm + MLP residual = hidden_states - hidden_states = self.norm2(hidden_states) + hidden_states = self.post_attention_layernorm(hidden_states) hidden_states = self.mlp(hidden_states) - hidden_states = residual + hidden_states + hidden_states = residual + self.hidden_dropout(hidden_states) - return hidden_states, present_key_value + return hidden_states, None, None -# ========================= -# Base model -# ========================= +class Evo2DecoderLayer(GradientCheckpointingLayer): + def __init__(self, config: Evo2Config, layer_type: str, layer_idx: int): + super().__init__() + self.layer_type = layer_type + if layer_type == "attention": + self.block = Evo2AttentionBlock(config, layer_idx) + else: + self.block = Evo2HyenaBlock(config) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor], + position_ids: torch.Tensor, + past_key_value: Optional[Cache], + output_attentions: bool, + use_cache: bool, + cache_position: Optional[torch.LongTensor], + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]: + return self.block( + hidden_states, + attention_mask, + position_ids, + past_key_value, + output_attentions, + use_cache, + cache_position, + ) class Evo2PreTrainedModel(PreTrainedModel): config_class = Evo2Config base_model_prefix = "model" - supports_gradient_checkpointing = False - _no_split_modules = ["Evo2Block"] + supports_gradient_checkpointing = True + _no_split_modules = ["Evo2DecoderLayer"] def _init_weights(self, module): if isinstance(module, nn.Linear): - nn.init.xavier_uniform_(module.weight) + nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range) if module.bias is not None: nn.init.zeros_(module.bias) elif isinstance(module, nn.Embedding): - nn.init.normal_(module.weight, mean=0.0, std=0.02) + nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range) + elif isinstance(module, Evo2RMSNorm): + module.weight.data.fill_(1.0) -class Evo2Model(Evo2PreTrainedModel): - """ - Decoder-only Evo2 backbone: embeddings + stack of Evo2Blocks. - """ +class Evo2Model(Evo2PreTrainedModel): def __init__(self, config: Evo2Config): super().__init__(config) - - self.padding_idx = 0 - self.vocab_size = config.vocab_size - - self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=self.padding_idx) - + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size) self.layers = nn.ModuleList( - [Evo2Block(config, layer_idx=i) for i in range(config.num_layers)] + [Evo2DecoderLayer(config, layer_type, layer_idx) for layer_idx, layer_type in enumerate(config.layer_types)] ) - - self.final_norm = Evo2RMSNorm(config.hidden_size, eps=config.eps) if config.final_norm else None + self.norm = Evo2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout) + self.gradient_checkpointing = False self.post_init() def get_input_embeddings(self): return self.embed_tokens - def set_input_embeddings(self, new_embeddings: nn.Embedding): - self.embed_tokens = new_embeddings + def set_input_embeddings(self, value): + self.embed_tokens = value def forward( self, input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.Tensor] = None, use_cache: Optional[bool] = None, - output_attentions: bool = False, - output_hidden_states: bool = False, - return_dict: bool = True, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, ) -> BaseModelOutputWithPast: - if output_attentions: - logger.warning_once("Evo2Model does not currently return attentions.") + output_attentions = ( + output_attentions if output_attentions is not None else getattr(self.config, "output_attentions", False) + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else getattr(self.config, "output_hidden_states", False) + ) + return_dict = return_dict if return_dict is not None else True + use_cache = use_cache if use_cache is not None else self.config.use_cache if input_ids is not None and inputs_embeds is not None: - raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time.") - elif input_ids is None and inputs_embeds is None: - raise ValueError("You have to specify either input_ids or inputs_embeds.") + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + batch_size, seq_length = input_ids.shape + elif inputs_embeds is not None: + batch_size, seq_length = inputs_embeds.shape[:2] + else: + raise ValueError("You must specify either input_ids or inputs_embeds") if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) - hidden_states = inputs_embeds - bsz, seq_len, _ = hidden_states.size() + if use_cache: + if past_key_values is None: + past_key_values = DynamicCache(config=self.config) + elif not isinstance(past_key_values, Cache): + raise TypeError("`past_key_values` must be a `Cache` when `use_cache` is True.") + else: + past_key_values = None - # Build causal attention mask if not provided (2D mask with 1 for non-padded tokens) - if attention_mask is not None: - # [b, s] -> [b, 1, 1, s] additive mask - attention_mask = attention_mask[:, None, None, :] - attention_mask = (1.0 - attention_mask) * torch.finfo(hidden_states.dtype).min + past_length = past_key_values.get_seq_length() if past_key_values is not None else 0 - all_hidden_states = [] if output_hidden_states else None - next_past_key_values = [] if use_cache else None + attention_mask = _prepare_4d_causal_attention_mask( + attention_mask, + (batch_size, seq_length), + inputs_embeds, + past_length, + ) - for idx, layer in enumerate(self.layers): - if output_hidden_states: - all_hidden_states.append(hidden_states) + hidden_states = self.dropout(inputs_embeds) - past_kv = past_key_values[idx] if past_key_values is not None else None + if cache_position is None: + cache_position = torch.arange(past_length, past_length + seq_length, device=hidden_states.device) + else: + cache_position = cache_position.to(hidden_states.device) - hidden_states, present_kv = layer( - hidden_states, - attention_mask=attention_mask, - past_key_value=past_kv, - use_cache=use_cache, - ) + if position_ids is None: + position_ids = cache_position.unsqueeze(0).expand(batch_size, -1) - if use_cache: - next_past_key_values.append(present_kv) + all_hidden_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None - if self.final_norm is not None: - hidden_states = self.final_norm(hidden_states) + for layer_idx, (decoder_layer, layer_type) in enumerate(zip(self.layers, self.config.layer_types)): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + layer_past = past_key_values if layer_type == "attention" else None + + if self.gradient_checkpointing and self.training: + if use_cache: + use_cache = False + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + hidden_states, attn, present_kv = checkpoint( + create_custom_forward(decoder_layer), + hidden_states, + attention_mask, + position_ids, + layer_past, + output_attentions, + use_cache, + cache_position, + ) + else: + hidden_states, attn, present_kv = decoder_layer( + hidden_states, + attention_mask, + position_ids, + layer_past, + output_attentions, + use_cache, + cache_position, + ) + + if layer_type == "attention" and present_kv is not None and use_cache: + past_key_values = present_kv + + if output_attentions: + all_attentions = all_attentions + (attn,) + + hidden_states = self.norm(hidden_states) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) if not return_dict: - outputs = (hidden_states, next_past_key_values) + outputs = (hidden_states, past_key_values) if output_hidden_states: - outputs = (hidden_states, next_past_key_values, all_hidden_states) + outputs += (all_hidden_states,) + if output_attentions: + outputs += (all_attentions,) return outputs return BaseModelOutputWithPast( last_hidden_state=hidden_states, - past_key_values=next_past_key_values, + past_key_values=past_key_values if use_cache else None, hidden_states=all_hidden_states, + attentions=all_attentions, ) -# ========================= -# Causal LM head -# ========================= - - -class Evo2ForCausalLM(Evo2PreTrainedModel): - """ - Evo2 language model with a LM head on top of Evo2Model. - """ +class Evo2ForCausalLM(Evo2PreTrainedModel, GenerationMixin): + _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"} def __init__(self, config: Evo2Config): super().__init__(config) self.model = Evo2Model(config) self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) - if config.tie_embeddings: - self.tie_weights() - self.post_init() def get_input_embeddings(self): return self.model.embed_tokens - def set_input_embeddings(self, new_embeddings): - self.model.embed_tokens = new_embeddings - - def get_output_embeddings(self): - return self.lm_head - - def set_output_embeddings(self, new_embeddings): - self.lm_head = new_embeddings - - def tie_weights(self): - self._tie_or_clone_weights(self.lm_head, self.get_input_embeddings()) - - def prepare_inputs_for_generation( - self, - input_ids: torch.LongTensor, - past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None, - attention_mask: Optional[torch.Tensor] = None, - inputs_embeds: Optional[torch.Tensor] = None, - **kwargs, - ): - # Standard decoder-only prepare_inputs_for_generation: - # if we have past_key_values, only feed the last token. - if past_key_values is not None: - input_ids = input_ids[:, -1:] - - return { - "input_ids": input_ids, - "past_key_values": past_key_values, - "attention_mask": attention_mask, - "use_cache": True, - } + def set_input_embeddings(self, value): + self.model.embed_tokens = value def forward( self, input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None, + past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.Tensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, - output_attentions: bool = False, - output_hidden_states: bool = False, - return_dict: bool = True, - ) -> Union[Tuple, CausalLMOutputWithPast]: + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + logits_to_keep: int | torch.Tensor = 0, + **kwargs, + ) -> CausalLMOutputWithPast: + return_dict = return_dict if return_dict is not None else True + outputs = self.model( input_ids=input_ids, attention_mask=attention_mask, @@ -451,26 +552,27 @@ def forward( use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, - return_dict=True, + return_dict=return_dict, + cache_position=cache_position, + **kwargs, ) - hidden_states = outputs.last_hidden_state - logits = self.lm_head(hidden_states) + hidden_states = outputs[0] + if isinstance(logits_to_keep, int): + slice_indices = slice(-logits_to_keep, None) if logits_to_keep > 0 else slice(None) + logits = self.lm_head(hidden_states[:, slice_indices, :]) + else: + logits = self.lm_head(hidden_states[:, logits_to_keep, :]) loss = None if labels is not None: - # shift for causal LM - shift_logits = logits[:, :-1, :].contiguous() - shift_labels = labels[:, 1:].contiguous() - loss = F.cross_entropy( - shift_logits.view(-1, shift_logits.size(-1)), - shift_labels.view(-1), - ) + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + loss_fct = nn.CrossEntropyLoss() + loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) if not return_dict: - output = (logits, outputs.past_key_values) - if output_hidden_states: - output = (logits, outputs.past_key_values, outputs.hidden_states) + output = (logits,) + outputs[1:] return ((loss,) + output) if loss is not None else output return CausalLMOutputWithPast( @@ -478,4 +580,9 @@ def forward( logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, + attentions=outputs.attentions, ) + + def _reorder_cache(self, past_key_values: Cache, beam_idx: torch.LongTensor) -> Cache: + past_key_values.reorder_cache(beam_idx) + return past_key_values diff --git a/src/transformers/models/evo2/tokenization_evo2.py b/src/transformers/models/evo2/tokenization_evo2.py index f6d4833d0d26..6647c659446e 100644 --- a/src/transformers/models/evo2/tokenization_evo2.py +++ b/src/transformers/models/evo2/tokenization_evo2.py @@ -1,220 +1,130 @@ -# src/transformers/models/evo2/tokenization_evo2.py +"""Tokenizer for the Evo2 model.""" from __future__ import annotations import json import os -from typing import Dict, List, Optional, Tuple +from typing import List, Optional -from transformers.tokenization_utils import PreTrainedTokenizer -from transformers.utils import logging +from ...tokenization_utils import PreTrainedTokenizer +from ...utils import logging -logger = logging.get_logger(__name__) +logger = logging.get_logger(__name__) -VOCAB_FILES_NAMES = { - "vocab_file": "vocab.json", -} +__all__ = ["Evo2Tokenizer"] -PRETRAINED_VOCAB_FILES_MAP = { - "vocab_file": { - # You can fill these in once you upload a checkpoint - # "arcinstitute/evo2-1b-8k": "https://huggingface.co/arcinstitute/evo2-1b-8k/resolve/main/vocab.json", - } -} -PRETRAINED_INIT_CONFIGURATION = { - # "arcinstitute/evo2-1b-8k": {}, -} +def _clamp_token_id(token_id: int) -> int: + return max(0, min(255, int(token_id))) class Evo2Tokenizer(PreTrainedTokenizer): - """ - Hugging Face wrapper around the Evo2 CharLevelTokenizer. - - - Encoding: - text.encode("utf-8") -> list of uint8 bytes in [0, 255] - - Token IDs: - those bytes directly used as IDs (0..255). - `vocab_size` can be larger (e.g. 512), but extra IDs are unused. - - Decoding: - clamp each id with `clamp(n) = max(32, min(n, vocab_size))` - then `chr(clamp(n))` and join. - - We implement vocab as stringified integers: "0" -> 0, "1" -> 1, etc. - """ - - vocab_files_names = VOCAB_FILES_NAMES - pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP - pretrained_init_configuration = PRETRAINED_INIT_CONFIGURATION - - def __init__( - self, - vocab_file: Optional[str] = None, - vocab_size: int = 512, - # Match original CharLevelTokenizer semantics: - # eod_id = eos_id = 0, pad_id = 1 - eos_token: str = "0", - pad_token: str = "1", - unk_token: str = "0", # there is no real "unknown" in char-level; anything maps to a byte - bos_token: Optional[str] = None, - **kwargs, - ): - self._vocab_size = vocab_size - - if vocab_file is None: - # Default vocab: token "0" -> id 0, "1" -> id 1, ..., up to vocab_size-1 - self.vocab: Dict[str, int] = {str(i): i for i in range(vocab_size)} - else: - with open(vocab_file, "r", encoding="utf-8") as f: - self.vocab = json.load(f) - # Ensure ids are ints - self.vocab = {str(k): int(v) for k, v in self.vocab.items()} - - self.ids_to_tokens = {v: k for k, v in self.vocab.items()} - - # Call parent ctor (this also sets pad/eos/bos/unk attributes) + model_input_names = ["input_ids", "attention_mask"] + + def __init__(self, **kwargs) -> None: + self._vocab_size = 256 + self._token_to_id = {chr(i): i for i in range(self._vocab_size)} + self._id_to_token = {i: chr(i) for i in range(self._vocab_size)} + self._eos_token_id = 0 + self._pad_token_id = 1 + self._bos_token_id = None super().__init__( - eos_token=eos_token, - pad_token=pad_token, - bos_token=bos_token, # None by default; CharLevelTokenizer has no BOS - unk_token=unk_token, + bos_token=None, + eos_token=chr(0), + pad_token=chr(1), + unk_token=None, + add_bos_token=False, + add_eos_token=False, **kwargs, ) - # Cache some commonly used ids - self._eos_id = int(eos_token) if bos_token is None else self.vocab[eos_token] - self._pad_id = int(pad_token) - self._unk_id = int(unk_token) - - # ---- Char-level core logic --------------------------------------------- - @property def vocab_size(self) -> int: return self._vocab_size - def get_vocab(self) -> Dict[str, int]: - return dict(self.vocab) - - def clamp(self, n: int) -> int: - # Same as in CharLevelTokenizer: max(32, min(n, vocab_size)) - return max(32, min(n, self._vocab_size)) - - # HF will call this to get string "tokens" before converting to ids - def _tokenize(self, text: str, **kwargs) -> List[str]: - # CharLevelTokenizer.tokenize: - # list(np.frombuffer(text.encode('utf-8'), dtype=np.uint8)) - # We can replicate with Python directly: - byte_ids = list(text.encode("utf-8")) # each in [0, 255] - # Represent each id as a string token "id" - return [str(b) for b in byte_ids] - - def _convert_token_to_id(self, token: str) -> int: - # Tokens we produce are numeric strings "0", "1", ... - try: - idx = int(token) - except ValueError: - # Shouldn't really happen with our _tokenize, but just in case - return self._unk_id - # CharLevelTokenizer allows any 0..255; we don't clamp on encode. - # (clamp is only used on decode) - if 0 <= idx < self._vocab_size: - return idx - # If out-of-range, fall back to unk - return self._unk_id + def get_vocab(self) -> dict[str, int]: + vocab = dict(self._token_to_id) + vocab.update(self.added_tokens_encoder) + return vocab + + def tokenize(self, text: str, **kwargs) -> List[int]: + del kwargs + return list(text.encode("utf-8")) + + def _tokenize(self, text: str) -> List[str]: + return [str(byte) for byte in text.encode("utf-8")] + + def _convert_token_to_id(self, token: str | int) -> int: + if isinstance(token, int): + return _clamp_token_id(token) + if token in self.added_tokens_encoder: + return self.added_tokens_encoder[token] + if token in self._token_to_id: + return self._token_to_id[token] + return _clamp_token_id(int(token)) def _convert_id_to_token(self, index: int) -> str: - # Represent ids as numeric strings consistently - if 0 <= index < self._vocab_size: - return str(index) - return str(self._unk_id) - - def convert_tokens_to_string(self, tokens: List[str]) -> str: - # CharLevelTokenizer.detokenize: - # "".join(chr(clamp(token)) for token in token_ids) - chars: List[str] = [] - for tok in tokens: - try: - idx = int(tok) - except ValueError: - idx = self._unk_id - c = chr(self.clamp(idx)) - chars.append(c) - return "".join(chars) - - # ---- Special tokens / sequence helpers --------------------------------- + index = _clamp_token_id(index) + if index in self.added_tokens_decoder: + return self.added_tokens_decoder[index] + return self._id_to_token[index] + + def convert_tokens_to_string(self, tokens: List[str | int]) -> str: + byte_values = [] + for token in tokens: + if isinstance(token, str) and token in self.added_tokens_encoder: + token = self.added_tokens_encoder[token] + token_id = _clamp_token_id(int(token)) + byte_values.append(token_id) + return "".join(chr(byte) for byte in byte_values) def build_inputs_with_special_tokens( - self, - token_ids_0: List[int], - token_ids_1: Optional[List[int]] = None, + self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None ) -> List[int]: - """ - CharLevelTokenizer does *not* add BOS/EOS automatically, so we just - return the sequence as-is. - - We also do not support sentence pairs. - """ - if token_ids_1 is not None: - raise ValueError("Evo2Tokenizer (CharLevel) does not support sentence pairs.") - - return token_ids_0 + if token_ids_1 is None: + return list(token_ids_0) + return list(token_ids_0) + list(token_ids_1) def get_special_tokens_mask( - self, - token_ids_0: List[int], - token_ids_1: Optional[List[int]] = None, - already_has_special_tokens: bool = False, + self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False ) -> List[int]: - """ - Mark eos/eod (id 0) and pad (id 1) as special, everything else as 0. - """ - if token_ids_1 is not None: - raise ValueError("Evo2Tokenizer (CharLevel) does not support sentence pairs.") - if already_has_special_tokens: - # Just mark known special IDs - return [ - 1 if t in {self._eos_id, self._pad_id} else 0 - for t in token_ids_0 - ] - - # We don't auto-add any extra tokens, so same as above - return [ - 1 if t in {self._eos_id, self._pad_id} else 0 - for t in token_ids_0 - ] + return super().get_special_tokens_mask( + token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True + ) + if token_ids_1 is None: + return [0] * len(token_ids_0) + return [0] * (len(token_ids_0) + len(token_ids_1)) def create_token_type_ids_from_sequences( - self, - token_ids_0: List[int], - token_ids_1: Optional[List[int]] = None, + self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None ) -> List[int]: - """ - No token type IDs; everything is 0, like most decoder-only models. - """ - if token_ids_1 is not None: - raise ValueError("Evo2Tokenizer (CharLevel) does not support sentence pairs.") - - return [0] * len(token_ids_0) - - # ---- Saving / loading vocab -------------------------------------------- + length = len(token_ids_0) if token_ids_1 is None else len(token_ids_0) + len(token_ids_1) + return [0] * length - def save_vocabulary( - self, - save_directory: str, - filename_prefix: Optional[str] = None, - ) -> Tuple[str]: + def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> tuple[str]: if not os.path.isdir(save_directory): os.makedirs(save_directory, exist_ok=True) + file_name = (filename_prefix + "-" if filename_prefix else "") + "vocab.json" + path = os.path.join(save_directory, file_name) + with open(path, "w", encoding="utf-8") as f: + json.dump({str(i): i for i in range(self._vocab_size)}, f, ensure_ascii=False, indent=2) + return (path,) - vocab_file = ( - (filename_prefix + "-" if filename_prefix else "") - + VOCAB_FILES_NAMES["vocab_file"] - ) - vocab_path = os.path.join(save_directory, vocab_file) - - with open(vocab_path, "w", encoding="utf-8") as f: - json.dump(self.vocab, f, ensure_ascii=False, indent=2) - - return (vocab_path,) + def _decode( + self, + token_ids: List[int], + skip_special_tokens: bool = False, + clean_up_tokenization_spaces: bool = None, + spaces_between_special_tokens: bool = True, + **kwargs, + ) -> str: + del clean_up_tokenization_spaces, spaces_between_special_tokens + if skip_special_tokens: + token_ids = [ + token_id + for token_id in token_ids + if token_id not in {self.pad_token_id, self.eos_token_id} + ] + return "".join(chr(_clamp_token_id(token_id)) for token_id in token_ids) diff --git a/tests/models/evo2/test_modeling_evo2.py b/tests/models/evo2/test_modeling_evo2.py index b53ba1ec252f..acdd9c1a6ea9 100644 --- a/tests/models/evo2/test_modeling_evo2.py +++ b/tests/models/evo2/test_modeling_evo2.py @@ -1,73 +1,54 @@ -# coding=utf-8 -# Copyright 2025 the HuggingFace Team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# -# This file contains the *unit* tests for the Evo2 model, based on the -# shared CausalLMModelTester utilities. Integration tests that depend on -# public Hub checkpoints or special hardware can be added later once the -# official Evo2 weights are wired to this architecture. - import unittest +import pytest + +pytest.importorskip("parameterized") + from transformers import is_torch_available from transformers.testing_utils import require_torch from ...causal_lm_tester import CausalLMModelTest, CausalLMModelTester - if is_torch_available(): - import torch - - from transformers import ( - Evo2ForCausalLM, - Evo2Model, - ) + from transformers import Evo2ForCausalLM, Evo2Model class Evo2ModelTester(CausalLMModelTester): - """ - Minimal tester for Evo2 that plugs into the shared causal LM test - harness. We just need to specify the base and LM classes; the generic - tester will handle: - - building a small config - - instantiating Evo2Model / Evo2ForCausalLM - - running forward / loss / generate / save-load tests - """ - if is_torch_available(): base_model_class = Evo2Model - lm_model_class = Evo2ForCausalLM - # If you want to tweak the tiny test config (e.g. reduce sizes), - # you can override `prepare_config_and_inputs` or `get_config` here. + def __init__(self, parent, **kwargs): + super().__init__( + parent, + pad_token_id=1, + bos_token_id=None, + eos_token_id=0, + vocab_size=256, + hidden_size=32, + num_hidden_layers=2, + num_attention_heads=4, + num_key_value_heads=4, + intermediate_size=64, + use_input_mask=True, + use_token_type_ids=False, + use_labels=True, + **kwargs, + ) + + def get_config(self): + config = super().get_config() + config.layer_types = ["attention"] * config.num_hidden_layers + config.hyena_filters = 8 + config.hyena_kernel_size = 3 + config.hyena_order = 2 + config.tie_word_embeddings = True + return config @require_torch class Evo2ModelTest(CausalLMModelTest, unittest.TestCase): - """ - Generic causal LM tests for Evo2. - - These tests: - - instantiate tiny Evo2 configs - - run forward passes - - check loss computation - - check generation API - - test save / load / from_pretrained with local weights - """ - model_tester_class = Evo2ModelTester - # Pipelines for this model are not wired yet; skip pipeline tests. - def is_pipeline_test_to_skip( - self, - pipeline_test_case_name, - config_class, - model_architecture, - tokenizer_name, - image_processor_name, - feature_extractor_name, - processor_name, - ): - return True + +if __name__ == "__main__": + unittest.main() diff --git a/tests/models/evo2/test_tokenization_evo2.py b/tests/models/evo2/test_tokenization_evo2.py index 1d33cbe43e2c..a5d50cd0be1a 100644 --- a/tests/models/evo2/test_tokenization_evo2.py +++ b/tests/models/evo2/test_tokenization_evo2.py @@ -1,110 +1,30 @@ -# Copyright 2025 -# -# Licensed under the Apache License, Version 2.0 (the "License"); +import pytest -import json -import os -import tempfile -import unittest +from transformers import Evo2Tokenizer -from transformers.models.evo2.tokenization_evo2 import VOCAB_FILES_NAMES, Evo2Tokenizer -from transformers.tokenization_utils import PreTrainedTokenizer -from transformers.tokenization_utils_base import PreTrainedTokenizerBase +@pytest.fixture +def tokenizer(): + return Evo2Tokenizer() -class Evo2TokenizationTest(unittest.TestCase): - tokenizer_class = Evo2Tokenizer - @classmethod - def setUpClass(cls): - super().setUpClass() +def test_round_trip_ascii(tokenizer): + text = "Hello, Evo2!" + encoded = tokenizer(text)["input_ids"] + expected = list(text.encode("utf-8")) + assert encoded == expected + decoded = tokenizer.decode(encoded) + assert decoded == text - cls.tmpdirname = tempfile.mkdtemp() - # Build a simple numeric vocab: "0" -> 0, "1" -> 1, ..., "255" -> 255 - vocab_size = 256 - vocab = {str(i): i for i in range(vocab_size)} +def test_clamp_behavior(tokenizer): + tokens = [0, 1, 255, 300, -5] + decoded = tokenizer.decode(tokens) + expected = "".join(chr(max(0, min(255, token))) for token in tokens) + assert decoded == expected - cls.vocab_file = os.path.join(cls.tmpdirname, VOCAB_FILES_NAMES["vocab_file"]) - with open(cls.vocab_file, "w", encoding="utf-8") as vocab_writer: - json.dump(vocab, vocab_writer) - def get_tokenizers(cls, **kwargs) -> list[PreTrainedTokenizerBase]: - return [cls.get_tokenizer(**kwargs)] - - @classmethod - def get_tokenizer(cls, pretrained_name=None, **kwargs) -> PreTrainedTokenizer: - pretrained_name = pretrained_name or cls.tmpdirname - return cls.tokenizer_class.from_pretrained(pretrained_name, **kwargs) - - def test_tokenizer_single_example(self): - # Direct constructor - tokenizer = self.tokenizer_class(self.vocab_file, vocab_size=256) - - text = "ABC" - # ASCII codes: A=65, B=66, C=67 - tokens = tokenizer.tokenize(text) - self.assertListEqual(tokens, ["65", "66", "67"]) - self.assertListEqual(tokenizer.convert_tokens_to_ids(tokens), [65, 66, 67]) - - def test_tokenizer_encode_single(self): - tokenizer = self.tokenizer_class(self.vocab_file, vocab_size=256) - - text = "ABC" - # encode() should NOT add BOS/EOS for this char-level tokenizer - self.assertListEqual(tokenizer.encode(text), [65, 66, 67]) - - def test_tokenizer_call_no_pad(self): - tokenizer = self.tokenizer_class(self.vocab_file, vocab_size=256) - - seq_batch = ["AB", "XYZ"] - encoded = tokenizer(seq_batch, padding=False)["input_ids"] - - # "AB" -> 65,66 ; "XYZ" -> 88,89,90 - self.assertListEqual(encoded, [[65, 66], [88, 89, 90]]) - - def test_tokenizer_call_pad(self): - tokenizer = self.tokenizer_class(self.vocab_file, vocab_size=256) - - seq_batch = ["AB", "XYZ"] - encoded = tokenizer(seq_batch, padding=True)["input_ids"] - - # pad_token_id should be 1, so shorter seq gets padded with 1 - # max length = 3 - self.assertEqual(tokenizer.pad_token_id, 1) - self.assertListEqual(encoded, [[65, 66, 1], [88, 89, 90]]) - - def test_detokenize_roundtrip(self): - tokenizer = self.tokenizer_class(self.vocab_file, vocab_size=256) - - text = "Hello!" - ids = tokenizer.encode(text) - decoded = tokenizer.decode(ids, skip_special_tokens=False) - - # Because of clamp, some low values could be bumped, but ASCII letters - # should round-trip cleanly. - self.assertEqual(decoded, text) - - def test_add_tokens(self): - tokenizer = self.tokenizer_class(self.vocab_file, vocab_size=256) - - vocab_size = len(tokenizer) - self.assertEqual(tokenizer.add_tokens(""), 0) - self.assertEqual(tokenizer.add_tokens("testtoken"), 1) - self.assertEqual(tokenizer.add_tokens(["testtoken1", "testtoken2"]), 2) - self.assertEqual(len(tokenizer), vocab_size + 3) - - self.assertEqual(tokenizer.add_special_tokens({}), 0) - self.assertEqual(tokenizer.add_special_tokens({"bos_token": "[BOS]", "eos_token": "[EOS]"}), 2) - - # additional_special_tokens logic - self.assertRaises(AssertionError, tokenizer.add_special_tokens, {"additional_special_tokens": ""}) - self.assertEqual(tokenizer.add_special_tokens({"additional_special_tokens": [""]}), 1) - self.assertEqual( - tokenizer.add_special_tokens({"additional_special_tokens": ["", ""]}), 2 - ) - self.assertIn("", tokenizer.special_tokens_map["additional_special_tokens"]) - self.assertIsInstance(tokenizer.special_tokens_map["additional_special_tokens"], list) - self.assertGreaterEqual(len(tokenizer.special_tokens_map["additional_special_tokens"]), 2) - - self.assertEqual(len(tokenizer), vocab_size + 8) +def test_tokenize_returns_bytes(tokenizer): + text = "ABcd" + tokens = tokenizer.tokenize(text) + assert tokens == list(text.encode("utf-8")) From 29ddb218a57bbd2eb4987487f5698760481a3c56 Mon Sep 17 00:00:00 2001 From: McClain Thiel Date: Tue, 18 Nov 2025 18:48:34 +0000 Subject: [PATCH 4/6] I think the logits match now but would like to verify --- .../models/evo2/configuration_evo2.py | 66 ++++- .../models/evo2/convert_evo2_weights.py | 203 +++++++++++++ src/transformers/models/evo2/modeling_evo2.py | 213 +++++++++++--- .../evo2/evo2_1b_base_ground_truth_logits.pt | Bin 0 -> 45898 bytes tests/models/evo2/test_modeling_evo2.py | 271 +++++++++++++++++- 5 files changed, 708 insertions(+), 45 deletions(-) create mode 100644 src/transformers/models/evo2/convert_evo2_weights.py create mode 100644 tests/models/evo2/evo2_1b_base_ground_truth_logits.pt diff --git a/src/transformers/models/evo2/configuration_evo2.py b/src/transformers/models/evo2/configuration_evo2.py index 01aa5256088b..a301bfb6dc4e 100644 --- a/src/transformers/models/evo2/configuration_evo2.py +++ b/src/transformers/models/evo2/configuration_evo2.py @@ -15,14 +15,74 @@ class Evo2Config(PretrainedConfig): - r"""Configuration class for the Evo2 model.""" + r""" + This is the configuration class to store the configuration of a [`Evo2Model`]. It is used to instantiate an Evo2 + model according to the specified arguments, defining the model architecture. Instantiating a configuration with the + defaults will yield a similar configuration to that of the Evo2-1b-base model. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + vocab_size (`int`, *optional*, defaults to 512): + Vocabulary size of the Evo2 model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`Evo2Model`] + hidden_size (`int`, *optional*, defaults to 2048): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*): + Dimension of the MLP representations. + num_hidden_layers (`int`, *optional*, defaults to 24): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 16): + Number of attention heads for each attention layer in the Transformer encoder. + num_key_value_heads (`int`, *optional*): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=None`, the model will use the same number of key/value heads as the number of + query heads. + max_position_embeddings (`int`, *optional*, defaults to 2048): + The maximum sequence length that this model might ever be used with. + rope_theta (`float`, *optional*, defaults to 1000000.0): + The base period of the RoPE embeddings. + rms_norm_eps (`float`, *optional*, defaults to 1e-6): + The epsilon used by the rms normalization layers. + attn_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + hidden_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the hidden units. + mlp_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the MLP layers. + layer_types (`Sequence[str]`, *optional*): + List of layer types ("attention" or "hyena") for each layer. If None, defaults to all "attention". + hyena_filters (`int`, *optional*, defaults to 256): + Number of Hyena filter groups. + hyena_kernel_size (`int`, *optional*, defaults to 8): + Kernel size for the short convolution in Hyena. + hyena_hidden_size (`int`, *optional*): + Hidden size for Hyena layers. + hyena_order (`int`, *optional*, defaults to 4): + Order of the Hyena recurrence. + hyena_flip_x1x2 (`bool`, *optional*, defaults to False): + Whether to flip x1 and x2 in the Hyena gating mechanism. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + use_cache (`bool`, *optional*, defaults to True): + Whether or not the model should return the last key/values attentions (not used by all models). + pad_token_id (`int`, *optional*, defaults to 1): + Padding token id. + bos_token_id (`int`, *optional*): + Beginning of stream token id. + eos_token_id (`int`, *optional*, defaults to 0): + End of stream token id. + tie_word_embeddings (`bool`, *optional*, defaults to True): + Whether to tie weight embeddings + """ model_type = "evo2" keys_to_ignore_at_inference = ["past_key_values"] def __init__( self, - vocab_size: int = 256, + vocab_size: int = 512, hidden_size: int = 2048, intermediate_size: Optional[int] = None, num_hidden_layers: int = 24, @@ -39,6 +99,7 @@ def __init__( hyena_kernel_size: int = 8, hyena_hidden_size: Optional[int] = None, hyena_order: int = 4, + hyena_flip_x1x2: bool = False, initializer_range: float = 0.02, use_cache: bool = True, pad_token_id: int = 1, @@ -71,6 +132,7 @@ def __init__( self.hyena_kernel_size = hyena_kernel_size self.hyena_hidden_size = hyena_hidden_size if hyena_hidden_size is not None else hidden_size self.hyena_order = hyena_order + self.hyena_flip_x1x2 = hyena_flip_x1x2 self.initializer_range = initializer_range self.use_cache = use_cache diff --git a/src/transformers/models/evo2/convert_evo2_weights.py b/src/transformers/models/evo2/convert_evo2_weights.py new file mode 100644 index 000000000000..a948c660b728 --- /dev/null +++ b/src/transformers/models/evo2/convert_evo2_weights.py @@ -0,0 +1,203 @@ +# Copyright 2024 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import argparse +import os + +import torch +from huggingface_hub import hf_hub_download + +from transformers import Evo2Config, Evo2ForCausalLM + + +def convert_original_weights_to_transformers(original_weights): + """Convert weights from original Evo2 format to transformers format.""" + + # Create config based on the original model architecture (Evo2-1b-base) + # vocab_size=512, hidden_size=1920, 25 layers (21 hyena + 4 attention every 7th layer starting from 3) + layer_types = [] + for i in range(25): + if i % 7 == 3: + layer_types.append("attention") + else: + layer_types.append("hyena") + + config = Evo2Config( + vocab_size=512, + hidden_size=1920, + intermediate_size=5120, + num_hidden_layers=25, + num_attention_heads=15, # 1920 / 128 + num_key_value_heads=15, + layer_types=layer_types, + hyena_filters=128, # Number of filter groups + hyena_order=3, # 5760 / 1920 = 3 + hyena_kernel_size=3, # Short filter kernel size + tie_word_embeddings=True, + ) + + # Initialize new state dict + new_state_dict = {} + + # Convert embeddings + new_state_dict["model.embed_tokens.weight"] = original_weights["embedding_layer.weight"] + new_state_dict["lm_head.weight"] = original_weights["unembed.weight"] + + # Convert each layer + for layer_idx in range(25): + layer_type = layer_types[layer_idx] + orig_prefix = f"blocks.{layer_idx}" + new_prefix = f"model.layers.{layer_idx}.block" + + # Common components: norms and MLP + new_state_dict[f"model.layers.{layer_idx}.block.input_layernorm.weight"] = original_weights[ + f"{orig_prefix}.pre_norm.scale" + ] + new_state_dict[f"model.layers.{layer_idx}.block.post_attention_layernorm.weight"] = original_weights[ + f"{orig_prefix}.post_norm.scale" + ] + + # MLP layers + # Original: l1 (gate), l2 (up), l3 (down) + new_state_dict[f"{new_prefix}.mlp.gate_proj.weight"] = original_weights[f"{orig_prefix}.mlp.l1.weight"] + new_state_dict[f"{new_prefix}.mlp.up_proj.weight"] = original_weights[f"{orig_prefix}.mlp.l2.weight"] + new_state_dict[f"{new_prefix}.mlp.down_proj.weight"] = original_weights[f"{orig_prefix}.mlp.l3.weight"] + + if layer_type == "attention": + # Convert attention layer + # Original uses Wqkv (combined), we need separate q_proj, k_proj, v_proj + wqkv = original_weights[f"{orig_prefix}.inner_mha_cls.Wqkv.weight"] + hidden_size = config.hidden_size + + # Split Wqkv into q, k, v + q, k, v = torch.split(wqkv, hidden_size, dim=0) + new_state_dict[f"model.layers.{layer_idx}.block.attention.q_proj.weight"] = q + new_state_dict[f"model.layers.{layer_idx}.block.attention.k_proj.weight"] = k + new_state_dict[f"model.layers.{layer_idx}.block.attention.v_proj.weight"] = v + + # Output projection + new_state_dict[f"model.layers.{layer_idx}.block.attention.o_proj.weight"] = original_weights[ + f"{orig_prefix}.inner_mha_cls.out_proj.weight" + ] + + # Load rotary embedding inv_freq from original weights + if f"{orig_prefix}.inner_mha_cls.rotary_emb.inv_freq" in original_weights: + new_state_dict[f"model.layers.{layer_idx}.block.attention.rotary_emb.inv_freq"] = original_weights[ + f"{orig_prefix}.inner_mha_cls.rotary_emb.inv_freq" + ] + + else: + # Convert hyena filter layer + new_state_dict[f"model.layers.{layer_idx}.block.filter.projections.weight"] = original_weights[ + f"{orig_prefix}.projections.weight" + ] + new_state_dict[f"model.layers.{layer_idx}.block.filter.short_filter_weight"] = original_weights[ + f"{orig_prefix}.filter.short_filter_weight" + ] + new_state_dict[f"model.layers.{layer_idx}.block.filter.out_filter_dense.weight"] = original_weights[ + f"{orig_prefix}.out_filter_dense.weight" + ] + new_state_dict[f"model.layers.{layer_idx}.block.filter.out_filter_dense.bias"] = original_weights[ + f"{orig_prefix}.out_filter_dense.bias" + ] + + # Long filter parameters (FIR or IIR) + # These are not standard nn.Parameters in our implementation but we can load them into the state dict + # and then manually assign them in the model if needed, or just save them as part of the state dict + # since we registered them as buffers/parameters in the model (or should have). + # In our implementation, they are initialized as None. We need to make sure they are loaded. + + if f"{orig_prefix}.filter.h" in original_weights: + new_state_dict[f"model.layers.{layer_idx}.block.filter.h"] = original_weights[ + f"{orig_prefix}.filter.h" + ] + if f"{orig_prefix}.filter.D" in original_weights: + new_state_dict[f"model.layers.{layer_idx}.block.filter.D"] = original_weights[ + f"{orig_prefix}.filter.D" + ] + if f"{orig_prefix}.filter.log_poles" in original_weights: + new_state_dict[f"model.layers.{layer_idx}.block.filter.log_poles"] = original_weights[ + f"{orig_prefix}.filter.log_poles" + ] + if f"{orig_prefix}.filter.residues" in original_weights: + new_state_dict[f"model.layers.{layer_idx}.block.filter.residues"] = original_weights[ + f"{orig_prefix}.filter.residues" + ] + + # Final norm + new_state_dict["model.norm.weight"] = original_weights["norm.scale"] + + return new_state_dict, config + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--model_id", + default="arcinstitute/evo2_1b_base", + help="Hub model id", + ) + parser.add_argument( + "--output_dir", + default="evo2_converted", + help="The output directory to save the converted model", + ) + args = parser.parse_args() + + print(f"Downloading weights for {args.model_id}...") + weights_path = hf_hub_download(args.model_id, "evo2_1b_base.pt") + original_weights = torch.load(weights_path, map_location="cpu", weights_only=False) + + print("Converting weights...") + new_state_dict, config = convert_original_weights_to_transformers(original_weights) + + print("Loading into Evo2ForCausalLM...") + model = Evo2ForCausalLM(config) + + # Load state dict (strict=False because Hyena layers have optional parameters that might be missing if unused) + # But we want to make sure we load everything we have. + missing_keys, unexpected_keys = model.load_state_dict(new_state_dict, strict=False) + + print(f"Missing keys: {len(missing_keys)}") + if len(missing_keys) > 0: + print(missing_keys[:10]) + print(f"Unexpected keys: {len(unexpected_keys)}") + if len(unexpected_keys) > 0: + print(unexpected_keys[:10]) + + # Manually assign filter parameters (h, D, log_poles, residues) if they were not loaded by load_state_dict + # because they were None in the model init. + # Actually, since we put them in new_state_dict, load_state_dict might complain if the model attributes are None. + # We might need to initialize them in the model first or just assign them directly. + + for layer_idx in range(config.num_hidden_layers): + if config.layer_types[layer_idx] == "hyena": + filter_module = model.model.layers[layer_idx].block.filter + orig_prefix = f"blocks.{layer_idx}.filter" + + if f"{orig_prefix}.h" in original_weights: + filter_module.h = nn.Parameter(original_weights[f"{orig_prefix}.h"]) + if f"{orig_prefix}.D" in original_weights: + filter_module.D = nn.Parameter(original_weights[f"{orig_prefix}.D"]) + if f"{orig_prefix}.log_poles" in original_weights: + filter_module.log_poles = nn.Parameter(original_weights[f"{orig_prefix}.log_poles"]) + if f"{orig_prefix}.residues" in original_weights: + filter_module.residues = nn.Parameter(original_weights[f"{orig_prefix}.residues"]) + + print(f"Saving to {args.output_dir}...") + model.save_pretrained(args.output_dir) + print("Done.") + + +if __name__ == "__main__": + main() diff --git a/src/transformers/models/evo2/modeling_evo2.py b/src/transformers/models/evo2/modeling_evo2.py index 3928e23ee00a..6138d1240946 100644 --- a/src/transformers/models/evo2/modeling_evo2.py +++ b/src/transformers/models/evo2/modeling_evo2.py @@ -4,7 +4,7 @@ import math from collections.abc import Callable -from typing import Optional, Tuple +from typing import Optional, Tuple, Union import torch import torch.nn.functional as F @@ -81,7 +81,8 @@ def __init__(self, config: Evo2Config, device=None): rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] inv_freq, self.attention_scaling = rope_init_fn(self.config, device) - self.register_buffer("inv_freq", inv_freq, persistent=False) + # Register inv_freq as persistent so it can be loaded from checkpoints + self.register_buffer("inv_freq", inv_freq, persistent=True) self.original_inv_freq = inv_freq @staticmethod @@ -123,19 +124,29 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: class Evo2ParallelGatedMLP(nn.Module): - def __init__(self, config: Evo2Config): + def __init__(self, config: Evo2Config, layer_idx: int): super().__init__() self.hidden_size = config.hidden_size self.intermediate_size = config.intermediate_size + self.layer_idx = layer_idx + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) self.dropout = nn.Dropout(config.mlp_dropout) + + # Evo2 style: only layer 0 has activation, rest use identity + self.use_activation = (layer_idx == 0) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - gated = F.silu(self.gate_proj(hidden_states)) - up = self.up_proj(hidden_states) - hidden_states = self.down_proj(gated * up) + z1 = self.gate_proj(hidden_states) + z2 = self.up_proj(hidden_states) + + # Apply SiLU only for layer 0, identity for others (Evo2 style) + if self.use_activation: + z1 = F.silu(z1) + + hidden_states = self.down_proj(z1 * z2) hidden_states = self.dropout(hidden_states) return hidden_states @@ -210,32 +221,146 @@ class Evo2HyenaFilter(nn.Module): def __init__(self, config: Evo2Config): super().__init__() self.hidden_size = config.hidden_size + self.num_attention_heads = config.num_attention_heads self.order = config.hyena_order - self.filter_channels = config.hyena_filters - self.kernel_size = config.hyena_kernel_size - - self.in_proj = nn.Linear(self.hidden_size, self.filter_channels * self.order, bias=False) - self.conv = nn.Conv1d( - in_channels=self.filter_channels, - out_channels=self.filter_channels, - kernel_size=self.kernel_size, - groups=self.filter_channels, - padding=self.kernel_size - 1, + self.short_filter_length = config.hyena_kernel_size + self.hyena_flip_x1x2 = config.hyena_flip_x1x2 + + # Projections: hidden_size -> 3 * hidden_size (for x, y, z) + self.projections = nn.Linear(self.hidden_size, self.order * self.hidden_size, bias=False) + + # Short filter (Conv1d) + self.short_filter_weight = nn.Parameter( + torch.randn(self.order * self.hidden_size, 1, self.short_filter_length) ) - self.out_proj = nn.Linear(self.filter_channels * self.order, self.hidden_size, bias=False) - self.activation = nn.SiLU() + + # Output projection + self.out_filter_dense = nn.Linear(self.hidden_size, self.hidden_size, bias=True) + + # Long filter parameters - check if FIR or IIR based on config + self.hyena_filter_groups = config.hyena_filters + self.channels_per_group = self.hidden_size // self.hyena_filter_groups + + # These parameters are optional and will be set dynamically when loading weights + # We register them as None initially + self.h = None + self.D = None + self.log_poles = None + self.residues = None def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: batch, seq_len, _ = hidden_states.shape - projected = self.in_proj(hidden_states) - projected = projected.view(batch, seq_len, self.order, self.filter_channels).permute(0, 2, 3, 1) - conv_input = projected.reshape(batch * self.order, self.filter_channels, seq_len) - conv_output = self.conv(conv_input) - conv_output = conv_output[:, :, :seq_len] - conv_output = conv_output.view(batch, self.order, self.filter_channels, seq_len).permute(0, 3, 1, 2) - conv_output = conv_output.reshape(batch, seq_len, self.order * self.filter_channels) - conv_output = self.activation(conv_output) - return self.out_proj(conv_output) + + # Project to 3 * hidden_size + u = self.projections(hidden_states) # [batch, seq_len, 3 * hidden_size] + + # Transpose for conv1d: [batch, 3 * hidden_size, seq_len] + u = u.transpose(1, 2) + + # Apply short filter (depthwise conv1d) + u = F.conv1d( + u.to(torch.float32), + self.short_filter_weight.to(torch.float32), + padding=self.short_filter_length - 1, + groups=self.order * self.hidden_size + )[:, :, :seq_len] + u = u.to(hidden_states.dtype) + + # Apply interleave to de-interleave the channels (following vortex model.py line 645) + # This reorders from [x1, x2, v, x1, x2, v, ...] to [x1, x1, ..., x2, x2, ..., v, v, ...] + # u is [batch, 3 * hidden_size, seq_len] + u_x1 = u[:, 0::3, :] # Every 3rd channel starting from 0 + u_x2 = u[:, 1::3, :] # Every 3rd channel starting from 1 + u_v = u[:, 2::3, :] # Every 3rd channel starting from 2 + u = torch.cat([u_x1, u_x2, u_v], dim=1) # [batch, 3 * hidden_size, seq_len] + + # Split into x2, x1, v + # Note: Vortex column_split returns x2, x1, v in that order + # u is now [batch, 3 * hidden_size, seq_len] + # We split along dim 1 + x2, x1, v = u.split([self.hidden_size, self.hidden_size, self.hidden_size], dim=1) + + # Vortex HyenaCascade.sequential_forward logic: + # x2, x1, v = column_split(...) + # if self.hyena_flip_x1x2: x1, x2 = x2, x1 + if self.hyena_flip_x1x2: + x1, x2 = x2, x1 + + # Compute x1 * v (element-wise) + x1v = x1 * v + + # Apply long filter to x1v + # Both FIR and IIR use FFT convolution (IIR converts to FIR first) + if self.h is not None or (self.log_poles is not None and self.residues is not None): + # Compute filter h + if self.h is not None: + # FIR filter: use pre-computed h + h = self.h + else: + # IIR filter: convert to FIR using modal form (model.py compute_filter line 703) + # h = (residues[..., None] * (log_poles * self.t).exp()).sum(1)[None] + # Create time vector t = [0, 1, 2, ..., seq_len-1] + t = torch.arange(seq_len, device=hidden_states.device, dtype=torch.float32)[None, None, :] # [1, 1, L] + + # log_poles: [hidden_size, state_dim, 1], residues: [hidden_size, state_dim] + log_poles = self.log_poles.to(torch.float32) + residues = self.residues.to(torch.float32) + + # Compute h = sum(residues * exp(log_poles * t), dim=state_dim) + # Broadcasting: log_poles [D, S, 1] * t [1, 1, L] = [D, S, L] + h = (residues.unsqueeze(-1) * torch.exp(log_poles * t)).sum(dim=1) # [D, L] + h = h.unsqueeze(0) # [1, D, L] - matches FIR filter shape [num_groups, hidden_size/num_groups, L] + + # FFT convolution following vortex engine.py parallel_iir + fft_size = 2 * seq_len + + # Prepare filter: h shape is [num_groups, 1, filter_len] or [1, hidden_size, 1, filter_len] for IIR + h = h.to(torch.float32) + H = torch.fft.rfft(h, n=fft_size) / fft_size # [num_groups, 1, fft_len] + + # Apply adjust_filter_shape_for_broadcast logic + H = H.squeeze() # [num_groups, fft_len] for FIR or [hidden_size, fft_len] for IIR + + # For x1v: [batch, hidden_size, seq_len], we need H: [batch, hidden_size, fft_len] + if H.shape[0] != self.hidden_size: + # FIR case: Repeat H for each channel in group + # hidden_size = num_groups * channels_per_group + if self.hyena_filter_groups > 1: + H = H.repeat_interleave(self.channels_per_group, dim=0) # [hidden_size, fft_len] + # else: IIR case, H is already [hidden_size, fft_len] + + H = H.unsqueeze(0) # [1, hidden_size, fft_len] + + # FFT of input - use torch.fft.fft like original, not rfft + X_s = torch.fft.fft(x1v.to(torch.float32), n=fft_size) # [batch, hidden_size, fft_size] + X = X_s[..., : H.shape[-1]] # [batch, hidden_size, fft_len] + + # Multiply in frequency domain + y = torch.fft.irfft(X * H, n=fft_size, norm='forward')[..., :seq_len] + + # Add bias (direct connection) - note the order matches original: (y + x1v * D) * x2 + # Ensure both y and x1v are in float32 for numerical stability + if self.D is not None: + x1v_f32 = x1v.to(torch.float32) + D_f32 = self.D.to(torch.float32) + y = y + x1v_f32 * D_f32.unsqueeze(0).unsqueeze(-1) + + # Convert back to original dtype (matching Mamba2 pattern) + y = y.to(hidden_states.dtype) + else: + # No long filter + y = x1v + + # Apply gating: x2 * y + z = x2 * y + + # Transpose back: [batch, hidden_size, seq_len] -> [batch, seq_len, hidden_size] + z = z.transpose(1, 2) + + # Output projection + out = self.out_filter_dense(z) + + return out class Evo2AttentionBlock(nn.Module): @@ -244,7 +369,7 @@ def __init__(self, config: Evo2Config, layer_idx: int): self.attention = Evo2Attention(config, layer_idx) self.input_layernorm = Evo2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = Evo2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.mlp = Evo2ParallelGatedMLP(config) + self.mlp = Evo2ParallelGatedMLP(config, layer_idx) self.hidden_dropout = nn.Dropout(config.hidden_dropout) def forward( @@ -257,7 +382,8 @@ def forward( use_cache: bool, cache_position: Optional[torch.LongTensor], ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]: - residual = hidden_states + # Keep residual in float32 for numerical stability (Mamba2 technique) + residual = hidden_states.to(torch.float32) if hidden_states.dtype == torch.bfloat16 else hidden_states hidden_states = self.input_layernorm(hidden_states) attn_output, attn_weights, present_kv = self.attention( hidden_states, @@ -268,23 +394,24 @@ def forward( use_cache=use_cache, cache_position=cache_position, ) - hidden_states = residual + self.hidden_dropout(attn_output) + hidden_states = (residual + self.hidden_dropout(attn_output).to(residual.dtype)).to(hidden_states.dtype) - residual = hidden_states + # Keep residual in float32 for numerical stability + residual = hidden_states.to(torch.float32) if hidden_states.dtype == torch.bfloat16 else hidden_states hidden_states = self.post_attention_layernorm(hidden_states) hidden_states = self.mlp(hidden_states) - hidden_states = residual + self.hidden_dropout(hidden_states) + hidden_states = (residual + self.hidden_dropout(hidden_states).to(residual.dtype)).to(hidden_states.dtype) return hidden_states, attn_weights, present_kv class Evo2HyenaBlock(nn.Module): - def __init__(self, config: Evo2Config): + def __init__(self, config: Evo2Config, layer_idx: int): super().__init__() self.input_layernorm = Evo2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.filter = Evo2HyenaFilter(config) self.post_attention_layernorm = Evo2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.mlp = Evo2ParallelGatedMLP(config) + self.mlp = Evo2ParallelGatedMLP(config, layer_idx) self.hidden_dropout = nn.Dropout(config.hidden_dropout) def forward( @@ -298,15 +425,17 @@ def forward( cache_position: Optional[torch.LongTensor], ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]: del attention_mask, past_key_value, output_attentions, use_cache, cache_position, position_ids - residual = hidden_states + # Keep residual in float32 for numerical stability (Mamba2 technique) + residual = hidden_states.to(torch.float32) if hidden_states.dtype == torch.bfloat16 else hidden_states hidden_states = self.input_layernorm(hidden_states) hidden_states = self.filter(hidden_states) - hidden_states = residual + self.hidden_dropout(hidden_states) + hidden_states = (residual + self.hidden_dropout(hidden_states).to(residual.dtype)).to(hidden_states.dtype) - residual = hidden_states + # Keep residual in float32 for numerical stability + residual = hidden_states.to(torch.float32) if hidden_states.dtype == torch.bfloat16 else hidden_states hidden_states = self.post_attention_layernorm(hidden_states) hidden_states = self.mlp(hidden_states) - hidden_states = residual + self.hidden_dropout(hidden_states) + hidden_states = (residual + self.hidden_dropout(hidden_states).to(residual.dtype)).to(hidden_states.dtype) return hidden_states, None, None @@ -318,7 +447,7 @@ def __init__(self, config: Evo2Config, layer_type: str, layer_idx: int): if layer_type == "attention": self.block = Evo2AttentionBlock(config, layer_idx) else: - self.block = Evo2HyenaBlock(config) + self.block = Evo2HyenaBlock(config, layer_idx) def forward( self, @@ -560,9 +689,11 @@ def forward( hidden_states = outputs[0] if isinstance(logits_to_keep, int): slice_indices = slice(-logits_to_keep, None) if logits_to_keep > 0 else slice(None) - logits = self.lm_head(hidden_states[:, slice_indices, :]) + # Mamba2 technique: convert to lm_head dtype, then to float32 for numerical stability + logits = self.lm_head(hidden_states[:, slice_indices, :].to(self.lm_head.weight.dtype)).float() else: - logits = self.lm_head(hidden_states[:, logits_to_keep, :]) + # Mamba2 technique: convert to lm_head dtype, then to float32 for numerical stability + logits = self.lm_head(hidden_states[:, logits_to_keep, :].to(self.lm_head.weight.dtype)).float() loss = None if labels is not None: diff --git a/tests/models/evo2/evo2_1b_base_ground_truth_logits.pt b/tests/models/evo2/evo2_1b_base_ground_truth_logits.pt new file mode 100644 index 0000000000000000000000000000000000000000..b7bb1bc0605a03491800c920d9c12f06f575de81 GIT binary patch literal 45898 zcmeHQYj7OZl^$ENB?~)#5c~rCB8g?o$darFLe{h-gS8N<9zR-L-S^yczVn^aS97QD+_~Met)r+okti)qWc{s6G$zu!GBv5{o>Wiw zNIJD+I5Rrbm$HUOt^U+tX2*avvazq*>aN`Rz+ih@UE;pU{5>>tx0xvKDueJ3X8O{D zsiE!%)8>q>xu8&^n=`xSfNCUdj;4ot(<5fl*UaMbF33I6x@GHjbC%8=t`52r^)2dxov&Tg};ynohi^SW{;;=s_(Q7bjjH)781wKBuK z{gtUvYhZAsEH#|&866me21yT%WQJ3_YRb%c<>e092q3$6q|Hx}OW9pR))z@tX3p=L z0pY651zkl@_U;@t7j_g|=AtijBsz-6%m6ogZz=8{fxY1eFEp#OxqtRI5+bHKLt!A!SRU4Jp^awlrFd1L>=dlIFV zc@u`4(NTO)V&jd5aDmU=0M3kJS}Jnyv3_D*xFCvb!LCm48zW!^=B z=R1Kv4=Ois2f5#EnfG*>_l}N^R&MJkDgEH1lJq=yW#GOouya-=ZqT}iZ|AJMKQsJb zx0TwJ9v*?$QG46M>^@qJF;9N{n@8|PiF2J{z6{zAJ64YHwQ|o?lla1wf(ke zwx4&koOO$X;aX4`4He`a>*E%#4km^(PFQ$-x5jtpsOWYw4{YEk8> z#mnWL|0wxq%$0vWY?!e(Yi~U$PBfe!%7*KI65`@cThgsdElX4qQbnORX_3Jz$vX2B ziu7kikFnD3GH!A>+Kpw7|1s$2Cqye85rxuUcY1dgK8%t!TEZB={gpnCyJzK|U%78PXl6|?9-mGCSnFm_b zOsLMW2`BSl9_i0E$)ZPP4~k4R8AC=Ui7gVxnXraPPvh+AEt{b= zaF4PzM@D+t&WIJz|NqLk(TUS(EOpAGPXFHmN6g@mctd;#`u{5XmtwR16X#t)@^E`g zO?KAV8j1FkR6NR5SG9g#{yNz=Vb_Tw@e6>K4_Zs!`8*?be#NG2&E z2h~`|rOM6-Kbwl1h#Zwm$LTd0FGw#D{?f$#CgMR@KM$PwD5_k0mowP_6AroN8!95s zH`KJZtqq!Q@Xo+^=Rd;gnD3I`75%mf=0-Q-*$Tychm`Am!axS*1+43JH|BO>R!*FlR z*qgPt9>fw2=ZC`i6YnNJU3o3ux5d^uuxNxWg~VfOGi zDk6w2NU!=%ha5GoKh9rEdEzHFl++`i!9N7fpZI^y%eVH7Sn154(EMqqGk==2Y5t_m zohi}hPY#66pX58om@jUDtF;CGzTo}eg#KUeet+_HrT-rnNw=pT5O0e6ocYr?VE*)Z zpZ=dchwA-*g(JKkaXoH<=q+%izP`@+jQ@FTa4E`jVlRCD?|}Y~(N??br*Qm#T(r64 z|AVl8qQ_Z3^=82ODZkGEdiLQ69Us@?7PuB#;Me2${a(}!zvGIsd2v+tYLV%T__DFW zeh9ul#ab79e~RT%=;yD)`U#44K%5Y_8_n?j>6@a--tTIe4NI00ugc^ftd;g$l0T!L zM=qq-=B@kju`UWW3WoWnR!%L_)VwE~IYo^YI~q#QWHLTFL-blcWBp6P$N#s(_H=Se_co)Vw)|%dNb8&nkCjr*icFOVP$>#naILm%8gG4=dOI4=Yy zWaFAF*&j8s9i*M~*;%t*Jwx_qPCXC3MiJP9*(3%nvPO2Sla4K{@z-+JjOHjDn@Z+n ztw~g2K`W5__2{3Q z zrs};u(d!duiRM8lKWwliYSwW2zspXa<-FJfP-q|NlPp{|dMN-vaZf*lPIvnUl)%XHJQ4 zJN^Gn`22Lp^Z5yTc6jd&M_dX2tvJWH1@hYh=Lm@1}UZRFcN;{29f*jG=d9uE|p3Q73B#+&^^`uK#rZ zWYXT3#OKc@VgAp6d#D;O_W!pxa$5~RHymTbtxt+2RoW1JBGO=>r9GSH8d{jO2b6#XqkIcnq(kJM&wu$OVY~)zy4o<{(q-?|HQvF)AMIfK+k_o)VlrufH);K zJN^Hd&*y*joaNc6T9kt=USwR4Tj1Jhfxjtu|NjH@{}Oln|G)0_AJ<9#LacZD|99-9F6M|kN@p|r=3NKurM*# zy2PmKC+jD;ep3DYt$W?Ilu>k;`NbN9?zwW`7;#NyLwj4rMAlTs&!46-W_*lW z;96{f!asjn{r#>@+WN`(_q%YdjJ)4~Bw3%@N}6CE)uPH%idRQ{hoO|wg3@1=}73*U^;?{}fco$&n$mecwCX^xC>R5>VPFUDIB@?$-Y zVM*qaEzwHpB_-W6NUx)~HWl>5d^uuxNxWg~VfOGiDk6w2NU!=%ha5GoKh9rEdGafe zFgmjE&!45=@1o~Ve^Gh<#0hx*Bz^vSLHqr#`1{lFPsK4JZh>em@YIB!KNh_niJyvMV}rc_bnERW!KQAvLJjxaTR0-Oe{79g;A(7vM^9JSh1`z9FH3yN zKEJul$>Li^sZ%~9D&RipcgHESed0w?Z&cXx;YqaX>>q+n-8MiC_uN}JBDa5Rja%Sq zY=Joi9{-;b*TeXKjX~r8Pm(nLf5{pD7r7(-XJP#RoHPFaIgJ0GzNF*-%Vo3{=MlHS zbhN->Q%A&OVzGEh%<-^e zwD5@Gha&rDU7fT^4eGeY>a%R9~O&XFB3OHZnzbTrQ(s~Vey2R4>N&( z948)%PZ?0O!!Qn@JXuTr9E&^%t6IEhW@8L{Fb|b@Bv1Col6Kvq>nID=hc)cab&juN zbreQa?JVn=5JuXFnc{e(FeCeLJSrszOVZ97X7rY9=aS=NB)O57%!2D&5|Yl3k7OfW zJgU2DlNKp4A8I%=l4MQVScj}6q#sxJfKAUKM{0|3=om&YKh@P(eh8IK+Bq-gkSJnT>+Izh$q&bPJ%-duNqXyNDr?@D(xS6S zCU5IV50X@*-1cR6ue zliQv0!|^@hKgRzriBa~7_s9Qhycp!Ki~k&dLi~E1c*vh(e$T-%r_7BcMyh%}Q?7AI zOvFt&P)ycKCS6Z8)=3m`=-wC$kus)^S1n!)`$|SHf+JE%aabqwlM>sol&oGk+sX)) ziDfKlr|ei}pKLbvB#Z9NUWD-?gSoPiUX7PgwWG|t-Y%RYN0#Tv7Uoe)mJuWWi|i3x z>i=J6@y3)W9U(JkJ!|kDfB$9%CYvZN#bYAz z!AB+Oc}0l??(0aU9MiipHL2>JR8RLvI<;dsGdk3lvW7>k{?uS*$AC4m5tFXm`M}_| z4xRSo$G>@G1`y#6zYe7QJqLk8X)Eu~3_sXyrQCmumwl}H!)3F96ZiEBcOtQ{I0Adm zH?4GLxGz2IlC3DZ^&;&&SK@LfbTIt;#w%2UUkOV8ew6C?MzXDA7WVkWJlGIV*#8F;9k+V` literal 0 HcmV?d00001 diff --git a/tests/models/evo2/test_modeling_evo2.py b/tests/models/evo2/test_modeling_evo2.py index acdd9c1a6ea9..c87e9c98bfbd 100644 --- a/tests/models/evo2/test_modeling_evo2.py +++ b/tests/models/evo2/test_modeling_evo2.py @@ -1,3 +1,4 @@ +import os import unittest import pytest @@ -5,12 +6,14 @@ pytest.importorskip("parameterized") from transformers import is_torch_available -from transformers.testing_utils import require_torch +from transformers.testing_utils import require_torch, slow from ...causal_lm_tester import CausalLMModelTest, CausalLMModelTester if is_torch_available(): - from transformers import Evo2ForCausalLM, Evo2Model + import torch + + from transformers import Evo2ForCausalLM, Evo2Model, Evo2Tokenizer class Evo2ModelTester(CausalLMModelTester): @@ -50,5 +53,269 @@ class Evo2ModelTest(CausalLMModelTest, unittest.TestCase): model_tester_class = Evo2ModelTester +@require_torch +@slow +class Evo2InferenceTest(unittest.TestCase): + """Test inference against ground truth logits from the original evo2_1b_base model.""" + + @staticmethod + def convert_original_weights_to_transformers(original_weights): + """Convert weights from original Evo2 format to transformers format.""" + from transformers import Evo2Config + + # Create config based on the original model architecture + # vocab_size=512, hidden_size=1920, 25 layers (21 hyena + 4 attention every 7th layer starting from 3) + layer_types = [] + for i in range(25): + if i % 7 == 3: + layer_types.append("attention") + else: + layer_types.append("hyena") + + config = Evo2Config( + vocab_size=512, + hidden_size=1920, + intermediate_size=5120, + num_hidden_layers=25, + num_attention_heads=15, # 1920 / 128 + num_key_value_heads=15, + layer_types=layer_types, + hyena_filters=128, # Number of filter groups + hyena_order=3, # 5760 / 1920 = 3 + hyena_kernel_size=3, # Short filter kernel size + tie_word_embeddings=True, + ) + + # Initialize new state dict + new_state_dict = {} + + # Convert embeddings + new_state_dict["model.embed_tokens.weight"] = original_weights["embedding_layer.weight"] + new_state_dict["lm_head.weight"] = original_weights["unembed.weight"] + + # Convert each layer + for layer_idx in range(25): + layer_type = layer_types[layer_idx] + orig_prefix = f"blocks.{layer_idx}" + new_prefix = f"model.layers.{layer_idx}.block" + + # Common components: norms and MLP + new_state_dict[f"model.layers.{layer_idx}.block.input_layernorm.weight"] = original_weights[ + f"{orig_prefix}.pre_norm.scale" + ] + new_state_dict[f"model.layers.{layer_idx}.block.post_attention_layernorm.weight"] = original_weights[ + f"{orig_prefix}.post_norm.scale" + ] + + # MLP layers + # Original: l1 (gate), l2 (up), l3 (down) + new_state_dict[f"{new_prefix}.mlp.gate_proj.weight"] = original_weights[f"{orig_prefix}.mlp.l1.weight"] + new_state_dict[f"{new_prefix}.mlp.up_proj.weight"] = original_weights[f"{orig_prefix}.mlp.l2.weight"] + new_state_dict[f"{new_prefix}.mlp.down_proj.weight"] = original_weights[f"{orig_prefix}.mlp.l3.weight"] + + if layer_type == "attention": + # Convert attention layer + # Original uses Wqkv (combined), we need separate q_proj, k_proj, v_proj + wqkv = original_weights[f"{orig_prefix}.inner_mha_cls.Wqkv.weight"] + hidden_size = config.hidden_size + head_dim = hidden_size // config.num_attention_heads + + # Split Wqkv into q, k, v + q, k, v = torch.split(wqkv, hidden_size, dim=0) + new_state_dict[f"model.layers.{layer_idx}.block.attention.q_proj.weight"] = q + new_state_dict[f"model.layers.{layer_idx}.block.attention.k_proj.weight"] = k + new_state_dict[f"model.layers.{layer_idx}.block.attention.v_proj.weight"] = v + + # Output projection + new_state_dict[f"model.layers.{layer_idx}.block.attention.o_proj.weight"] = original_weights[ + f"{orig_prefix}.inner_mha_cls.out_proj.weight" + ] + + # Load rotary embedding inv_freq from original weights + if f"{orig_prefix}.inner_mha_cls.rotary_emb.inv_freq" in original_weights: + new_state_dict[f"model.layers.{layer_idx}.block.attention.rotary_emb.inv_freq"] = original_weights[ + f"{orig_prefix}.inner_mha_cls.rotary_emb.inv_freq" + ] + + # Note: Original has out_proj.bias but our implementation doesn't use bias + else: + # Convert hyena filter layer + new_state_dict[f"model.layers.{layer_idx}.block.filter.projections.weight"] = original_weights[ + f"{orig_prefix}.projections.weight" + ] + new_state_dict[f"model.layers.{layer_idx}.block.filter.short_filter_weight"] = original_weights[ + f"{orig_prefix}.filter.short_filter_weight" + ] + new_state_dict[f"model.layers.{layer_idx}.block.filter.out_filter_dense.weight"] = original_weights[ + f"{orig_prefix}.out_filter_dense.weight" + ] + new_state_dict[f"model.layers.{layer_idx}.block.filter.out_filter_dense.bias"] = original_weights[ + f"{orig_prefix}.out_filter_dense.bias" + ] + + # Long filter parameters (FIR or IIR) + if f"{orig_prefix}.filter.h" in original_weights: + new_state_dict[f"model.layers.{layer_idx}.block.filter.h"] = original_weights[ + f"{orig_prefix}.filter.h" + ] + if f"{orig_prefix}.filter.D" in original_weights: + new_state_dict[f"model.layers.{layer_idx}.block.filter.D"] = original_weights[ + f"{orig_prefix}.filter.D" + ] + if f"{orig_prefix}.filter.log_poles" in original_weights: + new_state_dict[f"model.layers.{layer_idx}.block.filter.log_poles"] = original_weights[ + f"{orig_prefix}.filter.log_poles" + ] + if f"{orig_prefix}.filter.residues" in original_weights: + new_state_dict[f"model.layers.{layer_idx}.block.filter.residues"] = original_weights[ + f"{orig_prefix}.filter.residues" + ] + + # Final norm + new_state_dict["model.norm.weight"] = original_weights["norm.scale"] + + return new_state_dict, config + + def test_weight_loading(self): + """Test that we can successfully load and convert weights from the original model.""" + from huggingface_hub import hf_hub_download + + # Download original weights + weights_path = hf_hub_download("arcinstitute/evo2_1b_base", "evo2_1b_base.pt") + original_weights = torch.load(weights_path, map_location="cpu", weights_only=False) + + # Convert weights to transformers format + new_state_dict, config = self.convert_original_weights_to_transformers(original_weights) + + # Create model and load converted weights + model = Evo2ForCausalLM(config) + + # Load state dict (strict=False because Hyena layers have optional parameters) + missing_keys, unexpected_keys = model.load_state_dict(new_state_dict, strict=False) + + # Manually assign filter parameters (h, D, log_poles, residues) + for layer_idx in range(config.num_hidden_layers): + if config.layer_types[layer_idx] == "hyena": + filter_module = model.model.layers[layer_idx].block.filter + orig_prefix = f"blocks.{layer_idx}.filter" + + if f"{orig_prefix}.h" in original_weights: + filter_module.h = original_weights[f"{orig_prefix}.h"] + if f"{orig_prefix}.D" in original_weights: + filter_module.D = original_weights[f"{orig_prefix}.D"] + if f"{orig_prefix}.log_poles" in original_weights: + filter_module.log_poles = original_weights[f"{orig_prefix}.log_poles"] + if f"{orig_prefix}.residues" in original_weights: + filter_module.residues = original_weights[f"{orig_prefix}.residues"] + + # Check that only expected keys are missing/unexpected + # (Hyena filter parameters and rotary embeddings) + expected_patterns = ["filter.h", "filter.D", "filter.log_poles", "filter.residues", "rotary_emb.inv_freq"] + + for key in missing_keys: + self.assertTrue( + any(pattern in key for pattern in expected_patterns), + f"Unexpected missing key: {key}" + ) + + for key in unexpected_keys: + self.assertTrue( + any(pattern in key for pattern in expected_patterns), + f"Unexpected key in state dict: {key}" + ) + + print(f"✓ Successfully loaded weights ({len(missing_keys)} missing, {len(unexpected_keys)} unexpected)") + + def test_inference_shape(self): + """Test that the model can run inference and produces the correct output shape.""" + from huggingface_hub import hf_hub_download + + # Load ground truth for reference + ground_truth_path = os.path.join( + os.path.dirname(__file__), "evo2_1b_base_ground_truth_logits.pt" + ) + ground_truth = torch.load(ground_truth_path, map_location="cpu", weights_only=False) + + # Download and convert weights + weights_path = hf_hub_download("arcinstitute/evo2_1b_base", "evo2_1b_base.pt") + original_weights = torch.load(weights_path, map_location="cpu", weights_only=False) + new_state_dict, config = self.convert_original_weights_to_transformers(original_weights) + + # Create and load model + model = Evo2ForCausalLM(config) + model.load_state_dict(new_state_dict, strict=False) + + # Manually assign filter parameters (h, D, log_poles, residues) + # These can't be loaded via load_state_dict because they're None initially + for layer_idx in range(config.num_hidden_layers): + if config.layer_types[layer_idx] == "hyena": + filter_module = model.model.layers[layer_idx].block.filter + orig_prefix = f"blocks.{layer_idx}.filter" + + if f"{orig_prefix}.h" in original_weights: + filter_module.h = original_weights[f"{orig_prefix}.h"] + if f"{orig_prefix}.D" in original_weights: + filter_module.D = original_weights[f"{orig_prefix}.D"] + if f"{orig_prefix}.log_poles" in original_weights: + filter_module.log_poles = original_weights[f"{orig_prefix}.log_poles"] + if f"{orig_prefix}.residues" in original_weights: + filter_module.residues = original_weights[f"{orig_prefix}.residues"] + + model = model.to(torch.bfloat16) + model.eval() + + # Create tokenizer + tokenizer = Evo2Tokenizer() + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + model = model.to(device) + + sequences = ground_truth["sequences"] + results = ground_truth["results"] + + # Test each sequence + for seq in sequences: + with self.subTest(sequence=seq): + # Get ground truth + gt_input_ids = results[seq]["input_ids"] + gt_logits = results[seq]["logits"] + + # Tokenize + tokens = tokenizer.tokenize(seq) + input_ids = torch.tensor([tokens], dtype=torch.long).to(device) + + # Verify input_ids match + self.assertTrue( + torch.equal(input_ids.cpu(), gt_input_ids.unsqueeze(0)), + f"Input IDs mismatch for sequence {seq!r}" + ) + + # Run inference + with torch.no_grad(): + outputs = model(input_ids) + logits = outputs.logits + + # Check shapes match + expected_shape = gt_logits.shape + actual_shape = logits.shape + self.assertEqual( + actual_shape, + expected_shape, + f"Shape mismatch for {seq!r}: expected {expected_shape}, got {actual_shape}" + ) + + # Check that logits are finite (not NaN or Inf) + self.assertTrue(torch.isfinite(logits).all(), f"Non-finite values in logits for {seq!r}") + + print(f"✓ {seq!r}: shape {actual_shape} OK, logits finite") + + # Check logits values match ground truth + # Using relaxed tolerance for bfloat16 + # rtol=1e-2, atol=1e-2 is typical for bfloat16 accumulation differences + torch.testing.assert_close(logits.cpu(), gt_logits.cpu(), rtol=0.02, atol=0.02) + + print(f"✓ {seq!r}: shape {actual_shape} OK, logits match ground truth") + + if __name__ == "__main__": unittest.main() From 857dc716ab7f30040f2977e870a45cd07b984f33 Mon Sep 17 00:00:00 2001 From: McClain Thiel Date: Tue, 18 Nov 2025 22:06:38 +0000 Subject: [PATCH 5/6] fixed the weird tests that allowed it to pass --- docs/source/en/model_doc/evo2.md | 34 +++++++++--- run_evo2.py | 29 ++++++++++ .../models/evo2/configuration_evo2.py | 4 +- .../models/evo2/convert_evo2_weights.py | 53 ++++++++++--------- src/transformers/models/evo2/modeling_evo2.py | 23 ++++++-- 5 files changed, 106 insertions(+), 37 deletions(-) create mode 100644 run_evo2.py diff --git a/docs/source/en/model_doc/evo2.md b/docs/source/en/model_doc/evo2.md index b3086dab3bf2..bf40c12ebf27 100644 --- a/docs/source/en/model_doc/evo2.md +++ b/docs/source/en/model_doc/evo2.md @@ -22,23 +22,43 @@ limitations under the License. ## Overview -The Evo2 model was proposed in []() by . - +The Evo2 model was proposed in [Genome modeling and design across all domains of life with Evo 2](https://www.biorxiv.org/content/10.1101/2024.02.27.582234v1) by Garyk Brixi, Matthew G. Durrant, Jerome Ku, Michael Poli, et al. +It is a biological foundation model trained on 9.3 trillion DNA base pairs from a curated genomic atlas spanning all domains of life. The abstract from the paper is the following: - +Evo 2 is a biological foundation model trained on 9.3 trillion DNA base pairs from a curated genomic atlas spanning all domains of life. The model features 7B and 40B parameter architectures and can process sequences up to 1 million base pairs at nucleotide-level resolution. It learns from DNA sequences alone to accurately predict the functional impacts of genetic variation, including noncoding pathogenic mutations and clinically significant BRCA1 variants, without task-specific finetuning. Mechanistic interpretability analyses reveal that Evo 2 autonomously learns a breadth of biological features, such as exon-intron boundaries, transcription factor binding sites, protein structural elements, and prophage genomic regions. Beyond its predictive capabilities, Evo 2 can generate mitochondrial, prokaryotic, and eukaryotic sequences at genome scale with greater naturalness and coherence than previous methods. Guiding Evo 2 via inference-time search enables controllable generation of epigenomic structure, for which the first inference-time scaling results in biology are demonstrated. The project makes Evo 2 fully open, including model parameters, training code, inference code, and the OpenGenome2 dataset, to accelerate the exploration and design of biological complexity. Tips: - +- Evo 2 is a genomic foundation model, meaning it is designed to process and generate DNA sequences. +- It uses the StripedHyena architecture, which combines attention with Hyena filters to handle long contexts efficiently. +- The model is trained on a massive dataset of 9.3 trillion base pairs. +- It can handle context lengths up to 1 million base pairs (though this specific implementation may be limited by available memory). -This model was contributed by [INSERT YOUR HF USERNAME HERE](https://huggingface.co/). -The original code can be found [here](). +This model was contributed by [arcinstitute](https://huggingface.co/arcinstitute). +The original code can be found [here](https://github.com/ArcInstitute/evo2). +The model was converted to Hugging Face format by [McClain Thiel](mailto:mcclain.thiel@gmail.com). ## Usage examples - +```python +from transformers import Evo2Config, Evo2ForCausalLM, Evo2Tokenizer + +# Initialize model and tokenizer +config = Evo2Config() +model = Evo2ForCausalLM(config) +tokenizer = Evo2Tokenizer() + +# Encode input DNA sequence +sequence = "ACGTACGT" +input_ids = tokenizer.encode(sequence, return_tensors="pt") + +# Generate +output = model.generate(input_ids, max_length=20) +generated_sequence = tokenizer.decode(output[0]) +print(generated_sequence) +``` ## Evo2Config diff --git a/run_evo2.py b/run_evo2.py new file mode 100644 index 000000000000..09bf5e996010 --- /dev/null +++ b/run_evo2.py @@ -0,0 +1,29 @@ +import torch +from transformers import Evo2ForCausalLM, Evo2Tokenizer + +# Path to the converted model +model_path = "/tmp/evo2_hf" + +print(f"Loading model from {model_path}...") +model = Evo2ForCausalLM.from_pretrained(model_path) +tokenizer = Evo2Tokenizer.from_pretrained(model_path) + +# Move to GPU if available +device = "cuda" if torch.cuda.is_available() else "cpu" +model = model.to(device) + +# Input sequence (DNA) +sequence = "ACGTACGT" +print(f"Input: {sequence}") + +# Tokenize +input_ids = tokenizer.encode(sequence, return_tensors="pt").to(device) + +# Generate +print("Generating...") +with torch.no_grad(): + output = model.generate(input_ids, max_new_tokens=20) + +# Decode +generated_sequence = tokenizer.decode(output[0]) +print(f"Output: {generated_sequence}") diff --git a/src/transformers/models/evo2/configuration_evo2.py b/src/transformers/models/evo2/configuration_evo2.py index a301bfb6dc4e..15ee518cb1d5 100644 --- a/src/transformers/models/evo2/configuration_evo2.py +++ b/src/transformers/models/evo2/configuration_evo2.py @@ -2,7 +2,7 @@ from __future__ import annotations -from typing import Optional, Sequence +from typing import Optional, Sequence, List, Dict, Any from ...configuration_utils import PretrainedConfig from ...modeling_rope_utils import standardize_rope_params @@ -100,6 +100,7 @@ def __init__( hyena_hidden_size: Optional[int] = None, hyena_order: int = 4, hyena_flip_x1x2: bool = False, + hyena_filter_configurations: Optional[List[Dict[str, Any]]] = None, initializer_range: float = 0.02, use_cache: bool = True, pad_token_id: int = 1, @@ -133,6 +134,7 @@ def __init__( self.hyena_hidden_size = hyena_hidden_size if hyena_hidden_size is not None else hidden_size self.hyena_order = hyena_order self.hyena_flip_x1x2 = hyena_flip_x1x2 + self.hyena_filter_configurations = hyena_filter_configurations self.initializer_range = initializer_range self.use_cache = use_cache diff --git a/src/transformers/models/evo2/convert_evo2_weights.py b/src/transformers/models/evo2/convert_evo2_weights.py index a948c660b728..84c8bddef161 100644 --- a/src/transformers/models/evo2/convert_evo2_weights.py +++ b/src/transformers/models/evo2/convert_evo2_weights.py @@ -15,6 +15,7 @@ import os import torch +from torch import nn from huggingface_hub import hf_hub_download from transformers import Evo2Config, Evo2ForCausalLM @@ -26,11 +27,32 @@ def convert_original_weights_to_transformers(original_weights): # Create config based on the original model architecture (Evo2-1b-base) # vocab_size=512, hidden_size=1920, 25 layers (21 hyena + 4 attention every 7th layer starting from 3) layer_types = [] + hyena_filter_configurations = [] + for i in range(25): if i % 7 == 3: layer_types.append("attention") + hyena_filter_configurations.append({}) # Empty config for attention layers else: layer_types.append("hyena") + + # Determine filter configuration for this layer + orig_prefix = f"blocks.{i}.filter" + layer_config = {} + + if f"{orig_prefix}.h" in original_weights: + layer_config["h_shape"] = original_weights[f"{orig_prefix}.h"].shape + + if f"{orig_prefix}.D" in original_weights: + layer_config["D_shape"] = original_weights[f"{orig_prefix}.D"].shape + + if f"{orig_prefix}.log_poles" in original_weights: + layer_config["log_poles_shape"] = original_weights[f"{orig_prefix}.log_poles"].shape + + if f"{orig_prefix}.residues" in original_weights: + layer_config["residues_shape"] = original_weights[f"{orig_prefix}.residues"].shape + + hyena_filter_configurations.append(layer_config) config = Evo2Config( vocab_size=512, @@ -44,6 +66,7 @@ def convert_original_weights_to_transformers(original_weights): hyena_order=3, # 5760 / 1920 = 3 hyena_kernel_size=3, # Short filter kernel size tie_word_embeddings=True, + hyena_filter_configurations=hyena_filter_configurations, ) # Initialize new state dict @@ -112,11 +135,6 @@ def convert_original_weights_to_transformers(original_weights): ] # Long filter parameters (FIR or IIR) - # These are not standard nn.Parameters in our implementation but we can load them into the state dict - # and then manually assign them in the model if needed, or just save them as part of the state dict - # since we registered them as buffers/parameters in the model (or should have). - # In our implementation, they are initialized as None. We need to make sure they are loaded. - if f"{orig_prefix}.filter.h" in original_weights: new_state_dict[f"model.layers.{layer_idx}.block.filter.h"] = original_weights[ f"{orig_prefix}.filter.h" @@ -164,8 +182,10 @@ def main(): print("Loading into Evo2ForCausalLM...") model = Evo2ForCausalLM(config) - # Load state dict (strict=False because Hyena layers have optional parameters that might be missing if unused) - # But we want to make sure we load everything we have. + # Load state dict + # strict=True should work now for the filter parameters! + # But we might still have some minor mismatches if we missed anything else (like rotary inv_freq buffers if they are persistent) + # Let's try strict=False but print what's missing to verify. missing_keys, unexpected_keys = model.load_state_dict(new_state_dict, strict=False) print(f"Missing keys: {len(missing_keys)}") @@ -175,24 +195,7 @@ def main(): if len(unexpected_keys) > 0: print(unexpected_keys[:10]) - # Manually assign filter parameters (h, D, log_poles, residues) if they were not loaded by load_state_dict - # because they were None in the model init. - # Actually, since we put them in new_state_dict, load_state_dict might complain if the model attributes are None. - # We might need to initialize them in the model first or just assign them directly. - - for layer_idx in range(config.num_hidden_layers): - if config.layer_types[layer_idx] == "hyena": - filter_module = model.model.layers[layer_idx].block.filter - orig_prefix = f"blocks.{layer_idx}.filter" - - if f"{orig_prefix}.h" in original_weights: - filter_module.h = nn.Parameter(original_weights[f"{orig_prefix}.h"]) - if f"{orig_prefix}.D" in original_weights: - filter_module.D = nn.Parameter(original_weights[f"{orig_prefix}.D"]) - if f"{orig_prefix}.log_poles" in original_weights: - filter_module.log_poles = nn.Parameter(original_weights[f"{orig_prefix}.log_poles"]) - if f"{orig_prefix}.residues" in original_weights: - filter_module.residues = nn.Parameter(original_weights[f"{orig_prefix}.residues"]) + # We no longer need manual assignment! print(f"Saving to {args.output_dir}...") model.save_pretrained(args.output_dir) diff --git a/src/transformers/models/evo2/modeling_evo2.py b/src/transformers/models/evo2/modeling_evo2.py index 6138d1240946..7045180598d2 100644 --- a/src/transformers/models/evo2/modeling_evo2.py +++ b/src/transformers/models/evo2/modeling_evo2.py @@ -218,13 +218,14 @@ def forward( class Evo2HyenaFilter(nn.Module): - def __init__(self, config: Evo2Config): + def __init__(self, config: Evo2Config, layer_idx: int): super().__init__() self.hidden_size = config.hidden_size self.num_attention_heads = config.num_attention_heads self.order = config.hyena_order self.short_filter_length = config.hyena_kernel_size self.hyena_flip_x1x2 = config.hyena_flip_x1x2 + self.layer_idx = layer_idx # Projections: hidden_size -> 3 * hidden_size (for x, y, z) self.projections = nn.Linear(self.hidden_size, self.order * self.hidden_size, bias=False) @@ -241,13 +242,27 @@ def __init__(self, config: Evo2Config): self.hyena_filter_groups = config.hyena_filters self.channels_per_group = self.hidden_size // self.hyena_filter_groups - # These parameters are optional and will be set dynamically when loading weights - # We register them as None initially + # Register parameters based on layer configuration self.h = None self.D = None self.log_poles = None self.residues = None + if config.hyena_filter_configurations is not None and layer_idx < len(config.hyena_filter_configurations): + layer_config = config.hyena_filter_configurations[layer_idx] + + if layer_config.get("h_shape"): + self.h = nn.Parameter(torch.randn(layer_config["h_shape"])) + + if layer_config.get("D_shape"): + self.D = nn.Parameter(torch.randn(layer_config["D_shape"])) + + if layer_config.get("log_poles_shape"): + self.log_poles = nn.Parameter(torch.randn(layer_config["log_poles_shape"])) + + if layer_config.get("residues_shape"): + self.residues = nn.Parameter(torch.randn(layer_config["residues_shape"])) + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: batch, seq_len, _ = hidden_states.shape @@ -409,7 +424,7 @@ class Evo2HyenaBlock(nn.Module): def __init__(self, config: Evo2Config, layer_idx: int): super().__init__() self.input_layernorm = Evo2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.filter = Evo2HyenaFilter(config) + self.filter = Evo2HyenaFilter(config, layer_idx) self.post_attention_layernorm = Evo2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.mlp = Evo2ParallelGatedMLP(config, layer_idx) self.hidden_dropout = nn.Dropout(config.hidden_dropout) From de25153dd1f54f1e2dd80eee37bd0f72e3bb7cb4 Mon Sep 17 00:00:00 2001 From: McClain Thiel Date: Tue, 18 Nov 2025 22:07:24 +0000 Subject: [PATCH 6/6] removing extra file --- run_evo2.py | 29 ----------------------------- 1 file changed, 29 deletions(-) delete mode 100644 run_evo2.py diff --git a/run_evo2.py b/run_evo2.py deleted file mode 100644 index 09bf5e996010..000000000000 --- a/run_evo2.py +++ /dev/null @@ -1,29 +0,0 @@ -import torch -from transformers import Evo2ForCausalLM, Evo2Tokenizer - -# Path to the converted model -model_path = "/tmp/evo2_hf" - -print(f"Loading model from {model_path}...") -model = Evo2ForCausalLM.from_pretrained(model_path) -tokenizer = Evo2Tokenizer.from_pretrained(model_path) - -# Move to GPU if available -device = "cuda" if torch.cuda.is_available() else "cpu" -model = model.to(device) - -# Input sequence (DNA) -sequence = "ACGTACGT" -print(f"Input: {sequence}") - -# Tokenize -input_ids = tokenizer.encode(sequence, return_tensors="pt").to(device) - -# Generate -print("Generating...") -with torch.no_grad(): - output = model.generate(input_ids, max_new_tokens=20) - -# Decode -generated_sequence = tokenizer.decode(output[0]) -print(f"Output: {generated_sequence}")