|
7 | 7 | import logging |
8 | 8 | import argparse |
9 | 9 | import contextlib |
| 10 | +import importlib.util |
10 | 11 | import json |
11 | 12 | import os |
12 | 13 | import re |
|
29 | 30 | sys.path.insert(1, str(Path(__file__).parent / 'gguf-py')) |
30 | 31 | import gguf |
31 | 32 | from gguf.vocab import MistralTokenizerType, MistralVocab |
32 | | -from mistral_common.tokens.tokenizers.base import TokenizerVersion |
33 | | -from mistral_common.tokens.tokenizers.multimodal import DATASET_MEAN, DATASET_STD |
34 | | -from mistral_common.tokens.tokenizers.tekken import Tekkenizer |
35 | | -from mistral_common.tokens.tokenizers.sentencepiece import ( |
36 | | - SentencePieceTokenizer, |
37 | | -) |
| 33 | + |
| 34 | +if importlib.util.find_spec("mistral_common") is not None: |
| 35 | + from mistral_common.tokens.tokenizers.base import TokenizerVersion # pyright: ignore[reportMissingImports] |
| 36 | + from mistral_common.tokens.tokenizers.multimodal import DATASET_MEAN as _MISTRAL_COMMON_DATASET_MEAN, DATASET_STD as _MISTRAL_COMMON_DATASET_STD # pyright: ignore[reportMissingImports] |
| 37 | + from mistral_common.tokens.tokenizers.tekken import Tekkenizer # pyright: ignore[reportMissingImports] |
| 38 | + from mistral_common.tokens.tokenizers.sentencepiece import ( # pyright: ignore[reportMissingImports] |
| 39 | + SentencePieceTokenizer, |
| 40 | + ) |
| 41 | + |
| 42 | + _mistral_common_installed = True |
| 43 | + _mistral_import_error_msg = "" |
| 44 | +else: |
| 45 | + _MISTRAL_COMMON_DATASET_MEAN = (0.48145466, 0.4578275, 0.40821073) |
| 46 | + _MISTRAL_COMMON_DATASET_STD = (0.26862954, 0.26130258, 0.27577711) |
| 47 | + |
| 48 | + _mistral_common_installed = False |
| 49 | + TokenizerVersion = None |
| 50 | + Tekkenizer = None |
| 51 | + SentencePieceTokenizer = None |
| 52 | + _mistral_import_error_msg = ( |
| 53 | + "Mistral format requires `mistral-common` to be installed. Please run " |
| 54 | + "`pip install mistral-common[image,audio]` to install it." |
| 55 | + ) |
38 | 56 |
|
39 | 57 |
|
40 | 58 | logger = logging.getLogger("hf-to-gguf") |
@@ -107,6 +125,9 @@ def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, fname_out: Path, |
107 | 125 | type(self) is MmprojModel: |
108 | 126 | raise TypeError(f"{type(self).__name__!r} should not be directly instantiated") |
109 | 127 |
|
| 128 | + if self.is_mistral_format and not _mistral_common_installed: |
| 129 | + raise ImportError(_mistral_import_error_msg) |
| 130 | + |
110 | 131 | self.dir_model = dir_model |
111 | 132 | self.ftype = ftype |
112 | 133 | self.fname_out = fname_out |
@@ -1363,8 +1384,8 @@ def set_gguf_parameters(self): |
1363 | 1384 | self.gguf_writer.add_vision_head_count(self.find_vparam(["num_attention_heads"])) |
1364 | 1385 |
|
1365 | 1386 | # preprocessor config |
1366 | | - image_mean = DATASET_MEAN if self.is_mistral_format else self.preprocessor_config["image_mean"] |
1367 | | - image_std = DATASET_STD if self.is_mistral_format else self.preprocessor_config["image_std"] |
| 1387 | + image_mean = _MISTRAL_COMMON_DATASET_MEAN if self.is_mistral_format else self.preprocessor_config["image_mean"] |
| 1388 | + image_std = _MISTRAL_COMMON_DATASET_STD if self.is_mistral_format else self.preprocessor_config["image_std"] |
1368 | 1389 |
|
1369 | 1390 | self.gguf_writer.add_vision_image_mean(image_mean) |
1370 | 1391 | self.gguf_writer.add_vision_image_std(image_std) |
@@ -2033,6 +2054,9 @@ def __init__(self, *args, **kwargs): |
2033 | 2054 | self.hparams["num_attention_heads"] = self.hparams.get("num_attention_heads", 32) |
2034 | 2055 |
|
2035 | 2056 | def _set_vocab_mistral(self): |
| 2057 | + if not _mistral_common_installed: |
| 2058 | + raise ImportError(_mistral_import_error_msg) |
| 2059 | + |
2036 | 2060 | vocab = MistralVocab(self.dir_model) |
2037 | 2061 | logger.info( |
2038 | 2062 | f"Converting tokenizer {vocab.tokenizer_type} of size {vocab.vocab_size}." |
@@ -9212,7 +9236,7 @@ class MistralModel(LlamaModel): |
9212 | 9236 |
|
9213 | 9237 | @staticmethod |
9214 | 9238 | def get_community_chat_template(vocab: MistralVocab, templates_dir: Path, is_mistral_format: bool): |
9215 | | - assert TokenizerVersion is not None, "mistral_common is not installed" |
| 9239 | + assert TokenizerVersion is not None and Tekkenizer is not None and SentencePieceTokenizer is not None, _mistral_import_error_msg |
9216 | 9240 | assert isinstance(vocab.tokenizer, (Tekkenizer, SentencePieceTokenizer)), ( |
9217 | 9241 | f"Expected Tekkenizer or SentencePieceTokenizer, got {type(vocab.tokenizer)}" |
9218 | 9242 | ) |
@@ -9594,6 +9618,8 @@ def main() -> None: |
9594 | 9618 | fname_out = ModelBase.add_prefix_to_filename(fname_out, "mmproj-") |
9595 | 9619 |
|
9596 | 9620 | is_mistral_format = args.mistral_format |
| 9621 | + if is_mistral_format and not _mistral_common_installed: |
| 9622 | + raise ImportError(_mistral_import_error_msg) |
9597 | 9623 | disable_mistral_community_chat_template = args.disable_mistral_community_chat_template |
9598 | 9624 |
|
9599 | 9625 | with torch.inference_mode(): |
|
0 commit comments