Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit 30afac5

Browse files
committed
couple other fixes to address style nits missing in earlier commit
1 parent 89ff85f commit 30afac5

File tree

1 file changed

+6
-41
lines changed

1 file changed

+6
-41
lines changed

torchchat/utils/quantize.py

Lines changed: 6 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,6 @@
2020
# torchao Quantizer:
2121
# * Int8DynActInt4WeightQuantizer: dynamic quantization for int8 acitvation and int4 weight. Using torchao API.
2222
#
23-
24-
2523
from __future__ import annotations
2624

2725
import json
@@ -30,7 +28,6 @@
3028
# from math import gcd
3129

3230
from typing import Any, Callable, Dict, List, Optional
33-
3431
import torch
3532
import torch.nn as nn
3633
import torch.nn.functional as F
@@ -62,7 +59,6 @@
6259

6360

6461
# Flag for whether the a8wxdq quantizer is available.
65-
6662
torchao_experimental_load_error: Optional[Exception] = None
6763

6864
#########################################################################
@@ -79,20 +75,16 @@ def get_named_parameters(func: Callable) -> List[str]:
7975

8076
# Filter and return named parameters
8177
named_params = [
82-
name
83-
for name, param in parameters.items()
84-
if param.kind
85-
in (inspect.Parameter.POSITIONAL_OR_KEYWORD, inspect.Parameter.KEYWORD_ONLY)
78+
name for name, param in parameters.items()
79+
if param.kind in (inspect.Parameter.POSITIONAL_OR_KEYWORD, inspect.Parameter.KEYWORD_ONLY)
8680
]
8781
return named_params
8882

8983

9084
def validate_args(named_params: List[str], q_kwargs: Dict[str, Any], quantizer: Optional[str] = None) -> Dict[str, Any]:
9185
for key in list(q_kwargs.keys()):
9286
if key not in named_params:
93-
print(
94-
f"Specification for quantizer {quantizer} has extraneous key {key}. Ignoring."
95-
)
87+
print(f"Specification for quantizer {quantizer} has extraneous key {key}. Ignoring.")
9688
del q_kwargs[key]
9789
return q_kwargs
9890

@@ -232,7 +224,6 @@ def quantize_model(
232224
if quantizer == "embedding:wx":
233225
# These quantizers require float32 input weights. Note that after quantization,
234226
# the weights will no longer be float32, but lowbit integers
235-
236227
if get_precision() != torch.float32:
237228
print(
238229
f"Quantizer {quantizer} requires float32 inputs, but received {get_precision()}. Changing dtype to float32. Note that after quantization, the weights will be lowbit integers, not float32."
@@ -261,15 +252,13 @@ def quantize_model(
261252
)
262253
# We set global precision from quantize options if it is specified at cli.py:485
263254
# so the precision returned by get_precision() is always the authoritative precision/dtype in torchchat
264-
265255
precision = get_precision()
266256

267257
q = quantizer_class_dict[quantizer]
268258
named_params = get_named_parameters(q.__init__)
269259
q_kwargs = validate_args(named_params, q_kwargs, quantizer)
270260

271261
# Handle tokenizer for scenarios where the quantizer needs to tokenizer sample inputs
272-
273262
if "tokenizer" in named_params:
274263
q_kwargs["tokenizer"] = tokenizer
275264
if quantizer == "embedding:wx":
@@ -278,7 +267,6 @@ def quantize_model(
278267
quant_handler = q(device=device, precision=precision, **q_kwargs)
279268

280269
# quantize model
281-
282270
model = quant_handler.quantize(model)
283271

284272

@@ -288,13 +276,7 @@ def quantize_model(
288276

289277

290278
class QuantHandler:
291-
def __init__(
292-
self,
293-
model: Optional[nn.Module] = None,
294-
device="cpu",
295-
precision=None,
296-
tokenizer=None,
297-
):
279+
def __init__(self, model: Optional[nn.Module] = None, device="cpu", precision=None, tokenizer=None):
298280
self.model_ = model
299281
self.device = device
300282
self.tokenizer = tokenizer
@@ -312,7 +294,6 @@ def quantized_model(self) -> nn.Module:
312294
return self.model_
313295

314296
# fallback for TC QuantHandlers that do not implement the method .quantize()
315-
316297
def quantize(self, model: nn.Module) -> nn.Module:
317298
self.model_ = model
318299
return self.quantized_model()
@@ -323,15 +304,7 @@ def quantize(self, model: nn.Module) -> nn.Module:
323304

324305

325306
class PrecisionHandler(QuantHandler):
326-
def __init__(
327-
self,
328-
model: Optional[nn.Module] = None,
329-
device="cpu",
330-
precision=None,
331-
tokenizer=None,
332-
*,
333-
dtype,
334-
):
307+
def __init__(self, model: Optional[nn.Module]=None, device="cpu", precision=None, tokenizer=None, *, dtype):
335308
self.model_ = model
336309
self.device = device
337310
self.tokenizer = tokenizer
@@ -360,15 +333,7 @@ def quantized_model(self) -> nn.Module:
360333

361334

362335
class ExecutorHandler(QuantHandler):
363-
def __init__(
364-
self,
365-
model: Optional[nn.Module] = None,
366-
device="cpu",
367-
precision=None,
368-
tokenizer=None,
369-
*,
370-
accelerator,
371-
):
336+
def __init__(self, model: Optional[nn.Module]=None, device="cpu", precision=None, tokenizer=None, *, accelerator):
372337
self.model_ = model
373338

374339
if isinstance(accelerator, str):

0 commit comments

Comments
 (0)