Skip to content
Snippets Groups Projects
Unverified Commit ed3136f1 authored by dongwang218's avatar dongwang218 Committed by GitHub
Browse files

Update hf weight conversion script to llama 3 (#551)

parent f6617fb8
No related branches found
No related tags found
No related merge requests found
...@@ -7,16 +7,16 @@ This is the reverse conversion for `convert_llama_weights_to_hf.py` script from ...@@ -7,16 +7,16 @@ This is the reverse conversion for `convert_llama_weights_to_hf.py` script from
- Copy file params.json from the official llama download into that directory. - Copy file params.json from the official llama download into that directory.
- Run the conversion script. `model-path` can be a Hugging Face hub model or a local hf model directory. - Run the conversion script. `model-path` can be a Hugging Face hub model or a local hf model directory.
``` ```
python -m llama_recipes.tools.convert_hf_weights_to_llama --model-path meta-llama/Llama-2-70b-chat-hf --output-dir test70B --model-size 70B python -m llama_recipes.tools.convert_hf_weights_to_llama --model-path meta-llama/Meta-Llama-3-70B-Instruct --output-dir test70B --model-size 70B
``` ```
## Step 1: Run inference ## Step 1: Run inference
Checkout the official llama inference [repo](https://github.com/facebookresearch/llama). Test using chat or text completion. Checkout the official llama 3 inference [repo](https://github.com/meta-llama/llama3). Test using chat or text completion.
``` ```
torchrun --nproc_per_node 8 example_chat_completion.py --ckpt_dir ./test70B --tokenizer_path ${llama_2_dir}/tokenizer.model torchrun --nproc_per_node 8 example_chat_completion.py --ckpt_dir ./test70B --tokenizer_path ${llama_3_dir}/tokenizer.model
``` ```
For validation, please compare the converted weights with official llama 2 weights For validation, please compare the converted weights with official llama 2 weights
``` ```
python compare_llama_weights.py test70B ${llama_2_70b_chat_dir} python compare_llama_weights.py test70B ${Llama-3-70B-Instruct_dir}
``` ```
...@@ -28,23 +28,25 @@ def main() -> None: ...@@ -28,23 +28,25 @@ def main() -> None:
assert len(one) == len( assert len(one) == len(
two two
), "shard should have the same length: {} != {}".format(len(one), len(two)) ), "shard should have the same length: {} != {}".format(len(one), len(two))
one = sorted(one.items(), key=lambda x: x[0])
two = sorted(two.items(), key=lambda x: x[0])
for _, (v, w) in enumerate(zip(one.items(), two.items())): for _, (v, w) in enumerate(zip(one, two)):
assert v[0] == w[0], "{} != {}".format(v[0], w[0]) assert v[0] == w[0], "{} != {}".format(v[0], w[0])
assert v[1].shape == w[1].shape, "tensor {} shape {} != {}".format( assert v[1].shape == w[1].shape, "tensor {} shape {} != {}".format(
v[0], v[1].shape, w[1].shape v[0], v[1].shape, w[1].shape
) )
delta = (v[1] - w[1]).abs().max().item() delta = (v[1] - w[1]).abs().max().item()
deltas.append((i, v[0], delta)) deltas.append((i, v[0], delta, w[1].abs().mean().item()))
del one del one
del two del two
gc.collect() gc.collect()
deltas = sorted(deltas, key=lambda x: x[-1], reverse=True) deltas = sorted(deltas, key=lambda x: x[-2], reverse=True)
print("Top 10 largest deltas:") print("Top 10 largest deltas:")
for i, k, v in deltas[:10]: for i, k, delta, value in deltas[:10]:
print(f" shard {i} {k}: {v}") print(f" shard {i} {k}: {delta} vs {value}")
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -12,6 +12,7 @@ from transformers import LlamaForCausalLM # @manual ...@@ -12,6 +12,7 @@ from transformers import LlamaForCausalLM # @manual
NUM_SHARDS = { NUM_SHARDS = {
"7B": 1, "7B": 1,
"8B": 1,
"13B": 2, "13B": 2,
"34B": 4, "34B": 4,
"30B": 4, "30B": 4,
...@@ -30,15 +31,12 @@ def write_model(model_path, model_size, output_base_path): ...@@ -30,15 +31,12 @@ def write_model(model_path, model_size, output_base_path):
n_heads_per_shard = n_heads // num_shards n_heads_per_shard = n_heads // num_shards
dim = params["dim"] dim = params["dim"]
dims_per_head = dim // n_heads dims_per_head = dim // n_heads
base = 10000.0 llama_version = 3 if params.get("vocab_size") == 128256 else 2
inv_freq = (
1.0 / (base ** (torch.arange(0, dims_per_head, 2).float() / dims_per_head))
).to(dtype)
if "n_kv_heads" in params: if "n_kv_heads" in params:
num_key_value_heads = params["n_kv_heads"] # for GQA / MQA num_key_value_heads = params["n_kv_heads"] # for GQA / MQA
num_local_key_value_heads = n_heads_per_shard // num_key_value_heads num_local_key_value_heads = num_key_value_heads // num_shards
key_value_dim = dim // num_key_value_heads key_value_dim = dims_per_head * num_key_value_heads
else: # compatibility with other checkpoints else: # compatibility with other checkpoints
num_key_value_heads = n_heads num_key_value_heads = n_heads
num_local_key_value_heads = n_heads_per_shard num_local_key_value_heads = n_heads_per_shard
...@@ -72,7 +70,10 @@ def write_model(model_path, model_size, output_base_path): ...@@ -72,7 +70,10 @@ def write_model(model_path, model_size, output_base_path):
for i, tensor in enumerate(tensors): for i, tensor in enumerate(tensors):
state_dict[i][name] = tensor.clone() state_dict[i][name] = tensor.clone()
insert_chunk("tok_embeddings.weight", loaded["model.embed_tokens.weight"], 1) concat_dim = 0 if llama_version == 3 else 1
insert_chunk(
"tok_embeddings.weight", loaded["model.embed_tokens.weight"], concat_dim
)
insert("norm.weight", loaded["model.norm.weight"]) insert("norm.weight", loaded["model.norm.weight"])
insert_chunk("output.weight", loaded["lm_head.weight"], 0) insert_chunk("output.weight", loaded["lm_head.weight"], 0)
...@@ -136,7 +137,12 @@ def write_model(model_path, model_size, output_base_path): ...@@ -136,7 +137,12 @@ def write_model(model_path, model_size, output_base_path):
f"layers.{layer_i}.ffn_norm.weight", f"layers.{layer_i}.ffn_norm.weight",
loaded[f"model.layers.{layer_i}.post_attention_layernorm.weight"], loaded[f"model.layers.{layer_i}.post_attention_layernorm.weight"],
) )
insert("rope.freqs", inv_freq) if llama_version != 3:
base = 10000.0
inv_freq = (
1.0 / (base ** (torch.arange(0, dims_per_head, 2).float() / dims_per_head))
).to(dtype)
insert("rope.freqs", inv_freq)
for i in tqdm(range(num_shards), desc="Saving checkpoint shards"): for i in tqdm(range(num_shards), desc="Saving checkpoint shards"):
torch.save( torch.save(
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment