@@ -49,6 +49,7 @@ def vocab(ordered_dict: Dict, min_freq: int = 1,
4949 ordered_dict .pop (token , None )
5050
5151 tokens = []
52+ # Save room for special tokens
5253 for token , freq in ordered_dict .items ():
5354 if freq >= min_freq :
5455 tokens .append (token )
@@ -61,7 +62,7 @@ def vocab(ordered_dict: Dict, min_freq: int = 1,
6162 return Vocab (VocabPybind (tokens , None ))
6263
6364
64- def build_vocab_from_iterator (iterator : Iterable , min_freq : int = 1 , specials : Optional [List [str ]] = None , special_first : bool = True ) -> Vocab :
65+ def build_vocab_from_iterator (iterator : Iterable , min_freq : int = 1 , specials : Optional [List [str ]] = None , special_first : bool = True , max_tokens : Optional [ int ] = None ) -> Vocab :
6566 """
6667 Build a Vocab from an iterator.
6768
@@ -70,6 +71,7 @@ def build_vocab_from_iterator(iterator: Iterable, min_freq: int = 1, specials: O
7071 min_freq: The minimum frequency needed to include a token in the vocabulary.
7172 specials: Special symbols to add. The order of supplied tokens will be preserved.
7273 special_first: Indicates whether to insert symbols at the beginning or at the end.
74+ max_tokens: If provided, creates the vocab from the `max_tokens - len(specials)` most frequent tokens.
7375
7476
7577 Returns:
@@ -90,10 +92,16 @@ def build_vocab_from_iterator(iterator: Iterable, min_freq: int = 1, specials: O
9092 for tokens in iterator :
9193 counter .update (tokens )
9294
93- sorted_by_freq_tuples = sorted (counter .items (), key = lambda x : x [0 ])
94- sorted_by_freq_tuples .sort (key = lambda x : x [1 ], reverse = True )
95- ordered_dict = OrderedDict (sorted_by_freq_tuples )
95+ specials = specials or []
96+
97+ # First sort by descending frequency, then lexicographically
98+ sorted_by_freq_tuples = sorted (counter .items (), key = lambda x : (- x [1 ], x [0 ]))
99+
100+ if max_tokens is None :
101+ ordered_dict = OrderedDict (sorted_by_freq_tuples )
102+ else :
103+ assert len (specials ) < max_tokens , "len(specials) >= max_tokens, so the vocab will be entirely special tokens."
104+ ordered_dict = OrderedDict (sorted_by_freq_tuples [:max_tokens - len (specials )])
96105
97- word_vocab = vocab (ordered_dict , min_freq = min_freq , specials = specials or [],
98- special_first = special_first )
106+ word_vocab = vocab (ordered_dict , min_freq = min_freq , specials = specials , special_first = special_first )
99107 return word_vocab
0 commit comments