Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit 8d19c03

Browse files
parmeetfacebook-github-bot
authored andcommitted
Fix issue in label Transform
Summary: In the construction of Vocab within label transform, the default index is set to 0. This index is returned when OOV token is given. For this transform, the default index should never be set. Otherwise, it will return default index (which is 0) for unknown labels that might get passed (Ideally it should throw error in this case because we do not know what to do when wrong label is passed for query) Reviewed By: hudeven Differential Revision: D32610834 fbshipit-source-id: e49385fb313929627c41fc515b6d900a6bfc3591
1 parent 68dc59c commit 8d19c03

File tree

2 files changed

+4
-1
lines changed

2 files changed

+4
-1
lines changed

test/test_transforms.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,9 @@ def test_labeltoindex(self):
9595
expected = [0, 1, 2]
9696
self.assertEqual(actual, expected)
9797

98+
with self.assertRaises(RuntimeError):
99+
transform(['OOV'])
100+
98101
transform = transforms.LabelToIndex(label_names=label_names, sort_names=True)
99102
actual = transform(label_names)
100103
expected = [2, 1, 0]

torchtext/transforms.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@ def __init__(
130130

131131
if sort_names:
132132
label_names = sorted(label_names)
133-
self._label_vocab = Vocab(torch.classes.torchtext.Vocab(label_names, 0))
133+
self._label_vocab = Vocab(torch.classes.torchtext.Vocab(label_names, None))
134134
self._label_names = self._label_vocab.get_itos()
135135

136136
def forward(self, labels: Union[str, List[str]]) -> Union[int, List[int]]:

0 commit comments

Comments
 (0)