@@ -36,14 +36,71 @@ public sealed class EnglishRoberta : Model
3636 /// <param name="highestOccurrenceMappingPath">Remap the original GPT-2 model Ids to high occurrence ranks and values.</param>
3737 public EnglishRoberta ( string vocabularyPath , string mergePath , string highestOccurrenceMappingPath )
3838 {
39+ if ( vocabularyPath is null )
40+ {
41+ throw new ArgumentNullException ( nameof ( vocabularyPath ) ) ;
42+ }
43+
44+ if ( mergePath is null )
45+ {
46+ throw new ArgumentNullException ( nameof ( mergePath ) ) ;
47+ }
48+
49+ if ( highestOccurrenceMappingPath is null )
50+ {
51+ throw new ArgumentNullException ( nameof ( highestOccurrenceMappingPath ) ) ;
52+ }
53+
54+ using Stream vocabularyStream = File . OpenRead ( vocabularyPath ) ;
55+ using Stream mergeStream = File . OpenRead ( mergePath ) ;
56+ using Stream highestOccurrenceMappingStream = File . OpenRead ( highestOccurrenceMappingPath ) ;
57+
3958 // vocabularyPath like encoder.json
4059 // merge file like vocab.bpe
4160 // highestOccurrenceMappingPath like dict.txt
4261
43- _vocabIdToHighestOccurrence = GetHighestOccurrenceMapping ( highestOccurrenceMappingPath ) ;
44- _vocab = GetVocabulary ( vocabularyPath ) ;
62+ _vocabIdToHighestOccurrence = GetHighestOccurrenceMapping ( highestOccurrenceMappingStream ) ;
63+ _vocab = GetVocabulary ( vocabularyStream ) ;
4564 _vocabReverse = _vocab . ReverseSorted ( ) ;
46- _mergeRanks = GetMergeRanks ( mergePath ) ;
65+ _mergeRanks = GetMergeRanks ( mergeStream ) ;
66+ int maxCharValue = GetByteToUnicode ( out _byteToUnicode ) ;
67+ _charToString = new string [ maxCharValue ] ;
68+ for ( char c = ( char ) 0 ; c < ( char ) maxCharValue ; c ++ )
69+ {
70+ _charToString [ c ] = c . ToString ( ) ;
71+ }
72+
73+ _unicodeToByte = _byteToUnicode . Reverse ( ) ;
74+ _cache = new Cache < string , IReadOnlyList < Token > > ( ) ;
75+ }
76+
77+ /// <summary>
78+ /// Construct tokenizer object to use with the English Robert model.
79+ /// </summary>
80+ /// <param name="vocabularyStream">The stream of a JSON file containing the dictionary of string keys and their ids.</param>
81+ /// <param name="mergeStream">The stream of a file containing the tokens's pairs list.</param>
82+ /// <param name="highestOccurrenceMappingStream">Remap the original GPT-2 model Ids to high occurrence ranks and values.</param>
83+ public EnglishRoberta ( Stream vocabularyStream , Stream mergeStream , Stream highestOccurrenceMappingStream )
84+ {
85+ if ( vocabularyStream is null )
86+ {
87+ throw new ArgumentNullException ( nameof ( vocabularyStream ) ) ;
88+ }
89+
90+ if ( mergeStream is null )
91+ {
92+ throw new ArgumentNullException ( nameof ( mergeStream ) ) ;
93+ }
94+
95+ if ( highestOccurrenceMappingStream is null )
96+ {
97+ throw new ArgumentNullException ( nameof ( highestOccurrenceMappingStream ) ) ;
98+ }
99+
100+ _vocabIdToHighestOccurrence = GetHighestOccurrenceMapping ( highestOccurrenceMappingStream ) ;
101+ _vocab = GetVocabulary ( vocabularyStream ) ;
102+ _vocabReverse = _vocab . ReverseSorted ( ) ;
103+ _mergeRanks = GetMergeRanks ( mergeStream ) ;
47104 int maxCharValue = GetByteToUnicode ( out _byteToUnicode ) ;
48105 _charToString = new string [ maxCharValue ] ;
49106 for ( char c = ( char ) 0 ; c < ( char ) maxCharValue ; c ++ )
@@ -298,28 +355,24 @@ private IReadOnlyList<Token> ModifyTokenListOffsets(IReadOnlyList<Token> tokens,
298355 return tokens ;
299356 }
300357
301- private static HighestOccurrenceMapping GetHighestOccurrenceMapping ( string highestOccurrenceMappingPath ) =>
302- HighestOccurrenceMapping . Load ( highestOccurrenceMappingPath ) ;
358+ private static HighestOccurrenceMapping GetHighestOccurrenceMapping ( Stream highestOccurrenceMappingStream ) =>
359+ HighestOccurrenceMapping . Load ( highestOccurrenceMappingStream ) ;
303360
304- private Dictionary < string , int > GetVocabulary ( string vocabularyPath )
361+ private Dictionary < string , int > GetVocabulary ( Stream vocabularyStream )
305362 {
306363 Dictionary < string , int > ? vocab ;
307364 try
308365 {
309- using ( Stream stream = File . OpenRead ( vocabularyPath ) )
310- {
311- vocab = JsonSerializer . Deserialize < Dictionary < string , int > > ( stream ) as Dictionary < string , int > ;
312-
313- }
366+ vocab = JsonSerializer . Deserialize < Dictionary < string , int > > ( vocabularyStream ) as Dictionary < string , int > ;
314367 }
315368 catch ( Exception e )
316369 {
317- throw new ArgumentException ( $ "Problems met when parsing JSON object in { vocabularyPath } .{ Environment . NewLine } Error message: { e . Message } ") ;
370+ throw new ArgumentException ( $ "Problems met when parsing JSON vocabulary object .{ Environment . NewLine } Error message: { e . Message } ") ;
318371 }
319372
320373 if ( vocab is null )
321374 {
322- throw new ArgumentException ( $ "Failed to read the vocabulary file ' { vocabularyPath } ' ") ;
375+ throw new ArgumentException ( $ "Failed to read the vocabulary file. ") ;
323376 }
324377
325378 if ( _vocabIdToHighestOccurrence . BosWord is not null )
@@ -345,28 +398,28 @@ private Dictionary<string, int> GetVocabulary(string vocabularyPath)
345398 return vocab ;
346399 }
347400
348- private Dictionary < ( string , string ) , int > GetMergeRanks ( string mergePath )
401+ private Dictionary < ( string , string ) , int > GetMergeRanks ( Stream mergeStream )
349402 {
350- string [ ] splitContents ;
403+ List < string > splitContents = new ( ) ;
351404
352405 try
353406 {
354- splitContents = File . ReadAllLines ( mergePath ) ;
407+ using StreamReader reader = new StreamReader ( mergeStream ) ;
408+ while ( reader . Peek ( ) >= 0 )
409+ {
410+ splitContents . Add ( reader . ReadLine ( ) ) ;
411+ }
355412 }
356413 catch ( Exception e )
357414 {
358- throw new IOException ( $ "Cannot read the file ' { mergePath } ' .{ Environment . NewLine } Error message: { e . Message } ", e ) ;
415+ throw new IOException ( $ "Cannot read the file Merge file .{ Environment . NewLine } Error message: { e . Message } ", e ) ;
359416 }
360417
361418 var mergeRanks = new Dictionary < ( string , string ) , int > ( ) ;
362419
363- for ( int i = 0 ; i < splitContents . Length ; i ++ )
420+ // We ignore the first and last line in the file
421+ for ( int i = 1 ; i < splitContents . Count - 1 ; i ++ )
364422 {
365- if ( i == 0 || i == splitContents . Length - 1 )
366- {
367- continue ;
368- }
369-
370423 var split = splitContents [ i ] . Split ( ' ' ) ;
371424 if ( split . Length != 2 || string . IsNullOrEmpty ( split [ 0 ] ) || string . IsNullOrEmpty ( split [ 1 ] ) )
372425 {
@@ -664,22 +717,25 @@ public int this[int idx]
664717 /// 284 432911125
665718 /// ...
666719 /// </summary>
667- public static HighestOccurrenceMapping Load ( string fileName )
720+ public static HighestOccurrenceMapping Load ( Stream stream )
668721 {
669722 var mapping = new HighestOccurrenceMapping ( ) ;
670- mapping . AddFromFile ( fileName ) ;
723+ mapping . AddFromStream ( stream ) ;
671724 return mapping ;
672725 }
673726
674727 /// <summary>
675- /// Loads a pre-existing vocabulary from a text file and adds its symbols to this instance.
728+ /// Loads a pre-existing vocabulary from a text stream and adds its symbols to this instance.
676729 /// </summary>
677- public void AddFromFile ( string fileName )
730+ public void AddFromStream ( Stream stream )
678731 {
679- var lines = File . ReadAllLines ( fileName , Encoding . UTF8 ) ;
732+ Debug . Assert ( stream is not null ) ;
733+ using StreamReader reader = new StreamReader ( stream ) ;
680734
681- foreach ( var line in lines )
735+ while ( reader . Peek ( ) >= 0 )
682736 {
737+ string line = reader . ReadLine ( ) ;
738+
683739 var splitLine = line . Trim ( ) . Split ( ' ' ) ;
684740 if ( splitLine . Length != 2 )
685741 {
0 commit comments