@@ -1257,6 +1257,10 @@ def create(
12571257 # Handle specialized decoding approaches
12581258 if decoding :
12591259 logger .info (f"Using specialized decoding approach: { decoding } " )
1260+
1261+ # Ensure model is in eval mode and on correct device
1262+ pipeline .current_model .eval ()
1263+ device = pipeline .current_model .device
12601264
12611265 if decoding == "cot_decoding" :
12621266 # Use directly available parameters for CoT
@@ -1284,32 +1288,46 @@ def create(
12841288 completion_tokens = len (pipeline .tokenizer .encode (result ))
12851289
12861290 elif decoding == "entropy_decoding" :
1287- # Configure generator for entropy decoding
1288- generator = None
1289- if seed is not None :
1290- generator = torch .Generator (device = pipeline .current_model .device )
1291- generator .manual_seed (seed )
1292-
1293- # Use directly available parameters for entropy decoding
12941291
1295- entropy_params = {
1296- "max_new_tokens" : max_tokens if max_tokens is not None else 512 ,
1297- "temperature" : 0.666 ,
1298- "top_p" : 0.90 ,
1299- "top_k" : top_k ,
1300- "min_p" : min_p ,
1301- "generator" : generator
1302- }
1292+ # Ensure model is using full precision
1293+ original_dtype = pipeline .current_model .dtype
1294+ pipeline .current_model = pipeline .current_model .to (torch .float32 )
1295+
1296+ try :
1297+ # Configure generator for entropy decoding
1298+ generator = None
1299+ if seed is not None :
1300+ generator = torch .Generator (device = device )
1301+ generator .manual_seed (seed )
1302+ else :
1303+ generator = torch .Generator (device = device )
1304+ generator .manual_seed (1337 ) # Default seed as in original implementation
1305+
1306+ # Use directly available parameters for entropy decoding
1307+ entropy_params = {
1308+ "max_new_tokens" : max_tokens if max_tokens is not None else 4096 ,
1309+ "temperature" : temperature ,
1310+ "top_p" : top_p ,
1311+ "top_k" : top_k ,
1312+ "min_p" : min_p ,
1313+ "generator" : generator
1314+ }
1315+
1316+ # Disable autocast and run in full precision
1317+ with torch .amp .autocast ('cuda' , enabled = False ), torch .inference_mode ():
1318+ result = entropy_decode (
1319+ pipeline .current_model ,
1320+ pipeline .tokenizer ,
1321+ messages ,
1322+ ** entropy_params
1323+ )
1324+ responses = [result ]
1325+ logprobs_results = [None ]
1326+ completion_tokens = len (pipeline .tokenizer .encode (result ))
13031327
1304- result = entropy_decode (
1305- pipeline .current_model ,
1306- pipeline .tokenizer ,
1307- messages ,
1308- ** entropy_params
1309- )
1310- responses = [result ]
1311- logprobs_results = [None ]
1312- completion_tokens = len (pipeline .tokenizer .encode (result ))
1328+ finally :
1329+ # Restore original dtype
1330+ pipeline .current_model = pipeline .current_model .to (original_dtype )
13131331
13141332 else :
13151333 raise ValueError (f"Unknown specialized decoding approach: { decoding } " )
0 commit comments