Skip to content
Snippets Groups Projects
Commit 5da84b29 authored by Matthias Reso's avatar Matthias Reso
Browse files

Fix usage of dataclass for train_config and fsdp_config

parent aa5dee24
No related branches found
No related tags found
No related merge requests found
......@@ -21,7 +21,8 @@ from transformers import (
)
from transformers.models.llama.modeling_llama import LlamaDecoderLayer
from llama_recipes.configs import fsdp_config, train_config
from llama_recipes.configs import fsdp_config as FSDP_CONFIG
from llama_recipes.configs import train_config as TRAIN_CONFIG
from llama_recipes.data.concatenator import ConcatDataset
from llama_recipes.policies import AnyPrecisionAdamW, apply_fsdp_checkpointing
......@@ -47,6 +48,7 @@ from llama_recipes.utils.train_utils import (
def main(**kwargs):
# Update the configuration for the training and sharding process
train_config, fsdp_config = TRAIN_CONFIG(), FSDP_CONFIG()
update_config((train_config, fsdp_config), **kwargs)
# Set the seeds for reproducibility
......
......@@ -101,8 +101,8 @@ def test_finetuning_weight_decay(step_lr, get_peft_model, get_dataset, tokenizer
get_dataset.return_value = get_fake_dataset()
get_peft_model.return_value = Linear(1,1)
get_peft_model.return_value.print_trainable_parameters=lambda:None
get_model.return_value = Linear(1,1)
main(**kwargs)
assert train.call_count == 1
......@@ -123,10 +123,7 @@ def test_finetuning_weight_decay(step_lr, get_peft_model, get_dataset, tokenizer
@patch('llama_recipes.finetuning.optim.AdamW')
@patch('llama_recipes.finetuning.StepLR')
def test_batching_strategy(step_lr, optimizer, get_dataset, tokenizer, get_model, train):
kwargs = {
"batching_strategy": "packing",
"use_peft": False,
}
kwargs = {"batching_strategy": "packing"}
get_dataset.return_value = get_fake_dataset()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment