1- // Licensed to the .NET Foundation under one or more agreements.
1+ // Licensed to the .NET Foundation under one or more agreements.
22// The .NET Foundation licenses this file to you under the MIT license.
33// See the LICENSE file in the project root for more information.
44
@@ -14,6 +14,91 @@ namespace Microsoft.ML.Tokenizers.Tests
1414{
1515 public class BertTokenizerTests
1616 {
17+ [Fact]
18+ public void TestWithLowerCasingExplicitSpecialTokens()
19+ {
20+ // Add [SPECIAL] token at end (to keep indices as is)
21+ // Ids: 0 1 2 3 4 5 6 7 8 9 10 11 12, 13
22+ string[] vocabTokens = ["[PAD]", "[UNK]", "[CLS]", "[SEP]", "[MASK]", "!", ",", "?", "hello", "world", "how", "are", "you", "[SPECIAL]"];
23+
24+ string vocabFile = WordPieceTests.CreateVocabFile(vocabTokens);
25+
26+ Dictionary<string, int> specialTokens = new() {
27+ { "[PAD]", 0 },
28+ { "[UNK]", 1 },
29+ { "[CLS]", 2 },
30+ { "[SEP]", 3 },
31+ { "[MASK]", 4 },
32+ { "[SPECIAL]", 13 },
33+ };
34+ var bertOptions = new BertOptions()
35+ {
36+ SpecialTokens = specialTokens
37+ };
38+
39+ try
40+ {
41+ using Stream vocabStream = File.OpenRead(vocabFile);
42+ BertTokenizer[] bertTokenizers = [BertTokenizer.Create(vocabFile, bertOptions), BertTokenizer.Create(vocabStream, bertOptions)];
43+
44+ foreach (var tokenizer in bertTokenizers)
45+ {
46+ Assert.NotNull(tokenizer.PreTokenizer);
47+ Assert.Equal("[UNK]", tokenizer.UnknownToken);
48+ Assert.Equal(1, tokenizer.UnknownTokenId);
49+ Assert.NotNull(tokenizer.Normalizer);
50+ Assert.NotNull(tokenizer.PreTokenizer);
51+
52+ Assert.True(tokenizer.SpecialTokens!.ContainsKey("[SPECIAL]"));
53+
54+ string text = "Hello, How are you [SPECIAL]?";
55+ var tokens = tokenizer.EncodeToTokens(text, out string? normalizedText);
56+ Assert.Equal("hello, how are you [special]?", normalizedText);
57+
58+ Assert.Equal(
59+ [
60+ new EncodedToken(8, "hello", new Range(0, 5)),
61+ new EncodedToken(6, ",", new Range(5, 6)),
62+ new EncodedToken(10, "how", new Range(7, 10)),
63+ new EncodedToken(11, "are", new Range(11, 14)),
64+ new EncodedToken(12, "you", new Range(15, 18)),
65+ new EncodedToken(13, "[SPECIAL]", new Range(19, 28)),
66+ new EncodedToken(7, "?", new Range(28, 29))
67+ ],
68+ tokens);
69+
70+ var ids = tokenizer.EncodeToIds(text);
71+ Assert.Equal([tokenizer.ClassificationTokenId, 8, 6, 10, 11, 12, 13, 7, tokenizer.SeparatorTokenId], ids);
72+
73+ Assert.Equal("[CLS] hello, how are you [SPECIAL]? [SEP]", tokenizer.Decode(ids));
74+ Assert.Equal("hello, how are you?", tokenizer.Decode(ids, skipSpecialTokens: true));
75+
76+ tokens = tokenizer.EncodeToTokens(tokenizer.Decode(ids), out normalizedText);
77+ Assert.Equal("[cls] hello, how are you [special]? [sep]", normalizedText);
78+ Assert.Equal(
79+ [
80+ new EncodedToken(2, "[CLS]", new Range(0, 5)),
81+ new EncodedToken(8, "hello", new Range(6, 11)),
82+ new EncodedToken(6, ",", new Range(11, 12)),
83+ new EncodedToken(10, "how", new Range(13, 16)),
84+ new EncodedToken(11, "are", new Range(17, 20)),
85+ new EncodedToken(12, "you", new Range(21, 24)),
86+ new EncodedToken(13, "[SPECIAL]", new Range(25, 34)),
87+ new EncodedToken(7, "?", new Range(34, 35)),
88+ new EncodedToken(3, "[SEP]", new Range(36, 41))
89+ ],
90+ tokens);
91+
92+ ids = tokenizer.EncodeToIds(normalizedText!);
93+ Assert.Equal([tokenizer.ClassificationTokenId, tokenizer.ClassificationTokenId, 8, 6, 10, 11, 12, 13, 7, tokenizer.SeparatorTokenId, tokenizer.SeparatorTokenId], ids);
94+ }
95+ }
96+ finally
97+ {
98+ File.Delete(vocabFile);
99+ }
100+ }
101+
17102 [Fact]
18103 public void TestWithLowerCasing()
19104 {
@@ -35,6 +120,10 @@ public void TestWithLowerCasing()
35120 Assert.NotNull(tokenizer.Normalizer);
36121 Assert.NotNull(tokenizer.PreTokenizer);
37122
123+ // Make sure the SpecialTokens dictionary contains the not-normalized tokens
124+ Assert.True(tokenizer.SpecialTokens!.ContainsKey(tokenizer.UnknownToken));
125+ Assert.True(tokenizer.SpecialTokens!.ContainsKey(tokenizer.ClassificationToken));
126+
38127 string text = "Hello, How are you?";
39128 var tokens = tokenizer.EncodeToTokens(text, out string? normalizedText);
40129 Assert.Equal("hello, how are you?", normalizedText);
@@ -511,4 +600,4 @@ public void TestCreateTokenTypeIdsFromSequences()
511600 }
512601 }
513602 }
514- }
603+ }
0 commit comments