diff --git a/SwissArmyTransformer/training/deepspeed_training.py b/SwissArmyTransformer/training/deepspeed_training.py
index a47795b82cdf8623818f585a463b7645750189ab..24cb9c658af091ee4ec0a528d4f26ecb8f74059a 100644
--- a/SwissArmyTransformer/training/deepspeed_training.py
+++ b/SwissArmyTransformer/training/deepspeed_training.py
@@ -162,7 +162,7 @@ def get_model(args, model_cls):
     return model
 
 
-def setup_model_and_optimizer(args, model_cls):
+def setup_model_and_optimizer(args, model_cls, config_params=None):
     """Setup model and optimizer."""
 
     model = get_model(args, model_cls)
@@ -179,7 +179,8 @@ def setup_model_and_optimizer(args, model_cls):
                 model_parameters=param_groups,
                 args=args,
                 mpu=mpu,
-                dist_init_required=False
+                dist_init_required=False,
+                config_params=config_params
             )
         else:
             raise ValueError('Currently, we only support training with deepspeed.')