From 0b6c24dbb06d757e7568ad2dc2c1ede4b940a187 Mon Sep 17 00:00:00 2001 From: Guanheng Zhang Date: Wed, 11 Nov 2020 09:30:14 -0800 Subject: [PATCH 1/8] switch to_ivalue to __prepare_scriptable__ --- .../benchmark_basic_english_normalize.py | 2 +- benchmark/benchmark_experimental_vectors.py | 2 +- benchmark/benchmark_experimental_vocab.py | 6 +-- benchmark/benchmark_pytext_vocab.py | 2 +- examples/data_pipeline/pipelines.py | 20 +++++----- examples/data_pipeline/transforms.py | 12 +++--- test/data/test_functional.py | 12 +++--- test/experimental/test_transforms.py | 6 +-- .../test_transforms_with_asset.py | 10 ++--- test/experimental/test_vectors.py | 8 ++-- test/experimental/test_vocab.py | 8 ++-- torchtext/experimental/transforms.py | 38 +++++++++---------- torchtext/experimental/vectors.py | 2 +- torchtext/experimental/vocab.py | 4 +- 14 files changed, 66 insertions(+), 66 deletions(-) diff --git a/benchmark/benchmark_basic_english_normalize.py b/benchmark/benchmark_basic_english_normalize.py index d719e748a5..fa395b1299 100644 --- a/benchmark/benchmark_basic_english_normalize.py +++ b/benchmark/benchmark_basic_english_normalize.py @@ -15,7 +15,7 @@ def _run_benchmark_lookup(train, tokenizer): existing_basic_english_tokenizer = get_tokenizer("basic_english") experimental_basic_english_normalize = basic_english_normalize() - experimental_jit_basic_english_normalize = torch.jit.script(experimental_basic_english_normalize.to_ivalue()) + experimental_jit_basic_english_normalize = torch.jit.script(experimental_basic_english_normalize) # existing eager lookup train, _ = AG_NEWS() diff --git a/benchmark/benchmark_experimental_vectors.py b/benchmark/benchmark_experimental_vectors.py index 42fc008370..f644c14e62 100644 --- a/benchmark/benchmark_experimental_vectors.py +++ b/benchmark/benchmark_experimental_vectors.py @@ -42,7 +42,7 @@ def _run_benchmark_lookup(tokens, vector): # experimental FastText jit lookup print("FastText Experimental - Jit Mode") - jit_fast_text_experimental = torch.jit.script(fast_text_experimental.to_ivalue()) + jit_fast_text_experimental = torch.jit.script(fast_text_experimental) _run_benchmark_lookup(tokens, jit_fast_text_experimental) diff --git a/benchmark/benchmark_experimental_vocab.py b/benchmark/benchmark_experimental_vocab.py index 4183c27f6a..f815bf3648 100644 --- a/benchmark/benchmark_experimental_vocab.py +++ b/benchmark/benchmark_experimental_vocab.py @@ -67,7 +67,7 @@ def benchmark_experimental_vocab_construction(vocab_file_path, is_raw_text=True, print("Loading from raw text file with basic_english_normalize tokenizer") for _ in range(num_iters): tokenizer = basic_english_normalize() - jited_tokenizer = torch.jit.script(tokenizer.to_ivalue()) + jited_tokenizer = torch.jit.script(tokenizer) build_vocab_from_text_file(f, jited_tokenizer, num_cpus=1) print("Construction time:", time.monotonic() - t0) else: @@ -140,7 +140,7 @@ def token_iterator(file_path): t0 = time.monotonic() v_experimental = VocabExperimental(ordered_dict) print("Construction time:", time.monotonic() - t0) - jit_v_experimental = torch.jit.script(v_experimental.to_ivalue()) + jit_v_experimental = torch.jit.script(v_experimental) # existing Vocab eager lookup print("Vocab - Eager Mode") @@ -154,7 +154,7 @@ def token_iterator(file_path): _run_benchmark_lookup([tokens], v_experimental) _run_benchmark_lookup(tokens_lists, v_experimental) - jit_v_experimental = torch.jit.script(v_experimental.to_ivalue()) + jit_v_experimental = torch.jit.script(v_experimental) # experimental Vocab jit lookup print("Vocab Experimental - Jit Mode") _run_benchmark_lookup(tokens, jit_v_experimental) diff --git a/benchmark/benchmark_pytext_vocab.py b/benchmark/benchmark_pytext_vocab.py index 2e686dd5dc..6dbe200fd4 100644 --- a/benchmark/benchmark_pytext_vocab.py +++ b/benchmark/benchmark_pytext_vocab.py @@ -150,7 +150,7 @@ def benchmark_experimental_vocab(): t0 = time.monotonic() experimental_script_vocab = ExperimentalScriptVocabulary(ordered_dict, unk_token="") print("Construction time:", time.monotonic() - t0) - jit_experimental_script_vocab = torch.jit.script(experimental_script_vocab.to_ivalue()) + jit_experimental_script_vocab = torch.jit.script(experimental_script_vocab) # pytext Vocab eager lookup print("Pytext Vocabulary - Eager Mode") diff --git a/examples/data_pipeline/pipelines.py b/examples/data_pipeline/pipelines.py index 4e2db98021..5721a81b5c 100644 --- a/examples/data_pipeline/pipelines.py +++ b/examples/data_pipeline/pipelines.py @@ -34,9 +34,9 @@ def build_sp_pipeline(spm_file): # Insert token in vocab to match a pretrained vocab vocab.insert_token('', 1) pipeline = TextSequentialTransforms(tokenizer, vocab) - jit_pipeline = torch.jit.script(pipeline.to_ivalue()) + jit_pipeline = torch.jit.script(pipeline) print('jit sentencepiece pipeline success!') - return pipeline, pipeline.to_ivalue(), jit_pipeline + return pipeline, pipeline, jit_pipeline def build_legacy_torchtext_vocab_pipeline(vocab_file): @@ -59,9 +59,9 @@ def build_experimental_torchtext_pipeline(hf_vocab_file): with open(hf_vocab_file, 'r') as f: vocab = load_vocab_from_file(f) pipeline = TextSequentialTransforms(tokenizer, vocab) - jit_pipeline = torch.jit.script(pipeline.to_ivalue()) + jit_pipeline = torch.jit.script(pipeline) print('jit experimental torchtext pipeline success!') - return pipeline, pipeline.to_ivalue(), jit_pipeline + return pipeline, pipeline, jit_pipeline def build_legacy_batch_torchtext_vocab_pipeline(vocab_file): @@ -104,9 +104,9 @@ def build_legacy_pytext_script_vocab_pipeline(vocab_file): vocab_list.insert(0, "") pipeline = TextSequentialTransforms(tokenizer, PyTextScriptVocabTransform(ScriptVocabulary(vocab_list))) - jit_pipeline = torch.jit.script(pipeline.to_ivalue()) + jit_pipeline = torch.jit.script(pipeline) print('jit legacy PyText pipeline success!') - return pipeline, pipeline.to_ivalue(), jit_pipeline + return pipeline, pipeline, jit_pipeline def build_experimental_pytext_script_pipeline(vocab_file): @@ -125,9 +125,9 @@ def build_experimental_pytext_script_pipeline(vocab_file): # Insert token in vocab to match a pretrained vocab pipeline = TextSequentialTransforms(tokenizer, PyTextScriptVocabTransform(script_vocab(ordered_dict))) - jit_pipeline = torch.jit.script(pipeline.to_ivalue()) + jit_pipeline = torch.jit.script(pipeline) print('jit legacy PyText pipeline success!') - return pipeline, pipeline.to_ivalue(), jit_pipeline + return pipeline, pipeline, jit_pipeline def build_legacy_fasttext_vector_pipeline(): @@ -143,10 +143,10 @@ def build_experimental_fasttext_vector_pipeline(): vector = FastTextExperimental() pipeline = TextSequentialTransforms(tokenizer, vector) - jit_pipeline = torch.jit.script(pipeline.to_ivalue()) + jit_pipeline = torch.jit.script(pipeline) print('jit legacy fasttext pipeline success!') - return pipeline, pipeline.to_ivalue(), jit_pipeline + return pipeline, pipeline, jit_pipeline def run_benchmark_lookup(text_classification_dataset, pipeline): diff --git a/examples/data_pipeline/transforms.py b/examples/data_pipeline/transforms.py index 7a6d9214e5..5c9d33a0c3 100644 --- a/examples/data_pipeline/transforms.py +++ b/examples/data_pipeline/transforms.py @@ -24,11 +24,11 @@ def forward(self, tokens: List[str]) -> List[int]: def insert_token(self, token: str, index: int) -> None: self.vocab.insert_token(token, index) - def to_ivalue(self): - if hasattr(self.vocab, 'to_ivalue'): + def __prepare_scriptable__(self): + if hasattr(self.vocab, '__prepare_scriptable__'): sp_model = self.sp_model new_module = PretrainedSPVocab(sp_model) - new_module.vocab = self.vocab.to_ivalue() + new_module.vocab = self.vocab.__prepare_scriptable__() return new_module return self @@ -57,9 +57,9 @@ def __init__(self, vocab): def forward(self, tokens: List[str]) -> List[int]: return self.vocab.lookup_indices_1d(tokens) - def to_ivalue(self): - if hasattr(self.vocab, 'to_ivalue'): - vocab = self.vocab.to_ivalue() + def __prepare_scriptable__(self): + if hasattr(self.vocab, '__prepare_scriptable__'): + vocab = self.vocab.__prepare_scriptable__() return PyTextScriptVocabTransform(vocab) return self diff --git a/test/data/test_functional.py b/test/data/test_functional.py index 66fda21154..e6e7da5246 100644 --- a/test/data/test_functional.py +++ b/test/data/test_functional.py @@ -94,14 +94,14 @@ def test_BasicEnglishNormalize(self): basic_eng_norm = basic_english_normalize() experimental_eager_tokens = basic_eng_norm(test_sample) - jit_basic_eng_norm = torch.jit.script(basic_eng_norm.to_ivalue()) + jit_basic_eng_norm = torch.jit.script(basic_eng_norm) experimental_jit_tokens = jit_basic_eng_norm(test_sample) basic_english_tokenizer = data.get_tokenizer("basic_english") eager_tokens = basic_english_tokenizer(test_sample) assert not basic_eng_norm.is_jitable - assert basic_eng_norm.to_ivalue().is_jitable + assert basic_eng_norm.__prepare_scriptable__().is_jitable self.assertEqual(experimental_jit_tokens, ref_results) self.assertEqual(eager_tokens, ref_results) @@ -109,7 +109,7 @@ def test_BasicEnglishNormalize(self): # test load and save save_path = os.path.join(self.test_dir, 'basic_english_normalize.pt') - torch.save(basic_eng_norm.to_ivalue(), save_path) + torch.save(basic_eng_norm.__prepare_scriptable__(), save_path) loaded_basic_eng_norm = torch.load(save_path) loaded_eager_tokens = loaded_basic_eng_norm(test_sample) @@ -138,18 +138,18 @@ def test_RegexTokenizer(self): r_tokenizer = regex_tokenizer(patterns_list) eager_tokens = r_tokenizer(test_sample) - jit_r_tokenizer = torch.jit.script(r_tokenizer.to_ivalue()) + jit_r_tokenizer = torch.jit.script(r_tokenizer) jit_tokens = jit_r_tokenizer(test_sample) assert not r_tokenizer.is_jitable - assert r_tokenizer.to_ivalue().is_jitable + assert r_tokenizer.__prepare_scriptable__().is_jitable self.assertEqual(eager_tokens, ref_results) self.assertEqual(jit_tokens, ref_results) # test load and save save_path = os.path.join(self.test_dir, 'regex.pt') - torch.save(r_tokenizer.to_ivalue(), save_path) + torch.save(r_tokenizer.__prepare_scriptable__(), save_path) loaded_r_tokenizer = torch.load(save_path) loaded_eager_tokens = loaded_r_tokenizer(test_sample) diff --git a/test/experimental/test_transforms.py b/test/experimental/test_transforms.py index 69816dd6e7..1766fd029b 100644 --- a/test/experimental/test_transforms.py +++ b/test/experimental/test_transforms.py @@ -16,7 +16,7 @@ class TestTransforms(TorchtextTestCase): def test_sentencepiece_processor(self): model_path = get_asset_path('spm_example.model') spm_transform = sentencepiece_processor(model_path) - jit_spm_transform = torch.jit.script(spm_transform.to_ivalue()) + jit_spm_transform = torch.jit.script(spm_transform) test_sample = 'SentencePiece is an unsupervised text tokenizer and detokenizer' ref_results = [15340, 4286, 981, 1207, 1681, 17, 84, 684, 8896, 5366, 144, 3689, 9, 5602, 12114, 6, 560, 649, 5602, 12114] @@ -28,7 +28,7 @@ def test_sentencepiece_processor(self): def test_sentencepiece_tokenizer(self): model_path = get_asset_path('spm_example.model') spm_tokenizer = sentencepiece_tokenizer(model_path) - jit_spm_tokenizer = torch.jit.script(spm_tokenizer.to_ivalue()) + jit_spm_tokenizer = torch.jit.script(spm_tokenizer) test_sample = 'SentencePiece is an unsupervised text tokenizer and detokenizer' ref_results = ['\u2581Sent', 'ence', 'P', 'ie', 'ce', '\u2581is', '\u2581an', '\u2581un', 'super', 'vis', 'ed', '\u2581text', @@ -48,7 +48,7 @@ def test_vector_transform(self): data_path = os.path.join(dir_name, asset_name) shutil.copy(asset_path, data_path) vector_transform = VectorTransform(FastText(root=dir_name, validate_file=False)) - jit_vector_transform = torch.jit.script(vector_transform.to_ivalue()) + jit_vector_transform = torch.jit.script(vector_transform) # The first 3 entries in each vector. expected_fasttext_simple_en = torch.tensor([[-0.065334, -0.093031, -0.017571], [-0.32423, -0.098845, -0.0073467]]) diff --git a/test/experimental/test_transforms_with_asset.py b/test/experimental/test_transforms_with_asset.py index 492aa94e3e..0e27da4a62 100644 --- a/test/experimental/test_transforms_with_asset.py +++ b/test/experimental/test_transforms_with_asset.py @@ -33,7 +33,7 @@ def test_vocab_transform(self): vocab_transform = VocabTransform(load_vocab_from_file(f)) self.assertEqual(vocab_transform(['of', 'that', 'new']), [7, 18, 24]) - jit_vocab_transform = torch.jit.script(vocab_transform.to_ivalue()) + jit_vocab_transform = torch.jit.script(vocab_transform) self.assertEqual(jit_vocab_transform(['of', 'that', 'new', 'that']), [7, 18, 24, 18]) @@ -79,7 +79,7 @@ def test_glove(self): data_path = os.path.join(dir_name, asset_name) shutil.copy(asset_path, data_path) vectors_obj = GloVe(root=dir_name, validate_file=False) - jit_vectors_obj = torch.jit.script(vectors_obj.to_ivalue()) + jit_vectors_obj = torch.jit.script(vectors_obj) # The first 3 entries in each vector. expected_glove = { @@ -142,7 +142,7 @@ def test_vocab_from_raw_text_file(self): asset_path = get_asset_path(asset_name) with open(asset_path, 'r') as f: tokenizer = basic_english_normalize() - jit_tokenizer = torch.jit.script(tokenizer.to_ivalue()) + jit_tokenizer = torch.jit.script(tokenizer) v = build_vocab_from_text_file(f, jit_tokenizer, unk_token='') expected_itos = ['', "'", 'after', 'talks', '.', 'are', 'at', 'disappointed', 'fears', 'federal', 'firm', 'for', 'mogul', 'n', 'newall', 'parent', @@ -174,7 +174,7 @@ def test_text_sequential_transform(self): asset_path = get_asset_path(asset_name) with open(asset_path, 'r') as f: pipeline = TextSequentialTransforms(basic_english_normalize(), load_vocab_from_file(f)) - jit_pipeline = torch.jit.script(pipeline.to_ivalue()) + jit_pipeline = torch.jit.script(pipeline) self.assertEqual(pipeline('of that new'), [7, 18, 24]) self.assertEqual(jit_pipeline('of that new'), [7, 18, 24]) @@ -201,7 +201,7 @@ def test_fast_text(self): data_path = os.path.join(dir_name, asset_name) shutil.copy(asset_path, data_path) vectors_obj = FastText(root=dir_name, validate_file=False) - jit_vectors_obj = torch.jit.script(vectors_obj.to_ivalue()) + jit_vectors_obj = torch.jit.script(vectors_obj) # The first 3 entries in each vector. expected_fasttext_simple_en = { diff --git a/test/experimental/test_vectors.py b/test/experimental/test_vectors.py index 946c436e95..71695d0b75 100644 --- a/test/experimental/test_vectors.py +++ b/test/experimental/test_vectors.py @@ -54,10 +54,10 @@ def test_vectors_jit(self): tokens = ['a', 'b'] vecs = torch.stack((tensorA, tensorB), 0) vectors_obj = build_vectors(tokens, vecs, unk_tensor=unk_tensor) - jit_vectors_obj = torch.jit.script(vectors_obj.to_ivalue()) + jit_vectors_obj = torch.jit.script(vectors_obj) assert not vectors_obj.is_jitable - assert vectors_obj.to_ivalue().is_jitable + assert vectors_obj.__prepare_scriptable__().is_jitable self.assertEqual(vectors_obj['a'], jit_vectors_obj['a']) self.assertEqual(vectors_obj['b'], jit_vectors_obj['b']) @@ -71,7 +71,7 @@ def test_vectors_forward(self): tokens = ['a', 'b'] vecs = torch.stack((tensorA, tensorB), 0) vectors_obj = build_vectors(tokens, vecs, unk_tensor=unk_tensor) - jit_vectors_obj = torch.jit.script(vectors_obj.to_ivalue()) + jit_vectors_obj = torch.jit.script(vectors_obj) tokens_to_lookup = ['a', 'b', 'c'] expected_vectors = torch.stack((tensorA, tensorB, unk_tensor), 0) @@ -124,7 +124,7 @@ def test_vectors_load_and_save(self): vectors_obj['b'] = tensorC vector_path = os.path.join(self.test_dir, 'vectors.pt') - torch.save(vectors_obj.to_ivalue(), vector_path) + torch.save(vectors_obj.__prepare_scriptable__(), vector_path) loaded_vectors_obj = torch.load(vector_path) self.assertEqual(loaded_vectors_obj['a'], tensorA) diff --git a/test/experimental/test_vocab.py b/test/experimental/test_vocab.py index 626db2e726..996a672513 100644 --- a/test/experimental/test_vocab.py +++ b/test/experimental/test_vocab.py @@ -104,13 +104,13 @@ def test_vocab_jit(self): c = OrderedDict(sorted_by_freq_tuples) v = vocab(c, min_freq=3) - jit_v = torch.jit.script(v.to_ivalue()) + jit_v = torch.jit.script(v) expected_itos = ['', 'ᑌᑎIᑕOᗪᕮ_Tᕮ᙭T', 'hello', 'world'] expected_stoi = {x: index for index, x in enumerate(expected_itos)} assert not v.is_jitable - assert v.to_ivalue().is_jitable + assert v.__prepare_scriptable__().is_jitable self.assertEqual(jit_v.get_itos(), expected_itos) self.assertEqual(dict(jit_v.get_stoi()), expected_stoi) @@ -121,7 +121,7 @@ def test_vocab_forward(self): c = OrderedDict(sorted_by_freq_tuples) v = vocab(c) - jit_v = torch.jit.script(v.to_ivalue()) + jit_v = torch.jit.script(v) tokens = ['b', 'a', 'c'] expected_indices = [2, 1, 3] @@ -200,7 +200,7 @@ def test_vocab_load_and_save(self): self.assertEqual(dict(v.get_stoi()), expected_stoi) vocab_path = os.path.join(self.test_dir, 'vocab.pt') - torch.save(v.to_ivalue(), vocab_path) + torch.save(v.__prepare_scriptable__(), vocab_path) loaded_v = torch.load(vocab_path) self.assertEqual(v.get_itos(), expected_itos) diff --git a/torchtext/experimental/transforms.py b/torchtext/experimental/transforms.py index 6c3896aa61..c62bd2ce1c 100644 --- a/torchtext/experimental/transforms.py +++ b/torchtext/experimental/transforms.py @@ -50,7 +50,7 @@ def basic_english_normalize(): >>> from torchtext.experimental.transforms import basic_english_normalize >>> test_sample = 'Basic English Normalization for a Line of Text' >>> basic_eng_norm = basic_english_normalize() - >>> jit_basic_eng_norm = torch.jit.script(basic_eng_norm.to_ivalue()) + >>> jit_basic_eng_norm = torch.jit.script(basic_eng_norm) >>> tokens = jit_basic_eng_norm(test_sample) """ @@ -124,7 +124,7 @@ def forward(self, line: str) -> List[str]: return self.regex_tokenizer.forward(line) - def to_ivalue(self): + def __prepare_scriptable__(self): r"""Return a JITable BasicEnglishNormalize. """ @@ -159,7 +159,7 @@ def forward(self, line: str) -> List[str]: return self.regex_tokenizer.forward(line) - def to_ivalue(self): + def __prepare_scriptable__(self): r"""Return a JITable RegexTokenizer. """ @@ -177,7 +177,7 @@ class TextSequentialTransforms(nn.Sequential): >>> txt_pipeline = TextSequentialTransforms(tokenizer) >>> txt_pipeline('here is an example') ['here', 'is', 'an', 'example'] - >>> jit_txt_pipeline = torch.jit.script(txt_pipeline.to_ivalue()) + >>> jit_txt_pipeline = torch.jit.script(txt_pipeline) """ def forward(self, input: str): @@ -185,14 +185,14 @@ def forward(self, input: str): input = module(input) return input - def to_ivalue(self): + def __prepare_scriptable__(self): r"""Return a JITable TextSequentialTransforms. """ module_list = [] for _idx, _module in enumerate(self): - if hasattr(_module, 'to_ivalue'): - _module = _module.to_ivalue() + if hasattr(_module, '__prepare_scriptable__'): + _module = _module.__prepare_scriptable__() module_list.append((str(_idx), _module)) return TextSequentialTransforms(OrderedDict(module_list)) @@ -263,7 +263,7 @@ def sentencepiece_tokenizer(sp_model): >>> import torch >>> from torchtext.experimental.transforms import sentencepiece_tokenizer >>> spm_tokenizer = sentencepiece_tokenizer('m_user.model') - >>> jit_spm_tokenizer = torch.jit.script(spm_tokenizer.to_ivalue()) + >>> jit_spm_tokenizer = torch.jit.script(spm_tokenizer) """ spm = load_sp_model(sp_model) @@ -308,7 +308,7 @@ def decode(self, tokens: List[str]) -> str: return self.sp_model.DecodePieces(tokens) - def to_ivalue(self): + def __prepare_scriptable__(self): torchbind_spm = torch.classes.torchtext.SentencePiece(self.sp_model._return_content()) return SentencePieceTokenizer(torchbind_spm) @@ -323,7 +323,7 @@ def sentencepiece_processor(sp_model): >>> import torch >>> from torchtext.experimental.transforms import sentencepiece_processor >>> spm_processor = sentencepiece_processor('m_user.model') - >>> jit_spm_processor = torch.jit.script(spm_processor.to_ivalue()) + >>> jit_spm_processor = torch.jit.script(spm_processor) """ spm = load_sp_model(sp_model) @@ -366,7 +366,7 @@ def decode(self, ids: List[int]) -> str: return self.sp_model.DecodeIds(ids) - def to_ivalue(self): + def __prepare_scriptable__(self): torchbind_spm = torch.classes.torchtext.SentencePiece(self.sp_model._return_content()) return SentencePieceProcessor(torchbind_spm) @@ -382,7 +382,7 @@ class VocabTransform(nn.Module): >>> from torchtext.experimental.vocab import vocab_from_file_object >>> f = open('vocab.txt', 'r') >>> vocab_transform = VocabTransform(vocab_from_file_object(f)) - >>> jit_vocab_transform = torch.jit.script(vocab_transform.to_ivalue()) + >>> jit_vocab_transform = torch.jit.script(vocab_transform) """ def __init__(self, vocab): @@ -402,9 +402,9 @@ def forward(self, tokens: List[str]) -> List[int]: return self.vocab.lookup_indices(tokens) - def to_ivalue(self): - if hasattr(self.vocab, 'to_ivalue'): - vocab = self.vocab.to_ivalue() + def __prepare_scriptable__(self): + if hasattr(self.vocab, '__prepare_scriptable__'): + vocab = self.vocab.__prepare_scriptable__() return VocabTransform(vocab) return self @@ -419,7 +419,7 @@ class VectorTransform(nn.Module): >>> import torch >>> from torchtext.experimental.vectors import FastText >>> vector_transform = VectorTransform(FastText()) - >>> jit_vector_transform = torch.jit.script(vector_transform.to_ivalue()) + >>> jit_vector_transform = torch.jit.script(vector_transform) """ def __init__(self, vector): @@ -439,8 +439,8 @@ def forward(self, tokens: List[str]) -> Tensor: return self.vector.lookup_vectors(tokens) - def to_ivalue(self): - if hasattr(self.vector, 'to_ivalue'): - vector = self.vector.to_ivalue() + def __prepare_scriptable__(self): + if hasattr(self.vector, '__prepare_scriptable__'): + vector = self.vector.__prepare_scriptable__() return VectorTransform(vector) return self diff --git a/torchtext/experimental/vectors.py b/torchtext/experimental/vectors.py index a606e12bcc..e1b7590897 100644 --- a/torchtext/experimental/vectors.py +++ b/torchtext/experimental/vectors.py @@ -285,7 +285,7 @@ def lookup_vectors(self, tokens: List[str]) -> Tensor: return self.vectors.lookup_vectors(tokens) - def to_ivalue(self): + def __prepare_scriptable__(self): r"""Return a JITable Vectors. """ stoi = self.vectors.get_stoi() diff --git a/torchtext/experimental/vocab.py b/torchtext/experimental/vocab.py index a7707003ee..c0dacead8a 100644 --- a/torchtext/experimental/vocab.py +++ b/torchtext/experimental/vocab.py @@ -43,7 +43,7 @@ def build_vocab_from_text_file(file_object, jited_tokenizer, min_freq=1, unk_tok >>> f = open('vocab.txt', 'r') >>> tokenizer = basic_english_normalize() >>> tokenizer = basic_english_normalize() - >>> jit_tokenizer = torch.jit.script(tokenizer.to_ivalue()) + >>> jit_tokenizer = torch.jit.script(tokenizer) >>> v = build_vocab_from_text_file(f, jit_tokenizer) """ vocab_obj = _build_vocab_from_text_file(file_object.name, unk_token, min_freq, num_cpus, jited_tokenizer) @@ -264,7 +264,7 @@ def get_itos(self) -> List[str]: """ return self.vocab.get_itos() - def to_ivalue(self): + def __prepare_scriptable__(self): r"""Return a JITable Vocab. """ cpp_vocab = torch.classes.torchtext.Vocab(self.vocab.itos_, self.vocab.unk_token_) From 9eb20d9006973e9f66abb99952372b5b05af2f45 Mon Sep 17 00:00:00 2001 From: Guanheng Zhang Date: Fri, 13 Nov 2020 07:48:13 -0800 Subject: [PATCH 2/8] Skip __prepare_scriptable__ func in some cases --- torchtext/experimental/transforms.py | 5 ++++- torchtext/experimental/vectors.py | 2 ++ torchtext/experimental/vocab.py | 2 ++ 3 files changed, 8 insertions(+), 1 deletion(-) diff --git a/torchtext/experimental/transforms.py b/torchtext/experimental/transforms.py index c62bd2ce1c..d4af226059 100644 --- a/torchtext/experimental/transforms.py +++ b/torchtext/experimental/transforms.py @@ -127,7 +127,8 @@ def forward(self, line: str) -> List[str]: def __prepare_scriptable__(self): r"""Return a JITable BasicEnglishNormalize. """ - + if self.is_jitable: + return BasicEnglishNormalize(self.regex_tokenizer) regex_tokenizer = torch.classes.torchtext.RegexTokenizer(self.regex_tokenizer.patterns_, self.regex_tokenizer.replacements_, True) return BasicEnglishNormalize(regex_tokenizer) @@ -162,6 +163,8 @@ def forward(self, line: str) -> List[str]: def __prepare_scriptable__(self): r"""Return a JITable RegexTokenizer. """ + if self.is_jitable: + return RegexTokenizer(self.regex_tokenizer) regex_tokenizer = torch.classes.torchtext.RegexTokenizer(self.regex_tokenizer.patterns_, self.regex_tokenizer.replacements_, False) return RegexTokenizer(regex_tokenizer) diff --git a/torchtext/experimental/vectors.py b/torchtext/experimental/vectors.py index e1b7590897..8d9c7e426a 100644 --- a/torchtext/experimental/vectors.py +++ b/torchtext/experimental/vectors.py @@ -288,6 +288,8 @@ def lookup_vectors(self, tokens: List[str]) -> Tensor: def __prepare_scriptable__(self): r"""Return a JITable Vectors. """ + if self.is_jitable: + return Vectors(self.vectors) stoi = self.vectors.get_stoi() cpp_vectors = torch.classes.torchtext.Vectors(list(stoi.keys()), list(stoi.values()), self.vectors.vectors_, self.vectors.unk_tensor_) return(Vectors(cpp_vectors)) diff --git a/torchtext/experimental/vocab.py b/torchtext/experimental/vocab.py index c0dacead8a..5f6a1ff702 100644 --- a/torchtext/experimental/vocab.py +++ b/torchtext/experimental/vocab.py @@ -267,5 +267,7 @@ def get_itos(self) -> List[str]: def __prepare_scriptable__(self): r"""Return a JITable Vocab. """ + if self.is_jitable: + return Vocab(self.vocab) cpp_vocab = torch.classes.torchtext.Vocab(self.vocab.itos_, self.vocab.unk_token_) return Vocab(cpp_vocab) From 1396901a4ea35570f3a46e789b4d4ab9f608881b Mon Sep 17 00:00:00 2001 From: Guanheng Zhang Date: Wed, 6 Jan 2021 07:30:04 -0800 Subject: [PATCH 3/8] checkpoint --- test/experimental/test_transforms.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/experimental/test_transforms.py b/test/experimental/test_transforms.py index e241bfc0f9..d8a76d34f7 100644 --- a/test/experimental/test_transforms.py +++ b/test/experimental/test_transforms.py @@ -74,7 +74,7 @@ def test_sentencepiece_load_and_save(self): with self.subTest('torchscript'): save_path = os.path.join(self.test_dir, 'spm_torchscript.pt') - spm = sentencepiece_tokenizer((model_path)).to_ivalue() + spm = sentencepiece_tokenizer((model_path)).__prepare_scriptable__() torch.save(spm, save_path) loaded_spm = torch.load(save_path) self.assertEqual(expected, loaded_spm(input)) From fff1adb2679ff477ada54ad71a7c2062569d84d6 Mon Sep 17 00:00:00 2001 From: Guanheng Zhang Date: Sat, 9 Jan 2021 12:18:19 -0500 Subject: [PATCH 4/8] remove __prepare_scriptable__ func from TextSequentialTransforms --- torchtext/experimental/transforms.py | 12 ------------ 1 file changed, 12 deletions(-) diff --git a/torchtext/experimental/transforms.py b/torchtext/experimental/transforms.py index f805a3d5f2..262202a60e 100644 --- a/torchtext/experimental/transforms.py +++ b/torchtext/experimental/transforms.py @@ -2,7 +2,6 @@ import torch.nn as nn from typing import List from torchtext._torchtext import RegexTokenizer as RegexTokenizerPybind -from collections import OrderedDict from torch import Tensor from torchtext._torchtext import SentencePiece as SentencePiecePybind import io @@ -188,17 +187,6 @@ def forward(self, input: str): input = module(input) return input - def __prepare_scriptable__(self): - r"""Return a JITable TextSequentialTransforms. - """ - - module_list = [] - for _idx, _module in enumerate(self): - if hasattr(_module, '__prepare_scriptable__'): - _module = _module.__prepare_scriptable__() - module_list.append((str(_idx), _module)) - return TextSequentialTransforms(OrderedDict(module_list)) - PRETRAINED_SP_MODEL = { 'text_unigram_15000': 'https://pytorch.s3.amazonaws.com/models/text/pretrained_spm/text_unigram_15000.model', From f9a7d5befad3a9c1b466d0c626504f12b12bc53a Mon Sep 17 00:00:00 2001 From: Guanheng Zhang Date: Sat, 9 Jan 2021 12:45:54 -0500 Subject: [PATCH 5/8] remove __prepare_scriptable__ func from the transforms in the example --- examples/data_pipeline/pipelines.py | 1 - examples/data_pipeline/transforms.py | 14 -------------- 2 files changed, 15 deletions(-) diff --git a/examples/data_pipeline/pipelines.py b/examples/data_pipeline/pipelines.py index 5721a81b5c..d8c6f71f7d 100644 --- a/examples/data_pipeline/pipelines.py +++ b/examples/data_pipeline/pipelines.py @@ -32,7 +32,6 @@ def build_sp_pipeline(spm_file): vocab = PretrainedSPVocab(load_sp_model(spm_file)) # Insert token in vocab to match a pretrained vocab - vocab.insert_token('', 1) pipeline = TextSequentialTransforms(tokenizer, vocab) jit_pipeline = torch.jit.script(pipeline) print('jit sentencepiece pipeline success!') diff --git a/examples/data_pipeline/transforms.py b/examples/data_pipeline/transforms.py index 5c9d33a0c3..2bcb8ff34c 100644 --- a/examples/data_pipeline/transforms.py +++ b/examples/data_pipeline/transforms.py @@ -24,14 +24,6 @@ def forward(self, tokens: List[str]) -> List[int]: def insert_token(self, token: str, index: int) -> None: self.vocab.insert_token(token, index) - def __prepare_scriptable__(self): - if hasattr(self.vocab, '__prepare_scriptable__'): - sp_model = self.sp_model - new_module = PretrainedSPVocab(sp_model) - new_module.vocab = self.vocab.__prepare_scriptable__() - return new_module - return self - class PyTextVocabTransform(nn.Module): r"""PyTextVocabTransform transform @@ -57,12 +49,6 @@ def __init__(self, vocab): def forward(self, tokens: List[str]) -> List[int]: return self.vocab.lookup_indices_1d(tokens) - def __prepare_scriptable__(self): - if hasattr(self.vocab, '__prepare_scriptable__'): - vocab = self.vocab.__prepare_scriptable__() - return PyTextScriptVocabTransform(vocab) - return self - class ToLongTensor(nn.Module): r"""Convert a list of integers to long tensor From c6f42363c4511ca96c664bd0c7f58a2b90262811 Mon Sep 17 00:00:00 2001 From: Guanheng Zhang Date: Sat, 9 Jan 2021 12:57:17 -0500 Subject: [PATCH 6/8] remove __prepare_scriptable__ func from VocabTransform and VectorTransform --- torchtext/experimental/transforms.py | 12 ------------ 1 file changed, 12 deletions(-) diff --git a/torchtext/experimental/transforms.py b/torchtext/experimental/transforms.py index 262202a60e..b3cd617956 100644 --- a/torchtext/experimental/transforms.py +++ b/torchtext/experimental/transforms.py @@ -393,12 +393,6 @@ def forward(self, tokens: List[str]) -> List[int]: return self.vocab.lookup_indices(tokens) - def __prepare_scriptable__(self): - if hasattr(self.vocab, '__prepare_scriptable__'): - vocab = self.vocab.__prepare_scriptable__() - return VocabTransform(vocab) - return self - class VectorTransform(nn.Module): r"""Vector transform @@ -429,9 +423,3 @@ def forward(self, tokens: List[str]) -> Tensor: """ return self.vector.lookup_vectors(tokens) - - def __prepare_scriptable__(self): - if hasattr(self.vector, '__prepare_scriptable__'): - vector = self.vector.__prepare_scriptable__() - return VectorTransform(vector) - return self From 1fb69bda2fd6473ef7795a6bf75e51036d632c0e Mon Sep 17 00:00:00 2001 From: Guanheng Zhang Date: Mon, 11 Jan 2021 12:03:23 -0800 Subject: [PATCH 7/8] add a note --- test/data/test_functional.py | 8 ++++++++ test/experimental/test_transforms.py | 2 ++ test/experimental/test_vectors.py | 4 ++++ test/experimental/test_vocab.py | 4 ++++ 4 files changed, 18 insertions(+) diff --git a/test/data/test_functional.py b/test/data/test_functional.py index 9135680c3b..199c022786 100644 --- a/test/data/test_functional.py +++ b/test/data/test_functional.py @@ -101,6 +101,8 @@ def test_BasicEnglishNormalize(self): eager_tokens = basic_english_tokenizer(test_sample) assert not basic_eng_norm.is_jitable + # Call the __prepare_scriptable__() func and convert the building block to the torbhind version + # Not expect users to use the torchbind version on eager mode but still need a CI test here. assert basic_eng_norm.__prepare_scriptable__().is_jitable self.assertEqual(experimental_jit_tokens, ref_results) @@ -121,6 +123,8 @@ def test_basicEnglishNormalize_load_and_save(self): with self.subTest('torchscript'): save_path = os.path.join(self.test_dir, 'ben_torchscrip.pt') + # Call the __prepare_scriptable__() func and convert the building block to the torbhind version + # Not expect users to use the torchbind version on eager mode but still need a CI test here. ben = basic_english_normalize().__prepare_scriptable__() torch.save(ben, save_path) loaded_ben = torch.load(save_path) @@ -153,6 +157,8 @@ def test_RegexTokenizer(self): jit_tokens = jit_r_tokenizer(test_sample) assert not r_tokenizer.is_jitable + # Call the __prepare_scriptable__() func and convert the building block to the torbhind version + # Not expect users to use the torchbind version on eager mode but still need a CI test here. assert r_tokenizer.__prepare_scriptable__().is_jitable self.assertEqual(eager_tokens, ref_results) @@ -186,6 +192,8 @@ def test_load_and_save(self): with self.subTest('torchscript'): save_path = os.path.join(self.test_dir, 'regex_torchscript.pt') + # Call the __prepare_scriptable__() func and convert the building block to the torbhind version + # Not expect users to use the torchbind version on eager mode but still need a CI test here. tokenizer = regex_tokenizer(patterns_list).__prepare_scriptable__() torch.save(tokenizer, save_path) loaded_tokenizer = torch.load(save_path) diff --git a/test/experimental/test_transforms.py b/test/experimental/test_transforms.py index d8a76d34f7..d3cc651ddc 100644 --- a/test/experimental/test_transforms.py +++ b/test/experimental/test_transforms.py @@ -74,6 +74,8 @@ def test_sentencepiece_load_and_save(self): with self.subTest('torchscript'): save_path = os.path.join(self.test_dir, 'spm_torchscript.pt') + # Call the __prepare_scriptable__() func and convert the building block to the torbhind version + # Not expect users to use the torchbind version on eager mode but still need a CI test here. spm = sentencepiece_tokenizer((model_path)).__prepare_scriptable__() torch.save(spm, save_path) loaded_spm = torch.load(save_path) diff --git a/test/experimental/test_vectors.py b/test/experimental/test_vectors.py index 0e48ba8317..0cb32ffe1d 100644 --- a/test/experimental/test_vectors.py +++ b/test/experimental/test_vectors.py @@ -57,6 +57,8 @@ def test_vectors_jit(self): jit_vectors_obj = torch.jit.script(vectors_obj) assert not vectors_obj.is_jitable + # Call the __prepare_scriptable__() func and convert the building block to the torbhind version + # Not expect users to use the torchbind version on eager mode but still need a CI test here. assert vectors_obj.__prepare_scriptable__().is_jitable self.assertEqual(vectors_obj['a'], jit_vectors_obj['a']) @@ -148,6 +150,8 @@ def test_vectors_load_and_save(self): with self.subTest('torchscript'): vector_path = os.path.join(self.test_dir, 'vectors_torchscript.pt') + # Call the __prepare_scriptable__() func and convert the building block to the torbhind version + # Not expect users to use the torchbind version on eager mode but still need a CI test here. torch.save(vectors_obj.__prepare_scriptable__(), vector_path) loaded_vectors_obj = torch.load(vector_path) diff --git a/test/experimental/test_vocab.py b/test/experimental/test_vocab.py index c7ec3a59bd..879c03e72d 100644 --- a/test/experimental/test_vocab.py +++ b/test/experimental/test_vocab.py @@ -110,6 +110,8 @@ def test_vocab_jit(self): expected_stoi = {x: index for index, x in enumerate(expected_itos)} assert not v.is_jitable + # Call the __prepare_scriptable__() func and convert the building block to the torbhind version + # Not expect users to use the torchbind version on eager mode but still need a CI test here. assert v.__prepare_scriptable__().is_jitable self.assertEqual(jit_v.get_itos(), expected_itos) @@ -208,6 +210,8 @@ def test_vocab_load_and_save(self): with self.subTest('torchscript'): vocab_path = os.path.join(self.test_dir, 'vocab_torchscript.pt') + # Call the __prepare_scriptable__() func and convert the building block to the torbhind version + # Not expect users to use the torchbind version on eager mode but still need a CI test here. torch.save(v.__prepare_scriptable__(), vocab_path) loaded_v = torch.load(vocab_path) self.assertEqual(v.get_itos(), expected_itos) From 8809bd53edf6fff0a119bf11f9c99a382c781336 Mon Sep 17 00:00:00 2001 From: Guanheng Zhang Date: Tue, 9 Feb 2021 09:23:53 -0800 Subject: [PATCH 8/8] remove return self --- torchtext/experimental/transforms.py | 5 ----- torchtext/experimental/vectors.py | 2 -- torchtext/experimental/vocab.py | 2 -- 3 files changed, 9 deletions(-) diff --git a/torchtext/experimental/transforms.py b/torchtext/experimental/transforms.py index b3cd617956..3b542aeb45 100644 --- a/torchtext/experimental/transforms.py +++ b/torchtext/experimental/transforms.py @@ -126,8 +126,6 @@ def forward(self, line: str) -> List[str]: def __prepare_scriptable__(self): r"""Return a JITable BasicEnglishNormalize. """ - if self.is_jitable: - return BasicEnglishNormalize(self.regex_tokenizer) regex_tokenizer = torch.classes.torchtext.RegexTokenizer(self.regex_tokenizer.patterns_, self.regex_tokenizer.replacements_, True) return BasicEnglishNormalize(regex_tokenizer) @@ -162,9 +160,6 @@ def forward(self, line: str) -> List[str]: def __prepare_scriptable__(self): r"""Return a JITable RegexTokenizer. """ - if self.is_jitable: - return RegexTokenizer(self.regex_tokenizer) - regex_tokenizer = torch.classes.torchtext.RegexTokenizer(self.regex_tokenizer.patterns_, self.regex_tokenizer.replacements_, False) return RegexTokenizer(regex_tokenizer) diff --git a/torchtext/experimental/vectors.py b/torchtext/experimental/vectors.py index 552ba74c50..e7779d9ad4 100644 --- a/torchtext/experimental/vectors.py +++ b/torchtext/experimental/vectors.py @@ -288,8 +288,6 @@ def lookup_vectors(self, tokens: List[str]) -> Tensor: def __prepare_scriptable__(self): r"""Return a JITable Vectors. """ - if self.is_jitable: - return Vectors(self.vectors) stoi = self.vectors.get_stoi() cpp_vectors = torch.classes.torchtext.Vectors(list(stoi.keys()), list(stoi.values()), self.vectors.vectors_, self.vectors.unk_tensor_) return(Vectors(cpp_vectors)) diff --git a/torchtext/experimental/vocab.py b/torchtext/experimental/vocab.py index 8ead21e7ad..606965daa9 100644 --- a/torchtext/experimental/vocab.py +++ b/torchtext/experimental/vocab.py @@ -267,7 +267,5 @@ def get_itos(self) -> List[str]: def __prepare_scriptable__(self): r"""Return a JITable Vocab. """ - if self.is_jitable: - return Vocab(self.vocab) cpp_vocab = torch.classes.torchtext.Vocab(self.vocab.itos_, self.vocab.unk_token_) return Vocab(cpp_vocab)