Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 31 additions & 4 deletions captum/module/binary_concrete_stochastic_gates.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#!/usr/bin/env python3
import math
from typing import Optional
from typing import Optional, Tuple

import torch
from captum.module.stochastic_gates_base import StochasticGatesBase
Expand Down Expand Up @@ -133,12 +133,11 @@ def __init__(
# pre-calculate the fixed term used in active prob
self.active_prob_offset = temperature * math.log(-lower_bound / upper_bound)

def forward(self, *args, **kwargs):
def forward(self, input_tensor: Tensor) -> Tuple[Tensor, Tensor]:
"""
Args:
input_tensor (Tensor): Tensor to be gated with stochastic gates


Outputs:
gated_input (Tensor): Tensor of the same shape weighted by the sampled
gate values
Expand All @@ -147,7 +146,35 @@ def forward(self, *args, **kwargs):
model loss,
e.g. loss(model_out, target) + l0_reg
"""
return super().forward(*args, **kwargs)
return super().forward(input_tensor)

def get_gate_values(self, clamp: bool = True) -> Tensor:
"""
Get the gate values, which are the means of the underneath gate distributions,
optionally clamped within 0 and 1.

Returns:
gate_values (Tensor): value of each gate in shape(n_gates)

clamp (bool): if clamp the gate values. As smoothed Bernoulli
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: maybe indicates whether to clamp the gate values or not ? (also in gaussian stochastic gates)

variables, gate values are clamped withn 0 and 1 by defautl.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: default ?

Turn this off to get the raw means of the underneath
distribution (e.g., conrete, gaussian), which can be useful to
differentiate the gates' importance when multiple gate
values are beyond 0 or 1.
Default: True
"""
return super().get_gate_values(clamp)

def get_gate_active_probs(self) -> Tensor:
"""
Get the active probability of each gate, i.e, gate value > 0

Returns:
probs (Tensor): probabilities tensor of the gates are active
in shape(n_gates)
"""
return super().get_gate_active_probs()

def _sample_gate_values(self, batch_size: int) -> Tensor:
"""
Expand Down
34 changes: 31 additions & 3 deletions captum/module/gaussian_stochastic_gates.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#!/usr/bin/env python3
import math
from typing import Optional
from typing import Optional, Tuple

import torch
from captum.module.stochastic_gates_base import StochasticGatesBase
Expand Down Expand Up @@ -78,7 +78,7 @@ def __init__(
assert 0 < std, f"the standard deviation should be positive, received {std}"
self.std = std

def forward(self, *args, **kwargs):
def forward(self, input_tensor: Tensor) -> Tuple[Tensor, Tensor]:
"""
Args:
input_tensor (Tensor): Tensor to be gated with stochastic gates
Expand All @@ -91,7 +91,35 @@ def forward(self, *args, **kwargs):
model loss,
e.g. loss(model_out, target) + l0_reg
"""
return super().forward(*args, **kwargs)
return super().forward(input_tensor)

def get_gate_values(self, clamp: bool = True) -> Tensor:
"""
Get the gate values, which are the means of the underneath gate distributions,
optionally clamped within 0 and 1.

Returns:
gate_values (Tensor): value of each gate in shape(n_gates)

clamp (bool): if clamp the gate values. As smoothed Bernoulli
variables, gate values are clamped withn 0 and 1 by defautl.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: default ?

Turn this off to get the raw means of the underneath
distribution (e.g., conrete, gaussian), which can be useful to
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: concrete ?

differentiate the gates' importance when multiple gate
values are beyond 0 or 1.
Default: True
"""
return super().get_gate_values(clamp)

def get_gate_active_probs(self) -> Tensor:
"""
Get the active probability of each gate, i.e, gate value > 0

Returns:
probs (Tensor): probabilities tensor of the gates are active
in shape(n_gates)
"""
return super().get_gate_active_probs()

def _sample_gate_values(self, batch_size: int) -> Tensor:
"""
Expand Down