Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit f3ffa0e

Browse files
authored
rename ModelArgs into TransformerArgs
1 parent 2970f8e commit f3ffa0e

File tree

4 files changed

+19
-19
lines changed

4 files changed

+19
-19
lines changed

build/convert_hf_checkpoint.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
sys.path.append(str(wd.resolve()))
1818
sys.path.append(str((wd / "build").resolve()))
1919

20-
from build.model import ModelArgs
20+
from build.model import TransformerArgs
2121

2222

2323
@torch.inference_mode()
@@ -32,7 +32,7 @@ def convert_hf_checkpoint(
3232
if model_name is None:
3333
model_name = model_dir.name
3434

35-
config = ModelArgs.from_name(model_name)
35+
config = TransformerArgs.from_name(model_name)
3636
print(f"Model config {config.__dict__}")
3737

3838
# Load the json file containing weight mapping

build/gguf_loader.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from quantization.qops import LinearInt4 as WeightOnlyInt4Linear
1818
from quantization.quantize import pack_scales_and_zeros
1919
from build.gguf_util import Q4_0, to_float
20-
from build.model import ModelArgs, Transformer
20+
from build.model import TransformerArgs, Transformer
2121

2222
logger: logging.Logger = logging.getLogger(__name__)
2323

@@ -107,7 +107,7 @@ def load_model(gguf_file: str) -> torch.nn.Module:
107107
arch = metadata["general.architecture"]
108108
assert arch == "llama", "Only LLaMa models are supported by this converter."
109109

110-
model_args = ModelArgs(
110+
model_args = TransformerArgs(
111111
dim=metadata[f"{arch}.embedding_length"],
112112
n_layers=metadata[f"{arch}.block_count"],
113113
n_heads=metadata[f"{arch}.attention.head_count"],

build/model.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222

2323

2424
@dataclass
25-
class ModelArgs:
25+
class TransformerArgs:
2626
block_size: int = 2048
2727
vocab_size: int = 32000
2828
n_layers: int = 32
@@ -45,7 +45,7 @@ def __post_init__(self):
4545
if self.n_local_heads == -1:
4646
self.n_local_heads = self.n_heads
4747
if self.hidden_dim is None:
48-
# If hidden_dim is not explicitly set in the ModelArgs,
48+
# If hidden_dim is not explicitly set in the TransformerArgs,
4949
# then calculate implicitly based on dim and
5050
# also multiple of `args.multiple_of`
5151
multiple_of = self.multiple_of
@@ -73,7 +73,7 @@ def from_params(cls, params_path):
7373
def from_table(cls, name: str):
7474
json_path = config_path / f"{name}.json"
7575
if json_path.is_file():
76-
return ModelArgs.from_params(json_path)
76+
return TransformerArgs.from_params(json_path)
7777
else:
7878
known_model_params = [
7979
config.replace(".json", "") for config in os.listdir(config_path)
@@ -86,7 +86,7 @@ def from_table(cls, name: str):
8686
def from_name(cls, name: str):
8787
json_path = config_path / f"{name}.json"
8888
if Path(json_path).is_file():
89-
return ModelArgs.from_params(json_path)
89+
return TransformerArgs.from_params(json_path)
9090

9191
known_model_params = [
9292
config.replace(".json", "") for config in os.listdir(config_path)
@@ -113,7 +113,7 @@ def from_name(cls, name: str):
113113
f"Unknown model directory name {name}. Must be one of {known_model_params}."
114114
)
115115

116-
return ModelArgs.from_params(config_path / f"{config[0]}.json")
116+
return TransformerArgs.from_params(config_path / f"{config[0]}.json")
117117

118118

119119
class KVCache(nn.Module):
@@ -145,7 +145,7 @@ def update(self, input_pos, k_val, v_val):
145145

146146

147147
class Transformer(nn.Module):
148-
def __init__(self, config: ModelArgs) -> None:
148+
def __init__(self, config: TransformerArgs) -> None:
149149
super().__init__()
150150
self.config = config
151151

@@ -203,15 +203,15 @@ def forward(self, idx: Tensor, input_pos: Optional[Tensor] = None) -> Tensor:
203203

204204
@classmethod
205205
def from_name(cls, name: str):
206-
return cls(ModelArgs.from_name(name))
206+
return cls(TransformerArgs.from_name(name))
207207

208208
@classmethod
209209
def from_table(cls, name: str):
210-
return cls(ModelArgs.from_table(name))
210+
return cls(TransformerArgs.from_table(name))
211211

212212
@classmethod
213213
def from_params(cls, params_path: str):
214-
return cls(ModelArgs.from_params(params_path))
214+
return cls(TransformerArgs.from_params(params_path))
215215

216216
@classmethod
217217
def from_gguf(cls, gguf_path: str, **kwargs):
@@ -224,7 +224,7 @@ def from_gguf(cls, gguf_path: str, **kwargs):
224224

225225

226226
class TransformerBlock(nn.Module):
227-
def __init__(self, config: ModelArgs) -> None:
227+
def __init__(self, config: TransformerArgs) -> None:
228228
super().__init__()
229229
self.attention = Attention(config)
230230
self.feed_forward = FeedForward(config)
@@ -240,7 +240,7 @@ def forward(
240240

241241

242242
class Attention(nn.Module):
243-
def __init__(self, config: ModelArgs):
243+
def __init__(self, config: TransformerArgs):
244244
super().__init__()
245245
assert config.dim % config.n_heads == 0
246246

@@ -340,7 +340,7 @@ def forward(
340340

341341

342342
class FeedForward(nn.Module):
343-
def __init__(self, config: ModelArgs) -> None:
343+
def __init__(self, config: TransformerArgs) -> None:
344344
super().__init__()
345345
self.w1 = nn.Linear(config.dim, config.hidden_dim, bias=False)
346346
self.w2 = nn.Linear(config.hidden_dim, config.dim, bias=False)

docs/ADVANCED-USERS.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -123,13 +123,13 @@ For example, for the stories15M model, this would be expressed as
123123

124124
For models using a configuration not in the list of known
125125
configurations, you can construct the model by initializing the
126-
`ModelArgs` dataclass that controls model construction from a
126+
`TransformerArgs` dataclass that controls model construction from a
127127
parameter json using the `params-path ${PARAMS_PATH}` containing the
128-
appropriate model parameters to initialize the `ModelArgs` for the
128+
appropriate model parameters to initialize the `TransformerArgs` for the
129129
model. (We use the model constructor `Transformer.from_params()`).
130130

131131
The parameter file should be in JSON format specifying these
132-
parameters. You can find the `ModelArgs` data class in
132+
parameters. You can find the `TransformerArgs` data class in
133133
[`model.py`](https://github.com/pytorch/torchchat/blob/main/model.py#L22).
134134

135135
The final way to initialize a torchchat model is from GGUF. You load a

0 commit comments

Comments
 (0)