diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index c92fed507a6d..0e639d5bffb4 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -480,6 +480,8 @@ title: ErnieM - local: model_doc/esm title: ESM + - local: model_doc/evo2 + title: Evo2 - local: model_doc/exaone4 title: EXAONE-4.0 - local: model_doc/falcon diff --git a/docs/source/en/model_doc/evo2.md b/docs/source/en/model_doc/evo2.md new file mode 100644 index 000000000000..bf40c12ebf27 --- /dev/null +++ b/docs/source/en/model_doc/evo2.md @@ -0,0 +1,80 @@ + + + +# Evo2 + +## Overview + +The Evo2 model was proposed in [Genome modeling and design across all domains of life with Evo 2](https://www.biorxiv.org/content/10.1101/2024.02.27.582234v1) by Garyk Brixi, Matthew G. Durrant, Jerome Ku, Michael Poli, et al. +It is a biological foundation model trained on 9.3 trillion DNA base pairs from a curated genomic atlas spanning all domains of life. + +The abstract from the paper is the following: + +Evo 2 is a biological foundation model trained on 9.3 trillion DNA base pairs from a curated genomic atlas spanning all domains of life. The model features 7B and 40B parameter architectures and can process sequences up to 1 million base pairs at nucleotide-level resolution. It learns from DNA sequences alone to accurately predict the functional impacts of genetic variation, including noncoding pathogenic mutations and clinically significant BRCA1 variants, without task-specific finetuning. Mechanistic interpretability analyses reveal that Evo 2 autonomously learns a breadth of biological features, such as exon-intron boundaries, transcription factor binding sites, protein structural elements, and prophage genomic regions. Beyond its predictive capabilities, Evo 2 can generate mitochondrial, prokaryotic, and eukaryotic sequences at genome scale with greater naturalness and coherence than previous methods. Guiding Evo 2 via inference-time search enables controllable generation of epigenomic structure, for which the first inference-time scaling results in biology are demonstrated. The project makes Evo 2 fully open, including model parameters, training code, inference code, and the OpenGenome2 dataset, to accelerate the exploration and design of biological complexity. + +Tips: + +- Evo 2 is a genomic foundation model, meaning it is designed to process and generate DNA sequences. +- It uses the StripedHyena architecture, which combines attention with Hyena filters to handle long contexts efficiently. +- The model is trained on a massive dataset of 9.3 trillion base pairs. +- It can handle context lengths up to 1 million base pairs (though this specific implementation may be limited by available memory). + +This model was contributed by [arcinstitute](https://huggingface.co/arcinstitute). +The original code can be found [here](https://github.com/ArcInstitute/evo2). +The model was converted to Hugging Face format by [McClain Thiel](mailto:mcclain.thiel@gmail.com). + +## Usage examples + +```python +from transformers import Evo2Config, Evo2ForCausalLM, Evo2Tokenizer + +# Initialize model and tokenizer +config = Evo2Config() +model = Evo2ForCausalLM(config) +tokenizer = Evo2Tokenizer() + +# Encode input DNA sequence +sequence = "ACGTACGT" +input_ids = tokenizer.encode(sequence, return_tensors="pt") + +# Generate +output = model.generate(input_ids, max_length=20) +generated_sequence = tokenizer.decode(output[0]) +print(generated_sequence) +``` + +## Evo2Config + +[[autodoc]] Evo2Config + +## Evo2ForCausalLM + +[[autodoc]] Evo2ForCausalLM + + +## Evo2Model + +[[autodoc]] Evo2Model + - forward + +## Evo2PreTrainedModel + +[[autodoc]] Evo2PreTrainedModel + - forward diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index ae0e0b67c874..75d8cac2b31e 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -264,6 +264,14 @@ "utils.kernel_config": ["KernelConfig"], } +_import_structure["models.evo2"] = [ + "Evo2Config", + "Evo2ForCausalLM", + "Evo2Model", + "Evo2PreTrainedModel", + "Evo2Tokenizer", +] + # tokenizers-backed objects try: if not is_tokenizers_available(): diff --git a/src/transformers/models/__init__.py b/src/transformers/models/__init__.py index 3534ce6719d0..5f3472868cba 100644 --- a/src/transformers/models/__init__.py +++ b/src/transformers/models/__init__.py @@ -120,6 +120,7 @@ from .encoder_decoder import * from .ernie import * from .esm import * + from .evo2 import * from .evolla import * from .exaone4 import * from .falcon import * diff --git a/src/transformers/models/auto/configuration_auto.py b/src/transformers/models/auto/configuration_auto.py index 9a3b2ec5ecc2..1f5a861e9b83 100644 --- a/src/transformers/models/auto/configuration_auto.py +++ b/src/transformers/models/auto/configuration_auto.py @@ -145,6 +145,7 @@ ("ernie4_5_moe", "Ernie4_5_MoeConfig"), ("ernie_m", "ErnieMConfig"), ("esm", "EsmConfig"), + ("evo2", "Evo2Config"), ("evolla", "EvollaConfig"), ("exaone4", "Exaone4Config"), ("falcon", "FalconConfig"), @@ -590,6 +591,7 @@ ("ernie4_5_moe", "Ernie4_5_MoE"), ("ernie_m", "ErnieM"), ("esm", "ESM"), + ("evo2", "Evo2"), ("evolla", "Evolla"), ("exaone4", "EXAONE-4.0"), ("falcon", "Falcon"), diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index 257fb95fdea7..4781d00cdc30 100644 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -148,6 +148,7 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin): ("ernie4_5_moe", "Ernie4_5_MoeModel"), ("ernie_m", "ErnieMModel"), ("esm", "EsmModel"), + ("evo2", "Evo2Model"), ("evolla", "EvollaModel"), ("exaone4", "Exaone4Model"), ("falcon", "FalconModel"), @@ -666,6 +667,7 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin): ("ernie", "ErnieForCausalLM"), ("ernie4_5", "Ernie4_5ForCausalLM"), ("ernie4_5_moe", "Ernie4_5_MoeForCausalLM"), + ("evo2", "Evo2ForCausalLM"), ("exaone4", "Exaone4ForCausalLM"), ("falcon", "FalconForCausalLM"), ("falcon_h1", "FalconH1ForCausalLM"), diff --git a/src/transformers/models/auto/tokenization_auto.py b/src/transformers/models/auto/tokenization_auto.py index c830fee86987..1cb97b2fd90e 100644 --- a/src/transformers/models/auto/tokenization_auto.py +++ b/src/transformers/models/auto/tokenization_auto.py @@ -241,6 +241,7 @@ ("ernie4_5_moe", (None, "LlamaTokenizerFast" if is_tokenizers_available() else None)), ("ernie_m", ("ErnieMTokenizer" if is_sentencepiece_available() else None, None)), ("esm", ("EsmTokenizer", None)), + ("evo2", ("Evo2Tokenizer", None)), ("evolla", (None, "LlamaTokenizerFast" if is_tokenizers_available() else None)), ( "exaone4", diff --git a/src/transformers/models/evo2/__init__.py b/src/transformers/models/evo2/__init__.py new file mode 100644 index 000000000000..e48f532a620f --- /dev/null +++ b/src/transformers/models/evo2/__init__.py @@ -0,0 +1,39 @@ +"""Evo2 model, tokenizer, and configuration.""" + +from typing import TYPE_CHECKING + +from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available + +_import_structure = { + "configuration_evo2": ["Evo2Config"], + "tokenization_evo2": ["Evo2Tokenizer"], +} + +try: + if not is_torch_available(): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + pass +else: + _import_structure["modeling_evo2"] = [ + "Evo2ForCausalLM", + "Evo2Model", + "Evo2PreTrainedModel", + ] + + +if TYPE_CHECKING: + from .configuration_evo2 import Evo2Config + from .tokenization_evo2 import Evo2Tokenizer + + try: + if not is_torch_available(): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + pass + else: + from .modeling_evo2 import Evo2ForCausalLM, Evo2Model, Evo2PreTrainedModel +else: + import sys + + sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure) diff --git a/src/transformers/models/evo2/configuration_evo2.py b/src/transformers/models/evo2/configuration_evo2.py new file mode 100644 index 000000000000..15ee518cb1d5 --- /dev/null +++ b/src/transformers/models/evo2/configuration_evo2.py @@ -0,0 +1,185 @@ +"""Evo2 model configuration.""" + +from __future__ import annotations + +from typing import Optional, Sequence, List, Dict, Any + +from ...configuration_utils import PretrainedConfig +from ...modeling_rope_utils import standardize_rope_params +from ...utils import logging + + +logger = logging.get_logger(__name__) + +__all__ = ["Evo2Config"] + + +class Evo2Config(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`Evo2Model`]. It is used to instantiate an Evo2 + model according to the specified arguments, defining the model architecture. Instantiating a configuration with the + defaults will yield a similar configuration to that of the Evo2-1b-base model. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + vocab_size (`int`, *optional*, defaults to 512): + Vocabulary size of the Evo2 model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`Evo2Model`] + hidden_size (`int`, *optional*, defaults to 2048): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*): + Dimension of the MLP representations. + num_hidden_layers (`int`, *optional*, defaults to 24): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 16): + Number of attention heads for each attention layer in the Transformer encoder. + num_key_value_heads (`int`, *optional*): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=None`, the model will use the same number of key/value heads as the number of + query heads. + max_position_embeddings (`int`, *optional*, defaults to 2048): + The maximum sequence length that this model might ever be used with. + rope_theta (`float`, *optional*, defaults to 1000000.0): + The base period of the RoPE embeddings. + rms_norm_eps (`float`, *optional*, defaults to 1e-6): + The epsilon used by the rms normalization layers. + attn_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + hidden_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the hidden units. + mlp_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the MLP layers. + layer_types (`Sequence[str]`, *optional*): + List of layer types ("attention" or "hyena") for each layer. If None, defaults to all "attention". + hyena_filters (`int`, *optional*, defaults to 256): + Number of Hyena filter groups. + hyena_kernel_size (`int`, *optional*, defaults to 8): + Kernel size for the short convolution in Hyena. + hyena_hidden_size (`int`, *optional*): + Hidden size for Hyena layers. + hyena_order (`int`, *optional*, defaults to 4): + Order of the Hyena recurrence. + hyena_flip_x1x2 (`bool`, *optional*, defaults to False): + Whether to flip x1 and x2 in the Hyena gating mechanism. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + use_cache (`bool`, *optional*, defaults to True): + Whether or not the model should return the last key/values attentions (not used by all models). + pad_token_id (`int`, *optional*, defaults to 1): + Padding token id. + bos_token_id (`int`, *optional*): + Beginning of stream token id. + eos_token_id (`int`, *optional*, defaults to 0): + End of stream token id. + tie_word_embeddings (`bool`, *optional*, defaults to True): + Whether to tie weight embeddings + """ + + model_type = "evo2" + keys_to_ignore_at_inference = ["past_key_values"] + + def __init__( + self, + vocab_size: int = 512, + hidden_size: int = 2048, + intermediate_size: Optional[int] = None, + num_hidden_layers: int = 24, + num_attention_heads: int = 16, + num_key_value_heads: Optional[int] = None, + max_position_embeddings: int = 2048, + rope_theta: float = 1_000_000.0, + rms_norm_eps: float = 1e-6, + attn_dropout: float = 0.0, + hidden_dropout: float = 0.0, + mlp_dropout: float = 0.0, + layer_types: Optional[Sequence[str]] = None, + hyena_filters: int = 256, + hyena_kernel_size: int = 8, + hyena_hidden_size: Optional[int] = None, + hyena_order: int = 4, + hyena_flip_x1x2: bool = False, + hyena_filter_configurations: Optional[List[Dict[str, Any]]] = None, + initializer_range: float = 0.02, + use_cache: bool = True, + pad_token_id: int = 1, + bos_token_id: Optional[int] = None, + eos_token_id: int = 0, + tie_word_embeddings: bool = True, + **kwargs, + ) -> None: + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size if intermediate_size is not None else hidden_size * 4 + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.num_key_value_heads = num_key_value_heads or num_attention_heads + self.max_position_embeddings = max_position_embeddings + self.rope_theta = rope_theta + self.rms_norm_eps = rms_norm_eps + self.attn_dropout = attn_dropout + self.hidden_dropout = hidden_dropout + self.mlp_dropout = mlp_dropout + self.hyena_filters = hyena_filters + self.hyena_kernel_size = hyena_kernel_size + self.hyena_hidden_size = hyena_hidden_size if hyena_hidden_size is not None else hidden_size + self.hyena_order = hyena_order + self.hyena_flip_x1x2 = hyena_flip_x1x2 + self.hyena_filter_configurations = hyena_filter_configurations + self.initializer_range = initializer_range + self.use_cache = use_cache + + if layer_types is None: + self.layer_types = ["attention"] * num_hidden_layers + else: + self.layer_types = list(layer_types) + + standardize_rope_params(self, rope_theta=self.rope_theta) + + if len(self.layer_types) != self.num_hidden_layers: + raise ValueError( + "The length of `layer_types` must match `num_hidden_layers` (received" + f" {len(self.layer_types)} and {self.num_hidden_layers})." + ) + + for layer_type in self.layer_types: + if layer_type not in {"attention", "hyena"}: + raise ValueError(f"Unsupported layer type: {layer_type}. Expected 'attention' or 'hyena'.") + + if self.num_attention_heads <= 0 or self.hidden_size % self.num_attention_heads != 0: + raise ValueError("`hidden_size` must be divisible by `num_attention_heads`.") + + if self.num_key_value_heads <= 0 or self.hidden_size % self.num_key_value_heads != 0: + raise ValueError("`hidden_size` must be divisible by `num_key_value_heads`.") + + logger.info("Initialized Evo2Config with %s layers (%s).", self.num_hidden_layers, ", ".join(self.layer_types)) + + @property + def head_dim(self) -> int: + return self.hidden_size // self.num_attention_heads + + @property + def kv_head_dim(self) -> int: + return self.hidden_size // self.num_key_value_heads + + @property + def num_attention_layers(self) -> int: + return sum(layer_type == "attention" for layer_type in self.layer_types) + + @property + def num_hyena_layers(self) -> int: + return sum(layer_type == "hyena" for layer_type in self.layer_types) + + def to_dict(self) -> dict: + output = super().to_dict() + output["layer_types"] = list(self.layer_types) + return output diff --git a/src/transformers/models/evo2/convert_evo2_original_to_hf.py b/src/transformers/models/evo2/convert_evo2_original_to_hf.py new file mode 100644 index 000000000000..d8bf5f3984bd --- /dev/null +++ b/src/transformers/models/evo2/convert_evo2_original_to_hf.py @@ -0,0 +1,214 @@ +"""Conversion script for original Evo2 checkpoints to Hugging Face format.""" + +from __future__ import annotations + +import argparse +import json +import os +from typing import Dict, Iterable, Optional + +import torch + +try: + import yaml +except ImportError: # pragma: no cover + yaml = None + +from transformers import Evo2Config, Evo2ForCausalLM + + +def _load_original_state_dict(path: str) -> Dict[str, torch.Tensor]: + checkpoint = torch.load(path, map_location="cpu") + if isinstance(checkpoint, dict): + if "state_dict" in checkpoint: + return checkpoint["state_dict"] + if "model" in checkpoint and isinstance(checkpoint["model"], dict): + return checkpoint["model"] + if isinstance(checkpoint, (list, tuple)): + raise ValueError("Unexpected checkpoint structure. Expected a dictionary with model weights.") + return checkpoint + + +def _load_config(path: str) -> Evo2Config: + ext = os.path.splitext(path)[1].lower() + if ext in {".yml", ".yaml"}: + if yaml is None: + raise ImportError("PyYAML is required to parse YAML configuration files. Please install pyyaml.") + with open(path, "r", encoding="utf-8") as handle: + raw = yaml.safe_load(handle) + config_dict = raw.get("model", raw) + layer_specs = config_dict.get("layers", []) + layer_types = [] + for layer in layer_specs: + if isinstance(layer, dict): + layer_type = layer.get("type") or layer.get("layer_type") or layer.get("block_type") + if layer_type is None: + layer_type = "attention" + layer_types.append(layer_type.lower()) + else: + layer_types.append(str(layer).lower()) + config_dict["layer_types"] = layer_types or config_dict.get("layer_types") + return Evo2Config(**config_dict) + if ext == ".json": + with open(path, "r", encoding="utf-8") as handle: + config_kwargs = json.load(handle) + return Evo2Config(**config_kwargs) + if os.path.isdir(path): + return Evo2Config.from_pretrained(path) + raise ValueError(f"Unsupported config format for '{path}'. Expected directory, JSON, or YAML file.") + + +def _match_first_available(state_dict: Dict[str, torch.Tensor], candidates: Iterable[str]) -> Optional[str]: + for candidate in candidates: + if candidate in state_dict: + return candidate + return None + + +def _map_attention_key(layer_idx: int, target_suffix: str, original_state: Dict[str, torch.Tensor]) -> Optional[str]: + suffix_map = { + "q_proj.weight": ["attn.wq.weight", "attention.wq.weight", "attention.q_proj.weight"], + "k_proj.weight": ["attn.wk.weight", "attention.wk.weight", "attention.k_proj.weight"], + "v_proj.weight": ["attn.wv.weight", "attention.wv.weight", "attention.v_proj.weight"], + "o_proj.weight": ["attn.wo.weight", "attention.wo.weight", "attention.out_proj.weight"], + "q_proj.bias": ["attn.wq.bias", "attention.wq.bias", "attention.q_proj.bias"], + "k_proj.bias": ["attn.wk.bias", "attention.wk.bias", "attention.k_proj.bias"], + "v_proj.bias": ["attn.wv.bias", "attention.wv.bias", "attention.v_proj.bias"], + "o_proj.bias": ["attn.wo.bias", "attention.wo.bias", "attention.out_proj.bias"], + "input_layernorm.weight": ["attn_norm.weight", "input_layernorm.weight"], + "post_attention_layernorm.weight": ["mlp_norm.weight", "post_attention_layernorm.weight"], + "input_layernorm.bias": ["attn_norm.bias", "input_layernorm.bias"], + "post_attention_layernorm.bias": ["mlp_norm.bias", "post_attention_layernorm.bias"], + } + candidates = [] + for variant in suffix_map.get(target_suffix, []): + candidates.append(f"layers.{layer_idx}.{variant}") + candidates.append(f"model.layers.{layer_idx}.{variant}") + return _match_first_available(original_state, candidates) + + +def _map_mlp_key(layer_idx: int, target_suffix: str, original_state: Dict[str, torch.Tensor]) -> Optional[str]: + suffix_map = { + "gate_proj.weight": ["mlp.gate_proj.weight", "mlp.w1.weight", "ffn.gate_proj.weight"], + "up_proj.weight": ["mlp.up_proj.weight", "mlp.w3.weight", "ffn.up_proj.weight"], + "down_proj.weight": ["mlp.down_proj.weight", "mlp.w2.weight", "ffn.down_proj.weight"], + "gate_proj.bias": ["mlp.gate_proj.bias", "mlp.w1.bias", "ffn.gate_proj.bias"], + "up_proj.bias": ["mlp.up_proj.bias", "mlp.w3.bias", "ffn.up_proj.bias"], + "down_proj.bias": ["mlp.down_proj.bias", "mlp.w2.bias", "ffn.down_proj.bias"], + } + candidates = [] + for variant in suffix_map.get(target_suffix, []): + candidates.append(f"layers.{layer_idx}.{variant}") + candidates.append(f"model.layers.{layer_idx}.{variant}") + return _match_first_available(original_state, candidates) + + +def _map_hyena_key(layer_idx: int, target_suffix: str, original_state: Dict[str, torch.Tensor]) -> Optional[str]: + suffix_map = { + "input_layernorm.weight": ["hyena_norm.weight", "input_layernorm.weight"], + "input_layernorm.bias": ["hyena_norm.bias", "input_layernorm.bias"], + "post_attention_layernorm.weight": ["mlp_norm.weight", "post_layernorm.weight"], + "post_attention_layernorm.bias": ["mlp_norm.bias", "post_layernorm.bias"], + "filter.in_proj.weight": ["hyena.filter.in_proj.weight", "hyena.in_proj.weight"], + "filter.out_proj.weight": ["hyena.filter.out_proj.weight", "hyena.out_proj.weight"], + "filter.conv.weight": ["hyena.filter.conv.weight"], + } + candidates = [] + for variant in suffix_map.get(target_suffix, []): + candidates.append(f"layers.{layer_idx}.{variant}") + candidates.append(f"model.layers.{layer_idx}.{variant}") + return _match_first_available(original_state, candidates) + + +def _map_key(target_key: str, config: Evo2Config, original_state: Dict[str, torch.Tensor]) -> Optional[str]: + if target_key == "model.embed_tokens.weight": + return _match_first_available( + original_state, + [ + "model.embed_tokens.weight", + "embed_tokens.weight", + "tok_embeddings.weight", + "embedding.weight", + "embeddings.word_embeddings.weight", + ], + ) + if target_key == "model.norm.weight": + return _match_first_available( + original_state, + ["model.norm.weight", "norm.weight", "final_layer_norm.weight", "rms_norm.weight"], + ) + if target_key == "lm_head.weight": + return _match_first_available(original_state, ["lm_head.weight", "output.weight", "head.weight"]) + + if target_key.startswith("model.layers."): + parts = target_key.split(".") + layer_idx = int(parts[2]) + layer_type = config.layer_types[layer_idx] + suffix = ".".join(parts[4:]) if parts[3] == "block" else ".".join(parts[3:]) + if layer_type == "attention": + if suffix.startswith("attention."): + attn_suffix = suffix[len("attention.") :] + return _map_attention_key(layer_idx, attn_suffix, original_state) + if suffix.startswith("mlp."): + mlp_suffix = suffix[len("mlp.") :] + return _map_mlp_key(layer_idx, mlp_suffix, original_state) + if suffix.startswith("hidden_dropout"): + return None + if suffix.startswith("input_layernorm") or suffix.startswith("post_attention_layernorm"): + return _map_attention_key(layer_idx, suffix, original_state) + else: + if suffix.startswith("filter."): + filter_suffix = suffix + return _map_hyena_key(layer_idx, filter_suffix, original_state) + if suffix.startswith("mlp."): + mlp_suffix = suffix[len("mlp.") :] + return _map_mlp_key(layer_idx, mlp_suffix, original_state) + if suffix.startswith("input_layernorm") or suffix.startswith("post_attention_layernorm"): + return _map_hyena_key(layer_idx, suffix, original_state) + return None + return None + + +def convert_checkpoint(original_checkpoint: str, config_path: str, output_dir: str) -> None: + original_state = _load_original_state_dict(original_checkpoint) + config = _load_config(config_path) + model = Evo2ForCausalLM(config) + + target_state = model.state_dict() + new_state = {} + missing_keys = [] + + for key in target_state.keys(): + source_key = _map_key(key, config, original_state) + if source_key is None: + missing_keys.append(key) + continue + new_state[key] = original_state[source_key] + + if missing_keys: + raise KeyError( + "The following keys could not be mapped from the original checkpoint: " + ", ".join(missing_keys) + ) + + model.load_state_dict(new_state, strict=True) + + os.makedirs(output_dir, exist_ok=True) + model.save_pretrained(output_dir) + config.save_pretrained(output_dir) + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description="Convert an original Evo2 checkpoint to Hugging Face format.") + parser.add_argument("checkpoint", type=str, help="Path to the original .pt checkpoint file") + parser.add_argument("config", type=str, help="Path to the Evo2 YAML/JSON config or HF directory") + parser.add_argument("output", type=str, help="Output directory for the converted model") + return parser.parse_args() + + +def main() -> None: + args = parse_args() + convert_checkpoint(args.checkpoint, args.config, args.output) + + +if __name__ == "__main__": + main() diff --git a/src/transformers/models/evo2/convert_evo2_weights.py b/src/transformers/models/evo2/convert_evo2_weights.py new file mode 100644 index 000000000000..84c8bddef161 --- /dev/null +++ b/src/transformers/models/evo2/convert_evo2_weights.py @@ -0,0 +1,206 @@ +# Copyright 2024 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import argparse +import os + +import torch +from torch import nn +from huggingface_hub import hf_hub_download + +from transformers import Evo2Config, Evo2ForCausalLM + + +def convert_original_weights_to_transformers(original_weights): + """Convert weights from original Evo2 format to transformers format.""" + + # Create config based on the original model architecture (Evo2-1b-base) + # vocab_size=512, hidden_size=1920, 25 layers (21 hyena + 4 attention every 7th layer starting from 3) + layer_types = [] + hyena_filter_configurations = [] + + for i in range(25): + if i % 7 == 3: + layer_types.append("attention") + hyena_filter_configurations.append({}) # Empty config for attention layers + else: + layer_types.append("hyena") + + # Determine filter configuration for this layer + orig_prefix = f"blocks.{i}.filter" + layer_config = {} + + if f"{orig_prefix}.h" in original_weights: + layer_config["h_shape"] = original_weights[f"{orig_prefix}.h"].shape + + if f"{orig_prefix}.D" in original_weights: + layer_config["D_shape"] = original_weights[f"{orig_prefix}.D"].shape + + if f"{orig_prefix}.log_poles" in original_weights: + layer_config["log_poles_shape"] = original_weights[f"{orig_prefix}.log_poles"].shape + + if f"{orig_prefix}.residues" in original_weights: + layer_config["residues_shape"] = original_weights[f"{orig_prefix}.residues"].shape + + hyena_filter_configurations.append(layer_config) + + config = Evo2Config( + vocab_size=512, + hidden_size=1920, + intermediate_size=5120, + num_hidden_layers=25, + num_attention_heads=15, # 1920 / 128 + num_key_value_heads=15, + layer_types=layer_types, + hyena_filters=128, # Number of filter groups + hyena_order=3, # 5760 / 1920 = 3 + hyena_kernel_size=3, # Short filter kernel size + tie_word_embeddings=True, + hyena_filter_configurations=hyena_filter_configurations, + ) + + # Initialize new state dict + new_state_dict = {} + + # Convert embeddings + new_state_dict["model.embed_tokens.weight"] = original_weights["embedding_layer.weight"] + new_state_dict["lm_head.weight"] = original_weights["unembed.weight"] + + # Convert each layer + for layer_idx in range(25): + layer_type = layer_types[layer_idx] + orig_prefix = f"blocks.{layer_idx}" + new_prefix = f"model.layers.{layer_idx}.block" + + # Common components: norms and MLP + new_state_dict[f"model.layers.{layer_idx}.block.input_layernorm.weight"] = original_weights[ + f"{orig_prefix}.pre_norm.scale" + ] + new_state_dict[f"model.layers.{layer_idx}.block.post_attention_layernorm.weight"] = original_weights[ + f"{orig_prefix}.post_norm.scale" + ] + + # MLP layers + # Original: l1 (gate), l2 (up), l3 (down) + new_state_dict[f"{new_prefix}.mlp.gate_proj.weight"] = original_weights[f"{orig_prefix}.mlp.l1.weight"] + new_state_dict[f"{new_prefix}.mlp.up_proj.weight"] = original_weights[f"{orig_prefix}.mlp.l2.weight"] + new_state_dict[f"{new_prefix}.mlp.down_proj.weight"] = original_weights[f"{orig_prefix}.mlp.l3.weight"] + + if layer_type == "attention": + # Convert attention layer + # Original uses Wqkv (combined), we need separate q_proj, k_proj, v_proj + wqkv = original_weights[f"{orig_prefix}.inner_mha_cls.Wqkv.weight"] + hidden_size = config.hidden_size + + # Split Wqkv into q, k, v + q, k, v = torch.split(wqkv, hidden_size, dim=0) + new_state_dict[f"model.layers.{layer_idx}.block.attention.q_proj.weight"] = q + new_state_dict[f"model.layers.{layer_idx}.block.attention.k_proj.weight"] = k + new_state_dict[f"model.layers.{layer_idx}.block.attention.v_proj.weight"] = v + + # Output projection + new_state_dict[f"model.layers.{layer_idx}.block.attention.o_proj.weight"] = original_weights[ + f"{orig_prefix}.inner_mha_cls.out_proj.weight" + ] + + # Load rotary embedding inv_freq from original weights + if f"{orig_prefix}.inner_mha_cls.rotary_emb.inv_freq" in original_weights: + new_state_dict[f"model.layers.{layer_idx}.block.attention.rotary_emb.inv_freq"] = original_weights[ + f"{orig_prefix}.inner_mha_cls.rotary_emb.inv_freq" + ] + + else: + # Convert hyena filter layer + new_state_dict[f"model.layers.{layer_idx}.block.filter.projections.weight"] = original_weights[ + f"{orig_prefix}.projections.weight" + ] + new_state_dict[f"model.layers.{layer_idx}.block.filter.short_filter_weight"] = original_weights[ + f"{orig_prefix}.filter.short_filter_weight" + ] + new_state_dict[f"model.layers.{layer_idx}.block.filter.out_filter_dense.weight"] = original_weights[ + f"{orig_prefix}.out_filter_dense.weight" + ] + new_state_dict[f"model.layers.{layer_idx}.block.filter.out_filter_dense.bias"] = original_weights[ + f"{orig_prefix}.out_filter_dense.bias" + ] + + # Long filter parameters (FIR or IIR) + if f"{orig_prefix}.filter.h" in original_weights: + new_state_dict[f"model.layers.{layer_idx}.block.filter.h"] = original_weights[ + f"{orig_prefix}.filter.h" + ] + if f"{orig_prefix}.filter.D" in original_weights: + new_state_dict[f"model.layers.{layer_idx}.block.filter.D"] = original_weights[ + f"{orig_prefix}.filter.D" + ] + if f"{orig_prefix}.filter.log_poles" in original_weights: + new_state_dict[f"model.layers.{layer_idx}.block.filter.log_poles"] = original_weights[ + f"{orig_prefix}.filter.log_poles" + ] + if f"{orig_prefix}.filter.residues" in original_weights: + new_state_dict[f"model.layers.{layer_idx}.block.filter.residues"] = original_weights[ + f"{orig_prefix}.filter.residues" + ] + + # Final norm + new_state_dict["model.norm.weight"] = original_weights["norm.scale"] + + return new_state_dict, config + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--model_id", + default="arcinstitute/evo2_1b_base", + help="Hub model id", + ) + parser.add_argument( + "--output_dir", + default="evo2_converted", + help="The output directory to save the converted model", + ) + args = parser.parse_args() + + print(f"Downloading weights for {args.model_id}...") + weights_path = hf_hub_download(args.model_id, "evo2_1b_base.pt") + original_weights = torch.load(weights_path, map_location="cpu", weights_only=False) + + print("Converting weights...") + new_state_dict, config = convert_original_weights_to_transformers(original_weights) + + print("Loading into Evo2ForCausalLM...") + model = Evo2ForCausalLM(config) + + # Load state dict + # strict=True should work now for the filter parameters! + # But we might still have some minor mismatches if we missed anything else (like rotary inv_freq buffers if they are persistent) + # Let's try strict=False but print what's missing to verify. + missing_keys, unexpected_keys = model.load_state_dict(new_state_dict, strict=False) + + print(f"Missing keys: {len(missing_keys)}") + if len(missing_keys) > 0: + print(missing_keys[:10]) + print(f"Unexpected keys: {len(unexpected_keys)}") + if len(unexpected_keys) > 0: + print(unexpected_keys[:10]) + + # We no longer need manual assignment! + + print(f"Saving to {args.output_dir}...") + model.save_pretrained(args.output_dir) + print("Done.") + + +if __name__ == "__main__": + main() diff --git a/src/transformers/models/evo2/modeling_evo2.py b/src/transformers/models/evo2/modeling_evo2.py new file mode 100644 index 000000000000..7045180598d2 --- /dev/null +++ b/src/transformers/models/evo2/modeling_evo2.py @@ -0,0 +1,734 @@ +"""PyTorch Evo2 model.""" + +from __future__ import annotations + +import math +from collections.abc import Callable +from typing import Optional, Tuple, Union + +import torch +import torch.nn.functional as F +from torch import nn +from torch.utils.checkpoint import checkpoint + +from ...cache_utils import Cache, DynamicCache +from ...modeling_attn_mask_utils import _prepare_4d_causal_attention_mask +from ...generation import GenerationMixin +from ...modeling_layers import GradientCheckpointingLayer +from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast +from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update +from ...modeling_utils import PreTrainedModel +from ...utils import logging +from .configuration_evo2 import Evo2Config + + +logger = logging.get_logger(__name__) + +__all__ = ["Evo2Model", "Evo2ForCausalLM", "Evo2PreTrainedModel"] + + +# Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->Evo2 +class Evo2RMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """Evo2RMSNorm is equivalent to T5LayerNorm.""" + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" + + +# Copied from transformers.models.llama.modeling_llama.rotate_half +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors.""" + del position_ids + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +class Evo2RotaryEmbedding(nn.Module): + inv_freq: torch.Tensor + + def __init__(self, config: Evo2Config, device=None): + super().__init__() + self.max_seq_len_cached = getattr(config, "max_position_embeddings", 2048) + self.original_max_seq_len = self.max_seq_len_cached + + self.config = config + + self.rope_type = self.config.rope_parameters["rope_type"] + rope_init_fn: Callable = self.compute_default_rope_parameters + if self.rope_type != "default": + rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + inv_freq, self.attention_scaling = rope_init_fn(self.config, device) + + # Register inv_freq as persistent so it can be loaded from checkpoints + self.register_buffer("inv_freq", inv_freq, persistent=True) + self.original_inv_freq = inv_freq + + @staticmethod + def compute_default_rope_parameters( + config: Optional[Evo2Config] = None, + device: Optional[torch.device] = None, + seq_len: Optional[int] = None, + ) -> Tuple[torch.Tensor, float]: + del seq_len + rope_params = getattr(config, "rope_parameters", None) + base = rope_params.get("rope_theta") if rope_params is not None else config.rope_theta + dim = config.head_dim + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float, device=device) / dim)) + return inv_freq, 1.0 + + @torch.no_grad() + @dynamic_rope_update + def forward(self, x: torch.Tensor, position_ids: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device) + position_ids_expanded = position_ids[:, None, :].float() + + device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() * self.attention_scaling + sin = emb.sin() * self.attention_scaling + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +# Copied from transformers.models.llama.modeling_llama.repeat_kv +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +class Evo2ParallelGatedMLP(nn.Module): + def __init__(self, config: Evo2Config, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.layer_idx = layer_idx + + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + self.dropout = nn.Dropout(config.mlp_dropout) + + # Evo2 style: only layer 0 has activation, rest use identity + self.use_activation = (layer_idx == 0) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + z1 = self.gate_proj(hidden_states) + z2 = self.up_proj(hidden_states) + + # Apply SiLU only for layer 0, identity for others (Evo2 style) + if self.use_activation: + z1 = F.silu(z1) + + hidden_states = self.down_proj(z1 * z2) + hidden_states = self.dropout(hidden_states) + return hidden_states + + +class Evo2Attention(nn.Module): + def __init__(self, config: Evo2Config, layer_idx: int): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.num_key_value_heads = config.num_key_value_heads + self.head_dim = self.hidden_size // self.num_heads + self.kv_head_dim = self.hidden_size // self.num_key_value_heads + self.layer_idx = layer_idx + + self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False) + self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.kv_head_dim, bias=False) + self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.kv_head_dim, bias=False) + self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) + self.dropout = nn.Dropout(config.attn_dropout) + + self.rotary_emb = Evo2RotaryEmbedding(config) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor], + position_ids: torch.Tensor, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor, torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.kv_head_dim).transpose(1, 2) + value_states = self.v_proj(hidden_states).view( + bsz, q_len, self.num_key_value_heads, self.kv_head_dim + ).transpose(1, 2) + + cos, sin = self.rotary_emb(query_states, position_ids) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + cache_kwargs = {"cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + kv_seq_len = key_states.shape[-2] + + if self.num_key_value_heads != self.num_heads: + key_states = repeat_kv(key_states, self.num_heads // self.num_key_value_heads) + value_states = repeat_kv(value_states, self.num_heads // self.num_key_value_heads) + + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + + if attention_mask is not None: + attn_weights = attn_weights + attention_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1) + attn_weights = self.dropout(attn_weights) + + attn_output = torch.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous().view(bsz, q_len, self.hidden_size) + attn_output = self.o_proj(attn_output) + + present = past_key_value if use_cache else None + return attn_output, (attn_weights if output_attentions else None), present + + +class Evo2HyenaFilter(nn.Module): + def __init__(self, config: Evo2Config, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + self.num_attention_heads = config.num_attention_heads + self.order = config.hyena_order + self.short_filter_length = config.hyena_kernel_size + self.hyena_flip_x1x2 = config.hyena_flip_x1x2 + self.layer_idx = layer_idx + + # Projections: hidden_size -> 3 * hidden_size (for x, y, z) + self.projections = nn.Linear(self.hidden_size, self.order * self.hidden_size, bias=False) + + # Short filter (Conv1d) + self.short_filter_weight = nn.Parameter( + torch.randn(self.order * self.hidden_size, 1, self.short_filter_length) + ) + + # Output projection + self.out_filter_dense = nn.Linear(self.hidden_size, self.hidden_size, bias=True) + + # Long filter parameters - check if FIR or IIR based on config + self.hyena_filter_groups = config.hyena_filters + self.channels_per_group = self.hidden_size // self.hyena_filter_groups + + # Register parameters based on layer configuration + self.h = None + self.D = None + self.log_poles = None + self.residues = None + + if config.hyena_filter_configurations is not None and layer_idx < len(config.hyena_filter_configurations): + layer_config = config.hyena_filter_configurations[layer_idx] + + if layer_config.get("h_shape"): + self.h = nn.Parameter(torch.randn(layer_config["h_shape"])) + + if layer_config.get("D_shape"): + self.D = nn.Parameter(torch.randn(layer_config["D_shape"])) + + if layer_config.get("log_poles_shape"): + self.log_poles = nn.Parameter(torch.randn(layer_config["log_poles_shape"])) + + if layer_config.get("residues_shape"): + self.residues = nn.Parameter(torch.randn(layer_config["residues_shape"])) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + batch, seq_len, _ = hidden_states.shape + + # Project to 3 * hidden_size + u = self.projections(hidden_states) # [batch, seq_len, 3 * hidden_size] + + # Transpose for conv1d: [batch, 3 * hidden_size, seq_len] + u = u.transpose(1, 2) + + # Apply short filter (depthwise conv1d) + u = F.conv1d( + u.to(torch.float32), + self.short_filter_weight.to(torch.float32), + padding=self.short_filter_length - 1, + groups=self.order * self.hidden_size + )[:, :, :seq_len] + u = u.to(hidden_states.dtype) + + # Apply interleave to de-interleave the channels (following vortex model.py line 645) + # This reorders from [x1, x2, v, x1, x2, v, ...] to [x1, x1, ..., x2, x2, ..., v, v, ...] + # u is [batch, 3 * hidden_size, seq_len] + u_x1 = u[:, 0::3, :] # Every 3rd channel starting from 0 + u_x2 = u[:, 1::3, :] # Every 3rd channel starting from 1 + u_v = u[:, 2::3, :] # Every 3rd channel starting from 2 + u = torch.cat([u_x1, u_x2, u_v], dim=1) # [batch, 3 * hidden_size, seq_len] + + # Split into x2, x1, v + # Note: Vortex column_split returns x2, x1, v in that order + # u is now [batch, 3 * hidden_size, seq_len] + # We split along dim 1 + x2, x1, v = u.split([self.hidden_size, self.hidden_size, self.hidden_size], dim=1) + + # Vortex HyenaCascade.sequential_forward logic: + # x2, x1, v = column_split(...) + # if self.hyena_flip_x1x2: x1, x2 = x2, x1 + if self.hyena_flip_x1x2: + x1, x2 = x2, x1 + + # Compute x1 * v (element-wise) + x1v = x1 * v + + # Apply long filter to x1v + # Both FIR and IIR use FFT convolution (IIR converts to FIR first) + if self.h is not None or (self.log_poles is not None and self.residues is not None): + # Compute filter h + if self.h is not None: + # FIR filter: use pre-computed h + h = self.h + else: + # IIR filter: convert to FIR using modal form (model.py compute_filter line 703) + # h = (residues[..., None] * (log_poles * self.t).exp()).sum(1)[None] + # Create time vector t = [0, 1, 2, ..., seq_len-1] + t = torch.arange(seq_len, device=hidden_states.device, dtype=torch.float32)[None, None, :] # [1, 1, L] + + # log_poles: [hidden_size, state_dim, 1], residues: [hidden_size, state_dim] + log_poles = self.log_poles.to(torch.float32) + residues = self.residues.to(torch.float32) + + # Compute h = sum(residues * exp(log_poles * t), dim=state_dim) + # Broadcasting: log_poles [D, S, 1] * t [1, 1, L] = [D, S, L] + h = (residues.unsqueeze(-1) * torch.exp(log_poles * t)).sum(dim=1) # [D, L] + h = h.unsqueeze(0) # [1, D, L] - matches FIR filter shape [num_groups, hidden_size/num_groups, L] + + # FFT convolution following vortex engine.py parallel_iir + fft_size = 2 * seq_len + + # Prepare filter: h shape is [num_groups, 1, filter_len] or [1, hidden_size, 1, filter_len] for IIR + h = h.to(torch.float32) + H = torch.fft.rfft(h, n=fft_size) / fft_size # [num_groups, 1, fft_len] + + # Apply adjust_filter_shape_for_broadcast logic + H = H.squeeze() # [num_groups, fft_len] for FIR or [hidden_size, fft_len] for IIR + + # For x1v: [batch, hidden_size, seq_len], we need H: [batch, hidden_size, fft_len] + if H.shape[0] != self.hidden_size: + # FIR case: Repeat H for each channel in group + # hidden_size = num_groups * channels_per_group + if self.hyena_filter_groups > 1: + H = H.repeat_interleave(self.channels_per_group, dim=0) # [hidden_size, fft_len] + # else: IIR case, H is already [hidden_size, fft_len] + + H = H.unsqueeze(0) # [1, hidden_size, fft_len] + + # FFT of input - use torch.fft.fft like original, not rfft + X_s = torch.fft.fft(x1v.to(torch.float32), n=fft_size) # [batch, hidden_size, fft_size] + X = X_s[..., : H.shape[-1]] # [batch, hidden_size, fft_len] + + # Multiply in frequency domain + y = torch.fft.irfft(X * H, n=fft_size, norm='forward')[..., :seq_len] + + # Add bias (direct connection) - note the order matches original: (y + x1v * D) * x2 + # Ensure both y and x1v are in float32 for numerical stability + if self.D is not None: + x1v_f32 = x1v.to(torch.float32) + D_f32 = self.D.to(torch.float32) + y = y + x1v_f32 * D_f32.unsqueeze(0).unsqueeze(-1) + + # Convert back to original dtype (matching Mamba2 pattern) + y = y.to(hidden_states.dtype) + else: + # No long filter + y = x1v + + # Apply gating: x2 * y + z = x2 * y + + # Transpose back: [batch, hidden_size, seq_len] -> [batch, seq_len, hidden_size] + z = z.transpose(1, 2) + + # Output projection + out = self.out_filter_dense(z) + + return out + + +class Evo2AttentionBlock(nn.Module): + def __init__(self, config: Evo2Config, layer_idx: int): + super().__init__() + self.attention = Evo2Attention(config, layer_idx) + self.input_layernorm = Evo2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = Evo2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.mlp = Evo2ParallelGatedMLP(config, layer_idx) + self.hidden_dropout = nn.Dropout(config.hidden_dropout) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor], + position_ids: torch.Tensor, + past_key_value: Optional[Cache], + output_attentions: bool, + use_cache: bool, + cache_position: Optional[torch.LongTensor], + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]: + # Keep residual in float32 for numerical stability (Mamba2 technique) + residual = hidden_states.to(torch.float32) if hidden_states.dtype == torch.bfloat16 else hidden_states + hidden_states = self.input_layernorm(hidden_states) + attn_output, attn_weights, present_kv = self.attention( + hidden_states, + attention_mask, + position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + ) + hidden_states = (residual + self.hidden_dropout(attn_output).to(residual.dtype)).to(hidden_states.dtype) + + # Keep residual in float32 for numerical stability + residual = hidden_states.to(torch.float32) if hidden_states.dtype == torch.bfloat16 else hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = (residual + self.hidden_dropout(hidden_states).to(residual.dtype)).to(hidden_states.dtype) + + return hidden_states, attn_weights, present_kv + + +class Evo2HyenaBlock(nn.Module): + def __init__(self, config: Evo2Config, layer_idx: int): + super().__init__() + self.input_layernorm = Evo2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.filter = Evo2HyenaFilter(config, layer_idx) + self.post_attention_layernorm = Evo2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.mlp = Evo2ParallelGatedMLP(config, layer_idx) + self.hidden_dropout = nn.Dropout(config.hidden_dropout) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor], + position_ids: torch.Tensor, + past_key_value: Optional[Cache], + output_attentions: bool, + use_cache: bool, + cache_position: Optional[torch.LongTensor], + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]: + del attention_mask, past_key_value, output_attentions, use_cache, cache_position, position_ids + # Keep residual in float32 for numerical stability (Mamba2 technique) + residual = hidden_states.to(torch.float32) if hidden_states.dtype == torch.bfloat16 else hidden_states + hidden_states = self.input_layernorm(hidden_states) + hidden_states = self.filter(hidden_states) + hidden_states = (residual + self.hidden_dropout(hidden_states).to(residual.dtype)).to(hidden_states.dtype) + + # Keep residual in float32 for numerical stability + residual = hidden_states.to(torch.float32) if hidden_states.dtype == torch.bfloat16 else hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = (residual + self.hidden_dropout(hidden_states).to(residual.dtype)).to(hidden_states.dtype) + + return hidden_states, None, None + + +class Evo2DecoderLayer(GradientCheckpointingLayer): + def __init__(self, config: Evo2Config, layer_type: str, layer_idx: int): + super().__init__() + self.layer_type = layer_type + if layer_type == "attention": + self.block = Evo2AttentionBlock(config, layer_idx) + else: + self.block = Evo2HyenaBlock(config, layer_idx) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor], + position_ids: torch.Tensor, + past_key_value: Optional[Cache], + output_attentions: bool, + use_cache: bool, + cache_position: Optional[torch.LongTensor], + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]: + return self.block( + hidden_states, + attention_mask, + position_ids, + past_key_value, + output_attentions, + use_cache, + cache_position, + ) + + +class Evo2PreTrainedModel(PreTrainedModel): + config_class = Evo2Config + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["Evo2DecoderLayer"] + + def _init_weights(self, module): + if isinstance(module, nn.Linear): + nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + nn.init.zeros_(module.bias) + elif isinstance(module, nn.Embedding): + nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range) + elif isinstance(module, Evo2RMSNorm): + module.weight.data.fill_(1.0) + + + +class Evo2Model(Evo2PreTrainedModel): + def __init__(self, config: Evo2Config): + super().__init__(config) + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size) + self.layers = nn.ModuleList( + [Evo2DecoderLayer(config, layer_type, layer_idx) for layer_idx, layer_type in enumerate(config.layer_types)] + ) + self.norm = Evo2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout) + self.gradient_checkpointing = False + + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.Tensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, + ) -> BaseModelOutputWithPast: + output_attentions = ( + output_attentions if output_attentions is not None else getattr(self.config, "output_attentions", False) + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else getattr(self.config, "output_hidden_states", False) + ) + return_dict = return_dict if return_dict is not None else True + use_cache = use_cache if use_cache is not None else self.config.use_cache + + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + batch_size, seq_length = input_ids.shape + elif inputs_embeds is not None: + batch_size, seq_length = inputs_embeds.shape[:2] + else: + raise ValueError("You must specify either input_ids or inputs_embeds") + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if use_cache: + if past_key_values is None: + past_key_values = DynamicCache(config=self.config) + elif not isinstance(past_key_values, Cache): + raise TypeError("`past_key_values` must be a `Cache` when `use_cache` is True.") + else: + past_key_values = None + + past_length = past_key_values.get_seq_length() if past_key_values is not None else 0 + + attention_mask = _prepare_4d_causal_attention_mask( + attention_mask, + (batch_size, seq_length), + inputs_embeds, + past_length, + ) + + hidden_states = self.dropout(inputs_embeds) + + if cache_position is None: + cache_position = torch.arange(past_length, past_length + seq_length, device=hidden_states.device) + else: + cache_position = cache_position.to(hidden_states.device) + + if position_ids is None: + position_ids = cache_position.unsqueeze(0).expand(batch_size, -1) + + all_hidden_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + + for layer_idx, (decoder_layer, layer_type) in enumerate(zip(self.layers, self.config.layer_types)): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + layer_past = past_key_values if layer_type == "attention" else None + + if self.gradient_checkpointing and self.training: + if use_cache: + use_cache = False + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + hidden_states, attn, present_kv = checkpoint( + create_custom_forward(decoder_layer), + hidden_states, + attention_mask, + position_ids, + layer_past, + output_attentions, + use_cache, + cache_position, + ) + else: + hidden_states, attn, present_kv = decoder_layer( + hidden_states, + attention_mask, + position_ids, + layer_past, + output_attentions, + use_cache, + cache_position, + ) + + if layer_type == "attention" and present_kv is not None and use_cache: + past_key_values = present_kv + + if output_attentions: + all_attentions = all_attentions + (attn,) + + hidden_states = self.norm(hidden_states) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + outputs = (hidden_states, past_key_values) + if output_hidden_states: + outputs += (all_hidden_states,) + if output_attentions: + outputs += (all_attentions,) + return outputs + + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values if use_cache else None, + hidden_states=all_hidden_states, + attentions=all_attentions, + ) + + +class Evo2ForCausalLM(Evo2PreTrainedModel, GenerationMixin): + _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"} + + def __init__(self, config: Evo2Config): + super().__init__(config) + self.model = Evo2Model(config) + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.Tensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + logits_to_keep: int | torch.Tensor = 0, + **kwargs, + ) -> CausalLMOutputWithPast: + return_dict = return_dict if return_dict is not None else True + + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + **kwargs, + ) + + hidden_states = outputs[0] + if isinstance(logits_to_keep, int): + slice_indices = slice(-logits_to_keep, None) if logits_to_keep > 0 else slice(None) + # Mamba2 technique: convert to lm_head dtype, then to float32 for numerical stability + logits = self.lm_head(hidden_states[:, slice_indices, :].to(self.lm_head.weight.dtype)).float() + else: + # Mamba2 technique: convert to lm_head dtype, then to float32 for numerical stability + logits = self.lm_head(hidden_states[:, logits_to_keep, :].to(self.lm_head.weight.dtype)).float() + + loss = None + if labels is not None: + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + loss_fct = nn.CrossEntropyLoss() + loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) + + if not return_dict: + output = (logits,) + outputs[1:] + return ((loss,) + output) if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + def _reorder_cache(self, past_key_values: Cache, beam_idx: torch.LongTensor) -> Cache: + past_key_values.reorder_cache(beam_idx) + return past_key_values diff --git a/src/transformers/models/evo2/tokenization_evo2.py b/src/transformers/models/evo2/tokenization_evo2.py new file mode 100644 index 000000000000..6647c659446e --- /dev/null +++ b/src/transformers/models/evo2/tokenization_evo2.py @@ -0,0 +1,130 @@ +"""Tokenizer for the Evo2 model.""" + +from __future__ import annotations + +import json +import os +from typing import List, Optional + +from ...tokenization_utils import PreTrainedTokenizer +from ...utils import logging + + +logger = logging.get_logger(__name__) + +__all__ = ["Evo2Tokenizer"] + + +def _clamp_token_id(token_id: int) -> int: + return max(0, min(255, int(token_id))) + + +class Evo2Tokenizer(PreTrainedTokenizer): + model_input_names = ["input_ids", "attention_mask"] + + def __init__(self, **kwargs) -> None: + self._vocab_size = 256 + self._token_to_id = {chr(i): i for i in range(self._vocab_size)} + self._id_to_token = {i: chr(i) for i in range(self._vocab_size)} + self._eos_token_id = 0 + self._pad_token_id = 1 + self._bos_token_id = None + super().__init__( + bos_token=None, + eos_token=chr(0), + pad_token=chr(1), + unk_token=None, + add_bos_token=False, + add_eos_token=False, + **kwargs, + ) + + @property + def vocab_size(self) -> int: + return self._vocab_size + + def get_vocab(self) -> dict[str, int]: + vocab = dict(self._token_to_id) + vocab.update(self.added_tokens_encoder) + return vocab + + def tokenize(self, text: str, **kwargs) -> List[int]: + del kwargs + return list(text.encode("utf-8")) + + def _tokenize(self, text: str) -> List[str]: + return [str(byte) for byte in text.encode("utf-8")] + + def _convert_token_to_id(self, token: str | int) -> int: + if isinstance(token, int): + return _clamp_token_id(token) + if token in self.added_tokens_encoder: + return self.added_tokens_encoder[token] + if token in self._token_to_id: + return self._token_to_id[token] + return _clamp_token_id(int(token)) + + def _convert_id_to_token(self, index: int) -> str: + index = _clamp_token_id(index) + if index in self.added_tokens_decoder: + return self.added_tokens_decoder[index] + return self._id_to_token[index] + + def convert_tokens_to_string(self, tokens: List[str | int]) -> str: + byte_values = [] + for token in tokens: + if isinstance(token, str) and token in self.added_tokens_encoder: + token = self.added_tokens_encoder[token] + token_id = _clamp_token_id(int(token)) + byte_values.append(token_id) + return "".join(chr(byte) for byte in byte_values) + + def build_inputs_with_special_tokens( + self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None + ) -> List[int]: + if token_ids_1 is None: + return list(token_ids_0) + return list(token_ids_0) + list(token_ids_1) + + def get_special_tokens_mask( + self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False + ) -> List[int]: + if already_has_special_tokens: + return super().get_special_tokens_mask( + token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True + ) + if token_ids_1 is None: + return [0] * len(token_ids_0) + return [0] * (len(token_ids_0) + len(token_ids_1)) + + def create_token_type_ids_from_sequences( + self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None + ) -> List[int]: + length = len(token_ids_0) if token_ids_1 is None else len(token_ids_0) + len(token_ids_1) + return [0] * length + + def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> tuple[str]: + if not os.path.isdir(save_directory): + os.makedirs(save_directory, exist_ok=True) + file_name = (filename_prefix + "-" if filename_prefix else "") + "vocab.json" + path = os.path.join(save_directory, file_name) + with open(path, "w", encoding="utf-8") as f: + json.dump({str(i): i for i in range(self._vocab_size)}, f, ensure_ascii=False, indent=2) + return (path,) + + def _decode( + self, + token_ids: List[int], + skip_special_tokens: bool = False, + clean_up_tokenization_spaces: bool = None, + spaces_between_special_tokens: bool = True, + **kwargs, + ) -> str: + del clean_up_tokenization_spaces, spaces_between_special_tokens + if skip_special_tokens: + token_ids = [ + token_id + for token_id in token_ids + if token_id not in {self.pad_token_id, self.eos_token_id} + ] + return "".join(chr(_clamp_token_id(token_id)) for token_id in token_ids) diff --git a/tests/models/evo2/__init__.py b/tests/models/evo2/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/models/evo2/evo2_1b_base_ground_truth_logits.pt b/tests/models/evo2/evo2_1b_base_ground_truth_logits.pt new file mode 100644 index 000000000000..b7bb1bc0605a Binary files /dev/null and b/tests/models/evo2/evo2_1b_base_ground_truth_logits.pt differ diff --git a/tests/models/evo2/test_modeling_evo2.py b/tests/models/evo2/test_modeling_evo2.py new file mode 100644 index 000000000000..c87e9c98bfbd --- /dev/null +++ b/tests/models/evo2/test_modeling_evo2.py @@ -0,0 +1,321 @@ +import os +import unittest + +import pytest + +pytest.importorskip("parameterized") + +from transformers import is_torch_available +from transformers.testing_utils import require_torch, slow + +from ...causal_lm_tester import CausalLMModelTest, CausalLMModelTester + +if is_torch_available(): + import torch + + from transformers import Evo2ForCausalLM, Evo2Model, Evo2Tokenizer + + +class Evo2ModelTester(CausalLMModelTester): + if is_torch_available(): + base_model_class = Evo2Model + + def __init__(self, parent, **kwargs): + super().__init__( + parent, + pad_token_id=1, + bos_token_id=None, + eos_token_id=0, + vocab_size=256, + hidden_size=32, + num_hidden_layers=2, + num_attention_heads=4, + num_key_value_heads=4, + intermediate_size=64, + use_input_mask=True, + use_token_type_ids=False, + use_labels=True, + **kwargs, + ) + + def get_config(self): + config = super().get_config() + config.layer_types = ["attention"] * config.num_hidden_layers + config.hyena_filters = 8 + config.hyena_kernel_size = 3 + config.hyena_order = 2 + config.tie_word_embeddings = True + return config + + +@require_torch +class Evo2ModelTest(CausalLMModelTest, unittest.TestCase): + model_tester_class = Evo2ModelTester + + +@require_torch +@slow +class Evo2InferenceTest(unittest.TestCase): + """Test inference against ground truth logits from the original evo2_1b_base model.""" + + @staticmethod + def convert_original_weights_to_transformers(original_weights): + """Convert weights from original Evo2 format to transformers format.""" + from transformers import Evo2Config + + # Create config based on the original model architecture + # vocab_size=512, hidden_size=1920, 25 layers (21 hyena + 4 attention every 7th layer starting from 3) + layer_types = [] + for i in range(25): + if i % 7 == 3: + layer_types.append("attention") + else: + layer_types.append("hyena") + + config = Evo2Config( + vocab_size=512, + hidden_size=1920, + intermediate_size=5120, + num_hidden_layers=25, + num_attention_heads=15, # 1920 / 128 + num_key_value_heads=15, + layer_types=layer_types, + hyena_filters=128, # Number of filter groups + hyena_order=3, # 5760 / 1920 = 3 + hyena_kernel_size=3, # Short filter kernel size + tie_word_embeddings=True, + ) + + # Initialize new state dict + new_state_dict = {} + + # Convert embeddings + new_state_dict["model.embed_tokens.weight"] = original_weights["embedding_layer.weight"] + new_state_dict["lm_head.weight"] = original_weights["unembed.weight"] + + # Convert each layer + for layer_idx in range(25): + layer_type = layer_types[layer_idx] + orig_prefix = f"blocks.{layer_idx}" + new_prefix = f"model.layers.{layer_idx}.block" + + # Common components: norms and MLP + new_state_dict[f"model.layers.{layer_idx}.block.input_layernorm.weight"] = original_weights[ + f"{orig_prefix}.pre_norm.scale" + ] + new_state_dict[f"model.layers.{layer_idx}.block.post_attention_layernorm.weight"] = original_weights[ + f"{orig_prefix}.post_norm.scale" + ] + + # MLP layers + # Original: l1 (gate), l2 (up), l3 (down) + new_state_dict[f"{new_prefix}.mlp.gate_proj.weight"] = original_weights[f"{orig_prefix}.mlp.l1.weight"] + new_state_dict[f"{new_prefix}.mlp.up_proj.weight"] = original_weights[f"{orig_prefix}.mlp.l2.weight"] + new_state_dict[f"{new_prefix}.mlp.down_proj.weight"] = original_weights[f"{orig_prefix}.mlp.l3.weight"] + + if layer_type == "attention": + # Convert attention layer + # Original uses Wqkv (combined), we need separate q_proj, k_proj, v_proj + wqkv = original_weights[f"{orig_prefix}.inner_mha_cls.Wqkv.weight"] + hidden_size = config.hidden_size + head_dim = hidden_size // config.num_attention_heads + + # Split Wqkv into q, k, v + q, k, v = torch.split(wqkv, hidden_size, dim=0) + new_state_dict[f"model.layers.{layer_idx}.block.attention.q_proj.weight"] = q + new_state_dict[f"model.layers.{layer_idx}.block.attention.k_proj.weight"] = k + new_state_dict[f"model.layers.{layer_idx}.block.attention.v_proj.weight"] = v + + # Output projection + new_state_dict[f"model.layers.{layer_idx}.block.attention.o_proj.weight"] = original_weights[ + f"{orig_prefix}.inner_mha_cls.out_proj.weight" + ] + + # Load rotary embedding inv_freq from original weights + if f"{orig_prefix}.inner_mha_cls.rotary_emb.inv_freq" in original_weights: + new_state_dict[f"model.layers.{layer_idx}.block.attention.rotary_emb.inv_freq"] = original_weights[ + f"{orig_prefix}.inner_mha_cls.rotary_emb.inv_freq" + ] + + # Note: Original has out_proj.bias but our implementation doesn't use bias + else: + # Convert hyena filter layer + new_state_dict[f"model.layers.{layer_idx}.block.filter.projections.weight"] = original_weights[ + f"{orig_prefix}.projections.weight" + ] + new_state_dict[f"model.layers.{layer_idx}.block.filter.short_filter_weight"] = original_weights[ + f"{orig_prefix}.filter.short_filter_weight" + ] + new_state_dict[f"model.layers.{layer_idx}.block.filter.out_filter_dense.weight"] = original_weights[ + f"{orig_prefix}.out_filter_dense.weight" + ] + new_state_dict[f"model.layers.{layer_idx}.block.filter.out_filter_dense.bias"] = original_weights[ + f"{orig_prefix}.out_filter_dense.bias" + ] + + # Long filter parameters (FIR or IIR) + if f"{orig_prefix}.filter.h" in original_weights: + new_state_dict[f"model.layers.{layer_idx}.block.filter.h"] = original_weights[ + f"{orig_prefix}.filter.h" + ] + if f"{orig_prefix}.filter.D" in original_weights: + new_state_dict[f"model.layers.{layer_idx}.block.filter.D"] = original_weights[ + f"{orig_prefix}.filter.D" + ] + if f"{orig_prefix}.filter.log_poles" in original_weights: + new_state_dict[f"model.layers.{layer_idx}.block.filter.log_poles"] = original_weights[ + f"{orig_prefix}.filter.log_poles" + ] + if f"{orig_prefix}.filter.residues" in original_weights: + new_state_dict[f"model.layers.{layer_idx}.block.filter.residues"] = original_weights[ + f"{orig_prefix}.filter.residues" + ] + + # Final norm + new_state_dict["model.norm.weight"] = original_weights["norm.scale"] + + return new_state_dict, config + + def test_weight_loading(self): + """Test that we can successfully load and convert weights from the original model.""" + from huggingface_hub import hf_hub_download + + # Download original weights + weights_path = hf_hub_download("arcinstitute/evo2_1b_base", "evo2_1b_base.pt") + original_weights = torch.load(weights_path, map_location="cpu", weights_only=False) + + # Convert weights to transformers format + new_state_dict, config = self.convert_original_weights_to_transformers(original_weights) + + # Create model and load converted weights + model = Evo2ForCausalLM(config) + + # Load state dict (strict=False because Hyena layers have optional parameters) + missing_keys, unexpected_keys = model.load_state_dict(new_state_dict, strict=False) + + # Manually assign filter parameters (h, D, log_poles, residues) + for layer_idx in range(config.num_hidden_layers): + if config.layer_types[layer_idx] == "hyena": + filter_module = model.model.layers[layer_idx].block.filter + orig_prefix = f"blocks.{layer_idx}.filter" + + if f"{orig_prefix}.h" in original_weights: + filter_module.h = original_weights[f"{orig_prefix}.h"] + if f"{orig_prefix}.D" in original_weights: + filter_module.D = original_weights[f"{orig_prefix}.D"] + if f"{orig_prefix}.log_poles" in original_weights: + filter_module.log_poles = original_weights[f"{orig_prefix}.log_poles"] + if f"{orig_prefix}.residues" in original_weights: + filter_module.residues = original_weights[f"{orig_prefix}.residues"] + + # Check that only expected keys are missing/unexpected + # (Hyena filter parameters and rotary embeddings) + expected_patterns = ["filter.h", "filter.D", "filter.log_poles", "filter.residues", "rotary_emb.inv_freq"] + + for key in missing_keys: + self.assertTrue( + any(pattern in key for pattern in expected_patterns), + f"Unexpected missing key: {key}" + ) + + for key in unexpected_keys: + self.assertTrue( + any(pattern in key for pattern in expected_patterns), + f"Unexpected key in state dict: {key}" + ) + + print(f"✓ Successfully loaded weights ({len(missing_keys)} missing, {len(unexpected_keys)} unexpected)") + + def test_inference_shape(self): + """Test that the model can run inference and produces the correct output shape.""" + from huggingface_hub import hf_hub_download + + # Load ground truth for reference + ground_truth_path = os.path.join( + os.path.dirname(__file__), "evo2_1b_base_ground_truth_logits.pt" + ) + ground_truth = torch.load(ground_truth_path, map_location="cpu", weights_only=False) + + # Download and convert weights + weights_path = hf_hub_download("arcinstitute/evo2_1b_base", "evo2_1b_base.pt") + original_weights = torch.load(weights_path, map_location="cpu", weights_only=False) + new_state_dict, config = self.convert_original_weights_to_transformers(original_weights) + + # Create and load model + model = Evo2ForCausalLM(config) + model.load_state_dict(new_state_dict, strict=False) + + # Manually assign filter parameters (h, D, log_poles, residues) + # These can't be loaded via load_state_dict because they're None initially + for layer_idx in range(config.num_hidden_layers): + if config.layer_types[layer_idx] == "hyena": + filter_module = model.model.layers[layer_idx].block.filter + orig_prefix = f"blocks.{layer_idx}.filter" + + if f"{orig_prefix}.h" in original_weights: + filter_module.h = original_weights[f"{orig_prefix}.h"] + if f"{orig_prefix}.D" in original_weights: + filter_module.D = original_weights[f"{orig_prefix}.D"] + if f"{orig_prefix}.log_poles" in original_weights: + filter_module.log_poles = original_weights[f"{orig_prefix}.log_poles"] + if f"{orig_prefix}.residues" in original_weights: + filter_module.residues = original_weights[f"{orig_prefix}.residues"] + + model = model.to(torch.bfloat16) + model.eval() + + # Create tokenizer + tokenizer = Evo2Tokenizer() + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + model = model.to(device) + + sequences = ground_truth["sequences"] + results = ground_truth["results"] + + # Test each sequence + for seq in sequences: + with self.subTest(sequence=seq): + # Get ground truth + gt_input_ids = results[seq]["input_ids"] + gt_logits = results[seq]["logits"] + + # Tokenize + tokens = tokenizer.tokenize(seq) + input_ids = torch.tensor([tokens], dtype=torch.long).to(device) + + # Verify input_ids match + self.assertTrue( + torch.equal(input_ids.cpu(), gt_input_ids.unsqueeze(0)), + f"Input IDs mismatch for sequence {seq!r}" + ) + + # Run inference + with torch.no_grad(): + outputs = model(input_ids) + logits = outputs.logits + + # Check shapes match + expected_shape = gt_logits.shape + actual_shape = logits.shape + self.assertEqual( + actual_shape, + expected_shape, + f"Shape mismatch for {seq!r}: expected {expected_shape}, got {actual_shape}" + ) + + # Check that logits are finite (not NaN or Inf) + self.assertTrue(torch.isfinite(logits).all(), f"Non-finite values in logits for {seq!r}") + + print(f"✓ {seq!r}: shape {actual_shape} OK, logits finite") + + # Check logits values match ground truth + # Using relaxed tolerance for bfloat16 + # rtol=1e-2, atol=1e-2 is typical for bfloat16 accumulation differences + torch.testing.assert_close(logits.cpu(), gt_logits.cpu(), rtol=0.02, atol=0.02) + + print(f"✓ {seq!r}: shape {actual_shape} OK, logits match ground truth") + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/models/evo2/test_tokenization_evo2.py b/tests/models/evo2/test_tokenization_evo2.py new file mode 100644 index 000000000000..a5d50cd0be1a --- /dev/null +++ b/tests/models/evo2/test_tokenization_evo2.py @@ -0,0 +1,30 @@ +import pytest + +from transformers import Evo2Tokenizer + + +@pytest.fixture +def tokenizer(): + return Evo2Tokenizer() + + +def test_round_trip_ascii(tokenizer): + text = "Hello, Evo2!" + encoded = tokenizer(text)["input_ids"] + expected = list(text.encode("utf-8")) + assert encoded == expected + decoded = tokenizer.decode(encoded) + assert decoded == text + + +def test_clamp_behavior(tokenizer): + tokens = [0, 1, 255, 300, -5] + decoded = tokenizer.decode(tokens) + expected = "".join(chr(max(0, min(255, token))) for token in tokens) + assert decoded == expected + + +def test_tokenize_returns_bytes(tokenizer): + text = "ABcd" + tokens = tokenizer.tokenize(text) + assert tokens == list(text.encode("utf-8"))