Skip to content

Commit c69acbe

Browse files
authored
Embed the Tokenizer data files inside the assembly (#6403)
1 parent 1903fa5 commit c69acbe

File tree

9 files changed

+100381
-102
lines changed

9 files changed

+100381
-102
lines changed

src/Microsoft.ML.Tokenizers/Model/EnglishRoberta.cs

Lines changed: 85 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -36,14 +36,71 @@ public sealed class EnglishRoberta : Model
3636
/// <param name="highestOccurrenceMappingPath">Remap the original GPT-2 model Ids to high occurrence ranks and values.</param>
3737
public EnglishRoberta(string vocabularyPath, string mergePath, string highestOccurrenceMappingPath)
3838
{
39+
if (vocabularyPath is null)
40+
{
41+
throw new ArgumentNullException(nameof(vocabularyPath));
42+
}
43+
44+
if (mergePath is null)
45+
{
46+
throw new ArgumentNullException(nameof(mergePath));
47+
}
48+
49+
if (highestOccurrenceMappingPath is null)
50+
{
51+
throw new ArgumentNullException(nameof(highestOccurrenceMappingPath));
52+
}
53+
54+
using Stream vocabularyStream = File.OpenRead(vocabularyPath);
55+
using Stream mergeStream = File.OpenRead(mergePath);
56+
using Stream highestOccurrenceMappingStream = File.OpenRead(highestOccurrenceMappingPath);
57+
3958
// vocabularyPath like encoder.json
4059
// merge file like vocab.bpe
4160
// highestOccurrenceMappingPath like dict.txt
4261

43-
_vocabIdToHighestOccurrence = GetHighestOccurrenceMapping(highestOccurrenceMappingPath);
44-
_vocab = GetVocabulary(vocabularyPath);
62+
_vocabIdToHighestOccurrence = GetHighestOccurrenceMapping(highestOccurrenceMappingStream);
63+
_vocab = GetVocabulary(vocabularyStream);
4564
_vocabReverse = _vocab.ReverseSorted();
46-
_mergeRanks = GetMergeRanks(mergePath);
65+
_mergeRanks = GetMergeRanks(mergeStream);
66+
int maxCharValue = GetByteToUnicode(out _byteToUnicode);
67+
_charToString = new string[maxCharValue];
68+
for (char c = (char)0; c < (char)maxCharValue; c++)
69+
{
70+
_charToString[c] = c.ToString();
71+
}
72+
73+
_unicodeToByte = _byteToUnicode.Reverse();
74+
_cache = new Cache<string, IReadOnlyList<Token>>();
75+
}
76+
77+
/// <summary>
78+
/// Construct tokenizer object to use with the English Robert model.
79+
/// </summary>
80+
/// <param name="vocabularyStream">The stream of a JSON file containing the dictionary of string keys and their ids.</param>
81+
/// <param name="mergeStream">The stream of a file containing the tokens's pairs list.</param>
82+
/// <param name="highestOccurrenceMappingStream">Remap the original GPT-2 model Ids to high occurrence ranks and values.</param>
83+
public EnglishRoberta(Stream vocabularyStream, Stream mergeStream, Stream highestOccurrenceMappingStream)
84+
{
85+
if (vocabularyStream is null)
86+
{
87+
throw new ArgumentNullException(nameof(vocabularyStream));
88+
}
89+
90+
if (mergeStream is null)
91+
{
92+
throw new ArgumentNullException(nameof(mergeStream));
93+
}
94+
95+
if (highestOccurrenceMappingStream is null)
96+
{
97+
throw new ArgumentNullException(nameof(highestOccurrenceMappingStream));
98+
}
99+
100+
_vocabIdToHighestOccurrence = GetHighestOccurrenceMapping(highestOccurrenceMappingStream);
101+
_vocab = GetVocabulary(vocabularyStream);
102+
_vocabReverse = _vocab.ReverseSorted();
103+
_mergeRanks = GetMergeRanks(mergeStream);
47104
int maxCharValue = GetByteToUnicode(out _byteToUnicode);
48105
_charToString = new string[maxCharValue];
49106
for (char c = (char)0; c < (char)maxCharValue; c++)
@@ -298,28 +355,24 @@ private IReadOnlyList<Token> ModifyTokenListOffsets(IReadOnlyList<Token> tokens,
298355
return tokens;
299356
}
300357

301-
private static HighestOccurrenceMapping GetHighestOccurrenceMapping(string highestOccurrenceMappingPath) =>
302-
HighestOccurrenceMapping.Load(highestOccurrenceMappingPath);
358+
private static HighestOccurrenceMapping GetHighestOccurrenceMapping(Stream highestOccurrenceMappingStream) =>
359+
HighestOccurrenceMapping.Load(highestOccurrenceMappingStream);
303360

304-
private Dictionary<string, int> GetVocabulary(string vocabularyPath)
361+
private Dictionary<string, int> GetVocabulary(Stream vocabularyStream)
305362
{
306363
Dictionary<string, int>? vocab;
307364
try
308365
{
309-
using (Stream stream = File.OpenRead(vocabularyPath))
310-
{
311-
vocab = JsonSerializer.Deserialize<Dictionary<string, int>>(stream) as Dictionary<string, int>;
312-
313-
}
366+
vocab = JsonSerializer.Deserialize<Dictionary<string, int>>(vocabularyStream) as Dictionary<string, int>;
314367
}
315368
catch (Exception e)
316369
{
317-
throw new ArgumentException($"Problems met when parsing JSON object in {vocabularyPath}.{Environment.NewLine}Error message: {e.Message}");
370+
throw new ArgumentException($"Problems met when parsing JSON vocabulary object.{Environment.NewLine}Error message: {e.Message}");
318371
}
319372

320373
if (vocab is null)
321374
{
322-
throw new ArgumentException($"Failed to read the vocabulary file '{vocabularyPath}'");
375+
throw new ArgumentException($"Failed to read the vocabulary file.");
323376
}
324377

325378
if (_vocabIdToHighestOccurrence.BosWord is not null)
@@ -345,28 +398,28 @@ private Dictionary<string, int> GetVocabulary(string vocabularyPath)
345398
return vocab;
346399
}
347400

348-
private Dictionary<(string, string), int> GetMergeRanks(string mergePath)
401+
private Dictionary<(string, string), int> GetMergeRanks(Stream mergeStream)
349402
{
350-
string[] splitContents;
403+
List<string> splitContents = new();
351404

352405
try
353406
{
354-
splitContents = File.ReadAllLines(mergePath);
407+
using StreamReader reader = new StreamReader(mergeStream);
408+
while (reader.Peek() >= 0)
409+
{
410+
splitContents.Add(reader.ReadLine());
411+
}
355412
}
356413
catch (Exception e)
357414
{
358-
throw new IOException($"Cannot read the file '{mergePath}'.{Environment.NewLine}Error message: {e.Message}", e);
415+
throw new IOException($"Cannot read the file Merge file.{Environment.NewLine}Error message: {e.Message}", e);
359416
}
360417

361418
var mergeRanks = new Dictionary<(string, string), int>();
362419

363-
for (int i = 0; i < splitContents.Length; i++)
420+
// We ignore the first and last line in the file
421+
for (int i = 1; i < splitContents.Count - 1; i++)
364422
{
365-
if (i == 0 || i == splitContents.Length - 1)
366-
{
367-
continue;
368-
}
369-
370423
var split = splitContents[i].Split(' ');
371424
if (split.Length != 2 || string.IsNullOrEmpty(split[0]) || string.IsNullOrEmpty(split[1]))
372425
{
@@ -664,22 +717,25 @@ public int this[int idx]
664717
/// 284 432911125
665718
/// ...
666719
/// </summary>
667-
public static HighestOccurrenceMapping Load(string fileName)
720+
public static HighestOccurrenceMapping Load(Stream stream)
668721
{
669722
var mapping = new HighestOccurrenceMapping();
670-
mapping.AddFromFile(fileName);
723+
mapping.AddFromStream(stream);
671724
return mapping;
672725
}
673726

674727
/// <summary>
675-
/// Loads a pre-existing vocabulary from a text file and adds its symbols to this instance.
728+
/// Loads a pre-existing vocabulary from a text stream and adds its symbols to this instance.
676729
/// </summary>
677-
public void AddFromFile(string fileName)
730+
public void AddFromStream(Stream stream)
678731
{
679-
var lines = File.ReadAllLines(fileName, Encoding.UTF8);
732+
Debug.Assert(stream is not null);
733+
using StreamReader reader = new StreamReader(stream);
680734

681-
foreach (var line in lines)
735+
while (reader.Peek() >= 0)
682736
{
737+
string line = reader.ReadLine();
738+
683739
var splitLine = line.Trim().Split(' ');
684740
if (splitLine.Length != 2)
685741
{

src/Microsoft.ML.TorchSharp/Extensions/TokenizerExtensions.cs

Lines changed: 11 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
using System;
66
using System.Collections.Generic;
7+
using System.Reflection;
78
using System.Runtime.CompilerServices;
89
using Microsoft.ML.Runtime;
910
using Microsoft.ML.Tokenizers;
@@ -13,25 +14,22 @@ namespace Microsoft.ML.TorchSharp.Extensions
1314
{
1415
internal static class TokenizerExtensions
1516
{
16-
private const string EncoderJsonName = "encoder.json";
17-
private const string MergeName = "vocab.bpe";
18-
private const string DictName = "dict.txt";
19-
20-
private static readonly Uri _encoderJsonUrl = new Uri("https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/encoder.json");
21-
private static readonly Uri _mergeUrl = new Uri("https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/vocab.bpe");
22-
private static readonly Uri _dictUrl = new Uri("https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/dict.txt");
23-
2417
private static Tokenizer _instance;
2518

2619
internal static Tokenizer GetInstance(IChannel ch)
2720
{
2821
if (_instance is null)
2922
{
30-
FileUtils.LoadFromFileOrDownloadFromWeb(string.Empty, EncoderJsonName, _encoderJsonUrl, ch);
31-
FileUtils.LoadFromFileOrDownloadFromWeb(string.Empty, MergeName, _mergeUrl, ch);
32-
FileUtils.LoadFromFileOrDownloadFromWeb(string.Empty, DictName, _dictUrl, ch);
33-
34-
EnglishRoberta model = new EnglishRoberta(EncoderJsonName, MergeName, DictName);
23+
// encoder.json, vocab.bpe, and dict.txt are picked up from the following source:
24+
// "https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/encoder.json"
25+
// "https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/vocab.bpe"
26+
// "https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/dict.txt"
27+
Assembly assembly = typeof(TokenizerExtensions).Assembly;
28+
29+
EnglishRoberta model = new EnglishRoberta(
30+
assembly.GetManifestResourceStream("encoder.json"),
31+
assembly.GetManifestResourceStream("vocab.bpe"),
32+
assembly.GetManifestResourceStream("dict.txt"));
3533
model.AddMaskSymbol();
3634
_instance = new Tokenizer(model, new RobertaPreTokenizer());
3735
}

src/Microsoft.ML.TorchSharp/Microsoft.ML.TorchSharp.csproj

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,9 @@
1010

1111
<ItemGroup>
1212
<PackageReference Include="TorchSharp" Version="$(TorchSharpVersion)" />
13-
<PackageReference Include="libtorch-cpu-win-x64" Version="$(LibTorchVersion)" Condition="$([MSBuild]::IsOSPlatform('Windows'))" PrivateAssets="all"/>
14-
<PackageReference Include="libtorch-cpu-linux-x64" Version="$(LibTorchVersion)" Condition="$([MSBuild]::IsOSPlatform('Linux'))" PrivateAssets="all"/>
15-
<PackageReference Include="libtorch-cpu-osx-x64" Version="$(LibTorchVersion)" Condition="$([MSBuild]::IsOSPlatform('OSX'))" PrivateAssets="all"/>
13+
<PackageReference Include="libtorch-cpu-win-x64" Version="$(LibTorchVersion)" Condition="$([MSBuild]::IsOSPlatform('Windows'))" PrivateAssets="all" />
14+
<PackageReference Include="libtorch-cpu-linux-x64" Version="$(LibTorchVersion)" Condition="$([MSBuild]::IsOSPlatform('Linux'))" PrivateAssets="all" />
15+
<PackageReference Include="libtorch-cpu-osx-x64" Version="$(LibTorchVersion)" Condition="$([MSBuild]::IsOSPlatform('OSX'))" PrivateAssets="all" />
1616
</ItemGroup>
1717

1818
<ItemGroup>
@@ -24,4 +24,16 @@
2424
</ProjectReference>
2525
</ItemGroup>
2626

27+
<ItemGroup>
28+
<EmbeddedResource Include="Resources\dict.txt">
29+
<LogicalName>dict.txt</LogicalName>
30+
</EmbeddedResource>
31+
<EmbeddedResource Include="Resources\encoder.json">
32+
<LogicalName>encoder.json</LogicalName>
33+
</EmbeddedResource>
34+
<EmbeddedResource Include="Resources\vocab.bpe">
35+
<LogicalName>vocab.bpe</LogicalName>
36+
</EmbeddedResource>
37+
</ItemGroup>
38+
2739
</Project>

0 commit comments

Comments
 (0)