From 6518ea5adba819880e2409f71c7758c25c084f7b Mon Sep 17 00:00:00 2001 From: Aobo Yang Date: Fri, 16 Dec 2022 18:45:10 -0800 Subject: [PATCH 1/2] update stg sphinx to include inherited methods --- .../binary_concrete_stochastic_gates.py | 16 ---------------- captum/module/gaussian_stochastic_gates.py | 15 --------------- captum/module/stochastic_gates_base.py | 19 ++++++++++++------- sphinx/source/binary_concrete_stg.rst | 1 + sphinx/source/gaussian_stg.rst | 1 + 5 files changed, 14 insertions(+), 38 deletions(-) diff --git a/captum/module/binary_concrete_stochastic_gates.py b/captum/module/binary_concrete_stochastic_gates.py index d6fa318d87..c3cabc39b4 100644 --- a/captum/module/binary_concrete_stochastic_gates.py +++ b/captum/module/binary_concrete_stochastic_gates.py @@ -133,22 +133,6 @@ def __init__( # pre-calculate the fixed term used in active prob self.active_prob_offset = temperature * math.log(-lower_bound / upper_bound) - def forward(self, *args, **kwargs): - """ - Args: - input_tensor (Tensor): Tensor to be gated with stochastic gates - - - Outputs: - gated_input (Tensor): Tensor of the same shape weighted by the sampled - gate values - - l0_reg (Tensor): L0 regularization term to be optimized together with - model loss, - e.g. loss(model_out, target) + l0_reg - """ - return super().forward(*args, **kwargs) - def _sample_gate_values(self, batch_size: int) -> Tensor: """ Sample gate values for each example in the batch from the binary concrete diff --git a/captum/module/gaussian_stochastic_gates.py b/captum/module/gaussian_stochastic_gates.py index b10f837dc1..13054c55f5 100644 --- a/captum/module/gaussian_stochastic_gates.py +++ b/captum/module/gaussian_stochastic_gates.py @@ -78,21 +78,6 @@ def __init__( assert 0 < std, f"the standard deviation should be positive, received {std}" self.std = std - def forward(self, *args, **kwargs): - """ - Args: - input_tensor (Tensor): Tensor to be gated with stochastic gates - - Outputs: - gated_input (Tensor): Tensor of the same shape weighted by the sampled - gate values - - l0_reg (Tensor): L0 regularization term to be optimized together with - model loss, - e.g. loss(model_out, target) + l0_reg - """ - return super().forward(*args, **kwargs) - def _sample_gate_values(self, batch_size: int) -> Tensor: """ Sample gate values for each example in the batch from the Gaussian distribution diff --git a/captum/module/stochastic_gates_base.py b/captum/module/stochastic_gates_base.py index 75eebb2d65..b1ef662b5a 100644 --- a/captum/module/stochastic_gates_base.py +++ b/captum/module/stochastic_gates_base.py @@ -87,11 +87,13 @@ def forward(self, input_tensor: Tensor) -> Tuple[Tensor, Tensor]: input_tensor (Tensor): Tensor to be gated with stochastic gates - Outputs: - gated_input (Tensor): Tensor of the same shape weighted by the sampled + Returns: + tuple[Tensor, Tensor]: + + - gated_input (Tensor): Tensor of the same shape weighted by the sampled gate values - l0_reg (Tensor): L0 regularization term to be optimized together with + - l0_reg (Tensor): L0 regularization term to be optimized together with model loss, e.g. loss(model_out, target) + l0_reg """ @@ -140,9 +142,7 @@ def get_gate_values(self, clamp: bool = True) -> Tensor: Get the gate values, which are the means of the underneath gate distributions, optionally clamped within 0 and 1. - Returns: - gate_values (Tensor): value of each gate in shape(n_gates) - + Args: clamp (bool): if clamp the gate values. As smoothed Bernoulli variables, gate values are clamped withn 0 and 1 by defautl. Turn this off to get the raw means of the underneath @@ -150,6 +150,10 @@ def get_gate_values(self, clamp: bool = True) -> Tensor: differentiate the gates' importance when multiple gate values are beyond 0 or 1. Default: True + + Returns: + Tensor: + - gate_values (Tensor): value of each gate in shape(n_gates) """ gate_values = self._get_gate_values() if clamp: @@ -162,7 +166,8 @@ def get_gate_active_probs(self) -> Tensor: Get the active probability of each gate, i.e, gate value > 0 Returns: - probs (Tensor): probabilities tensor of the gates are active + Tensor: + - probs (Tensor): probabilities tensor of the gates are active in shape(n_gates) """ return self._get_gate_active_probs().detach() diff --git a/sphinx/source/binary_concrete_stg.rst b/sphinx/source/binary_concrete_stg.rst index 0889f6f839..11d4d442a9 100644 --- a/sphinx/source/binary_concrete_stg.rst +++ b/sphinx/source/binary_concrete_stg.rst @@ -3,3 +3,4 @@ BinaryConcreteStochasticGates .. autoclass:: captum.module.BinaryConcreteStochasticGates :members: + :inherited-members: Module diff --git a/sphinx/source/gaussian_stg.rst b/sphinx/source/gaussian_stg.rst index 22e7df6a82..dcecd361f4 100644 --- a/sphinx/source/gaussian_stg.rst +++ b/sphinx/source/gaussian_stg.rst @@ -3,3 +3,4 @@ GaussianStochasticGates .. autoclass:: captum.module.GaussianStochasticGates :members: + :inherited-members: Module From 1ec8b9619cfe46c24cfde89e7fd8992ad4e10533 Mon Sep 17 00:00:00 2001 From: Aobo Yang Date: Mon, 19 Dec 2022 11:37:02 -0800 Subject: [PATCH 2/2] typo --- captum/module/stochastic_gates_base.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/captum/module/stochastic_gates_base.py b/captum/module/stochastic_gates_base.py index b1ef662b5a..16691a4e36 100644 --- a/captum/module/stochastic_gates_base.py +++ b/captum/module/stochastic_gates_base.py @@ -143,10 +143,10 @@ def get_gate_values(self, clamp: bool = True) -> Tensor: optionally clamped within 0 and 1. Args: - clamp (bool): if clamp the gate values. As smoothed Bernoulli - variables, gate values are clamped withn 0 and 1 by defautl. + clamp (bool): whether to clamp the gate values or not. As smoothed Bernoulli + variables, gate values are clamped within 0 and 1 by default. Turn this off to get the raw means of the underneath - distribution (e.g., conrete, gaussian), which can be useful to + distribution (e.g., concrete, gaussian), which can be useful to differentiate the gates' importance when multiple gate values are beyond 0 or 1. Default: True