From b84c5aa02c415c3563afa1ab3421868ac8a43ce2 Mon Sep 17 00:00:00 2001
From: duzx16 <zx-du20@mails.tsinghua.edu.cn>
Date: Wed, 1 Dec 2021 15:42:34 +0800
Subject: [PATCH] Accept config_params for deepspeed engine

---
 SwissArmyTransformer/training/deepspeed_training.py | 5 +++--
 1 file changed, 3 insertions(+), 2 deletions(-)

diff --git a/SwissArmyTransformer/training/deepspeed_training.py b/SwissArmyTransformer/training/deepspeed_training.py
index a47795b..24cb9c6 100644
--- a/SwissArmyTransformer/training/deepspeed_training.py
+++ b/SwissArmyTransformer/training/deepspeed_training.py
@@ -162,7 +162,7 @@ def get_model(args, model_cls):
     return model
 
 
-def setup_model_and_optimizer(args, model_cls):
+def setup_model_and_optimizer(args, model_cls, config_params=None):
     """Setup model and optimizer."""
 
     model = get_model(args, model_cls)
@@ -179,7 +179,8 @@ def setup_model_and_optimizer(args, model_cls):
                 model_parameters=param_groups,
                 args=args,
                 mpu=mpu,
-                dist_init_required=False
+                dist_init_required=False,
+                config_params=config_params
             )
         else:
             raise ValueError('Currently, we only support training with deepspeed.')
-- 
GitLab