2222from typing import Dict , Optional , Sequence , Tuple , Union , cast
2323
2424import numpy as np
25+ import pandas as pd
2526import pytensor
2627import pytensor .tensor as at
28+ import xarray as xr
2729
2830from pytensor .compile .sharedvalue import SharedVariable
2931from pytensor .raise_op import Assert
@@ -205,17 +207,17 @@ def Minibatch(variable: TensorVariable, *variables: TensorVariable, batch_size:
205207
206208def determine_coords (
207209 model ,
208- value ,
210+ value : Union [ pd . DataFrame , pd . Series , xr . DataArray ] ,
209211 dims : Optional [Sequence [Optional [str ]]] = None ,
210- coords : Optional [Dict [str , Sequence ]] = None ,
211- ) -> Tuple [Dict [str , Sequence ], Sequence [Optional [str ]]]:
212+ coords : Optional [Dict [str , Union [ Sequence , np . ndarray ] ]] = None ,
213+ ) -> Tuple [Dict [str , Union [ Sequence , np . ndarray ] ], Sequence [Optional [str ]]]:
212214 """Determines coordinate values from data or the model (via ``dims``)."""
213215 if coords is None :
214216 coords = {}
215217
218+ dim_name = None
216219 # If value is a df or a series, we interpret the index as coords:
217220 if hasattr (value , "index" ):
218- dim_name = None
219221 if dims is not None :
220222 dim_name = dims [0 ]
221223 if dim_name is None and value .index .name is not None :
@@ -225,14 +227,20 @@ def determine_coords(
225227
226228 # If value is a df, we also interpret the columns as coords:
227229 if hasattr (value , "columns" ):
228- dim_name = None
229230 if dims is not None :
230231 dim_name = dims [1 ]
231232 if dim_name is None and value .columns .name is not None :
232233 dim_name = value .columns .name
233234 if dim_name is not None :
234235 coords [dim_name ] = value .columns
235236
237+ if isinstance (value , xr .DataArray ):
238+ if dims is not None :
239+ for dim in dims :
240+ dim_name = dim
241+ # str is applied because dim entries may be None
242+ coords [str (dim_name )] = value [dim ].to_numpy ()
243+
236244 if isinstance (value , np .ndarray ) and dims is not None :
237245 if len (dims ) != value .ndim :
238246 raise pm .exceptions .ShapeError (
@@ -257,21 +265,29 @@ def ConstantData(
257265 value ,
258266 * ,
259267 dims : Optional [Sequence [str ]] = None ,
260- coords : Optional [Dict [str , Sequence ]] = None ,
268+ coords : Optional [Dict [str , Union [ Sequence , np . ndarray ] ]] = None ,
261269 export_index_as_coords = False ,
270+ infer_dims_and_coords = False ,
262271 ** kwargs ,
263272) -> TensorConstant :
264273 """Alias for ``pm.Data(..., mutable=False)``.
265274
266275 Registers the ``value`` as a :class:`~pytensor.tensor.TensorConstant` with the model.
267276 For more information, please reference :class:`pymc.Data`.
268277 """
278+ if export_index_as_coords :
279+ infer_dims_and_coords = export_index_as_coords
280+ warnings .warn (
281+ "Deprecation warning: 'export_index_as_coords; is deprecated and will be removed in future versions. Please use 'infer_dims_and_coords' instead." ,
282+ DeprecationWarning ,
283+ )
284+
269285 var = Data (
270286 name ,
271287 value ,
272288 dims = dims ,
273289 coords = coords ,
274- export_index_as_coords = export_index_as_coords ,
290+ infer_dims_and_coords = infer_dims_and_coords ,
275291 mutable = False ,
276292 ** kwargs ,
277293 )
@@ -283,21 +299,29 @@ def MutableData(
283299 value ,
284300 * ,
285301 dims : Optional [Sequence [str ]] = None ,
286- coords : Optional [Dict [str , Sequence ]] = None ,
302+ coords : Optional [Dict [str , Union [ Sequence , np . ndarray ] ]] = None ,
287303 export_index_as_coords = False ,
304+ infer_dims_and_coords = False ,
288305 ** kwargs ,
289306) -> SharedVariable :
290307 """Alias for ``pm.Data(..., mutable=True)``.
291308
292309 Registers the ``value`` as a :class:`~pytensor.compile.sharedvalue.SharedVariable`
293310 with the model. For more information, please reference :class:`pymc.Data`.
294311 """
312+ if export_index_as_coords :
313+ infer_dims_and_coords = export_index_as_coords
314+ warnings .warn (
315+ "Deprecation warning: 'export_index_as_coords; is deprecated and will be removed in future versions. Please use 'infer_dims_and_coords' instead." ,
316+ DeprecationWarning ,
317+ )
318+
295319 var = Data (
296320 name ,
297321 value ,
298322 dims = dims ,
299323 coords = coords ,
300- export_index_as_coords = export_index_as_coords ,
324+ infer_dims_and_coords = infer_dims_and_coords ,
301325 mutable = True ,
302326 ** kwargs ,
303327 )
@@ -309,8 +333,9 @@ def Data(
309333 value ,
310334 * ,
311335 dims : Optional [Sequence [str ]] = None ,
312- coords : Optional [Dict [str , Sequence ]] = None ,
336+ coords : Optional [Dict [str , Union [ Sequence , np . ndarray ] ]] = None ,
313337 export_index_as_coords = False ,
338+ infer_dims_and_coords = False ,
314339 mutable : Optional [bool ] = None ,
315340 ** kwargs ,
316341) -> Union [SharedVariable , TensorConstant ]:
@@ -347,7 +372,9 @@ def Data(
347372 names.
348373 coords : dict, optional
349374 Coordinate values to set for new dimensions introduced by this ``Data`` variable.
350- export_index_as_coords : bool, default=False
375+ export_index_as_coords : bool
376+ Deprecated, previous version of "infer_dims_and_coords"
377+ infer_dims_and_coords : bool, default=False
351378 If True, the ``Data`` container will try to infer what the coordinates
352379 and dimension names should be if there is an index in ``value``.
353380 mutable : bool, optional
@@ -427,6 +454,13 @@ def Data(
427454
428455 # Optionally infer coords and dims from the input value.
429456 if export_index_as_coords :
457+ infer_dims_and_coords = export_index_as_coords
458+ warnings .warn (
459+ "Deprecation warning: 'export_index_as_coords; is deprecated and will be removed in future versions. Please use 'infer_dims_and_coords' instead." ,
460+ DeprecationWarning ,
461+ )
462+
463+ if infer_dims_and_coords :
430464 coords , dims = determine_coords (model , value , dims )
431465
432466 if dims :
0 commit comments