diff --git a/llama_finetuning.py b/llama_finetuning.py index ad33de3838faf21ebaae9df43624f4eddbcc13e7..cf573b020336d298e684631c56e1d9849b325b32 100644 --- a/llama_finetuning.py +++ b/llama_finetuning.py @@ -137,7 +137,7 @@ def main(**kwargs): sharding_strategy=fsdp_config.sharding_strategy, device_id=torch.cuda.current_device(), limit_all_gathers=True, - sync_module_states=True if train_config.low_cpu_fsdp else False, + sync_module_states=train_config.low_cpu_fsdp, param_init_fn=lambda module: module.to_empty(device=torch.device("cuda"), recurse=False) if train_config.low_cpu_fsdp and rank != 0 else None, )