From c5a1821c209dd6531c94d57d5e7420c9634eeb39 Mon Sep 17 00:00:00 2001 From: Hu Ye Date: Thu, 24 Mar 2022 19:49:04 +0800 Subject: [PATCH 01/22] add set_weight_decay --- torchvision/ops/_utils.py | 55 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 55 insertions(+) diff --git a/torchvision/ops/_utils.py b/torchvision/ops/_utils.py index 30f28e51c4c..9d30d5df9ad 100644 --- a/torchvision/ops/_utils.py +++ b/torchvision/ops/_utils.py @@ -61,3 +61,58 @@ def split_normalization_params( else: other_params.extend(p for p in module.parameters() if p.requires_grad) return norm_params, other_params + + +def set_weight_decay( + model: torch.nn.Module, + weight_decay: float, + norm_weight_decay: Optional[float] = None, + bias_weight_decay: Optional[float] = None, + custom_keys_weight_decay: Optional[Dict[str, float]] = None, +): + norm_classes = (torch.nn.modules.batchnorm._BatchNorm, torch.nn.LayerNorm, torch.nn.GroupNorm) + + norm_params = [] + bias_params = [] + other_params = [] + custom_params = {} + if custom_keys_weight_decay is not None: + for key in custom_keys_weight_decay: + custom_params[key] = [] + + for module in model.modules(): + if next(module.children(), None): + for name, p in module.named_parameters(recurse=False): + if not p.requires_grad: + continue + is_custom_key = False + for key in custom_params: + if key in name: + custom_params[key].append(p) + is_custom_key = True + if not is_custom_key: + other_params.append(p) + elif isinstance(module, norm_classes): + if norm_weight_decay is not None: + norm_params.extend(p for p in module.parameters() if p.requires_grad) + else: + other_params.extend(p for p in module.parameters() if p.requires_grad) + else: + for name, p in module.named_parameters(): + if not p.requires_grad: + continue + if name == "bias" and (bias_weight_decay is not None): + bias_params.append(p) + else: + other_params.append(p) + + param_groups = [] + if norm_weight_decay is not None: + param_groups.append({"params": norm_params, "weight_decay": norm_weight_decay}) + if bias_weight_decay is not None: + param_groups.append({"params": bias_params, "weight_decay": bias_weight_decay}) + for key in custom_params: + param_groups.append({"params": custom_params[key], "weight_decay": custom_keys_weight_decay[key]}) + param_groups.append({"params": other_params, "weight_decay": weight_decay}) + return param_groups + From 3955d447e7fcc5387ab092e2e33766e88a3124a8 Mon Sep 17 00:00:00 2001 From: Hu Ye Date: Fri, 25 Mar 2022 19:59:10 +0800 Subject: [PATCH 02/22] Update _utils.py --- torchvision/ops/_utils.py | 23 ++++++++++++++--------- 1 file changed, 14 insertions(+), 9 deletions(-) diff --git a/torchvision/ops/_utils.py b/torchvision/ops/_utils.py index 9d30d5df9ad..45807fe6ff0 100644 --- a/torchvision/ops/_utils.py +++ b/torchvision/ops/_utils.py @@ -67,10 +67,19 @@ def set_weight_decay( model: torch.nn.Module, weight_decay: float, norm_weight_decay: Optional[float] = None, + norm_classes: Optional[List[type]] = None, bias_weight_decay: Optional[float] = None, custom_keys_weight_decay: Optional[Dict[str, float]] = None, ): - norm_classes = (torch.nn.modules.batchnorm._BatchNorm, torch.nn.LayerNorm, torch.nn.GroupNorm) + if not norm_classes: + norm_classes = [ + nn.modules.batchnorm._BatchNorm, + nn.LayerNorm, + nn.GroupNorm, + nn.modules.instancenorm._InstanceNorm, + nn.LocalResponseNorm, + ] + norm_classes = tuple(norm_classes) norm_params = [] bias_params = [] @@ -87,21 +96,18 @@ def set_weight_decay( continue is_custom_key = False for key in custom_params: - if key in name: + if key == name: custom_params[key].append(p) is_custom_key = True if not is_custom_key: other_params.append(p) - elif isinstance(module, norm_classes): - if norm_weight_decay is not None: - norm_params.extend(p for p in module.parameters() if p.requires_grad) - else: - other_params.extend(p for p in module.parameters() if p.requires_grad) + elif isinstance(module, norm_classes) and norm_weight_decay is not None: + norm_params.extend(p for p in module.parameters() if p.requires_grad) else: for name, p in module.named_parameters(): if not p.requires_grad: continue - if name == "bias" and (bias_weight_decay is not None): + if name == "bias" and bias_weight_decay is not None: bias_params.append(p) else: other_params.append(p) @@ -115,4 +121,3 @@ def set_weight_decay( param_groups.append({"params": custom_params[key], "weight_decay": custom_keys_weight_decay[key]}) param_groups.append({"params": other_params, "weight_decay": weight_decay}) return param_groups - From 568d515a4f50bb6c25ade2e10db3200b37ce9f79 Mon Sep 17 00:00:00 2001 From: Hu Ye Date: Fri, 25 Mar 2022 20:01:40 +0800 Subject: [PATCH 03/22] refactor code --- torchvision/ops/_utils.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/torchvision/ops/_utils.py b/torchvision/ops/_utils.py index 45807fe6ff0..0cff5fd7c6e 100644 --- a/torchvision/ops/_utils.py +++ b/torchvision/ops/_utils.py @@ -113,11 +113,12 @@ def set_weight_decay( other_params.append(p) param_groups = [] - if norm_weight_decay is not None: + if len(norm_params) > 0: param_groups.append({"params": norm_params, "weight_decay": norm_weight_decay}) - if bias_weight_decay is not None: + if len(bias_params) > 0: param_groups.append({"params": bias_params, "weight_decay": bias_weight_decay}) for key in custom_params: - param_groups.append({"params": custom_params[key], "weight_decay": custom_keys_weight_decay[key]}) + if len(custom_params[key]) > 0: + param_groups.append({"params": custom_params[key], "weight_decay": custom_keys_weight_decay[key]}) param_groups.append({"params": other_params, "weight_decay": weight_decay}) return param_groups From c56a01dd684dd419d5cf88353d13ebeabb020815 Mon Sep 17 00:00:00 2001 From: Hu Ye Date: Sun, 27 Mar 2022 12:02:33 +0800 Subject: [PATCH 04/22] fix import --- torchvision/ops/_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchvision/ops/_utils.py b/torchvision/ops/_utils.py index 0cff5fd7c6e..36c1df22b42 100644 --- a/torchvision/ops/_utils.py +++ b/torchvision/ops/_utils.py @@ -1,4 +1,4 @@ -from typing import List, Optional, Tuple, Union +from typing import Dict, List, Optional, Tuple, Union import torch from torch import nn, Tensor From d1093431110a82dcc03dff57190f98f7fa298f64 Mon Sep 17 00:00:00 2001 From: Hu Ye Date: Sun, 27 Mar 2022 14:42:23 +0800 Subject: [PATCH 05/22] add set_weight_decay --- test/test_ops.py | 45 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 45 insertions(+) diff --git a/test/test_ops.py b/test/test_ops.py index d1562b00a42..ca8e36ce908 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -1366,6 +1366,51 @@ def test_split_normalization_params(self, norm_layer): assert len(params[0]) == 92 assert len(params[1]) == 82 + + @pytest.mark.parametrize("norm_weight_decay", [None, 0.0]) + @pytest.mark.parametrize("norm_layer", [None, nn.LayerNorm]) + @pytest.mark.parametrize("bias_weight_decay", [None, 0.0]) + @pytest.mark.parametrize("custom_keys_weight_decay", [None, {"class_token": 0.0, "pos_embedding": 0.0}]) + def test_set_weight_decay(self, norm_weight_decay, norm_layer, bias_weight_decay, custom_keys_weight_decay): + model = models.VisionTransformer( + image_size=224, + patch_size=16, + num_layers=1, + num_heads=2, + hidden_dim=8, + mlp_dim=4, + ) + param_groups = ops._utils.set_weight_decay( + model, + 1e-3, + norm_weight_decay=norm_weight_decay, + norm_classes=None if norm_layer is None else [norm_layer], + bias_weight_decay=bias_weight_decay, + custom_keys_weight_decay=custom_keys_weight_decay + ) + + if norm_weight_decay is None and bias_weight_decay is None and custom_keys_weight_decay is None: + assert len(param_groups) == 1 + assert len(param_groups[0]["params"]) == 20 + + if norm_weight_decay is not None and bias_weight_decay is None and custom_keys_weight_decay is None: + assert len(param_groups) == 2 + assert len(param_groups[0]["params"]) == 6 + assert len(param_groups[1]["params"]) == 14 + + if norm_weight_decay is not None and bias_weight_decay is not None and custom_keys_weight_decay is None: + assert len(param_groups) == 3 + assert len(param_groups[0]["params"]) == 6 + assert len(param_groups[1]["params"]) == 5 + assert len(param_groups[2]["params"]) == 9 + + if norm_weight_decay is not None and bias_weight_decay is not None and custom_keys_weight_decay is not None: + assert len(param_groups) == 5 + assert len(param_groups[0]["params"]) == 6 + assert len(param_groups[1]["params"]) == 5 + assert len(param_groups[2]["params"]) == 1 + assert len(param_groups[3]["params"]) == 1 + assert len(param_groups[4]["params"]) == 7 class TestDropBlock: From b84da511eea8edc48a65a64b3d59cdf0cff1f0e6 Mon Sep 17 00:00:00 2001 From: Hu Ye Date: Sun, 27 Mar 2022 14:49:52 +0800 Subject: [PATCH 06/22] fix lint --- test/test_ops.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/test/test_ops.py b/test/test_ops.py index ca8e36ce908..9bc70ea26c7 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -1366,10 +1366,10 @@ def test_split_normalization_params(self, norm_layer): assert len(params[0]) == 92 assert len(params[1]) == 82 - - @pytest.mark.parametrize("norm_weight_decay", [None, 0.0]) + + @pytest.mark.parametrize("norm_weight_decay", [None, 0.0]) @pytest.mark.parametrize("norm_layer", [None, nn.LayerNorm]) - @pytest.mark.parametrize("bias_weight_decay", [None, 0.0]) + @pytest.mark.parametrize("bias_weight_decay", [None, 0.0]) @pytest.mark.parametrize("custom_keys_weight_decay", [None, {"class_token": 0.0, "pos_embedding": 0.0}]) def test_set_weight_decay(self, norm_weight_decay, norm_layer, bias_weight_decay, custom_keys_weight_decay): model = models.VisionTransformer( @@ -1379,16 +1379,16 @@ def test_set_weight_decay(self, norm_weight_decay, norm_layer, bias_weight_decay num_heads=2, hidden_dim=8, mlp_dim=4, - ) + ) param_groups = ops._utils.set_weight_decay( model, 1e-3, norm_weight_decay=norm_weight_decay, norm_classes=None if norm_layer is None else [norm_layer], bias_weight_decay=bias_weight_decay, - custom_keys_weight_decay=custom_keys_weight_decay - ) - + custom_keys_weight_decay=custom_keys_weight_decay, + ) + if norm_weight_decay is None and bias_weight_decay is None and custom_keys_weight_decay is None: assert len(param_groups) == 1 assert len(param_groups[0]["params"]) == 20 From 8cc4eebf79908839ba3c8fc3a5bd4ad9c3ec63ab Mon Sep 17 00:00:00 2001 From: Hu Ye Date: Sun, 27 Mar 2022 15:00:21 +0800 Subject: [PATCH 07/22] fix lint --- test/test_ops.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/test_ops.py b/test/test_ops.py index 9bc70ea26c7..61bb26bcec6 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -1366,7 +1366,7 @@ def test_split_normalization_params(self, norm_layer): assert len(params[0]) == 92 assert len(params[1]) == 82 - + @pytest.mark.parametrize("norm_weight_decay", [None, 0.0]) @pytest.mark.parametrize("norm_layer", [None, nn.LayerNorm]) @pytest.mark.parametrize("bias_weight_decay", [None, 0.0]) @@ -1388,7 +1388,7 @@ def test_set_weight_decay(self, norm_weight_decay, norm_layer, bias_weight_decay bias_weight_decay=bias_weight_decay, custom_keys_weight_decay=custom_keys_weight_decay, ) - + if norm_weight_decay is None and bias_weight_decay is None and custom_keys_weight_decay is None: assert len(param_groups) == 1 assert len(param_groups[0]["params"]) == 20 From a8123879f24e153d7ae0d4fb44eb9ada4ea69e61 Mon Sep 17 00:00:00 2001 From: Hu Ye Date: Sun, 27 Mar 2022 15:07:27 +0800 Subject: [PATCH 08/22] replace split_normalization_params with set_weight_decay --- references/classification/train.py | 22 ++++++++++++++++------ 1 file changed, 16 insertions(+), 6 deletions(-) diff --git a/references/classification/train.py b/references/classification/train.py index eb8b56c1ad0..0c1afe60c06 100644 --- a/references/classification/train.py +++ b/references/classification/train.py @@ -229,12 +229,16 @@ def main(args): criterion = nn.CrossEntropyLoss(label_smoothing=args.label_smoothing) - if args.norm_weight_decay is None: - parameters = model.parameters() - else: - param_groups = torchvision.ops._utils.split_normalization_params(model) - wd_groups = [args.norm_weight_decay, args.weight_decay] - parameters = [{"params": p, "weight_decay": w} for p, w in zip(param_groups, wd_groups) if p] + custom_keys_weight_decay = None + if hasattr(model, "no_weight_decay_keys"): + custom_keys_weight_decay = {k: 0.0 for k in model.no_weight_decay_keys()} + parameters = torchvision.ops._utils.set_weight_decay( + model, + args.weight_decay, + norm_weight_decay=args.norm_weight_decay, + bias_weight_decay=args.bias_weight_decay, + custom_keys_weight_decay=custom_keys_weight_decay, + ) opt_name = args.opt.lower() if opt_name.startswith("sgd"): @@ -393,6 +397,12 @@ def get_args_parser(add_help=True): type=float, help="weight decay for Normalization layers (default: None, same value as --wd)", ) + parser.add_argument( + "--bias-weight-decay", + default=None, + type=float, + help="weight decay for bias parameter of all layers (default: None, same value as --wd)", + ) parser.add_argument( "--label-smoothing", default=0.0, type=float, help="label smoothing (default: 0.0)", dest="label_smoothing" ) From 4f2c206d57d5cc026bb64caa7e326f2da8b4457e Mon Sep 17 00:00:00 2001 From: Hu Ye Date: Mon, 28 Mar 2022 20:51:27 +0800 Subject: [PATCH 09/22] simplfy the code --- torchvision/ops/_utils.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/torchvision/ops/_utils.py b/torchvision/ops/_utils.py index 36c1df22b42..30a338dd9b2 100644 --- a/torchvision/ops/_utils.py +++ b/torchvision/ops/_utils.py @@ -94,12 +94,9 @@ def set_weight_decay( for name, p in module.named_parameters(recurse=False): if not p.requires_grad: continue - is_custom_key = False - for key in custom_params: - if key == name: - custom_params[key].append(p) - is_custom_key = True - if not is_custom_key: + if key in custom_params: + custom_params[key].append(p) + else: other_params.append(p) elif isinstance(module, norm_classes) and norm_weight_decay is not None: norm_params.extend(p for p in module.parameters() if p.requires_grad) From f12bd0840d3a67e875a5f4275ba5a5e90817e198 Mon Sep 17 00:00:00 2001 From: Hu Ye Date: Wed, 30 Mar 2022 21:21:41 +0800 Subject: [PATCH 10/22] refactor code --- torchvision/ops/_utils.py | 72 +++++++++++++++++++++------------------ 1 file changed, 39 insertions(+), 33 deletions(-) diff --git a/torchvision/ops/_utils.py b/torchvision/ops/_utils.py index 30a338dd9b2..4012e212e0b 100644 --- a/torchvision/ops/_utils.py +++ b/torchvision/ops/_utils.py @@ -68,8 +68,7 @@ def set_weight_decay( weight_decay: float, norm_weight_decay: Optional[float] = None, norm_classes: Optional[List[type]] = None, - bias_weight_decay: Optional[float] = None, - custom_keys_weight_decay: Optional[Dict[str, float]] = None, + custom_keys_weight_decay: Optional[List[Tuple[str, float]]] = None, ): if not norm_classes: norm_classes = [ @@ -81,41 +80,48 @@ def set_weight_decay( ] norm_classes = tuple(norm_classes) - norm_params = [] - bias_params = [] - other_params = [] - custom_params = {} + params = { + "other": [], + "norm": [], + } + params_weight_decay = { + "other": weight_decay, + "norm": norm_weight_decay, + } + custom_keys = [] if custom_keys_weight_decay is not None: - for key in custom_keys_weight_decay: - custom_params[key] = [] - - for module in model.modules(): - if next(module.children(), None): - for name, p in module.named_parameters(recurse=False): - if not p.requires_grad: - continue - if key in custom_params: - custom_params[key].append(p) - else: - other_params.append(p) - elif isinstance(module, norm_classes) and norm_weight_decay is not None: - norm_params.extend(p for p in module.parameters() if p.requires_grad) + for key, weight_decay in custom_keys_weight_decay: + params[key] = [] + params_weight_decay[key] = weight_decay + custom_keys.append(key) + + def _add_params(module, prefix=""): + # We firstly consider norm layers + if norm_weight_decay is not None and isinstance(module, norm_classes): + params["norm"].extend(p for p in module.parameters() if p.requires_grad) else: - for name, p in module.named_parameters(): + for name, p in module.named_parameters(recurse=False): if not p.requires_grad: continue - if name == "bias" and bias_weight_decay is not None: - bias_params.append(p) - else: - other_params.append(p) + is_custom_key = False + for key in custom_keys: + full_name = f"{prefix}.{name}" if prefix != "" else name + if key in full_name: + params[key].append(p) + is_custom_key = True + break + if not is_custom_key: + params["other"].append(p) + + for child_name, child_module in module.named_children(): + child_prefix = f"{prefix}.{child_name}" if prefix != "" else child_name + _add_params(child_module, prefix=child_prefix) + + _add_params(model) param_groups = [] - if len(norm_params) > 0: - param_groups.append({"params": norm_params, "weight_decay": norm_weight_decay}) - if len(bias_params) > 0: - param_groups.append({"params": bias_params, "weight_decay": bias_weight_decay}) - for key in custom_params: - if len(custom_params[key]) > 0: - param_groups.append({"params": custom_params[key], "weight_decay": custom_keys_weight_decay[key]}) - param_groups.append({"params": other_params, "weight_decay": weight_decay}) + for key in params: + if len(params[key]) > 0: + print(key, len(params[key])) + param_groups.append({"params": params[key], "weight_decay": params_weight_decay[key]}) return param_groups From f9f7f18096693f64c0c5192598d65cbe5040bd71 Mon Sep 17 00:00:00 2001 From: Hu Ye Date: Wed, 30 Mar 2022 21:23:20 +0800 Subject: [PATCH 11/22] refactor code --- test/test_ops.py | 45 +++++++++++++++++++-------------------------- 1 file changed, 19 insertions(+), 26 deletions(-) diff --git a/test/test_ops.py b/test/test_ops.py index 61bb26bcec6..fd6b11797c4 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -1367,11 +1367,10 @@ def test_split_normalization_params(self, norm_layer): assert len(params[0]) == 92 assert len(params[1]) == 82 - @pytest.mark.parametrize("norm_weight_decay", [None, 0.0]) + @pytest.mark.parametrize("norm_weight_decay", [None, 0.2]) @pytest.mark.parametrize("norm_layer", [None, nn.LayerNorm]) - @pytest.mark.parametrize("bias_weight_decay", [None, 0.0]) - @pytest.mark.parametrize("custom_keys_weight_decay", [None, {"class_token": 0.0, "pos_embedding": 0.0}]) - def test_set_weight_decay(self, norm_weight_decay, norm_layer, bias_weight_decay, custom_keys_weight_decay): + @pytest.mark.parametrize("custom_keys_weight_decay", [None, [("class_token", 0.3), ("pos_embedding", 0.4)]]) + def test_set_weight_decay(self, norm_weight_decay, norm_layer, custom_keys_weight_decay): model = models.VisionTransformer( image_size=224, patch_size=16, @@ -1380,37 +1379,31 @@ def test_set_weight_decay(self, norm_weight_decay, norm_layer, bias_weight_decay hidden_dim=8, mlp_dim=4, ) - param_groups = ops._utils.set_weight_decay( + param_groups = set_weight_decay( model, - 1e-3, + 0.1, norm_weight_decay=norm_weight_decay, norm_classes=None if norm_layer is None else [norm_layer], - bias_weight_decay=bias_weight_decay, - custom_keys_weight_decay=custom_keys_weight_decay, + custom_keys_weight_decay=custom_keys_weight_decay ) - - if norm_weight_decay is None and bias_weight_decay is None and custom_keys_weight_decay is None: + + if norm_weight_decay is None and custom_keys_weight_decay is None: assert len(param_groups) == 1 assert len(param_groups[0]["params"]) == 20 - - if norm_weight_decay is not None and bias_weight_decay is None and custom_keys_weight_decay is None: + + if norm_weight_decay is not None and custom_keys_weight_decay is None: assert len(param_groups) == 2 - assert len(param_groups[0]["params"]) == 6 - assert len(param_groups[1]["params"]) == 14 - - if norm_weight_decay is not None and bias_weight_decay is not None and custom_keys_weight_decay is None: - assert len(param_groups) == 3 - assert len(param_groups[0]["params"]) == 6 - assert len(param_groups[1]["params"]) == 5 - assert len(param_groups[2]["params"]) == 9 - - if norm_weight_decay is not None and bias_weight_decay is not None and custom_keys_weight_decay is not None: - assert len(param_groups) == 5 - assert len(param_groups[0]["params"]) == 6 - assert len(param_groups[1]["params"]) == 5 + param_groups.sort(key=lambda x: x["weight_decay"]) + assert len(param_groups[0]["params"]) == 14 + assert len(param_groups[1]["params"]) == 6 + + if norm_weight_decay is not None and custom_keys_weight_decay is not None: + assert len(param_groups) == 4 + param_groups.sort(key=lambda x: x["weight_decay"]) + assert len(param_groups[0]["params"]) == 12 + assert len(param_groups[1]["params"]) == 6 assert len(param_groups[2]["params"]) == 1 assert len(param_groups[3]["params"]) == 1 - assert len(param_groups[4]["params"]) == 7 class TestDropBlock: From 352978202771b47c8997e4053567dc7e22b776ca Mon Sep 17 00:00:00 2001 From: Hu Ye Date: Wed, 30 Mar 2022 21:29:49 +0800 Subject: [PATCH 12/22] fix lint --- test/test_ops.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/test/test_ops.py b/test/test_ops.py index fd6b11797c4..20b96d3577c 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -1367,7 +1367,7 @@ def test_split_normalization_params(self, norm_layer): assert len(params[0]) == 92 assert len(params[1]) == 82 - @pytest.mark.parametrize("norm_weight_decay", [None, 0.2]) + @pytest.mark.parametrize("norm_weight_decay", [None, 0.2]) @pytest.mark.parametrize("norm_layer", [None, nn.LayerNorm]) @pytest.mark.parametrize("custom_keys_weight_decay", [None, [("class_token", 0.3), ("pos_embedding", 0.4)]]) def test_set_weight_decay(self, norm_weight_decay, norm_layer, custom_keys_weight_decay): @@ -1384,13 +1384,13 @@ def test_set_weight_decay(self, norm_weight_decay, norm_layer, custom_keys_weigh 0.1, norm_weight_decay=norm_weight_decay, norm_classes=None if norm_layer is None else [norm_layer], - custom_keys_weight_decay=custom_keys_weight_decay + custom_keys_weight_decay=custom_keys_weight_decay, ) - + if norm_weight_decay is None and custom_keys_weight_decay is None: assert len(param_groups) == 1 assert len(param_groups[0]["params"]) == 20 - + if norm_weight_decay is not None and custom_keys_weight_decay is None: assert len(param_groups) == 2 param_groups.sort(key=lambda x: x["weight_decay"]) From d0a0efc1e9320091529de496e9b94bb76a12c1a0 Mon Sep 17 00:00:00 2001 From: Hu Ye Date: Thu, 31 Mar 2022 10:10:34 +0800 Subject: [PATCH 13/22] remove unused --- torchvision/ops/_utils.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/torchvision/ops/_utils.py b/torchvision/ops/_utils.py index bb8c18443e1..79dcb46d259 100644 --- a/torchvision/ops/_utils.py +++ b/torchvision/ops/_utils.py @@ -1,4 +1,4 @@ -from typing import Dict, List, Optional, Tuple, Union +from typing import List, Optional, Tuple, Union import torch from torch import nn, Tensor @@ -122,6 +122,5 @@ def _add_params(module, prefix=""): param_groups = [] for key in params: if len(params[key]) > 0: - print(key, len(params[key])) param_groups.append({"params": params[key], "weight_decay": params_weight_decay[key]}) return param_groups From 28679642ab43bd9ef848a59aa2c5a68b73c84b80 Mon Sep 17 00:00:00 2001 From: Hu Ye Date: Thu, 31 Mar 2022 18:59:21 +0800 Subject: [PATCH 14/22] Update test_ops.py --- test/test_ops.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_ops.py b/test/test_ops.py index 1db9598be1a..7ba4df005b4 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -1379,7 +1379,7 @@ def test_set_weight_decay(self, norm_weight_decay, norm_layer, custom_keys_weigh hidden_dim=8, mlp_dim=4, ) - param_groups = set_weight_decay( + param_groups = ops._utils.set_weight_decay( model, 0.1, norm_weight_decay=norm_weight_decay, From e4aba9d541dbb29a9e300a702c8b8c10419654d1 Mon Sep 17 00:00:00 2001 From: Hu Ye Date: Thu, 31 Mar 2022 19:37:17 +0800 Subject: [PATCH 15/22] Update train.py --- references/classification/train.py | 9 +-------- 1 file changed, 1 insertion(+), 8 deletions(-) diff --git a/references/classification/train.py b/references/classification/train.py index 0c1afe60c06..7fd3f8740ea 100644 --- a/references/classification/train.py +++ b/references/classification/train.py @@ -231,12 +231,11 @@ def main(args): custom_keys_weight_decay = None if hasattr(model, "no_weight_decay_keys"): - custom_keys_weight_decay = {k: 0.0 for k in model.no_weight_decay_keys()} + custom_keys_weight_decay = [(key, 0.0) for key in model.no_weight_decay_keys()] parameters = torchvision.ops._utils.set_weight_decay( model, args.weight_decay, norm_weight_decay=args.norm_weight_decay, - bias_weight_decay=args.bias_weight_decay, custom_keys_weight_decay=custom_keys_weight_decay, ) @@ -397,12 +396,6 @@ def get_args_parser(add_help=True): type=float, help="weight decay for Normalization layers (default: None, same value as --wd)", ) - parser.add_argument( - "--bias-weight-decay", - default=None, - type=float, - help="weight decay for bias parameter of all layers (default: None, same value as --wd)", - ) parser.add_argument( "--label-smoothing", default=0.0, type=float, help="label smoothing (default: 0.0)", dest="label_smoothing" ) From bbc0005032d7ed527e7afb94b52db4bc6a18f6db Mon Sep 17 00:00:00 2001 From: Hu Ye Date: Thu, 31 Mar 2022 20:51:14 +0800 Subject: [PATCH 16/22] Update _utils.py --- torchvision/ops/_utils.py | 29 ++++++++++++++--------------- 1 file changed, 14 insertions(+), 15 deletions(-) diff --git a/torchvision/ops/_utils.py b/torchvision/ops/_utils.py index 79dcb46d259..0a20f3eb4eb 100644 --- a/torchvision/ops/_utils.py +++ b/torchvision/ops/_utils.py @@ -96,21 +96,20 @@ def set_weight_decay( custom_keys.append(key) def _add_params(module, prefix=""): - # We firstly consider norm layers - if norm_weight_decay is not None and isinstance(module, norm_classes): - params["norm"].extend(p for p in module.parameters() if p.requires_grad) - else: - for name, p in module.named_parameters(recurse=False): - if not p.requires_grad: - continue - is_custom_key = False - for key in custom_keys: - full_name = f"{prefix}.{name}" if prefix != "" else name - if key in full_name: - params[key].append(p) - is_custom_key = True - break - if not is_custom_key: + for name, p in module.named_parameters(recurse=False): + if not p.requires_grad: + continue + is_custom_key = False + for key in custom_keys: + target_name = f"{prefix}.{name}" if prefix != "" and "." in key else name + if key == target_name: + params[key].append(p) + is_custom_key = True + break + if not is_custom_key: + if norm_weight_decay is not None and isinstance(module, norm_classes): + params["norm"].append(p) + else: params["other"].append(p) for child_name, child_module in module.named_children(): From 5c0bd122b216053859b2cd534fd20073170305d2 Mon Sep 17 00:00:00 2001 From: Hu Ye Date: Thu, 31 Mar 2022 23:08:33 +0800 Subject: [PATCH 17/22] Update train.py --- references/classification/train.py | 23 +++++++++++++++++++---- 1 file changed, 19 insertions(+), 4 deletions(-) diff --git a/references/classification/train.py b/references/classification/train.py index 7fd3f8740ea..bd6e01d8c33 100644 --- a/references/classification/train.py +++ b/references/classification/train.py @@ -229,14 +229,17 @@ def main(args): criterion = nn.CrossEntropyLoss(label_smoothing=args.label_smoothing) - custom_keys_weight_decay = None - if hasattr(model, "no_weight_decay_keys"): - custom_keys_weight_decay = [(key, 0.0) for key in model.no_weight_decay_keys()] + custom_keys_weight_decay = [] + if args.bias_weight_decay is not None: + custom_keys_weight_decay.append(("bias", args.bias_weight_decay)) + if args.transformer_weight_decay is not None: + for key in ["class_token", "position_embedding", "relative_position_bias"]: + custom_keys_weight_decay.append((key, args.transformer_weight_decay)) parameters = torchvision.ops._utils.set_weight_decay( model, args.weight_decay, norm_weight_decay=args.norm_weight_decay, - custom_keys_weight_decay=custom_keys_weight_decay, + custom_keys_weight_decay=custom_keys_weight_decay if len(custom_keys_weight_decay) > 0 else None, ) opt_name = args.opt.lower() @@ -396,6 +399,18 @@ def get_args_parser(add_help=True): type=float, help="weight decay for Normalization layers (default: None, same value as --wd)", ) + parser.add_argument( + "--bias-weight-decay", + default=None, + type=float, + help="weight decay for bias parameters of all layers (default: None, same value as --wd)", + ) + parser.add_argument( + "--transformer-weight-decay", + default=None, + type=float, + help="weight decay for special parameters for vision transformer models (default: None, same value as --wd)", + ) parser.add_argument( "--label-smoothing", default=0.0, type=float, help="label smoothing (default: 0.0)", dest="label_smoothing" ) From f0635dd3009a146976acfd4292ce79de534b13f9 Mon Sep 17 00:00:00 2001 From: Hu Ye Date: Fri, 1 Apr 2022 09:03:59 +0800 Subject: [PATCH 18/22] add set_weight_decay --- references/classification/utils.py | 63 ++++++++++++++++++++++++++++++ 1 file changed, 63 insertions(+) diff --git a/references/classification/utils.py b/references/classification/utils.py index 32658a7c137..c31f3928e86 100644 --- a/references/classification/utils.py +++ b/references/classification/utils.py @@ -5,6 +5,7 @@ import os import time from collections import defaultdict, deque, OrderedDict +from typing import List, Optional, Tuple import torch import torch.distributed as dist @@ -400,3 +401,65 @@ def reduce_across_processes(val): dist.barrier() dist.all_reduce(t) return t + + +def set_weight_decay( + model: torch.nn.Module, + weight_decay: float, + norm_weight_decay: Optional[float] = None, + norm_classes: Optional[List[type]] = None, + custom_keys_weight_decay: Optional[List[Tuple[str, float]]] = None, +): + if not norm_classes: + norm_classes = [ + torch.nn.modules.batchnorm._BatchNorm, + torch.nn.LayerNorm, + torch.nn.GroupNorm, + torch.nn.modules.instancenorm._InstanceNorm, + torch.nn.LocalResponseNorm, + ] + norm_classes = tuple(norm_classes) + + params = { + "other": [], + "norm": [], + } + params_weight_decay = { + "other": weight_decay, + "norm": norm_weight_decay, + } + custom_keys = [] + if custom_keys_weight_decay is not None: + for key, weight_decay in custom_keys_weight_decay: + params[key] = [] + params_weight_decay[key] = weight_decay + custom_keys.append(key) + + def _add_params(module, prefix=""): + for name, p in module.named_parameters(recurse=False): + if not p.requires_grad: + continue + is_custom_key = False + for key in custom_keys: + target_name = f"{prefix}.{name}" if prefix != "" and "." in key else name + if key == target_name: + params[key].append(p) + is_custom_key = True + break + if not is_custom_key: + if norm_weight_decay is not None and isinstance(module, norm_classes): + params["norm"].append(p) + else: + params["other"].append(p) + + for child_name, child_module in module.named_children(): + child_prefix = f"{prefix}.{child_name}" if prefix != "" else child_name + _add_params(child_module, prefix=child_prefix) + + _add_params(model) + + param_groups = [] + for key in params: + if len(params[key]) > 0: + param_groups.append({"params": params[key], "weight_decay": params_weight_decay[key]}) + return param_groups From 3564ae64c7366793f167f6d2bc1bbbb26e047801 Mon Sep 17 00:00:00 2001 From: Hu Ye Date: Fri, 1 Apr 2022 09:05:12 +0800 Subject: [PATCH 19/22] add set_weight_decay --- references/classification/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/references/classification/train.py b/references/classification/train.py index bd6e01d8c33..a7ca2e5de4b 100644 --- a/references/classification/train.py +++ b/references/classification/train.py @@ -235,7 +235,7 @@ def main(args): if args.transformer_weight_decay is not None: for key in ["class_token", "position_embedding", "relative_position_bias"]: custom_keys_weight_decay.append((key, args.transformer_weight_decay)) - parameters = torchvision.ops._utils.set_weight_decay( + parameters = utils.set_weight_decay( model, args.weight_decay, norm_weight_decay=args.norm_weight_decay, From 8251dad9f82d09dee2b8cab5b28688c611b5d3a4 Mon Sep 17 00:00:00 2001 From: Hu Ye Date: Fri, 1 Apr 2022 09:06:23 +0800 Subject: [PATCH 20/22] Update _utils.py --- torchvision/ops/_utils.py | 62 --------------------------------------- 1 file changed, 62 deletions(-) diff --git a/torchvision/ops/_utils.py b/torchvision/ops/_utils.py index 0a20f3eb4eb..107785266a1 100644 --- a/torchvision/ops/_utils.py +++ b/torchvision/ops/_utils.py @@ -61,65 +61,3 @@ def split_normalization_params( else: other_params.extend(p for p in module.parameters() if p.requires_grad) return norm_params, other_params - - -def set_weight_decay( - model: torch.nn.Module, - weight_decay: float, - norm_weight_decay: Optional[float] = None, - norm_classes: Optional[List[type]] = None, - custom_keys_weight_decay: Optional[List[Tuple[str, float]]] = None, -): - if not norm_classes: - norm_classes = [ - nn.modules.batchnorm._BatchNorm, - nn.LayerNorm, - nn.GroupNorm, - nn.modules.instancenorm._InstanceNorm, - nn.LocalResponseNorm, - ] - norm_classes = tuple(norm_classes) - - params = { - "other": [], - "norm": [], - } - params_weight_decay = { - "other": weight_decay, - "norm": norm_weight_decay, - } - custom_keys = [] - if custom_keys_weight_decay is not None: - for key, weight_decay in custom_keys_weight_decay: - params[key] = [] - params_weight_decay[key] = weight_decay - custom_keys.append(key) - - def _add_params(module, prefix=""): - for name, p in module.named_parameters(recurse=False): - if not p.requires_grad: - continue - is_custom_key = False - for key in custom_keys: - target_name = f"{prefix}.{name}" if prefix != "" and "." in key else name - if key == target_name: - params[key].append(p) - is_custom_key = True - break - if not is_custom_key: - if norm_weight_decay is not None and isinstance(module, norm_classes): - params["norm"].append(p) - else: - params["other"].append(p) - - for child_name, child_module in module.named_children(): - child_prefix = f"{prefix}.{child_name}" if prefix != "" else child_name - _add_params(child_module, prefix=child_prefix) - - _add_params(model) - - param_groups = [] - for key in params: - if len(params[key]) > 0: - param_groups.append({"params": params[key], "weight_decay": params_weight_decay[key]}) - return param_groups From 5a0b0ccaba34bf0d3267905ac2ef420a842a400b Mon Sep 17 00:00:00 2001 From: Hu Ye Date: Fri, 1 Apr 2022 09:07:07 +0800 Subject: [PATCH 21/22] Update test_ops.py --- test/test_ops.py | 38 -------------------------------------- 1 file changed, 38 deletions(-) diff --git a/test/test_ops.py b/test/test_ops.py index 7ba4df005b4..ad9aaefee52 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -1367,44 +1367,6 @@ def test_split_normalization_params(self, norm_layer): assert len(params[0]) == 92 assert len(params[1]) == 82 - @pytest.mark.parametrize("norm_weight_decay", [None, 0.2]) - @pytest.mark.parametrize("norm_layer", [None, nn.LayerNorm]) - @pytest.mark.parametrize("custom_keys_weight_decay", [None, [("class_token", 0.3), ("pos_embedding", 0.4)]]) - def test_set_weight_decay(self, norm_weight_decay, norm_layer, custom_keys_weight_decay): - model = models.VisionTransformer( - image_size=224, - patch_size=16, - num_layers=1, - num_heads=2, - hidden_dim=8, - mlp_dim=4, - ) - param_groups = ops._utils.set_weight_decay( - model, - 0.1, - norm_weight_decay=norm_weight_decay, - norm_classes=None if norm_layer is None else [norm_layer], - custom_keys_weight_decay=custom_keys_weight_decay, - ) - - if norm_weight_decay is None and custom_keys_weight_decay is None: - assert len(param_groups) == 1 - assert len(param_groups[0]["params"]) == 20 - - if norm_weight_decay is not None and custom_keys_weight_decay is None: - assert len(param_groups) == 2 - param_groups.sort(key=lambda x: x["weight_decay"]) - assert len(param_groups[0]["params"]) == 14 - assert len(param_groups[1]["params"]) == 6 - - if norm_weight_decay is not None and custom_keys_weight_decay is not None: - assert len(param_groups) == 4 - param_groups.sort(key=lambda x: x["weight_decay"]) - assert len(param_groups[0]["params"]) == 12 - assert len(param_groups[1]["params"]) == 6 - assert len(param_groups[2]["params"]) == 1 - assert len(param_groups[3]["params"]) == 1 - class TestDropBlock: @pytest.mark.parametrize("seed", range(10)) From 755ccd5157f40508bd1e2aa920dc0e7af3d7f0ee Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Fri, 1 Apr 2022 15:06:18 +0100 Subject: [PATCH 22/22] Change `--transformer-weight-decay` to `--transformer-embedding-decay` --- references/classification/train.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/references/classification/train.py b/references/classification/train.py index a7ca2e5de4b..6a3c289bc04 100644 --- a/references/classification/train.py +++ b/references/classification/train.py @@ -232,9 +232,9 @@ def main(args): custom_keys_weight_decay = [] if args.bias_weight_decay is not None: custom_keys_weight_decay.append(("bias", args.bias_weight_decay)) - if args.transformer_weight_decay is not None: + if args.transformer_embedding_decay is not None: for key in ["class_token", "position_embedding", "relative_position_bias"]: - custom_keys_weight_decay.append((key, args.transformer_weight_decay)) + custom_keys_weight_decay.append((key, args.transformer_embedding_decay)) parameters = utils.set_weight_decay( model, args.weight_decay, @@ -406,10 +406,10 @@ def get_args_parser(add_help=True): help="weight decay for bias parameters of all layers (default: None, same value as --wd)", ) parser.add_argument( - "--transformer-weight-decay", + "--transformer-embedding-decay", default=None, type=float, - help="weight decay for special parameters for vision transformer models (default: None, same value as --wd)", + help="weight decay for embedding parameters for vision transformer models (default: None, same value as --wd)", ) parser.add_argument( "--label-smoothing", default=0.0, type=float, help="label smoothing (default: 0.0)", dest="label_smoothing"