diff --git a/inference/inference.py b/inference/inference.py
index b6398d407f0cf95e9c7d22ee91689bff0ec3d620..60b1cd6e9e07675bb7dd9dbdda175ff1b1fd6b7b 100644
--- a/inference/inference.py
+++ b/inference/inference.py
@@ -7,6 +7,7 @@ import fire
 import torch
 import os
 import sys
+import time
 from typing import List
 
 from transformers import LlamaTokenizer
@@ -49,15 +50,14 @@ def main(
     # Set the seeds for reproducibility
     torch.cuda.manual_seed(seed)
     torch.manual_seed(seed)
+    
     model = load_model(model_name, quantization)
-
+    model.config.pretraining_tp=8
     tokenizer = LlamaTokenizer.from_pretrained(model_name)
     tokenizer.add_special_tokens(
         {
-            "eos_token": "</s>",
-            "bos_token": "</s>",
-            "unk_token": "</s>",
-            "pad_token": "[PAD]",
+         
+            "pad_token": "<PAD>",
         }
     )
     
@@ -88,7 +88,7 @@ def main(
 
     batch = tokenizer(user_prompt, return_tensors="pt")
     batch = {k: v.to("cuda") for k, v in batch.items()}
-    
+    start = time.perf_counter()
     with torch.no_grad():
         outputs = model.generate(
             **batch,
@@ -103,7 +103,8 @@ def main(
             length_penalty=length_penalty,
             **kwargs 
         )
-
+    e2e_inference_time = (time.perf_counter()-start)*1000
+    print(f"the inference time is {e2e_inference_time} ms")
     output_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
     
     # Safety check of the model output