From 7af069936656e033ff1f321d6516ebc984dabb89 Mon Sep 17 00:00:00 2001 From: Kai Zhang Date: Fri, 7 Jan 2022 21:36:56 +0000 Subject: [PATCH 1/9] add regnet_y_128gf --- torchvision/models/regnet.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/torchvision/models/regnet.py b/torchvision/models/regnet.py index 85f53751dd0..cf1863ccfe4 100644 --- a/torchvision/models/regnet.py +++ b/torchvision/models/regnet.py @@ -26,6 +26,7 @@ "regnet_y_8gf", "regnet_y_16gf", "regnet_y_32gf", + "regnet_y_128gf", "regnet_x_400mf", "regnet_x_800mf", "regnet_x_1_6gf", @@ -505,6 +506,16 @@ def regnet_y_32gf(pretrained: bool = False, progress: bool = True, **kwargs: Any return _regnet("regnet_y_32gf", params, pretrained, progress, **kwargs) +def regnet_y_128gf(**kwargs: Any) -> RegNet: + """ + Constructs a RegNetY_128GF architecture from + `"Designing Network Design Spaces" `_. + NOTE: Pretrained weights are not available for this model. + """ + params = BlockParams.from_init_params(depth=27, w_0=456, w_a=160.83, w_m=2.52, group_width=264, se_ratio=0.25) + return _regnet("regnet_y_128gf", params, **kwargs) + + def regnet_x_400mf(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> RegNet: """ Constructs a RegNetX_400MF architecture from From 85b2ec1daaac600cc794a9fc3d83292d40e7f302 Mon Sep 17 00:00:00 2001 From: Kai Zhang Date: Sat, 8 Jan 2022 00:38:57 +0000 Subject: [PATCH 2/9] fix test --- torchvision/models/regnet.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/torchvision/models/regnet.py b/torchvision/models/regnet.py index cf1863ccfe4..97d4af953e6 100644 --- a/torchvision/models/regnet.py +++ b/torchvision/models/regnet.py @@ -512,8 +512,10 @@ def regnet_y_128gf(**kwargs: Any) -> RegNet: `"Designing Network Design Spaces" `_. NOTE: Pretrained weights are not available for this model. """ + if "pretrained" in kwargs and kwargs.pop("pretrained"): + raise ValueError("No pretrained weights available for regnet_y_128gf.") params = BlockParams.from_init_params(depth=27, w_0=456, w_a=160.83, w_m=2.52, group_width=264, se_ratio=0.25) - return _regnet("regnet_y_128gf", params, **kwargs) + return _regnet("regnet_y_128gf", params, pretrained=False, progress=False, **kwargs) def regnet_x_400mf(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> RegNet: From 83dd16a4883b713ed3c264103cf26b58d146bdb6 Mon Sep 17 00:00:00 2001 From: Kai Zhang Date: Sat, 8 Jan 2022 01:51:18 +0000 Subject: [PATCH 3/9] add expected test file --- .../ModelTester.test_regnet_y_128gf_expect.pkl | Bin 0 -> 939 bytes 1 file changed, 0 insertions(+), 0 deletions(-) create mode 100644 test/expect/ModelTester.test_regnet_y_128gf_expect.pkl diff --git a/test/expect/ModelTester.test_regnet_y_128gf_expect.pkl b/test/expect/ModelTester.test_regnet_y_128gf_expect.pkl new file mode 100644 index 0000000000000000000000000000000000000000..4f6037929cc79ec90862a76bdfd341c353247da5 GIT binary patch literal 939 zcmWIWW@cev;NW1u00Im`42ea_8JT6N`YDMeiFyUuIc`pT3{fbcfhoBpAE-(%zO*DW zr zf)S|3ppZF&8AvA=loqmh8ZvUemW=jY_4CYNO9=M{7L z7p0^YrKY%KCYNv(a%ct>a+VZw1r>7Z1$eV_Fj*X^nFTZrgadH;l#f9R#i#lPZcb`w z{zUOK5@+ShDb~e?&3jT;vD?@x$nJGvnqkwH*<>a6BYaQ5lC3s=KLc!DztrE8BCyh$ zfAgB%M|N$sxo+}v7vGwLR%;%d*?l1YgY~~($$S1hb=%b>CA4Ru5vR?wk6)}K5C612 z6f9}=orA+@ZzKU=0tgvecqtU@W*fnG+dYNFfJd8gqeczBn&E zlnH1n2nTpGf+%>JM2 Date: Mon, 10 Jan 2022 20:11:49 +0000 Subject: [PATCH 4/9] update regnet factory function, add to prototype as well --- torchvision/models/regnet.py | 10 +++++----- torchvision/prototype/models/regnet.py | 17 +++++++++++++++++ 2 files changed, 22 insertions(+), 5 deletions(-) diff --git a/torchvision/models/regnet.py b/torchvision/models/regnet.py index 97d4af953e6..1066ade43f4 100644 --- a/torchvision/models/regnet.py +++ b/torchvision/models/regnet.py @@ -506,16 +506,16 @@ def regnet_y_32gf(pretrained: bool = False, progress: bool = True, **kwargs: Any return _regnet("regnet_y_32gf", params, pretrained, progress, **kwargs) -def regnet_y_128gf(**kwargs: Any) -> RegNet: +def regnet_y_128gf(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> RegNet: """ Constructs a RegNetY_128GF architecture from `"Designing Network Design Spaces" `_. NOTE: Pretrained weights are not available for this model. """ - if "pretrained" in kwargs and kwargs.pop("pretrained"): - raise ValueError("No pretrained weights available for regnet_y_128gf.") - params = BlockParams.from_init_params(depth=27, w_0=456, w_a=160.83, w_m=2.52, group_width=264, se_ratio=0.25) - return _regnet("regnet_y_128gf", params, pretrained=False, progress=False, **kwargs) + params = BlockParams.from_init_params( + depth=27, w_0=456, w_a=160.83, w_m=2.52, group_width=264, se_ratio=0.25, **kwargs + ) + return _regnet("regnet_y_128gf", params, pretrained, progress, **kwargs) def regnet_x_400mf(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> RegNet: diff --git a/torchvision/prototype/models/regnet.py b/torchvision/prototype/models/regnet.py index db1e86fdcab..c23a971bb5f 100644 --- a/torchvision/prototype/models/regnet.py +++ b/torchvision/prototype/models/regnet.py @@ -20,6 +20,7 @@ "RegNet_Y_8GF_Weights", "RegNet_Y_16GF_Weights", "RegNet_Y_32GF_Weights", + "RegNet_Y_128GF_Weights", "RegNet_X_400MF_Weights", "RegNet_X_800MF_Weights", "RegNet_X_1_6GF_Weights", @@ -34,6 +35,7 @@ "regnet_y_8gf", "regnet_y_16gf", "regnet_y_32gf", + "regnet_y_128gf", "regnet_x_400mf", "regnet_x_800mf", "regnet_x_1_6gf", @@ -253,6 +255,11 @@ class RegNet_Y_32GF_Weights(WeightsEnum): default = ImageNet1K_V2 +class RegNet_Y_128GF_Weights(WeightsEnum): + # weights are not available yet. + pass + + class RegNet_X_400MF_Weights(WeightsEnum): ImageNet1K_V1 = Weights( url="https://download.pytorch.org/models/regnet_x_400mf-adf1edd5.pth", @@ -501,6 +508,16 @@ def regnet_y_32gf(*, weights: Optional[RegNet_Y_32GF_Weights] = None, progress: return _regnet(params, weights, progress, **kwargs) +@handle_legacy_interface(weights=("pretrained", None)) +def regnet_y_128gf(*, weights: Optional[RegNet_Y_128GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet: + weights = RegNet_Y_128GF_Weights.verify(weights) + + params = BlockParams.from_init_params( + depth=27, w_0=456, w_a=160.83, w_m=2.52, group_width=264, se_ratio=0.25, **kwargs + ) + return _regnet(params, weights, progress, **kwargs) + + @handle_legacy_interface(weights=("pretrained", RegNet_X_400MF_Weights.ImageNet1K_V1)) def regnet_x_400mf(*, weights: Optional[RegNet_X_400MF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet: weights = RegNet_X_400MF_Weights.verify(weights) From cf27e0e9a21ee26cf421b6c4bf5d329625d508b1 Mon Sep 17 00:00:00 2001 From: Kai Zhang Date: Tue, 11 Jan 2022 00:01:30 +0000 Subject: [PATCH 5/9] write torchscript to temp file instead bytesio in model test --- test/test_models.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/test/test_models.py b/test/test_models.py index 2e0ed783849..62a214506b2 100644 --- a/test/test_models.py +++ b/test/test_models.py @@ -1,6 +1,5 @@ import contextlib import functools -import io import operator import os import pkgutil @@ -8,6 +7,7 @@ import traceback import warnings from collections import OrderedDict +from tempfile import TemporaryDirectory import pytest import torch @@ -126,10 +126,10 @@ def assert_export_import_module(m, args): def get_export_import_copy(m): """Save and load a TorchScript model""" - buffer = io.BytesIO() - torch.jit.save(m, buffer) - buffer.seek(0) - imported = torch.jit.load(buffer) + with TemporaryDirectory() as dir: + path = os.path.join(dir, "script.pt") + m.save(path) + imported = torch.jit.load(path) return imported m_import = get_export_import_copy(m) From 6985776e3d0c30428ec032254930cba1010c6d91 Mon Sep 17 00:00:00 2001 From: Kai Zhang Date: Tue, 11 Jan 2022 17:04:16 +0000 Subject: [PATCH 6/9] docs --- docs/source/models.rst | 2 ++ hubconf.py | 1 + 2 files changed, 3 insertions(+) diff --git a/docs/source/models.rst b/docs/source/models.rst index 9c750908b06..62c104cf927 100644 --- a/docs/source/models.rst +++ b/docs/source/models.rst @@ -76,6 +76,7 @@ You can construct a model with random weights by calling its constructor: regnet_y_8gf = models.regnet_y_8gf() regnet_y_16gf = models.regnet_y_16gf() regnet_y_32gf = models.regnet_y_32gf() + regnet_y_128gf = models.regnet_y_128gf() regnet_x_400mf = models.regnet_x_400mf() regnet_x_800mf = models.regnet_x_800mf() regnet_x_1_6gf = models.regnet_x_1_6gf() @@ -439,6 +440,7 @@ RegNet regnet_y_8gf regnet_y_16gf regnet_y_32gf + regnet_y_128gf regnet_x_400mf regnet_x_800mf regnet_x_1_6gf diff --git a/hubconf.py b/hubconf.py index 81b15ff9ff1..2b2eeb1c166 100644 --- a/hubconf.py +++ b/hubconf.py @@ -27,6 +27,7 @@ regnet_y_8gf, regnet_y_16gf, regnet_y_32gf, + regnet_y_128gf, regnet_x_400mf, regnet_x_800mf, regnet_x_1_6gf, From ffcdf09a09331c8c3e555b7142fb004dc07884c8 Mon Sep 17 00:00:00 2001 From: Kai Zhang Date: Wed, 12 Jan 2022 22:20:37 +0000 Subject: [PATCH 7/9] clear GPU memory --- test/test_models.py | 1 + 1 file changed, 1 insertion(+) diff --git a/test/test_models.py b/test/test_models.py index 62a214506b2..c2c67f94351 100644 --- a/test/test_models.py +++ b/test/test_models.py @@ -166,6 +166,7 @@ def get_export_import_copy(m): torch.testing.assert_close(eager_out, script_out, atol=1e-4, rtol=1e-4) assert_export_import_module(sm, args) + torch.cuda.empty_cache() def _check_fx_compatible(model, inputs): From 9b608aa5728317ba6d6250207674baa1859a72ee Mon Sep 17 00:00:00 2001 From: Kai Zhang Date: Wed, 12 Jan 2022 23:31:57 +0000 Subject: [PATCH 8/9] no_grad --- test/test_models.py | 25 ++++++++++++++----------- 1 file changed, 14 insertions(+), 11 deletions(-) diff --git a/test/test_models.py b/test/test_models.py index c2c67f94351..f8a9d4cf0fc 100644 --- a/test/test_models.py +++ b/test/test_models.py @@ -133,10 +133,12 @@ def get_export_import_copy(m): return imported m_import = get_export_import_copy(m) - with freeze_rng_state(): - results = m(*args) - with freeze_rng_state(): - results_from_imported = m_import(*args) + with torch.no_grad(): + with freeze_rng_state(): + results = m(*args) + with torch.no_grad(): + with freeze_rng_state(): + results_from_imported = m_import(*args) tol = 3e-4 torch.testing.assert_close(results, results_from_imported, atol=tol, rtol=tol) @@ -156,17 +158,18 @@ def get_export_import_copy(m): sm = torch.jit.script(nn_module) - with freeze_rng_state(): - eager_out = nn_module(*args) + with torch.no_grad(): + with freeze_rng_state(): + eager_out = nn_module(*args) - with freeze_rng_state(): - script_out = sm(*args) - if unwrapper: - script_out = unwrapper(script_out) + with torch.no_grad(): + with freeze_rng_state(): + script_out = sm(*args) + if unwrapper: + script_out = unwrapper(script_out) torch.testing.assert_close(eager_out, script_out, atol=1e-4, rtol=1e-4) assert_export_import_module(sm, args) - torch.cuda.empty_cache() def _check_fx_compatible(model, inputs): From a038c858f06b42604513506db7e301b831cb8724 Mon Sep 17 00:00:00 2001 From: Kai Zhang Date: Thu, 13 Jan 2022 00:34:47 +0000 Subject: [PATCH 9/9] nit --- test/test_models.py | 24 ++++++++++-------------- 1 file changed, 10 insertions(+), 14 deletions(-) diff --git a/test/test_models.py b/test/test_models.py index f8a9d4cf0fc..f4f1828d8af 100644 --- a/test/test_models.py +++ b/test/test_models.py @@ -133,12 +133,10 @@ def get_export_import_copy(m): return imported m_import = get_export_import_copy(m) - with torch.no_grad(): - with freeze_rng_state(): - results = m(*args) - with torch.no_grad(): - with freeze_rng_state(): - results_from_imported = m_import(*args) + with torch.no_grad(), freeze_rng_state(): + results = m(*args) + with torch.no_grad(), freeze_rng_state(): + results_from_imported = m_import(*args) tol = 3e-4 torch.testing.assert_close(results, results_from_imported, atol=tol, rtol=tol) @@ -158,15 +156,13 @@ def get_export_import_copy(m): sm = torch.jit.script(nn_module) - with torch.no_grad(): - with freeze_rng_state(): - eager_out = nn_module(*args) + with torch.no_grad(), freeze_rng_state(): + eager_out = nn_module(*args) - with torch.no_grad(): - with freeze_rng_state(): - script_out = sm(*args) - if unwrapper: - script_out = unwrapper(script_out) + with torch.no_grad(), freeze_rng_state(): + script_out = sm(*args) + if unwrapper: + script_out = unwrapper(script_out) torch.testing.assert_close(eager_out, script_out, atol=1e-4, rtol=1e-4) assert_export_import_module(sm, args)