@@ -333,7 +333,7 @@ def sample(
333333 compute_convergence_checks : bool = True ,
334334 keep_warning_stat : bool = False ,
335335 return_inferencedata : bool = True ,
336- idata_kwargs : dict = None ,
336+ idata_kwargs : Optional [ Dict [ str , Any ]] = None ,
337337 callback = None ,
338338 mp_ctx = None ,
339339 model : Optional [Model ] = None ,
@@ -687,7 +687,36 @@ def sample(
687687
688688 t_sampling = time .time () - t_start
689689
690- # Wrap chain traces in a MultiTrace
690+ # Packaging, validating and returning the result was extracted
691+ # into a function to make it easier to test and refactor.
692+ return _sample_return (
693+ traces = traces ,
694+ tune = tune ,
695+ t_sampling = t_sampling ,
696+ discard_tuned_samples = discard_tuned_samples ,
697+ compute_convergence_checks = compute_convergence_checks ,
698+ return_inferencedata = return_inferencedata ,
699+ keep_warning_stat = keep_warning_stat ,
700+ idata_kwargs = idata_kwargs or {},
701+ model = model ,
702+ )
703+
704+
705+ def _sample_return (
706+ * ,
707+ traces : Sequence [IBaseTrace ],
708+ tune : int ,
709+ t_sampling : float ,
710+ discard_tuned_samples : bool ,
711+ compute_convergence_checks : bool ,
712+ return_inferencedata : bool ,
713+ keep_warning_stat : bool ,
714+ idata_kwargs : Dict [str , Any ],
715+ model : Model ,
716+ ) -> Union [InferenceData , MultiTrace ]:
717+ """Final step of `pm.sampler` that picks/slices chains,
718+ runs diagnostics and converts to the desired return type."""
719+ # Pick and slice chains to keep the maximum number of samples
691720 if discard_tuned_samples :
692721 traces , length = _choose_chains (traces , tune )
693722 else :
@@ -725,8 +754,7 @@ def sample(
725754 idata = None
726755 if compute_convergence_checks or return_inferencedata :
727756 ikwargs : Dict [str , Any ] = dict (model = model , save_warmup = not discard_tuned_samples )
728- if idata_kwargs :
729- ikwargs .update (idata_kwargs )
757+ ikwargs .update (idata_kwargs )
730758 idata = pm .to_inference_data (mtrace , ** ikwargs )
731759
732760 if compute_convergence_checks :
0 commit comments