2222
2323
2424@dataclass
25- class ModelArgs :
25+ class TransformerArgs :
2626 block_size : int = 2048
2727 vocab_size : int = 32000
2828 n_layers : int = 32
@@ -45,7 +45,7 @@ def __post_init__(self):
4545 if self .n_local_heads == - 1 :
4646 self .n_local_heads = self .n_heads
4747 if self .hidden_dim is None :
48- # If hidden_dim is not explicitly set in the ModelArgs ,
48+ # If hidden_dim is not explicitly set in the TransformerArgs ,
4949 # then calculate implicitly based on dim and
5050 # also multiple of `args.multiple_of`
5151 multiple_of = self .multiple_of
@@ -73,7 +73,7 @@ def from_params(cls, params_path):
7373 def from_table (cls , name : str ):
7474 json_path = config_path / f"{ name } .json"
7575 if json_path .is_file ():
76- return ModelArgs .from_params (json_path )
76+ return TransformerArgs .from_params (json_path )
7777 else :
7878 known_model_params = [
7979 config .replace (".json" , "" ) for config in os .listdir (config_path )
@@ -86,7 +86,7 @@ def from_table(cls, name: str):
8686 def from_name (cls , name : str ):
8787 json_path = config_path / f"{ name } .json"
8888 if Path (json_path ).is_file ():
89- return ModelArgs .from_params (json_path )
89+ return TransformerArgs .from_params (json_path )
9090
9191 known_model_params = [
9292 config .replace (".json" , "" ) for config in os .listdir (config_path )
@@ -113,7 +113,7 @@ def from_name(cls, name: str):
113113 f"Unknown model directory name { name } . Must be one of { known_model_params } ."
114114 )
115115
116- return ModelArgs .from_params (config_path / f"{ config [0 ]} .json" )
116+ return TransformerArgs .from_params (config_path / f"{ config [0 ]} .json" )
117117
118118
119119class KVCache (nn .Module ):
@@ -145,7 +145,7 @@ def update(self, input_pos, k_val, v_val):
145145
146146
147147class Transformer (nn .Module ):
148- def __init__ (self , config : ModelArgs ) -> None :
148+ def __init__ (self , config : TransformerArgs ) -> None :
149149 super ().__init__ ()
150150 self .config = config
151151
@@ -203,15 +203,15 @@ def forward(self, idx: Tensor, input_pos: Optional[Tensor] = None) -> Tensor:
203203
204204 @classmethod
205205 def from_name (cls , name : str ):
206- return cls (ModelArgs .from_name (name ))
206+ return cls (TransformerArgs .from_name (name ))
207207
208208 @classmethod
209209 def from_table (cls , name : str ):
210- return cls (ModelArgs .from_table (name ))
210+ return cls (TransformerArgs .from_table (name ))
211211
212212 @classmethod
213213 def from_params (cls , params_path : str ):
214- return cls (ModelArgs .from_params (params_path ))
214+ return cls (TransformerArgs .from_params (params_path ))
215215
216216 @classmethod
217217 def from_gguf (cls , gguf_path : str , ** kwargs ):
@@ -224,7 +224,7 @@ def from_gguf(cls, gguf_path: str, **kwargs):
224224
225225
226226class TransformerBlock (nn .Module ):
227- def __init__ (self , config : ModelArgs ) -> None :
227+ def __init__ (self , config : TransformerArgs ) -> None :
228228 super ().__init__ ()
229229 self .attention = Attention (config )
230230 self .feed_forward = FeedForward (config )
@@ -240,7 +240,7 @@ def forward(
240240
241241
242242class Attention (nn .Module ):
243- def __init__ (self , config : ModelArgs ):
243+ def __init__ (self , config : TransformerArgs ):
244244 super ().__init__ ()
245245 assert config .dim % config .n_heads == 0
246246
@@ -340,7 +340,7 @@ def forward(
340340
341341
342342class FeedForward (nn .Module ):
343- def __init__ (self , config : ModelArgs ) -> None :
343+ def __init__ (self , config : TransformerArgs ) -> None :
344344 super ().__init__ ()
345345 self .w1 = nn .Linear (config .dim , config .hidden_dim , bias = False )
346346 self .w2 = nn .Linear (config .hidden_dim , config .dim , bias = False )
0 commit comments