diff --git a/README.md b/README.md
index ab122be269381e696b1895c1ae31ab0684320f75..370be8974ef3d3d9fe77c080844261c2c3b29115 100644
--- a/README.md
+++ b/README.md
@@ -29,7 +29,7 @@ Llama 2 is a new technology that carries potential risks with use. Testing condu
 **For more in depth information checkout the following:**
 
 * [Single GPU Fine-tuning](./docs/single_gpu.md)
-* [Multi-GPU Fine-tuning](./docs/mutli_gpu.md)
+* [Multi-GPU Fine-tuning](./docs/multi_gpu.md)
 * [LLM Fine-tuning](./docs/LLM_finetuning.md)
 * [Adding custom datasets](./docs/Dataset.md)
 * [Inference](./docs/inference.md)
@@ -107,13 +107,21 @@ torchrun --nnodes 1 --nproc_per_node 4  llama_finetuning.py --enable_fsdp --use_
 
 Here we use FSDP as discussed in the next section which can be used along with PEFT methods. To make use of PEFT methods with FSDP make sure to pass `use_peft` and `peft_method` args along with `enable_fsdp`. Here we are using `BF16` for training.
 
+## Flash Attention and Xformer Memory Efficient Kernels
+
+Setting `use_fast_kernels` will enable using of Flash Attention or Xformer memory-efficient kernels based on the hardware being used. This would speed up the fine-tuning job. This has been enabled in `optimum` library from HuggingFace as a one-liner API, please read more [here](https://pytorch.org/blog/out-of-the-box-acceleration/).
+
+```bash
+torchrun --nnodes 1 --nproc_per_node 4  llama_finetuning.py --enable_fsdp --use_peft --peft_method lora --model_name /patht_of_model_folder/7B --pure_bf16 --output_dir Path/to/save/PEFT/model --use_fast_kernels
+```
+
 ### Fine-tuning using FSDP Only
 
 If you are interested in running full parameter fine-tuning without making use of PEFT methods, please use the following command. Make sure to change the `nproc_per_node` to your available GPUs. This has been tested with `BF16` on 8xA100, 40GB GPUs.
 
 ```bash
 
-torchrun --nnodes 1 --nproc_per_node 8  llama_finetuning.py --enable_fsdp --model_name /patht_of_model_folder/7B --dist_checkpoint_root_folder model_checkpoints --dist_checkpoint_folder fine-tuned
+torchrun --nnodes 1 --nproc_per_node 8  llama_finetuning.py --enable_fsdp --model_name /patht_of_model_folder/7B --dist_checkpoint_root_folder model_checkpoints --dist_checkpoint_folder fine-tuned --use_fast_kernels
 
 ```
 
diff --git a/UPDATES.md b/UPDATES.md
new file mode 100644
index 0000000000000000000000000000000000000000..f90b4142a998c0d79510390d4a0f7082d382c0e3
--- /dev/null
+++ b/UPDATES.md
@@ -0,0 +1,19 @@
+## System Prompt Update
+
+### Observed Issue
+We received feedback from the community on our prompt template and we are providing an update to reduce the false refusal rates seen. False refusals occur when the model incorrectly refuses to answer a question that it should, for example due to overly broad instructions to be cautious in how it provides responses. 
+
+### Updated approach
+Based on evaluation and analysis, we recommend the removal of the system prompt as the default setting.  Pull request [#626](https://github.com/facebookresearch/llama/pull/626) removes the system prompt as the default option, but still provides an example to help enable experimentation for those using it. 
+
+## Token Sanitization Update
+
+### Observed Issue
+The PyTorch scripts currently provided for tokenization and model inference allow for direct prompt injection via string concatenation. Prompt injections allow for the addition of special system and instruction prompt strings from user-provided prompts. 
+
+As noted in the documentation, these strings are required to use the fine-tuned chat models. However, prompt injections have also been used for manipulating or abusing models by bypassing their safeguards, allowing for the creation of content or behaviors otherwise outside the bounds of acceptable use. 
+
+### Updated approach
+We recommend sanitizing [these strings](https://github.com/facebookresearch/llama#fine-tuned-chat-models) from any user provided prompts. Sanitization of user prompts mitigates malicious or accidental abuse of these strings. The provided scripts have been updated to do this. 
+
+Note: even with this update safety classifiers should still be applied to catch unsafe behaviors or content produced by the model. An [example](https://github.com/facebookresearch/llama-recipes/blob/main/inference/inference.py) of how to deploy such a classifier can be found in the llama-recipes repository.
\ No newline at end of file
diff --git a/configs/training.py b/configs/training.py
index 4c50372dafbabfbd4acc032476d8111e635bb7f0..7b0e82d44363e563a2420ef6a2e7977ad8c539c8 100644
--- a/configs/training.py
+++ b/configs/training.py
@@ -32,6 +32,7 @@ class train_config:
     dist_checkpoint_root_folder: str="PATH/to/save/FSDP/model" # will be used if using FSDP
     dist_checkpoint_folder: str="fine-tuned" # will be used if using FSDP
     save_optimizer: bool=False # will be used if using FSDP
+    use_fast_kernels: bool = False, # Enable using SDPA from PyTroch Accelerated Transformers, make use Flash Attention and Xformer memory-efficient kernels
 
     
     
diff --git a/docs/inference.md b/docs/inference.md
index 509eb2bcef73c95ba75f9561a3303669e5b45b7a..67ee3dca697a4a17d0d12404cb3b885fe52f91ae 100644
--- a/docs/inference.md
+++ b/docs/inference.md
@@ -49,6 +49,18 @@ The inference folder also includes a chat completion example, that adds built-in
 python inference/chat_completion.py --model_name "PATH/TO/MODEL/7B/" --prompt_file inference/chats.json  --quantization --use_auditnlg
 
 ```
+
+## Flash Attention and Xformer Memory Efficient Kernels
+
+Setting `use_fast_kernels` will enable using of Flash Attention or Xformer memory-efficient kernels based on the hardware being used. This would speed up inference when used for batched inputs. This has been enabled in `optimum` library from HuggingFace as a one-liner API, please read more [here](https://pytorch.org/blog/out-of-the-box-acceleration/).
+
+```bash
+python inference/chat_completion.py --model_name "PATH/TO/MODEL/7B/" --prompt_file inference/chats.json  --quantization --use_auditnlg --use_fast_kernels
+
+python inference/inference.py --model_name <training_config.output_dir> --peft_model <training_config.output_dir> --prompt_file <test_prompt_file> --use_auditnlg --use_fast_kernels
+
+```
+
 ## Loading back FSDP checkpoints
 
 In case you have fine-tuned your model with pure FSDP and saved the checkpoints with "SHARDED_STATE_DICT" as shown [here](../configs/fsdp.py), you can use this converter script to convert the FSDP Sharded checkpoints into HuggingFace checkpoints. This enables you to use the inference script normally as mentioned above.
diff --git a/docs/mutli_gpu.md b/docs/multi_gpu.md
similarity index 91%
rename from docs/mutli_gpu.md
rename to docs/multi_gpu.md
index a4396deeade7438bf026f8ac41a2348b28b9438c..5c8fdf353fb2b77cb8facaaeff2ba9bb1a8e511a 100644
--- a/docs/mutli_gpu.md
+++ b/docs/multi_gpu.md
@@ -44,6 +44,13 @@ The args used in the command above are:
 
 We use `torchrun` here to spawn multiple processes for FSDP.
 
+## Flash Attention and Xformer Memory Efficient Kernels
+
+Setting `use_fast_kernels` will enable using of Flash Attention or Xformer memory-efficient kernels based on the hardware being used. This would speed up the fine-tuning job. This has been enabled in `optimum` library from HuggingFace as a one-liner API, please read more [here](https://pytorch.org/blog/out-of-the-box-acceleration/).
+
+```bash
+torchrun --nnodes 1 --nproc_per_node 4  ../llama_finetuning.py --enable_fsdp --model_name /patht_of_model_folder/7B --use_peft --peft_method lora --output_dir Path/to/save/PEFT/model --use_fast_kernels
+```
 
 ### Fine-tuning using FSDP Only
 
@@ -51,7 +58,7 @@ If interested in running full parameter finetuning without making use of PEFT me
 
 ```bash
 
-torchrun --nnodes 1 --nproc_per_node 8  llama_finetuning.py --enable_fsdp --model_name /patht_of_model_folder/7B --dist_checkpoint_root_folder model_checkpoints --dist_checkpoint_folder fine-tuned --pure_bf16 
+torchrun --nnodes 1 --nproc_per_node 8  llama_finetuning.py --enable_fsdp --model_name /patht_of_model_folder/7B --dist_checkpoint_root_folder model_checkpoints --dist_checkpoint_folder fine-tuned --pure_bf16 --use_fast_kernels
 
 ```
 
@@ -75,7 +82,7 @@ Currently 4 datasets are supported that can be found in [Datasets config file](.
 * `alpaca_dataset` : to get this open source data please download the `aplaca.json` to `ft_dataset` folder.
 
 ```bash
-wget -P ft_dataset https://github.com/tatsu-lab/stanford_alpaca/blob/main/alpaca_data.json
+wget -P ft_datasets https://raw.githubusercontent.com/tatsu-lab/stanford_alpaca/main/alpaca_data.json
 ```
 
 * `samsum_dataset`
diff --git a/docs/single_gpu.md b/docs/single_gpu.md
index a449ce6386e285399b03f9d7df0345569e11e04c..95353b388b0b2c7b5c452d4c48c60cd1e621cfd9 100644
--- a/docs/single_gpu.md
+++ b/docs/single_gpu.md
@@ -47,7 +47,7 @@ Currently 4 datasets are supported that can be found in [Datasets config file](.
 * `alpaca_dataset` : to get this open source data please download the `aplaca.json` to `ft_dataset` folder.
 
 ```bash
-wget -P ft_dataset https://github.com/tatsu-lab/stanford_alpaca/blob/main/alpaca_data.json
+wget -P ft_datasets https://raw.githubusercontent.com/tatsu-lab/stanford_alpaca/main/alpaca_data.json
 ```
 
 * `samsum_dataset`
diff --git a/ft_datasets/alpaca_dataset.py b/ft_datasets/alpaca_dataset.py
index 4d492460f111f5fca279dc48d227fbb9b9f65017..77cbb27eac56e73f9f9eb0f415398d2bd9b17921 100644
--- a/ft_datasets/alpaca_dataset.py
+++ b/ft_datasets/alpaca_dataset.py
@@ -42,6 +42,9 @@ class InstructionDataset(Dataset):
         return len(self.ann)
 
     def __getitem__(self, index):
+        IGNORE_INDEX = -100  # The default setting in CrossEntropyLoss
+
+
         ann = self.ann[index]
         if ann.get("input", "") == "":
             prompt = PROMPT_DICT["prompt_no_input"].format_map(ann)
@@ -66,7 +69,7 @@ class InstructionDataset(Dataset):
         example_mask = example.ge(0)
         label_mask = labels.ge(0)
         example[~example_mask] = 0
-        labels[~label_mask] = 0
+        labels[~label_mask] = IGNORE_INDEX
         example_mask = example_mask.float()
         label_mask = label_mask.float()
 
diff --git a/inference/chat_completion.py b/inference/chat_completion.py
index bc5311d62ac735ced4043357ebae0144cf335884..d5c8378bdbbdc5f2db3fc33ef032dc47e23a07a4 100644
--- a/inference/chat_completion.py
+++ b/inference/chat_completion.py
@@ -34,6 +34,7 @@ def main(
     enable_azure_content_safety: bool=False, # Enable safety check with Azure content safety api
     enable_sensitive_topics: bool=False, # Enable check for sensitive topics using AuditNLG APIs
     enable_saleforce_content_safety: bool=True, # Enable safety check woth Saleforce safety flan t5
+    use_fast_kernels: bool = False, # Enable using SDPA from PyTroch Accelerated Transformers, make use Flash Attention and Xformer memory-efficient kernels
     **kwargs
 ):
     if prompt_file is not None:
@@ -59,6 +60,18 @@ def main(
     model = load_model(model_name, quantization)
     if peft_model:
         model = load_peft_model(model, peft_model)
+    if use_fast_kernels:
+        """
+        Setting 'use_fast_kernels' will enable
+        using of Flash Attention or Xformer memory-efficient kernels 
+        based on the hardware being used. This would speed up inference when used for batched inputs.
+        """
+        try:
+            from optimum.bettertransformer import BetterTransformer
+            model = BetterTransformer.transform(model)   
+        except ImportError:
+            print("Module 'optimum' not found. Please install 'optimum' it before proceeding.")
+
     tokenizer = LlamaTokenizer.from_pretrained(model_name)
     tokenizer.add_special_tokens(
         {
diff --git a/inference/chat_utils.py b/inference/chat_utils.py
index c8c90582b81c1292fffcf2fe4c2fa91f3515713f..89911dd09ef32876d2d6622965b348e9f1b92c77 100644
--- a/inference/chat_utils.py
+++ b/inference/chat_utils.py
@@ -16,22 +16,11 @@ Dialog = List[Message]
 
 B_INST, E_INST = "[INST]", "[/INST]"
 B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"
-DEFAULT_SYSTEM_PROMPT = """\
-You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe.  Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.
-
-If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information."""
-
 def format_tokens(dialogs, tokenizer):
     prompt_tokens = []
     for dialog in dialogs:
-        if dialog[0]["role"] != "system":
-                dialog = [
-                    {
-                        "role": "system",
-                        "content": DEFAULT_SYSTEM_PROMPT,
-                    }
-                ] + dialog
-        dialog = [
+        if dialog[0]["role"] == "system":
+            dialog = [
             {
                 "role": dialog[1]["role"],
                 "content": B_SYS
@@ -47,7 +36,7 @@ def format_tokens(dialogs, tokenizer):
             "starting with user and alternating (u/a/u/a/u...)"
         )
         """
-        Please verify that yout tokenizer support adding "[INST]", "[/INST]" to your inputs.
+        Please verify that your tokenizer support adding "[INST]", "[/INST]" to your inputs.
         Here, we are adding it manually.
         """
         dialog_tokens: List[int] = sum(
diff --git a/inference/chats.json b/inference/chats.json
index 4b1021bac148ea9da8b1c813873cf639fc0bbd7a..5d41f492944b4d6c58a3aefa2cd1824430d8f129 100644
--- a/inference/chats.json
+++ b/inference/chats.json
@@ -18,5 +18,12 @@
             "content": "Always answer with emojis"
         },
         {"role": "user", "content": "How to go from Beijing to NY?"}
+    ],
+    [
+        {
+            "role": "system",
+            "content": "You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature. If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information."
+        },
+        {"role": "user", "content": "Write a brief birthday message to John"}
     ]
 ]
\ No newline at end of file
diff --git a/inference/inference.py b/inference/inference.py
index 985ce68d51e561ccd4d2d980e1dfd617d444487d..81668e3fb5b748a37e0f657c5fa43c8fc03541a8 100644
--- a/inference/inference.py
+++ b/inference/inference.py
@@ -33,6 +33,7 @@ def main(
     enable_sensitive_topics: bool=False, # Enable check for sensitive topics using AuditNLG APIs
     enable_salesforce_content_safety: bool=True, # Enable safety check with Salesforce safety flan t5
     max_padding_length: int=None, # the max padding length to be used with tokenizer padding the prompts.
+    use_fast_kernels: bool = False, # Enable using SDPA from PyTroch Accelerated Transformers, make use Flash Attention and Xformer memory-efficient kernels
     **kwargs
 ):
     if prompt_file is not None:
@@ -52,6 +53,23 @@ def main(
     torch.manual_seed(seed)
     
     model = load_model(model_name, quantization)
+    if peft_model:
+        model = load_peft_model(model, peft_model)
+
+    model.eval()
+    
+    if use_fast_kernels:
+        """
+        Setting 'use_fast_kernels' will enable
+        using of Flash Attention or Xformer memory-efficient kernels 
+        based on the hardware being used. This would speed up inference when used for batched inputs.
+        """
+        try:
+            from optimum.bettertransformer import BetterTransformer
+            model = BetterTransformer.transform(model)    
+        except ImportError:
+            print("Module 'optimum' not found. Please install 'optimum' it before proceeding.")
+
     tokenizer = LlamaTokenizer.from_pretrained(model_name)
     tokenizer.add_special_tokens(
         {
@@ -80,12 +98,13 @@ def main(
                 print(report)
         print("Skipping the inference as the prompt is not safe.")
         sys.exit(1)  # Exit the program with an error status
-
+        
     if peft_model:
         model = load_peft_model(model, peft_model)
 
     model.eval()
     batch = tokenizer(user_prompt, padding='max_length', truncation=True,max_length=max_padding_length,return_tensors="pt")
+
     batch = {k: v.to("cuda") for k, v in batch.items()}
     start = time.perf_counter()
     with torch.no_grad():
diff --git a/llama_finetuning.py b/llama_finetuning.py
index ccf8c68457d13597cfe96768b6f8808b94fe243b..b7c2dc60355c203fae73c808c8a595baeeb4d5c3 100644
--- a/llama_finetuning.py
+++ b/llama_finetuning.py
@@ -24,7 +24,6 @@ from transformers import (
     BitsAndBytesConfig
 )
 import torch.distributed as dist
-
 # Unused imports removed
 from utils.train_utils import (
     set_tokenizer_params,
@@ -95,7 +94,17 @@ def main(**kwargs):
         load_in_8bit=True if train_config.quantization else None,
         device_map="auto" if train_config.quantization else None,
     )
-    
+    if train_config.enable_fsdp and train_config.use_fast_kernels:
+        """
+        For FSDP and FSDP+PEFT, setting 'use_fast_kernels' will enable
+        using of Flash Attention or Xformer memory-efficient kernels 
+        based on the hardware being used. This would speed up fine-tuning.
+        """
+        try:
+            from optimum.bettertransformer import BetterTransformer
+            model = BetterTransformer.transform(model) 
+        except ImportError:
+            print("Module 'optimum' not found. Please install 'optimum' it before proceeding.")
     print_model_size(model, train_config, rank if train_config.enable_fsdp else 0)
     
     # Prepare the model for int8 training if quantization is enabled
diff --git a/requirements.txt b/requirements.txt
index 9258c3e4e75e7714034a4fd177cad7e4615dffef..6c499b1b9275547c40cb1d16310a5da0cbec1b1d 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -13,4 +13,4 @@ transformers>=4.31.0
 sentencepiece
 py7zr
 scipy
-
+optimum
diff --git a/scripts/spellcheck.sh b/scripts/spellcheck.sh
index 7f423d5037cafa310a2761e395dc0dba9d270214..9cd9936eab55940f689cda0bf0ddfc958c48c91d 100755
--- a/scripts/spellcheck.sh
+++ b/scripts/spellcheck.sh
@@ -1,3 +1,6 @@
+
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
 # Source: https://github.com/pytorch/torchx/blob/main/scripts/spellcheck.sh
 set -ex
 sudo apt-get install aspell
diff --git a/scripts/spellcheck_conf/wordlist.txt b/scripts/spellcheck_conf/wordlist.txt
index 77a5572184d8470f9d780402670d067d3d66e600..27c7323cdbcfb06bfca606234c4d2ba31c577b9d 100644
--- a/scripts/spellcheck_conf/wordlist.txt
+++ b/scripts/spellcheck_conf/wordlist.txt
@@ -1089,4 +1089,35 @@ fragmentations
 intra
 nightlies
 recenly
-uncomment
\ No newline at end of file
+uncomment
+BFloat
+DDP
+LLM
+Xformer
+accuracies
+activations
+anyprecision
+aplaca
+assembels
+boolean
+checkpoining
+defatults
+gradinets
+itermediate
+recommond
+scaler
+sharding
+slurm
+summarization
+theJfleg
+xA
+Jupyter
+LLM
+Xformer
+dataset's
+jupyter
+mutli
+summarization
+xA
+Sanitization
+tokenization
diff --git a/utils/train_utils.py b/utils/train_utils.py
index 3fa4c0cf1d34a3b1e6bbad010597216d7669d1e6..8113ef173bb748a083e88a899ac8033f131e5421 100644
--- a/utils/train_utils.py
+++ b/utils/train_utils.py
@@ -5,6 +5,7 @@ import os
 import sys
 from typing import List
 import yaml
+import time
 
 import fire
 import torch
@@ -73,9 +74,12 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
     train_loss = []
     val_prep = []
     val_loss =[]
+    epoch_times = []
+    checkpoint_times = []
     results = {}
     best_val_loss = float("inf")
     for epoch in range(train_config.num_epochs):
+        epoch_start_time = time.perf_counter()
         with MemoryTrace() as memtrace:  # track the memory usage
             model.train()
             total_loss = 0.0
@@ -106,7 +110,8 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
                         print(f"\n step {step} is completed and loss is {loss.detach().float()}")
                 else:
                     print(f"\n step {step} is completed and loss is {loss.detach().float()}")
-                    
+        epoch_end_time = time.perf_counter()-epoch_start_time
+        epoch_times.append(epoch_end_time)    
         # Reducing total_loss across all devices if there's more than one CUDA device
         if torch.cuda.device_count() > 1 and train_config.enable_fsdp:
             dist.all_reduce(total_loss, op=dist.ReduceOp.SUM)
@@ -117,6 +122,7 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
         
         train_prep.append(train_perplexity)
         train_loss.append(train_epoch_loss)
+        
         if train_config.enable_fsdp:
             if rank==0:
                 print(f"Max CUDA memory allocated was {memtrace.peak} GB")
@@ -136,6 +142,7 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
           
         if train_config.run_validation:
             eval_ppl, eval_epoch_loss = evaluation(model, train_config, eval_dataloader, rank, tokenizer)   
+            checkpoint_start_time = time.perf_counter()
             if train_config.save_model and eval_epoch_loss < best_val_loss:
                 if train_config.enable_fsdp:
                     dist.barrier()
@@ -165,18 +172,19 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
                         model_checkpointing.save_model_and_optimizer_sharded(model, rank, train_config)
                         if train_config.save_optimizer:
                             model_checkpointing.save_model_and_optimizer_sharded(model, rank, train_config, optim=optimizer)
-                            print(" Saving the FSDP model checkpoints qnd optimizer using SHARDED_STATE_DICT")
+                            print(" Saving the FSDP model checkpoints and optimizer using SHARDED_STATE_DICT")
                             print("=====================================================")
 
                     if not train_config.use_peft and  train_config.save_optimizer:
                         model_checkpointing.save_optimizer_checkpoint(
                             model, optimizer, rank, train_config, epoch=epoch
                         )
-                        print(" Saving the FSDP model checkpoints qnd optimizer using FULL_STATE_DICT")
+                        print(" Saving the FSDP model checkpoints and optimizer using FULL_STATE_DICT")
                         print("=====================================================")                     
                 if train_config.enable_fsdp:
                     dist.barrier()
-            
+            checkpoint_end_time = time.perf_counter() - checkpoint_start_time
+            checkpoint_times.append(checkpoint_end_time)
             if eval_epoch_loss < best_val_loss:
                 best_val_loss = eval_epoch_loss
                 if train_config.enable_fsdp:
@@ -189,10 +197,11 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
         
         if train_config.enable_fsdp:
             if rank==0:
-                print(f"Epoch {epoch+1}: train_perplexity={train_perplexity:.4f}, train_epoch_loss={train_epoch_loss:.4f}")
+                print(f"Epoch {epoch+1}: train_perplexity={train_perplexity:.4f}, train_epoch_loss={train_epoch_loss:.4f}, epcoh time {epoch_end_time}s")
         else:
-            print(f"Epoch {epoch+1}: train_perplexity={train_perplexity:.4f}, train_epoch_loss={train_epoch_loss:.4f}")
-            
+            print(f"Epoch {epoch+1}: train_perplexity={train_perplexity:.4f}, train_epoch_loss={train_epoch_loss:.4f}, epcoh time {epoch_end_time}s")
+    avg_epoch_time = sum(epoch_times)/ len(epoch_times) 
+    avg_checkpoint_time = sum(checkpoint_times)/ len(checkpoint_times)   
     avg_train_prep = sum(train_prep)/len(train_prep)
     avg_train_loss = sum(train_loss)/len(train_loss)
     if train_config.run_validation:
@@ -204,7 +213,9 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
     if train_config.run_validation:
         results['avg_eval_prep'] = avg_eval_prep
         results['avg_eval_loss'] = avg_eval_loss
-        
+    results["avg_epoch_time"] = avg_epoch_time
+    results["avg_checkpoint_time"] = avg_checkpoint_time
+    
     #saving the training params including fsdp setting for reference.
     if train_config.enable_fsdp and not train_config.use_peft:
         save_train_params(train_config, fsdp_config, rank)