Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit 97ee0c1

Browse files
committed
added type annotation
1 parent e56eafb commit 97ee0c1

File tree

2 files changed

+5
-7
lines changed

2 files changed

+5
-7
lines changed

torchtext/data/datasets_utils.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -209,8 +209,7 @@ def _wrap_split_argument_with_fn(fn, splits):
209209
argspec.args[1] == "split" and
210210
argspec.varargs is None and
211211
argspec.varkw is None and
212-
len(argspec.kwonlyargs) == 0 and
213-
len(argspec.annotations) == 0
212+
len(argspec.kwonlyargs) == 0
214213
):
215214
raise ValueError("Internal Error: Given function {} did not adhere to standard signature.".format(fn))
216215

@@ -246,10 +245,9 @@ def decorator(func):
246245
argspec.args[1] == "split" and
247246
argspec.varargs is None and
248247
argspec.varkw is None and
249-
len(argspec.kwonlyargs) == 0 and
250-
len(argspec.annotations) == 0
248+
len(argspec.kwonlyargs) == 0
251249
):
252-
raise ValueError("Internal Error: Given function {} did not adhere to standard signature.".format(fn))
250+
raise ValueError("Internal Error: Given function {} did not adhere to standard signature.".format(func))
253251

254252
@functools.wraps(func)
255253
def wrapper(root=_CACHE_DIR, *args, **kwargs):

torchtext/datasets/amazonreviewpolarity.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from torchtext._internal.module_utils import is_module_available
2-
2+
from typing import Union, Tuple
33
if is_module_available("torchdata"):
44
from torchdata.datapipes.iter import FileOpener, GDriveReader, IterableWrapper
55

@@ -34,7 +34,7 @@
3434
@_add_docstring_header(num_lines=NUM_LINES, num_classes=2)
3535
@_create_dataset_directory(dataset_name=DATASET_NAME)
3636
@_wrap_split_argument(("train", "test"))
37-
def AmazonReviewPolarity(root, split):
37+
def AmazonReviewPolarity(root: str, split: Union[Tuple[str], str]):
3838
# TODO Remove this after removing conditional dependency
3939
if not is_module_available("torchdata"):
4040
raise ModuleNotFoundError("Package `torchdata` not found. Please install following instructions at `https://github.com/pytorch/data`")

0 commit comments

Comments
 (0)