Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion examples/mediatek/aot_utils/llm_utils/sanity_checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,9 @@ def check_weights_exist(weight_dir):
f"No weight files found in {weight_dir}! Weight files should be either .bin or .safetensors file types."
)
safetensors_l = [f for f in os.listdir(weight_dir) if f.endswith(".safetensors")]
bin_l = [f for f in os.listdir(weight_dir) if f.endswith(".bin")]
bin_l = [
f for f in os.listdir(weight_dir) if f.endswith(".bin") and "embedding" not in f
]
if len(safetensors_l) & len(bin_l):
raise RuntimeError(
"Weights should only be in either .bin or .safetensors format, not both."
Expand Down
4 changes: 3 additions & 1 deletion examples/mediatek/model_export_scripts/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,6 +419,9 @@ def main():
print(f"Max Num Token: {max_num_token}")
print(f"Max Cache Size: {max_cache_size}")

if args.dataset is not None:
embedding_layer = get_embedding_layer(config, weight_dir, state_dict)

# Instantiate model chunks
print("Instantiating submodels")
models = []
Expand All @@ -437,7 +440,6 @@ def main():
cal_dataset = None
if args.dataset is not None:
cal_dataset = load_dataset("text", data_files=args.dataset, split="train")
embedding_layer = get_embedding_layer(config, weight_dir, state_dict)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any specific reason to remove this line?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh I didn't remove it. I shifted it above to line 422. The reason being that during chunk.load_weights (line 437), the weights are popped from the original state_dict and for the tie word embedding True case, this would mean that the embedding weights are not present in the state_dict anymore by the time we get the embedding layer, hence I shifted it up

master_rot_emb = get_master_rot_emb(config, dtype=torch.float32)
if args.preformatter is not None:
cal_dataset = cal_dataset.map(
Expand Down
11 changes: 6 additions & 5 deletions examples/mediatek/models/llm_models/modeling_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -675,8 +675,8 @@ def load_weights(self, state_dict, state_dict_start_idx):
)
else:
if self.config.tie_word_embeddings:
lm_head_weight_key = "embed_tokens.weight"
lm_head_bias_key = "embed_tokens.bias"
lm_head_weight_key = f"{prefix}embed_tokens.weight"
lm_head_bias_key = f"{prefix}embed_tokens.bias"
else:
lm_head_weight_key = "lm_head.weight"
lm_head_bias_key = "lm_head.bias"
Expand Down Expand Up @@ -751,15 +751,16 @@ def get_example_inputs(
for _ in range(2 * self.num_blocks)
],
)
# Specify dims that would be dynamic during calibration
# Note: Assume cache size fixed shape as torch dynamic shape cannot handle dim 3 being
# combination of 2 dynamic dims
if get_dym_shape:
nt = Dim("num_token", max=num_token)
cache_dims = tuple(({} for _ in range(2 * self.num_blocks)))
dynamic_shapes = (
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I may ask in the previous PR - what does dynamic_shape do? Probably good to add a comment to explain

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @cccclai, dynamic_shape is passed in to torch.export.export_for_training to indicate that the input shapes during calibration for some of the inputs may be different. After calibration, the dynamic_shape argument is not needed

{0: None, 1: nt, 2: None},
{0: None, 1: None, 2: nt, 3: nt + cache_size},
{0: None, 1: None, 2: nt, 3: None},
{0: Dim.STATIC, 1: nt, 2: Dim.STATIC},
{0: Dim.STATIC, 1: Dim.STATIC, 2: nt, 3: nt + cache_size},
{0: Dim.STATIC, 1: Dim.STATIC, 2: nt, 3: Dim.STATIC},
cache_dims,
)
return example_inputs, dynamic_shapes
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
{
"architectures": [
"LlamaForCausalLM"
],
"bos_token_id": 128000,
"eos_token_id": 128001,
"head_dim": 64,
"hidden_size": 2048,
"initializer_range": 0.02,
"intermediate_size": 8192,
"max_position_embeddings": 131072,
"model_type": "llama",
"num_attention_heads": 32,
"num_hidden_layers": 16,
"num_key_value_heads": 8,
"rms_norm_eps": 1e-05,
"rope_theta": 500000.0,
"tie_word_embeddings": true,
"torch_dtype": "bfloat16",
"transformers_version": "4.45.0.dev0",
"vocab_size": 128256,
"tokenizer": "pretrained_fast"
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
{
"bos_token": {
"content": "<|begin_of_text|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false
},
"eos_token": {
"content": "<|eot_id|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false
}
}
Loading
Loading