From a3d0ef0cc21e6453b6d07b395c2ed85022cccfef Mon Sep 17 00:00:00 2001 From: Songhao Jia Date: Tue, 27 Aug 2024 22:52:18 -0700 Subject: [PATCH] migrate rest TransformerArgs to ModelArgs --- build/convert_hf_checkpoint.py | 4 ++-- build/model_dist.py | 25 +++++++++++++------------ dist_run.py | 4 ++-- docs/ADVANCED-USERS.md | 6 +++--- 4 files changed, 20 insertions(+), 19 deletions(-) diff --git a/build/convert_hf_checkpoint.py b/build/convert_hf_checkpoint.py index 6baa20908..de176af56 100644 --- a/build/convert_hf_checkpoint.py +++ b/build/convert_hf_checkpoint.py @@ -17,7 +17,7 @@ sys.path.append(str(wd.resolve())) sys.path.append(str((wd / "build").resolve())) -from build.model import TransformerArgs +from build.model import ModelArgs @torch.inference_mode() @@ -32,7 +32,7 @@ def convert_hf_checkpoint( if model_name is None: model_name = model_dir.name - config = TransformerArgs.from_name(model_name) + config = ModelArgs.from_name(model_name).text_transformer_args print(f"Model config {config.__dict__}") # Load the json file containing weight mapping diff --git a/build/model_dist.py b/build/model_dist.py index 820cb2f87..1351b2b0d 100644 --- a/build/model_dist.py +++ b/build/model_dist.py @@ -112,18 +112,19 @@ def forward(self, x: Tensor, input_pos: Optional[Tensor] = None) -> Tensor: # print(f"stage output shape: {x.shape}") return x - - @classmethod - def from_name(cls, name: str): - return cls(TransformerArgs.from_name(name)) - - @classmethod - def from_table(cls, name: str): - return cls(TransformerArgs.from_table(name)) - - @classmethod - def from_params(cls, params_path: str): - return cls(TransformerArgs.from_params(params_path)) + + # temporary disable them due to miss essential input + # @classmethod + # def from_name(cls, name: str): + # return cls(TransformerArgs.from_name(name)) + + # @classmethod + # def from_table(cls, name: str): + # return cls(TransformerArgs.from_table(name)) + + # @classmethod + # def from_params(cls, params_path: str): + # return cls(TransformerArgs.from_params(params_path)) @classmethod def from_gguf(cls, gguf_path: str, **kwargs): diff --git a/dist_run.py b/dist_run.py index 20b9ca6a8..34732a008 100644 --- a/dist_run.py +++ b/dist_run.py @@ -11,12 +11,12 @@ import torch.distributed as dist from torch.distributed.pipelining import PipelineStage, ScheduleGPipe -from build.model import TransformerArgs +from build.model import ModelArgs from build.model_dist import TransformerStage # Model config def main(): - config = TransformerArgs.from_name("Transformer-2-7b-chat-hf") + config = ModelArgs.from_name("Transformer-2-7b-chat-hf").text_transformer_args print(config) # Construct a device mesh with available devices (multi-host or single host) diff --git a/docs/ADVANCED-USERS.md b/docs/ADVANCED-USERS.md index 5a8d41db0..645501642 100644 --- a/docs/ADVANCED-USERS.md +++ b/docs/ADVANCED-USERS.md @@ -123,14 +123,14 @@ For example, for the stories15M model, this would be expressed as For models using a configuration not in the list of known configurations, you can construct the model by initializing the -`TransformerArgs` dataclass that controls model construction from a +`ModelArgs` dataclass that controls model construction from a parameter json using the `params-path ${PARAMS_PATH}` containing the appropriate model parameters to initialize the `ModelArgs` for the model. (We use the model constructor `Model.from_params()`). The parameter file should be in JSON format specifying these -parameters. You can find the `TransformerArgs` data class in -[`model.py`](https://github.com/pytorch/torchchat/blob/main/model.py#L22). +parameters. You can find the `ModelArgs` data class in +[`model.py`](https://github.com/pytorch/torchchat/blob/main/build/model.py#L70). The final way to initialize a torchchat model is from GGUF. You load a GGUF model with the option `--load-gguf ${MODELNAME}.gguf`. Presently,