diff --git a/src/llama_recipes/utils/hf_llama_conversion/README.md b/src/llama_recipes/tools/README.md
similarity index 69%
rename from src/llama_recipes/utils/hf_llama_conversion/README.md
rename to src/llama_recipes/tools/README.md
index c02d9647e1668b044857795f18c54dd55b0d7046..ebc05a0e8c181016058e89e0577e1688b7516805 100644
--- a/src/llama_recipes/utils/hf_llama_conversion/README.md
+++ b/src/llama_recipes/tools/README.md
@@ -7,16 +7,16 @@ This is the reverse conversion for `convert_llama_weights_to_hf.py` script from
 - Copy file params.json from the official llama download into that directory.
 - Run the conversion script. `model-path` can be a Hugging Face hub model or a local hf model directory.
 ```
-python -m llama_recipes.tools.convert_hf_weights_to_llama --model-path meta-llama/Llama-2-70b-chat-hf --output-dir test70B --model-size 70B
+python -m llama_recipes.tools.convert_hf_weights_to_llama --model-path meta-llama/Meta-Llama-3-70B-Instruct --output-dir test70B --model-size 70B
 ```
 
 ## Step 1: Run inference
-Checkout the official llama inference [repo](https://github.com/facebookresearch/llama). Test using chat or text completion.
+Checkout the official llama 3 inference [repo](https://github.com/meta-llama/llama3). Test using chat or text completion.
 ```
-torchrun --nproc_per_node 8 example_chat_completion.py --ckpt_dir ./test70B --tokenizer_path ${llama_2_dir}/tokenizer.model
+torchrun --nproc_per_node 8 example_chat_completion.py --ckpt_dir ./test70B --tokenizer_path ${llama_3_dir}/tokenizer.model
 ```
 
 For validation, please compare the converted weights with official llama 2 weights
 ```
-python compare_llama_weights.py test70B ${llama_2_70b_chat_dir}
+python compare_llama_weights.py test70B ${Llama-3-70B-Instruct_dir}
 ```
diff --git a/src/llama_recipes/utils/hf_llama_conversion/compare_llama_weights.py b/src/llama_recipes/tools/compare_llama_weights.py
similarity index 77%
rename from src/llama_recipes/utils/hf_llama_conversion/compare_llama_weights.py
rename to src/llama_recipes/tools/compare_llama_weights.py
index 64bcbb83b05a6b0dd92efa8d462cf235620cf2eb..25d16aa76a759d19c6730e95fc5ac4e33e7e4d90 100644
--- a/src/llama_recipes/utils/hf_llama_conversion/compare_llama_weights.py
+++ b/src/llama_recipes/tools/compare_llama_weights.py
@@ -28,23 +28,25 @@ def main() -> None:
         assert len(one) == len(
             two
         ), "shard should have the same length: {} != {}".format(len(one), len(two))
+        one = sorted(one.items(), key=lambda x: x[0])
+        two = sorted(two.items(), key=lambda x: x[0])
 
-        for _, (v, w) in enumerate(zip(one.items(), two.items())):
+        for _, (v, w) in enumerate(zip(one, two)):
             assert v[0] == w[0], "{} != {}".format(v[0], w[0])
             assert v[1].shape == w[1].shape, "tensor {} shape {} != {}".format(
                 v[0], v[1].shape, w[1].shape
             )
 
             delta = (v[1] - w[1]).abs().max().item()
-            deltas.append((i, v[0], delta))
+            deltas.append((i, v[0], delta, w[1].abs().mean().item()))
         del one
         del two
         gc.collect()
 
-    deltas = sorted(deltas, key=lambda x: x[-1], reverse=True)
+    deltas = sorted(deltas, key=lambda x: x[-2], reverse=True)
     print("Top 10 largest deltas:")
-    for i, k, v in deltas[:10]:
-        print(f"  shard {i} {k}: {v}")
+    for i, k, delta, value in deltas[:10]:
+        print(f"  shard {i} {k}: {delta} vs {value}")
 
 
 if __name__ == "__main__":
diff --git a/src/llama_recipes/tools/convert_hf_weights_to_llama.py b/src/llama_recipes/tools/convert_hf_weights_to_llama.py
index bdd7d8c4920f6ab2db9f83d267c70bba40ea05dd..356e4a4b9aa2f09bcdbef09b6ed57318bf8012db 100644
--- a/src/llama_recipes/tools/convert_hf_weights_to_llama.py
+++ b/src/llama_recipes/tools/convert_hf_weights_to_llama.py
@@ -12,6 +12,7 @@ from transformers import LlamaForCausalLM  # @manual
 
 NUM_SHARDS = {
     "7B": 1,
+    "8B": 1,
     "13B": 2,
     "34B": 4,
     "30B": 4,
@@ -30,15 +31,12 @@ def write_model(model_path, model_size, output_base_path):
     n_heads_per_shard = n_heads // num_shards
     dim = params["dim"]
     dims_per_head = dim // n_heads
-    base = 10000.0
-    inv_freq = (
-        1.0 / (base ** (torch.arange(0, dims_per_head, 2).float() / dims_per_head))
-    ).to(dtype)
+    llama_version = 3 if params.get("vocab_size") == 128256 else 2
 
     if "n_kv_heads" in params:
         num_key_value_heads = params["n_kv_heads"]  # for GQA / MQA
-        num_local_key_value_heads = n_heads_per_shard // num_key_value_heads
-        key_value_dim = dim // num_key_value_heads
+        num_local_key_value_heads = num_key_value_heads // num_shards
+        key_value_dim = dims_per_head * num_key_value_heads
     else:  # compatibility with other checkpoints
         num_key_value_heads = n_heads
         num_local_key_value_heads = n_heads_per_shard
@@ -72,7 +70,10 @@ def write_model(model_path, model_size, output_base_path):
         for i, tensor in enumerate(tensors):
             state_dict[i][name] = tensor.clone()
 
-    insert_chunk("tok_embeddings.weight", loaded["model.embed_tokens.weight"], 1)
+    concat_dim = 0 if llama_version == 3 else 1
+    insert_chunk(
+        "tok_embeddings.weight", loaded["model.embed_tokens.weight"], concat_dim
+    )
     insert("norm.weight", loaded["model.norm.weight"])
     insert_chunk("output.weight", loaded["lm_head.weight"], 0)
 
@@ -136,7 +137,12 @@ def write_model(model_path, model_size, output_base_path):
             f"layers.{layer_i}.ffn_norm.weight",
             loaded[f"model.layers.{layer_i}.post_attention_layernorm.weight"],
         )
-    insert("rope.freqs", inv_freq)
+    if llama_version != 3:
+        base = 10000.0
+        inv_freq = (
+            1.0 / (base ** (torch.arange(0, dims_per_head, 2).float() / dims_per_head))
+        ).to(dtype)
+        insert("rope.freqs", inv_freq)
 
     for i in tqdm(range(num_shards), desc="Saving checkpoint shards"):
         torch.save(