@@ -8,6 +8,12 @@ struct TapedTask
88 counter:: Ref{Int}
99 produce_ch:: Channel{Any}
1010 consume_ch:: Channel{Int}
11+ produced_val:: Vector{Any}
12+
13+ function TapedTask (
14+ t:: Task , tf:: TapedFunction , counter, pch:: Channel{Any} , cch:: Channel{Int} )
15+ new (t, tf, counter, pch, cch, Any[])
16+ end
1117end
1218
1319function TapedTask (tf:: TapedFunction , args... )
@@ -35,24 +41,33 @@ function TapedTask(tf::TapedFunction, args...)
3541 close (consume_ch)
3642 end
3743 t = TapedTask (task, tf, counter, produce_ch, consume_ch)
38- # task.storage === nothing && (task.storage = IdDict())
39- # task.storage[:tapedtask] = t
44+ task. storage === nothing && (task. storage = IdDict ())
45+ task. storage[:tapedtask ] = t
4046 tf. owner = t
4147 return t
4248end
4349
44- TapedTask (f, args... ) = TapedTask (TapedFunction (f, arity= length (args)), args... )
50+ # Issue: evaluating model without a trace, see
51+ # https://github.com/TuringLang/Turing.jl/pull/1757#diff-8d16dd13c316055e55f300cd24294bb2f73f46cbcb5a481f8936ff56939da7ceR329
52+ TapedTask (f, args... ) = TapedTask (TapedFunction (f, arity= length (args)), args... )
4553TapedTask (t:: TapedTask , args... ) = TapedTask (func (t), args... )
4654func (t:: TapedTask ) = t. tf. func
4755
4856function step_in (tf:: TapedFunction , counter:: Ref{Int} , args)
4957 len = length (tf. tape)
50- if (counter[] <= 1 )
58+ if (counter[] <= 1 && length (args) > 0 )
5159 input = map (box, args)
5260 tf. tape[1 ]. input = input
5361 end
5462 while counter[] <= len
5563 tf. tape[counter[]]()
64+ # produce and wait after an instruction is done
65+ ttask = tf. owner
66+ if length (ttask. produced_val) > 0
67+ val = pop! (ttask. produced_val)
68+ put! (ttask. produce_ch, val)
69+ take! (ttask. consume_ch) # wait for next consumer
70+ end
5671 counter[] += 1
5772 end
5873end
@@ -76,7 +91,7 @@ function (instr::Instruction{typeof(produce)})()
7691 internal_produce (instr, args)
7792end
7893
79- #=
94+
8095# Another way to support `produce` in nested call. This way has its caveat:
8196# `produce` may deeply hide in an instruction, but not be an instruction
8297# itself, and when we copy a task, the newly copied task will resume from
95110function produce (val)
96111 is_in_tapedtask () || return nothing
97112 ttask = current_task (). storage[:tapedtask ]
98- put! (ttask.produce_ch, val)
99- take!(ttask.consume_ch) # wait for next consumer
100- return nothing
113+ length (ttask. produced_val) > 1 &&
114+ error ( " There is a produced value which is not consumed. " )
115+ push! (ttask . produced_val, val)
101116end
102- =#
103117
104118function consume (ttask:: TapedTask )
105119 if istaskstarted (ttask. task)
0 commit comments