Skip to content

Commit 7eab46d

Browse files
committed
Fix last two beam_size kw errors
1 parent ac67e3a commit 7eab46d

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

beginner_source/t5_tutorial.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -222,7 +222,7 @@ def process_labels(labels, x):
222222
beam_size = 1
223223

224224
model_input = transform(input_text)
225-
model_output = sequence_generator.generate(model_input, eos_idx=eos_idx, beam_size=beam_size)
225+
model_output = sequence_generator.generate(model_input, eos_idx=eos_idx, num_beams=beam_size)
226226
output_text = transform.decode(model_output.tolist())
227227

228228
for i in range(cnndm_batch_size):
@@ -312,7 +312,7 @@ def process_labels(labels, x):
312312
beam_size = 1
313313

314314
model_input = transform(input_text)
315-
model_output = sequence_generator.generate(model_input, eos_idx=eos_idx, beam_size=beam_size)
315+
model_output = sequence_generator.generate(model_input, eos_idx=eos_idx, num_beams=beam_size)
316316
output_text = transform.decode(model_output.tolist())
317317

318318
for i in range(imdb_batch_size):

0 commit comments

Comments
 (0)