Skip to content
Snippets Groups Projects
Commit fac71cd1 authored by Matthias Reso's avatar Matthias Reso
Browse files

Fix src/tests/test_train_utils.py

parent 1090ccf4
Branches
No related tags found
No related merge requests found
......@@ -27,7 +27,12 @@ def temp_output_dir():
@patch("llama_recipes.utils.train_utils.nullcontext")
@patch("llama_recipes.utils.train_utils.torch.cuda.amp.GradScaler")
@patch("llama_recipes.utils.train_utils.torch.cuda.amp.autocast")
def test_gradient_accumulation(autocast, scaler, nullcontext, mem_trace, mocker):
def test_gradient_accumulation(
autocast,
scaler,
nullcontext,
mem_trace,
mocker):
model = mocker.MagicMock(name="model")
model().loss.__truediv__().detach.return_value = torch.tensor(1)
......@@ -47,6 +52,9 @@ def test_gradient_accumulation(autocast, scaler, nullcontext, mem_trace, mocker)
train_config.max_train_step = 0
train_config.max_eval_step = 0
train_config.save_metrics = False
train_config.flop_counter_start = 0
train_config.use_profiler = False
train_config.flop_counter = True
train(
model,
......@@ -103,6 +111,7 @@ def test_save_to_json(temp_output_dir, mocker):
train_config.max_train_step = 0
train_config.max_eval_step = 0
train_config.output_dir = temp_output_dir
train_config.flop_counter_start = 0
train_config.use_profiler = False
results = train(
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment