From 84f8bc97975ea35a5a21b89d0bb0cd5b5da496bd Mon Sep 17 00:00:00 2001 From: Shaltiel Shmidman Date: Mon, 2 Dec 2024 18:56:33 +0200 Subject: [PATCH 1/8] Moved special tokens assignment below so the collection won't be modified --- src/Microsoft.ML.Tokenizers/Model/BertTokenizer.cs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/Microsoft.ML.Tokenizers/Model/BertTokenizer.cs b/src/Microsoft.ML.Tokenizers/Model/BertTokenizer.cs index 6c08fae5b5..f407a8c914 100644 --- a/src/Microsoft.ML.Tokenizers/Model/BertTokenizer.cs +++ b/src/Microsoft.ML.Tokenizers/Model/BertTokenizer.cs @@ -1,4 +1,4 @@ -// Licensed to the .NET Foundation under one or more agreements. +// Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. @@ -770,8 +770,6 @@ private static BertTokenizer Create( if (lowerCase) { Dictionary dic = options.SpecialTokens.ToDictionary(kvp => kvp.Key, kvp => kvp.Value); - options.SpecialTokens = dic; - foreach (var kvp in options.SpecialTokens) { if (!vocab.TryGetValue(new StringSpanOrdinalKey(kvp.Key), out int id) || id != kvp.Value) @@ -782,6 +780,8 @@ private static BertTokenizer Create( // Ensure that the special tokens are lowercased. dic[kvp.Key.ToLowerInvariant()] = kvp.Value; } + + options.SpecialTokens = dic; } } else From ffaa96270c95baf5eeceb75f65316775db428ad1 Mon Sep 17 00:00:00 2001 From: Shaltiel Shmidman Date: Thu, 5 Dec 2024 01:09:55 +0200 Subject: [PATCH 2/8] Added safe dictionary inversion --- src/Microsoft.ML.Tokenizers/Model/WordPieceTokenizer.cs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/Microsoft.ML.Tokenizers/Model/WordPieceTokenizer.cs b/src/Microsoft.ML.Tokenizers/Model/WordPieceTokenizer.cs index e362da9b93..6dc0051346 100644 --- a/src/Microsoft.ML.Tokenizers/Model/WordPieceTokenizer.cs +++ b/src/Microsoft.ML.Tokenizers/Model/WordPieceTokenizer.cs @@ -1,4 +1,4 @@ -// Licensed to the .NET Foundation under one or more agreements. +// Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. @@ -42,7 +42,7 @@ internal WordPieceTokenizer( options ??= new(); SpecialTokens = options.SpecialTokens; - SpecialTokensReverse = options.SpecialTokens is not null ? options.SpecialTokens.ToDictionary(kvp => kvp.Value, kvp => kvp.Key) : null; + SpecialTokensReverse = options.SpecialTokens is not null ? options.SpecialTokens.GroupBy(kvp => kvp.Value).ToDictionary(g => g.Key, g => g.First().Key) : null; if (options.UnknownToken is null) { @@ -800,4 +800,4 @@ public OperationStatus Decode(IEnumerable ids, Span destination, bool return OperationStatus.Done; } } -} \ No newline at end of file +} From 4ab63b2e150065df413287f0706e43739cfb1106 Mon Sep 17 00:00:00 2001 From: Shaltiel Shmidman Date: Thu, 5 Dec 2024 01:19:20 +0200 Subject: [PATCH 3/8] Added storing the not-normalized special tokens --- .../Model/BertTokenizer.cs | 46 ++++++++++++------- 1 file changed, 30 insertions(+), 16 deletions(-) diff --git a/src/Microsoft.ML.Tokenizers/Model/BertTokenizer.cs b/src/Microsoft.ML.Tokenizers/Model/BertTokenizer.cs index f407a8c914..6b8a11bb46 100644 --- a/src/Microsoft.ML.Tokenizers/Model/BertTokenizer.cs +++ b/src/Microsoft.ML.Tokenizers/Model/BertTokenizer.cs @@ -762,6 +762,7 @@ private static BertTokenizer Create( options.Normalizer ??= options.ApplyBasicTokenization ? new BertNormalizer(options.LowerCaseBeforeTokenization, options.IndividuallyTokenizeCjk, options.RemoveNonSpacingMarks) : null; + Dictionary? specialTokensDict = null; if (options.SplitOnSpecialTokens) { bool lowerCase = options.ApplyBasicTokenization && options.LowerCaseBeforeTokenization; @@ -769,7 +770,7 @@ private static BertTokenizer Create( { if (lowerCase) { - Dictionary dic = options.SpecialTokens.ToDictionary(kvp => kvp.Key, kvp => kvp.Value); + specialTokensDict = []; foreach (var kvp in options.SpecialTokens) { if (!vocab.TryGetValue(new StringSpanOrdinalKey(kvp.Key), out int id) || id != kvp.Value) @@ -777,39 +778,52 @@ private static BertTokenizer Create( throw new ArgumentException($"The special token '{kvp.Key}' is not in the vocabulary or assigned id value {id} different than the value {kvp.Value} in the special tokens."); } - // Ensure that the special tokens are lowercased. - dic[kvp.Key.ToLowerInvariant()] = kvp.Value; + // Add the special token into our dictionary, normalizing it, and adding it into the + // main vocab, if needed. + AddSpecialToken(vocab, specialTokensDict, kvp.Key, true); } - - options.SpecialTokens = dic; + } + else + { + specialTokensDict = options.SpecialTokens?.ToDictionary(); } } else { - // Create a dictionary with the special tokens. - Dictionary specialTokens = new Dictionary(); - options.SpecialTokens = specialTokens; - - AddSpecialToken(vocab, specialTokens, options.UnknownToken, lowerCase); - AddSpecialToken(vocab, specialTokens, options.SeparatorToken, lowerCase); - AddSpecialToken(vocab, specialTokens, options.PaddingToken, lowerCase); - AddSpecialToken(vocab, specialTokens, options.ClassificationToken, lowerCase); - AddSpecialToken(vocab, specialTokens, options.MaskingToken, lowerCase); + // Create a dictionary with the special tokens - store the un-normalized forms in the options as + // that field is exposed to the public. In addition, store the normalized form for creating the + // pre-tokenizer. + specialTokensDict = []; + Dictionary notNormalizedSpecialTokens = []; + AddSpecialToken(vocab, specialTokensDict, options.UnknownToken, lowerCase, notNormalizedSpecialTokens); + AddSpecialToken(vocab, specialTokensDict, options.SeparatorToken, lowerCase, notNormalizedSpecialTokens); + AddSpecialToken(vocab, specialTokensDict, options.PaddingToken, lowerCase, notNormalizedSpecialTokens); + AddSpecialToken(vocab, specialTokensDict, options.ClassificationToken, lowerCase, notNormalizedSpecialTokens); + AddSpecialToken(vocab, specialTokensDict, options.MaskingToken, lowerCase, notNormalizedSpecialTokens); + + options.SpecialTokens = notNormalizedSpecialTokens; } } - options.PreTokenizer ??= options.ApplyBasicTokenization ? PreTokenizer.CreateWordOrPunctuation(options.SplitOnSpecialTokens ? options.SpecialTokens : null) : PreTokenizer.CreateWhiteSpace(); + // We set the PreTokenizer here using the normalized special tokens dict (if relevant), and therefore we can + // keep the not-normalized special tokens dict in the options passed to the WordPieceTokenizer. + options.PreTokenizer ??= options.ApplyBasicTokenization ? PreTokenizer.CreateWordOrPunctuation(options.SplitOnSpecialTokens ? specialTokensDict : null) : PreTokenizer.CreateWhiteSpace(); return new BertTokenizer(vocab, vocabReverse, options); } - private static void AddSpecialToken(Dictionary vocab, Dictionary specialTokens, string token, bool lowerCase) + private static void AddSpecialToken(Dictionary vocab, Dictionary specialTokens, string token, bool lowerCase, Dictionary? notNormalizedSpecialTokens = null) { if (token is null || !vocab.TryGetValue(new StringSpanOrdinalKey(token), out int id)) { throw new ArgumentException($"The special token '{token}' is not in the vocabulary."); } + if (notNormalizedSpecialTokens is not null) + { + notNormalizedSpecialTokens[token] = id; + } + string normalizedToken = token; if (lowerCase) { From 49bca972065e49c6f6be04e0ab17ce76912414b1 Mon Sep 17 00:00:00 2001 From: Shaltiel Shmidman Date: Thu, 5 Dec 2024 01:56:34 +0200 Subject: [PATCH 4/8] Added support for net standard --- src/Microsoft.ML.Tokenizers/Model/BertTokenizer.cs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/Microsoft.ML.Tokenizers/Model/BertTokenizer.cs b/src/Microsoft.ML.Tokenizers/Model/BertTokenizer.cs index 6b8a11bb46..30f58e511e 100644 --- a/src/Microsoft.ML.Tokenizers/Model/BertTokenizer.cs +++ b/src/Microsoft.ML.Tokenizers/Model/BertTokenizer.cs @@ -785,7 +785,7 @@ private static BertTokenizer Create( } else { - specialTokensDict = options.SpecialTokens?.ToDictionary(); + specialTokensDict = options.SpecialTokens?.ToDictionary(kvp => kvp.Key, kvp => kvp.Value); } } else From 1b7691ceec2a379db5bc05b6c6f537d48f57dcb0 Mon Sep 17 00:00:00 2001 From: Shaltiel Shmidman Date: Thu, 5 Dec 2024 01:56:40 +0200 Subject: [PATCH 5/8] Added and updated tests --- .../BertTokenizerTests.cs | 93 ++++++++++++++++++- 1 file changed, 91 insertions(+), 2 deletions(-) diff --git a/test/Microsoft.ML.Tokenizers.Tests/BertTokenizerTests.cs b/test/Microsoft.ML.Tokenizers.Tests/BertTokenizerTests.cs index fb1c3850ba..8a5042f645 100644 --- a/test/Microsoft.ML.Tokenizers.Tests/BertTokenizerTests.cs +++ b/test/Microsoft.ML.Tokenizers.Tests/BertTokenizerTests.cs @@ -1,4 +1,4 @@ -// Licensed to the .NET Foundation under one or more agreements. +// Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. @@ -14,6 +14,91 @@ namespace Microsoft.ML.Tokenizers.Tests { public class BertTokenizerTests { + [Fact] + public void TestWithLowerCasingExplicitSpecialTokens() + { + // Add [SPECIAL] token at end (to keep indices as is) + // Ids: 0 1 2 3 4 5 6 7 8 9 10 11 12, 13 + string[] vocabTokens = ["[PAD]", "[UNK]", "[CLS]", "[SEP]", "[MASK]", "!", ",", "?", "hello", "world", "how", "are", "you", "[SPECIAL]"]; + + string vocabFile = WordPieceTests.CreateVocabFile(vocabTokens); + + Dictionary specialTokens = new() { + { "[PAD]", 0 }, + { "[UNK]", 1 }, + { "[CLS]", 2 }, + { "[SEP]", 3 }, + { "[MASK]", 4 }, + { "[SPECIAL]", 13 }, + }; + var bertOptions = new BertOptions() + { + SpecialTokens = specialTokens + }; + + try + { + using Stream vocabStream = File.OpenRead(vocabFile); + BertTokenizer[] bertTokenizers = [BertTokenizer.Create(vocabFile, bertOptions), BertTokenizer.Create(vocabStream, bertOptions)]; + + foreach (var tokenizer in bertTokenizers) + { + Assert.NotNull(tokenizer.PreTokenizer); + Assert.Equal("[UNK]", tokenizer.UnknownToken); + Assert.Equal(1, tokenizer.UnknownTokenId); + Assert.NotNull(tokenizer.Normalizer); + Assert.NotNull(tokenizer.PreTokenizer); + + Assert.True(tokenizer.SpecialTokens!.ContainsKey("[SPECIAL]")); + + string text = "Hello, How are you [SPECIAL]?"; + var tokens = tokenizer.EncodeToTokens(text, out string? normalizedText); + Assert.Equal("hello, how are you [special]?", normalizedText); + + Assert.Equal( + [ + new EncodedToken(8, "hello", new Range(0, 5)), + new EncodedToken(6, ",", new Range(5, 6)), + new EncodedToken(10, "how", new Range(7, 10)), + new EncodedToken(11, "are", new Range(11, 14)), + new EncodedToken(12, "you", new Range(15, 18)), + new EncodedToken(13, "[SPECIAL]", new Range(19, 28)), + new EncodedToken(7, "?", new Range(28, 29)) + ], + tokens); + + var ids = tokenizer.EncodeToIds(text); + Assert.Equal([tokenizer.ClassificationTokenId, 8, 6, 10, 11, 12, 13, 7, tokenizer.SeparatorTokenId], ids); + + Assert.Equal("[CLS] hello, how are you [SPECIAL]? [SEP]", tokenizer.Decode(ids)); + Assert.Equal("hello, how are you?", tokenizer.Decode(ids, skipSpecialTokens: true)); + + tokens = tokenizer.EncodeToTokens(tokenizer.Decode(ids), out normalizedText); + Assert.Equal("[cls] hello, how are you [special]? [sep]", normalizedText); + Assert.Equal( + [ + new EncodedToken(2, "[CLS]", new Range(0, 5)), + new EncodedToken(8, "hello", new Range(6, 11)), + new EncodedToken(6, ",", new Range(11, 12)), + new EncodedToken(10, "how", new Range(13, 16)), + new EncodedToken(11, "are", new Range(17, 20)), + new EncodedToken(12, "you", new Range(21, 24)), + new EncodedToken(13, "[SPECIAL]", new Range(25, 34)), + new EncodedToken(7, "?", new Range(34, 35)), + new EncodedToken(3, "[SEP]", new Range(36, 41)) + ], + tokens); + + ids = tokenizer.EncodeToIds(normalizedText!); + Assert.Equal([tokenizer.ClassificationTokenId, tokenizer.ClassificationTokenId, 8, 6, 10, 11, 12, 13, 7, tokenizer.SeparatorTokenId, tokenizer.SeparatorTokenId], ids); + } + } + finally + { + File.Delete(vocabFile); + } + } + [Fact] public void TestWithLowerCasing() { @@ -35,6 +120,10 @@ public void TestWithLowerCasing() Assert.NotNull(tokenizer.Normalizer); Assert.NotNull(tokenizer.PreTokenizer); + // Make sure the SpecialTokens dictionary contains the not-normalized tokens + Assert.True(tokenizer.SpecialTokens!.ContainsKey(tokenizer.UnknownToken)); + Assert.True(tokenizer.SpecialTokens!.ContainsKey(tokenizer.ClassificationToken)); + string text = "Hello, How are you?"; var tokens = tokenizer.EncodeToTokens(text, out string? normalizedText); Assert.Equal("hello, how are you?", normalizedText); @@ -511,4 +600,4 @@ public void TestCreateTokenTypeIdsFromSequences() } } } -} \ No newline at end of file +} From fd1eb2c523f9afe6728c76779b5cfed67c2dbdf6 Mon Sep 17 00:00:00 2001 From: Shaltiel Shmidman Date: Thu, 5 Dec 2024 03:42:55 +0200 Subject: [PATCH 6/8] Updated without additional memory allocation --- .../Model/BertTokenizer.cs | 25 +++++++++---------- 1 file changed, 12 insertions(+), 13 deletions(-) diff --git a/src/Microsoft.ML.Tokenizers/Model/BertTokenizer.cs b/src/Microsoft.ML.Tokenizers/Model/BertTokenizer.cs index 30f58e511e..8d23442e89 100644 --- a/src/Microsoft.ML.Tokenizers/Model/BertTokenizer.cs +++ b/src/Microsoft.ML.Tokenizers/Model/BertTokenizer.cs @@ -762,7 +762,7 @@ private static BertTokenizer Create( options.Normalizer ??= options.ApplyBasicTokenization ? new BertNormalizer(options.LowerCaseBeforeTokenization, options.IndividuallyTokenizeCjk, options.RemoveNonSpacingMarks) : null; - Dictionary? specialTokensDict = null; + IReadOnlyDictionary? specialTokensDict = options.SpecialTokens; if (options.SplitOnSpecialTokens) { bool lowerCase = options.ApplyBasicTokenization && options.LowerCaseBeforeTokenization; @@ -770,7 +770,9 @@ private static BertTokenizer Create( { if (lowerCase) { - specialTokensDict = []; + Dictionary tempSpecialTokens = []; + specialTokensDict = tempSpecialTokens; + foreach (var kvp in options.SpecialTokens) { if (!vocab.TryGetValue(new StringSpanOrdinalKey(kvp.Key), out int id) || id != kvp.Value) @@ -780,28 +782,25 @@ private static BertTokenizer Create( // Add the special token into our dictionary, normalizing it, and adding it into the // main vocab, if needed. - AddSpecialToken(vocab, specialTokensDict, kvp.Key, true); + AddSpecialToken(vocab, tempSpecialTokens, kvp.Key, true); } } - else - { - specialTokensDict = options.SpecialTokens?.ToDictionary(kvp => kvp.Key, kvp => kvp.Value); - } } else { // Create a dictionary with the special tokens - store the un-normalized forms in the options as // that field is exposed to the public. In addition, store the normalized form for creating the // pre-tokenizer. - specialTokensDict = []; + Dictionary tempSpecialTokens = []; Dictionary notNormalizedSpecialTokens = []; - AddSpecialToken(vocab, specialTokensDict, options.UnknownToken, lowerCase, notNormalizedSpecialTokens); - AddSpecialToken(vocab, specialTokensDict, options.SeparatorToken, lowerCase, notNormalizedSpecialTokens); - AddSpecialToken(vocab, specialTokensDict, options.PaddingToken, lowerCase, notNormalizedSpecialTokens); - AddSpecialToken(vocab, specialTokensDict, options.ClassificationToken, lowerCase, notNormalizedSpecialTokens); - AddSpecialToken(vocab, specialTokensDict, options.MaskingToken, lowerCase, notNormalizedSpecialTokens); + AddSpecialToken(vocab, tempSpecialTokens, options.UnknownToken, lowerCase, notNormalizedSpecialTokens); + AddSpecialToken(vocab, tempSpecialTokens, options.SeparatorToken, lowerCase, notNormalizedSpecialTokens); + AddSpecialToken(vocab, tempSpecialTokens, options.PaddingToken, lowerCase, notNormalizedSpecialTokens); + AddSpecialToken(vocab, tempSpecialTokens, options.ClassificationToken, lowerCase, notNormalizedSpecialTokens); + AddSpecialToken(vocab, tempSpecialTokens, options.MaskingToken, lowerCase, notNormalizedSpecialTokens); options.SpecialTokens = notNormalizedSpecialTokens; + specialTokensDict = tempSpecialTokens; } } From 3b5ce564150a67f91bb41ca660465146f900d4a6 Mon Sep 17 00:00:00 2001 From: Tarek Mahmoud Sayed <10833894+tarekgh@users.noreply.github.com> Date: Thu, 5 Dec 2024 09:48:06 -0800 Subject: [PATCH 7/8] Apply suggestions from code review Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- src/Microsoft.ML.Tokenizers/Model/BertTokenizer.cs | 6 +++--- test/Microsoft.ML.Tokenizers.Tests/BertTokenizerTests.cs | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/Microsoft.ML.Tokenizers/Model/BertTokenizer.cs b/src/Microsoft.ML.Tokenizers/Model/BertTokenizer.cs index 8d23442e89..64534eddc8 100644 --- a/src/Microsoft.ML.Tokenizers/Model/BertTokenizer.cs +++ b/src/Microsoft.ML.Tokenizers/Model/BertTokenizer.cs @@ -770,7 +770,7 @@ private static BertTokenizer Create( { if (lowerCase) { - Dictionary tempSpecialTokens = []; + Dictionary tempSpecialTokens = new Dictionary(); specialTokensDict = tempSpecialTokens; foreach (var kvp in options.SpecialTokens) @@ -791,8 +791,8 @@ private static BertTokenizer Create( // Create a dictionary with the special tokens - store the un-normalized forms in the options as // that field is exposed to the public. In addition, store the normalized form for creating the // pre-tokenizer. - Dictionary tempSpecialTokens = []; - Dictionary notNormalizedSpecialTokens = []; + Dictionary tempSpecialTokens = {}; + Dictionary tempSpecialTokens = {}; AddSpecialToken(vocab, tempSpecialTokens, options.UnknownToken, lowerCase, notNormalizedSpecialTokens); AddSpecialToken(vocab, tempSpecialTokens, options.SeparatorToken, lowerCase, notNormalizedSpecialTokens); AddSpecialToken(vocab, tempSpecialTokens, options.PaddingToken, lowerCase, notNormalizedSpecialTokens); diff --git a/test/Microsoft.ML.Tokenizers.Tests/BertTokenizerTests.cs b/test/Microsoft.ML.Tokenizers.Tests/BertTokenizerTests.cs index 8a5042f645..9e6d7b0566 100644 --- a/test/Microsoft.ML.Tokenizers.Tests/BertTokenizerTests.cs +++ b/test/Microsoft.ML.Tokenizers.Tests/BertTokenizerTests.cs @@ -19,7 +19,7 @@ public void TestWithLowerCasingExplicitSpecialTokens() { // Add [SPECIAL] token at end (to keep indices as is) // Ids: 0 1 2 3 4 5 6 7 8 9 10 11 12, 13 - string[] vocabTokens = ["[PAD]", "[UNK]", "[CLS]", "[SEP]", "[MASK]", "!", ",", "?", "hello", "world", "how", "are", "you", "[SPECIAL]"]; + string[] vocabTokens = {"[PAD]", "[UNK]", "[CLS]", "[SEP]", "[MASK]", "!", ",", "?", "hello", "world", "how", "are", "you", "[SPECIAL]"}; string vocabFile = WordPieceTests.CreateVocabFile(vocabTokens); From 1362232e150afefa91e61aeaf85620a02b5261b7 Mon Sep 17 00:00:00 2001 From: Tarek Mahmoud Sayed Date: Thu, 5 Dec 2024 10:01:56 -0800 Subject: [PATCH 8/8] Fix copilot changes --- src/Microsoft.ML.Tokenizers/Model/BertTokenizer.cs | 10 +++++----- .../BertTokenizerTests.cs | 4 ++-- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/src/Microsoft.ML.Tokenizers/Model/BertTokenizer.cs b/src/Microsoft.ML.Tokenizers/Model/BertTokenizer.cs index 64534eddc8..b31081f538 100644 --- a/src/Microsoft.ML.Tokenizers/Model/BertTokenizer.cs +++ b/src/Microsoft.ML.Tokenizers/Model/BertTokenizer.cs @@ -781,7 +781,7 @@ private static BertTokenizer Create( } // Add the special token into our dictionary, normalizing it, and adding it into the - // main vocab, if needed. + // main vocab, if needed. AddSpecialToken(vocab, tempSpecialTokens, kvp.Key, true); } } @@ -789,10 +789,10 @@ private static BertTokenizer Create( else { // Create a dictionary with the special tokens - store the un-normalized forms in the options as - // that field is exposed to the public. In addition, store the normalized form for creating the + // that field is exposed to the public. In addition, store the normalized form for creating the // pre-tokenizer. - Dictionary tempSpecialTokens = {}; - Dictionary tempSpecialTokens = {}; + Dictionary tempSpecialTokens = new Dictionary(); + Dictionary notNormalizedSpecialTokens = new Dictionary(); AddSpecialToken(vocab, tempSpecialTokens, options.UnknownToken, lowerCase, notNormalizedSpecialTokens); AddSpecialToken(vocab, tempSpecialTokens, options.SeparatorToken, lowerCase, notNormalizedSpecialTokens); AddSpecialToken(vocab, tempSpecialTokens, options.PaddingToken, lowerCase, notNormalizedSpecialTokens); @@ -804,7 +804,7 @@ private static BertTokenizer Create( } } - // We set the PreTokenizer here using the normalized special tokens dict (if relevant), and therefore we can + // We set the PreTokenizer here using the normalized special tokens dict (if relevant), and therefore we can // keep the not-normalized special tokens dict in the options passed to the WordPieceTokenizer. options.PreTokenizer ??= options.ApplyBasicTokenization ? PreTokenizer.CreateWordOrPunctuation(options.SplitOnSpecialTokens ? specialTokensDict : null) : PreTokenizer.CreateWhiteSpace(); diff --git a/test/Microsoft.ML.Tokenizers.Tests/BertTokenizerTests.cs b/test/Microsoft.ML.Tokenizers.Tests/BertTokenizerTests.cs index 9e6d7b0566..ee4e541947 100644 --- a/test/Microsoft.ML.Tokenizers.Tests/BertTokenizerTests.cs +++ b/test/Microsoft.ML.Tokenizers.Tests/BertTokenizerTests.cs @@ -18,8 +18,8 @@ public class BertTokenizerTests public void TestWithLowerCasingExplicitSpecialTokens() { // Add [SPECIAL] token at end (to keep indices as is) - // Ids: 0 1 2 3 4 5 6 7 8 9 10 11 12, 13 - string[] vocabTokens = {"[PAD]", "[UNK]", "[CLS]", "[SEP]", "[MASK]", "!", ",", "?", "hello", "world", "how", "are", "you", "[SPECIAL]"}; + // Ids: 0 1 2 3 4 5 6 7 8 9 10 11 12, 13 + string[] vocabTokens = ["[PAD]", "[UNK]", "[CLS]", "[SEP]", "[MASK]", "!", ",", "?", "hello", "world", "how", "are", "you", "[SPECIAL]"]; string vocabFile = WordPieceTests.CreateVocabFile(vocabTokens);