From 69b17e315b240fe67e1152887918b37efccb7b1b Mon Sep 17 00:00:00 2001 From: Michael Gschwind Date: Fri, 17 Jan 2025 22:42:29 -0800 Subject: [PATCH 1/2] support model snapshots to save quantized models --- torchchat/cli/builder.py | 32 +++++++++++++++++++++++++ torchchat/cli/cli.py | 14 ++++++++++- torchchat/export.py | 50 ++++++++++++++++++++++++++++++++++++++-- 3 files changed, 93 insertions(+), 3 deletions(-) diff --git a/torchchat/cli/builder.py b/torchchat/cli/builder.py index 38d0e33b2..fb87302b8 100644 --- a/torchchat/cli/builder.py +++ b/torchchat/cli/builder.py @@ -56,6 +56,7 @@ class BuilderArgs: gguf_kwargs: Optional[Dict[str, Any]] = None dso_path: Optional[Union[Path, str]] = None aoti_package_path: Optional[Union[Path, str]] = None + snapshot_path: Optional[Union[Path, str]] = None pte_path: Optional[Union[Path, str]] = None device: Optional[str] = None precision: torch.dtype = torch.float32 @@ -81,6 +82,7 @@ def __post_init__(self): or (self.dso_path and Path(self.dso_path).is_file()) or (self.aoti_package_path and Path(self.aoti_package_path).is_file()) or (self.pte_path and Path(self.pte_path).is_file()) + or (self.snapshot_path and Path(self.snapshot_path).is_file()) ): raise RuntimeError( "need to specify a valid checkpoint path, checkpoint dir, gguf path, DSO path, AOTI PACKAGE or PTE path" @@ -136,6 +138,7 @@ def from_args(cls, args: argparse.Namespace) -> "BuilderArgs": dso_path = getattr(args, "dso_path", None) pte_path = getattr(args, "pte_path", None) aoti_package_path = getattr(args, "aoti_package_path", None) + snapshot_path = getattr(args, "snapshot_path", None) is_chat_model = False if args.is_chat_model: @@ -163,6 +166,7 @@ def from_args(cls, args: argparse.Namespace) -> "BuilderArgs": output_pte_path = getattr(args, "output_pte_path", None) output_aoti_package_path = getattr(args, "output_aoti_package_path", None) output_dso_path = getattr(args, "output_dso_path", None) + output_snapshot_path = getattr(args, "output_snapshot_path", None) if output_pte_path and args.dtype.startswith("fast"): if args.dtype == "fast": # As per Kimish, float32 should be faster on ET XNNPACK @@ -189,6 +193,7 @@ def from_args(cls, args: argparse.Namespace) -> "BuilderArgs": dso_path=dso_path, aoti_package_path=aoti_package_path, pte_path=pte_path, + snapshot_path=snapshot_path, device=args.device, precision=dtype, setup_caches=( @@ -614,6 +619,33 @@ def do_nothing(max_batch_size, max_seq_length): model = PTEModel(config, builder_args.pte_path) except Exception: raise RuntimeError(f"Failed to load ET compiled {builder_args.pte_path}") + elif builder_args.snapshot_path: + # Resolve ModelArgs for constructing the PTEModel + # If a manual params_path is provided, use that + if builder_args.params_path: + config: ModelArgs = ModelArgs.from_params(builder_args.params_path) + else: + # TODO: Instead of loading the whole model, refactor to call a + # helper that generate just model.config + with measure_time("Time to load model: {time:.02f} seconds"): + model = _load_model(builder_args) + device_sync(device=builder_args.device) + config = model.config + model = None + try: + model = torch.load(builder_args.snapshot_path, weights_only=False) + except Exception: + raise RuntimeError(f"Failed to load torchchat snapshot {builder_args.snapshot_path}") + # _active_backend() does not allow DSO & AOTI to be true. + # Choose either. + set_backend (dso=True, pte=False, aoti_package=False) + if (model.config != config): + raise RuntimeError("loaded model architecture mismatch") + ## + ## import all libraries with custom kernels ans custom operators + ## that quantize may be pulling in + ## + elif builder_args.distributed: pp_degree = builder_args.pp tp_degree = builder_args.tp diff --git a/torchchat/cli/cli.py b/torchchat/cli/cli.py index 91bdcaf26..94324772c 100644 --- a/torchchat/cli/cli.py +++ b/torchchat/cli/cli.py @@ -200,6 +200,12 @@ def _add_export_output_path_args(parser) -> None: default=None, help="Output to the specified AOT Inductor .dso model file", ) + exclusive_parser.add_argument( + "--output-snapshot-path", + type=str, + default=None, + help="Output to the specified PyTorch model and sha256 file", + ) exclusive_parser.add_argument( "--output-aoti-package-path", type=str, @@ -247,7 +253,13 @@ def _add_exported_input_path_args(parser) -> None: default=None, help="Use the specified ExecuTorch .pte model file", ) - + exclusive_parser.add_argument( + "--snapshot-path", + type=Path, + default=None, + help="Use the specified torchchat snaphot .tc model file", + ) + # Add CLI Args related to JIT downloading of model artifacts def _add_jit_downloading_args(parser) -> None: diff --git a/torchchat/export.py b/torchchat/export.py index 979778b7c..3aa91cc68 100644 --- a/torchchat/export.py +++ b/torchchat/export.py @@ -28,6 +28,31 @@ default_device = "cpu" +""" +Export Snapshot +""" + + +def export_snapshot( + model: nn.Module, + device: Optional[str] = None, + output_path: str = "model-snapshot.tc", +) -> str: + """ + Export the model as snapshot. + + Args: + model: The model to be exported. + device: The device to run the model on. + output_path: The path to save the exported model. + Returns: + The path to the exported model. + """ + assert output_path.endswith(".tc"), "use .tc extension for snapshots" + torch.save(model, output_path) + return output_path + + """ Export for Server """ @@ -66,7 +91,7 @@ def export_for_server( ) dynamic_shapes = None - with torch.nn.attention.sdpa_kernel([torch.nn.attention.SDPBackend.MATH]): + with torch.nn.attention.sdpa_kernel([torch.nn.attention.SDPBackend.FLASH_ATTENTION ]): metadata = {} # TODO: put more metadata here options = {"aot_inductor.package": package, "aot_inductor.metadata": metadata} if not package: @@ -359,6 +384,7 @@ def main(args): output_pte_path = args.output_pte_path output_dso_path = args.output_dso_path + output_snapshot_path = args.output_snapshot_path output_aoti_package_path = args.output_aoti_package_path if output_pte_path and builder_args.device != "cpu": @@ -366,7 +392,7 @@ def main(args): f"Warning! ExecuTorch export target is controlled by export recipe, not device setting. Ignoring device={builder_args.device} setting." ) builder_args.device = "cpu" - elif "mps" in builder_args.device: + elif (output_pte_path or output_dso_path or output_aoti_package_path) and "mps" in builder_args.device: print("Warning! Device MPS not supported for export. Exporting for device CPU.") builder_args.device = "cpu" @@ -402,6 +428,7 @@ def main(args): model_to_pte = model model_to_dso = model model_to_aoti_package = model + model_to_snapshot = model else: if output_pte_path: _set_gguf_kwargs(builder_args, is_et=True, context="export") @@ -421,6 +448,15 @@ def main(args): model_to_dso = model_to_aoti_package _unset_gguf_kwargs(builder_args) + if output_snapshot_path: + _set_gguf_kwargs(builder_args, is_et=False, context="export") + model_to_snapshot = _initialize_model( + builder_args, + quantize, + support_tensor_subclass=False, + ) + _unset_gguf_kwargs(builder_args) + with torch.no_grad(): if output_pte_path: output_pte_path = str(os.path.abspath(output_pte_path)) @@ -454,3 +490,13 @@ def main(args): builder_args.dynamic_shapes, package=True, ) + + if output_snapshot_path: + output_snapshot_path = str(os.path.abspath(output_snapshot_path)) + print(f"Exporting model using Snapshot to {output_snapshot_path}") + export_snapshot( + model_to_snapshot, + builder_args.device, + output_snapshot_path, + ) + From 0ae743acb716a7a36aa04bfc5843bab241e734d1 Mon Sep 17 00:00:00 2001 From: Michael Gschwind Date: Fri, 17 Jan 2025 23:03:53 -0800 Subject: [PATCH 2/2] import set backend --- torchchat/cli/builder.py | 1 + 1 file changed, 1 insertion(+) diff --git a/torchchat/cli/builder.py b/torchchat/cli/builder.py index fb87302b8..b018e00e1 100644 --- a/torchchat/cli/builder.py +++ b/torchchat/cli/builder.py @@ -638,6 +638,7 @@ def do_nothing(max_batch_size, max_seq_length): raise RuntimeError(f"Failed to load torchchat snapshot {builder_args.snapshot_path}") # _active_backend() does not allow DSO & AOTI to be true. # Choose either. + from torchchat.utils.build_utils import set_backend set_backend (dso=True, pte=False, aoti_package=False) if (model.config != config): raise RuntimeError("loaded model architecture mismatch")