@@ -458,7 +458,7 @@ and then converts these into a `Chains` object using `AbstractMCMC.bundle_sample
458458
459459# Example
460460```jldoctest
461- julia> using Turing; Turing.turnprogress (false);
461+ julia> using Turing; Turing.setprogress! (false);
462462[ Info: [Turing]: progress logging is disabled globally
463463
464464julia> @model function linear_reg(x, y, σ = 0.1)
@@ -517,31 +517,31 @@ function predict(model::Model, chain::MCMCChains.Chains; kwargs...)
517517 return predict (Random. GLOBAL_RNG, model, chain; kwargs... )
518518end
519519function predict (rng:: AbstractRNG , model:: Model , chain:: MCMCChains.Chains ; include_all = false )
520+ # Don't need all the diagnostics
521+ chain_parameters = MCMCChains. get_sections (chain, :parameters )
522+
520523 spl = DynamicPPL. SampleFromPrior ()
521524
522525 # Sample transitions using `spl` conditioned on values in `chain`
523- transitions = [
524- transitions_from_chain (rng, model, chain[:, :, chn_idx]; sampler = spl)
525- for chn_idx = 1 : size (chain, 3 )
526- ]
526+ transitions = transitions_from_chain (rng, model, chain_parameters; sampler = spl)
527527
528528 # Let the Turing internals handle everything else for you
529529 chain_result = reduce (
530530 MCMCChains. chainscat, [
531531 AbstractMCMC. bundle_samples (
532- transitions[chn_idx ],
532+ transitions[:, chain_idx ],
533533 model,
534534 spl,
535535 nothing ,
536536 MCMCChains. Chains
537- ) for chn_idx = 1 : size (chain, 3 )
537+ ) for chain_idx = 1 : size (transitions, 2 )
538538 ]
539539 )
540540
541541 parameter_names = if include_all
542542 names (chain_result, :parameters )
543543 else
544- filter (k -> ∉ (k, names (chain , :parameters )), names (chain_result, :parameters ))
544+ filter (k -> ∉ (k, names (chain_parameters , :parameters )), names (chain_result, :parameters ))
545545 end
546546
547547 return chain_result[parameter_names]
@@ -603,44 +603,22 @@ function transitions_from_chain(
603603)
604604 return transitions_from_chain (Random. GLOBAL_RNG, model, chain; kwargs... )
605605end
606+
606607function transitions_from_chain (
607- rng:: AbstractRNG ,
608+ rng:: Random. AbstractRNG ,
608609 model:: Turing.Model ,
609610 chain:: MCMCChains.Chains ;
610611 sampler = DynamicPPL. SampleFromPrior ()
611612)
612613 vi = Turing. VarInfo (model)
613614
614- transitions = map (1 : length (chain)) do i
615- c = chain[i]
616- md = vi. metadata
617- for v in keys (md)
618- for vn in md[v]. vns
619- vn_sym = Symbol (vn)
620-
621- # Cannot use `vn_sym` to index in the chain
622- # so we have to extract the corresponding "linear"
623- # indices and use those.
624- # `ks` is empty if `vn_sym` not in `c`.
625- ks = MCMCChains. namesingroup (c, vn_sym)
626-
627- if ! isempty (ks)
628- # 1st dimension is of size 1 since `c`
629- # only contains a single sample, and the
630- # last dimension is of size 1 since
631- # we're assuming we're working with a single chain.
632- val = copy (vec (c[ks]. value))
633- DynamicPPL. setval! (vi, val, vn)
634- DynamicPPL. settrans! (vi, false , vn)
635- else
636- DynamicPPL. set_flag! (vi, vn, " del" )
637- end
638- end
639- end
640- # Execute `model` on the parameters set in `vi` and sample those with `"del"` flag using `sampler`
615+ iters = Iterators. product (1 : size (chain, 1 ), 1 : size (chain, 3 ))
616+ transitions = map (iters) do (sample_idx, chain_idx)
617+ # Set variables present in `chain` and mark those NOT present in chain to be resampled.
618+ DynamicPPL. setval_and_resample! (vi, chain, sample_idx, chain_idx)
641619 model (rng, vi, sampler)
642620
643- # Convert `VarInfo` into `NamedTuple` and save
621+ # Convert `VarInfo` into `NamedTuple` and save.
644622 theta = DynamicPPL. tonamedtuple (vi)
645623 lp = Turing. getlogp (vi)
646624 Transition (theta, lp)
0 commit comments