Skip to content

Commit 85fe968

Browse files
committed
Make Constant and Shared variables subclasses of the basic Variable types
1 parent 082081a commit 85fe968

File tree

3 files changed

+5
-5
lines changed

3 files changed

+5
-5
lines changed

pytensor/sparse/basic.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -478,7 +478,7 @@ def pytensor_hash(self):
478478
return hash_from_sparse(d)
479479

480480

481-
class SparseConstant(TensorConstant, _sparse_py_operators):
481+
class SparseConstant(SparseVariable, TensorConstant):
482482
format = property(lambda self: self.type.format)
483483

484484
def signature(self):

pytensor/sparse/sharedvar.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,11 @@
33
import scipy.sparse
44

55
from pytensor.compile import shared_constructor
6-
from pytensor.sparse.basic import SparseTensorType, _sparse_py_operators
6+
from pytensor.sparse.basic import SparseTensorType, SparseVariable
77
from pytensor.tensor.sharedvar import TensorSharedVariable
88

99

10-
class SparseTensorSharedVariable(TensorSharedVariable, _sparse_py_operators):
10+
class SparseTensorSharedVariable(TensorSharedVariable, SparseVariable):
1111
@property
1212
def format(self):
1313
return self.type.format

pytensor/tensor/sharedvar.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from pytensor.misc.safe_asarray import _asarray
77
from pytensor.tensor import _get_vector_length
88
from pytensor.tensor.type import TensorType
9-
from pytensor.tensor.variable import _tensor_py_operators
9+
from pytensor.tensor.variable import TensorVariable
1010

1111

1212
def load_shared_variable(val):
@@ -19,7 +19,7 @@ def load_shared_variable(val):
1919
return tensor_constructor(val)
2020

2121

22-
class TensorSharedVariable(_tensor_py_operators, SharedVariable):
22+
class TensorSharedVariable(SharedVariable, TensorVariable):
2323
def zero(self, borrow: bool = False):
2424
r"""Set the values of a shared variable to 0.
2525

0 commit comments

Comments
 (0)