1717import os
1818import pkgutil
1919import urllib .request
20+ import warnings
2021
2122from copy import copy
22- from typing import Any , Dict , List , Sequence
23+ from typing import Any , Dict , List , Optional , Sequence , Union
2324
2425import aesara
2526import aesara .tensor as at
2627import numpy as np
2728import pandas as pd
2829
30+ from aesara .compile .sharedvalue import SharedVariable
2931from aesara .graph .basic import Apply
3032from aesara .tensor .type import TensorType
31- from aesara .tensor .var import TensorVariable
33+ from aesara .tensor .var import TensorConstant , TensorVariable
34+ from packaging import version
3235
3336import pymc as pm
3437
4043 "Minibatch" ,
4144 "align_minibatches" ,
4245 "Data" ,
46+ "ConstantData" ,
47+ "MutableData" ,
4348]
4449BASE_URL = "https://raw.githubusercontent.com/pymc-devs/pymc-examples/main/examples/data/{filename}"
4550
@@ -463,16 +468,115 @@ def align_minibatches(batches=None):
463468 rng .seed ()
464469
465470
466- class Data :
467- """Data container class that wraps :func:`aesara.shared` and lets
468- the model be aware of its inputs and outputs.
471+ def determine_coords (model , value , dims : Optional [Sequence [str ]] = None ) -> Dict [str , Sequence ]:
472+ """Determines coordinate values from data or the model (via ``dims``)."""
473+ coords = {}
474+
475+ # If value is a df or a series, we interpret the index as coords:
476+ if isinstance (value , (pd .Series , pd .DataFrame )):
477+ dim_name = None
478+ if dims is not None :
479+ dim_name = dims [0 ]
480+ if dim_name is None and value .index .name is not None :
481+ dim_name = value .index .name
482+ if dim_name is not None :
483+ coords [dim_name ] = value .index
484+
485+ # If value is a df, we also interpret the columns as coords:
486+ if isinstance (value , pd .DataFrame ):
487+ dim_name = None
488+ if dims is not None :
489+ dim_name = dims [1 ]
490+ if dim_name is None and value .columns .name is not None :
491+ dim_name = value .columns .name
492+ if dim_name is not None :
493+ coords [dim_name ] = value .columns
494+
495+ if isinstance (value , np .ndarray ) and dims is not None :
496+ if len (dims ) != value .ndim :
497+ raise pm .exceptions .ShapeError (
498+ "Invalid data shape. The rank of the dataset must match the " "length of `dims`." ,
499+ actual = value .shape ,
500+ expected = value .ndim ,
501+ )
502+ for size , dim in zip (value .shape , dims ):
503+ coord = model .coords .get (dim , None )
504+ if coord is None :
505+ coords [dim ] = pd .RangeIndex (size , name = dim )
506+
507+ return coords
508+
509+
510+ def ConstantData (
511+ name : str ,
512+ value ,
513+ * ,
514+ dims : Optional [Sequence [str ]] = None ,
515+ export_index_as_coords = False ,
516+ ** kwargs ,
517+ ) -> TensorConstant :
518+ """Alias for ``pm.Data(..., mutable=False)``.
519+
520+ Registers the ``value`` as a ``TensorConstant`` with the model.
521+ """
522+ return Data (
523+ name ,
524+ value ,
525+ dims = dims ,
526+ export_index_as_coords = export_index_as_coords ,
527+ mutable = False ,
528+ ** kwargs ,
529+ )
530+
531+
532+ def MutableData (
533+ name : str ,
534+ value ,
535+ * ,
536+ dims : Optional [Sequence [str ]] = None ,
537+ export_index_as_coords = False ,
538+ ** kwargs ,
539+ ) -> SharedVariable :
540+ """Alias for ``pm.Data(..., mutable=True)``.
541+
542+ Registers the ``value`` as a ``SharedVariable`` with the model.
543+ """
544+ return Data (
545+ name ,
546+ value ,
547+ dims = dims ,
548+ export_index_as_coords = export_index_as_coords ,
549+ mutable = True ,
550+ ** kwargs ,
551+ )
552+
553+
554+ def Data (
555+ name : str ,
556+ value ,
557+ * ,
558+ dims : Optional [Sequence [str ]] = None ,
559+ export_index_as_coords = False ,
560+ mutable : Optional [bool ] = None ,
561+ ** kwargs ,
562+ ) -> Union [SharedVariable , TensorConstant ]:
563+ """Data container that registers a data variable with the model.
564+
565+ Depending on the ``mutable`` setting (default: True), the variable
566+ is registered as a ``SharedVariable``, enabling it to be altered
567+ in value and shape, but NOT in dimensionality using ``pm.set_data()``.
469568
470569 Parameters
471570 ----------
472571 name: str
473572 The name for this variable
474573 value: {List, np.ndarray, pd.Series, pd.Dataframe}
475574 A value to associate with this variable
575+ mutable : bool, optional
576+ Switches between creating a ``SharedVariable`` (``mutable=True``, default)
577+ vs. creating a ``TensorConstant`` (``mutable=False``).
578+ Consider using ``pm.ConstantData`` or ``pm.MutableData`` as less verbose
579+ alternatives to ``pm.Data(..., mutable=...)``.
476580 dims: {str, tuple of str}, optional, default=None
477581 Dimension names of the random variables (as opposed to the shapes of these
478582 random variables). Use this when `value` is a pandas Series or DataFrame. The
@@ -495,7 +599,7 @@ class Data:
495599 >>> observed_data = [mu + np.random.randn(20) for mu in true_mu]
496600
497601 >>> with pm.Model() as model:
498- ... data = pm.Data ('data', observed_data[0])
602+ ... data = pm.MutableData ('data', observed_data[0])
499603 ... mu = pm.Normal('mu', 0, 10)
500604 ... pm.Normal('y', mu=mu, sigma=1, observed=data)
501605
@@ -513,104 +617,58 @@ class Data:
513617 For more information, take a look at this example notebook
514618 https://docs.pymc.io/notebooks/data_container.html
515619 """
620+ if isinstance (value , list ):
621+ value = np .array (value )
516622
517- def __new__ (
518- self ,
519- name ,
520- value ,
521- * ,
522- dims = None ,
523- export_index_as_coords = False ,
524- ** kwargs ,
525- ):
526- if isinstance (value , list ):
527- value = np .array (value )
528-
529- # Add data container to the named variables of the model.
530- try :
531- model = pm .Model .get_context ()
532- except TypeError :
533- raise TypeError (
534- "No model on context stack, which is needed to instantiate a data container. "
535- "Add variable inside a 'with model:' block."
536- )
537- name = model .name_for (name )
538-
539- # `pandas_to_array` takes care of parameter `value` and
540- # transforms it to something digestible for pymc
541- shared_object = aesara .shared (pandas_to_array (value ), name , ** kwargs )
542-
543- if isinstance (dims , str ):
544- dims = (dims ,)
545- if not (dims is None or len (dims ) == shared_object .ndim ):
546- raise pm .exceptions .ShapeError (
547- "Length of `dims` must match the dimensions of the dataset." ,
548- actual = len (dims ),
549- expected = shared_object .ndim ,
623+ # Add data container to the named variables of the model.
624+ try :
625+ model = pm .Model .get_context ()
626+ except TypeError :
627+ raise TypeError (
628+ "No model on context stack, which is needed to instantiate a data container. "
629+ "Add variable inside a 'with model:' block."
630+ )
631+ name = model .name_for (name )
632+
633+ # `pandas_to_array` takes care of parameter `value` and
634+ # transforms it to something digestible for Aesara.
635+ arr = pandas_to_array (value )
636+
637+ if mutable is None :
638+ current = version .Version (pm .__version__ )
639+ mutable = current .major == 4 and current .minor < 1
640+ if mutable :
641+ warnings .warn (
642+ "The `mutable` kwarg was not specified. Currently it defaults to `pm.Data(mutable=True)`,"
643+ " which is equivalent to using `pm.MutableData()`."
644+ " In v4.1.0 the default will change to `pm.Data(mutable=False)`, equivalent to `pm.ConstantData`."
645+ " Set `pm.Data(..., mutable=False/True)`, or use `pm.ConstantData`/`pm.MutableData`." ,
646+ FutureWarning ,
550647 )
551-
552- coords = self .set_coords (model , value , dims )
553-
554- if export_index_as_coords :
555- model .add_coords (coords )
556- elif dims :
557- # Register new dimension lengths
558- for d , dname in enumerate (dims ):
559- if not dname in model .dim_lengths :
560- model .add_coord (dname , values = None , length = shared_object .shape [d ])
561-
562- # To draw the node for this variable in the graphviz Digraph we need
563- # its shape.
564- # XXX: This needs to be refactored
565- # shared_object.dshape = tuple(shared_object.shape.eval())
566- # if dims is not None:
567- # shape_dims = model.shape_from_dims(dims)
568- # if shared_object.dshape != shape_dims:
569- # raise pm.exceptions.ShapeError(
570- # "Data shape does not match with specified `dims`.",
571- # actual=shared_object.dshape,
572- # expected=shape_dims,
573- # )
574-
575- model .add_random_variable (shared_object , dims = dims )
576-
577- return shared_object
578-
579- @staticmethod
580- def set_coords (model , value , dims = None ) -> Dict [str , Sequence ]:
581- coords = {}
582-
583- # If value is a df or a series, we interpret the index as coords:
584- if isinstance (value , (pd .Series , pd .DataFrame )):
585- dim_name = None
586- if dims is not None :
587- dim_name = dims [0 ]
588- if dim_name is None and value .index .name is not None :
589- dim_name = value .index .name
590- if dim_name is not None :
591- coords [dim_name ] = value .index
592-
593- # If value is a df, we also interpret the columns as coords:
594- if isinstance (value , pd .DataFrame ):
595- dim_name = None
596- if dims is not None :
597- dim_name = dims [1 ]
598- if dim_name is None and value .columns .name is not None :
599- dim_name = value .columns .name
600- if dim_name is not None :
601- coords [dim_name ] = value .columns
602-
603- if isinstance (value , np .ndarray ) and dims is not None :
604- if len (dims ) != value .ndim :
605- raise pm .exceptions .ShapeError (
606- "Invalid data shape. The rank of the dataset must match the "
607- "length of `dims`." ,
608- actual = value .shape ,
609- expected = value .ndim ,
610- )
611- for size , dim in zip (value .shape , dims ):
612- coord = model .coords .get (dim , None )
613- if coord is None :
614- coords [dim ] = pd .RangeIndex (size , name = dim )
615-
616- return coords
648+ if mutable :
649+ x = aesara .shared (arr , name , ** kwargs )
650+ else :
651+ x = at .as_tensor_variable (arr , name , ** kwargs )
652+
653+ if isinstance (dims , str ):
654+ dims = (dims ,)
655+ if not (dims is None or len (dims ) == x .ndim ):
656+ raise pm .exceptions .ShapeError (
657+ "Length of `dims` must match the dimensions of the dataset." ,
658+ actual = len (dims ),
659+ expected = x .ndim ,
660+ )
661+
662+ coords = determine_coords (model , value , dims )
663+
664+ if export_index_as_coords :
665+ model .add_coords (coords )
666+ elif dims :
667+ # Register new dimension lengths
668+ for d , dname in enumerate (dims ):
669+ if not dname in model .dim_lengths :
670+ model .add_coord (dname , values = None , length = x .shape [d ])
671+
672+ model .add_random_variable (x , dims = dims )
673+
674+ return x
0 commit comments