@@ -237,10 +237,14 @@ def _sample_external_nuts(
237237 model : Model ,
238238 progressbar : bool ,
239239 idata_kwargs : Optional [Dict ],
240+ nuts_sampler_kwargs : Optional [Dict ],
240241 ** kwargs ,
241242):
242243 warnings .warn ("Use of external NUTS sampler is still experimental" , UserWarning )
243244
245+ if nuts_sampler_kwargs is None :
246+ nuts_sampler_kwargs = {}
247+
244248 if sampler == "nutpie" :
245249 try :
246250 import nutpie
@@ -271,7 +275,7 @@ def _sample_external_nuts(
271275 target_accept = target_accept ,
272276 seed = _get_seeds_per_chain (random_seed , 1 )[0 ],
273277 progress_bar = progressbar ,
274- ** kwargs ,
278+ ** nuts_sampler_kwargs ,
275279 )
276280 return idata
277281
@@ -288,7 +292,7 @@ def _sample_external_nuts(
288292 model = model ,
289293 progressbar = progressbar ,
290294 idata_kwargs = idata_kwargs ,
291- ** kwargs ,
295+ ** nuts_sampler_kwargs ,
292296 )
293297 return idata
294298
@@ -304,7 +308,7 @@ def _sample_external_nuts(
304308 initvals = initvals ,
305309 model = model ,
306310 idata_kwargs = idata_kwargs ,
307- ** kwargs ,
311+ ** nuts_sampler_kwargs ,
308312 )
309313 return idata
310314
@@ -334,6 +338,7 @@ def sample(
334338 keep_warning_stat : bool = False ,
335339 return_inferencedata : bool = True ,
336340 idata_kwargs : Optional [Dict [str , Any ]] = None ,
341+ nuts_sampler_kwargs : Optional [Dict [str , Any ]] = None ,
337342 callback = None ,
338343 mp_ctx = None ,
339344 model : Optional [Model ] = None ,
@@ -410,6 +415,9 @@ def sample(
410415 `MultiTrace` (False). Defaults to `True`.
411416 idata_kwargs : dict, optional
412417 Keyword arguments for :func:`pymc.to_inference_data`
418+ nuts_sampler_kwargs : dict, optional
419+ Keyword arguments for the sampling library that implements nuts.
420+ Only used when an external sampler is specified via the `nuts_sampler` kwarg.
413421 callback : function, default=None
414422 A function which gets called for every sample from the trace of a chain. The function is
415423 called with the trace and the current draw and will contain all samples for a single trace.
@@ -493,6 +501,8 @@ def sample(
493501 stacklevel = 2 ,
494502 )
495503 initvals = kwargs .pop ("start" )
504+ if nuts_sampler_kwargs is None :
505+ nuts_sampler_kwargs = {}
496506 if "target_accept" in kwargs :
497507 if "nuts" in kwargs and "target_accept" in kwargs ["nuts" ]:
498508 raise ValueError (
@@ -569,6 +579,7 @@ def sample(
569579 model = model ,
570580 progressbar = progressbar ,
571581 idata_kwargs = idata_kwargs ,
582+ nuts_sampler_kwargs = nuts_sampler_kwargs ,
572583 ** kwargs ,
573584 )
574585
0 commit comments