@@ -203,25 +203,6 @@ def test_parallel_start(self):
203203 assert idata .warmup_posterior ["x" ].sel (chain = 0 , draw = 0 ).values [0 ] > 0
204204 assert idata .warmup_posterior ["x" ].sel (chain = 1 , draw = 0 ).values [0 ] < 0
205205
206- def test_sample_tune_len (self ):
207- with self .model :
208- with warnings .catch_warnings ():
209- warnings .filterwarnings ("ignore" , ".*number of samples.*" , UserWarning )
210- warnings .filterwarnings ("ignore" , "Tuning samples will be included.*" , UserWarning )
211- trace = pm .sample (draws = 100 , tune = 50 , cores = 1 , return_inferencedata = False )
212- assert len (trace ) == 100
213- trace = pm .sample (
214- draws = 100 ,
215- tune = 50 ,
216- cores = 1 ,
217- step = pm .Metropolis (),
218- return_inferencedata = False ,
219- discard_tuned_samples = False ,
220- )
221- assert len (trace ) == 150
222- trace = pm .sample (draws = 100 , tune = 50 , cores = 4 , return_inferencedata = False )
223- assert len (trace ) == 100
224-
225206 def test_reset_tuning (self ):
226207 with self .model :
227208 tune = 50
@@ -233,80 +214,6 @@ def test_reset_tuning(self):
233214 assert step .potential ._n_samples == tune
234215 assert step .step_adapt ._count == tune + 1
235216
236- @pytest .mark .parametrize ("step_cls" , [pm .NUTS , pm .Metropolis , pm .Slice ])
237- @pytest .mark .parametrize ("discard" , [True , False ])
238- def test_trace_report (self , step_cls , discard ):
239- with self .model :
240- # add more variables, because stats are 2D with CompoundStep!
241- pm .Uniform ("uni" )
242- with warnings .catch_warnings ():
243- warnings .filterwarnings (
244- "ignore" , ".*Tuning samples will be included.*" , UserWarning
245- )
246- trace = pm .sample (
247- draws = 100 ,
248- tune = 50 ,
249- cores = 1 ,
250- discard_tuned_samples = discard ,
251- step = step_cls (),
252- compute_convergence_checks = False ,
253- return_inferencedata = False ,
254- )
255- assert trace .report .n_tune == 50
256- assert trace .report .n_draws == 100
257- assert isinstance (trace .report .t_sampling , float )
258-
259- def test_return_inferencedata (self ):
260- with self .model :
261- kwargs = dict (draws = 100 , tune = 50 , cores = 1 , chains = 2 , step = pm .Metropolis ())
262-
263- # trace with tuning
264- with pytest .warns (UserWarning , match = "will be included" ):
265- result = pm .sample (
266- ** kwargs , return_inferencedata = False , discard_tuned_samples = False
267- )
268- assert isinstance (result , pm .backends .base .MultiTrace )
269- assert len (result ) == 150
270-
271- # inferencedata with tuning
272- result = pm .sample (** kwargs , return_inferencedata = True , discard_tuned_samples = False )
273- assert isinstance (result , InferenceData )
274- assert result .posterior .sizes ["draw" ] == 100
275- assert result .posterior .sizes ["chain" ] == 2
276- assert len (result ._groups_warmup ) > 0
277-
278- # inferencedata without tuning, with idata_kwargs
279- prior = pm .sample_prior_predictive (return_inferencedata = False )
280- result = pm .sample (
281- ** kwargs ,
282- return_inferencedata = True ,
283- discard_tuned_samples = True ,
284- idata_kwargs = {"prior" : prior },
285- random_seed = - 1 ,
286- )
287- assert "prior" in result
288- assert isinstance (result , InferenceData )
289- assert result .posterior .sizes ["draw" ] == 100
290- assert result .posterior .sizes ["chain" ] == 2
291- assert len (result ._groups_warmup ) == 0
292-
293- @pytest .mark .parametrize ("cores" , [1 , 2 ])
294- def test_sampler_stat_tune (self , cores ):
295- with self .model :
296- with warnings .catch_warnings ():
297- warnings .filterwarnings ("ignore" , ".*number of samples.*" , UserWarning )
298- warnings .filterwarnings ("ignore" , "Tuning samples will be included.*" , UserWarning )
299- tune_stat = pm .sample (
300- tune = 5 ,
301- draws = 7 ,
302- cores = cores ,
303- discard_tuned_samples = False ,
304- return_inferencedata = False ,
305- step = pm .Metropolis (),
306- ).get_sampler_stats ("tune" , chains = 1 )
307- assert list (tune_stat ).count (True ) == 5
308- assert list (tune_stat ).count (False ) == 7
309-
310217 @pytest .mark .parametrize (
311218 "start, error" ,
312219 [
@@ -411,6 +318,188 @@ def test_transform_with_rv_dependency(self, symbolic_rv):
411318 assert np .allclose (scipy .special .expit (trace ["y_interval__" ]), trace ["y" ])
412319
413320
321+ class ApocalypticMetropolis (pm .Metropolis ):
322+ """A stepper that warns in every iteration."""
323+
324+ stats_dtypes_shapes = {
325+ ** pm .Metropolis .stats_dtypes_shapes ,
326+ "warning" : (SamplerWarning , None ),
327+ }
328+
329+ def astep (self , q0 ):
330+ draw , stats = super ().astep (q0 )
331+ stats [0 ]["warning" ] = SamplerWarning (
332+ WarningType .BAD_ENERGY ,
333+ "Asteroid incoming!" ,
334+ "warn" ,
335+ )
336+ return draw , stats
337+
338+
339+ class TestSampleReturn :
340+ """Tests related to kwargs that parametrize how `pm.sample` results are returned."""
341+
342+ def test_sample_tune_len (self ):
343+ with pm .Model ():
344+ pm .Normal ("n" )
345+ with warnings .catch_warnings ():
346+ warnings .filterwarnings ("ignore" , ".*number of samples.*" , UserWarning )
347+ warnings .filterwarnings ("ignore" , "Tuning samples will be included.*" , UserWarning )
348+ trace = pm .sample (
349+ draws = 100 , tune = 50 , cores = 1 , step = pm .Metropolis (), return_inferencedata = False
350+ )
351+ assert len (trace ) == 100
352+ trace = pm .sample (
353+ draws = 100 ,
354+ tune = 50 ,
355+ cores = 1 ,
356+ step = pm .Metropolis (),
357+ return_inferencedata = False ,
358+ discard_tuned_samples = False ,
359+ )
360+ assert len (trace ) == 150
361+ trace = pm .sample (
362+ draws = 100 ,
363+ tune = 50 ,
364+ cores = 4 ,
365+ step = pm .Metropolis (),
366+ return_inferencedata = False ,
367+ )
368+ assert len (trace ) == 100
369+
370+ @pytest .mark .parametrize ("discard" , [True , False ])
371+ def test_trace_report (self , discard ):
372+ with pm .Model ():
373+ pm .Uniform ("uni" )
374+ with warnings .catch_warnings ():
375+ warnings .filterwarnings (
376+ "ignore" , ".*Tuning samples will be included.*" , UserWarning
377+ )
378+ trace = pm .sample (
379+ draws = 100 ,
380+ tune = 50 ,
381+ cores = 1 ,
382+ discard_tuned_samples = discard ,
383+ step = pm .Metropolis (),
384+ compute_convergence_checks = False ,
385+ return_inferencedata = False ,
386+ )
387+ assert trace .report .n_tune == 50
388+ assert trace .report .n_draws == 100
389+ assert isinstance (trace .report .t_sampling , float )
390+
391+ def test_return_inferencedata (self ):
392+ model , _ , step , _ = simple_init ()
393+ with model :
394+ kwargs = dict (draws = 100 , tune = 50 , cores = 1 , chains = 2 , step = step )
395+
396+ # trace with tuning
397+ with pytest .warns (UserWarning , match = "will be included" ):
398+ result = pm .sample (
399+ ** kwargs , return_inferencedata = False , discard_tuned_samples = False
400+ )
401+ assert isinstance (result , pm .backends .base .MultiTrace )
402+ assert len (result ) == 150
403+
404+ # inferencedata with tuning
405+ result = pm .sample (** kwargs , return_inferencedata = True , discard_tuned_samples = False )
406+ assert isinstance (result , InferenceData )
407+ assert result .posterior .sizes ["draw" ] == 100
408+ assert result .posterior .sizes ["chain" ] == 2
409+ assert len (result ._groups_warmup ) > 0
410+
411+ # inferencedata without tuning, with idata_kwargs
412+ prior = pm .sample_prior_predictive (return_inferencedata = False )
413+ result = pm .sample (
414+ ** kwargs ,
415+ return_inferencedata = True ,
416+ discard_tuned_samples = True ,
417+ idata_kwargs = {"prior" : prior },
418+ random_seed = - 1 ,
419+ )
420+ assert "prior" in result
421+ assert isinstance (result , InferenceData )
422+ assert result .posterior .sizes ["draw" ] == 100
423+ assert result .posterior .sizes ["chain" ] == 2
424+ assert len (result ._groups_warmup ) == 0
425+
426+ @pytest .mark .parametrize ("cores" , [1 , 2 ])
427+ def test_sampler_stat_tune (self , cores ):
428+ with pm .Model ():
429+ pm .Normal ("n" )
430+ with warnings .catch_warnings ():
431+ warnings .filterwarnings ("ignore" , ".*number of samples.*" , UserWarning )
432+ warnings .filterwarnings ("ignore" , "Tuning samples will be included.*" , UserWarning )
433+ tune_stat = pm .sample (
434+ tune = 5 ,
435+ draws = 7 ,
436+ cores = cores ,
437+ discard_tuned_samples = False ,
438+ return_inferencedata = False ,
439+ step = pm .Metropolis (),
440+ ).get_sampler_stats ("tune" , chains = 1 )
441+ assert list (tune_stat ).count (True ) == 5
442+ assert list (tune_stat ).count (False ) == 7
443+
444+ @pytest .mark .parametrize ("cores" , [1 , 2 ])
445+ def test_logs_sampler_warnings (self , caplog , cores ):
446+ """Asserts that "warning" sampler stats are logged during sampling."""
447+ with pm .Model ():
448+ pm .Normal ("n" )
449+ with caplog .at_level (logging .WARNING ):
450+ idata = pm .sample (
451+ tune = 2 ,
452+ draws = 3 ,
453+ cores = cores ,
454+ chains = cores ,
455+ step = ApocalypticMetropolis (),
456+ compute_convergence_checks = False ,
457+ discard_tuned_samples = False ,
458+ keep_warning_stat = True ,
459+ )
460+
461+ # Sampler warnings should be logged
462+ nwarns = sum ("Asteroid" in rec .message for rec in caplog .records )
463+ assert nwarns == (2 + 3 ) * cores
464+
465+ @pytest .mark .parametrize ("keep_warning_stat" , [None , True ])
466+ def test_keep_warning_stat_setting (self , keep_warning_stat ):
467+ """The ``keep_warning_stat`` stat (aka "Adrian's kwarg) enables users
468+ to keep the ``SamplerWarning`` objects from the ``sample_stats.warning`` group.
469+ This breaks ``idata.to_netcdf()`` which is why it defaults to ``False``.
470+ """
471+ sample_kwargs = dict (
472+ tune = 2 ,
473+ draws = 3 ,
474+ chains = 1 ,
475+ compute_convergence_checks = False ,
476+ discard_tuned_samples = False ,
477+ keep_warning_stat = keep_warning_stat ,
478+ )
479+ if keep_warning_stat :
480+ sample_kwargs ["keep_warning_stat" ] = True
481+ with pm .Model ():
482+ pm .Normal ("n" )
483+ idata = pm .sample (step = ApocalypticMetropolis (), ** sample_kwargs )
484+
485+ if keep_warning_stat :
486+ assert "warning" in idata .warmup_sample_stats
487+ assert "warning" in idata .sample_stats
488+ # And end up in the InferenceData
489+ assert "warning" in idata .sample_stats
490+ # NOTE: The stats are squeezed by default but this does not always work.
491+ # This tests flattens so we don't have to be exact in accessing (non-)squeezed items.
492+ # Also see https://github.com/pymc-devs/pymc/issues/6207.
493+ warn_objs = list (idata .sample_stats .warning .sel (chain = 0 ).values .flatten ())
494+ assert any (isinstance (w , SamplerWarning ) for w in warn_objs )
495+ assert any ("Asteroid" in w .message for w in warn_objs )
496+ else :
497+ assert "warning" not in idata .warmup_sample_stats
498+ assert "warning" not in idata .sample_stats
499+ assert "warning_dim_0" not in idata .warmup_sample_stats
500+ assert "warning_dim_0" not in idata .sample_stats
501+
502+
414503def test_sample_find_MAP_does_not_modify_start ():
415504 # see https://github.com/pymc-devs/pymc/pull/4458
416505 with pm .Model ():
@@ -603,84 +692,6 @@ def test_step_args():
603692 npt .assert_allclose (idata1 .sample_stats .scaling , 0 )
604693
605694
606- class ApocalypticMetropolis (pm .Metropolis ):
607- """A stepper that warns in every iteration."""
608-
609- stats_dtypes_shapes = {
610- ** pm .Metropolis .stats_dtypes_shapes ,
611- "warning" : (SamplerWarning , None ),
612- }
613-
614- def astep (self , q0 ):
615- draw , stats = super ().astep (q0 )
616- stats [0 ]["warning" ] = SamplerWarning (
617- WarningType .BAD_ENERGY ,
618- "Asteroid incoming!" ,
619- "warn" ,
620- )
621- return draw , stats
622-
623-
624- @pytest .mark .parametrize ("cores" , [1 , 2 ])
625- def test_logs_sampler_warnings (caplog , cores ):
626- """Asserts that "warning" sampler stats are logged during sampling."""
627- with pm .Model ():
628- pm .Normal ("n" )
629- with caplog .at_level (logging .WARNING ):
630- idata = pm .sample (
631- tune = 2 ,
632- draws = 3 ,
633- cores = cores ,
634- chains = cores ,
635- step = ApocalypticMetropolis (),
636- compute_convergence_checks = False ,
637- discard_tuned_samples = False ,
638- keep_warning_stat = True ,
639- )
640-
641- # Sampler warnings should be logged
642- nwarns = sum ("Asteroid" in rec .message for rec in caplog .records )
643- assert nwarns == (2 + 3 ) * cores
644-
645-
646- @pytest .mark .parametrize ("keep_warning_stat" , [None , True ])
647- def test_keep_warning_stat_setting (keep_warning_stat ):
648- """The ``keep_warning_stat`` stat (aka "Adrian's kwarg) enables users
649- to keep the ``SamplerWarning`` objects from the ``sample_stats.warning`` group.
650- This breaks ``idata.to_netcdf()`` which is why it defaults to ``False``.
651- """
652- sample_kwargs = dict (
653- tune = 2 ,
654- draws = 3 ,
655- chains = 1 ,
656- compute_convergence_checks = False ,
657- discard_tuned_samples = False ,
658- keep_warning_stat = keep_warning_stat ,
659- )
660- if keep_warning_stat :
661- sample_kwargs ["keep_warning_stat" ] = True
662- with pm .Model ():
663- pm .Normal ("n" )
664- idata = pm .sample (step = ApocalypticMetropolis (), ** sample_kwargs )
665-
666- if keep_warning_stat :
667- assert "warning" in idata .warmup_sample_stats
668- assert "warning" in idata .sample_stats
669- # And end up in the InferenceData
670- assert "warning" in idata .sample_stats
671- # NOTE: The stats are squeezed by default but this does not always work.
672- # This tests flattens so we don't have to be exact in accessing (non-)squeezed items.
673- # Also see https://github.com/pymc-devs/pymc/issues/6207.
674- warn_objs = list (idata .sample_stats .warning .sel (chain = 0 ).values .flatten ())
675- assert any (isinstance (w , SamplerWarning ) for w in warn_objs )
676- assert any ("Asteroid" in w .message for w in warn_objs )
677- else :
678- assert "warning" not in idata .warmup_sample_stats
679- assert "warning" not in idata .sample_stats
680- assert "warning_dim_0" not in idata .warmup_sample_stats
681- assert "warning_dim_0" not in idata .sample_stats
682-
683-
684695def test_init_nuts (caplog ):
685696 with pm .Model () as model :
686697 a = pm .Normal ("a" )
0 commit comments