Skip to content
Snippets Groups Projects
Commit 43cb6a2d authored by Matthias Reso's avatar Matthias Reso
Browse files

Remove check for nighlies for low_cpu_fsdp and bump torch version to 2.2 instead

parent cad284c6
No related branches found
No related tags found
No related merge requests found
torch>=2.0.1
torch>=2.2
accelerate
appdirs
loralib
......
......@@ -2,7 +2,6 @@
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
import os
from pkg_resources import packaging
import dataclasses
import fire
......@@ -51,7 +50,7 @@ from llama_recipes.utils.train_utils import (
from accelerate.utils import is_xpu_available
def setup_wandb(train_config, fsdp_config, **kwargs):
try:
try:
import wandb
except ImportError:
raise ImportError(
......@@ -97,7 +96,7 @@ def main(**kwargs):
if train_config.use_wandb:
if not train_config.enable_fsdp or rank==0:
wandb_run = setup_wandb(train_config, fsdp_config, **kwargs)
wandb_run = setup_wandb(train_config, fsdp_config, **kwargs)
# Load the pre-trained model and setup its configuration
use_cache = False if train_config.enable_fsdp else None
......@@ -108,11 +107,6 @@ def main(**kwargs):
model alone would consume 2+TB cpu mem (70 * 4 * 8). This will add some comms
overhead and currently requires latest nightly.
"""
v = packaging.version.parse(torch.__version__)
verify_latest_nightly = v.is_devrelease and v.dev >= 20230701
if not verify_latest_nightly:
raise Exception("latest pytorch nightly build is required to run with low_cpu_fsdp config, "
"please install latest nightly.")
if rank == 0:
model = LlamaForCausalLM.from_pretrained(
train_config.model_name,
......@@ -157,12 +151,12 @@ def main(**kwargs):
if wandb_run:
wandb_run.config.update(peft_config)
hsdp_device_mesh = None
if fsdp_config.hsdp and fsdp_config.sharding_strategy == ShardingStrategy.HYBRID_SHARD:
hsdp_device_mesh = hsdp_device_mesh(replica_group_size=fsdp_config.replica_group_size, sharding_group_size=fsdp_config.sharding_group_size)
print("HSDP device mesh is ready")
#setting up FSDP if enable_fsdp is enabled
if train_config.enable_fsdp:
if not train_config.use_peft and train_config.freeze_layers:
......@@ -171,7 +165,7 @@ def main(**kwargs):
mixed_precision_policy, wrapping_policy = get_policies(fsdp_config, rank)
my_auto_wrapping_policy = fsdp_auto_wrap_policy(model, LlamaDecoderLayer)
device_id = 0
if is_xpu_available():
device_id = torch.xpu.current_device()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment