Skip to content

Commit 97eb708

Browse files
authored
Merge pull request #95 from codelion/fix-entropy-decoding-in-local-server
Fix entropy decoding in local server
2 parents c3535c4 + db31686 commit 97eb708

File tree

2 files changed

+43
-25
lines changed

2 files changed

+43
-25
lines changed

optillm/inference.py

Lines changed: 42 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1257,6 +1257,10 @@ def create(
12571257
# Handle specialized decoding approaches
12581258
if decoding:
12591259
logger.info(f"Using specialized decoding approach: {decoding}")
1260+
1261+
# Ensure model is in eval mode and on correct device
1262+
pipeline.current_model.eval()
1263+
device = pipeline.current_model.device
12601264

12611265
if decoding == "cot_decoding":
12621266
# Use directly available parameters for CoT
@@ -1284,32 +1288,46 @@ def create(
12841288
completion_tokens = len(pipeline.tokenizer.encode(result))
12851289

12861290
elif decoding == "entropy_decoding":
1287-
# Configure generator for entropy decoding
1288-
generator = None
1289-
if seed is not None:
1290-
generator = torch.Generator(device=pipeline.current_model.device)
1291-
generator.manual_seed(seed)
1292-
1293-
# Use directly available parameters for entropy decoding
12941291

1295-
entropy_params = {
1296-
"max_new_tokens": max_tokens if max_tokens is not None else 512,
1297-
"temperature": 0.666,
1298-
"top_p": 0.90,
1299-
"top_k": top_k,
1300-
"min_p": min_p,
1301-
"generator": generator
1302-
}
1292+
# Ensure model is using full precision
1293+
original_dtype = pipeline.current_model.dtype
1294+
pipeline.current_model = pipeline.current_model.to(torch.float32)
1295+
1296+
try:
1297+
# Configure generator for entropy decoding
1298+
generator = None
1299+
if seed is not None:
1300+
generator = torch.Generator(device=device)
1301+
generator.manual_seed(seed)
1302+
else:
1303+
generator = torch.Generator(device=device)
1304+
generator.manual_seed(1337) # Default seed as in original implementation
1305+
1306+
# Use directly available parameters for entropy decoding
1307+
entropy_params = {
1308+
"max_new_tokens": max_tokens if max_tokens is not None else 4096,
1309+
"temperature": temperature,
1310+
"top_p": top_p,
1311+
"top_k": top_k,
1312+
"min_p": min_p,
1313+
"generator": generator
1314+
}
1315+
1316+
# Disable autocast and run in full precision
1317+
with torch.amp.autocast('cuda', enabled=False), torch.inference_mode():
1318+
result = entropy_decode(
1319+
pipeline.current_model,
1320+
pipeline.tokenizer,
1321+
messages,
1322+
**entropy_params
1323+
)
1324+
responses = [result]
1325+
logprobs_results = [None]
1326+
completion_tokens = len(pipeline.tokenizer.encode(result))
13031327

1304-
result = entropy_decode(
1305-
pipeline.current_model,
1306-
pipeline.tokenizer,
1307-
messages,
1308-
**entropy_params
1309-
)
1310-
responses = [result]
1311-
logprobs_results = [None]
1312-
completion_tokens = len(pipeline.tokenizer.encode(result))
1328+
finally:
1329+
# Restore original dtype
1330+
pipeline.current_model = pipeline.current_model.to(original_dtype)
13131331

13141332
else:
13151333
raise ValueError(f"Unknown specialized decoding approach: {decoding}")

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
setup(
44
name="optillm",
5-
version="0.0.12",
5+
version="0.0.13",
66
packages=find_packages(),
77
py_modules=['optillm'],
88
package_data={

0 commit comments

Comments
 (0)