Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions build/convert_hf_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
sys.path.append(str(wd.resolve()))
sys.path.append(str((wd / "build").resolve()))

from build.model import TransformerArgs
from build.model import ModelArgs


@torch.inference_mode()
Expand All @@ -32,7 +32,7 @@ def convert_hf_checkpoint(
if model_name is None:
model_name = model_dir.name

config = TransformerArgs.from_name(model_name)
config = ModelArgs.from_name(model_name).text_transformer_args
print(f"Model config {config.__dict__}")

# Load the json file containing weight mapping
Expand Down
25 changes: 13 additions & 12 deletions build/model_dist.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,18 +112,19 @@ def forward(self, x: Tensor, input_pos: Optional[Tensor] = None) -> Tensor:

# print(f"stage output shape: {x.shape}")
return x

@classmethod
def from_name(cls, name: str):
return cls(TransformerArgs.from_name(name))

@classmethod
def from_table(cls, name: str):
return cls(TransformerArgs.from_table(name))

@classmethod
def from_params(cls, params_path: str):
return cls(TransformerArgs.from_params(params_path))

# temporary disable them due to miss essential input
# @classmethod
# def from_name(cls, name: str):
# return cls(TransformerArgs.from_name(name))

# @classmethod
# def from_table(cls, name: str):
# return cls(TransformerArgs.from_table(name))

# @classmethod
# def from_params(cls, params_path: str):
# return cls(TransformerArgs.from_params(params_path))

@classmethod
def from_gguf(cls, gguf_path: str, **kwargs):
Expand Down
4 changes: 2 additions & 2 deletions dist_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,12 @@
import torch.distributed as dist
from torch.distributed.pipelining import PipelineStage, ScheduleGPipe

from build.model import TransformerArgs
from build.model import ModelArgs
from build.model_dist import TransformerStage

# Model config
def main():
config = TransformerArgs.from_name("Transformer-2-7b-chat-hf")
config = ModelArgs.from_name("Transformer-2-7b-chat-hf").text_transformer_args
print(config)

# Construct a device mesh with available devices (multi-host or single host)
Expand Down
6 changes: 3 additions & 3 deletions docs/ADVANCED-USERS.md
Original file line number Diff line number Diff line change
Expand Up @@ -123,14 +123,14 @@ For example, for the stories15M model, this would be expressed as

For models using a configuration not in the list of known
configurations, you can construct the model by initializing the
`TransformerArgs` dataclass that controls model construction from a
`ModelArgs` dataclass that controls model construction from a
parameter json using the `params-path ${PARAMS_PATH}` containing the
appropriate model parameters to initialize the `ModelArgs` for the
model. (We use the model constructor `Model.from_params()`).

The parameter file should be in JSON format specifying these
parameters. You can find the `TransformerArgs` data class in
[`model.py`](https://github.com/pytorch/torchchat/blob/main/model.py#L22).
parameters. You can find the `ModelArgs` data class in
[`model.py`](https://github.com/pytorch/torchchat/blob/main/build/model.py#L70).

The final way to initialize a torchchat model is from GGUF. You load a
GGUF model with the option `--load-gguf ${MODELNAME}.gguf`. Presently,
Expand Down