4646from aesara .tensor .random .opt import local_subtensor_rv_lift
4747from aesara .tensor .random .var import RandomStateSharedVariable
4848from aesara .tensor .sharedvar import ScalarSharedVariable
49- from aesara .tensor .var import TensorVariable
49+ from aesara .tensor .var import TensorConstant , TensorVariable
5050
5151from pymc .aesaraf import (
5252 compile_pymc ,
6161from pymc .distributions import joint_logpt
6262from pymc .distributions .logprob import _get_scaling
6363from pymc .distributions .transforms import _default_transform
64- from pymc .exceptions import ImputationWarning , SamplingError , ShapeError
64+ from pymc .exceptions import ImputationWarning , SamplingError , ShapeError , ShapeWarning
6565from pymc .initial_point import make_initial_point_fn
6666from pymc .math import flatten_list
6767from pymc .util import (
@@ -1179,23 +1179,48 @@ def set_data(
11791179 # Reject resizing if we already know that it would create shape problems.
11801180 # NOTE: If there are multiple pm.MutableData containers sharing this dim, but the user only
11811181 # changes the values for one of them, they will run into shape problems nonetheless.
1182- length_belongs_to = length_tensor .owner .inputs [0 ].owner .inputs [0 ]
1183- if not isinstance (length_belongs_to , SharedVariable ) and length_changed :
1184- raise ShapeError (
1185- f"Resizing dimension '{ dname } ' with values of length { new_length } would lead to incompatibilities, "
1186- f"because the dimension was initialized from '{ length_belongs_to } ' which is not a shared variable. "
1187- f"Check if the dimension was defined implicitly before the shared variable '{ name } ' was created, "
1188- f"for example by a model variable." ,
1189- actual = new_length ,
1190- expected = old_length ,
1191- )
1192- if original_coords is not None and length_changed :
1193- if length_changed and new_coords is None :
1194- raise ValueError (
1195- f"The '{ name } ' variable already had { len (original_coords )} coord values defined for"
1196- f"its { dname } dimension. With the new values this dimension changes to length "
1197- f"{ new_length } , so new coord values for the { dname } dimension are required."
1182+ if length_changed :
1183+ if isinstance (length_tensor , TensorConstant ):
1184+ raise ShapeError (
1185+ f"Resizing dimension '{ dname } ' is impossible, because "
1186+ f"a 'TensorConstant' stores its length. To be able "
1187+ f"to change the dimension length, 'fixed' in "
1188+ f"'model.add_coord' must be set to `False`."
11981189 )
1190+ if length_tensor .owner is None :
1191+ # This is the case if the dimension was initialized
1192+ # from custom coords, but dimension length was not
1193+ # stored in TensorConstant e.g by 'fixed' set to False
1194+
1195+ warnings .warn (
1196+ f"You're changing the shape of a variable "
1197+ f"in the '{ dname } ' dimension which was initialized "
1198+ f"from coords. Make sure to update the corresponding "
1199+ f"coords, otherwise you'll get shape issues." ,
1200+ ShapeWarning ,
1201+ )
1202+ else :
1203+ length_belongs_to = length_tensor .owner .inputs [0 ].owner .inputs [0 ]
1204+ if not isinstance (length_belongs_to , SharedVariable ):
1205+ raise ShapeError (
1206+ f"Resizing dimension '{ dname } ' with values of length { new_length } would lead to incompatibilities, "
1207+ f"because the dimension was initialized from '{ length_belongs_to } ' which is not a shared variable. "
1208+ f"Check if the dimension was defined implicitly before the shared variable '{ name } ' was created, "
1209+ f"for example by another model variable." ,
1210+ actual = new_length ,
1211+ expected = old_length ,
1212+ )
1213+ if original_coords is not None :
1214+ if new_coords is None :
1215+ raise ValueError (
1216+ f"The '{ name } ' variable already had { len (original_coords )} coord values defined for "
1217+ f"its { dname } dimension. With the new values this dimension changes to length "
1218+ f"{ new_length } , so new coord values for the { dname } dimension are required."
1219+ )
1220+ if isinstance (length_tensor , ScalarSharedVariable ):
1221+ # Updating the shared variable resizes dependent nodes that use this dimension for their `size`.
1222+ length_tensor .set_value (new_length )
1223+
11991224 if new_coords is not None :
12001225 # Update the registered coord values (also if they were None)
12011226 if len (new_coords ) != new_length :
@@ -1204,10 +1229,8 @@ def set_data(
12041229 actual = len (new_coords ),
12051230 expected = new_length ,
12061231 )
1207- self ._coords [dname ] = new_coords
1208- if isinstance (length_tensor , ScalarSharedVariable ) and new_length != old_length :
1209- # Updating the shared variable resizes dependent nodes that use this dimension for their `size`.
1210- length_tensor .set_value (new_length )
1232+ # store it as tuple for immutability as in add_coord
1233+ self ._coords [dname ] = tuple (new_coords )
12111234
12121235 shared_object .set_value (values )
12131236
0 commit comments