From b383787b4690da3a8f57ebc25b2d43fae762e45a Mon Sep 17 00:00:00 2001 From: Oliver Aobo Yang Date: Tue, 13 Dec 2022 10:30:06 -0800 Subject: [PATCH] Support different reg_reduction in Captum STG Summary: Add a new `str` argument `reg_reduction` in Captum STG classes, which specifies how the returned regularization should be reduced. Following Pytorch Loss's design, support 3 modes: `sum`, `mean`, and `none`. The default is `sum`. (There may be needs for other modes in future, like `weighted_sum`. With customized `mask`, each gate may handle different number of elements. The application may want to use as few elements as possible instead of as few gates. For now, such use cases can use `none` option and reduce themselves) Although we previously used `mean`, we decided to change to `sum` as default for 3 reasons: 1. The original paper "LEARNING SPARSE NEURAL NETWORKS THROUGH L0 REGULARIZATION" used `sum` both in its writing and its [implementation](https://github.com/AMLab-Amsterdam/L0_regularization/blob/master/l0_layers.py#L70) {F822978249} 2. L^1 and L^2 regularization also `sum` over each parameter without averaging over total number of parameters within a model. See [Pytorch's implementation](https://github.com/pytorch/pytorch/blob/df569367ef444dc9831ef0dde3bc611bcabcfbf9/torch/optim/adagrad.py#L268) 3. When there are multiple STG of imbalanced lengths, the results are comparable in `sum` but not `mean`. If the model has 2 STG, where one has 100 gates and the other has one single gate, the regularization of each gate in the 1st STG will be divided by 100 in `mean`, which makes the 1st STG 100 times weaker than the 2nd STG. This is usually unexpected for users. Using `mean` or `sum` will not impact the performance when there is only one BSN layer, coz people can tune `reg_weight` to counter the difference. The authors of "Feature selection using Stochastic Gates" mixed using `sum` and `mean` in [their implementation](https://github.com/runopti/stg/blob/master/python/stg/models.py#L164-L195) For backward compatibility, explicitly specified `reg_reduction = "mean"` for all existing usages in Pyper and MVAI. Differential Revision: D41991741 fbshipit-source-id: 77f54cf3948e44e943afff795bf473adaa01fa56 --- .../binary_concrete_stochastic_gates.py | 13 +++++++- captum/module/gaussian_stochastic_gates.py | 12 ++++++- captum/module/stochastic_gates_base.py | 29 ++++++++++++++-- .../test_binary_concrete_stochastic_gates.py | 32 +++++++++++++++--- .../module/test_gaussian_stochastic_gates.py | 33 ++++++++++++++++--- 5 files changed, 107 insertions(+), 12 deletions(-) diff --git a/captum/module/binary_concrete_stochastic_gates.py b/captum/module/binary_concrete_stochastic_gates.py index b45ca58e55..d6fa318d87 100644 --- a/captum/module/binary_concrete_stochastic_gates.py +++ b/captum/module/binary_concrete_stochastic_gates.py @@ -60,6 +60,7 @@ def __init__( lower_bound: float = -0.1, upper_bound: float = 1.1, eps: float = 1e-8, + reg_reduction: str = "sum", ): """ Args: @@ -93,8 +94,18 @@ def __init__( eps (float): term to improve numerical stability in binary concerete sampling Default: 1e-8 + + reg_reduction (str, optional): the reduction to apply to + the regularization: 'none'|'mean'|'sum'. 'none': no reduction will be + applied and it will be the same as the return of get_active_probs, + 'mean': the sum of the gates non-zero probabilities will be divided by + the number of gates, 'sum': the gates non-zero probabilities will + be summed. + Default: 'sum' """ - super().__init__(n_gates, mask=mask, reg_weight=reg_weight) + super().__init__( + n_gates, mask=mask, reg_weight=reg_weight, reg_reduction=reg_reduction + ) # avoid changing the tensor's variable name # when the module is used after compilation, diff --git a/captum/module/gaussian_stochastic_gates.py b/captum/module/gaussian_stochastic_gates.py index ebaa692c32..b10f837dc1 100644 --- a/captum/module/gaussian_stochastic_gates.py +++ b/captum/module/gaussian_stochastic_gates.py @@ -38,6 +38,7 @@ def __init__( mask: Optional[Tensor] = None, reg_weight: Optional[float] = 1.0, std: Optional[float] = 0.5, + reg_reduction: str = "sum", ): """ Args: @@ -58,8 +59,17 @@ def __init__( std (Optional[float]): standard deviation that will be fixed throughout. Default: 0.5 (by paper reference) + reg_reduction (str, optional): the reduction to apply to + the regularization: 'none'|'mean'|'sum'. 'none': no reduction will be + applied and it will be the same as the return of get_active_probs, + 'mean': the sum of the gates non-zero probabilities will be divided by + the number of gates, 'sum': the gates non-zero probabilities will + be summed. + Default: 'sum' """ - super().__init__(n_gates, mask=mask, reg_weight=reg_weight) + super().__init__( + n_gates, mask=mask, reg_weight=reg_weight, reg_reduction=reg_reduction + ) mu = torch.empty(n_gates) nn.init.normal_(mu, mean=0.5, std=0.01) diff --git a/captum/module/stochastic_gates_base.py b/captum/module/stochastic_gates_base.py index c10d32d596..75eebb2d65 100644 --- a/captum/module/stochastic_gates_base.py +++ b/captum/module/stochastic_gates_base.py @@ -29,7 +29,11 @@ class StochasticGatesBase(Module, ABC): """ def __init__( - self, n_gates: int, mask: Optional[Tensor] = None, reg_weight: float = 1.0 + self, + n_gates: int, + mask: Optional[Tensor] = None, + reg_weight: float = 1.0, + reg_reduction: str = "sum", ): """ Args: @@ -46,6 +50,14 @@ def __init__( reg_weight (Optional[float]): rescaling weight for L0 regularization term. Default: 1.0 + + reg_reduction (str, optional): the reduction to apply to + the regularization: 'none'|'mean'|'sum'. 'none': no reduction will be + applied and it will be the same as the return of get_active_probs, + 'mean': the sum of the gates non-zero probabilities will be divided by + the number of gates, 'sum': the gates non-zero probabilities will + be summed. + Default: 'sum' """ super().__init__() @@ -57,6 +69,12 @@ def __init__( " should correspond to a gate" ) + valid_reg_reduction = ["none", "mean", "sum"] + assert ( + reg_reduction in valid_reg_reduction + ), f"reg_reduction must be one of [none, mean, sum], received: {reg_reduction}" + self.reg_reduction = reg_reduction + self.n_gates = n_gates self.register_buffer( "mask", mask.detach().clone() if mask is not None else None @@ -106,7 +124,14 @@ def forward(self, input_tensor: Tensor) -> Tuple[Tensor, Tensor]: gated_input = input_tensor * gate_values prob_density = self._get_gate_active_probs() - l0_reg = self.reg_weight * prob_density.mean() + if self.reg_reduction == "sum": + l0_reg = prob_density.sum() + elif self.reg_reduction == "mean": + l0_reg = prob_density.mean() + else: + l0_reg = prob_density + + l0_reg *= self.reg_weight return gated_input, l0_reg diff --git a/tests/module/test_binary_concrete_stochastic_gates.py b/tests/module/test_binary_concrete_stochastic_gates.py index c910370350..25efbb26ad 100644 --- a/tests/module/test_binary_concrete_stochastic_gates.py +++ b/tests/module/test_binary_concrete_stochastic_gates.py @@ -32,7 +32,7 @@ def test_bcstg_1d_input(self) -> None: ).to(self.testing_device) gated_input, reg = bcstg(input_tensor) - expected_reg = 0.8316 + expected_reg = 2.4947 if self.testing_device == "cpu": expected_gated_input = [[0.0000, 0.0212, 0.1892], [0.1839, 0.3753, 0.4937]] @@ -42,6 +42,30 @@ def test_bcstg_1d_input(self) -> None: assertTensorAlmostEqual(self, gated_input, expected_gated_input, mode="max") assertTensorAlmostEqual(self, reg, expected_reg) + def test_bcstg_1d_input_with_reg_reduction(self) -> None: + + dim = 3 + mean_bcstg = BinaryConcreteStochasticGates(dim, reg_reduction="mean").to( + self.testing_device + ) + none_bcstg = BinaryConcreteStochasticGates(dim, reg_reduction="none").to( + self.testing_device + ) + input_tensor = torch.tensor( + [ + [0.0, 0.1, 0.2], + [0.3, 0.4, 0.5], + ] + ).to(self.testing_device) + + mean_gated_input, mean_reg = mean_bcstg(input_tensor) + none_gated_input, none_reg = none_bcstg(input_tensor) + expected_mean_reg = 0.8316 + expected_none_reg = torch.tensor([0.8321, 0.8310, 0.8325]) + + assertTensorAlmostEqual(self, mean_reg, expected_mean_reg) + assertTensorAlmostEqual(self, none_reg, expected_none_reg) + def test_bcstg_1d_input_with_n_gates_error(self) -> None: dim = 3 @@ -85,7 +109,7 @@ def test_bcstg_1d_input_with_mask(self) -> None: ).to(self.testing_device) gated_input, reg = bcstg(input_tensor) - expected_reg = 0.8321 + expected_reg = 1.6643 if self.testing_device == "cpu": expected_gated_input = [[0.0000, 0.0000, 0.1679], [0.0000, 0.0000, 0.2223]] @@ -118,7 +142,7 @@ def test_bcstg_2d_input(self) -> None: gated_input, reg = bcstg(input_tensor) - expected_reg = 0.8317 + expected_reg = 4.9903 if self.testing_device == "cpu": expected_gated_input = [ [[0.0000, 0.0990], [0.0261, 0.2431], [0.0551, 0.3863]], @@ -179,7 +203,7 @@ def test_bcstg_2d_input_with_mask(self) -> None: ).to(self.testing_device) gated_input, reg = bcstg(input_tensor) - expected_reg = 0.8316 + expected_reg = 2.4947 if self.testing_device == "cpu": expected_gated_input = [ diff --git a/tests/module/test_gaussian_stochastic_gates.py b/tests/module/test_gaussian_stochastic_gates.py index 06baaa8947..03df56c51f 100644 --- a/tests/module/test_gaussian_stochastic_gates.py +++ b/tests/module/test_gaussian_stochastic_gates.py @@ -25,6 +25,7 @@ def test_gstg_1d_input(self) -> None: dim = 3 gstg = GaussianStochasticGates(dim).to(self.testing_device) + input_tensor = torch.tensor( [ [0.0, 0.1, 0.2], @@ -33,7 +34,7 @@ def test_gstg_1d_input(self) -> None: ).to(self.testing_device) gated_input, reg = gstg(input_tensor) - expected_reg = 0.8404 + expected_reg = 2.5213 if self.testing_device == "cpu": expected_gated_input = [[0.0000, 0.0198, 0.1483], [0.1848, 0.3402, 0.1782]] @@ -43,6 +44,30 @@ def test_gstg_1d_input(self) -> None: assertTensorAlmostEqual(self, gated_input, expected_gated_input, mode="max") assertTensorAlmostEqual(self, reg, expected_reg) + def test_gstg_1d_input_with_reg_reduction(self) -> None: + dim = 3 + mean_gstg = GaussianStochasticGates(dim, reg_reduction="mean").to( + self.testing_device + ) + none_gstg = GaussianStochasticGates(dim, reg_reduction="none").to( + self.testing_device + ) + + input_tensor = torch.tensor( + [ + [0.0, 0.1, 0.2], + [0.3, 0.4, 0.5], + ] + ).to(self.testing_device) + + _, mean_reg = mean_gstg(input_tensor) + _, none_reg = none_gstg(input_tensor) + expected_mean_reg = 0.8404 + expected_none_reg = torch.tensor([0.8424, 0.8384, 0.8438]) + + assertTensorAlmostEqual(self, mean_reg, expected_mean_reg) + assertTensorAlmostEqual(self, none_reg, expected_none_reg) + def test_gstg_1d_input_with_n_gates_error(self) -> None: dim = 3 @@ -65,7 +90,7 @@ def test_gstg_1d_input_with_mask(self) -> None: ).to(self.testing_device) gated_input, reg = gstg(input_tensor) - expected_reg = 0.8424 + expected_reg = 1.6849 if self.testing_device == "cpu": expected_gated_input = [[0.0000, 0.0000, 0.1225], [0.0583, 0.0777, 0.3779]] @@ -111,7 +136,7 @@ def test_gstg_2d_input(self) -> None: ).to(self.testing_device) gated_input, reg = gstg(input_tensor) - expected_reg = 0.8410 + expected_reg = 5.0458 if self.testing_device == "cpu": expected_gated_input = [ @@ -173,7 +198,7 @@ def test_gstg_2d_input_with_mask(self) -> None: ).to(self.testing_device) gated_input, reg = gstg(input_tensor) - expected_reg = 0.8404 + expected_reg = 2.5213 if self.testing_device == "cpu": expected_gated_input = [