Skip to content
Snippets Groups Projects
Commit f1d90d0f authored by Kai Wu's avatar Kai Wu
Browse files

fix wandb config update

parent 98c0284a
No related branches found
No related tags found
No related merge requests found
......@@ -34,7 +34,7 @@ The args used in the command above are:
* `--use_peft` boolean flag to enable PEFT methods in the script
* `--peft_method` to specify the PEFT method, here we use `lora` other options are `llama_adapter`, `prefix`.
* `--peft_method` to specify the PEFT method, here we use `lora` other options are `llama_adapter`.
We use `torchrun` here to spawn multiple processes for FSDP.
......
......@@ -27,7 +27,7 @@ The args used in the command above are:
* `--use_peft` boolean flag to enable PEFT methods in the script
* `--peft_method` to specify the PEFT method, here we use `lora` other options are `llama_adapter`, `prefix`.
* `--peft_method` to specify the PEFT method, here we use `lora` other options are `llama_adapter`.
* `--quantization` boolean flag to enable int8 quantization
......
......@@ -154,12 +154,13 @@ def main(**kwargs):
# Load the pre-trained peft model checkpoint and setup its configuration
if train_config.from_peft_checkpoint:
model = PeftModel.from_pretrained(model, train_config.from_peft_checkpoint, is_trainable=True)
peft_config = model.peft_config()
# Generate the peft config and start fine-tuning from original model
else:
peft_config = generate_peft_config(train_config, kwargs)
model = get_peft_model(model, peft_config)
if wandb_run:
wandb_run.config.update(peft_config)
if wandb_run:
wandb_run.config.update(peft_config)
model.print_trainable_parameters()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment