Skip to content
Snippets Groups Projects
Unverified Commit e6b0f971 authored by haozhx23's avatar haozhx23 Committed by GitHub
Browse files

Fix hsdp_device_mesh=None when enable HSDP and HYBRID_SHARD (#402)

parent 4e1466c5
No related branches found
No related tags found
No related merge requests found
......@@ -163,10 +163,9 @@ def main(**kwargs):
wandb_run.config.update(peft_config)
model.print_trainable_parameters()
hsdp_device_mesh = None
hsdp_device_mesh_plan = None
if fsdp_config.hsdp and fsdp_config.sharding_strategy == ShardingStrategy.HYBRID_SHARD:
hsdp_device_mesh = hsdp_device_mesh(replica_group_size=fsdp_config.replica_group_size, sharding_group_size=fsdp_config.sharding_group_size)
hsdp_device_mesh_plan = hsdp_device_mesh(replica_group_size=fsdp_config.replica_group_size, sharding_group_size=fsdp_config.sharding_group_size)
print("HSDP device mesh is ready")
#setting up FSDP if enable_fsdp is enabled
......@@ -189,7 +188,7 @@ def main(**kwargs):
cpu_offload=CPUOffload(offload_params=True) if fsdp_config.fsdp_cpu_offload else None,
mixed_precision=mixed_precision_policy if not fsdp_config.pure_bf16 else None,
sharding_strategy=fsdp_config.sharding_strategy,
device_mesh=hsdp_device_mesh,
device_mesh=hsdp_device_mesh_plan,
device_id=device_id,
limit_all_gathers=True,
sync_module_states=train_config.low_cpu_fsdp,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment