Skip to content
Snippets Groups Projects
Commit b84c5aa0 authored by duzx16's avatar duzx16
Browse files

Accept config_params for deepspeed engine

parent 49dea453
No related branches found
No related tags found
No related merge requests found
...@@ -162,7 +162,7 @@ def get_model(args, model_cls): ...@@ -162,7 +162,7 @@ def get_model(args, model_cls):
return model return model
def setup_model_and_optimizer(args, model_cls): def setup_model_and_optimizer(args, model_cls, config_params=None):
"""Setup model and optimizer.""" """Setup model and optimizer."""
model = get_model(args, model_cls) model = get_model(args, model_cls)
...@@ -179,7 +179,8 @@ def setup_model_and_optimizer(args, model_cls): ...@@ -179,7 +179,8 @@ def setup_model_and_optimizer(args, model_cls):
model_parameters=param_groups, model_parameters=param_groups,
args=args, args=args,
mpu=mpu, mpu=mpu,
dist_init_required=False dist_init_required=False,
config_params=config_params
) )
else: else:
raise ValueError('Currently, we only support training with deepspeed.') raise ValueError('Currently, we only support training with deepspeed.')
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment