diff --git a/.github/scripts/spellcheck_conf/wordlist.txt b/.github/scripts/spellcheck_conf/wordlist.txt
index a5ff50c951cdba08180556db61f0b93fd9cd0b6d..90ef6e234f2393d9f5743d40d084a13d37ae94c1 100644
--- a/.github/scripts/spellcheck_conf/wordlist.txt
+++ b/.github/scripts/spellcheck_conf/wordlist.txt
@@ -1351,6 +1351,7 @@ Weaviate
 MediaGen
 SDXL
 SVD
+QLORA
 Agentic
 AutoGen
 DeepLearning
@@ -1399,6 +1400,8 @@ sqlite
 customerservice
 fn
 ExecuTorch
+nf
+quant
 DLAI
 agentic
 containts
diff --git a/docs/multi_gpu.md b/docs/multi_gpu.md
index e3e7b0aefe7389b63755c1623480b2f04865c66f..c87f9b9c23414ff65892d287fbe806835c2ddd44 100644
--- a/docs/multi_gpu.md
+++ b/docs/multi_gpu.md
@@ -56,6 +56,14 @@ torchrun --nnodes 1 --nproc_per_node 8  recipes/quickstart/finetuning/finetuning
 
 ```
 
+### Fine-tuning using FSDP + QLORA
+
+This has been tested on 4 H100s GPUs.
+
+```bash
+ FSDP_CPU_RAM_EFFICIENT_LOADING=1 ACCELERATE_USE_FSDP=1 torchrun --nnodes 1 --nproc_per_node 4  finetuning.py --enable_fsdp  --quantization int4 --model_name /path_of_model_folder/70B  --mixed_precision False --low_cpu_fsdp --use_peft --peft_method lora --output_dir Path/to/save/PEFT/model
+```
+
 ### Fine-tuning using FSDP on 70B Model
 
 If you are interested in running full parameter fine-tuning on the 70B model, you can enable `low_cpu_fsdp` mode as the following command. This option will load model on rank0 only before moving model to devices to construct FSDP. This can dramatically save cpu memory when loading large models like 70B (on a 8-gpu node, this reduces cpu memory from 2+T to 280G for 70B model). This has been tested with `BF16` on 16xA100, 80GB GPUs.
diff --git a/docs/single_gpu.md b/docs/single_gpu.md
index 168acadd9387033dd798056958f18b7e85e044d8..3f6834ef839d0f34164157832c22b8ae5bdc139d 100644
--- a/docs/single_gpu.md
+++ b/docs/single_gpu.md
@@ -17,10 +17,11 @@ To run the examples, make sure to install the llama-recipes package (See [README
 
 Get access to a machine with one GPU or if using a multi-GPU machine please make sure to only make one of them visible using `export CUDA_VISIBLE_DEVICES=GPU:id` and run the following. It runs by default with `samsum_dataset` for summarization application.
 
+**NOTE** To run the fine-tuning with `QLORA`, make sure to set `--peft_method lora` and `--quantization int4`.
 
 ```bash
 
-python -m llama_recipes.finetuning  --use_peft --peft_method lora --quantization --use_fp16 --model_name /path_of_model_folder/8B --output_dir Path/to/save/PEFT/model
+python -m llama_recipes.finetuning  --use_peft --peft_method lora --quantization 8bit --use_fp16 --model_name /path_of_model_folder/8B --output_dir Path/to/save/PEFT/model
 
 ```
 The args used in the command above are:
@@ -51,16 +52,16 @@ to run with each of the datasets set the `dataset` flag in the command as shown
 ```bash
 # grammer_dataset
 
-python -m llama_recipes.finetuning  --use_peft --peft_method lora --quantization  --dataset grammar_dataset --model_name /path_of_model_folder/8B --output_dir Path/to/save/PEFT/model
+python -m llama_recipes.finetuning  --use_peft --peft_method lora --quantization 8bit --dataset grammar_dataset --model_name /path_of_model_folder/8B --output_dir Path/to/save/PEFT/model
 
 # alpaca_dataset
 
-python -m llama_recipes.finetuning  --use_peft --peft_method lora --quantization  --dataset alpaca_dataset --model_name /path_of_model_folder/8B --output_dir Path/to/save/PEFT/model
+python -m llama_recipes.finetuning  --use_peft --peft_method lora --quantization 8bit --dataset alpaca_dataset --model_name /path_of_model_folder/8B --output_dir Path/to/save/PEFT/model
 
 
 # samsum_dataset
 
-python -m llama_recipes.finetuning  --use_peft --peft_method lora --quantization  --dataset samsum_dataset --model_name /path_of_model_folder/8B --output_dir Path/to/save/PEFT/model
+python -m llama_recipes.finetuning  --use_peft --peft_method lora --quantization 8bit --dataset samsum_dataset --model_name /path_of_model_folder/8B --output_dir Path/to/save/PEFT/model
 
 ```
 
diff --git a/recipes/quickstart/finetuning/README.md b/recipes/quickstart/finetuning/README.md
index ad4808082052741c27dda5d21c517d7af1e68b7d..b41ef6d43d79bbf92eec1d117a3eb3b9b8806efd 100644
--- a/recipes/quickstart/finetuning/README.md
+++ b/recipes/quickstart/finetuning/README.md
@@ -54,7 +54,7 @@ It lets us specify the training settings for everything from `model_name` to `da
     output_dir: str = "PATH/to/save/PEFT/model"
     freeze_layers: bool = False
     num_freeze_layers: int = 1
-    quantization: bool = False
+    quantization: str = None
     one_gpu: bool = False
     save_model: bool = True
     dist_checkpoint_root_folder: str="PATH/to/save/FSDP/model" # will be used if using FSDP
@@ -101,7 +101,7 @@ It lets us specify the training settings for everything from `model_name` to `da
 You can enable [W&B](https://wandb.ai/) experiment tracking by using `use_wandb` flag as below. You can change the project name, entity and other `wandb.init` arguments in `wandb_config`.
 
 ```bash
-python -m llama_recipes.finetuning --use_peft --peft_method lora --quantization --model_name /path_of_model_folder/8B --output_dir Path/to/save/PEFT/model --use_wandb
+python -m llama_recipes.finetuning --use_peft --peft_method lora --quantization 8bit --model_name /path_of_model_folder/8B --output_dir Path/to/save/PEFT/model --use_wandb
 ```
 You'll be able to access a dedicated project or run link on [wandb.ai](https://wandb.ai) and see your dashboard like the one below.
 <div style="display: flex;">
diff --git a/recipes/quickstart/finetuning/multigpu_finetuning.md b/recipes/quickstart/finetuning/multigpu_finetuning.md
index d375b76b2bf7f6e62cd17644f1dd110b2aed1af5..b936660a3ff11533b8dcaae3545da32bca4edb51 100644
--- a/recipes/quickstart/finetuning/multigpu_finetuning.md
+++ b/recipes/quickstart/finetuning/multigpu_finetuning.md
@@ -18,6 +18,14 @@ We will also need 2 packages:
 ## How to run it
 Get access to a machine with multiple GPUs (in this case we tested with 4 A100 and A10s).
 
+### With FSDP + QLORA
+
+This has been tested on 4 H100s GPUs.
+
+```bash
+ FSDP_CPU_RAM_EFFICIENT_LOADING=1 ACCELERATE_USE_FSDP=1 torchrun --nnodes 1 --nproc_per_node 4  finetuning.py --enable_fsdp  --quantization int4 --model_name /path_of_model_folder/70B  --mixed_precision False --low_cpu_fsdp --use_peft --peft_method lora --output_dir Path/to/save/PEFT/model
+```
+
 ### With FSDP + PEFT
 
 <details open>
diff --git a/recipes/quickstart/finetuning/singlegpu_finetuning.md b/recipes/quickstart/finetuning/singlegpu_finetuning.md
index f1a4a38bf49e31eae3b4bc8a652fa5d1a40c7ebc..1b054be18d6a5bd61f89a720d49a41d86921e4a3 100644
--- a/recipes/quickstart/finetuning/singlegpu_finetuning.md
+++ b/recipes/quickstart/finetuning/singlegpu_finetuning.md
@@ -15,14 +15,17 @@ To run fine-tuning on a single GPU, we will make use of two packages:
 
 ## How to run it?
 
+**NOTE** To run the fine-tuning with `QLORA`, make sure to set `--peft_method lora` and `--quantization 4bit --quantization_config.quant_type nf4`.
+
+
 ```bash
-python finetuning.py  --use_peft --peft_method lora --quantization --use_fp16 --model_name /path_of_model_folder/8B --output_dir Path/to/save/PEFT/model
+FSDP_CPU_RAM_EFFICIENT_LOADING=1 python finetuning.py  --use_peft --peft_method lora --quantization 8bit --model_name /path_of_model_folder/8B --output_dir Path/to/save/PEFT/model
 ```
 The args used in the command above are:
 
 * `--use_peft` boolean flag to enable PEFT methods in the script
 * `--peft_method` to specify the PEFT method, here we use `lora` other options are `llama_adapter`, `prefix`.
-* `--quantization` boolean flag to enable int8 quantization
+* `--quantization` string flag to enable 8bit or 4bit quantization
 
 > [!NOTE]
 > In case you are using a multi-GPU machine please make sure to only make one of them visible using `export CUDA_VISIBLE_DEVICES=GPU:id`.
@@ -48,16 +51,16 @@ to run with each of the datasets set the `dataset` flag in the command as shown
 ```bash
 # grammar_dataset
 
-python -m finetuning.py  --use_peft --peft_method lora --quantization  --dataset grammar_dataset --model_name /path_of_model_folder/8B --output_dir Path/to/save/PEFT/model
+python -m finetuning.py  --use_peft --peft_method lora --quantization 8bit --dataset grammar_dataset --model_name /path_of_model_folder/8B --output_dir Path/to/save/PEFT/model
 
 # alpaca_dataset
 
-python -m finetuning.py  --use_peft --peft_method lora --quantization  --dataset alpaca_dataset --model_name /path_of_model_folder/8B --output_dir Path/to/save/PEFT/model
+python -m finetuning.py  --use_peft --peft_method lora --quantization 8bit  --dataset alpaca_dataset --model_name /path_of_model_folder/8B --output_dir Path/to/save/PEFT/model
 
 
 # samsum_dataset
 
-python -m finetuning.py  --use_peft --peft_method lora --quantization  --dataset samsum_dataset --model_name /path_of_model_folder/8B --output_dir Path/to/save/PEFT/model
+python -m finetuning.py  --use_peft --peft_method lora --quantization 8bit  --dataset samsum_dataset --model_name /path_of_model_folder/8B --output_dir Path/to/save/PEFT/model
 
 ```
 
diff --git a/recipes/quickstart/inference/local_inference/README.md b/recipes/quickstart/inference/local_inference/README.md
index 427edd54e1c2cc0f7c9f754f456029817b317e39..de610137b7050bd2e07ff69b975d1fedf9265846 100644
--- a/recipes/quickstart/inference/local_inference/README.md
+++ b/recipes/quickstart/inference/local_inference/README.md
@@ -46,7 +46,7 @@ Padding would be required for batch inference. In this this [example](inference.
 The inference folder also includes a chat completion example, that adds built-in safety features in fine-tuned models to the prompt tokens. To run the example:
 
 ```bash
-python chat_completion/chat_completion.py --model_name "PATH/TO/MODEL/7B/" --prompt_file chat_completion/chats.json  --quantization --use_auditnlg
+python chat_completion/chat_completion.py --model_name "PATH/TO/MODEL/7B/" --prompt_file chat_completion/chats.json  --quantization 8bit --use_auditnlg
 
 ```
 
@@ -55,7 +55,7 @@ python chat_completion/chat_completion.py --model_name "PATH/TO/MODEL/7B/" --pro
 Setting `use_fast_kernels` will enable using of Flash Attention or Xformer memory-efficient kernels based on the hardware being used. This would speed up inference when used for batched inputs. This has been enabled in `optimum` library from HuggingFace as a one-liner API, please read more [here](https://pytorch.org/blog/out-of-the-box-acceleration/).
 
 ```bash
-python chat_completion/chat_completion.py --model_name "PATH/TO/MODEL/7B/" --prompt_file chat_completion/chats.json  --quantization --use_auditnlg --use_fast_kernels
+python chat_completion/chat_completion.py --model_name "PATH/TO/MODEL/7B/" --prompt_file chat_completion/chats.json  --quantization 8bit --use_auditnlg --use_fast_kernels
 
 python inference.py --model_name <training_config.output_dir> --peft_model <training_config.output_dir> --prompt_file <test_prompt_file> --use_auditnlg --use_fast_kernels
 
diff --git a/recipes/responsible_ai/llama_guard/README.md b/recipes/responsible_ai/llama_guard/README.md
index 42233bec87ce653e5d94adfdb749c4a0da8209d0..39055d8ca7f659b8f316817881614f3de4e179fe 100644
--- a/recipes/responsible_ai/llama_guard/README.md
+++ b/recipes/responsible_ai/llama_guard/README.md
@@ -2,7 +2,7 @@
 <!-- markdown-link-check-disable -->
 Meta Llama Guard is a language model that provides input and output guardrails for LLM inference. For more details and model cards, please visit the main repository for each model, [Meta Llama Guard](https://github.com/meta-llama/PurpleLlama/tree/main/Llama-Guard) and Meta [Llama Guard 2](https://github.com/meta-llama/PurpleLlama/tree/main/Llama-Guard2).
 
-This folder contains an example file to run inference with a locally hosted model, either using the Hugging Face Hub or a local path. 
+This folder contains an example file to run inference with a locally hosted model, either using the Hugging Face Hub or a local path.
 
 ## Requirements
 1. Access to Llama guard model weights on Hugging Face. To get access, follow the steps described [here](https://github.com/facebookresearch/PurpleLlama/tree/main/Llama-Guard#download)
@@ -10,7 +10,7 @@ This folder contains an example file to run inference with a locally hosted mode
 
 
 ## Llama Guard inference script
-For testing, you can add User or User/Agent interactions into the prompts list and the run the script to verify the results. When the conversation has one or more Agent responses, it's considered of type agent. 
+For testing, you can add User or User/Agent interactions into the prompts list and the run the script to verify the results. When the conversation has one or more Agent responses, it's considered of type agent.
 
 
 ```
@@ -66,7 +66,4 @@ In this case, the default categories are applied by the tokenizer, using the `ap
 
 Use this command for testing with a quantized Llama model, modifying the values accordingly:
 
-`python examples/inference.py --model_name <path_to_regular_llama_model> --prompt_file <path_to_prompt_file> --quantization --enable_llamaguard_content_safety`
-
-
-
+`python examples/inference.py --model_name <path_to_regular_llama_model> --prompt_file <path_to_prompt_file> --quantization 8bit --enable_llamaguard_content_safety`
diff --git a/src/llama_recipes/configs/__init__.py b/src/llama_recipes/configs/__init__.py
index 5db9c216bbe566dbfaef2e05bd76a72c087a450b..67d2d9a67b15fb796b1bd3d1d239fe136f777561 100644
--- a/src/llama_recipes/configs/__init__.py
+++ b/src/llama_recipes/configs/__init__.py
@@ -5,3 +5,4 @@ from llama_recipes.configs.peft import lora_config, llama_adapter_config, prefix
 from llama_recipes.configs.fsdp import fsdp_config
 from llama_recipes.configs.training import train_config
 from llama_recipes.configs.wandb import wandb_config
+from llama_recipes.configs.quantization import quantization_config
diff --git a/src/llama_recipes/configs/quantization.py b/src/llama_recipes/configs/quantization.py
new file mode 100644
index 0000000000000000000000000000000000000000..ecefa2fafc2588a619026e5c0ecf7b36342a1e1d
--- /dev/null
+++ b/src/llama_recipes/configs/quantization.py
@@ -0,0 +1,30 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
+
+from dataclasses import dataclass
+from typing import Optional
+import torch
+from transformers import BitsAndBytesConfig
+
+@dataclass
+class quantization_config:
+    quant_type: str =  "fp4" # "fp4" or "nf4"
+    compute_dtype: torch.dtype = torch.bfloat16
+    use_double_quant: bool = False
+    quant_storage: torch.dtype = torch.bfloat16
+
+    def create_bnb_config(self, quantization: str) -> BitsAndBytesConfig:
+        if quantization not in {"4bit", "8bit"}:
+            raise ValueError("quantization must be either '4bit' or '8bit'")
+
+        if quantization == "4bit":
+            config_params = {
+                "bnb_4bit_quant_type": self.quant_type,
+                "bnb_4bit_compute_dtype": self.compute_dtype,
+                "bnb_4bit_use_double_quant": self.use_double_quant,
+                "bnb_4bit_quant_storage": self.quant_storage,
+            }
+            
+            return BitsAndBytesConfig(load_in_4bit=True, **config_params)
+        else:
+            return BitsAndBytesConfig(load_in_8bit=True)
diff --git a/src/llama_recipes/configs/training.py b/src/llama_recipes/configs/training.py
index 7ae8265b172a551a43c36940ec32c3c99a11d5b6..14d77f37ff100138e17e0d76e202f5b59dd774f1 100644
--- a/src/llama_recipes/configs/training.py
+++ b/src/llama_recipes/configs/training.py
@@ -35,7 +35,7 @@ class train_config:
     output_dir: str = "PATH/to/save/PEFT/model"
     freeze_layers: bool = False
     num_freeze_layers: int = 1
-    quantization: bool = False
+    quantization: str = None
     one_gpu: bool = False
     save_model: bool = True
     dist_checkpoint_root_folder: str="PATH/to/save/FSDP/model" # will be used if using FSDP
diff --git a/src/llama_recipes/finetuning.py b/src/llama_recipes/finetuning.py
index a2470d8df8b4ae33b1d3737ebecdceb503242e7c..76dd4d56994c9c4dc3a82ddf673e2674835ea8a8 100644
--- a/src/llama_recipes/finetuning.py
+++ b/src/llama_recipes/finetuning.py
@@ -1,6 +1,7 @@
 # Copyright (c) Meta Platforms, Inc. and affiliates.
 # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
 
+from collections import Counter
 import os
 
 import dataclasses
@@ -8,7 +9,7 @@ import fire
 import random
 import torch
 import torch.optim as optim
-from peft import get_peft_model, prepare_model_for_kbit_training, PeftModel
+from peft import get_peft_model, PeftModel
 from torch.distributed.fsdp import (
     FullyShardedDataParallel as FSDP,
     ShardingStrategy
@@ -18,6 +19,7 @@ from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload
 from torch.optim.lr_scheduler import StepLR
 from transformers import (
     AutoTokenizer,
+    BitsAndBytesConfig,
     LlamaForCausalLM,
     LlamaConfig,
 )
@@ -25,6 +27,7 @@ from transformers.models.llama.modeling_llama import LlamaDecoderLayer
 
 from llama_recipes.configs import fsdp_config as FSDP_CONFIG
 from llama_recipes.configs import train_config as TRAIN_CONFIG
+from llama_recipes.configs import quantization_config  as QUANTIZATION_CONFIG
 from llama_recipes.data.concatenator import ConcatDataset
 from llama_recipes.policies import AnyPrecisionAdamW, apply_fsdp_checkpointing
 
@@ -48,6 +51,7 @@ from llama_recipes.utils.train_utils import (
     get_policies,
 )
 from accelerate.utils import is_xpu_available
+from warnings import warn
 
 def setup_wandb(train_config, fsdp_config, **kwargs):
     try:
@@ -66,7 +70,6 @@ def setup_wandb(train_config, fsdp_config, **kwargs):
     run.config.update(fsdp_config, allow_val_change=True)
     return run
 
-
 def main(**kwargs):
     # Update the configuration for the training and sharding process
     train_config, fsdp_config = TRAIN_CONFIG(), FSDP_CONFIG()
@@ -97,38 +100,31 @@ def main(**kwargs):
     if train_config.use_wandb:
         if not train_config.enable_fsdp or rank==0:
             wandb_run = setup_wandb(train_config, fsdp_config, **kwargs)
+    
+    #setting quantization configs
+    bnb_config = None
+    if train_config.quantization:
+        if type(train_config.quantization) == type(True):
+            warn("Quantization (--quantization) is a boolean, please specify quantization as '4bit' or '8bit'. Defaulting to '8bit' but this might change in the future.", FutureWarning)
+            train_config.quantization = "8bit"
+
+        if train_config.quantization == "8bit" and train_config.enable_fsdp:
+            raise ValueError("8bit quantization is not supported with FSDP, please use 4bit quantization")
+
+        quant_config = QUANTIZATION_CONFIG()
+        update_config(quant_config, **kwargs)
+        bnb_config = quant_config.create_bnb_config(train_config.quantization)
 
     # Load the pre-trained model and setup its configuration
     use_cache = False if train_config.enable_fsdp else None
-    if train_config.enable_fsdp and train_config.low_cpu_fsdp:
-        """
-        for FSDP, we can save cpu memory by loading pretrained model on rank0 only.
-        this avoids cpu oom when loading large models like llama 70B, in which case
-        model alone would consume 2+TB cpu mem (70 * 4 * 8). This will add some comms
-        overhead and currently requires latest nightly.
-        """
-        if rank == 0:
-            model = LlamaForCausalLM.from_pretrained(
-                train_config.model_name,
-                load_in_8bit=True if train_config.quantization else None,
-                device_map="auto" if train_config.quantization else None,
-                use_cache=use_cache,
-                attn_implementation="sdpa" if train_config.use_fast_kernels else None,
-            )
-        else:
-            llama_config = LlamaConfig.from_pretrained(train_config.model_name)
-            llama_config.use_cache = use_cache
-            with torch.device("meta"):
-                model = LlamaForCausalLM(llama_config)
-
-    else:
-        model = LlamaForCausalLM.from_pretrained(
-            train_config.model_name,
-            load_in_8bit=True if train_config.quantization else None,
-            device_map="auto" if train_config.quantization else None,
-            use_cache=use_cache,
-            attn_implementation="sdpa" if train_config.use_fast_kernels else None,
-        )
+    model = LlamaForCausalLM.from_pretrained(
+        train_config.model_name,
+        quantization_config=bnb_config,
+        use_cache=use_cache,
+        attn_implementation="sdpa" if train_config.use_fast_kernels else None,
+        device_map="auto" if train_config.quantization and not train_config.enable_fsdp else None,
+        torch_dtype=torch.float16 if train_config.use_fp16 else torch.bfloat16,
+    )
 
     # Load the tokenizer and add special tokens
     tokenizer = AutoTokenizer.from_pretrained(train_config.model_name if train_config.tokenizer_name is None else train_config.tokenizer_name)
@@ -142,14 +138,10 @@ def main(**kwargs):
 
     print_model_size(model, train_config, rank if train_config.enable_fsdp else 0)
 
-    # Prepare the model for int8 training if quantization is enabled
-    if train_config.quantization:
-        model = prepare_model_for_kbit_training(model)
-
     # Convert the model to bfloat16 if fsdp and pure_bf16 is enabled
-    if train_config.enable_fsdp and fsdp_config.pure_bf16:
+    if train_config.enable_fsdp and fsdp_config.pure_bf16 and not train_config.quantization:
         model.to(torch.bfloat16)
-
+        
     if train_config.use_peft:
         # Load the pre-trained peft model checkpoint and setup its configuration
         if train_config.from_peft_checkpoint:
@@ -181,7 +173,6 @@ def main(**kwargs):
             device_id = torch.xpu.current_device()
         elif torch.cuda.is_available():
             device_id = torch.cuda.current_device()
-
         model = FSDP(
             model,
             auto_wrap_policy= my_auto_wrapping_policy if train_config.use_peft else wrapping_policy,
@@ -195,8 +186,10 @@ def main(**kwargs):
             param_init_fn=(lambda module: module.to_empty(device=torch.device("cuda"), recurse=False))
             if train_config.low_cpu_fsdp and rank != 0 else None,
         )
-        if fsdp_config.fsdp_activation_checkpointing:
-            apply_fsdp_checkpointing(model)
+        if fsdp_config.fsdp_activation_checkpointing:            
+            model.enable_input_require_grads()
+            model.gradient_checkpointing_enable()
+            apply_fsdp_checkpointing(model)                      
     elif not train_config.quantization and not train_config.enable_fsdp:
         if is_xpu_available():
             model.to("xpu:0")
@@ -211,7 +204,6 @@ def main(**kwargs):
         dataset_config,
         split="train",
     )
-
     if not train_config.enable_fsdp or rank == 0:
         print(f"--> Training Set Length = {len(dataset_train)}")
 
@@ -271,7 +263,6 @@ def main(**kwargs):
             weight_decay=train_config.weight_decay,
         )
     scheduler = StepLR(optimizer, step_size=1, gamma=train_config.gamma)
-
     # Start the training process
     results = train(
         model,