Skip to content

Commit bde00d0

Browse files
committed
formatting
1 parent bfb9ce0 commit bde00d0

File tree

2 files changed

+21
-28
lines changed

2 files changed

+21
-28
lines changed

src/varinfo.jl

Lines changed: 16 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -302,8 +302,11 @@ end
302302
303303
Merge two `VarInfo` instances into one, giving precedence to `varinfo_right` when reasonable.
304304
"""
305-
Base.merge(varinfo_left::UntypedVarInfo, varinfo_right::UntypedVarInfo) =_merge(varinfo_left, varinfo_right)
306-
Base.merge(varinfo_left::TypedVarInfo, varinfo_right::TypedVarInfo) =_merge(varinfo_left, varinfo_right)
305+
Base.merge(varinfo_left::UntypedVarInfo, varinfo_right::UntypedVarInfo) =
306+
_merge(varinfo_left, varinfo_right)
307+
function Base.merge(varinfo_left::TypedVarInfo, varinfo_right::TypedVarInfo)
308+
return _merge(varinfo_left, varinfo_right)
309+
end
307310

308311
function _merge(varinfo_left::VarInfo, varinfo_right::VarInfo)
309312
metadata = merge_metadata(varinfo_left.metadata, varinfo_right.metadata)
@@ -314,9 +317,8 @@ function _merge(varinfo_left::VarInfo, varinfo_right::VarInfo)
314317
end
315318

316319
function merge_metadata(
317-
metadata_left::NamedTuple{names_left},
318-
metadata_right::NamedTuple{names_right}
319-
) where {names_left, names_right}
320+
metadata_left::NamedTuple{names_left}, metadata_right::NamedTuple{names_right}
321+
) where {names_left,names_right}
320322
# TODO: Improve this. Maybe make `@generated`?
321323
metadata = map(names_left) do sym
322324
if sym in names_right
@@ -332,7 +334,9 @@ function merge_metadata(
332334
end
333335
end
334336

335-
return NamedTuple{(names_left..., names_right_only...)}(tuple(metadata..., metadata_right_only...))
337+
return NamedTuple{(names_left..., names_right_only...)}(
338+
tuple(metadata..., metadata_right_only...)
339+
)
336340
end
337341

338342
function merge_metadata(metadata_left::Metadata, metadata_right::Metadata)
@@ -361,13 +365,13 @@ function merge_metadata(metadata_left::Metadata, metadata_right::Metadata)
361365

362366
# Initialize required fields for `metadata`.
363367
vns = VarName[]
364-
idcs = Dict{VarName, Int}()
368+
idcs = Dict{VarName,Int}()
365369
ranges = Vector{UnitRange{Int}}()
366370
vals = T[]
367371
dists = D[]
368372
gids = metadata_right.gids # NOTE: giving precedence to `metadata_right`
369373
orders = Int[]
370-
flags = Dict{String, BitVector}()
374+
flags = Dict{String,BitVector}()
371375
# Initialize the `flags`.
372376
for k in union(keys(metadata_left.flags), keys(metadata_right.flags))
373377
flags[k] = BitVector()
@@ -442,16 +446,7 @@ function merge_metadata(metadata_left::Metadata, metadata_right::Metadata)
442446
end
443447
end
444448

445-
return Metadata(
446-
idcs,
447-
vns,
448-
ranges,
449-
vals,
450-
dists,
451-
gids,
452-
orders,
453-
flags,
454-
)
449+
return Metadata(idcs, vns, ranges, vals, dists, gids, orders, flags)
455450
end
456451

457452
const VarView = Union{Int,UnitRange,Vector{Int}}
@@ -1589,7 +1584,6 @@ run before sampling `vn`.
15891584
getorder(vi::VarInfo, vn::VarName) = getorder(getmetadata(vi, vn), vn)
15901585
getorder(metadata::Metadata, vn::VarName) = metadata.orders[getidx(metadata, vn)]
15911586

1592-
15931587
#######################################
15941588
# Rand & replaying method for VarInfo #
15951589
#######################################
@@ -1602,7 +1596,9 @@ Check whether `vn` has a true value for `flag` in `vi`.
16021596
function is_flagged(vi::VarInfo, vn::VarName, flag::String)
16031597
return is_flagged(getmetadata(vi, vn), vn, flag)
16041598
end
1605-
is_flagged(metadata::Metadata, vn::VarName, flag::String) = metadata.flags[flag][getidx(metadata, vn)]
1599+
function is_flagged(metadata::Metadata, vn::VarName, flag::String)
1600+
return metadata.flags[flag][getidx(metadata, vn)]
1601+
end
16061602

16071603
"""
16081604
unset_flag!(vi::VarInfo, vn::VarName, flag::String)

test/varinfo.jl

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -481,12 +481,7 @@ DynamicPPL.getspace(::DynamicPPL.Sampler{MySAlg}) = (:s,)
481481
]
482482

483483
# All variables.
484-
@test isempty(
485-
setdiff(
486-
keys(varinfo),
487-
vns,
488-
),
489-
)
484+
@test isempty(setdiff(keys(varinfo), vns))
490485

491486
@testset "$(convert(Vector{VarName}, vns_subset))" for vns_subset in [
492487
[@varname(s)],
@@ -526,7 +521,7 @@ DynamicPPL.getspace(::DynamicPPL.Sampler{MySAlg}) = (:s,)
526521
@testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS
527522
@testset "$(short_varinfo_name(varinfo))" for varinfo in [
528523
VarInfo(model),
529-
last(DynamicPPL.evaluate!!(model, VarInfo(), SamplingContext()))
524+
last(DynamicPPL.evaluate!!(model, VarInfo(), SamplingContext())),
530525
]
531526
vns = DynamicPPL.TestUtils.varnames(model)
532527
@testset "with itself" begin
@@ -551,7 +546,9 @@ DynamicPPL.getspace(::DynamicPPL.Sampler{MySAlg}) = (:s,)
551546

552547
@testset "with different value" begin
553548
x = DynamicPPL.TestUtils.rand(model)
554-
varinfo_changed = DynamicPPL.TestUtils.update_values!!(deepcopy(varinfo), x, vns)
549+
varinfo_changed = DynamicPPL.TestUtils.update_values!!(
550+
deepcopy(varinfo), x, vns
551+
)
555552
# After `merge`, we should have the same values as `x`.
556553
varinfo_merged = merge(varinfo, varinfo_changed)
557554
DynamicPPL.TestUtils.test_values(varinfo_merged, x, vns)

0 commit comments

Comments
 (0)