@@ -233,22 +233,20 @@ def test_sample_start_good_shape(self, start):
233233 def test_sample_callback (self ):
234234 callback = mock .Mock ()
235235 test_cores = [1 , 2 ]
236- test_chains = [1 , 2 ]
237236 with self .model :
238237 for cores in test_cores :
239- for chain in test_chains :
240- with warnings .catch_warnings ():
241- warnings .filterwarnings ("ignore" , ".*number of samples.*" , UserWarning )
242- pm .sample (
243- 10 ,
244- tune = 0 ,
245- chains = chain ,
246- step = self .step ,
247- cores = cores ,
248- random_seed = self .random_seed ,
249- callback = callback ,
250- )
251- assert callback .called
238+ with warnings .catch_warnings ():
239+ warnings .filterwarnings ("ignore" , ".*number of samples.*" , UserWarning )
240+ pm .sample (
241+ 10 ,
242+ tune = 0 ,
243+ chains = 2 ,
244+ step = self .step ,
245+ cores = cores ,
246+ random_seed = self .random_seed ,
247+ callback = callback ,
248+ )
249+ assert callback .called
252250
253251 def test_callback_can_cancel (self ):
254252 trace_cancel_length = 5
@@ -339,107 +337,81 @@ def astep(self, q0):
339337class TestSampleReturn :
340338 """Tests related to kwargs that parametrize how `pm.sample` results are returned."""
341339
342- def test_sample_tune_len (self ):
343- with pm .Model ():
340+ def test_sample_return_lengths (self ):
341+ with pm .Model () as model :
344342 pm .Normal ("n" )
345- with warnings .catch_warnings ():
346- warnings .filterwarnings ("ignore" , ".*number of samples.*" , UserWarning )
347- warnings .filterwarnings ("ignore" , "Tuning samples will be included.*" , UserWarning )
348- trace = pm .sample (
349- draws = 100 , tune = 50 , cores = 1 , step = pm .Metropolis (), return_inferencedata = False
350- )
351- assert len (trace ) == 100
352- trace = pm .sample (
353- draws = 100 ,
354- tune = 50 ,
355- cores = 1 ,
356- step = pm .Metropolis (),
357- return_inferencedata = False ,
358- discard_tuned_samples = False ,
359- )
360- assert len (trace ) == 150
361- trace = pm .sample (
362- draws = 100 ,
363- tune = 50 ,
364- cores = 4 ,
365- step = pm .Metropolis (),
366- return_inferencedata = False ,
367- )
368- assert len (trace ) == 100
369343
370- @pytest .mark .parametrize ("discard" , [True , False ])
371- def test_trace_report (self , discard ):
372- with pm .Model ():
373- pm .Uniform ("uni" )
374- with warnings .catch_warnings ():
375- warnings .filterwarnings (
376- "ignore" , ".*Tuning samples will be included.*" , UserWarning
377- )
378- trace = pm .sample (
344+ # Get a MultiTrace with warmup
345+ with pytest .warns (UserWarning , match = "will be included" ):
346+ mtrace = pm .sample (
379347 draws = 100 ,
380348 tune = 50 ,
381349 cores = 1 ,
382- discard_tuned_samples = discard ,
350+ chains = 3 ,
383351 step = pm .Metropolis (),
384- compute_convergence_checks = False ,
385352 return_inferencedata = False ,
386- )
387- assert trace .report .n_tune == 50
388- assert trace .report .n_draws == 100
389- assert isinstance (trace .report .t_sampling , float )
390-
391- def test_return_inferencedata (self ):
392- model , _ , step , _ = simple_init ()
393- with model :
394- kwargs = dict (draws = 100 , tune = 50 , cores = 1 , chains = 2 , step = step )
395-
396- # trace with tuning
397- with pytest .warns (UserWarning , match = "will be included" ):
398- result = pm .sample (
399- ** kwargs , return_inferencedata = False , discard_tuned_samples = False
400- )
401- assert isinstance (result , pm .backends .base .MultiTrace )
402- assert len (result ) == 150
403-
404- # inferencedata with tuning
405- result = pm .sample (** kwargs , return_inferencedata = True , discard_tuned_samples = False )
406- assert isinstance (result , InferenceData )
407- assert result .posterior .sizes ["draw" ] == 100
408- assert result .posterior .sizes ["chain" ] == 2
409- assert len (result ._groups_warmup ) > 0
410-
411- # inferencedata without tuning, with idata_kwargs
412- prior = pm .sample_prior_predictive (return_inferencedata = False )
413- result = pm .sample (
414- ** kwargs ,
415- return_inferencedata = True ,
416- discard_tuned_samples = True ,
417- idata_kwargs = {"prior" : prior },
418- random_seed = - 1 ,
419- )
420- assert "prior" in result
421- assert isinstance (result , InferenceData )
422- assert result .posterior .sizes ["draw" ] == 100
423- assert result .posterior .sizes ["chain" ] == 2
424- assert len (result ._groups_warmup ) == 0
425-
426- @pytest .mark .parametrize ("cores" , [1 , 2 ])
427- def test_sampler_stat_tune (self , cores ):
428- with pm .Model ():
429- pm .Normal ("n" )
430- with warnings .catch_warnings ():
431- warnings .filterwarnings ("ignore" , ".*number of samples.*" , UserWarning )
432- warnings .filterwarnings ("ignore" , "Tuning samples will be included.*" , UserWarning )
433- tune_stat = pm .sample (
434- tune = 5 ,
435- draws = 7 ,
436- cores = cores ,
437353 discard_tuned_samples = False ,
438- return_inferencedata = False ,
439- step = pm .Metropolis (),
440- ).get_sampler_stats ("tune" , chains = 1 )
441- assert list (tune_stat ).count (True ) == 5
442- assert list (tune_stat ).count (False ) == 7
354+ )
355+ assert isinstance (mtrace , pm .backends .base .MultiTrace )
356+ assert len (mtrace ) == 150
357+
358+ # Now instead of running more MCMCs, we'll test the other return
359+ # options using the basetraces inside the MultiTrace.
360+ traces = list (mtrace ._straces .values ())
361+ assert len (traces ) == 3
362+
363+ # MultiTrace without warmup
364+ mtrace_pst = pm .sampling .mcmc ._sample_return (
365+ traces = traces ,
366+ tune = 50 ,
367+ t_sampling = 123.4 ,
368+ discard_tuned_samples = True ,
369+ return_inferencedata = False ,
370+ compute_convergence_checks = False ,
371+ keep_warning_stat = True ,
372+ idata_kwargs = {},
373+ model = model ,
374+ )
375+ assert isinstance (mtrace_pst , pm .backends .base .MultiTrace )
376+ assert len (mtrace_pst ) == 100
377+ assert mtrace_pst .report .t_sampling == 123.4
378+ assert mtrace_pst .report .n_tune == 50
379+ assert mtrace_pst .report .n_draws == 100
380+
381+ # InferenceData with warmup
382+ idata_w = pm .sampling .mcmc ._sample_return (
383+ traces = traces ,
384+ tune = 50 ,
385+ t_sampling = 123.4 ,
386+ discard_tuned_samples = False ,
387+ compute_convergence_checks = False ,
388+ return_inferencedata = True ,
389+ keep_warning_stat = True ,
390+ idata_kwargs = {},
391+ model = model ,
392+ )
393+ assert isinstance (idata_w , InferenceData )
394+ assert hasattr (idata_w , "warmup_posterior" )
395+ assert idata_w .warmup_posterior .sizes ["draw" ] == 50
396+ assert idata_w .posterior .sizes ["draw" ] == 100
397+ assert idata_w .posterior .sizes ["chain" ] == 3
398+
399+ # InferenceData without warmup
400+ idata = pm .sampling .mcmc ._sample_return (
401+ traces = traces ,
402+ tune = 50 ,
403+ t_sampling = 123.4 ,
404+ discard_tuned_samples = True ,
405+ compute_convergence_checks = False ,
406+ return_inferencedata = True ,
407+ keep_warning_stat = False ,
408+ idata_kwargs = {},
409+ model = model ,
410+ )
411+ assert isinstance (idata , InferenceData )
412+ assert not hasattr (idata , "warmup_posterior" )
413+ assert idata .posterior .sizes ["draw" ] == 100
414+ assert idata .posterior .sizes ["chain" ] == 3
443415
444416 @pytest .mark .parametrize ("cores" , [1 , 2 ])
445417 def test_logs_sampler_warnings (self , caplog , cores ):
0 commit comments