Skip to content
Snippets Groups Projects
Commit fc5485d9 authored by kldarek's avatar kldarek
Browse files

fixing wandb for fsdp

parent 83a7c1ec
No related branches found
No related tags found
No related merge requests found
...@@ -50,7 +50,7 @@ def setup_wandb(train_config, fsdp_config, **kwargs): ...@@ -50,7 +50,7 @@ def setup_wandb(train_config, fsdp_config, **kwargs):
import wandb import wandb
except ImportError: except ImportError:
raise ImportError( raise ImportError(
"You are trying to use wandb which is not currently installed" "You are trying to use wandb which is not currently installed. "
"Please install it using pip install wandb" "Please install it using pip install wandb"
) )
from llama_recipes.configs import wandb_config as WANDB_CONFIG from llama_recipes.configs import wandb_config as WANDB_CONFIG
...@@ -59,7 +59,7 @@ def setup_wandb(train_config, fsdp_config, **kwargs): ...@@ -59,7 +59,7 @@ def setup_wandb(train_config, fsdp_config, **kwargs):
update_config(wandb_config, **kwargs) update_config(wandb_config, **kwargs)
run = wandb.init(project=wandb_config.wandb_project, entity=wandb_entity) run = wandb.init(project=wandb_config.wandb_project, entity=wandb_entity)
run.config.update(train_config) run.config.update(train_config)
run.config.update(fsdp_config) run.config.update(fsdp_config, allow_val_change=True)
return run return run
...@@ -84,6 +84,8 @@ def main(**kwargs): ...@@ -84,6 +84,8 @@ def main(**kwargs):
clear_gpu_cache(local_rank) clear_gpu_cache(local_rank)
setup_environ_flags(rank) setup_environ_flags(rank)
wandb_run = None
if train_config.enable_wandb: if train_config.enable_wandb:
if not train_config.enable_fsdp or rank==0: if not train_config.enable_fsdp or rank==0:
wandb_run = setup_wandb(train_config, fsdp_config, **kwargs) wandb_run = setup_wandb(train_config, fsdp_config, **kwargs)
...@@ -152,9 +154,8 @@ def main(**kwargs): ...@@ -152,9 +154,8 @@ def main(**kwargs):
peft_config = generate_peft_config(train_config, kwargs) peft_config = generate_peft_config(train_config, kwargs)
model = get_peft_model(model, peft_config) model = get_peft_model(model, peft_config)
model.print_trainable_parameters() model.print_trainable_parameters()
if train_config.enable_wandb: if wandb_run:
if not train_config.enable_fsdp or rank==0: wandb_run.config.update(peft_config)
wandb_run.config.update(peft_config)
#setting up FSDP if enable_fsdp is enabled #setting up FSDP if enable_fsdp is enabled
if train_config.enable_fsdp: if train_config.enable_fsdp:
...@@ -260,7 +261,7 @@ def main(**kwargs): ...@@ -260,7 +261,7 @@ def main(**kwargs):
fsdp_config if train_config.enable_fsdp else None, fsdp_config if train_config.enable_fsdp else None,
local_rank if train_config.enable_fsdp else None, local_rank if train_config.enable_fsdp else None,
rank if train_config.enable_fsdp else None, rank if train_config.enable_fsdp else None,
wandb_run if train_config.enable_wandb else None, wandb_run,
) )
if not train_config.enable_fsdp or rank==0: if not train_config.enable_fsdp or rank==0:
[print(f'Key: {k}, Value: {v}') for k, v in results.items()] [print(f'Key: {k}, Value: {v}') for k, v in results.items()]
......
...@@ -275,11 +275,10 @@ def evaluation(model,train_config, eval_dataloader, local_rank, tokenizer, wandb ...@@ -275,11 +275,10 @@ def evaluation(model,train_config, eval_dataloader, local_rank, tokenizer, wandb
print(f" {eval_ppl=} {eval_epoch_loss=}") print(f" {eval_ppl=} {eval_epoch_loss=}")
if wandb_run: if wandb_run:
if not train_config.enable_fsdp or rank==0: wandb_run.log({
wandb_run.log({ 'eval/perplexity': eval_ppl,
'eval/perplexity': eval_ppl, 'eval/loss': eval_epoch_loss,
'eval/loss': eval_epoch_loss, }, commit=False)
}, commit=False)
return eval_ppl, eval_epoch_loss return eval_ppl, eval_epoch_loss
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment