Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion fast_llm/engine/checkpoint/external.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,7 +283,7 @@ def _export_config(cls, config: BaseModelConfig) -> dict[str, typing.Any]:
return exported_config # Noqa

@classmethod
def _import_config(cls, config: dict[str, typing.Any]) -> BaseModelConfig: # noqa
def _import_config_dict(cls, config: dict[str, typing.Any]) -> dict[str | tuple[str, ...], typing.Any]:
kwargs = {}
for converter in cls._get_config_converters():
try:
Expand All @@ -306,7 +306,11 @@ def _import_config(cls, config: dict[str, typing.Any]) -> BaseModelConfig: # no
kwargs[fast_llm_name] = value
except Exception as e:
raise RuntimeError(f"Config conversion failed for converter {converter}", *e.args)
return kwargs

@classmethod
def _import_config(cls, config: dict[str, typing.Any]) -> BaseModelConfig: # noqa
kwargs = cls._import_config_dict(config)
return cls._model_class.get_base_model_config_class().from_dict({}, kwargs)

def _convert_state_dict(
Expand Down
3 changes: 3 additions & 0 deletions fast_llm/engine/checkpoint/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@ class CustomModelingExportMixin:
configuration_file: typing.ClassVar[str]
configuration_cls: typing.ClassVar[type[PretrainedConfig]]
generation_utils_file: str | None = None
additional_files: typing.ClassVar[list[str]] = []

# Use custom config instead of relying on the transformers library
@classmethod
Expand All @@ -159,3 +160,5 @@ def _copy_modeling_files(self, config: CheckpointSaveConfig) -> None:
gen_config = pathlib.Path(self.generation_utils_file).parent / "generation_config.json"
if gen_config.exists():
shutil.copy(gen_config, config.path)
for file in self.additional_files:
shutil.copy(file, config.path)
2 changes: 1 addition & 1 deletion fast_llm/layers/ssm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ class SSMDimNames:
v_heads = "v_heads" # Number of V heads

# Mamba 2
x_proj_dim_2 = "x_proj_dim" # d_xb
x_proj_dim_2 = "x_proj_dim_2" # d_xb


class SSMBlockType(enum.StrEnum):
Expand Down
15 changes: 14 additions & 1 deletion fast_llm/layers/ssm/mamba2.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import logging
import math
import typing

Expand All @@ -10,6 +11,8 @@
from fast_llm.tensor import ParameterMeta, init_fill_, init_ones_, init_uniform_, kaiming_init_
from fast_llm.utils import get_lr_scale

logger = logging.getLogger(__name__)

try:
from mamba_ssm.ops.selective_scan_interface import selective_scan_fn # noqa

Expand Down Expand Up @@ -144,7 +147,15 @@ def init_from_tensor_(
value: torch.Tensor,
) -> typing.Callable[[ParameterMeta, torch.Tensor, torch.Generator], torch.Tensor]:
def init_(meta: ParameterMeta, tensor: torch.Tensor, generator: torch.Generator): # noqa
return tensor.copy_(value)
logger.info(
f"Initializing {meta.tensor_name} with shape {meta.shape}, tensor shape {tensor.shape} from value shape {value.shape}"
)
# TODO: fix and remove try-except
try:
return tensor.copy_(value)
except RuntimeError as e:
logger.error(f"Failed to copy value to tensor: {e}")
return tensor.fill_(0.0)

return init_

Expand All @@ -156,6 +167,7 @@ def init_(meta: ParameterMeta, tensor: torch.Tensor, generator: torch.Generator)
lr_scale=mamba_layer_lr_scale,
)
# define bias outside the linear layer since its also used in the selective_scan_fn
logger.info(f"td_inner: {td_inner}, inv_dt: {inv_dt.shape}")
self.dt_proj_bias = ParameterMeta.from_dims(
(td_inner,), init_method=init_from_tensor_(inv_dt), lr_scale=mamba_layer_lr_scale
)
Expand All @@ -166,6 +178,7 @@ def init_(meta: ParameterMeta, tensor: torch.Tensor, generator: torch.Generator)
d=self.d_inner,
).contiguous()
A_log = torch.log(A).flatten() # Keep A_log in fp32
logger.info(f"A_log: {A_log.shape}, td_inner: {td_inner}, td_state: {td_state}")
self.A_log = ParameterMeta.from_dims(
(td_inner, td_state),
init_method=init_from_tensor_(A_log),
Expand Down
8 changes: 4 additions & 4 deletions fast_llm/layers/vision_encoder/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,15 +24,15 @@ def __init__(self, config: VisionEncoderConfig, tensor_space: TensorSpace):
input_dim,
tensor_space.get_tensor_dim(VisionEncoderDimNames.adapter_size),
bias=True,
weight_init_method=init_normal_(),
bias_init_method=init_normal_(),
weight_init_method=init_normal_(std=config.adapter_init_method_std),
bias_init_method=init_normal_(std=config.adapter_init_method_std),
)
self.layer_2 = Linear(
tensor_space.get_tensor_dim(VisionEncoderDimNames.adapter_size),
tensor_space.get_tensor_dim(TransformerDimNames.hidden),
bias=True,
weight_init_method=init_normal_(),
bias_init_method=init_normal_(),
weight_init_method=init_normal_(std=config.adapter_init_method_std),
bias_init_method=init_normal_(std=config.adapter_init_method_std),
)

def forward(
Expand Down
12 changes: 12 additions & 0 deletions fast_llm/layers/vision_encoder/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,18 @@ class VisionEncoderConfig(BaseModelConfig):
hint=FieldHint.feature,
valid=skip_valid_if_none(check_field(Assert.geq, 0)),
)
adapter_init_method_std: float = Field(
default=None,
desc="Standard deviation for the normal initialization of the adapter weights. Default: adapter_size ** -0.5.",
hint=FieldHint.optional,
valid=check_field(Assert.geq, 0),
)

def _validate(self) -> None:
with self._set_implicit_default():
if self.adapter_init_method_std is None:
self.adapter_init_method_std = self.adapter_size**-0.5
super()._validate()

def setup_tensor_space(self, tensor_space: TensorSpace):
tensor_space.add_tensor_dim(TensorDim(VisionEncoderDimNames.out_channels, self.transformer.hidden_size))
Expand Down
33 changes: 25 additions & 8 deletions fast_llm/layers/vision_encoder/preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,11 +163,12 @@ def preprocess(self, tokens, kwargs: dict[str, typing.Any]) -> None:
for imgs in images
]

labels = kwargs[LanguageModelKwargs.labels]
if (self._config.image_break_token is not None) or (self._config.image_end_token is not None):
# If image break or end token is present, we need to replace image token ids to -100 in labels
# TODO: avoid double cloning labels in case of loss masking spans?
labels = labels.clone()
if LanguageModelKwargs.labels in kwargs:
labels = kwargs[LanguageModelKwargs.labels]
if (self._config.image_break_token is not None) or (self._config.image_end_token is not None):
# If image break or end token is present, we need to replace image token ids to -100 in labels
# TODO: avoid double cloning labels in case of loss masking spans?
labels = labels.clone()

patches = []
patch_position_ids = []
Expand All @@ -191,8 +192,9 @@ def preprocess(self, tokens, kwargs: dict[str, typing.Any]) -> None:
image_break=self._config.image_break_token is not None,
image_end=self._config.image_end_token is not None,
)
# set labels for image patches to -100
labels[idx, max(position - 1, 0) : position + num_tokens - 1] = -100
if LanguageModelKwargs.labels in kwargs:
# set labels for image patches to -100
labels[idx, max(position - 1, 0) : position + num_tokens - 1] = -100
if seqlen > max_seqlen:
max_seqlen = seqlen
cu_seqlens.append(cu_seqlens[-1] + seqlen)
Expand Down Expand Up @@ -261,4 +263,19 @@ def preprocess(self, tokens, kwargs: dict[str, typing.Any]) -> None:
)
kwargs[VisionTransformerKwargs.max_seqlen_q] = max_seqlen
kwargs[VisionTransformerKwargs.max_seqlen_k] = max_seqlen
kwargs[LanguageModelKwargs.labels] = labels
if LanguageModelKwargs.labels in kwargs:
kwargs[LanguageModelKwargs.labels] = labels

# TODO: add proper preprocessing for attention-mask when not using flash attention
# Following is just a dummy code to run the tests.
kwargs[self._config.transformer._transformer_kwargs.attention_mask] = torch.ones(
(1, 1, kwargs[TransformerKwargs.sequence_length], 1, kwargs[TransformerKwargs.sequence_length]),
dtype=torch.bool,
device=self._tensor_space.distributed.device,
)
kwargs[self._config.transformer._transformer_kwargs.attention_mask_value] = torch.full(
[],
torch.finfo(self._distributed_config.training_dtype.torch).min,
dtype=self._distributed_config.training_dtype.torch,
device=self._tensor_space.distributed.device,
)
3 changes: 3 additions & 0 deletions fast_llm/models/gpt/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,9 @@ def _validate(self) -> None:
Assert.none(reference_model.model.base_model.cross_entropy_splits)
Assert.eq(reference_model.model.base_model.parallel_embeddings, self.model.base_model.parallel_embeddings)
Assert.geq(reference_model.model.base_model.prediction_heads, self.model.base_model.prediction_heads)
if self.model.base_model.vision_encoder.enabled:
assert self.batch.max_image_size is not None, "max_image_size must be set when using vision encoder"
Assert.gt(self.batch.max_image_size, 0)

@classmethod
def _from_dict(
Expand Down
Loading