Skip to content

Commit a8e88c5

Browse files
authored
Merge branch 'main' into fix_harware_check
2 parents 96d271e + ed76e9c commit a8e88c5

File tree

21 files changed

+831
-129
lines changed

21 files changed

+831
-129
lines changed

.github/workflows/regression_test.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ jobs:
7070
torch-spec: 'torch==2.5.1 --index-url https://download.pytorch.org/whl/cu121'
7171
gpu-arch-type: "cuda"
7272
gpu-arch-version: "12.1"
73+
7374
- name: CPU 2.3
7475
runs-on: linux.4xlarge
7576
torch-spec: 'torch==2.3.0 --index-url https://download.pytorch.org/whl/cpu'

examples/sam2_amg_server/cli.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
from server import model_type_to_paths
77
from server import MODEL_TYPES_TO_MODEL
88
from server import set_fast
9+
from server import set_aot_fast
10+
from server import load_aot_fast
911
from server import set_furious
1012
from torchao._models.sam2.build_sam import build_sam2
1113
from torchao._models.sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator
@@ -22,17 +24,20 @@ def main_docstring():
2224
"""
2325

2426

25-
def main_headless(checkpoint_path, model_type, input_bytes, points_per_batch=1024, output_format='png', verbose=False, fast=False, furious=False):
27+
def main_headless(checkpoint_path, model_type, input_bytes, points_per_batch=1024, output_format='png', verbose=False, fast=False, furious=False, load_fast=""):
2628
device = "cuda"
2729
sam2_checkpoint, model_cfg = model_type_to_paths(checkpoint_path, model_type)
2830
if verbose:
2931
print(f"Loading model {sam2_checkpoint} with config {model_cfg}")
3032
sam2 = build_sam2(model_cfg, sam2_checkpoint, device=device, apply_postprocessing=False)
3133
mask_generator = SAM2AutomaticMaskGenerator(sam2, points_per_batch=points_per_batch, output_mode="uncompressed_rle")
32-
if fast:
33-
set_fast(mask_generator)
3434
if furious:
3535
set_furious(mask_generator)
36+
if load_fast:
37+
load_aot_fast(mask_generator, load_fast)
38+
if fast:
39+
set_fast(mask_generator, load_fast)
40+
3641
image_tensor = file_bytes_to_image_tensor(input_bytes)
3742
if verbose:
3843
print(f"Loaded image of size {tuple(image_tensor.shape)} and generating mask.")
@@ -50,7 +55,7 @@ def main_headless(checkpoint_path, model_type, input_bytes, points_per_batch=102
5055
buf.seek(0)
5156
return buf.getvalue()
5257

53-
def main(checkpoint_path, model_type, input_path, output_path, points_per_batch=1024, output_format='png', verbose=False, fast=False, furious=False):
58+
def main(checkpoint_path, model_type, input_path, output_path, points_per_batch=1024, output_format='png', verbose=False, fast=False, furious=False, load_fast=""):
5459
input_bytes = bytearray(open(input_path, 'rb').read())
5560
output_bytes = main_headless(checkpoint_path,
5661
model_type,
@@ -59,7 +64,8 @@ def main(checkpoint_path, model_type, input_path, output_path, points_per_batch=
5964
output_format=output_format,
6065
verbose=verbose,
6166
fast=fast,
62-
furious=furious)
67+
furious=furious,
68+
load_fast=load_fast)
6369
with open(output_path, "wb") as file:
6470
file.write(output_bytes)
6571

examples/sam2_amg_server/server.py

Lines changed: 197 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -332,14 +332,175 @@ def model_type_to_paths(checkpoint_path, model_type):
332332
model_cfg = f"configs/sam2.1/{MODEL_TYPES_TO_CONFIG[model_type]}"
333333
return sam2_checkpoint, model_cfg
334334

335-
def set_fast(mask_generator):
336-
# TODO: Using CUDA graphs can cause numerical differences?
337-
mask_generator.predictor.model.image_encoder = torch.compile(
338-
mask_generator.predictor.model.image_encoder,
339-
mode="max-autotune",
340-
fullgraph=True,
341-
dynamic=False,
335+
336+
def aot_compile(model_directory, name, fn, sample_args):
337+
path = Path(model_directory) / Path(f"{name}.pt2")
338+
print(f"Saving at {path=}")
339+
options = {
340+
"max_autotune": True,
341+
"triton.cudagraphs": True,
342+
}
343+
344+
exported = torch.export.export_for_inference(fn, sample_args)
345+
output_path = torch._inductor.aoti_compile_and_package(
346+
exported,
347+
package_path=str(path),
348+
inductor_configs=options,
342349
)
350+
return output_path
351+
352+
353+
def aot_load(path):
354+
return torch._export.aot_load(path, "cuda")
355+
356+
class FunctionModel(torch.nn.Module):
357+
358+
def __init__(self, module, fn_name):
359+
super().__init__()
360+
self.module = module
361+
self.fn_name = fn_name
362+
363+
def forward(self, *args):
364+
return getattr(self.module, self.fn_name)(*args)
365+
366+
367+
def set_aot_fast(mask_generator, model_directory):
368+
example_input = torch.empty(1, 3, 1024, 1024)
369+
example_input = example_input.to(mask_generator.predictor._image_dtype)
370+
example_input = (example_input.to(mask_generator.predictor.device),)
371+
aot_compile(model_directory,
372+
"sam2_image_encoder",
373+
mask_generator.predictor.model.image_encoder,
374+
example_input)
375+
376+
# NOTE: THIS DOESN'T WORK YET!
377+
# example_input_0_0 = torch.empty(1, 32, 256, 256, dtype=torch.float16, device=mask_generator.predictor.device)
378+
# example_input_0_1 = torch.empty(1, 64, 128, 128, dtype=torch.float16, device=mask_generator.predictor.device)
379+
# example_input_1 = torch.empty(1, 256, 64, 64, dtype=torch.float32, device=mask_generator.predictor.device)
380+
# example_input_2 = torch.empty(1024, 1, 2, dtype=torch.float32, device=mask_generator.predictor.device)
381+
# example_input_3 = torch.empty(1024, 1, dtype=torch.int32, device=mask_generator.predictor.device)
382+
# example_input = ([example_input_0_0, example_input_0_1],
383+
# example_input_1,
384+
# example_input_2,
385+
# example_input_3,
386+
# None,
387+
# None,
388+
# True,
389+
# True,
390+
# -1)
391+
# mask_generator.forward = mask_generator.predictor._predict_masks_with_features
392+
# mask_generator(*example_input)
393+
# aot_compile("sam2__predict_masks_with_features",
394+
# mask_generator,
395+
# example_input)
396+
397+
# example_input_2 = torch.empty(1024, 1, 2, dtype=torch.float32, device=mask_generator.predictor.device)
398+
# example_input_3 = torch.empty(1024, 1, dtype=torch.int32, device=mask_generator.predictor.device)
399+
# aot_compile("sam2_sam_prompt_encoder",
400+
# mask_generator.predictor.model.sam_prompt_encoder,
401+
# ((example_input_2, example_input_3),
402+
# None,
403+
# None))
404+
405+
# NOTE: THIS DOESN'T WORK YET!
406+
# example_input_0 = torch.empty(1, 256, 64, 64, dtype=torch.float32, device=mask_generator.predictor.device)
407+
# example_input_1 = torch.empty(1, 256, 64, 64, dtype=torch.float32, device=mask_generator.predictor.device)
408+
# example_input_2 = torch.empty(1024, 2, 256, dtype=torch.float32, device=mask_generator.predictor.device)
409+
# example_input_3 = torch.empty(1024, 256, 64, 64, dtype=torch.float32, device=mask_generator.predictor.device)
410+
411+
# example_input_4_0 = torch.empty(1, 32, 256, 256, dtype=torch.float16, device=mask_generator.predictor.device)
412+
# example_input_4_1 = torch.empty(1, 64, 128, 128, dtype=torch.float16, device=mask_generator.predictor.device)
413+
414+
# example_input = (example_input_0,
415+
# example_input_1,
416+
# example_input_2,
417+
# example_input_3,
418+
# True,
419+
# True,
420+
# [example_input_4_0, example_input_4_1])
421+
# print("Example")
422+
# mask_generator.predictor.model.sam_mask_decoder(*example_input)
423+
# print("Example done")
424+
# aot_compile("sam2_sam_mask_decoder",
425+
# mask_generator.predictor.model.sam_mask_decoder,
426+
# example_input)
427+
428+
# example_input_0 = torch.empty(1024, 256, 64, 64, dtype=torch.float16, device=mask_generator.predictor.device)
429+
# example_input_1 = torch.empty(1024, 256, 64, 64, dtype=torch.float16, device=mask_generator.predictor.device)
430+
# example_input_2 = torch.empty(1024, 8, 256, dtype=torch.float16, device=mask_generator.predictor.device)
431+
# example_input = (example_input_0, example_input_1, example_input_2)
432+
433+
# mask_generator.predictor.model.sam_mask_decoder.transformer(*example_input)
434+
# aot_compile("sam2_sam_mask_decoder_transformer",
435+
# mask_generator.predictor.model.sam_mask_decoder.transformer,
436+
# example_input)
437+
438+
439+
440+
441+
class LoadedModel(torch.nn.Module):
442+
443+
def __init__(self, aoti_compiled_model):
444+
super().__init__()
445+
self.aoti_compiled_model = aoti_compiled_model
446+
447+
def forward(self, *args):
448+
return self.aoti_compiled_model(*args)
449+
450+
class LoadedDecoder(torch.nn.Module):
451+
452+
def __init__(self, aoti_compiled_model, other):
453+
super().__init__()
454+
self.aoti_compiled_model = aoti_compiled_model
455+
self.other = other
456+
457+
def forward(self, *args):
458+
return self.aoti_compiled_model(*args)
459+
460+
def get_dense_pe(self, *args, **kwargs) -> torch.Tensor:
461+
return self.other.get_dense_pe(*args, **kwargs)
462+
463+
def load_aot_fast(mask_generator, model_directory):
464+
t0 = time.time()
465+
path = Path(model_directory) / Path(f"sam2_image_encoder.pt2")
466+
assert path.exists(), f"Expected {path} to exist."
467+
print(f"Start load from {path}")
468+
pkg = torch._inductor.aoti_load_package(str(path))
469+
pkg_m = LoadedModel(pkg)
470+
mask_generator.predictor.model.image_encoder = pkg_m
471+
472+
# NOTE: This doesn't work yet!
473+
# pkg = torch._inductor.aoti_load_package(os.path.join(os.getcwd(), "sam2__predict_masks_with_features.pt2"))
474+
# pkg_m = LoadedModel(pkg)
475+
# mask_generator.predictor._predict_masks_with_features = pkg_m.forward
476+
477+
# pkg = torch._inductor.aoti_load_package(os.path.join(os.getcwd(), "sam2_sam_prompt_encoder.pt2"))
478+
# pkg_m = LoadedDecoder(pkg, mask_generator.predictor.model.sam_prompt_encoder)
479+
# mask_generator.predictor.model.sam_prompt_encoder = pkg_m
480+
481+
# NOTE: This doesn't work yet!
482+
# pkg = torch._inductor.aoti_load_package(os.path.join(os.getcwd(), "sam2_sam_mask_decoder.pt2"))
483+
# pkg_m = LoadedModel(pkg)
484+
# pkg_m.conv_s0 = mask_generator.predictor.model.sam_mask_decoder.conv_s0
485+
# pkg_m.conv_s1 = mask_generator.predictor.model.sam_mask_decoder.conv_s1
486+
# mask_generator.predictor.model.sam_mask_decoder = pkg_m
487+
488+
# pkg = torch._inductor.aoti_load_package(os.path.join(os.getcwd(), "sam2_sam_mask_decoder_transformer.pt2"))
489+
# pkg_m = LoadedModel(pkg)
490+
# mask_generator.predictor.model.sam_mask_decoder.transformer = pkg_m
491+
492+
print(f"End load. Took {time.time() - t0}s")
493+
494+
495+
def set_fast(mask_generator, load_fast=""):
496+
if load_fast == "":
497+
# TODO: Using CUDA graphs can cause numerical differences?
498+
mask_generator.predictor.model.image_encoder = torch.compile(
499+
mask_generator.predictor.model.image_encoder,
500+
mode="max-autotune",
501+
fullgraph=True,
502+
dynamic=False,
503+
)
343504

344505
mask_generator.predictor._predict_masks = torch.compile(
345506
mask_generator.predictor._predict_masks,
@@ -371,6 +532,7 @@ def main(checkpoint_path,
371532
baseline=False,
372533
fast=False,
373534
furious=False,
535+
use_autoquant=False,
374536
unittest=False,
375537
benchmark=False,
376538
profile=None,
@@ -380,7 +542,9 @@ def main(checkpoint_path,
380542
port=5000,
381543
host="127.0.0.1",
382544
dry=False,
383-
batch_size=1):
545+
batch_size=1,
546+
load_fast="",
547+
save_fast=""):
384548
if verbose:
385549
logging.basicConfig(level=logging.INFO,
386550
format='%(asctime)s - %(levelname)s - %(message)s',
@@ -399,22 +563,41 @@ def main(checkpoint_path,
399563
from torchao._models.sam2.build_sam import build_sam2
400564
from torchao._models.sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator
401565
from torchao._models.sam2.utils.amg import rle_to_mask
402-
566+
403567
device = "cuda"
404568
sam2_checkpoint, model_cfg = model_type_to_paths(checkpoint_path, model_type)
405-
569+
406570
logging.info(f"Loading model {sam2_checkpoint} with config {model_cfg}")
407571
sam2 = build_sam2(model_cfg, sam2_checkpoint, device=device, apply_postprocessing=False)
408-
572+
409573
logging.info(f"Using {points_per_batch} points_per_batch")
410574
mask_generator = SAM2AutomaticMaskGenerator(sam2, points_per_batch=points_per_batch, output_mode="uncompressed_rle")
411575

576+
if load_fast != "":
577+
load_aot_fast(mask_generator, load_fast)
578+
579+
if save_fast != "":
580+
assert load_fast == "", "Can't save compiled models while loading them with --load-fast."
581+
assert not baseline, "--fast cannot be combined with baseline. code to be torch.compile(fullgraph=True) compatible."
582+
print(f"Saving compiled models under directory {save_fast}")
583+
set_aot_fast(mask_generator, save_fast)
584+
412585
if fast:
413586
assert not baseline, "--fast cannot be combined with baseline. code to be torch.compile(fullgraph=True) compatible."
414-
set_fast(mask_generator)
587+
set_fast(mask_generator, load_fast)
415588

416589
if furious:
417590
set_furious(mask_generator)
591+
# since autoquant is replicating what furious mode is doing, don't use these two together
592+
elif use_autoquant:
593+
from torchao import autoquant
594+
from torchao.quantization import DEFAULT_FLOAT_AUTOQUANT_CLASS_LIST
595+
mask_generator.predictor.model.image_encoder = autoquant(mask_generator.predictor.model.image_encoder, qtensor_class_list=DEFAULT_FLOAT_AUTOQUANT_CLASS_LIST, min_sqnr=40)
596+
597+
# mask_generator.predictor.model.image_encoder = mask_generator.predictor.model.image_encoder.to(torch.float16, min_sqnr=40)
598+
# NOTE: Not baseline feature
599+
mask_generator.predictor._transforms_device = mask_generator.predictor.device
600+
torch.set_float32_matmul_precision('high')
418601

419602
with open('dog.jpg', 'rb') as f:
420603
image_tensor = file_bytes_to_image_tensor(bytearray(f.read()))
@@ -494,7 +677,7 @@ async def upload_rle(image: UploadFile = File(...)):
494677
await request_queue.put((image_tensor, response_future))
495678
masks = await response_future
496679
return masks_to_rle_dict(masks)
497-
680+
498681
@app.post("/upload")
499682
async def upload_image(image: UploadFile = File(...)):
500683
image_tensor = file_bytes_to_image_tensor(bytearray(await image.read()))
@@ -512,7 +695,7 @@ async def upload_image(image: UploadFile = File(...)):
512695
plt.savefig(buf, format='png')
513696
buf.seek(0)
514697
return StreamingResponse(buf, media_type="image/png")
515-
698+
516699

517700
# uvicorn.run(app, host=host, port=port, log_level="info")
518701
uvicorn.run(app, host=host, port=port)

0 commit comments

Comments
 (0)