Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 9 additions & 5 deletions src/tapedfunction.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ end
function Base.show(io::IO, instruction::Instruction)
fun = instruction.fun
tape = instruction.tape
println(io, "Instruction($(fun)), tape=$(objectid(tape)))")
println(io, "Instruction($(fun)$(map(val, instruction.input)), tape=$(objectid(tape)))")
end

function Base.show(io::IO, tp::Tape)
Expand Down Expand Up @@ -75,7 +75,8 @@ function run_and_record!(tape::Tape, f, args...)
f = val(f) # f maybe a Boxed closure
output = try
box(f(map(val, args)...))
catch
catch e
@warn e
any_box(nothing)
end
ins = Instruction(f, args, output, tape)
Expand All @@ -94,11 +95,14 @@ end
function unbox_condition(ir)
for blk in IRTools.blocks(ir)
vars = keys(blk)
for br in IRTools.branches(blk)
brs = IRTools.branches(blk)
for (i, br) in enumerate(brs)
IRTools.isconditional(br) || continue
cond = br.condition
prev_cond = IRTools.insert!(ir, cond, ir[cond])
ir[cond] = IRTools.xcall(@__MODULE__, :val, prev_cond)
new_cond = IRTools.push!(
blk,
IRTools.xcall(@__MODULE__, :val, cond))
brs[i] = IRTools.Branch(br; condition=new_cond)
end
end
end
Expand Down
32 changes: 23 additions & 9 deletions src/tapedtask.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,12 @@ struct TapedTask
counter::Ref{Int}
produce_ch::Channel{Any}
consume_ch::Channel{Int}
produced_val::Vector{Any}

function TapedTask(
t::Task, tf::TapedFunction, counter, pch::Channel{Any}, cch::Channel{Int})
new(t, tf, counter, pch, cch, Any[])
end
end

function TapedTask(tf::TapedFunction, args...)
Expand Down Expand Up @@ -35,24 +41,33 @@ function TapedTask(tf::TapedFunction, args...)
close(consume_ch)
end
t = TapedTask(task, tf, counter, produce_ch, consume_ch)
# task.storage === nothing && (task.storage = IdDict())
# task.storage[:tapedtask] = t
task.storage === nothing && (task.storage = IdDict())
task.storage[:tapedtask] = t
tf.owner = t
return t
end

TapedTask(f, args...) = TapedTask(TapedFunction(f, arity=length(args)), args...)
# Issue: evaluating model without a trace, see
# https://github.com/TuringLang/Turing.jl/pull/1757#diff-8d16dd13c316055e55f300cd24294bb2f73f46cbcb5a481f8936ff56939da7ceR329
TapedTask(f, args...) = TapedTask(TapedFunction(f, arity=length(args)), args...)
TapedTask(t::TapedTask, args...) = TapedTask(func(t), args...)
func(t::TapedTask) = t.tf.func

function step_in(tf::TapedFunction, counter::Ref{Int}, args)
len = length(tf.tape)
if(counter[] <= 1)
if(counter[] <= 1 && length(args) > 0)
input = map(box, args)
tf.tape[1].input = input
end
while counter[] <= len
tf.tape[counter[]]()
# produce and wait after an instruction is done
ttask = tf.owner
if length(ttask.produced_val) > 0
val = pop!(ttask.produced_val)
put!(ttask.produce_ch, val)
take!(ttask.consume_ch) # wait for next consumer
end
counter[] += 1
end
end
Expand All @@ -76,7 +91,7 @@ function (instr::Instruction{typeof(produce)})()
internal_produce(instr, args)
end

#=

# Another way to support `produce` in nested call. This way has its caveat:
# `produce` may deeply hide in an instruction, but not be an instruction
# itself, and when we copy a task, the newly copied task will resume from
Expand All @@ -95,11 +110,10 @@ end
function produce(val)
is_in_tapedtask() || return nothing
ttask = current_task().storage[:tapedtask]
put!(ttask.produce_ch, val)
take!(ttask.consume_ch) # wait for next consumer
return nothing
length(ttask.produced_val) > 1 &&
error("There is a produced value which is not consumed.")
push!(ttask.produced_val, val)
end
=#

function consume(ttask::TapedTask)
if istaskstarted(ttask.task)
Expand Down