@@ -116,12 +116,14 @@ def update_pytorch_inputs(self, inputs, pt_outputs):
116
116
if self .full_batch_size :
117
117
# Create CB inputs (make 1 batch index have proper inputs for decode pass)
118
118
batch_index = torch .arange (1 ).view (- 1 , 1 )
119
- batch_idx_input_ids = pt_outputs .logits .detach ().argmax (2 )
119
+ batch_idx_input_ids = pt_outputs .logits .detach ().argmax (2 ) # shape: [batch_size, num_logits_to_keep]
120
120
input_ids = torch .full ((self .full_batch_size , decode_len ), self .tokenizer .pad_token_id )
121
121
input_ids [batch_index .view (- 1 )] = batch_idx_input_ids
122
+
122
123
position_ids = torch .full ((self .full_batch_size , decode_len ), 0 )
123
124
batch_idx_position_ids = torch .arange (decode_len ).view (1 ,- 1 ) + (inputs ["position_ids" ].max (1 , keepdim = True ).values + 1 )
124
125
position_ids [batch_index .view (- 1 )] = batch_idx_position_ids
126
+
125
127
updated_inputs ["input_ids" ] = input_ids
126
128
updated_inputs ["position_ids" ] = position_ids
127
129
updated_inputs ["batch_index" ] = torch .arange (self .full_batch_size ).view (- 1 , 1 )
@@ -132,7 +134,7 @@ def update_pytorch_inputs(self, inputs, pt_outputs):
132
134
batch_size = input_ids .size (0 )
133
135
position_ids = torch .arange (self .num_logits_to_keep ).view (1 , self .num_logits_to_keep ).repeat (batch_size , 1 )
134
136
else :
135
- input_ids = pt_outputs ["logits" ].argmax (- 1 ).reshape (- 1 , 1 )
137
+ input_ids = pt_outputs ["logits" ].argmax (- 1 ).reshape (- 1 , 1 ) # shape: [batch_size, 1]
136
138
position_ids = inputs ["position_ids" ].max (1 , keepdim = True ).values + 1
137
139
updated_inputs ["input_ids" ] = input_ids
138
140
updated_inputs ["position_ids" ] = position_ids
0 commit comments