Skip to content

Commit 8323952

Browse files
authored
Backport some features from #100 (#102)
* Fix unbox condition function (ref #100) * Port new produce mechanism from #100. * Minor bugfixes. * Fix new produce mechanism. * Update src/tapedtask.jl * Update src/tapedtask.jl * Update src/tapedtask.jl * Update src/tapedtask.jl
1 parent 48703aa commit 8323952

File tree

2 files changed

+32
-14
lines changed

2 files changed

+32
-14
lines changed

src/tapedfunction.jl

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ end
4242
function Base.show(io::IO, instruction::Instruction)
4343
fun = instruction.fun
4444
tape = instruction.tape
45-
println(io, "Instruction($(fun)), tape=$(objectid(tape)))")
45+
println(io, "Instruction($(fun)$(map(val, instruction.input)), tape=$(objectid(tape)))")
4646
end
4747

4848
function Base.show(io::IO, tp::Tape)
@@ -75,7 +75,8 @@ function run_and_record!(tape::Tape, f, args...)
7575
f = val(f) # f maybe a Boxed closure
7676
output = try
7777
box(f(map(val, args)...))
78-
catch
78+
catch e
79+
@warn e
7980
any_box(nothing)
8081
end
8182
ins = Instruction(f, args, output, tape)
@@ -94,11 +95,14 @@ end
9495
function unbox_condition(ir)
9596
for blk in IRTools.blocks(ir)
9697
vars = keys(blk)
97-
for br in IRTools.branches(blk)
98+
brs = IRTools.branches(blk)
99+
for (i, br) in enumerate(brs)
98100
IRTools.isconditional(br) || continue
99101
cond = br.condition
100-
prev_cond = IRTools.insert!(ir, cond, ir[cond])
101-
ir[cond] = IRTools.xcall(@__MODULE__, :val, prev_cond)
102+
new_cond = IRTools.push!(
103+
blk,
104+
IRTools.xcall(@__MODULE__, :val, cond))
105+
brs[i] = IRTools.Branch(br; condition=new_cond)
102106
end
103107
end
104108
end

src/tapedtask.jl

Lines changed: 23 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,12 @@ struct TapedTask
88
counter::Ref{Int}
99
produce_ch::Channel{Any}
1010
consume_ch::Channel{Int}
11+
produced_val::Vector{Any}
12+
13+
function TapedTask(
14+
t::Task, tf::TapedFunction, counter, pch::Channel{Any}, cch::Channel{Int})
15+
new(t, tf, counter, pch, cch, Any[])
16+
end
1117
end
1218

1319
function TapedTask(tf::TapedFunction, args...)
@@ -35,24 +41,33 @@ function TapedTask(tf::TapedFunction, args...)
3541
close(consume_ch)
3642
end
3743
t = TapedTask(task, tf, counter, produce_ch, consume_ch)
38-
# task.storage === nothing && (task.storage = IdDict())
39-
# task.storage[:tapedtask] = t
44+
task.storage === nothing && (task.storage = IdDict())
45+
task.storage[:tapedtask] = t
4046
tf.owner = t
4147
return t
4248
end
4349

44-
TapedTask(f, args...) = TapedTask(TapedFunction(f, arity=length(args)), args...)
50+
# Issue: evaluating model without a trace, see
51+
# https://github.com/TuringLang/Turing.jl/pull/1757#diff-8d16dd13c316055e55f300cd24294bb2f73f46cbcb5a481f8936ff56939da7ceR329
52+
TapedTask(f, args...) = TapedTask(TapedFunction(f, arity=length(args)), args...)
4553
TapedTask(t::TapedTask, args...) = TapedTask(func(t), args...)
4654
func(t::TapedTask) = t.tf.func
4755

4856
function step_in(tf::TapedFunction, counter::Ref{Int}, args)
4957
len = length(tf.tape)
50-
if(counter[] <= 1)
58+
if(counter[] <= 1 && length(args) > 0)
5159
input = map(box, args)
5260
tf.tape[1].input = input
5361
end
5462
while counter[] <= len
5563
tf.tape[counter[]]()
64+
# produce and wait after an instruction is done
65+
ttask = tf.owner
66+
if length(ttask.produced_val) > 0
67+
val = pop!(ttask.produced_val)
68+
put!(ttask.produce_ch, val)
69+
take!(ttask.consume_ch) # wait for next consumer
70+
end
5671
counter[] += 1
5772
end
5873
end
@@ -76,7 +91,7 @@ function (instr::Instruction{typeof(produce)})()
7691
internal_produce(instr, args)
7792
end
7893

79-
#=
94+
8095
# Another way to support `produce` in nested call. This way has its caveat:
8196
# `produce` may deeply hide in an instruction, but not be an instruction
8297
# itself, and when we copy a task, the newly copied task will resume from
@@ -95,11 +110,10 @@ end
95110
function produce(val)
96111
is_in_tapedtask() || return nothing
97112
ttask = current_task().storage[:tapedtask]
98-
put!(ttask.produce_ch, val)
99-
take!(ttask.consume_ch) # wait for next consumer
100-
return nothing
113+
length(ttask.produced_val) > 1 &&
114+
error("There is a produced value which is not consumed.")
115+
push!(ttask.produced_val, val)
101116
end
102-
=#
103117

104118
function consume(ttask::TapedTask)
105119
if istaskstarted(ttask.task)

0 commit comments

Comments
 (0)