@@ -332,14 +332,175 @@ def model_type_to_paths(checkpoint_path, model_type):
332332 model_cfg = f"configs/sam2.1/{ MODEL_TYPES_TO_CONFIG [model_type ]} "
333333 return sam2_checkpoint , model_cfg
334334
335- def set_fast (mask_generator ):
336- # TODO: Using CUDA graphs can cause numerical differences?
337- mask_generator .predictor .model .image_encoder = torch .compile (
338- mask_generator .predictor .model .image_encoder ,
339- mode = "max-autotune" ,
340- fullgraph = True ,
341- dynamic = False ,
335+
336+ def aot_compile (model_directory , name , fn , sample_args ):
337+ path = Path (model_directory ) / Path (f"{ name } .pt2" )
338+ print (f"Saving at { path = } " )
339+ options = {
340+ "max_autotune" : True ,
341+ "triton.cudagraphs" : True ,
342+ }
343+
344+ exported = torch .export .export_for_inference (fn , sample_args )
345+ output_path = torch ._inductor .aoti_compile_and_package (
346+ exported ,
347+ package_path = str (path ),
348+ inductor_configs = options ,
342349 )
350+ return output_path
351+
352+
353+ def aot_load (path ):
354+ return torch ._export .aot_load (path , "cuda" )
355+
356+ class FunctionModel (torch .nn .Module ):
357+
358+ def __init__ (self , module , fn_name ):
359+ super ().__init__ ()
360+ self .module = module
361+ self .fn_name = fn_name
362+
363+ def forward (self , * args ):
364+ return getattr (self .module , self .fn_name )(* args )
365+
366+
367+ def set_aot_fast (mask_generator , model_directory ):
368+ example_input = torch .empty (1 , 3 , 1024 , 1024 )
369+ example_input = example_input .to (mask_generator .predictor ._image_dtype )
370+ example_input = (example_input .to (mask_generator .predictor .device ),)
371+ aot_compile (model_directory ,
372+ "sam2_image_encoder" ,
373+ mask_generator .predictor .model .image_encoder ,
374+ example_input )
375+
376+ # NOTE: THIS DOESN'T WORK YET!
377+ # example_input_0_0 = torch.empty(1, 32, 256, 256, dtype=torch.float16, device=mask_generator.predictor.device)
378+ # example_input_0_1 = torch.empty(1, 64, 128, 128, dtype=torch.float16, device=mask_generator.predictor.device)
379+ # example_input_1 = torch.empty(1, 256, 64, 64, dtype=torch.float32, device=mask_generator.predictor.device)
380+ # example_input_2 = torch.empty(1024, 1, 2, dtype=torch.float32, device=mask_generator.predictor.device)
381+ # example_input_3 = torch.empty(1024, 1, dtype=torch.int32, device=mask_generator.predictor.device)
382+ # example_input = ([example_input_0_0, example_input_0_1],
383+ # example_input_1,
384+ # example_input_2,
385+ # example_input_3,
386+ # None,
387+ # None,
388+ # True,
389+ # True,
390+ # -1)
391+ # mask_generator.forward = mask_generator.predictor._predict_masks_with_features
392+ # mask_generator(*example_input)
393+ # aot_compile("sam2__predict_masks_with_features",
394+ # mask_generator,
395+ # example_input)
396+
397+ # example_input_2 = torch.empty(1024, 1, 2, dtype=torch.float32, device=mask_generator.predictor.device)
398+ # example_input_3 = torch.empty(1024, 1, dtype=torch.int32, device=mask_generator.predictor.device)
399+ # aot_compile("sam2_sam_prompt_encoder",
400+ # mask_generator.predictor.model.sam_prompt_encoder,
401+ # ((example_input_2, example_input_3),
402+ # None,
403+ # None))
404+
405+ # NOTE: THIS DOESN'T WORK YET!
406+ # example_input_0 = torch.empty(1, 256, 64, 64, dtype=torch.float32, device=mask_generator.predictor.device)
407+ # example_input_1 = torch.empty(1, 256, 64, 64, dtype=torch.float32, device=mask_generator.predictor.device)
408+ # example_input_2 = torch.empty(1024, 2, 256, dtype=torch.float32, device=mask_generator.predictor.device)
409+ # example_input_3 = torch.empty(1024, 256, 64, 64, dtype=torch.float32, device=mask_generator.predictor.device)
410+
411+ # example_input_4_0 = torch.empty(1, 32, 256, 256, dtype=torch.float16, device=mask_generator.predictor.device)
412+ # example_input_4_1 = torch.empty(1, 64, 128, 128, dtype=torch.float16, device=mask_generator.predictor.device)
413+
414+ # example_input = (example_input_0,
415+ # example_input_1,
416+ # example_input_2,
417+ # example_input_3,
418+ # True,
419+ # True,
420+ # [example_input_4_0, example_input_4_1])
421+ # print("Example")
422+ # mask_generator.predictor.model.sam_mask_decoder(*example_input)
423+ # print("Example done")
424+ # aot_compile("sam2_sam_mask_decoder",
425+ # mask_generator.predictor.model.sam_mask_decoder,
426+ # example_input)
427+
428+ # example_input_0 = torch.empty(1024, 256, 64, 64, dtype=torch.float16, device=mask_generator.predictor.device)
429+ # example_input_1 = torch.empty(1024, 256, 64, 64, dtype=torch.float16, device=mask_generator.predictor.device)
430+ # example_input_2 = torch.empty(1024, 8, 256, dtype=torch.float16, device=mask_generator.predictor.device)
431+ # example_input = (example_input_0, example_input_1, example_input_2)
432+
433+ # mask_generator.predictor.model.sam_mask_decoder.transformer(*example_input)
434+ # aot_compile("sam2_sam_mask_decoder_transformer",
435+ # mask_generator.predictor.model.sam_mask_decoder.transformer,
436+ # example_input)
437+
438+
439+
440+
441+ class LoadedModel (torch .nn .Module ):
442+
443+ def __init__ (self , aoti_compiled_model ):
444+ super ().__init__ ()
445+ self .aoti_compiled_model = aoti_compiled_model
446+
447+ def forward (self , * args ):
448+ return self .aoti_compiled_model (* args )
449+
450+ class LoadedDecoder (torch .nn .Module ):
451+
452+ def __init__ (self , aoti_compiled_model , other ):
453+ super ().__init__ ()
454+ self .aoti_compiled_model = aoti_compiled_model
455+ self .other = other
456+
457+ def forward (self , * args ):
458+ return self .aoti_compiled_model (* args )
459+
460+ def get_dense_pe (self , * args , ** kwargs ) -> torch .Tensor :
461+ return self .other .get_dense_pe (* args , ** kwargs )
462+
463+ def load_aot_fast (mask_generator , model_directory ):
464+ t0 = time .time ()
465+ path = Path (model_directory ) / Path (f"sam2_image_encoder.pt2" )
466+ assert path .exists (), f"Expected { path } to exist."
467+ print (f"Start load from { path } " )
468+ pkg = torch ._inductor .aoti_load_package (str (path ))
469+ pkg_m = LoadedModel (pkg )
470+ mask_generator .predictor .model .image_encoder = pkg_m
471+
472+ # NOTE: This doesn't work yet!
473+ # pkg = torch._inductor.aoti_load_package(os.path.join(os.getcwd(), "sam2__predict_masks_with_features.pt2"))
474+ # pkg_m = LoadedModel(pkg)
475+ # mask_generator.predictor._predict_masks_with_features = pkg_m.forward
476+
477+ # pkg = torch._inductor.aoti_load_package(os.path.join(os.getcwd(), "sam2_sam_prompt_encoder.pt2"))
478+ # pkg_m = LoadedDecoder(pkg, mask_generator.predictor.model.sam_prompt_encoder)
479+ # mask_generator.predictor.model.sam_prompt_encoder = pkg_m
480+
481+ # NOTE: This doesn't work yet!
482+ # pkg = torch._inductor.aoti_load_package(os.path.join(os.getcwd(), "sam2_sam_mask_decoder.pt2"))
483+ # pkg_m = LoadedModel(pkg)
484+ # pkg_m.conv_s0 = mask_generator.predictor.model.sam_mask_decoder.conv_s0
485+ # pkg_m.conv_s1 = mask_generator.predictor.model.sam_mask_decoder.conv_s1
486+ # mask_generator.predictor.model.sam_mask_decoder = pkg_m
487+
488+ # pkg = torch._inductor.aoti_load_package(os.path.join(os.getcwd(), "sam2_sam_mask_decoder_transformer.pt2"))
489+ # pkg_m = LoadedModel(pkg)
490+ # mask_generator.predictor.model.sam_mask_decoder.transformer = pkg_m
491+
492+ print (f"End load. Took { time .time () - t0 } s" )
493+
494+
495+ def set_fast (mask_generator , load_fast = "" ):
496+ if load_fast == "" :
497+ # TODO: Using CUDA graphs can cause numerical differences?
498+ mask_generator .predictor .model .image_encoder = torch .compile (
499+ mask_generator .predictor .model .image_encoder ,
500+ mode = "max-autotune" ,
501+ fullgraph = True ,
502+ dynamic = False ,
503+ )
343504
344505 mask_generator .predictor ._predict_masks = torch .compile (
345506 mask_generator .predictor ._predict_masks ,
@@ -371,6 +532,7 @@ def main(checkpoint_path,
371532 baseline = False ,
372533 fast = False ,
373534 furious = False ,
535+ use_autoquant = False ,
374536 unittest = False ,
375537 benchmark = False ,
376538 profile = None ,
@@ -380,7 +542,9 @@ def main(checkpoint_path,
380542 port = 5000 ,
381543 host = "127.0.0.1" ,
382544 dry = False ,
383- batch_size = 1 ):
545+ batch_size = 1 ,
546+ load_fast = "" ,
547+ save_fast = "" ):
384548 if verbose :
385549 logging .basicConfig (level = logging .INFO ,
386550 format = '%(asctime)s - %(levelname)s - %(message)s' ,
@@ -399,22 +563,41 @@ def main(checkpoint_path,
399563 from torchao ._models .sam2 .build_sam import build_sam2
400564 from torchao ._models .sam2 .automatic_mask_generator import SAM2AutomaticMaskGenerator
401565 from torchao ._models .sam2 .utils .amg import rle_to_mask
402-
566+
403567 device = "cuda"
404568 sam2_checkpoint , model_cfg = model_type_to_paths (checkpoint_path , model_type )
405-
569+
406570 logging .info (f"Loading model { sam2_checkpoint } with config { model_cfg } " )
407571 sam2 = build_sam2 (model_cfg , sam2_checkpoint , device = device , apply_postprocessing = False )
408-
572+
409573 logging .info (f"Using { points_per_batch } points_per_batch" )
410574 mask_generator = SAM2AutomaticMaskGenerator (sam2 , points_per_batch = points_per_batch , output_mode = "uncompressed_rle" )
411575
576+ if load_fast != "" :
577+ load_aot_fast (mask_generator , load_fast )
578+
579+ if save_fast != "" :
580+ assert load_fast == "" , "Can't save compiled models while loading them with --load-fast."
581+ assert not baseline , "--fast cannot be combined with baseline. code to be torch.compile(fullgraph=True) compatible."
582+ print (f"Saving compiled models under directory { save_fast } " )
583+ set_aot_fast (mask_generator , save_fast )
584+
412585 if fast :
413586 assert not baseline , "--fast cannot be combined with baseline. code to be torch.compile(fullgraph=True) compatible."
414- set_fast (mask_generator )
587+ set_fast (mask_generator , load_fast )
415588
416589 if furious :
417590 set_furious (mask_generator )
591+ # since autoquant is replicating what furious mode is doing, don't use these two together
592+ elif use_autoquant :
593+ from torchao import autoquant
594+ from torchao .quantization import DEFAULT_FLOAT_AUTOQUANT_CLASS_LIST
595+ mask_generator .predictor .model .image_encoder = autoquant (mask_generator .predictor .model .image_encoder , qtensor_class_list = DEFAULT_FLOAT_AUTOQUANT_CLASS_LIST , min_sqnr = 40 )
596+
597+ # mask_generator.predictor.model.image_encoder = mask_generator.predictor.model.image_encoder.to(torch.float16, min_sqnr=40)
598+ # NOTE: Not baseline feature
599+ mask_generator .predictor ._transforms_device = mask_generator .predictor .device
600+ torch .set_float32_matmul_precision ('high' )
418601
419602 with open ('dog.jpg' , 'rb' ) as f :
420603 image_tensor = file_bytes_to_image_tensor (bytearray (f .read ()))
@@ -494,7 +677,7 @@ async def upload_rle(image: UploadFile = File(...)):
494677 await request_queue .put ((image_tensor , response_future ))
495678 masks = await response_future
496679 return masks_to_rle_dict (masks )
497-
680+
498681 @app .post ("/upload" )
499682 async def upload_image (image : UploadFile = File (...)):
500683 image_tensor = file_bytes_to_image_tensor (bytearray (await image .read ()))
@@ -512,7 +695,7 @@ async def upload_image(image: UploadFile = File(...)):
512695 plt .savefig (buf , format = 'png' )
513696 buf .seek (0 )
514697 return StreamingResponse (buf , media_type = "image/png" )
515-
698+
516699
517700 # uvicorn.run(app, host=host, port=port, log_level="info")
518701 uvicorn .run (app , host = host , port = port )
0 commit comments