From b84c5aa02c415c3563afa1ab3421868ac8a43ce2 Mon Sep 17 00:00:00 2001 From: duzx16 <zx-du20@mails.tsinghua.edu.cn> Date: Wed, 1 Dec 2021 15:42:34 +0800 Subject: [PATCH] Accept config_params for deepspeed engine --- SwissArmyTransformer/training/deepspeed_training.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/SwissArmyTransformer/training/deepspeed_training.py b/SwissArmyTransformer/training/deepspeed_training.py index a47795b..24cb9c6 100644 --- a/SwissArmyTransformer/training/deepspeed_training.py +++ b/SwissArmyTransformer/training/deepspeed_training.py @@ -162,7 +162,7 @@ def get_model(args, model_cls): return model -def setup_model_and_optimizer(args, model_cls): +def setup_model_and_optimizer(args, model_cls, config_params=None): """Setup model and optimizer.""" model = get_model(args, model_cls) @@ -179,7 +179,8 @@ def setup_model_and_optimizer(args, model_cls): model_parameters=param_groups, args=args, mpu=mpu, - dist_init_required=False + dist_init_required=False, + config_params=config_params ) else: raise ValueError('Currently, we only support training with deepspeed.') -- GitLab