2222
2323import pymc as pm
2424
25- from pymc .variational import opvi , test_functions
25+ from pymc .variational import test_functions
2626from pymc .variational .approximations import Empirical , FullRank , MeanField
2727from pymc .variational .operators import KL , KSD
2828
@@ -334,7 +334,7 @@ class ADVI(KLqp):
334334 The last ones are local random variables
335335 :math:`{\cal Z}=\{\mathbf{z}_{i}\}_{i=1}^{N}`, where
336336 :math:`\mathbf{z}_{i}=\{\mathbf{z}_{i}^{k}\}_{k=1}^{V_{l}}`.
337- These RVs are used only in AEVB.
337+ These RVs are used only in AEVB (which is not implemented in PyMC) .
338338
339339 The goal of ADVI is to approximate the posterior distribution
340340 :math:`p(\Theta,{\cal Z}|{\cal Y})` by variational posterior
@@ -408,8 +408,8 @@ class ADVI(KLqp):
408408
409409 - The probabilistic model
410410
411- `model` with three types of RVs (`observed_RVs`,
412- `global_RVs` and `local_RVs` ).
411+ `model` with two types of RVs (`observed_RVs`,
412+ `global_RVs`).
413413
414414 - (optional) Minibatches
415415
@@ -428,10 +428,6 @@ class ADVI(KLqp):
428428
429429 Parameters
430430 ----------
431- local_rv: dict[var->tuple]
432- mapping {model_variable -> approx params}
433- Local Vars are used for Autoencoding Variational Bayes
434- See (AEVB; Kingma and Welling, 2014) for details
435431 model: :class:`pymc.Model`
436432 PyMC model for inference
437433 random_seed: None or int
@@ -463,10 +459,6 @@ class FullRankADVI(KLqp):
463459
464460 Parameters
465461 ----------
466- local_rv: dict[var->tuple]
467- mapping {model_variable -> approx params}
468- Local Vars are used for Autoencoding Variational Bayes
469- See (AEVB; Kingma and Welling, 2014) for details
470462 model: :class:`pymc.Model`
471463 PyMC model for inference
472464 random_seed: None or int
@@ -571,8 +563,6 @@ def __init__(
571563 kernel = test_functions .rbf ,
572564 ** kwargs ,
573565 ):
574- if kwargs .get ("local_rv" ) is not None :
575- raise opvi .AEVBInferenceError ("SVGD does not support local groups" )
576566 empirical = Empirical (
577567 size = n_particles ,
578568 jitter = jitter ,
@@ -639,9 +629,7 @@ def __init__(self, approx=None, estimator=KSD, kernel=test_functions.rbf, **kwar
639629 "is often **underestimated** when using temperature = 1."
640630 )
641631 if approx is None :
642- approx = FullRank (
643- model = kwargs .pop ("model" , None ), local_rv = kwargs .pop ("local_rv" , None )
644- )
632+ approx = FullRank (model = kwargs .pop ("model" , None ))
645633 super ().__init__ (estimator = estimator , approx = approx , kernel = kernel , ** kwargs )
646634
647635 def fit (
0 commit comments