|
3 | 3 | import operator |
4 | 4 | import os |
5 | 5 | import pkgutil |
| 6 | +import platform |
6 | 7 | import sys |
7 | 8 | import warnings |
8 | 9 | from collections import OrderedDict |
@@ -343,12 +344,25 @@ def _check_input_backprop(model, inputs): |
343 | 344 | _model_params[m] = {"input_shape": (1, 3, 64, 64)} |
344 | 345 |
|
345 | 346 |
|
346 | | -# skip big models to reduce memory usage on CI test |
| 347 | +# skip big models to reduce memory usage on CI test. We can exclude combinations of (platform-system, device). |
347 | 348 | skipped_big_models = { |
348 | | - "vit_h_14", |
349 | | - "regnet_y_128gf", |
| 349 | + "vit_h_14": {("Windows", "cpu"), ("Windows", "cuda")}, |
| 350 | + "regnet_y_128gf": {("Windows", "cpu"), ("Windows", "cuda")}, |
| 351 | + "mvit_v1_b": {("Windows", "cuda")}, |
| 352 | + "mvit_v2_s": {("Windows", "cuda")}, |
350 | 353 | } |
351 | 354 |
|
| 355 | + |
| 356 | +def is_skippable(model_name, device): |
| 357 | + if model_name not in skipped_big_models: |
| 358 | + return False |
| 359 | + |
| 360 | + platform_system = platform.system() |
| 361 | + device_name = str(device).split(":")[0] |
| 362 | + |
| 363 | + return (platform_system, device_name) in skipped_big_models[model_name] |
| 364 | + |
| 365 | + |
352 | 366 | # The following contains configuration and expected values to be used tests that are model specific |
353 | 367 | _model_tests_values = { |
354 | 368 | "retinanet_resnet50_fpn": { |
@@ -612,7 +626,7 @@ def test_classification_model(model_fn, dev): |
612 | 626 | "input_shape": (1, 3, 224, 224), |
613 | 627 | } |
614 | 628 | model_name = model_fn.__name__ |
615 | | - if SKIP_BIG_MODEL and model_name in skipped_big_models: |
| 629 | + if SKIP_BIG_MODEL and is_skippable(model_name, dev): |
616 | 630 | pytest.skip("Skipped to reduce memory usage. Set env var SKIP_BIG_MODEL=0 to enable test for this model") |
617 | 631 | kwargs = {**defaults, **_model_params.get(model_name, {})} |
618 | 632 | num_classes = kwargs.get("num_classes") |
@@ -841,7 +855,7 @@ def test_video_model(model_fn, dev): |
841 | 855 | "num_classes": 50, |
842 | 856 | } |
843 | 857 | model_name = model_fn.__name__ |
844 | | - if SKIP_BIG_MODEL and model_name in skipped_big_models: |
| 858 | + if SKIP_BIG_MODEL and is_skippable(model_name, dev): |
845 | 859 | pytest.skip("Skipped to reduce memory usage. Set env var SKIP_BIG_MODEL=0 to enable test for this model") |
846 | 860 | kwargs = {**defaults, **_model_params.get(model_name, {})} |
847 | 861 | num_classes = kwargs.get("num_classes") |
|
0 commit comments