Skip to content

Commit 64d17a6

Browse files
committed
Fix CI breakage
1 parent e4d4031 commit 64d17a6

File tree

2 files changed

+8
-10
lines changed

2 files changed

+8
-10
lines changed

torchtitan/components/tokenizer.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,10 @@
88
import json
99
import os
1010
from abc import ABC, abstractmethod
11-
from typing import Any, Optional
11+
from typing import Any, Optional, Union
1212

1313
from tokenizers import AddedToken, Tokenizer as HfTokenizer
14-
14+
from torchtitan.config_manager import JobConfig
1515
from typing_extensions import override
1616

1717

@@ -407,20 +407,18 @@ def id_to_token(self, token_id: int) -> Optional[str]:
407407
return self.tokenizer.id_to_token(token_id)
408408

409409

410-
def build_hf_tokenizer(tokenizer_path: str) -> HuggingFaceTokenizer:
410+
def build_hf_tokenizer(
411+
job_config: JobConfig,
412+
) -> Union[HuggingFaceTokenizer, Tokenizer]:
411413
"""
412414
Builds a HuggingFaceTokenizer from the specified path.
413-
414415
This function creates a HuggingFaceTokenizer instance that handles BOS/EOS token
415416
inference and intelligent encoding. The tokenizer automatically detects and loads
416417
from various file formats and infers special token behavior.
417-
418418
Args:
419-
tokenizer_path (str): Path to the directory containing tokenizer files.
420-
Should contain one or more of the supported file types.
421-
419+
JobConfig: A JobConfig object containing the path to the tokenizer directory.
422420
Returns:
423421
tokenizer (HuggingFaceTokenizer): Loaded tokenizer instance with intelligent BOS/EOS handling
424422
"""
425-
tokenizer = HuggingFaceTokenizer(tokenizer_path)
423+
tokenizer = HuggingFaceTokenizer(job_config.model.tokenizer_path)
426424
return tokenizer

torchtitan/train.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,7 @@ def __init__(self, job_config: JobConfig):
126126

127127
# build dataloader
128128
tokenizer = (
129-
self.train_spec.build_tokenizer_fn(job_config.model.tokenizer_path)
129+
self.train_spec.build_tokenizer_fn(job_config)
130130
if self.train_spec.build_tokenizer_fn is not None
131131
else None
132132
)

0 commit comments

Comments
 (0)