Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
114 changes: 85 additions & 29 deletions src/Microsoft.ML.Tokenizers/Model/EnglishRoberta.cs
Original file line number Diff line number Diff line change
Expand Up @@ -36,14 +36,71 @@ public sealed class EnglishRoberta : Model
/// <param name="highestOccurrenceMappingPath">Remap the original GPT-2 model Ids to high occurrence ranks and values.</param>
public EnglishRoberta(string vocabularyPath, string mergePath, string highestOccurrenceMappingPath)
{
if (vocabularyPath is null)
{
throw new ArgumentNullException(nameof(vocabularyPath));
}

if (mergePath is null)
{
throw new ArgumentNullException(nameof(mergePath));
}

if (highestOccurrenceMappingPath is null)
{
throw new ArgumentNullException(nameof(highestOccurrenceMappingPath));
}

using Stream vocabularyStream = File.OpenRead(vocabularyPath);
using Stream mergeStream = File.OpenRead(mergePath);
using Stream highestOccurrenceMappingStream = File.OpenRead(highestOccurrenceMappingPath);

// vocabularyPath like encoder.json
// merge file like vocab.bpe
// highestOccurrenceMappingPath like dict.txt

_vocabIdToHighestOccurrence = GetHighestOccurrenceMapping(highestOccurrenceMappingPath);
_vocab = GetVocabulary(vocabularyPath);
_vocabIdToHighestOccurrence = GetHighestOccurrenceMapping(highestOccurrenceMappingStream);
_vocab = GetVocabulary(vocabularyStream);
_vocabReverse = _vocab.ReverseSorted();
_mergeRanks = GetMergeRanks(mergePath);
_mergeRanks = GetMergeRanks(mergeStream);
int maxCharValue = GetByteToUnicode(out _byteToUnicode);
_charToString = new string[maxCharValue];
for (char c = (char)0; c < (char)maxCharValue; c++)
{
_charToString[c] = c.ToString();
}

_unicodeToByte = _byteToUnicode.Reverse();
_cache = new Cache<string, IReadOnlyList<Token>>();
}

/// <summary>
/// Construct tokenizer object to use with the English Robert model.
/// </summary>
/// <param name="vocabularyStream">The stream of a JSON file containing the dictionary of string keys and their ids.</param>
/// <param name="mergeStream">The stream of a file containing the tokens's pairs list.</param>
/// <param name="highestOccurrenceMappingStream">Remap the original GPT-2 model Ids to high occurrence ranks and values.</param>
public EnglishRoberta(Stream vocabularyStream, Stream mergeStream, Stream highestOccurrenceMappingStream)
{
if (vocabularyStream is null)
{
throw new ArgumentNullException(nameof(vocabularyStream));
}

if (mergeStream is null)
{
throw new ArgumentNullException(nameof(mergeStream));
}

if (highestOccurrenceMappingStream is null)
{
throw new ArgumentNullException(nameof(highestOccurrenceMappingStream));
}

_vocabIdToHighestOccurrence = GetHighestOccurrenceMapping(highestOccurrenceMappingStream);
_vocab = GetVocabulary(vocabularyStream);
_vocabReverse = _vocab.ReverseSorted();
_mergeRanks = GetMergeRanks(mergeStream);
int maxCharValue = GetByteToUnicode(out _byteToUnicode);
_charToString = new string[maxCharValue];
for (char c = (char)0; c < (char)maxCharValue; c++)
Expand Down Expand Up @@ -298,28 +355,24 @@ private IReadOnlyList<Token> ModifyTokenListOffsets(IReadOnlyList<Token> tokens,
return tokens;
}

private static HighestOccurrenceMapping GetHighestOccurrenceMapping(string highestOccurrenceMappingPath) =>
HighestOccurrenceMapping.Load(highestOccurrenceMappingPath);
private static HighestOccurrenceMapping GetHighestOccurrenceMapping(Stream highestOccurrenceMappingStream) =>
HighestOccurrenceMapping.Load(highestOccurrenceMappingStream);

private Dictionary<string, int> GetVocabulary(string vocabularyPath)
private Dictionary<string, int> GetVocabulary(Stream vocabularyStream)
{
Dictionary<string, int>? vocab;
try
{
using (Stream stream = File.OpenRead(vocabularyPath))
{
vocab = JsonSerializer.Deserialize<Dictionary<string, int>>(stream) as Dictionary<string, int>;

}
vocab = JsonSerializer.Deserialize<Dictionary<string, int>>(vocabularyStream) as Dictionary<string, int>;
}
catch (Exception e)
{
throw new ArgumentException($"Problems met when parsing JSON object in {vocabularyPath}.{Environment.NewLine}Error message: {e.Message}");
throw new ArgumentException($"Problems met when parsing JSON vocabulary object.{Environment.NewLine}Error message: {e.Message}");
}

if (vocab is null)
{
throw new ArgumentException($"Failed to read the vocabulary file '{vocabularyPath}'");
throw new ArgumentException($"Failed to read the vocabulary file.");
}

if (_vocabIdToHighestOccurrence.BosWord is not null)
Expand All @@ -345,28 +398,28 @@ private Dictionary<string, int> GetVocabulary(string vocabularyPath)
return vocab;
}

private Dictionary<(string, string), int> GetMergeRanks(string mergePath)
private Dictionary<(string, string), int> GetMergeRanks(Stream mergeStream)
{
string[] splitContents;
List<string> splitContents = new();

try
{
splitContents = File.ReadAllLines(mergePath);
using StreamReader reader = new StreamReader(mergeStream);
while (reader.Peek() >= 0)
{
splitContents.Add(reader.ReadLine());
}
}
catch (Exception e)
{
throw new IOException($"Cannot read the file '{mergePath}'.{Environment.NewLine}Error message: {e.Message}", e);
throw new IOException($"Cannot read the file Merge file.{Environment.NewLine}Error message: {e.Message}", e);
}

var mergeRanks = new Dictionary<(string, string), int>();

for (int i = 0; i < splitContents.Length; i++)
// We ignore the first and last line in the file
for (int i = 1; i < splitContents.Count - 1; i++)
{
if (i == 0 || i == splitContents.Length - 1)
{
continue;
}

var split = splitContents[i].Split(' ');
if (split.Length != 2 || string.IsNullOrEmpty(split[0]) || string.IsNullOrEmpty(split[1]))
{
Expand Down Expand Up @@ -664,22 +717,25 @@ public int this[int idx]
/// 284 432911125
/// ...
/// </summary>
public static HighestOccurrenceMapping Load(string fileName)
public static HighestOccurrenceMapping Load(Stream stream)
{
var mapping = new HighestOccurrenceMapping();
mapping.AddFromFile(fileName);
mapping.AddFromStream(stream);
return mapping;
}

/// <summary>
/// Loads a pre-existing vocabulary from a text file and adds its symbols to this instance.
/// Loads a pre-existing vocabulary from a text stream and adds its symbols to this instance.
/// </summary>
public void AddFromFile(string fileName)
public void AddFromStream(Stream stream)
{
var lines = File.ReadAllLines(fileName, Encoding.UTF8);
Debug.Assert(stream is not null);
using StreamReader reader = new StreamReader(stream);

foreach (var line in lines)
while (reader.Peek() >= 0)
{
string line = reader.ReadLine();

var splitLine = line.Trim().Split(' ');
if (splitLine.Length != 2)
{
Expand Down
24 changes: 11 additions & 13 deletions src/Microsoft.ML.TorchSharp/Extensions/TokenizerExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

using System;
using System.Collections.Generic;
using System.Reflection;
using System.Runtime.CompilerServices;
using Microsoft.ML.Runtime;
using Microsoft.ML.Tokenizers;
Expand All @@ -13,25 +14,22 @@ namespace Microsoft.ML.TorchSharp.Extensions
{
internal static class TokenizerExtensions
{
private const string EncoderJsonName = "encoder.json";
private const string MergeName = "vocab.bpe";
private const string DictName = "dict.txt";

private static readonly Uri _encoderJsonUrl = new Uri("https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/encoder.json");
private static readonly Uri _mergeUrl = new Uri("https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/vocab.bpe");
private static readonly Uri _dictUrl = new Uri("https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/dict.txt");

private static Tokenizer _instance;

internal static Tokenizer GetInstance(IChannel ch)
{
if (_instance is null)
{
FileUtils.LoadFromFileOrDownloadFromWeb(string.Empty, EncoderJsonName, _encoderJsonUrl, ch);
FileUtils.LoadFromFileOrDownloadFromWeb(string.Empty, MergeName, _mergeUrl, ch);
FileUtils.LoadFromFileOrDownloadFromWeb(string.Empty, DictName, _dictUrl, ch);

EnglishRoberta model = new EnglishRoberta(EncoderJsonName, MergeName, DictName);
// encoder.json, vocab.bpe, and dict.txt are picked up from the following source:
// "https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/encoder.json"
// "https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/vocab.bpe"
// "https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/dict.txt"
Assembly assembly = typeof(TokenizerExtensions).Assembly;

EnglishRoberta model = new EnglishRoberta(
assembly.GetManifestResourceStream("encoder.json"),
assembly.GetManifestResourceStream("vocab.bpe"),
assembly.GetManifestResourceStream("dict.txt"));
model.AddMaskSymbol();
_instance = new Tokenizer(model, new RobertaPreTokenizer());
}
Expand Down
18 changes: 15 additions & 3 deletions src/Microsoft.ML.TorchSharp/Microsoft.ML.TorchSharp.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@

<ItemGroup>
<PackageReference Include="TorchSharp" Version="$(TorchSharpVersion)" />
<PackageReference Include="libtorch-cpu-win-x64" Version="$(LibTorchVersion)" Condition="$([MSBuild]::IsOSPlatform('Windows'))" PrivateAssets="all"/>
<PackageReference Include="libtorch-cpu-linux-x64" Version="$(LibTorchVersion)" Condition="$([MSBuild]::IsOSPlatform('Linux'))" PrivateAssets="all"/>
<PackageReference Include="libtorch-cpu-osx-x64" Version="$(LibTorchVersion)" Condition="$([MSBuild]::IsOSPlatform('OSX'))" PrivateAssets="all"/>
<PackageReference Include="libtorch-cpu-win-x64" Version="$(LibTorchVersion)" Condition="$([MSBuild]::IsOSPlatform('Windows'))" PrivateAssets="all" />
<PackageReference Include="libtorch-cpu-linux-x64" Version="$(LibTorchVersion)" Condition="$([MSBuild]::IsOSPlatform('Linux'))" PrivateAssets="all" />
<PackageReference Include="libtorch-cpu-osx-x64" Version="$(LibTorchVersion)" Condition="$([MSBuild]::IsOSPlatform('OSX'))" PrivateAssets="all" />
</ItemGroup>

<ItemGroup>
Expand All @@ -24,4 +24,16 @@
</ProjectReference>
</ItemGroup>

<ItemGroup>
<EmbeddedResource Include="Resources\dict.txt">
<LogicalName>dict.txt</LogicalName>
</EmbeddedResource>
<EmbeddedResource Include="Resources\encoder.json">
<LogicalName>encoder.json</LogicalName>
</EmbeddedResource>
<EmbeddedResource Include="Resources\vocab.bpe">
<LogicalName>vocab.bpe</LogicalName>
</EmbeddedResource>
</ItemGroup>

</Project>
Loading