Skip to content
Merged
7 changes: 5 additions & 2 deletions pymc/backends/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@
from pymc.backends.arviz import predictions_to_inference_data, to_inference_data
from pymc.backends.base import BaseTrace, IBaseTrace
from pymc.backends.ndarray import NDArray
from pymc.blocking import PointType
from pymc.model import Model
from pymc.step_methods.compound import BlockedStep, CompoundStep

Expand Down Expand Up @@ -100,11 +101,12 @@ def _init_trace(
trace: BaseTrace | None,
model: Model,
trace_vars: list[TensorVariable] | None = None,
initial_point: PointType | None = None,
) -> BaseTrace:
"""Initialize a trace backend for a chain."""
strace: BaseTrace
if trace is None:
strace = NDArray(model=model, vars=trace_vars)
strace = NDArray(model=model, vars=trace_vars, test_point=initial_point)
elif isinstance(trace, BaseTrace):
if len(trace) > 0:
raise ValueError("Continuation of traces is no longer supported.")
Expand All @@ -122,7 +124,7 @@ def init_traces(
chains: int,
expected_length: int,
step: BlockedStep | CompoundStep,
initial_point: Mapping[str, np.ndarray],
initial_point: PointType,
model: Model,
trace_vars: list[TensorVariable] | None = None,
) -> tuple[RunType | None, Sequence[IBaseTrace]]:
Expand All @@ -145,6 +147,7 @@ def init_traces(
trace=backend,
model=model,
trace_vars=trace_vars,
initial_point=initial_point,
)
for chain_number in range(chains)
]
Expand Down
53 changes: 37 additions & 16 deletions pymc/backends/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,11 @@
)

import numpy as np
import pytensor

from pymc.backends.report import SamplerReport
from pymc.model import modelcontext
from pymc.pytensorf import compile_pymc
from pymc.util import get_var_name

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -147,32 +149,51 @@ class BaseTrace(IBaseTrace):
use different test point that might be with changed variables shapes
"""

def __init__(self, name, model=None, vars=None, test_point=None):
self.name = name

def __init__(
self,
name=None,
model=None,
vars=None,
test_point=None,
*,
fn=None,
var_shapes=None,
var_dtypes=None,
):
model = modelcontext(model)
self.model = model

if vars is None:
vars = model.unobserved_value_vars

unnamed_vars = {var for var in vars if var.name is None}
if unnamed_vars:
raise Exception(f"Can't trace unnamed variables: {unnamed_vars}")
self.vars = vars
self.varnames = [var.name for var in vars]
self.fn = model.compile_fn(vars, inputs=model.value_vars, on_unused_input="ignore")

if fn is None:
# borrow=True avoids deepcopy when inputs=output which is the case for untransformed value variables
fn = compile_pymc(
inputs=[pytensor.In(v, borrow=True) for v in model.value_vars],
outputs=[pytensor.Out(v, borrow=True) for v in vars],
on_unused_input="ignore",
)
fn.trust_input = True

# Get variable shapes. Most backends will need this
# information.
if test_point is None:
test_point = model.initial_point()
else:
test_point_ = model.initial_point().copy()
test_point_.update(test_point)
test_point = test_point_
var_values = list(zip(self.varnames, self.fn(test_point)))
self.var_shapes = {var: value.shape for var, value in var_values}
self.var_dtypes = {var: value.dtype for var, value in var_values}
if var_shapes is None or var_dtypes is None:
if test_point is None:
test_point = model.initial_point()
var_values = tuple(zip(vars, fn(**test_point)))
var_shapes = {var.name: value.shape for var, value in var_values}
var_dtypes = {var.name: value.dtype for var, value in var_values}

self.name = name
self.model = model
self.fn = fn
self.vars = vars
self.varnames = [var.name for var in vars]
self.var_shapes = var_shapes
self.var_dtypes = var_dtypes
self.chain = None
self._is_base_setup = False
self.sampler_vars = None
Expand Down
29 changes: 18 additions & 11 deletions pymc/backends/ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,8 @@ class NDArray(base.BaseTrace):
`model.unobserved_RVs` is used.
"""

def __init__(self, name=None, model=None, vars=None, test_point=None):
super().__init__(name, model, vars, test_point)
def __init__(self, name=None, model=None, vars=None, test_point=None, **kwargs):
super().__init__(name, model, vars, test_point, **kwargs)
self.draw_idx = 0
self.draws = None
self.samples = {}
Expand Down Expand Up @@ -76,7 +76,7 @@ def setup(self, draws, chain, sampler_vars=None) -> None:
else: # Otherwise, make array of zeros for each variable.
self.draws = draws
for varname, shape in self.var_shapes.items():
self.samples[varname] = np.zeros((draws, *shape), dtype=self.var_dtypes[varname])
self.samples[varname] = np.empty((draws, *shape), dtype=self.var_dtypes[varname])

if sampler_vars is None:
return
Expand Down Expand Up @@ -105,17 +105,18 @@ def record(self, point, sampler_stats=None) -> None:
point: dict
Values mapped to variable names
"""
for varname, value in zip(self.varnames, self.fn(point)):
self.samples[varname][self.draw_idx] = value
samples = self.samples
draw_idx = self.draw_idx
for varname, value in zip(self.varnames, self.fn(*point.values())):
samples[varname][draw_idx] = value

if self._stats is not None and sampler_stats is None:
raise ValueError("Expected sampler_stats")
if self._stats is None and sampler_stats is not None:
raise ValueError("Unknown sampler_stats")
if sampler_stats is not None:
for data, vars in zip(self._stats, sampler_stats):
for key, val in vars.items():
data[key][self.draw_idx] = val
data[key][draw_idx] = val
elif self._stats is not None:
raise ValueError("Expected sampler_stats")

self.draw_idx += 1

def _get_sampler_stats(
Expand Down Expand Up @@ -166,7 +167,13 @@ def _slice(self, idx: slice):
# Only the first `draw_idx` value are valid because of preallocation
idx = slice(*idx.indices(len(self)))

sliced = NDArray(model=self.model, vars=self.vars)
sliced = type(self)(
model=self.model,
vars=self.vars,
fn=self.fn,
var_shapes=self.var_shapes,
var_dtypes=self.var_dtypes,
)
sliced.chain = self.chain
sliced.samples = {varname: values[idx] for varname, values in self.samples.items()}
sliced.sampler_vars = self.sampler_vars
Expand Down
25 changes: 10 additions & 15 deletions pymc/blocking.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,11 @@
StatShape: TypeAlias = Sequence[int | None] | None


# `point_map_info` is a tuple of tuples containing `(name, shape, dtype)` for
# `point_map_info` is a tuple of tuples containing `(name, shape, size, dtype)` for
# each of the raveled variables.
class RaveledVars(NamedTuple):
data: np.ndarray
point_map_info: tuple[tuple[str, tuple[int, ...], np.dtype], ...]
point_map_info: tuple[tuple[str, tuple[int, ...], int, np.dtype], ...]


class Compose(Generic[T]):
Expand All @@ -67,10 +67,9 @@ class DictToArrayBijection:
@staticmethod
def map(var_dict: PointType) -> RaveledVars:
"""Map a dictionary of names and variables to a concatenated 1D array space."""
vars_info = tuple((v, k, v.shape, v.dtype) for k, v in var_dict.items())
raveled_vars = [v[0].ravel() for v in vars_info]
if raveled_vars:
result = np.concatenate(raveled_vars)
vars_info = tuple((v, k, v.shape, v.size, v.dtype) for k, v in var_dict.items())
if vars_info:
result = np.concatenate(tuple(v[0].ravel() for v in vars_info))
else:
result = np.array([])
return RaveledVars(result, tuple(v[1:] for v in vars_info))
Expand All @@ -91,19 +90,15 @@ def rmap(

"""
if start_point:
result = dict(start_point)
result = start_point.copy()
else:
result = {}

if not isinstance(array, RaveledVars):
raise TypeError("`array` must be a `RaveledVars` type")

last_idx = 0
for name, shape, dtype in array.point_map_info:
arr_len = np.prod(shape, dtype=int)
var = array.data[last_idx : last_idx + arr_len].reshape(shape).astype(dtype)
result[name] = var
last_idx += arr_len
for name, shape, size, dtype in array.point_map_info:
end = last_idx + size
result[name] = array.data[last_idx:end].reshape(shape).astype(dtype)
last_idx = end

return result

Expand Down
82 changes: 53 additions & 29 deletions pymc/model/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@
ShapeError,
ShapeWarning,
)
from pymc.initial_point import make_initial_point_fn
from pymc.initial_point import PointType, make_initial_point_fn
from pymc.logprob.basic import transformed_conditional_logp
from pymc.logprob.transforms import Transform
from pymc.logprob.utils import ParameterValueError, replace_rvs_by_values
Expand All @@ -61,6 +61,7 @@
gradient,
hessian,
inputvars,
join_nonshared_inputs,
rewrite_pregrad,
)
from pymc.util import (
Expand Down Expand Up @@ -172,6 +173,9 @@
dtype=None,
casting="no",
compute_grads=True,
model=None,
initial_point: PointType | None = None,
ravel_inputs: bool | None = None,
**kwargs,
):
if extra_vars_and_values is None:
Expand Down Expand Up @@ -219,9 +223,7 @@
givens = []
self._extra_vars_shared = {}
for var, value in extra_vars_and_values.items():
shared = pytensor.shared(
value, var.name + "_shared__", shape=[1 if s == 1 else None for s in value.shape]
)
shared = pytensor.shared(value, var.name + "_shared__", shape=value.shape)
self._extra_vars_shared[var.name] = shared
givens.append((var, shared))

Expand All @@ -231,13 +233,28 @@
grads = pytensor.grad(cost, grad_vars, disconnected_inputs="ignore")
for grad_wrt, var in zip(grads, grad_vars):
grad_wrt.name = f"{var.name}_grad"
outputs = [cost, *grads]
grads = pt.join(0, *[pt.atleast_1d(grad.ravel()) for grad in grads])
outputs = [cost, grads]
else:
outputs = [cost]

inputs = grad_vars
if ravel_inputs:
if initial_point is None:
initial_point = modelcontext(model).initial_point()

Check warning on line 243 in pymc/model/core.py

View check run for this annotation

Codecov / codecov/patch

pymc/model/core.py#L243

Added line #L243 was not covered by tests
outputs, raveled_grad_vars = join_nonshared_inputs(
point=initial_point, inputs=grad_vars, outputs=outputs, make_inputs_shared=False
)
inputs = [raveled_grad_vars]
else:
if ravel_inputs is None:
warnings.warn(
"ValueGradFunction will become a function of raveled inputs.\n"
"Specify `ravel_inputs` to suppress this warning. Note that setting `ravel_inputs=False` will be forbidden in a future release."
)
inputs = grad_vars

self._pytensor_function = compile_pymc(inputs, outputs, givens=givens, **kwargs)
self._raveled_inputs = ravel_inputs

def set_weights(self, values):
if values.shape != (self._n_costs - 1,):
Expand All @@ -247,38 +264,29 @@
def set_extra_values(self, extra_vars):
self._extra_are_set = True
for var in self._extra_vars:
self._extra_vars_shared[var.name].set_value(extra_vars[var.name])
self._extra_vars_shared[var.name].set_value(extra_vars[var.name], borrow=True)

def get_extra_values(self):
if not self._extra_are_set:
raise ValueError("Extra values are not set.")

return {var.name: self._extra_vars_shared[var.name].get_value() for var in self._extra_vars}

def __call__(self, grad_vars, grad_out=None, extra_vars=None):
def __call__(self, grad_vars, *, extra_vars=None):
if extra_vars is not None:
self.set_extra_values(extra_vars)

if not self._extra_are_set:
elif not self._extra_are_set:
raise ValueError("Extra values are not set.")

if isinstance(grad_vars, RaveledVars):
grad_vars = list(DictToArrayBijection.rmap(grad_vars).values())

cost, *grads = self._pytensor_function(*grad_vars)

if grads:
grads_raveled = DictToArrayBijection.map(
{v.name: gv for v, gv in zip(self._grad_vars, grads)}
)

if grad_out is None:
return cost, grads_raveled.data
if self._raveled_inputs:
grad_vars = (grad_vars.data,)
else:
np.copyto(grad_out, grads_raveled.data)
return cost
else:
return cost
grad_vars = DictToArrayBijection.rmap(grad_vars).values()
elif self._raveled_inputs and not isinstance(grad_vars, Sequence):
grad_vars = (grad_vars,)

return self._pytensor_function(*grad_vars)

@property
def profile(self):
Expand Down Expand Up @@ -521,7 +529,14 @@
def isroot(self):
return self.parent is None

def logp_dlogp_function(self, grad_vars=None, tempered=False, **kwargs):
def logp_dlogp_function(
self,
grad_vars=None,
tempered=False,
initial_point: PointType | None = None,
ravel_inputs: bool | None = None,
**kwargs,
):
"""Compile a PyTensor function that computes logp and gradient.

Parameters
Expand All @@ -547,13 +562,22 @@
costs = [self.logp()]

input_vars = {i for i in graph_inputs(costs) if not isinstance(i, Constant)}
ip = self.initial_point(0)
if initial_point is None:
initial_point = self.initial_point(0)
extra_vars_and_values = {
var: ip[var.name]
var: initial_point[var.name]
for var in self.value_vars
if var in input_vars and var not in grad_vars
}
return ValueGradFunction(costs, grad_vars, extra_vars_and_values, **kwargs)
return ValueGradFunction(
costs,
grad_vars,
extra_vars_and_values,
model=self,
initial_point=initial_point,
ravel_inputs=ravel_inputs,
**kwargs,
)

def compile_logp(
self,
Expand Down
Loading