diff --git a/SwissArmyTransformer/training/deepspeed_training.py b/SwissArmyTransformer/training/deepspeed_training.py index a47795b82cdf8623818f585a463b7645750189ab..24cb9c658af091ee4ec0a528d4f26ecb8f74059a 100644 --- a/SwissArmyTransformer/training/deepspeed_training.py +++ b/SwissArmyTransformer/training/deepspeed_training.py @@ -162,7 +162,7 @@ def get_model(args, model_cls): return model -def setup_model_and_optimizer(args, model_cls): +def setup_model_and_optimizer(args, model_cls, config_params=None): """Setup model and optimizer.""" model = get_model(args, model_cls) @@ -179,7 +179,8 @@ def setup_model_and_optimizer(args, model_cls): model_parameters=param_groups, args=args, mpu=mpu, - dist_init_required=False + dist_init_required=False, + config_params=config_params ) else: raise ValueError('Currently, we only support training with deepspeed.')