Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 7 additions & 5 deletions benchmark/benchmark_experimental_vocab.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import time

import torch
from torchtext.experimental.datasets import AG_NEWS
from torchtext.experimental.datasets import DATASETS
from torchtext.experimental.vocab import (
vocab as VocabExperimental,
load_vocab_from_file,
Expand Down Expand Up @@ -76,7 +76,7 @@ def benchmark_experimental_vocab_construction(vocab_file_path, is_raw_text=True,
print("Construction time:", time.monotonic() - t0)


def benchmark_experimental_vocab_lookup(vocab_file_path=None):
def benchmark_experimental_vocab_lookup(vocab_file_path=None, dataset = 'AG_NEWS'):
def _run_benchmark_lookup(tokens, vocab):
t0 = time.monotonic()
# list lookup
Expand All @@ -94,7 +94,7 @@ def _run_benchmark_lookup(tokens, vocab):
tokens = []
tokens_lists = []

train = AG_NEWS(split='train')
train = DATASETS[dataset](split='train')
vocab = train.get_vocab()
for (_, text) in train:
cur_tokens = []
Expand Down Expand Up @@ -124,7 +124,7 @@ def token_iterator(file_path):
v_experimental = load_vocab_from_file(f)
print("Construction time:", time.monotonic() - t0)
else:
print("Loading Vocab from AG News")
print("Loading Vocab from {}".format(dataset))
counter = Counter(tokens)
sorted_by_freq_tuples = sorted(counter.items(), key=lambda x: x[1], reverse=True)
ordered_dict = OrderedDict(sorted_by_freq_tuples)
Expand Down Expand Up @@ -174,11 +174,13 @@ def token_iterator(file_path):
help='The name of vocab file used for construction')
parser.add_argument('--vocab-filename-lookup', type=str, default=None,
help='The name of vocab file used for lookup')
parser.add_argument('--dataset', type=str, default='AG_NEWS',
help='The name of vocab file used for lookup')
args = parser.parse_args()

if args.run_construction_benchmark:
print("is_legacy", args.is_legacy)
benchmark_experimental_vocab_construction(args.vocab_filename_construction,
is_raw_text=args.is_raw_text, is_legacy=args.is_legacy)
else:
benchmark_experimental_vocab_lookup(args.vocab_filename_lookup)
benchmark_experimental_vocab_lookup(args.vocab_filename_lookup, args.dataset)
243 changes: 140 additions & 103 deletions torchtext/csrc/register_bindings.cpp
Original file line number Diff line number Diff line change
@@ -1,26 +1,26 @@
#include <iostream>
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <regex.h>
#include <regex_tokenizer.h> // @manual
#include <sentencepiece.h> // @manual
#include <regex_tokenizer.h> // @manual
#include <sentencepiece.h> // @manual
#include <torch/csrc/jit/python/pybind_utils.h> // @manual
#include <torch/csrc/utils/pybind.h> // @manual
#include <torch/csrc/utils/pybind.h> // @manual
#include <torch/script.h>
#include <vectors.h> // @manual
#include <vocab.h> // @manual

namespace torchtext {

namespace py = pybind11;

namespace {
Vocab build_vocab_from_text_file(const std::string &file_path,
const std::string &unk_token,
const int64_t min_freq,
const int64_t num_cpus,
const int64_t min_freq, const int64_t num_cpus,
py::object fn) {
torch::jit::script::Module module(*torch::jit::as_module(fn));
return _build_vocab_from_text_file(file_path, unk_token, min_freq, num_cpus, module);
return _build_vocab_from_text_file(file_path, unk_token, min_freq, num_cpus,
module);
}
} // namespace

Expand All @@ -40,23 +40,27 @@ PYBIND11_MODULE(_torchtext, m) {
return _deserialize_regex(std::move(state));
}));

py::class_<RegexTokenizer, c10::intrusive_ptr<RegexTokenizer>>(m, "RegexTokenizer")
py::class_<RegexTokenizer, c10::intrusive_ptr<RegexTokenizer>>(
m, "RegexTokenizer")
.def_readonly("patterns_", &RegexTokenizer::patterns_)
.def_readonly("replacements_", &RegexTokenizer::replacements_)
.def_readonly("to_lower_", &RegexTokenizer::to_lower_)
.def(py::init<std::vector<std::string>, std::vector<std::string>, bool>())
.def("forward", &RegexTokenizer::forward)
.def(py::pickle(
// __getstate__
[](const c10::intrusive_ptr<RegexTokenizer> &self) -> RegexTokenizerStates {
[](const c10::intrusive_ptr<RegexTokenizer> &self)
-> RegexTokenizerStates {
return _serialize_regex_tokenizer(self);
},
// __setstate__
[](RegexTokenizerStates states) -> c10::intrusive_ptr<RegexTokenizer> {
[](RegexTokenizerStates states)
-> c10::intrusive_ptr<RegexTokenizer> {
return _deserialize_regex_tokenizer(std::move(states));
}));

py::class_<SentencePiece, c10::intrusive_ptr<SentencePiece>>(m, "SentencePiece")
py::class_<SentencePiece, c10::intrusive_ptr<SentencePiece>>(m,
"SentencePiece")
.def(py::init<std::string>())
.def("_return_content",
[](const SentencePiece &self) { return py::bytes(self.content_); })
Expand All @@ -70,14 +74,14 @@ PYBIND11_MODULE(_torchtext, m) {
.def("PieceToId", &SentencePiece::PieceToId)
.def("IdToPiece", &SentencePiece::IdToPiece)
.def(py::pickle(
// __getstate__
[](const c10::intrusive_ptr<SentencePiece> &self) -> py::bytes{
return py::bytes(self->content_);
},
// __setstate__
[](py::bytes state) -> c10::intrusive_ptr<SentencePiece> {
return c10::make_intrusive<SentencePiece>(std::string(state));
}));
// __getstate__
[](const c10::intrusive_ptr<SentencePiece> &self) -> py::bytes {
return py::bytes(self->content_);
},
// __setstate__
[](py::bytes state) -> c10::intrusive_ptr<SentencePiece> {
return c10::make_intrusive<SentencePiece>(std::string(state));
}));

py::class_<Vectors, c10::intrusive_ptr<Vectors>>(m, "Vectors")
.def(py::init<std::vector<std::string>, std::vector<int64_t>,
Expand All @@ -103,13 +107,30 @@ PYBIND11_MODULE(_torchtext, m) {
.def(py::init<std::vector<std::string>, std::string>())
.def_readonly("itos_", &Vocab::itos_)
.def_readonly("unk_token_", &Vocab::unk_token_)
.def("__getitem__", &Vocab::__getitem__)
.def("__getitem__",
[](c10::intrusive_ptr<Vocab> &self, const py::str &item) -> int64_t {
ssize_t length;
const char *buffer = PyUnicode_AsUTF8AndSize(item.ptr(), &length);
return self->__getitem__(c10::string_view{buffer, (size_t)length});
})
.def("__len__", &Vocab::__len__)
.def("insert_token", &Vocab::insert_token)
.def("append_token", &Vocab::append_token)
.def("lookup_token", &Vocab::lookup_token)
.def("lookup_tokens", &Vocab::lookup_tokens)
.def("lookup_indices", &Vocab::lookup_indices)
.def("lookup_indices",
[](const c10::intrusive_ptr<Vocab> &self, const py::list &items) {
std::vector<int64_t> indices(items.size());
int64_t counter = 0;
for (const auto &item : items) {
ssize_t length;
const char *buffer =
PyUnicode_AsUTF8AndSize(item.ptr(), &length);
indices[counter++] =
self->__getitem__(c10::string_view{buffer, (size_t)length});
}
return indices;
})
.def("get_stoi", &Vocab::get_stoi)
.def("get_itos", &Vocab::get_itos)
.def(py::pickle(
Expand All @@ -131,96 +152,112 @@ PYBIND11_MODULE(_torchtext, m) {

TORCH_LIBRARY_FRAGMENT(torchtext, m) {
m.class_<Regex>("Regex")
.def(torch::init<std::string>())
.def("Sub", &Regex::Sub)
.def_pickle(
// __getstate__
[](const c10::intrusive_ptr<Regex> &self) -> std::string {
return _serialize_regex(self);
},
// __setstate__
[](std::string state) -> c10::intrusive_ptr<Regex> {
return _deserialize_regex(std::move(state));
});
.def(torch::init<std::string>())
.def("Sub", &Regex::Sub)
.def_pickle(
// __getstate__
[](const c10::intrusive_ptr<Regex> &self) -> std::string {
return _serialize_regex(self);
},
// __setstate__
[](std::string state) -> c10::intrusive_ptr<Regex> {
return _deserialize_regex(std::move(state));
});

m.class_<RegexTokenizer>("RegexTokenizer")
.def(torch::init<std::vector<std::string>, std::vector<std::string>, bool>())
.def("forward", &RegexTokenizer::forward)
.def_pickle(
// __getstate__
[](const c10::intrusive_ptr<RegexTokenizer> &self) -> RegexTokenizerStates {
return _serialize_regex_tokenizer(self);
},
// __setstate__
[](RegexTokenizerStates states) -> c10::intrusive_ptr<RegexTokenizer> {
return _deserialize_regex_tokenizer(std::move(states));
});
.def(torch::init<std::vector<std::string>, std::vector<std::string>,
bool>())
.def("forward", &RegexTokenizer::forward)
.def_pickle(
// __getstate__
[](const c10::intrusive_ptr<RegexTokenizer> &self)
-> RegexTokenizerStates {
return _serialize_regex_tokenizer(self);
},
// __setstate__
[](RegexTokenizerStates states)
-> c10::intrusive_ptr<RegexTokenizer> {
return _deserialize_regex_tokenizer(std::move(states));
});

m.class_<SentencePiece>("SentencePiece")
.def(torch::init<std::string>())
.def("Encode", &SentencePiece::Encode)
.def("EncodeAsIds", &SentencePiece::EncodeAsIds)
.def("DecodeIds", &SentencePiece::DecodeIds)
.def("EncodeAsPieces", &SentencePiece::EncodeAsPieces)
.def("DecodePieces", &SentencePiece::DecodePieces)
.def("GetPieceSize", &SentencePiece::GetPieceSize)
.def("unk_id", &SentencePiece::unk_id)
.def("PieceToId", &SentencePiece::PieceToId)
.def("IdToPiece", &SentencePiece::IdToPiece)
.def_pickle(
// The underlying content of SentencePiece contains byte string,
// and returing it as std::string cause UTF8 decoding error.
// Since TorchScript does not support byte string, we use byte Tensor to
// pass around the data.
// __getstate__
[](const c10::intrusive_ptr<SentencePiece> &self) -> torch::Tensor {
auto *data = static_cast<void*>(const_cast<char*>(self->content_.data()));
auto numel = static_cast<int64_t>(self->content_.size());
return torch::from_blob(data, {numel}, {torch::kUInt8}).clone();
},
// __setstate__
[](torch::Tensor state) -> c10::intrusive_ptr<SentencePiece> {
auto *data = static_cast<char*>(state.data_ptr());
auto numel = state.size(0);
return c10::make_intrusive<SentencePiece>(std::string(data, numel));
});
.def(torch::init<std::string>())
.def("Encode", &SentencePiece::Encode)
.def("EncodeAsIds", &SentencePiece::EncodeAsIds)
.def("DecodeIds", &SentencePiece::DecodeIds)
.def("EncodeAsPieces", &SentencePiece::EncodeAsPieces)
.def("DecodePieces", &SentencePiece::DecodePieces)
.def("GetPieceSize", &SentencePiece::GetPieceSize)
.def("unk_id", &SentencePiece::unk_id)
.def("PieceToId", &SentencePiece::PieceToId)
.def("IdToPiece", &SentencePiece::IdToPiece)
.def_pickle(
// The underlying content of SentencePiece contains byte string,
// and returing it as std::string cause UTF8 decoding error.
// Since TorchScript does not support byte string, we use byte Tensor
// to pass around the data.
// __getstate__
[](const c10::intrusive_ptr<SentencePiece> &self) -> torch::Tensor {
auto *data =
static_cast<void *>(const_cast<char *>(self->content_.data()));
auto numel = static_cast<int64_t>(self->content_.size());
return torch::from_blob(data, {numel}, {torch::kUInt8}).clone();
},
// __setstate__
[](torch::Tensor state) -> c10::intrusive_ptr<SentencePiece> {
auto *data = static_cast<char *>(state.data_ptr());
auto numel = state.size(0);
return c10::make_intrusive<SentencePiece>(std::string(data, numel));
});

m.class_<Vectors>("Vectors")
.def(torch::init<std::vector<std::string>, std::vector<std::int64_t>, torch::Tensor, torch::Tensor>())
.def("__getitem__", &Vectors::__getitem__)
.def("lookup_vectors", &Vectors::lookup_vectors)
.def("__setitem__", &Vectors::__setitem__)
.def("__len__", &Vectors::__len__)
.def_pickle(
// __getstate__
[](const c10::intrusive_ptr<Vectors> &self) -> VectorsStates {
return _serialize_vectors(self);
},
// __setstate__
[](VectorsStates states) -> c10::intrusive_ptr<Vectors> {
return _deserialize_vectors(states);
});
.def(torch::init<std::vector<std::string>, std::vector<std::int64_t>,
torch::Tensor, torch::Tensor>())
.def("__getitem__", &Vectors::__getitem__)
.def("lookup_vectors", &Vectors::lookup_vectors)
.def("__setitem__", &Vectors::__setitem__)
.def("__len__", &Vectors::__len__)
.def_pickle(
// __getstate__
[](const c10::intrusive_ptr<Vectors> &self) -> VectorsStates {
return _serialize_vectors(self);
},
// __setstate__
[](VectorsStates states) -> c10::intrusive_ptr<Vectors> {
return _deserialize_vectors(states);
});

m.class_<Vocab>("Vocab")
.def(torch::init<StringList, std::string>())
.def("__getitem__", &Vocab::__getitem__)
.def("__len__", &Vocab::__len__)
.def("insert_token", &Vocab::insert_token)
.def("append_token", &Vocab::append_token)
.def("lookup_token", &Vocab::lookup_token)
.def("lookup_tokens", &Vocab::lookup_tokens)
.def("lookup_indices", &Vocab::lookup_indices)
.def("get_stoi", &Vocab::get_stoi)
.def("get_itos", &Vocab::get_itos)
.def_pickle(
// __getstate__
[](const c10::intrusive_ptr<Vocab> &self) -> VocabStates {
return _serialize_vocab(self);
},
// __setstate__
[](VocabStates states) -> c10::intrusive_ptr<Vocab> {
return _deserialize_vocab(states);
});
.def(torch::init<StringList, std::string>())
.def("__getitem__",
[](const c10::intrusive_ptr<Vocab> &self, const std::string &item)
-> int64_t { return self->__getitem__(c10::string_view{item}); })
.def("__len__", &Vocab::__len__)
.def("insert_token", &Vocab::insert_token)
.def("append_token", &Vocab::append_token)
.def("lookup_token", &Vocab::lookup_token)
.def("lookup_tokens", &Vocab::lookup_tokens)
.def("lookup_indices",
[](const c10::intrusive_ptr<Vocab> &self,
const std::vector<std::string> &items) {
std::vector<int64_t> indices(items.size());
int64_t counter = 0;
for (const auto &item : items) {
indices[counter++] = self->__getitem__(c10::string_view{item});
}
return indices;
})
.def("get_stoi", &Vocab::get_stoi)
.def("get_itos", &Vocab::get_itos)
.def_pickle(
// __getstate__
[](const c10::intrusive_ptr<Vocab> &self) -> VocabStates {
return _serialize_vocab(self);
},
// __setstate__
[](VocabStates states) -> c10::intrusive_ptr<Vocab> {
return _deserialize_vocab(states);
});

m.def("torchtext::generate_sp_model", &generate_sp_model);
m.def("torchtext::load_sp_model", &load_sp_model);
Expand Down
Loading