Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/source/en/_toctree.yml
Original file line number Diff line number Diff line change
Expand Up @@ -480,6 +480,8 @@
title: ErnieM
- local: model_doc/esm
title: ESM
- local: model_doc/evo2
title: Evo2
- local: model_doc/exaone4
title: EXAONE-4.0
- local: model_doc/falcon
Expand Down
80 changes: 80 additions & 0 deletions docs/source/en/model_doc/evo2.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
<!--Copyright 2025 the HuggingFace Team. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.


⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be rendered properly in your Markdown viewer.

-->


# Evo2

## Overview

The Evo2 model was proposed in [Genome modeling and design across all domains of life with Evo 2](https://www.biorxiv.org/content/10.1101/2024.02.27.582234v1) by Garyk Brixi, Matthew G. Durrant, Jerome Ku, Michael Poli, et al.
It is a biological foundation model trained on 9.3 trillion DNA base pairs from a curated genomic atlas spanning all domains of life.

The abstract from the paper is the following:

Evo 2 is a biological foundation model trained on 9.3 trillion DNA base pairs from a curated genomic atlas spanning all domains of life. The model features 7B and 40B parameter architectures and can process sequences up to 1 million base pairs at nucleotide-level resolution. It learns from DNA sequences alone to accurately predict the functional impacts of genetic variation, including noncoding pathogenic mutations and clinically significant BRCA1 variants, without task-specific finetuning. Mechanistic interpretability analyses reveal that Evo 2 autonomously learns a breadth of biological features, such as exon-intron boundaries, transcription factor binding sites, protein structural elements, and prophage genomic regions. Beyond its predictive capabilities, Evo 2 can generate mitochondrial, prokaryotic, and eukaryotic sequences at genome scale with greater naturalness and coherence than previous methods. Guiding Evo 2 via inference-time search enables controllable generation of epigenomic structure, for which the first inference-time scaling results in biology are demonstrated. The project makes Evo 2 fully open, including model parameters, training code, inference code, and the OpenGenome2 dataset, to accelerate the exploration and design of biological complexity.

Tips:

- Evo 2 is a genomic foundation model, meaning it is designed to process and generate DNA sequences.
- It uses the StripedHyena architecture, which combines attention with Hyena filters to handle long contexts efficiently.
- The model is trained on a massive dataset of 9.3 trillion base pairs.
- It can handle context lengths up to 1 million base pairs (though this specific implementation may be limited by available memory).

This model was contributed by [arcinstitute](https://huggingface.co/arcinstitute).
The original code can be found [here](https://github.com/ArcInstitute/evo2).
The model was converted to Hugging Face format by [McClain Thiel](mailto:[email protected]).

## Usage examples

```python
from transformers import Evo2Config, Evo2ForCausalLM, Evo2Tokenizer

# Initialize model and tokenizer
config = Evo2Config()
model = Evo2ForCausalLM(config)
tokenizer = Evo2Tokenizer()

# Encode input DNA sequence
sequence = "ACGTACGT"
input_ids = tokenizer.encode(sequence, return_tensors="pt")

# Generate
output = model.generate(input_ids, max_length=20)
generated_sequence = tokenizer.decode(output[0])
print(generated_sequence)
```

## Evo2Config

[[autodoc]] Evo2Config

## Evo2ForCausalLM

[[autodoc]] Evo2ForCausalLM


## Evo2Model

[[autodoc]] Evo2Model
- forward

## Evo2PreTrainedModel

[[autodoc]] Evo2PreTrainedModel
- forward
8 changes: 8 additions & 0 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,14 @@
"utils.kernel_config": ["KernelConfig"],
}

_import_structure["models.evo2"] = [
"Evo2Config",
"Evo2ForCausalLM",
"Evo2Model",
"Evo2PreTrainedModel",
"Evo2Tokenizer",
]

# tokenizers-backed objects
try:
if not is_tokenizers_available():
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@
from .encoder_decoder import *
from .ernie import *
from .esm import *
from .evo2 import *
from .evolla import *
from .exaone4 import *
from .falcon import *
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/models/auto/configuration_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,7 @@
("ernie4_5_moe", "Ernie4_5_MoeConfig"),
("ernie_m", "ErnieMConfig"),
("esm", "EsmConfig"),
("evo2", "Evo2Config"),
("evolla", "EvollaConfig"),
("exaone4", "Exaone4Config"),
("falcon", "FalconConfig"),
Expand Down Expand Up @@ -590,6 +591,7 @@
("ernie4_5_moe", "Ernie4_5_MoE"),
("ernie_m", "ErnieM"),
("esm", "ESM"),
("evo2", "Evo2"),
("evolla", "Evolla"),
("exaone4", "EXAONE-4.0"),
("falcon", "Falcon"),
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/models/auto/modeling_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,7 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin):
("ernie4_5_moe", "Ernie4_5_MoeModel"),
("ernie_m", "ErnieMModel"),
("esm", "EsmModel"),
("evo2", "Evo2Model"),
("evolla", "EvollaModel"),
("exaone4", "Exaone4Model"),
("falcon", "FalconModel"),
Expand Down Expand Up @@ -666,6 +667,7 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin):
("ernie", "ErnieForCausalLM"),
("ernie4_5", "Ernie4_5ForCausalLM"),
("ernie4_5_moe", "Ernie4_5_MoeForCausalLM"),
("evo2", "Evo2ForCausalLM"),
("exaone4", "Exaone4ForCausalLM"),
("falcon", "FalconForCausalLM"),
("falcon_h1", "FalconH1ForCausalLM"),
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/auto/tokenization_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,7 @@
("ernie4_5_moe", (None, "LlamaTokenizerFast" if is_tokenizers_available() else None)),
("ernie_m", ("ErnieMTokenizer" if is_sentencepiece_available() else None, None)),
("esm", ("EsmTokenizer", None)),
("evo2", ("Evo2Tokenizer", None)),
("evolla", (None, "LlamaTokenizerFast" if is_tokenizers_available() else None)),
(
"exaone4",
Expand Down
39 changes: 39 additions & 0 deletions src/transformers/models/evo2/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
"""Evo2 model, tokenizer, and configuration."""

from typing import TYPE_CHECKING

from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available

_import_structure = {
"configuration_evo2": ["Evo2Config"],
"tokenization_evo2": ["Evo2Tokenizer"],
}

try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["modeling_evo2"] = [
"Evo2ForCausalLM",
"Evo2Model",
"Evo2PreTrainedModel",
]


if TYPE_CHECKING:
from .configuration_evo2 import Evo2Config
from .tokenization_evo2 import Evo2Tokenizer

try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .modeling_evo2 import Evo2ForCausalLM, Evo2Model, Evo2PreTrainedModel
else:
import sys

sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure)
185 changes: 185 additions & 0 deletions src/transformers/models/evo2/configuration_evo2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,185 @@
"""Evo2 model configuration."""

from __future__ import annotations

from typing import Optional, Sequence, List, Dict, Any

from ...configuration_utils import PretrainedConfig
from ...modeling_rope_utils import standardize_rope_params
from ...utils import logging


logger = logging.get_logger(__name__)

__all__ = ["Evo2Config"]


class Evo2Config(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`Evo2Model`]. It is used to instantiate an Evo2
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
defaults will yield a similar configuration to that of the Evo2-1b-base model.

Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.

Args:
vocab_size (`int`, *optional*, defaults to 512):
Vocabulary size of the Evo2 model. Defines the number of different tokens that can be represented by the
`inputs_ids` passed when calling [`Evo2Model`]
hidden_size (`int`, *optional*, defaults to 2048):
Dimension of the hidden representations.
intermediate_size (`int`, *optional*):
Dimension of the MLP representations.
num_hidden_layers (`int`, *optional*, defaults to 24):
Number of hidden layers in the Transformer encoder.
num_attention_heads (`int`, *optional*, defaults to 16):
Number of attention heads for each attention layer in the Transformer encoder.
num_key_value_heads (`int`, *optional*):
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
`num_key_value_heads=None`, the model will use the same number of key/value heads as the number of
query heads.
max_position_embeddings (`int`, *optional*, defaults to 2048):
The maximum sequence length that this model might ever be used with.
rope_theta (`float`, *optional*, defaults to 1000000.0):
The base period of the RoPE embeddings.
rms_norm_eps (`float`, *optional*, defaults to 1e-6):
The epsilon used by the rms normalization layers.
attn_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for the attention probabilities.
hidden_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for the hidden units.
mlp_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for the MLP layers.
layer_types (`Sequence[str]`, *optional*):
List of layer types ("attention" or "hyena") for each layer. If None, defaults to all "attention".
hyena_filters (`int`, *optional*, defaults to 256):
Number of Hyena filter groups.
hyena_kernel_size (`int`, *optional*, defaults to 8):
Kernel size for the short convolution in Hyena.
hyena_hidden_size (`int`, *optional*):
Hidden size for Hyena layers.
hyena_order (`int`, *optional*, defaults to 4):
Order of the Hyena recurrence.
hyena_flip_x1x2 (`bool`, *optional*, defaults to False):
Whether to flip x1 and x2 in the Hyena gating mechanism.
initializer_range (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
use_cache (`bool`, *optional*, defaults to True):
Whether or not the model should return the last key/values attentions (not used by all models).
pad_token_id (`int`, *optional*, defaults to 1):
Padding token id.
bos_token_id (`int`, *optional*):
Beginning of stream token id.
eos_token_id (`int`, *optional*, defaults to 0):
End of stream token id.
tie_word_embeddings (`bool`, *optional*, defaults to True):
Whether to tie weight embeddings
"""

model_type = "evo2"
keys_to_ignore_at_inference = ["past_key_values"]

def __init__(
self,
vocab_size: int = 512,
hidden_size: int = 2048,
intermediate_size: Optional[int] = None,
num_hidden_layers: int = 24,
num_attention_heads: int = 16,
num_key_value_heads: Optional[int] = None,
max_position_embeddings: int = 2048,
rope_theta: float = 1_000_000.0,
rms_norm_eps: float = 1e-6,
attn_dropout: float = 0.0,
hidden_dropout: float = 0.0,
mlp_dropout: float = 0.0,
layer_types: Optional[Sequence[str]] = None,
hyena_filters: int = 256,
hyena_kernel_size: int = 8,
hyena_hidden_size: Optional[int] = None,
hyena_order: int = 4,
hyena_flip_x1x2: bool = False,
hyena_filter_configurations: Optional[List[Dict[str, Any]]] = None,
initializer_range: float = 0.02,
use_cache: bool = True,
pad_token_id: int = 1,
bos_token_id: Optional[int] = None,
eos_token_id: int = 0,
tie_word_embeddings: bool = True,
**kwargs,
) -> None:
super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)

self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size if intermediate_size is not None else hidden_size * 4
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads or num_attention_heads
self.max_position_embeddings = max_position_embeddings
self.rope_theta = rope_theta
self.rms_norm_eps = rms_norm_eps
self.attn_dropout = attn_dropout
self.hidden_dropout = hidden_dropout
self.mlp_dropout = mlp_dropout
self.hyena_filters = hyena_filters
self.hyena_kernel_size = hyena_kernel_size
self.hyena_hidden_size = hyena_hidden_size if hyena_hidden_size is not None else hidden_size
self.hyena_order = hyena_order
self.hyena_flip_x1x2 = hyena_flip_x1x2
self.hyena_filter_configurations = hyena_filter_configurations
self.initializer_range = initializer_range
self.use_cache = use_cache

if layer_types is None:
self.layer_types = ["attention"] * num_hidden_layers
else:
self.layer_types = list(layer_types)

standardize_rope_params(self, rope_theta=self.rope_theta)

if len(self.layer_types) != self.num_hidden_layers:
raise ValueError(
"The length of `layer_types` must match `num_hidden_layers` (received"
f" {len(self.layer_types)} and {self.num_hidden_layers})."
)

for layer_type in self.layer_types:
if layer_type not in {"attention", "hyena"}:
raise ValueError(f"Unsupported layer type: {layer_type}. Expected 'attention' or 'hyena'.")

if self.num_attention_heads <= 0 or self.hidden_size % self.num_attention_heads != 0:
raise ValueError("`hidden_size` must be divisible by `num_attention_heads`.")

if self.num_key_value_heads <= 0 or self.hidden_size % self.num_key_value_heads != 0:
raise ValueError("`hidden_size` must be divisible by `num_key_value_heads`.")

logger.info("Initialized Evo2Config with %s layers (%s).", self.num_hidden_layers, ", ".join(self.layer_types))

@property
def head_dim(self) -> int:
return self.hidden_size // self.num_attention_heads

@property
def kv_head_dim(self) -> int:
return self.hidden_size // self.num_key_value_heads

@property
def num_attention_layers(self) -> int:
return sum(layer_type == "attention" for layer_type in self.layer_types)

@property
def num_hyena_layers(self) -> int:
return sum(layer_type == "hyena" for layer_type in self.layer_types)

def to_dict(self) -> dict:
output = super().to_dict()
output["layer_types"] = list(self.layer_types)
return output
Loading