diff --git a/captum/module/stochastic_gates_base.py b/captum/module/stochastic_gates_base.py index b34a4d5f4..7c9d752c0 100644 --- a/captum/module/stochastic_gates_base.py +++ b/captum/module/stochastic_gates_base.py @@ -30,6 +30,8 @@ class StochasticGatesBase(Module, ABC): extend this class and implement the distribution specific functions. """ + mask: Optional[Tensor] + def __init__( self, n_gates: int, diff --git a/setup.py b/setup.py index bb1126589..f6c62d8a0 100644 --- a/setup.py +++ b/setup.py @@ -81,6 +81,7 @@ def report(*args): "ufmt", "scikit-learn", "annoy", + "click<8.2.0", ] )