Skip to content

Commit 543209b

Browse files
authored
Add floating point options for autoquant and add accuracy measurement (#1355)
* Add floating point options for autoquant and add accuracy measurement Summary: * This PR adds float32/float16/bfloat16 as a list of options for autoquant, it converts input/weight/bias/output to the specified dtype * Also adds min_sqnr (https://en.wikipedia.org/wiki/Signal-to-quantization-noise_ratio) to allow users to filter out the quantization methods that has large numerical impact compared to original output Note that we use random generated input activation right now, we can improve this by adding the support for using real inputs Test Plan: python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --quantization autoquant-fp Reviewers: Subscribers: Tasks: Tags: * update docstring * fix * ruff * skip if no cuda
1 parent 04a25e7 commit 543209b

File tree

5 files changed

+190
-19
lines changed

5 files changed

+190
-19
lines changed

examples/sam2_amg_server/server.py

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -371,6 +371,7 @@ def main(checkpoint_path,
371371
baseline=False,
372372
fast=False,
373373
furious=False,
374+
use_autoquant=False,
374375
unittest=False,
375376
benchmark=False,
376377
profile=None,
@@ -399,13 +400,13 @@ def main(checkpoint_path,
399400
from torchao._models.sam2.build_sam import build_sam2
400401
from torchao._models.sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator
401402
from torchao._models.sam2.utils.amg import rle_to_mask
402-
403+
403404
device = "cuda"
404405
sam2_checkpoint, model_cfg = model_type_to_paths(checkpoint_path, model_type)
405-
406+
406407
logging.info(f"Loading model {sam2_checkpoint} with config {model_cfg}")
407408
sam2 = build_sam2(model_cfg, sam2_checkpoint, device=device, apply_postprocessing=False)
408-
409+
409410
logging.info(f"Using {points_per_batch} points_per_batch")
410411
mask_generator = SAM2AutomaticMaskGenerator(sam2, points_per_batch=points_per_batch, output_mode="uncompressed_rle")
411412

@@ -416,6 +417,18 @@ def main(checkpoint_path,
416417
if furious:
417418
set_furious(mask_generator)
418419

420+
# since autoquant is replicating what furious mode is doing, don't use these two together
421+
elif use_autoquant:
422+
from torchao import autoquant
423+
from torchao.quantization import DEFAULT_FLOAT_AUTOQUANT_CLASS_LIST
424+
mask_generator.predictor.model = autoquant(mask_generator.predictor.model, qtensor_class_list=DEFAULT_FLOAT_AUTOQUANT_CLASS_LIST, min_sqnr=40)
425+
426+
mask_generator.predictor.model.image_encoder = mask_generator.predictor.model.image_encoder.to(torch.float16, min_sqnr=40)
427+
# NOTE: Not baseline feature
428+
mask_generator.predictor._transforms_device = mask_generator.predictor.device
429+
torch.set_float32_matmul_precision('high')
430+
431+
419432
with open('dog.jpg', 'rb') as f:
420433
image_tensor = file_bytes_to_image_tensor(bytearray(f.read()))
421434

@@ -494,7 +507,7 @@ async def upload_rle(image: UploadFile = File(...)):
494507
await request_queue.put((image_tensor, response_future))
495508
masks = await response_future
496509
return masks_to_rle_dict(masks)
497-
510+
498511
@app.post("/upload")
499512
async def upload_image(image: UploadFile = File(...)):
500513
image_tensor = file_bytes_to_image_tensor(bytearray(await image.read()))
@@ -512,7 +525,7 @@ async def upload_image(image: UploadFile = File(...)):
512525
plt.savefig(buf, format='png')
513526
buf.seek(0)
514527
return StreamingResponse(buf, media_type="image/png")
515-
528+
516529

517530
# uvicorn.run(app, host=host, port=port, log_level="info")
518531
uvicorn.run(app, host=host, port=port)

test/integration/test_integration.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1514,6 +1514,23 @@ def forward(self, x):
15141514
assert not isinstance(model.lin1.weight.weight, AutoQuantizableLinearWeight)
15151515
model(x_in)
15161516

1517+
@parameterized.expand(list(itertools.product(["cuda"], COMMON_DTYPES)))
1518+
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
1519+
def test_autoquant_min_sqnr(self, device, dtype):
1520+
m, k, n = 128, 128, 128
1521+
example_input = torch.randn(m, k, device=device, dtype=dtype)
1522+
model = torch.nn.Sequential(
1523+
torch.nn.ReLU(),
1524+
torch.nn.Linear(k,n),
1525+
torch.nn.ReLU(),
1526+
).to(device).to(dtype)
1527+
out = model(example_input)
1528+
torchao.autoquant(model, min_sqnr=60)
1529+
out2 = model(example_input)
1530+
sqnr = SQNR(out, out2)
1531+
# without setting min_sqnr to 60, we get around 45-50 final sqnr
1532+
# setting min_sqnr for individual linear to be 60 allows us to achieve >= 50 final sqnr
1533+
self.assertTrue(sqnr >= 50, f"sqnr: {sqnr}")
15171534

15181535

15191536

torchao/_models/llama/generate.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -402,6 +402,8 @@ def main(
402402
model = autoquant(model, manual=True, qtensor_class_list = torchao.quantization.DEFAULT_INT4_AUTOQUANT_CLASS_LIST, example_input=inputs)
403403
elif "autoquant-float8" == quantization:
404404
model = autoquant(model, manual=True, qtensor_class_list = torchao.quantization.OTHER_AUTOQUANT_CLASS_LIST, example_input=inputs)
405+
if "autoquant-fp" == quantization:
406+
model = autoquant(model, manual=True, qtensor_class_list = torchao.quantization.DEFAULT_FLOAT_AUTOQUANT_CLASS_LIST, example_input=inputs)
405407
else:
406408
model = autoquant(model, manual=True, example_input=inputs)
407409

torchao/quantization/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
from .autoquant import (
1313
DEFAULT_AUTOQUANT_CLASS_LIST,
14+
DEFAULT_FLOAT_AUTOQUANT_CLASS_LIST,
1415
DEFAULT_INT4_AUTOQUANT_CLASS_LIST,
1516
OTHER_AUTOQUANT_CLASS_LIST,
1617
autoquant,
@@ -89,6 +90,7 @@
8990
"autoquant",
9091
"DEFAULT_AUTOQUANT_CLASS_LIST",
9192
"DEFAULT_INT4_AUTOQUANT_CLASS_LIST",
93+
"DEFAULT_FLOAT_AUTOQUANT_CLASS_LIST",
9294
"OTHER_AUTOQUANT_CLASS_LIST",
9395
# top level API - manual
9496
"quantize_",

0 commit comments

Comments
 (0)