@@ -87,11 +87,13 @@ def forward(self, input_tensor: Tensor) -> Tuple[Tensor, Tensor]:
8787 input_tensor (Tensor): Tensor to be gated with stochastic gates
8888
8989
90- Outputs:
91- gated_input (Tensor): Tensor of the same shape weighted by the sampled
90+ Returns:
91+ tuple[Tensor, Tensor]:
92+
93+ - gated_input (Tensor): Tensor of the same shape weighted by the sampled
9294 gate values
9395
94- l0_reg (Tensor): L0 regularization term to be optimized together with
96+ - l0_reg (Tensor): L0 regularization term to be optimized together with
9597 model loss,
9698 e.g. loss(model_out, target) + l0_reg
9799 """
@@ -140,16 +142,18 @@ def get_gate_values(self, clamp: bool = True) -> Tensor:
140142 Get the gate values, which are the means of the underneath gate distributions,
141143 optionally clamped within 0 and 1.
142144
143- Returns:
144- gate_values (Tensor): value of each gate in shape(n_gates)
145-
146- clamp (bool): if clamp the gate values. As smoothed Bernoulli
147- variables, gate values are clamped withn 0 and 1 by defautl.
145+ Args:
146+ clamp (bool): whether to clamp the gate values or not. As smoothed Bernoulli
147+ variables, gate values are clamped within 0 and 1 by default.
148148 Turn this off to get the raw means of the underneath
149- distribution (e.g., conrete , gaussian), which can be useful to
149+ distribution (e.g., concrete , gaussian), which can be useful to
150150 differentiate the gates' importance when multiple gate
151151 values are beyond 0 or 1.
152152 Default: True
153+
154+ Returns:
155+ Tensor:
156+ - gate_values (Tensor): value of each gate in shape(n_gates)
153157 """
154158 gate_values = self ._get_gate_values ()
155159 if clamp :
@@ -162,7 +166,8 @@ def get_gate_active_probs(self) -> Tensor:
162166 Get the active probability of each gate, i.e, gate value > 0
163167
164168 Returns:
165- probs (Tensor): probabilities tensor of the gates are active
169+ Tensor:
170+ - probs (Tensor): probabilities tensor of the gates are active
166171 in shape(n_gates)
167172 """
168173 return self ._get_gate_active_probs ().detach ()
0 commit comments