Skip to content
This repository was archived by the owner on Aug 7, 2024. It is now read-only.

Commit bdb8586

Browse files
committed
Update
[ghstack-poisoned]
1 parent a9bae50 commit bdb8586

File tree

1 file changed

+10
-4
lines changed

1 file changed

+10
-4
lines changed

test/test_base.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import unittest
1111
import warnings
1212
from itertools import product
13+
from typing import Any, Callable, Dict, List, Optional, Tuple
1314

1415
import pytest
1516

@@ -53,14 +54,19 @@
5354
is_H100 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (9, 0)
5455

5556

56-
def filtered_parametrize(param_list, filter_func=None):
57+
def filtered_parametrize(
58+
param_list: List[Tuple[str, List[Any]]],
59+
filter_func: Optional[Callable[[Dict[str, Any]], bool]] = None,
60+
):
5761
"""
5862
A decorator that works like pytest.mark.parametrize but filters out
5963
unwanted parameter combinations.
6064
61-
:param param_list: A list of tuples, each containing (arg_name, [arg_values])
62-
:param filter_func: A function that takes a dictionary of parameter names and values,
63-
and returns True for valid combinations, False otherwise
65+
Args:
66+
param_list: A list of tuples, each containing (arg_name, [arg_values])
67+
filter_func: A function that takes a dictionary of parameter names and values,
68+
and returns True for valid combinations, False otherwise
69+
6470
"""
6571

6672
def decorator(func):

0 commit comments

Comments
 (0)